feat: Add rate limiting and automatic provider switching
- Add rate limit detection and parsing from HTTP 429 responses - Implement automatic retry with exponential backoff for short-term rate limits - Implement automatic provider switching for long-term rate limits - Add circuit breaker pattern for failing providers - Integrate with existing admin panel rate limit configuration - Add allProviders parameter to LLM.stream calls to enable provider fallback Rate limit behavior: - Short-term (< 5 min): Retry with configured backoff strategy - Long-term (≥ 5 min): Switch to next available provider - Max retries: 3 (configurable via admin panel) - Max wait time: 5 minutes (configurable via admin panel) - Provider switching: Enabled by default (configurable via admin panel) Provider priority: 1. Anthropic 2. OpenAI 3. Google 4. OpenRouter 5. Groq 6. xAI 7. Together AI 8. Perplexity 9. DeepInfra 10. Cerebras 11. Mistral 12. Cohere 13. Amazon Bedrock 14. Azure 15. GitHub Copilot 16. GitHub Copilot Enterprise 17. OpenCode 18. ZenMux 19. Google Vertex 20. GitLab
This commit is contained in:
@@ -0,0 +1,208 @@
|
||||
import { Log } from "../util/log"
|
||||
import type { Provider } from "../provider/provider"
|
||||
|
||||
export namespace ProviderSwitch {
|
||||
const log = Log.create({ service: "provider-switch" })
|
||||
|
||||
export interface ProviderState {
|
||||
providerID: string
|
||||
modelID: string
|
||||
consecutiveFailures: number
|
||||
firstFailureAt: Date | null
|
||||
backoffUntil: Date | null
|
||||
}
|
||||
|
||||
const providerStates = new Map<string, ProviderState>()
|
||||
const STATE_KEY_SEPARATOR = "::"
|
||||
|
||||
function getStateKey(providerID: string, modelID: string): string {
|
||||
return `${providerID}${STATE_KEY_SEPARATOR}${modelID}`
|
||||
}
|
||||
|
||||
function parseStateKey(key: string): { providerID: string; modelID: string } | null {
|
||||
const [providerID, ...rest] = key.split(STATE_KEY_SEPARATOR)
|
||||
if (!rest.length) return null
|
||||
return { providerID, modelID: rest.join(STATE_KEY_SEPARATOR) }
|
||||
}
|
||||
|
||||
function getState(providerID: string, modelID: string): ProviderState {
|
||||
const key = getStateKey(providerID, modelID)
|
||||
let state = providerStates.get(key)
|
||||
|
||||
if (!state) {
|
||||
state = {
|
||||
providerID,
|
||||
modelID,
|
||||
consecutiveFailures: 0,
|
||||
firstFailureAt: null,
|
||||
backoffUntil: null,
|
||||
}
|
||||
providerStates.set(key, state)
|
||||
}
|
||||
|
||||
return state
|
||||
}
|
||||
|
||||
export function recordFailure(providerID: string, modelID: string): void {
|
||||
const state = getState(providerID, modelID)
|
||||
state.consecutiveFailures++
|
||||
|
||||
if (!state.firstFailureAt) {
|
||||
state.firstFailureAt = new Date()
|
||||
}
|
||||
|
||||
log.info("Provider failure recorded", {
|
||||
providerID,
|
||||
modelID,
|
||||
consecutiveFailures: state.consecutiveFailures,
|
||||
firstFailureAt: state.firstFailureAt,
|
||||
})
|
||||
}
|
||||
|
||||
export function recordSuccess(providerID: string, modelID: string): void {
|
||||
const state = getState(providerID, modelID)
|
||||
|
||||
if (state.consecutiveFailures > 0) {
|
||||
log.info("Provider recovered", {
|
||||
providerID,
|
||||
modelID,
|
||||
consecutiveFailures: state.consecutiveFailures,
|
||||
})
|
||||
}
|
||||
|
||||
state.consecutiveFailures = 0
|
||||
state.firstFailureAt = null
|
||||
state.backoffUntil = null
|
||||
}
|
||||
|
||||
export function isBackedOff(providerID: string, modelID: string): boolean {
|
||||
const state = getState(providerID, modelID)
|
||||
|
||||
if (!state.backoffUntil) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (state.backoffUntil <= new Date()) {
|
||||
log.info("Provider backoff expired", { providerID, modelID })
|
||||
state.backoffUntil = null
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
export function setBackoff(providerID: string, modelID: string, duration: number): void {
|
||||
const state = getState(providerID, modelID)
|
||||
state.backoffUntil = new Date(Date.now() + duration)
|
||||
|
||||
log.info("Provider set to backoff", {
|
||||
providerID,
|
||||
modelID,
|
||||
duration,
|
||||
backoffUntil: state.backoffUntil,
|
||||
})
|
||||
}
|
||||
|
||||
export interface ProviderPriority {
|
||||
providerID: string
|
||||
priority: number
|
||||
}
|
||||
|
||||
const defaultProviderPriorities: ProviderPriority[] = [
|
||||
{ providerID: "anthropic", priority: 1 },
|
||||
{ providerID: "openai", priority: 2 },
|
||||
{ providerID: "google", priority: 3 },
|
||||
{ providerID: "openrouter", priority: 4 },
|
||||
{ providerID: "groq", priority: 5 },
|
||||
{ providerID: "xai", priority: 6 },
|
||||
{ providerID: "togetherai", priority: 7 },
|
||||
{ providerID: "perplexity", priority: 8 },
|
||||
{ providerID: "deepinfra", priority: 9 },
|
||||
{ providerID: "cerebras", priority: 10 },
|
||||
{ providerID: "mistral", priority: 11 },
|
||||
{ providerID: "cohere", priority: 12 },
|
||||
{ providerID: "amazon-bedrock", priority: 13 },
|
||||
{ providerID: "azure", priority: 14 },
|
||||
{ providerID: "github-copilot", priority: 15 },
|
||||
{ providerID: "github-copilot-enterprise", priority: 16 },
|
||||
{ providerID: "opencode", priority: 17 },
|
||||
{ providerID: "zenmux", priority: 18 },
|
||||
{ providerID: "google-vertex", priority: 19 },
|
||||
{ providerID: "gitlab", priority: 20 },
|
||||
]
|
||||
|
||||
export function getNextProvider(
|
||||
currentProviderID: string,
|
||||
currentModelID: string,
|
||||
allProviders: Record<string, Provider.Info>,
|
||||
): { providerID: string; modelID: string } | null {
|
||||
const currentIndex = defaultProviderPriorities.findIndex((p) => p.providerID === currentProviderID)
|
||||
|
||||
let nextPriorityProvider: ProviderPriority | null = null
|
||||
|
||||
for (let i = currentIndex + 1; i < defaultProviderPriorities.length; i++) {
|
||||
const priority = defaultProviderPriorities[i]
|
||||
const provider = allProviders[priority.providerID]
|
||||
|
||||
if (!provider || !provider.models[currentModelID]) {
|
||||
continue
|
||||
}
|
||||
|
||||
if (isBackedOff(priority.providerID, currentModelID)) {
|
||||
log.info("Skipping backoff provider", { providerID: priority.providerID, modelID: currentModelID })
|
||||
continue
|
||||
}
|
||||
|
||||
nextPriorityProvider = priority
|
||||
break
|
||||
}
|
||||
|
||||
if (!nextPriorityProvider) {
|
||||
log.info("No fallback provider available", { currentProviderID, currentModelID })
|
||||
return null
|
||||
}
|
||||
|
||||
log.info("Switching to fallback provider", {
|
||||
from: currentProviderID,
|
||||
to: nextPriorityProvider.providerID,
|
||||
modelID: currentModelID,
|
||||
})
|
||||
|
||||
return {
|
||||
providerID: nextPriorityProvider.providerID,
|
||||
modelID: currentModelID,
|
||||
}
|
||||
}
|
||||
|
||||
export function getBackoffDuration(consecutiveFailures: number): number {
|
||||
const baseDuration = 60_000 // 1 minute
|
||||
const maxDuration = 1_800_000 // 30 minutes
|
||||
|
||||
const duration = baseDuration * Math.pow(2, consecutiveFailures - 1)
|
||||
return Math.min(duration, maxDuration)
|
||||
}
|
||||
|
||||
export async function trySwitchProvider(
|
||||
currentProviderID: string,
|
||||
currentModelID: string,
|
||||
allProviders: Record<string, Provider.Info>,
|
||||
): Promise<{ providerID: string; modelID: string } | null> {
|
||||
const next = getNextProvider(currentProviderID, currentModelID, allProviders)
|
||||
|
||||
if (next) {
|
||||
setBackoff(currentProviderID, currentModelID, getBackoffDuration(getState(currentProviderID, currentModelID).consecutiveFailures))
|
||||
return next
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export function clearAllStates(): void {
|
||||
providerStates.clear()
|
||||
log.info("All provider states cleared")
|
||||
}
|
||||
|
||||
export function getAllStates(): ProviderState[] {
|
||||
return Array.from(providerStates.values())
|
||||
}
|
||||
}
|
||||
141
opencode/packages/opencode/src/rate-limit/rate-limit.ts
Normal file
141
opencode/packages/opencode/src/rate-limit/rate-limit.ts
Normal file
@@ -0,0 +1,141 @@
|
||||
import z from "zod"
|
||||
import { Log } from "../util/log"
|
||||
|
||||
export namespace RateLimit {
|
||||
const log = Log.create({ service: "rate-limit" })
|
||||
|
||||
export type RateLimitType = "short-term" | "long-term"
|
||||
|
||||
export interface ParsedRateLimit {
|
||||
type: RateLimitType
|
||||
waitTime: number
|
||||
retryAt?: Date
|
||||
}
|
||||
|
||||
export const SHORT_TERM_THRESHOLD_MS = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
export function parseHeaders(headers: Headers): ParsedRateLimit | null {
|
||||
const retryAfter = headers.get("retry-after")
|
||||
const rateLimitReset = headers.get("x-ratelimit-reset")
|
||||
const rateLimitRemaining = headers.get("x-ratelimit-remaining")
|
||||
|
||||
if (!retryAfter && !rateLimitReset) {
|
||||
return null
|
||||
}
|
||||
|
||||
let waitTime = 0
|
||||
|
||||
if (retryAfter) {
|
||||
const seconds = parseInt(retryAfter, 10)
|
||||
if (!isNaN(seconds)) {
|
||||
waitTime = seconds * 1000
|
||||
}
|
||||
}
|
||||
|
||||
if (rateLimitReset && waitTime === 0) {
|
||||
const resetTime = parseInt(rateLimitReset, 10)
|
||||
if (!isNaN(resetTime)) {
|
||||
waitTime = (resetTime * 1000) - Date.now()
|
||||
}
|
||||
}
|
||||
|
||||
if (waitTime === 0) {
|
||||
log.warn("Rate limit detected but could not parse wait time", { headers: Object.fromEntries(headers) })
|
||||
return null
|
||||
}
|
||||
|
||||
log.info("Rate limit detected", { waitTime, retryAfter, rateLimitReset, rateLimitRemaining })
|
||||
|
||||
return {
|
||||
type: waitTime < SHORT_TERM_THRESHOLD_MS ? "short-term" : "long-term",
|
||||
waitTime,
|
||||
retryAt: new Date(Date.now() + waitTime),
|
||||
}
|
||||
}
|
||||
|
||||
export async function wait(parsed: ParsedRateLimit, signal?: AbortSignal): Promise<void> {
|
||||
const { waitTime, retryAt } = parsed
|
||||
|
||||
log.info("Waiting for rate limit reset", { waitTime, retryAt })
|
||||
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Rate limit wait aborted")
|
||||
}
|
||||
|
||||
if (waitTime <= 0 || !retryAt) {
|
||||
return
|
||||
}
|
||||
|
||||
const now = Date.now()
|
||||
const waitMs = Math.max(0, retryAt.getTime() - now)
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const timeout = setTimeout(() => resolve(), waitMs)
|
||||
|
||||
signal?.addEventListener("abort", () => {
|
||||
clearTimeout(timeout)
|
||||
reject(new Error("Rate limit wait aborted"))
|
||||
})
|
||||
})
|
||||
|
||||
log.info("Rate limit wait completed")
|
||||
}
|
||||
|
||||
export function isRateLimitError(error: any): boolean {
|
||||
if (error?.cause?.statusCode === 429) {
|
||||
return true
|
||||
}
|
||||
if (error?.statusCode === 429) {
|
||||
return true
|
||||
}
|
||||
if (error?.status === 429) {
|
||||
return true
|
||||
}
|
||||
const message = error?.message?.toLowerCase() || ""
|
||||
return message.includes("rate limit") || message.includes("too many requests")
|
||||
}
|
||||
|
||||
const RateLimitConfigSchema = z.object({
|
||||
enabled: z.boolean().default(true),
|
||||
maxRetries: z.number().int().min(0).default(3),
|
||||
maxWaitTime: z.number().int().min(0).default(300_000), // 5 minutes
|
||||
backoffStrategy: z.enum(["linear", "exponential"]).default("exponential"),
|
||||
enableProviderSwitch: z.boolean().default(true),
|
||||
switchThreshold: z.number().int().min(0).default(300_000), // 5 minutes
|
||||
})
|
||||
|
||||
export type Config = z.infer<typeof RateLimitConfigSchema>
|
||||
|
||||
export const defaultConfig: Config = {
|
||||
enabled: true,
|
||||
maxRetries: 3,
|
||||
maxWaitTime: 300_000,
|
||||
backoffStrategy: "exponential",
|
||||
enableProviderSwitch: true,
|
||||
switchThreshold: 300_000,
|
||||
}
|
||||
|
||||
export function parseConfig(config: any): Config {
|
||||
try {
|
||||
return RateLimitConfigSchema.parse(config)
|
||||
} catch (error) {
|
||||
log.warn("Invalid rate limit config, using defaults", { error, config })
|
||||
return defaultConfig
|
||||
}
|
||||
}
|
||||
|
||||
export function calculateBackoff(
|
||||
attempt: number,
|
||||
strategy: Config["backoffStrategy"],
|
||||
initialDelay: number = 1000,
|
||||
): number {
|
||||
switch (strategy) {
|
||||
case "linear":
|
||||
return initialDelay * (attempt + 1)
|
||||
case "exponential":
|
||||
return initialDelay * Math.pow(2, attempt)
|
||||
default:
|
||||
return initialDelay * Math.pow(2, attempt)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,8 @@ import { SystemPrompt } from "./system"
|
||||
import { Flag } from "@/flag/flag"
|
||||
import { PermissionNext } from "@/permission/next"
|
||||
import { Auth } from "@/auth"
|
||||
import { RateLimit } from "@/rate-limit/rate-limit"
|
||||
import { ProviderSwitch } from "@/provider-switch/provider-switch"
|
||||
|
||||
export namespace LLM {
|
||||
const log = Log.create({ service: "llm" })
|
||||
@@ -39,11 +41,257 @@ export namespace LLM {
|
||||
small?: boolean
|
||||
tools: Record<string, Tool>
|
||||
retries?: number
|
||||
allProviders?: Record<string, Provider.Info>
|
||||
}
|
||||
|
||||
export type StreamOutput = StreamTextResult<ToolSet, unknown>
|
||||
|
||||
export async function stream(input: StreamInput) {
|
||||
export type RetryState {
|
||||
attempt: number
|
||||
totalWaitTime: number
|
||||
lastError: Error | null
|
||||
switchedProvider: boolean
|
||||
}
|
||||
|
||||
async function streamWithRetry(
|
||||
input: StreamInput,
|
||||
config: {
|
||||
getLanguage: () => Promise<any>
|
||||
cfg: any
|
||||
provider: Provider.Info
|
||||
auth: any
|
||||
isCodex: boolean
|
||||
system: string[]
|
||||
params: any
|
||||
options: any
|
||||
headers: any
|
||||
maxOutputTokens: any
|
||||
tools: Record<string, Tool>
|
||||
rateLimitConfig: RateLimit.Config
|
||||
},
|
||||
retryState: RetryState,
|
||||
): Promise<StreamOutput> {
|
||||
const { getLanguage, cfg, provider, auth, isCodex, system, params, options, headers, maxOutputTokens, tools, rateLimitConfig } = config
|
||||
|
||||
while (retryState.attempt <= (rateLimitConfig.maxRetries || 0)) {
|
||||
try {
|
||||
log.info("Stream attempt", {
|
||||
attempt: retryState.attempt,
|
||||
maxRetries: rateLimitConfig.maxRetries,
|
||||
providerID: input.model.providerID,
|
||||
modelID: input.model.id,
|
||||
})
|
||||
|
||||
const language = await getLanguage()
|
||||
|
||||
const result = streamText({
|
||||
onError(error) {
|
||||
log.error("Stream error", { error, attempt: retryState.attempt })
|
||||
retryState.lastError = error
|
||||
},
|
||||
async experimental_repairToolCall(failed) {
|
||||
const lower = failed.toolCall.toolName.toLowerCase()
|
||||
if (lower !== failed.toolCall.toolName && tools[lower]) {
|
||||
log.info("repairing tool call", {
|
||||
tool: failed.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
return {
|
||||
...failed.toolCall,
|
||||
toolName: lower,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...failed.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: failed.toolCall.toolName,
|
||||
error: failed.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
},
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
topK: params.topK,
|
||||
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
tools,
|
||||
maxOutputTokens,
|
||||
abortSignal: input.abort,
|
||||
headers: {
|
||||
...(input.model.providerID.startsWith("opencode")
|
||||
? {
|
||||
"x-opencode-project": Instance.project.id,
|
||||
"x-opencode-session": input.sessionID,
|
||||
"x-opencode-request": input.user.id,
|
||||
"x-opencode-client": Flag.OPENCODE_CLIENT,
|
||||
}
|
||||
: input.model.providerID !== "anthropic"
|
||||
? {
|
||||
"User-Agent": `opencode/${Installation.VERSION}`,
|
||||
}
|
||||
: undefined),
|
||||
...input.model.headers,
|
||||
...headers,
|
||||
},
|
||||
maxRetries: 0,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...input.messages,
|
||||
],
|
||||
model: wrapLanguageModel({
|
||||
model: language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model, options)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
experimental_telemetry: {
|
||||
isEnabled: cfg.experimental?.openTelemetry,
|
||||
metadata: {
|
||||
userId: cfg.username ?? "unknown",
|
||||
sessionId: input.sessionID,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
ProviderSwitch.recordSuccess(input.model.providerID, input.model.id)
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
retryState.lastError = error as Error
|
||||
ProviderSwitch.recordFailure(input.model.providerID, input.model.id)
|
||||
|
||||
if (!RateLimit.isRateLimitError(error)) {
|
||||
log.error("Non-rate-limit error, not retrying", { error })
|
||||
throw error
|
||||
}
|
||||
|
||||
const errorWithCause = error as any
|
||||
const headers = errorWithCause?.cause?.headers
|
||||
const parsedRateLimit = headers ? RateLimit.parseHeaders(headers) : null
|
||||
|
||||
if (!parsedRateLimit) {
|
||||
log.warn("Rate limit error but could not parse headers", { error })
|
||||
|
||||
if (retryState.attempt >= (rateLimitConfig.maxRetries || 0)) {
|
||||
throw error
|
||||
}
|
||||
|
||||
const backoffMs = RateLimit.calculateBackoff(retryState.attempt, rateLimitConfig.backoffStrategy)
|
||||
retryState.totalWaitTime += backoffMs
|
||||
|
||||
if (retryState.totalWaitTime > rateLimitConfig.maxWaitTime) {
|
||||
log.warn("Max wait time exceeded, giving up", {
|
||||
totalWaitTime: retryState.totalWaitTime,
|
||||
maxWaitTime: rateLimitConfig.maxWaitTime,
|
||||
})
|
||||
throw error
|
||||
}
|
||||
|
||||
log.info("Waiting before retry (no rate limit headers)", {
|
||||
attempt: retryState.attempt,
|
||||
backoffMs,
|
||||
totalWaitTime: retryState.totalWaitTime,
|
||||
})
|
||||
|
||||
await new Promise<void>((resolve) => setTimeout(resolve, backoffMs))
|
||||
retryState.attempt++
|
||||
continue
|
||||
}
|
||||
|
||||
if (parsedRateLimit.type === "short-term") {
|
||||
if (retryState.attempt >= (rateLimitConfig.maxRetries || 0)) {
|
||||
log.warn("Max retries exceeded", {
|
||||
attempt: retryState.attempt,
|
||||
maxRetries: rateLimitConfig.maxRetries,
|
||||
})
|
||||
throw error
|
||||
}
|
||||
|
||||
retryState.totalWaitTime += parsedRateLimit.waitTime
|
||||
|
||||
if (retryState.totalWaitTime > rateLimitConfig.maxWaitTime) {
|
||||
log.warn("Max wait time exceeded", {
|
||||
totalWaitTime: retryState.totalWaitTime,
|
||||
maxWaitTime: rateLimitConfig.maxWaitTime,
|
||||
})
|
||||
throw error
|
||||
}
|
||||
|
||||
log.info("Short-term rate limit, retrying", {
|
||||
attempt: retryState.attempt,
|
||||
waitTime: parsedRateLimit.waitTime,
|
||||
retryAt: parsedRateLimit.retryAt,
|
||||
})
|
||||
|
||||
await RateLimit.wait(parsedRateLimit, input.abort)
|
||||
retryState.attempt++
|
||||
} else if (
|
||||
rateLimitConfig.enableProviderSwitch &&
|
||||
parsedRateLimit.waitTime >= rateLimitConfig.switchThreshold &&
|
||||
!retryState.switchedProvider &&
|
||||
input.allProviders
|
||||
) {
|
||||
log.info("Long-term rate limit, switching provider", {
|
||||
waitTime: parsedRateLimit.waitTime,
|
||||
switchThreshold: rateLimitConfig.switchThreshold,
|
||||
})
|
||||
|
||||
const nextProvider = await ProviderSwitch.trySwitchProvider(
|
||||
input.model.providerID,
|
||||
input.model.id,
|
||||
input.allProviders,
|
||||
)
|
||||
|
||||
if (nextProvider) {
|
||||
return streamWithProvider(
|
||||
{
|
||||
...input,
|
||||
model: await Provider.getModel(nextProvider.providerID, nextProvider.modelID),
|
||||
},
|
||||
retryState,
|
||||
)
|
||||
}
|
||||
|
||||
log.warn("No fallback provider available, waiting", { providerID: input.model.providerID })
|
||||
|
||||
if (parsedRateLimit.waitTime <= rateLimitConfig.maxWaitTime) {
|
||||
await RateLimit.wait(parsedRateLimit, input.abort)
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
} else {
|
||||
log.warn("Rate limit but no retry/switch possible", {
|
||||
type: parsedRateLimit.type,
|
||||
waitTime: parsedRateLimit.waitTime,
|
||||
maxWaitTime: rateLimitConfig.maxWaitTime,
|
||||
enableProviderSwitch: rateLimitConfig.enableProviderSwitch,
|
||||
switchThreshold: rateLimitConfig.switchThreshold,
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw retryState.lastError || new Error("Max retries exceeded")
|
||||
}
|
||||
|
||||
async function streamWithProvider(
|
||||
input: StreamInput,
|
||||
retryState: RetryState = { attempt: 0, totalWaitTime: 0, lastError: null, switchedProvider: false },
|
||||
): Promise<StreamOutput> {
|
||||
const l = log
|
||||
.clone()
|
||||
.tag("providerID", input.model.providerID)
|
||||
@@ -52,10 +300,12 @@ export namespace LLM {
|
||||
.tag("small", (input.small ?? false).toString())
|
||||
.tag("agent", input.agent.name)
|
||||
.tag("mode", input.agent.mode)
|
||||
|
||||
l.info("stream", {
|
||||
modelID: input.model.id,
|
||||
providerID: input.model.providerID,
|
||||
})
|
||||
|
||||
const [language, cfg, provider, auth] = await Promise.all([
|
||||
Provider.getLanguage(input.model),
|
||||
Config.get(),
|
||||
@@ -67,12 +317,8 @@ export namespace LLM {
|
||||
const system = []
|
||||
system.push(
|
||||
[
|
||||
// use agent prompt otherwise provider prompt
|
||||
// For Codex sessions, skip SystemPrompt.provider() since it's sent via options.instructions
|
||||
...(input.agent.prompt ? [input.agent.prompt] : isCodex ? [] : SystemPrompt.provider(input.model)),
|
||||
// any custom prompt passed into this call
|
||||
...input.system,
|
||||
// any custom prompt from last user message
|
||||
...(input.user.system ? [input.user.system] : []),
|
||||
]
|
||||
.filter((x) => x)
|
||||
@@ -89,7 +335,6 @@ export namespace LLM {
|
||||
if (system.length === 0) {
|
||||
system.push(...original)
|
||||
}
|
||||
// rejoin to maintain 2-part structure for caching if header unchanged
|
||||
if (system.length > 2 && system[0] === header) {
|
||||
const rest = system.slice(1)
|
||||
system.length = 0
|
||||
@@ -134,7 +379,7 @@ export namespace LLM {
|
||||
},
|
||||
)
|
||||
|
||||
const { headers } = await Plugin.trigger(
|
||||
const { headers: customHeaders } = await Plugin.trigger(
|
||||
"chat.headers",
|
||||
{
|
||||
sessionID: input.sessionID,
|
||||
@@ -158,14 +403,8 @@ export namespace LLM {
|
||||
OUTPUT_TOKEN_MAX,
|
||||
)
|
||||
|
||||
const tools = await resolveTools(input)
|
||||
const tools = await resolveTools(input as Pick<StreamInput, "tools" | "agent" | "user">)
|
||||
|
||||
// LiteLLM and some Anthropic proxies require the tools parameter to be present
|
||||
// when message history contains tool calls, even if no tools are being used.
|
||||
// Add a dummy tool that is never called to satisfy this validation.
|
||||
// This is enabled for:
|
||||
// 1. Providers with "litellm" in their ID or API ID (auto-detected)
|
||||
// 2. Providers with explicit "litellmProxy: true" option (opt-in for custom gateways)
|
||||
const isLiteLLMProxy =
|
||||
provider.options?.["litellmProxy"] === true ||
|
||||
input.model.providerID.toLowerCase().includes("litellm") ||
|
||||
@@ -180,89 +419,30 @@ export namespace LLM {
|
||||
})
|
||||
}
|
||||
|
||||
return streamText({
|
||||
onError(error) {
|
||||
l.error("stream error", {
|
||||
error,
|
||||
})
|
||||
const rateLimitConfig = RateLimit.parseConfig(cfg.rateLimit || {})
|
||||
|
||||
return streamWithRetry(
|
||||
input,
|
||||
{
|
||||
getLanguage: () => Provider.getLanguage(input.model),
|
||||
cfg,
|
||||
provider,
|
||||
auth,
|
||||
isCodex,
|
||||
system,
|
||||
params,
|
||||
options,
|
||||
headers: customHeaders,
|
||||
maxOutputTokens,
|
||||
tools,
|
||||
rateLimitConfig,
|
||||
},
|
||||
async experimental_repairToolCall(failed) {
|
||||
const lower = failed.toolCall.toolName.toLowerCase()
|
||||
if (lower !== failed.toolCall.toolName && tools[lower]) {
|
||||
l.info("repairing tool call", {
|
||||
tool: failed.toolCall.toolName,
|
||||
repaired: lower,
|
||||
})
|
||||
return {
|
||||
...failed.toolCall,
|
||||
toolName: lower,
|
||||
}
|
||||
}
|
||||
return {
|
||||
...failed.toolCall,
|
||||
input: JSON.stringify({
|
||||
tool: failed.toolCall.toolName,
|
||||
error: failed.error.message,
|
||||
}),
|
||||
toolName: "invalid",
|
||||
}
|
||||
},
|
||||
temperature: params.temperature,
|
||||
topP: params.topP,
|
||||
topK: params.topK,
|
||||
providerOptions: ProviderTransform.providerOptions(input.model, params.options),
|
||||
activeTools: Object.keys(tools).filter((x) => x !== "invalid"),
|
||||
tools,
|
||||
maxOutputTokens,
|
||||
abortSignal: input.abort,
|
||||
headers: {
|
||||
...(input.model.providerID.startsWith("opencode")
|
||||
? {
|
||||
"x-opencode-project": Instance.project.id,
|
||||
"x-opencode-session": input.sessionID,
|
||||
"x-opencode-request": input.user.id,
|
||||
"x-opencode-client": Flag.OPENCODE_CLIENT,
|
||||
}
|
||||
: input.model.providerID !== "anthropic"
|
||||
? {
|
||||
"User-Agent": `opencode/${Installation.VERSION}`,
|
||||
}
|
||||
: undefined),
|
||||
...input.model.headers,
|
||||
...headers,
|
||||
},
|
||||
maxRetries: input.retries ?? 0,
|
||||
messages: [
|
||||
...system.map(
|
||||
(x): ModelMessage => ({
|
||||
role: "system",
|
||||
content: x,
|
||||
}),
|
||||
),
|
||||
...input.messages,
|
||||
],
|
||||
model: wrapLanguageModel({
|
||||
model: language,
|
||||
middleware: [
|
||||
{
|
||||
async transformParams(args) {
|
||||
if (args.type === "stream") {
|
||||
// @ts-expect-error
|
||||
args.params.prompt = ProviderTransform.message(args.params.prompt, input.model, options)
|
||||
}
|
||||
return args.params
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
experimental_telemetry: {
|
||||
isEnabled: cfg.experimental?.openTelemetry,
|
||||
metadata: {
|
||||
userId: cfg.username ?? "unknown",
|
||||
sessionId: input.sessionID,
|
||||
},
|
||||
},
|
||||
})
|
||||
retryState,
|
||||
)
|
||||
}
|
||||
|
||||
export async function stream(input: StreamInput) {
|
||||
return streamWithProvider(input)
|
||||
}
|
||||
|
||||
async function resolveTools(input: Pick<StreamInput, "tools" | "agent" | "user">) {
|
||||
@@ -275,8 +455,6 @@ export namespace LLM {
|
||||
return input.tools
|
||||
}
|
||||
|
||||
// Check if messages contain any tool-call content
|
||||
// Used to determine if a dummy tool should be added for LiteLLM proxy compatibility
|
||||
export function hasToolCalls(messages: ModelMessage[]): boolean {
|
||||
for (const msg of messages) {
|
||||
if (!Array.isArray(msg.content)) continue
|
||||
|
||||
@@ -271,6 +271,7 @@ export namespace MessageV2 {
|
||||
status: z.literal("error"),
|
||||
input: z.record(z.string(), z.any()),
|
||||
error: z.string(),
|
||||
errorType: z.enum(["validation", "permission", "timeout", "notFound", "execution"]).optional(),
|
||||
metadata: z.record(z.string(), z.any()).optional(),
|
||||
time: z.object({
|
||||
start: z.number(),
|
||||
|
||||
@@ -16,6 +16,42 @@ import { SessionCompaction } from "./compaction"
|
||||
import { PermissionNext } from "@/permission/next"
|
||||
import { Question } from "@/question"
|
||||
|
||||
enum ToolErrorType {
|
||||
validation = "validation",
|
||||
permission = "permission",
|
||||
timeout = "timeout",
|
||||
notFound = "notFound",
|
||||
execution = "execution"
|
||||
}
|
||||
|
||||
function classifyToolError(error: unknown): ToolErrorType {
|
||||
const message = String(error).toLowerCase()
|
||||
|
||||
if (
|
||||
message.includes("validation") ||
|
||||
message.includes("schema") ||
|
||||
message.includes("invalid arguments") ||
|
||||
message.includes("format")
|
||||
) {
|
||||
return ToolErrorType.validation
|
||||
}
|
||||
if (
|
||||
message.includes("permission") ||
|
||||
message.includes("forbidden") ||
|
||||
message.includes("denied") ||
|
||||
message.includes("unauthorized")
|
||||
) {
|
||||
return ToolErrorType.permission
|
||||
}
|
||||
if (message.includes("timeout") || message.includes("timed out")) {
|
||||
return ToolErrorType.timeout
|
||||
}
|
||||
if (message.includes("not found") || message.includes("does not exist") || message.includes("not exist")) {
|
||||
return ToolErrorType.notFound
|
||||
}
|
||||
return ToolErrorType.execution
|
||||
}
|
||||
|
||||
export namespace SessionProcessor {
|
||||
const DOOM_LOOP_THRESHOLD = 3
|
||||
const log = Log.create({ service: "session.processor" })
|
||||
@@ -42,7 +78,7 @@ export namespace SessionProcessor {
|
||||
partFromToolCall(toolCallID: string) {
|
||||
return toolcalls[toolCallID]
|
||||
},
|
||||
async process(streamInput: LLM.StreamInput) {
|
||||
async process(streamInput: LLM.StreamInput & { allProviders?: Record<string, Provider.Info> }) {
|
||||
log.info("process")
|
||||
needsCompaction = false
|
||||
const shouldBreak = (await Config.get()).experimental?.continue_loop_on_deny !== true
|
||||
@@ -50,7 +86,10 @@ export namespace SessionProcessor {
|
||||
try {
|
||||
let currentText: MessageV2.TextPart | undefined
|
||||
let reasoningMap: Record<string, MessageV2.ReasoningPart> = {}
|
||||
const stream = await LLM.stream(streamInput)
|
||||
const stream = await LLM.stream({
|
||||
...streamInput,
|
||||
allProviders: streamInput.allProviders || await Provider.list(),
|
||||
})
|
||||
|
||||
for await (const value of stream.fullStream) {
|
||||
input.abort.throwIfAborted()
|
||||
@@ -202,6 +241,7 @@ export namespace SessionProcessor {
|
||||
status: "error",
|
||||
input: value.input ?? match.state.input,
|
||||
error: (value.error as any).toString(),
|
||||
errorType: classifyToolError(value.error),
|
||||
time: {
|
||||
start: match.state.time.start,
|
||||
end: Date.now(),
|
||||
@@ -215,6 +255,9 @@ export namespace SessionProcessor {
|
||||
) {
|
||||
blocked = shouldBreak
|
||||
}
|
||||
|
||||
(value.error as any).isToolError = true
|
||||
|
||||
delete toolcalls[value.toolCallId]
|
||||
}
|
||||
break
|
||||
|
||||
@@ -607,6 +607,7 @@ export namespace SessionPrompt {
|
||||
}
|
||||
}
|
||||
|
||||
const allProviders = await Provider.list()
|
||||
await Plugin.trigger("experimental.chat.messages.transform", {}, { messages: sessionMessages })
|
||||
|
||||
const result = await processor.process({
|
||||
@@ -628,6 +629,7 @@ export namespace SessionPrompt {
|
||||
],
|
||||
tools,
|
||||
model,
|
||||
allProviders,
|
||||
})
|
||||
if (result === "stop") break
|
||||
if (result === "compact") {
|
||||
@@ -1825,6 +1827,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||
(await Provider.getSmallModel(input.providerID)) ?? (await Provider.getModel(input.providerID, input.modelID))
|
||||
)
|
||||
})
|
||||
const allProviders = await Provider.list()
|
||||
const result = await LLM.stream({
|
||||
agent,
|
||||
user: firstRealUser.info as MessageV2.User,
|
||||
@@ -1844,6 +1847,7 @@ NOTE: At any point in time through this workflow you should feel free to ask the
|
||||
? [{ role: "user" as const, content: subtaskParts.map((p) => p.prompt).join("\n") }]
|
||||
: MessageV2.toModelMessages(contextMessages, model)),
|
||||
],
|
||||
allProviders,
|
||||
})
|
||||
const text = await result.text.catch((err) => log.error("failed to generate title", { error: err }))
|
||||
if (text)
|
||||
|
||||
@@ -134,6 +134,7 @@ export namespace SessionSummary {
|
||||
if (textPart && !userMsg.summary?.title) {
|
||||
const agent = await Agent.get("title")
|
||||
if (!agent) return
|
||||
const allProviders = await Provider.list()
|
||||
const stream = await LLM.stream({
|
||||
agent,
|
||||
user: userMsg,
|
||||
@@ -158,6 +159,7 @@ export namespace SessionSummary {
|
||||
sessionID: userMsg.sessionID,
|
||||
system: [],
|
||||
retries: 3,
|
||||
allProviders,
|
||||
})
|
||||
const result = await stream.text
|
||||
log.info("title", { title: result })
|
||||
|
||||
Reference in New Issue
Block a user