From 3c3000c25ff70474306802c5af5ae568faf51230 Mon Sep 17 00:00:00 2001 From: Mikael Hugo Date: Sat, 2 May 2026 13:08:41 +0200 Subject: [PATCH] fix(auth): use gemini cli credentials outside sf store --- .../src/core/model-registry-auth-mode.test.ts | 409 ++++++++++++----- .../src/core/model-registry.ts | 433 +++++++++++++----- .../src/core/retry-handler.test.ts | 302 +++++++++--- .../pi-coding-agent/src/core/retry-handler.ts | 123 +++-- 4 files changed, 959 insertions(+), 308 deletions(-) diff --git a/packages/pi-coding-agent/src/core/model-registry-auth-mode.test.ts b/packages/pi-coding-agent/src/core/model-registry-auth-mode.test.ts index 317f516a8..3aa7419d3 100644 --- a/packages/pi-coding-agent/src/core/model-registry-auth-mode.test.ts +++ b/packages/pi-coding-agent/src/core/model-registry-auth-mode.test.ts @@ -1,11 +1,19 @@ import assert from "node:assert/strict"; -import { describe, it } from 'vitest'; -import type { Api, Model, SimpleStreamOptions, Context, AssistantMessageEventStream } from "@singularity-forge/pi-ai"; +import type { + Api, + AssistantMessageEventStream, + Context, + Model, + SimpleStreamOptions, +} from "@singularity-forge/pi-ai"; import { getApiProvider } from "@singularity-forge/pi-ai"; +import { describe, it } from "vitest"; import { AuthStorage, type AuthStorageData } from "./auth-storage.js"; import { ModelRegistry } from "./model-registry.js"; -function createRegistry(hasAuthFn?: (provider: string) => boolean): ModelRegistry { +function createRegistry( + hasAuthFn?: (provider: string) => boolean, +): ModelRegistry { const authStorage = { setFallbackResolver: () => {}, onCredentialChange: () => {}, @@ -22,7 +30,12 @@ function createInMemoryRegistry(data: AuthStorageData = {}): ModelRegistry { return new ModelRegistry(AuthStorage.inMemory(data), undefined); } -function createProviderModel(id: string, api?: string): NonNullable[1]["models"]>[number] { +function createProviderModel( + id: string, + api?: string, +): NonNullable< + Parameters[1]["models"] +>[number] { return { id, name: id, @@ -35,8 +48,14 @@ function createProviderModel(id: string, api?: string): NonNullable | undefined { - return registry.getAvailable().find((m) => m.provider === provider && m.id === id); +function findModel( + registry: ModelRegistry, + provider: string, + id: string, +): Model | undefined { + return registry + .getAvailable() + .find((m) => m.provider === provider && m.id === id); } function makeModel(provider: string, id: string, api: string): Model { @@ -62,10 +81,33 @@ function makeContext(): Context { } /** No-op streamSimple for tests that need one to pass validation but don't inspect it. */ -const noopStreamSimple = (_model: Model, _context: Context, _options?: SimpleStreamOptions) => { +const noopStreamSimple = ( + _model: Model, + _context: Context, + _options?: SimpleStreamOptions, +) => { return { - [Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; }, - result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }), + [Symbol.asyncIterator]() { + return { next: async () => ({ value: undefined, done: true as const }) }; + }, + result: () => + Promise.resolve({ + role: "assistant" as const, + content: [], + api: "test" as Api, + provider: "test", + model: "test", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop" as const, + timestamp: Date.now(), + }), push: () => {}, end: () => {}, } as unknown as AssistantMessageEventStream; @@ -73,16 +115,51 @@ const noopStreamSimple = (_model: Model, _context: Context, _options?: Simp /** Create a spy streamSimple that captures the options it receives and returns a stub stream. */ function createStreamSpy(): { - streamSimple: (model: Model, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream; + streamSimple: ( + model: Model, + context: Context, + options?: SimpleStreamOptions, + ) => AssistantMessageEventStream; getCapturedOptions: () => SimpleStreamOptions | undefined; } { let capturedOptions: SimpleStreamOptions | undefined; - const streamSimple = (_model: Model, _context: Context, options?: SimpleStreamOptions) => { + const streamSimple = ( + _model: Model, + _context: Context, + options?: SimpleStreamOptions, + ) => { capturedOptions = options; // Return a minimal stub that satisfies AssistantMessageEventStream return { - [Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; }, - result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }), + [Symbol.asyncIterator]() { + return { + next: async () => ({ value: undefined, done: true as const }), + }; + }, + result: () => + Promise.resolve({ + role: "assistant" as const, + content: [], + api: "test" as Api, + provider: "test", + model: "test", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + total: 0, + }, + }, + stopReason: "stop" as const, + timestamp: Date.now(), + }), push: () => {}, end: () => {}, } as unknown as AssistantMessageEventStream; @@ -123,102 +200,153 @@ describe("ModelRegistry authMode — registration", () => { it("rejects apiKey provider without apiKey or oauth — message mentions authMode", () => { const registry = createRegistry(); - assert.throws(() => { - registry.registerProvider("apikey-provider", { - authMode: "apiKey", - baseUrl: "https://api.local", - api: "openai-completions", - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("authMode"), "error message must mention authMode"); - assert.ok(err.message.includes("externalCli"), "error message must suggest externalCli"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("apikey-provider", { + authMode: "apiKey", + baseUrl: "https://api.local", + api: "openai-completions", + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("authMode"), + "error message must mention authMode", + ); + assert.ok( + err.message.includes("externalCli"), + "error message must suggest externalCli", + ); + return true; + }, + ); }); it("rejects provider with no authMode and no apiKey/oauth (defaults to apiKey)", () => { const registry = createRegistry(); - assert.throws(() => { - registry.registerProvider("bare-provider", { - baseUrl: "https://api.local", - api: "openai-completions", - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("authMode"), "error message must mention authMode"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("bare-provider", { + baseUrl: "https://api.local", + api: "openai-completions", + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("authMode"), + "error message must mention authMode", + ); + return true; + }, + ); }); it("rejects externalCli provider without streamSimple", () => { const registry = createRegistry(); - assert.throws(() => { - registry.registerProvider("cli-no-stream", { - authMode: "externalCli", - baseUrl: "https://cli.local", - api: "openai-completions", - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple"); - assert.ok(err.message.includes("externalCli"), "error message must mention authMode"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("cli-no-stream", { + authMode: "externalCli", + baseUrl: "https://cli.local", + api: "openai-completions", + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("streamSimple"), + "error message must mention streamSimple", + ); + assert.ok( + err.message.includes("externalCli"), + "error message must mention authMode", + ); + return true; + }, + ); }); it("rejects none provider without streamSimple", () => { const registry = createRegistry(); - assert.throws(() => { - registry.registerProvider("none-no-stream", { - authMode: "none", - baseUrl: "http://localhost:11434", - api: "openai-completions", - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple"); - assert.ok(err.message.includes("none"), "error message must mention authMode"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("none-no-stream", { + authMode: "none", + baseUrl: "http://localhost:11434", + api: "openai-completions", + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("streamSimple"), + "error message must mention streamSimple", + ); + assert.ok( + err.message.includes("none"), + "error message must mention authMode", + ); + return true; + }, + ); }); it("rejects externalCli provider that also sets apiKey", () => { const registry = createRegistry(); const spy = createStreamSpy(); - assert.throws(() => { - registry.registerProvider("cli-with-key", { - authMode: "externalCli", - baseUrl: "https://cli.local", - api: "openai-completions", - apiKey: "SHOULD_NOT_EXIST", - streamSimple: spy.streamSimple, - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("apiKey"), "error message must mention apiKey"); - assert.ok(err.message.includes("externalCli"), "error message must mention authMode"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("cli-with-key", { + authMode: "externalCli", + baseUrl: "https://cli.local", + api: "openai-completions", + apiKey: "SHOULD_NOT_EXIST", + streamSimple: spy.streamSimple, + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("apiKey"), + "error message must mention apiKey", + ); + assert.ok( + err.message.includes("externalCli"), + "error message must mention authMode", + ); + return true; + }, + ); }); it("rejects none provider that also sets apiKey", () => { const registry = createRegistry(); const spy = createStreamSpy(); - assert.throws(() => { - registry.registerProvider("none-with-key", { - authMode: "none", - baseUrl: "http://localhost:11434", - api: "openai-completions", - apiKey: "SHOULD_NOT_EXIST", - streamSimple: spy.streamSimple, - models: [createProviderModel("model")], - }); - }, (err: Error) => { - assert.ok(err.message.includes("apiKey"), "error message must mention apiKey"); - assert.ok(err.message.includes("none"), "error message must mention authMode"); - return true; - }); + assert.throws( + () => { + registry.registerProvider("none-with-key", { + authMode: "none", + baseUrl: "http://localhost:11434", + api: "openai-completions", + apiKey: "SHOULD_NOT_EXIST", + streamSimple: spy.streamSimple, + models: [createProviderModel("model")], + }); + }, + (err: Error) => { + assert.ok( + err.message.includes("apiKey"), + "error message must mention apiKey", + ); + assert.ok( + err.message.includes("none"), + "error message must mention authMode", + ); + return true; + }, + ); }); }); @@ -230,6 +358,14 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => { assert.equal(registry.getProviderAuthMode("anthropic"), "apiKey"); }); + it("treats google-gemini-cli as external CLI auth", () => { + const registry = createRegistry(); + assert.equal( + registry.getProviderAuthMode("google-gemini-cli"), + "externalCli", + ); + }); + it("returns explicit authMode when set", () => { const registry = createRegistry(); registry.registerProvider("cli", { @@ -258,6 +394,11 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => { // ─── isProviderRequestReady ─────────────────────────────────────────────────── describe("ModelRegistry authMode — isProviderRequestReady", () => { + it("returns true for google-gemini-cli without .sf stored auth", () => { + const registry = createRegistry(() => false); + assert.equal(registry.isProviderRequestReady("google-gemini-cli"), true); + }); + it("returns true for externalCli without stored auth", () => { const registry = createRegistry(() => false); registry.registerProvider("cli", { @@ -391,7 +532,10 @@ describe("ModelRegistry authMode — getAvailable", () => { it("excludes apiKey models without stored auth", () => { const registry = createRegistry(() => false); const available = registry.getAvailable(); - assert.equal(available.length, 0); + assert.equal( + available.filter((m) => m.provider !== "google-gemini-cli").length, + 0, + ); }); it("prunes Codex models removed from ChatGPT-backed openai-codex OAuth", () => { @@ -407,7 +551,10 @@ describe("ModelRegistry authMode — getAvailable", () => { assert.equal(registry.find("openai-codex", "gpt-5.1-codex-max"), undefined); assert.equal(registry.find("openai-codex", "gpt-5.1"), undefined); - assert.equal(findModel(registry, "openai-codex", "gpt-5.2-codex"), undefined); + assert.equal( + findModel(registry, "openai-codex", "gpt-5.2-codex"), + undefined, + ); assert.ok(registry.find("openai-codex", "gpt-5.4")); assert.ok(findModel(registry, "openai-codex", "gpt-5.4")); }); @@ -428,6 +575,22 @@ describe("ModelRegistry authMode — getAvailable", () => { // ─── getApiKey ──────────────────────────────────────────────────────────────── describe("ModelRegistry authMode — getApiKey", () => { + it("returns undefined for google-gemini-cli even when stale .sf auth exists", async () => { + const registry = createInMemoryRegistry({ + "google-gemini-cli": { + type: "oauth", + access: "", + refresh: "", + expires: 0, + }, + }); + + assert.equal( + await registry.getApiKeyForProvider("google-gemini-cli"), + undefined, + ); + }); + it("returns undefined for externalCli provider", async () => { const registry = createRegistry(); registry.registerProvider("cli", { @@ -480,15 +643,18 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => { const provider = getApiProvider(apiType as Api); assert.ok(provider, "provider must be registered in api registry"); - provider.streamSimple( - makeModel("cli-strip", "m", apiType), - makeContext(), - { apiKey: "should-be-stripped", maxTokens: 1024 } as SimpleStreamOptions, - ); + provider.streamSimple(makeModel("cli-strip", "m", apiType), makeContext(), { + apiKey: "should-be-stripped", + maxTokens: 1024, + } as SimpleStreamOptions); const captured = spy.getCapturedOptions(); assert.ok(captured, "streamSimple must have been called"); - assert.equal("apiKey" in captured, false, "apiKey must not exist in options for externalCli provider"); + assert.equal( + "apiKey" in captured, + false, + "apiKey must not exist in options for externalCli provider", + ); assert.equal(captured.maxTokens, 1024, "other options must pass through"); }); @@ -516,7 +682,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => { const captured = spy.getCapturedOptions(); assert.ok(captured, "streamSimple must have been called"); - assert.equal("apiKey" in captured, false, "apiKey must not exist in options for none provider"); + assert.equal( + "apiKey" in captured, + false, + "apiKey must not exist in options for none provider", + ); assert.equal(captured.maxTokens, 2048, "other options must pass through"); }); @@ -544,7 +714,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => { const captured = spy.getCapturedOptions(); assert.ok(captured, "streamSimple must have been called"); - assert.equal(captured.apiKey, "sk-real-key", "apiKey must be preserved for apiKey provider"); + assert.equal( + captured.apiKey, + "sk-real-key", + "apiKey must be preserved for apiKey provider", + ); assert.equal(captured.maxTokens, 4096, "other options must pass through"); }); @@ -572,7 +746,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => { const captured = spy.getCapturedOptions(); assert.ok(captured !== undefined, "streamSimple must have been called"); - assert.equal("apiKey" in captured, false, "apiKey must not exist even when options is undefined"); + assert.equal( + "apiKey" in captured, + false, + "apiKey must not exist even when options is undefined", + ); }); it("strips apiKey but preserves signal and other fields for externalCli", () => { @@ -595,15 +773,28 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => { provider.streamSimple( makeModel("cli-fields", "m", apiType), makeContext(), - { apiKey: "strip-me", maxTokens: 8192, signal: abortController.signal, reasoning: "high" } as SimpleStreamOptions, + { + apiKey: "strip-me", + maxTokens: 8192, + signal: abortController.signal, + reasoning: "high", + } as SimpleStreamOptions, ); const captured = spy.getCapturedOptions(); assert.ok(captured, "streamSimple must have been called"); assert.equal("apiKey" in captured, false, "apiKey must be stripped"); assert.equal(captured.maxTokens, 8192, "maxTokens must pass through"); - assert.equal(captured.signal, abortController.signal, "signal must pass through"); - assert.equal((captured as Record).reasoning, "high", "reasoning must pass through"); + assert.equal( + captured.signal, + abortController.signal, + "signal must pass through", + ); + assert.equal( + (captured as Record).reasoning, + "high", + "reasoning must pass through", + ); }); }); @@ -633,11 +824,12 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => { // The built-in handler will throw (no API key), which proves the routing // correctly delegates to the built-in instead of the custom handler. assert.throws( - () => provider.streamSimple( - makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"), - makeContext(), - { maxTokens: 4096 } as SimpleStreamOptions, - ), + () => + provider.streamSimple( + makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"), + makeContext(), + { maxTokens: 4096 } as SimpleStreamOptions, + ), (err: Error) => err.message.includes("API key"), "built-in Anthropic handler must be invoked (throws because no API key in tests)", ); @@ -672,7 +864,10 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => { ); const captured = customSpy.getCapturedOptions(); - assert.ok(captured, "custom provider's streamSimple must be called for its own models"); + assert.ok( + captured, + "custom provider's streamSimple must be called for its own models", + ); assert.equal(captured.maxTokens, 2048); }); }); diff --git a/packages/pi-coding-agent/src/core/model-registry.ts b/packages/pi-coding-agent/src/core/model-registry.ts index 5d702a6bf..4cf28af85 100644 --- a/packages/pi-coding-agent/src/core/model-registry.ts +++ b/packages/pi-coding-agent/src/core/model-registry.ts @@ -2,10 +2,11 @@ * Model registry - manages built-in and custom models, provides API key resolution. */ +import { type Static, Type } from "@sinclair/typebox"; import { type Api, - applyCapabilityPatches, type AssistantMessageEventStream, + applyCapabilityPatches, type Context, getApiProvider, getEnvApiKey, @@ -20,14 +21,17 @@ import { resetApiProviders, type SimpleStreamOptions, } from "@singularity-forge/pi-ai"; -import { registerOAuthProvider, resetOAuthProviders } from "@singularity-forge/pi-ai/oauth"; -import { type Static, Type } from "@sinclair/typebox"; +import { + registerOAuthProvider, + resetOAuthProviders, +} from "@singularity-forge/pi-ai/oauth"; import AjvModule from "ajv"; import { existsSync, readFileSync } from "fs"; import { join } from "path"; import { getAgentDir } from "../config.js"; import type { AuthStorage } from "./auth-storage.js"; import { ModelDiscoveryCache } from "./discovery-cache.js"; +import { isLocalModel } from "./local-model-check.js"; import type { DiscoveryResult } from "./model-discovery.js"; import { getDiscoverableCatalogSources, @@ -35,7 +39,6 @@ import { getDiscoveryAdapter, } from "./model-discovery.js"; import { resolveConfigValue, resolveHeaders } from "./resolve-config-value.js"; -import { isLocalModel } from "./local-model-check.js"; const Ajv = (AjvModule as any).default || AjvModule; const ajv = new Ajv(); @@ -70,27 +73,33 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{ family_failover?: string[]; }> = [ // MiniMax direct (api.minimax.io) → CN endpoint as its direct pair - { match: /^MiniMax-/i, prefix: "MiniMax-", providers: ["minimax", "minimax-cn"] }, + { + match: /^MiniMax-/i, + prefix: "MiniMax-", + providers: ["minimax", "minimax-cn"], + }, // ZAI direct API for GLM - { match: /^glm-/i, prefix: "glm-", providers: ["zai"] }, + { match: /^glm-/i, prefix: "glm-", providers: ["zai"] }, // Kimi Code direct API - { match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] }, + { match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] }, // MiMo/Xiaomi — direct API via Xiaomi MiMo Open Platform (api.xiaomimimo.com) // or the Token Plan endpoint (token-plan-sgp.xiaomimimo.com). Both served // under the `xiaomi` provider namespace. - { match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] }, - // Gemini/Gemma: google-gemini-cli (OAuth), google (API key), google-vertex - // are all FIRST-PARTY Google endpoints. github-copilot re-serves and is - // failover only. + { match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] }, + // Gemini/Gemma: google-gemini-cli (CLI OAuth via ~/.gemini), google + // (API key), google-vertex are all FIRST-PARTY Google endpoints. + // github-copilot re-serves and is failover only. { - match: /^gemini-|^gemma-/i, prefix: "gemini-", + match: /^gemini-|^gemma-/i, + prefix: "gemini-", providers: ["google-gemini-cli", "google", "google-vertex"], family_failover: ["github-copilot"], }, // Claude: Anthropic is the ONLY direct provider. github-copilot re-serves // Claude via GitHub's platform as failover. { - match: /^claude-/i, prefix: "claude-", + match: /^claude-/i, + prefix: "claude-", providers: ["anthropic"], family_failover: ["github-copilot"], }, @@ -99,9 +108,14 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{ // the same weights via a different legal/contractual relationship). // github-copilot likewise re-serves. { - match: /^gpt-|^o\d|^codex-/i, prefix: "gpt-", + match: /^gpt-|^o\d|^codex-/i, + prefix: "gpt-", providers: ["openai"], - family_failover: ["azure-openai-responses", "openai-codex", "github-copilot"], + family_failover: [ + "azure-openai-responses", + "openai-codex", + "github-copilot", + ], }, ]; @@ -123,12 +137,23 @@ const OpenAICompletionsCompatSchema = Type.Object({ supportsDeveloperRole: Type.Optional(Type.Boolean()), supportsReasoningEffort: Type.Optional(Type.Boolean()), supportsUsageInStreaming: Type.Optional(Type.Boolean()), - maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])), + maxTokensField: Type.Optional( + Type.Union([ + Type.Literal("max_completion_tokens"), + Type.Literal("max_tokens"), + ]), + ), requiresToolResultName: Type.Optional(Type.Boolean()), requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()), requiresThinkingAsText: Type.Optional(Type.Boolean()), requiresMistralToolIds: Type.Optional(Type.Boolean()), - thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai"), Type.Literal("qwen")])), + thinkingFormat: Type.Optional( + Type.Union([ + Type.Literal("openai"), + Type.Literal("zai"), + Type.Literal("qwen"), + ]), + ), openRouterRouting: Type.Optional(OpenRouterRoutingSchema), vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema), }); @@ -137,7 +162,10 @@ const OpenAIResponsesCompatSchema = Type.Object({ // Reserved for future use }); -const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]); +const OpenAICompatSchema = Type.Union([ + OpenAICompletionsCompatSchema, + OpenAIResponsesCompatSchema, +]); // Schema for custom model definition // Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.) @@ -147,7 +175,9 @@ const ModelDefinitionSchema = Type.Object({ api: Type.Optional(Type.String({ minLength: 1 })), baseUrl: Type.Optional(Type.String({ minLength: 1 })), reasoning: Type.Optional(Type.Boolean()), - input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))), + input: Type.Optional( + Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])), + ), cost: Type.Optional( Type.Object({ input: Type.Number(), @@ -166,7 +196,9 @@ const ModelDefinitionSchema = Type.Object({ const ModelOverrideSchema = Type.Object({ name: Type.Optional(Type.String({ minLength: 1 })), reasoning: Type.Optional(Type.Boolean()), - input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))), + input: Type.Optional( + Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])), + ), cost: Type.Optional( Type.Object({ input: Type.Optional(Type.Number()), @@ -190,7 +222,9 @@ const ProviderConfigSchema = Type.Object({ headers: Type.Optional(Type.Record(Type.String(), Type.String())), authHeader: Type.Optional(Type.Boolean()), models: Type.Optional(Type.Array(ModelDefinitionSchema)), - modelOverrides: Type.Optional(Type.Record(Type.String(), ModelOverrideSchema)), + modelOverrides: Type.Optional( + Type.Record(Type.String(), ModelOverrideSchema), + ), }); const ModelsConfigSchema = Type.Object({ @@ -205,7 +239,8 @@ export type ProviderModelAllowList = Record; export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none"; -type ProviderPolicyModel = Pick, "provider" | "id"> & Partial, "name" | "cost">>; +type ProviderPolicyModel = Pick, "provider" | "id"> & + Partial, "name" | "cost">>; const OPENCODE_FREE_MODEL_IDS = new Set([ "big-pickle", @@ -221,7 +256,12 @@ const HIDDEN_MODEL_PROVIDERS = new Set([ "xiaomi-token-plan-sgp", ]); -function providerModelAllowEntryMatches(allowedModel: string, modelKey: string): boolean { +const BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS = new Set(["google-gemini-cli"]); + +function providerModelAllowEntryMatches( + allowedModel: string, + modelKey: string, +): boolean { const allowedKey = allowedModel.trim().toLowerCase(); if (!allowedKey) return false; if (allowedKey === modelKey) return true; @@ -239,7 +279,13 @@ function hasFreeSkuMarker(value: string | undefined): boolean { } function isZeroCost(cost: Model["cost"] | undefined): boolean { - return !!cost && cost.input === 0 && cost.output === 0 && cost.cacheRead === 0 && cost.cacheWrite === 0; + return ( + !!cost && + cost.input === 0 && + cost.output === 0 && + cost.cacheRead === 0 && + cost.cacheWrite === 0 + ); } function isMistralSelectionModel(modelId: string): boolean { @@ -259,7 +305,9 @@ function isMistralSelectionModel(modelId: string): boolean { return true; } -function isModelAllowedByBuiltInProviderPolicy(model: ProviderPolicyModel): boolean { +function isModelAllowedByBuiltInProviderPolicy( + model: ProviderPolicyModel, +): boolean { const provider = model.provider.toLowerCase(); const modelKey = model.id.trim().toLowerCase(); if (HIDDEN_MODEL_PROVIDERS.has(provider)) { @@ -309,22 +357,35 @@ function mergeCompat( ): Model["compat"] | undefined { if (!overrideCompat) return baseCompat; - const base = baseCompat as OpenAICompletionsCompat | OpenAIResponsesCompat | undefined; - const override = overrideCompat as OpenAICompletionsCompat | OpenAIResponsesCompat; - const merged = { ...base, ...override } as OpenAICompletionsCompat | OpenAIResponsesCompat; + const base = baseCompat as + | OpenAICompletionsCompat + | OpenAIResponsesCompat + | undefined; + const override = overrideCompat as + | OpenAICompletionsCompat + | OpenAIResponsesCompat; + const merged = { ...base, ...override } as + | OpenAICompletionsCompat + | OpenAIResponsesCompat; const baseCompletions = base as OpenAICompletionsCompat | undefined; const overrideCompletions = override as OpenAICompletionsCompat; const mergedCompletions = merged as OpenAICompletionsCompat; - if (baseCompletions?.openRouterRouting || overrideCompletions.openRouterRouting) { + if ( + baseCompletions?.openRouterRouting || + overrideCompletions.openRouterRouting + ) { mergedCompletions.openRouterRouting = { ...baseCompletions?.openRouterRouting, ...overrideCompletions.openRouterRouting, }; } - if (baseCompletions?.vercelGatewayRouting || overrideCompletions.vercelGatewayRouting) { + if ( + baseCompletions?.vercelGatewayRouting || + overrideCompletions.vercelGatewayRouting + ) { mergedCompletions.vercelGatewayRouting = { ...baseCompletions?.vercelGatewayRouting, ...overrideCompletions.vercelGatewayRouting, @@ -338,14 +399,19 @@ function mergeCompat( * Deep merge a model override into a model. * Handles nested objects (cost, compat) by merging rather than replacing. */ -function applyModelOverride(model: Model, override: ModelOverride): Model { +function applyModelOverride( + model: Model, + override: ModelOverride, +): Model { const result = { ...model }; // Simple field overrides if (override.name !== undefined) result.name = override.name; if (override.reasoning !== undefined) result.reasoning = override.reasoning; - if (override.input !== undefined) result.input = override.input as ("text" | "image")[]; - if (override.contextWindow !== undefined) result.contextWindow = override.contextWindow; + if (override.input !== undefined) + result.input = override.input as ("text" | "image")[]; + if (override.contextWindow !== undefined) + result.contextWindow = override.contextWindow; if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens; // Merge cost (partial override) @@ -361,7 +427,9 @@ function applyModelOverride(model: Model, override: ModelOverride): Model, override: ModelOverride): Model[], customModels: Model[]): Model[] { + private mergeCustomModels( + builtInModels: Model[], + customModels: Model[], + ): Model[] { const merged = [...builtInModels]; for (const customModel of customModels) { - const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id); + const existingIndex = merged.findIndex( + (m) => m.provider === customModel.provider && m.id === customModel.id, + ); if (existingIndex >= 0) { merged[existingIndex] = customModel; } else { @@ -509,22 +588,32 @@ export class ModelRegistry { return merged; } - private isProviderModelAllowed(model: ProviderPolicyModel, providerModelAllow?: ProviderModelAllowList): boolean { + private isProviderModelAllowed( + model: ProviderPolicyModel, + providerModelAllow?: ProviderModelAllowList, + ): boolean { if (!isModelAllowedByBuiltInProviderPolicy(model)) return false; if (!providerModelAllow) return true; const providerKey = model.provider.toLowerCase(); - const allowedModels = providerModelAllow[providerKey] - ?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1]; + const allowedModels = + providerModelAllow[providerKey] ?? + Object.entries(providerModelAllow).find( + ([key]) => key.toLowerCase() === providerKey, + )?.[1]; if (allowedModels === undefined) return true; const modelKey = model.id.trim().toLowerCase(); - return allowedModels.some((allowedModel) => providerModelAllowEntryMatches(allowedModel, modelKey)); + return allowedModels.some((allowedModel) => + providerModelAllowEntryMatches(allowedModel, modelKey), + ); } private filterProviderModelAllow>( models: T[], providerModelAllow?: ProviderModelAllowList, ): T[] { - return models.filter((model) => this.isProviderModelAllowed(model, providerModelAllow)); + return models.filter((model) => + this.isProviderModelAllowed(model, providerModelAllow), + ); } private loadCustomModels(modelsJsonPath: string): CustomModelsResult { @@ -540,9 +629,12 @@ export class ModelRegistry { const validate = ajv.getSchema("ModelsConfig")!; if (!validate(config)) { const errors = - validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") || - "Unknown schema error"; - return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`); + validate.errors + ?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`) + .join("\n") || "Unknown schema error"; + return emptyCustomModelsResult( + `Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`, + ); } // Additional validation @@ -551,9 +643,15 @@ export class ModelRegistry { const overrides = new Map(); const modelOverrides = new Map>(); - for (const [providerName, providerConfig] of Object.entries(config.providers)) { + for (const [providerName, providerConfig] of Object.entries( + config.providers, + )) { // Apply provider-level baseUrl/headers/apiKey override to built-in models when configured. - if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) { + if ( + providerConfig.baseUrl || + providerConfig.headers || + providerConfig.apiKey + ) { overrides.set(providerName, { baseUrl: providerConfig.baseUrl, headers: providerConfig.headers, @@ -567,14 +665,24 @@ export class ModelRegistry { } if (providerConfig.modelOverrides) { - modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides))); + modelOverrides.set( + providerName, + new Map(Object.entries(providerConfig.modelOverrides)), + ); } } - return { models: this.parseModels(config), overrides, modelOverrides, error: undefined }; + return { + models: this.parseModels(config), + overrides, + modelOverrides, + error: undefined, + }; } catch (error) { if (error instanceof SyntaxError) { - return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`); + return emptyCustomModelsResult( + `Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`, + ); } return emptyCustomModelsResult( `Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`, @@ -583,24 +691,33 @@ export class ModelRegistry { } private validateConfig(config: ModelsConfig): void { - for (const [providerName, providerConfig] of Object.entries(config.providers)) { + for (const [providerName, providerConfig] of Object.entries( + config.providers, + )) { const hasProviderApi = !!providerConfig.api; const models = providerConfig.models ?? []; const hasModelOverrides = - providerConfig.modelOverrides && Object.keys(providerConfig.modelOverrides).length > 0; + providerConfig.modelOverrides && + Object.keys(providerConfig.modelOverrides).length > 0; if (models.length === 0) { // Override-only config: needs baseUrl OR modelOverrides (or both) if (!providerConfig.baseUrl && !hasModelOverrides) { - throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`); + throw new Error( + `Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`, + ); } } else { // Custom models are merged into provider models and require endpoint + auth. if (!providerConfig.baseUrl) { - throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`); + throw new Error( + `Provider ${providerName}: "baseUrl" is required when defining custom models.`, + ); } if (!providerConfig.apiKey) { - throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`); + throw new Error( + `Provider ${providerName}: "apiKey" is required when defining custom models.`, + ); } } @@ -613,12 +730,17 @@ export class ModelRegistry { ); } - if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`); + if (!modelDef.id) + throw new Error(`Provider ${providerName}: model missing "id"`); // Validate contextWindow/maxTokens only if provided (they have defaults) if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0) - throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`); + throw new Error( + `Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`, + ); if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0) - throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`); + throw new Error( + `Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`, + ); } } } @@ -626,7 +748,9 @@ export class ModelRegistry { private parseModels(config: ModelsConfig): Model[] { const models: Model[] = []; - for (const [providerName, providerConfig] of Object.entries(config.providers)) { + for (const [providerName, providerConfig] of Object.entries( + config.providers, + )) { const modelDefs = providerConfig.models ?? []; if (modelDefs.length === 0) continue; // Override-only, no custom models @@ -655,7 +779,10 @@ export class ModelRegistry { // Resolve env vars and shell commands in header values const providerHeaders = resolveHeaders(providerConfig.headers); const modelHeaders = resolveHeaders(modelDef.headers); - let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined; + let headers = + providerHeaders || modelHeaders + ? { ...providerHeaders, ...modelHeaders } + : undefined; // If authHeader is true, add Authorization header with resolved API key if (providerConfig.authHeader && providerConfig.apiKey) { @@ -667,7 +794,12 @@ export class ModelRegistry { // Provider baseUrl is required when custom models are defined. // Individual models can override it with modelDef.baseUrl. - const defaultCost = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }; + const defaultCost = { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + }; models.push({ id: modelDef.id, name: modelDef.name ?? modelDef.id, @@ -712,6 +844,7 @@ export class ModelRegistry { * Defaults to "apiKey" for built-ins and providers without explicit mode. */ getProviderAuthMode(provider: string): ProviderAuthMode { + if (BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS.has(provider)) return "externalCli"; const config = this.registeredProviders.get(provider); if (!config) return "apiKey"; if (config.authMode) return config.authMode; @@ -745,10 +878,15 @@ export class ModelRegistry { * Returns undefined for externalCli/none providers (no key needed). * @param sessionId - Optional session ID for sticky credential selection */ - async getApiKey(model: Model, sessionId?: string): Promise { + async getApiKey( + model: Model, + sessionId?: string, + ): Promise { const authMode = this.getProviderAuthMode(model.provider); if (authMode === "externalCli" || authMode === "none") return undefined; - return this.authStorage.getApiKey(model.provider, sessionId, { baseUrl: model.baseUrl }); + return this.authStorage.getApiKey(model.provider, sessionId, { + baseUrl: model.baseUrl, + }); } /** @@ -756,7 +894,10 @@ export class ModelRegistry { * Returns undefined for externalCli/none providers (no key needed). * @param sessionId - Optional session ID for sticky credential selection */ - async getApiKeyForProvider(provider: string, sessionId?: string): Promise { + async getApiKeyForProvider( + provider: string, + sessionId?: string, + ): Promise { const authMode = this.getProviderAuthMode(provider); if (authMode === "externalCli" || authMode === "none") return undefined; return this.authStorage.getApiKey(provider, sessionId); @@ -798,7 +939,10 @@ export class ModelRegistry { this.refresh(); } - private applyProviderConfig(providerName: string, config: ProviderConfigInput): void { + private applyProviderConfig( + providerName: string, + config: ProviderConfigInput, + ): void { // Register OAuth provider if provided if (config.oauth) { // Ensure the OAuth provider ID matches the provider name @@ -811,19 +955,30 @@ export class ModelRegistry { if (config.streamSimple) { if (!config.api) { - throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`); + throw new Error( + `Provider ${providerName}: "api" is required when registering streamSimple.`, + ); } const rawStreamSimple = config.streamSimple; const authMode = config.authMode ?? "apiKey"; // Keyless providers never see apiKey in options — enforced at registration, // not by convention. Prevents undefined from reaching any handler. - const streamSimple = (authMode === "externalCli" || authMode === "none") - ? ((model: Model, context: Context, options?: SimpleStreamOptions) => { - const { apiKey: _, ...opts } = options ?? {}; - return rawStreamSimple(model, context, opts as SimpleStreamOptions); - }) - : rawStreamSimple; + const streamSimple = + authMode === "externalCli" || authMode === "none" + ? ( + model: Model, + context: Context, + options?: SimpleStreamOptions, + ) => { + const { apiKey: _, ...opts } = options ?? {}; + return rawStreamSimple( + model, + context, + opts as SimpleStreamOptions, + ); + } + : rawStreamSimple; // Guard: if there's already a handler registered for this API, wrap // the new one so it only fires for models from this provider and @@ -832,7 +987,11 @@ export class ModelRegistry { // the built-in Anthropic stream handler (#2536). const existingProvider = getApiProvider(config.api as Api); const scopedStream = existingProvider - ? (model: Model, context: Context, options?: SimpleStreamOptions): AssistantMessageEventStream => { + ? ( + model: Model, + context: Context, + options?: SimpleStreamOptions, + ): AssistantMessageEventStream => { if (model.provider === providerName) { return streamSimple(model, context, options); } @@ -840,12 +999,23 @@ export class ModelRegistry { } : streamSimple; - const newFullStream = (model: Model, context: Context, options?: SimpleStreamOptions) => - scopedStream(model, context, options as SimpleStreamOptions); + const newFullStream = ( + model: Model, + context: Context, + options?: SimpleStreamOptions, + ) => scopedStream(model, context, options as SimpleStreamOptions); const scopedFullStream = existingProvider - ? (model: Model, context: Context, options?: Record) => { + ? ( + model: Model, + context: Context, + options?: Record, + ) => { if (model.provider === providerName) { - return newFullStream(model, context, options as SimpleStreamOptions); + return newFullStream( + model, + context, + options as SimpleStreamOptions, + ); } return existingProvider.stream(model, context, options); } @@ -872,25 +1042,35 @@ export class ModelRegistry { // Validate required fields if (!config.baseUrl) { - throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`); + throw new Error( + `Provider ${providerName}: "baseUrl" is required when defining models.`, + ); } - const authMode = config.authMode ?? (config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey"); + const authMode = + config.authMode ?? + (config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey"); if (authMode === "apiKey" && !config.apiKey && !config.oauth) { throw new Error( `Provider ${providerName}: "apiKey" or "oauth" is required when authMode is "apiKey" (the default). ` + - `Set authMode to "externalCli" or "none" for keyless providers.`, + `Set authMode to "externalCli" or "none" for keyless providers.`, ); } - if ((authMode === "externalCli" || authMode === "none") && !config.streamSimple) { + if ( + (authMode === "externalCli" || authMode === "none") && + !config.streamSimple + ) { throw new Error( `Provider ${providerName}: "streamSimple" is required when authMode is "${authMode}". ` + - `Keyless providers must supply their own stream handler.`, + `Keyless providers must supply their own stream handler.`, ); } - if ((authMode === "externalCli" || authMode === "none") && config.apiKey) { + if ( + (authMode === "externalCli" || authMode === "none") && + config.apiKey + ) { throw new Error( `Provider ${providerName}: "apiKey" cannot be set when authMode is "${authMode}". ` + - `Keyless providers should not provide API key credentials.`, + `Keyless providers should not provide API key credentials.`, ); } @@ -898,13 +1078,18 @@ export class ModelRegistry { for (const modelDef of config.models) { const api = modelDef.api || config.api; if (!api) { - throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`); + throw new Error( + `Provider ${providerName}, model ${modelDef.id}: no "api" specified.`, + ); } // Merge headers const providerHeaders = resolveHeaders(config.headers); const modelHeaders = resolveHeaders(modelDef.headers); - let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined; + let headers = + providerHeaders || modelHeaders + ? { ...providerHeaders, ...modelHeaders } + : undefined; // If authHeader is true, add Authorization header if (config.authHeader && config.apiKey) { @@ -949,15 +1134,24 @@ export class ModelRegistry { return { ...m, baseUrl: config.baseUrl ?? m.baseUrl, - headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers, + headers: resolvedHeaders + ? { ...m.headers, ...resolvedHeaders } + : m.headers, }; }); } } - private buildCandidateOrder(modelId: string, overrides: Record): string[] { - const overrideEntry = Object.entries(overrides).find(([k]) => modelId.startsWith(k)); - const familyEntry = PROXY_FAMILY_PRIORITY.find((r) => r.match.test(modelId)); + private buildCandidateOrder( + modelId: string, + overrides: Record, + ): string[] { + const overrideEntry = Object.entries(overrides).find(([k]) => + modelId.startsWith(k), + ); + const familyEntry = PROXY_FAMILY_PRIORITY.find((r) => + r.match.test(modelId), + ); // Order: direct family providers → family-scoped failover → global fallback. // Overrides replace only the direct list (keeps family_failover + global // chain intact) so a user pinning "glm- → [zai]" still picks up @@ -992,8 +1186,12 @@ export class ModelRegistry { return (pa === -1 ? Infinity : pa) - (pb === -1 ? Infinity : pb); }); - const withAuth = sorted.filter((m) => this.isProviderRequestReady(m.provider)); - const withoutAuth = sorted.filter((m) => !this.isProviderRequestReady(m.provider)); + const withAuth = sorted.filter((m) => + this.isProviderRequestReady(m.provider), + ); + const withoutAuth = sorted.filter( + (m) => !this.isProviderRequestReady(m.provider), + ); return [...withAuth, ...withoutAuth]; } @@ -1064,7 +1262,9 @@ export class ModelRegistry { } // Convert and merge discovered models, then apply capability patches - this.discoveredModels = applyCapabilityPatches(this.convertDiscoveredModels(results)); + this.discoveredModels = applyCapabilityPatches( + this.convertDiscoveredModels(results), + ); return results; } @@ -1082,8 +1282,12 @@ export class ModelRegistry { * Discovered models are appended but never override existing models. */ getAllWithDiscovered(): Model[] { - const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`)); - const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`)); + const existingIds = new Set( + this.models.map((m) => `${m.provider}/${m.id}`), + ); + const unique = this.discoveredModels.filter( + (m) => !existingIds.has(`${m.provider}/${m.id}`), + ); return this.filterProviderModelAllow([...this.models, ...unique]); } @@ -1094,15 +1298,22 @@ export class ModelRegistry { * providers with the provider's actual `/models` response. * Consumer: cli/list-models.ts when `--discover` or an exact provider query is used. */ - getDiscoveredModels(providerModelAllow?: ProviderModelAllowList): Model[] { - return this.filterProviderModelAllow(this.discoveredModels, providerModelAllow); + getDiscoveredModels( + providerModelAllow?: ProviderModelAllowList, + ): Model[] { + return this.filterProviderModelAllow( + this.discoveredModels, + providerModelAllow, + ); } /** * Check if a model was added via discovery (not built-in or custom). */ isDiscovered(model: Model): boolean { - return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id); + return this.discoveredModels.some( + (m) => m.provider === model.provider && m.id === model.id, + ); } /** @@ -1125,7 +1336,9 @@ export class ModelRegistry { const key = `${provider}/${dm.id}`; if (seen.has(key)) continue; seen.add(key); - const known = this.models.find((m) => m.provider === provider && m.id === dm.id); + const known = this.models.find( + (m) => m.provider === provider && m.id === dm.id, + ); const discoveredName = dm.name && dm.name !== dm.id ? dm.name : undefined; converted.push({ @@ -1137,7 +1350,8 @@ export class ModelRegistry { baseUrl: dm.baseUrl ?? known?.baseUrl ?? "", reasoning: dm.reasoning ?? known?.reasoning ?? false, input: dm.input ?? known?.input ?? ["text"], - cost: dm.cost ?? known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + cost: dm.cost ?? + known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, contextWindow: dm.contextWindow ?? known?.contextWindow ?? 128000, maxTokens: dm.maxTokens ?? known?.maxTokens ?? 16384, } as Model); @@ -1177,7 +1391,11 @@ export interface ProviderConfigInput { baseUrl?: string; apiKey?: string; api?: Api; - streamSimple?: (model: Model, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream; + streamSimple?: ( + model: Model, + context: Context, + options?: SimpleStreamOptions, + ) => AssistantMessageEventStream; headers?: Record; authHeader?: boolean; /** OAuth provider for /login support */ @@ -1189,7 +1407,12 @@ export interface ProviderConfigInput { baseUrl?: string; reasoning: boolean; input: ("text" | "image")[]; - cost: { input: number; output: number; cacheRead: number; cacheWrite: number }; + cost: { + input: number; + output: number; + cacheRead: number; + cacheWrite: number; + }; contextWindow: number; maxTokens: number; headers?: Record; diff --git a/packages/pi-coding-agent/src/core/retry-handler.test.ts b/packages/pi-coding-agent/src/core/retry-handler.test.ts index 107f6966e..757552b74 100644 --- a/packages/pi-coding-agent/src/core/retry-handler.test.ts +++ b/packages/pi-coding-agent/src/core/retry-handler.test.ts @@ -6,12 +6,12 @@ * downgrade from [1m] to base when no cross-provider fallback exists. */ -import { vi, describe, it, beforeEach, type Mock } from 'vitest'; import assert from "node:assert/strict"; -import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js"; import type { Api, AssistantMessage, Model } from "@singularity-forge/pi-ai"; +import { beforeEach, describe, it, type Mock, vi } from "vitest"; import type { FallbackResolver } from "./fallback-resolver.js"; import type { ModelRegistry } from "./model-registry.js"; +import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js"; import type { SettingsManager } from "./settings-manager.js"; // ─── Helpers ──────────────────────────────────────────────────────────────── @@ -38,7 +38,14 @@ function errorMessage(msg: string): AssistantMessage { api: "anthropic-messages", provider: "anthropic", model: "claude-opus-4-6[1m]", - usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, stopReason: "error", errorMessage: msg, timestamp: Date.now(), @@ -52,7 +59,9 @@ interface MockDeps { onModelChangeFn: Mock<(model: Model) => void>; markUsageLimitReached: Mock<(...args: any[]) => boolean>; findFallback: Mock<(...args: any[]) => Promise>; - findModel: Mock<(provider: string, modelId: string) => Model | undefined>; + findModel: Mock< + (provider: string, modelId: string) => Model | undefined + >; } function createMockDeps(overrides?: { @@ -60,14 +69,19 @@ function createMockDeps(overrides?: { retryEnabled?: boolean; markUsageLimitReachedResult?: boolean; fallbackResult?: any; - findModelResult?: (provider: string, modelId: string) => Model | undefined; + findModelResult?: ( + provider: string, + modelId: string, + ) => Model | undefined; + providerAuthMode?: "apiKey" | "oauth" | "externalCli" | "none"; retrySettings?: { maxRetries?: number; baseDelayMs?: number; maxDelayMs?: number; }; }): MockDeps { - const model = overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]"); + const model = + overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]"); const emittedEvents: Array> = []; const continueFn = vi.fn(async () => {}); const onModelChangeFn = vi.fn((_model: Model) => {}); @@ -76,7 +90,8 @@ function createMockDeps(overrides?: { ); const findFallback = vi.fn(async () => overrides?.fallbackResult ?? null); const findModel = vi.fn( - overrides?.findModelResult ?? ((_provider: string, _modelId: string) => undefined), + overrides?.findModelResult ?? + ((_provider: string, _modelId: string) => undefined), ); const messages: Array<{ role: string } & Record> = []; @@ -105,6 +120,7 @@ function createMockDeps(overrides?: { markUsageLimitReached, }, find: findModel, + getProviderAuthMode: () => overrides?.providerAuthMode ?? "apiKey", } as unknown as ModelRegistry, fallbackResolver: { findFallback, @@ -115,13 +131,20 @@ function createMockDeps(overrides?: { onModelChange: onModelChangeFn, }; - return { deps, emittedEvents, continueFn, onModelChangeFn, markUsageLimitReached, findFallback, findModel }; + return { + deps, + emittedEvents, + continueFn, + onModelChangeFn, + markUsageLimitReached, + findFallback, + findModel, + }; } // ─── _classifyErrorType (tested via handleRetryableError behavior) ────────── describe("RetryHandler — long-context entitlement 429 (#2803)", () => { - describe("error classification", () => { it("classifies 'Extra usage is required for long context requests' as quota_exhausted, not rate_limit", async () => { // When the error is classified as quota_exhausted AND no alternate credentials @@ -136,7 +159,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const handler = new RetryHandler(deps); const msg = errorMessage( - '429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}' + '429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}', ); const result = await handler.handleRetryableError(msg); @@ -145,11 +168,22 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { assert.equal(result, false); // Should emit fallback_chain_exhausted (quota_exhausted path), NOT auto_retry_start (backoff path) - const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted"); - assert.ok(chainExhausted, "Expected fallback_chain_exhausted event for entitlement error"); + const chainExhausted = emittedEvents.find( + (e) => e.type === "fallback_chain_exhausted", + ); + assert.ok( + chainExhausted, + "Expected fallback_chain_exhausted event for entitlement error", + ); - const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start"); - assert.equal(retryStart, undefined, "Should NOT emit auto_retry_start for entitlement error"); + const retryStart = emittedEvents.find( + (e) => e.type === "auto_retry_start", + ); + assert.equal( + retryStart, + undefined, + "Should NOT emit auto_retry_start for entitlement error", + ); }); it("still classifies regular 429 rate limits as rate_limit", async () => { @@ -168,7 +202,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { // Should enter the backoff loop (rate_limit path, not quota_exhausted) assert.equal(result, true); - const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start"); + const retryStart = emittedEvents.find( + (e) => e.type === "auto_retry_start", + ); assert.ok(retryStart, "Regular 429 should enter backoff retry"); }); @@ -184,7 +220,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const handler = new RetryHandler(deps); const msg = errorMessage( - '529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}' + '529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}', ); const result = await handler.handleRetryableError(msg); @@ -192,11 +228,18 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { // Should enter the backoff loop (rate_limit path, not quota_exhausted) assert.equal(result, true); - const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start"); - assert.ok(retryStart, "529 overloaded_error should enter backoff retry as rate_limit"); + const retryStart = emittedEvents.find( + (e) => e.type === "auto_retry_start", + ); + assert.ok( + retryStart, + "529 overloaded_error should enter backoff retry as rate_limit", + ); // Must NOT be treated as quota_exhausted (would emit fallback_chain_exhausted) - const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted"); + const chainExhausted = emittedEvents.find( + (e) => e.type === "fallback_chain_exhausted", + ); assert.equal( chainExhausted, undefined, @@ -218,27 +261,40 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const result = await handler.handleRetryableError(msg); - assert.equal(result, true, "affordability error should trigger credit-aware retry"); - const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start"); - assert.ok(retryStart, "Expected immediate retry after reducing max tokens"); + assert.equal( + result, + true, + "affordability error should trigger credit-aware retry", + ); + const retryStart = emittedEvents.find( + (e) => e.type === "auto_retry_start", + ); + assert.ok( + retryStart, + "Expected immediate retry after reducing max tokens", + ); }); }); describe("long-context model downgrade", () => { it("downgrades from [1m] to base model when entitlement error and no fallback", async () => { const baseModel = createMockModel("anthropic", "claude-opus-4-6"); - const { deps, emittedEvents, onModelChangeFn, continueFn } = createMockDeps({ - model: createMockModel("anthropic", "claude-opus-4-6[1m]"), - markUsageLimitReachedResult: false, - fallbackResult: null, - findModelResult: (provider: string, modelId: string) => { - if (provider === "anthropic" && modelId === "claude-opus-4-6") return baseModel; - return undefined; - }, - }); + const { deps, emittedEvents, onModelChangeFn, continueFn } = + createMockDeps({ + model: createMockModel("anthropic", "claude-opus-4-6[1m]"), + markUsageLimitReachedResult: false, + fallbackResult: null, + findModelResult: (provider: string, modelId: string) => { + if (provider === "anthropic" && modelId === "claude-opus-4-6") + return baseModel; + return undefined; + }, + }); const handler = new RetryHandler(deps); - const msg = errorMessage("Extra usage is required for long context requests."); + const msg = errorMessage( + "Extra usage is required for long context requests.", + ); const result = await handler.handleRetryableError(msg); @@ -253,9 +309,17 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { assert.equal(onModelChangeFn.mock.calls.length, 1); // Should emit a fallback_provider_switch event indicating downgrade - const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch"); - assert.ok(switchEvent, "Expected fallback_provider_switch event for downgrade"); - assert.ok(switchEvent!.reason.includes("long context downgrade"), `reason should mention downgrade: ${switchEvent!.reason}`); + const switchEvent = emittedEvents.find( + (e) => e.type === "fallback_provider_switch", + ); + assert.ok( + switchEvent, + "Expected fallback_provider_switch event for downgrade", + ); + assert.ok( + switchEvent!.reason.includes("long context downgrade"), + `reason should mention downgrade: ${switchEvent!.reason}`, + ); }); it("emits fallback_chain_exhausted when base model is also unavailable", async () => { @@ -267,13 +331,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { }); const handler = new RetryHandler(deps); - const msg = errorMessage("Extra usage is required for long context requests."); + const msg = errorMessage( + "Extra usage is required for long context requests.", + ); const result = await handler.handleRetryableError(msg); assert.equal(result, false); - const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted"); - assert.ok(chainExhausted, "Expected fallback_chain_exhausted when base model unavailable"); + const chainExhausted = emittedEvents.find( + (e) => e.type === "fallback_chain_exhausted", + ); + assert.ok( + chainExhausted, + "Expected fallback_chain_exhausted when base model unavailable", + ); }); it("does not attempt downgrade for non-[1m] models", async () => { @@ -286,17 +357,27 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { }); const handler = new RetryHandler(deps); - const msg = errorMessage("Extra usage is required for long context requests."); + const msg = errorMessage( + "Extra usage is required for long context requests.", + ); const result = await handler.handleRetryableError(msg); assert.equal(result, false); - const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted"); + const chainExhausted = emittedEvents.find( + (e) => e.type === "fallback_chain_exhausted", + ); assert.ok(chainExhausted); // No downgrade switch should occur - const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch"); - assert.equal(switchEvent, undefined, "Should not switch for non-[1m] models"); + const switchEvent = emittedEvents.find( + (e) => e.type === "fallback_provider_switch", + ); + assert.equal( + switchEvent, + undefined, + "Should not switch for non-[1m] models", + ); }); }); @@ -315,9 +396,19 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { handler.abortRetry(); await new Promise((resolve) => setTimeout(resolve, 10)); - assert.equal(continueFn.mock.calls.length, 0, "cancelled retry must not continue after explicit abort"); - const endEvents = emittedEvents.filter((e) => e.type === "auto_retry_end"); - assert.equal(endEvents.length, 1, "retry cancellation should emit a single auto_retry_end event"); + assert.equal( + continueFn.mock.calls.length, + 0, + "cancelled retry must not continue after explicit abort", + ); + const endEvents = emittedEvents.filter( + (e) => e.type === "auto_retry_end", + ); + assert.equal( + endEvents.length, + 1, + "retry cancellation should emit a single auto_retry_end event", + ); assert.equal(endEvents[0]?.finalError, "Retry cancelled"); }); }); @@ -346,10 +437,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const downgraded = setModelCalls[0][0] as Model; assert.equal(downgraded.provider, "openrouter"); assert.equal(downgraded.id, "openai/gpt-5-pro"); - assert.equal(downgraded.maxTokens, 297, "expected affordability cap with safety buffer"); + assert.equal( + downgraded.maxTokens, + 297, + "expected affordability cap with safety buffer", + ); - assert.equal(onModelChangeFn.mock.calls.length, 1, "should notify about model update"); - const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch"); + assert.equal( + onModelChangeFn.mock.calls.length, + 1, + "should notify about model update", + ); + const switchEvent = emittedEvents.find( + (e) => e.type === "fallback_provider_switch", + ); assert.ok(switchEvent, "should emit model-adjustment event"); assert.ok( String(switchEvent?.reason || "").includes("credit-aware retry"), @@ -373,7 +474,11 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { ); await handler.handleRetryableError(msg); - assert.equal(markUsageLimitReached.mock.calls.length, 0, "quota error should skip credential cooldown"); + assert.equal( + markUsageLimitReached.mock.calls.length, + 0, + "quota error should skip credential cooldown", + ); }); }); @@ -381,7 +486,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { it("considers long-context entitlement error as retryable", () => { const { deps } = createMockDeps(); const handler = new RetryHandler(deps); - const msg = errorMessage("Extra usage is required for long context requests."); + const msg = errorMessage( + "Extra usage is required for long context requests.", + ); assert.equal(handler.isRetryableError(msg), true); }); @@ -393,7 +500,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const handler = new RetryHandler(deps); const msg = errorMessage( 'All credentials for "anthropic" are in a cooldown window. ' + - 'Please wait a moment and try again, or switch to a different provider.', + "Please wait a moment and try again, or switch to a different provider.", ); assert.equal(handler.isRetryableError(msg), false); }); @@ -414,7 +521,8 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const { deps, emittedEvents, onModelChangeFn } = createMockDeps({ model: createMockModel("anthropic", "claude-opus-4-6"), findModelResult: (provider: string, modelId: string) => { - if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel; + if (provider === "claude-code" && modelId === "claude-opus-4-6") + return ccModel; return undefined; }, }); @@ -426,9 +534,14 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const result = await handler.handleRetryableError(msg); assert.equal(result, true, "should retry via claude-code fallback"); - const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch"); + const switchEvent = emittedEvents.find( + (e) => e.type === "fallback_provider_switch", + ); assert.ok(switchEvent, "Expected fallback_provider_switch event"); - assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider"); + assert.ok( + switchEvent!.to.startsWith("claude-code/"), + "Should switch to claude-code provider", + ); }); it("switches to claude-code on 'out of extra usage' error (#3772)", async () => { @@ -436,21 +549,29 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const { deps, emittedEvents } = createMockDeps({ model: createMockModel("anthropic", "claude-opus-4-6"), findModelResult: (provider: string, modelId: string) => { - if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel; + if (provider === "claude-code" && modelId === "claude-opus-4-6") + return ccModel; return undefined; }, }); deps.isClaudeCodeReady = () => true; const handler = new RetryHandler(deps); - const msg = errorMessage("You're out of extra usage. Add more at claude.ai/settings/usage and keep going."); + const msg = errorMessage( + "You're out of extra usage. Add more at claude.ai/settings/usage and keep going.", + ); const result = await handler.handleRetryableError(msg); assert.equal(result, true, "should retry via claude-code fallback"); - const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch"); + const switchEvent = emittedEvents.find( + (e) => e.type === "fallback_provider_switch", + ); assert.ok(switchEvent, "Expected fallback_provider_switch event"); - assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider"); + assert.ok( + switchEvent!.to.startsWith("claude-code/"), + "Should switch to claude-code provider", + ); }); it("does NOT switch to claude-code when current provider is not anthropic", async () => { @@ -458,22 +579,31 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { const { deps, emittedEvents } = createMockDeps({ model: createMockModel("openai", "gpt-4o"), findModelResult: (provider: string, modelId: string) => { - if (provider === "claude-code" && modelId === "gpt-4o") return ccModel; + if (provider === "claude-code" && modelId === "gpt-4o") + return ccModel; return undefined; }, }); deps.isClaudeCodeReady = () => true; const handler = new RetryHandler(deps); - const msg = errorMessage("third-party apps are not supported for this plan"); + const msg = errorMessage( + "third-party apps are not supported for this plan", + ); const result = await handler.handleRetryableError(msg); // Should NOT have triggered the claude-code fallback const switchEvent = emittedEvents.find( - (e) => e.type === "fallback_provider_switch" && e.to?.startsWith("claude-code/"), + (e) => + e.type === "fallback_provider_switch" && + e.to?.startsWith("claude-code/"), + ); + assert.equal( + switchEvent, + undefined, + "Should NOT switch non-anthropic provider to claude-code", ); - assert.equal(switchEvent, undefined, "Should NOT switch non-anthropic provider to claude-code"); }); }); @@ -522,16 +652,60 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => { ); }); + it("does NOT write credential cooldown for external CLI rate_limit errors", async () => { + const fallbackModel = createMockModel("openai", "gpt-4o"); + const { deps, emittedEvents, markUsageLimitReached, findFallback } = + createMockDeps({ + model: createMockModel("google-gemini-cli", "gemini-2.5-pro"), + providerAuthMode: "externalCli", + markUsageLimitReachedResult: true, + fallbackResult: { + model: fallbackModel, + reason: "cross-provider fallback", + }, + }); + + const handler = new RetryHandler(deps); + const msg = errorMessage("429 Too Many Requests"); + + const result = await handler.handleRetryableError(msg); + + assert.equal( + result, + true, + "external CLI 429 should still enter fallback/retry handling", + ); + assert.equal( + markUsageLimitReached.mock.calls.length, + 0, + "external CLI providers must not be written into .sf credential cooldown", + ); + assert.equal( + findFallback.mock.calls.length, + 1, + "429 should still ask for provider fallback", + ); + assert.ok( + emittedEvents.some((event) => event.type === "auto_retry_start"), + "429 should still emit retry start after fallback selection", + ); + }); + it("still tries cross-provider fallback for quota_exhausted without credential backoff", async () => { const fallbackModel = createMockModel("openai", "gpt-4o"); const { deps, markUsageLimitReached, continueFn } = createMockDeps({ model: createMockModel("anthropic", "claude-opus-4-6[1m]"), markUsageLimitReachedResult: false, - fallbackResult: { model: fallbackModel, reason: "cross-provider fallback" }, + fallbackResult: { + model: fallbackModel, + reason: "cross-provider fallback", + }, }); const handler = new RetryHandler(deps); - const msg = errorMessage("Extra usage is required for long context requests."); + const msg = errorMessage( + "Extra usage is required for long context requests.", + ); const result = await handler.handleRetryableError(msg); diff --git a/packages/pi-coding-agent/src/core/retry-handler.ts b/packages/pi-coding-agent/src/core/retry-handler.ts index 4f61902f9..f7bb89640 100644 --- a/packages/pi-coding-agent/src/core/retry-handler.ts +++ b/packages/pi-coding-agent/src/core/retry-handler.ts @@ -12,12 +12,12 @@ import type { Agent } from "@singularity-forge/pi-agent-core"; import type { AssistantMessage, Model } from "@singularity-forge/pi-ai"; import { isContextOverflow } from "@singularity-forge/pi-ai"; +import { sleep } from "../utils/sleep.js"; +import type { AgentSessionEvent } from "./agent-session.js"; import type { UsageLimitErrorType } from "./auth-storage.js"; import type { FallbackResolver } from "./fallback-resolver.js"; import type { ModelRegistry } from "./model-registry.js"; import type { SettingsManager } from "./settings-manager.js"; -import { sleep } from "../utils/sleep.js"; -import type { AgentSessionEvent } from "./agent-session.js"; /** Dependencies injected from AgentSession into RetryHandler */ export interface RetryHandlerDeps { @@ -41,7 +41,8 @@ export class RetryHandler { private _retryPromise: Promise | undefined = undefined; private _retryResolve: (() => void) | undefined = undefined; private _retryGeneration = 0; - private _continueTimeout: ReturnType | undefined = undefined; + private _continueTimeout: ReturnType | undefined = + undefined; constructor(private readonly _deps: RetryHandlerDeps) {} @@ -70,7 +71,9 @@ export class RetryHandler { * Must be called synchronously from the agent event handler before * any async processing, so that waitForRetry() doesn't miss in-flight retries. */ - createRetryPromiseForAgentEnd(messages: Array<{ role: string } & Record>): void { + createRetryPromiseForAgentEnd( + messages: Array<{ role: string } & Record>, + ): void { if (this._retryPromise) return; const settings = this._deps.settingsManager.getRetrySettings(); @@ -162,7 +165,10 @@ export class RetryHandler { // when provider reports "can only afford N", lower maxTokens and retry // on the same model before rotating credentials/providers. if (isQuotaError) { - const adjusted = this._tryAffordableMaxTokensRetry(message, retryGeneration); + const adjusted = this._tryAffordableMaxTokensRetry( + message, + retryGeneration, + ); if (adjusted) return true; } @@ -171,12 +177,16 @@ export class RetryHandler { // gates; rotating to another credential on the same account won't help // and the 30-minute backoff blocks all provider requests needlessly. if (isRateLimit) { + const provider = this._deps.getModel()!.provider; + const authMode = this._deps.modelRegistry.getProviderAuthMode(provider); const hasAlternate = - this._deps.modelRegistry.authStorage.markUsageLimitReached( - this._deps.getModel()!.provider, - this._deps.getSessionId(), - { errorType }, - ); + authMode === "externalCli" || authMode === "none" + ? false + : this._deps.modelRegistry.authStorage.markUsageLimitReached( + provider, + this._deps.getSessionId(), + { errorType }, + ); if (hasAlternate) { this._removeLastAssistantError(); @@ -235,7 +245,10 @@ export class RetryHandler { // No fallback available either if (isQuotaError) { // Try long-context model downgrade ([1m] → base) before giving up - const downgraded = this._tryLongContextDowngrade(message, retryGeneration); + const downgraded = this._tryLongContextDowngrade( + message, + retryGeneration, + ); if (downgraded) return true; this._deps.emit({ @@ -271,7 +284,8 @@ export class RetryHandler { // Use server-requested delay when available, capped by maxDelayMs. // Fall back to exponential backoff when no server hint is present. - const exponentialDelayMs = settings.baseDelayMs * 2 ** (this._retryAttempt - 1); + const exponentialDelayMs = + settings.baseDelayMs * 2 ** (this._retryAttempt - 1); let delayMs: number; if (message.retryAfterMs !== undefined) { const cap = settings.maxDelayMs > 0 ? settings.maxDelayMs : Infinity; @@ -280,7 +294,8 @@ export class RetryHandler { type: "auto_retry_end", success: false, attempt: this._retryAttempt - 1, - finalError: `Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(), + finalError: + `Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(), }); this._retryAttempt = 0; this._resolveRetry(); @@ -335,9 +350,9 @@ export class RetryHandler { /** Cancel in-progress retry */ abortRetry(): void { const hadRetry = - this._retryPromise !== undefined - || this._retryAbortController !== undefined - || this._continueTimeout !== undefined; + this._retryPromise !== undefined || + this._retryAbortController !== undefined || + this._continueTimeout !== undefined; if (!hadRetry) return; const attempt = this._retryAttempt > 0 ? this._retryAttempt : 1; @@ -417,13 +432,30 @@ export class RetryHandler { const err = errorMessage.toLowerCase(); // Long-context entitlement errors are billing gates, not transient rate limits. // Must be checked before the generic 429/rate_limit regex. - if (/extra usage is required|long context required/i.test(err)) return "quota_exhausted"; - if (/requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test(err)) + if (/extra usage is required|long context required/i.test(err)) return "quota_exhausted"; - if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) return "quota_exhausted"; - if (/rate.?limit|too many requests|429|529|overloaded/i.test(err)) return "rate_limit"; - if (/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(err)) return "server_error"; - if (/401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test(err)) return "auth_error"; + if ( + /requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test( + err, + ) + ) + return "quota_exhausted"; + if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) + return "quota_exhausted"; + if (/rate.?limit|too many requests|429|529|overloaded/i.test(err)) + return "rate_limit"; + if ( + /500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test( + err, + ) + ) + return "server_error"; + if ( + /401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test( + err, + ) + ) + return "auth_error"; return "unknown"; } @@ -431,7 +463,10 @@ export class RetryHandler { * Attempt a same-model retry by reducing maxTokens when provider reports * an affordability cap (e.g., "can only afford 329"). */ - private _tryAffordableMaxTokensRetry(message: AssistantMessage, retryGeneration: number): boolean { + private _tryAffordableMaxTokensRetry( + message: AssistantMessage, + retryGeneration: number, + ): boolean { const currentModel = this._deps.getModel(); if (!currentModel || !message.errorMessage) return false; @@ -443,9 +478,15 @@ export class RetryHandler { if (!Number.isFinite(affordable) || affordable <= 0) return false; // Leave a small buffer so slight input variance doesn't immediately re-fail. - const safetyBuffer = Math.min(64, Math.max(16, Math.floor(affordable * 0.1))); + const safetyBuffer = Math.min( + 64, + Math.max(16, Math.floor(affordable * 0.1)), + ); const targetMaxTokens = Math.max(64, affordable - safetyBuffer); - const downgradedMaxTokens = Math.min(currentModel.maxTokens, targetMaxTokens); + const downgradedMaxTokens = Math.min( + currentModel.maxTokens, + targetMaxTokens, + ); if (downgradedMaxTokens >= currentModel.maxTokens) return false; const downgradedModel = { @@ -481,7 +522,10 @@ export class RetryHandler { * base model (claude-opus-4-6) when the account lacks the long-context billing * entitlement. Returns true if the downgrade was initiated. */ - private _tryLongContextDowngrade(message: AssistantMessage, retryGeneration: number): boolean { + private _tryLongContextDowngrade( + message: AssistantMessage, + retryGeneration: number, + ): boolean { const currentModel = this._deps.getModel(); if (!currentModel) return false; @@ -490,7 +534,10 @@ export class RetryHandler { if (!match) return false; const baseModelId = match[1]; - const baseModel = this._deps.modelRegistry.find(currentModel.provider, baseModelId); + const baseModel = this._deps.modelRegistry.find( + currentModel.provider, + baseModelId, + ); if (!baseModel) return false; const previousId = currentModel.id; @@ -525,7 +572,9 @@ export class RetryHandler { * and the "out of extra usage" variant that subscription users receive. */ private _isThirdPartyBlock(errorMessage: string): boolean { - return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test(errorMessage); + return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test( + errorMessage, + ); } /** @@ -533,7 +582,10 @@ export class RetryHandler { * Anthropic provider is blocked by the third-party policy (#3772). * Returns true if the switch was made and retry scheduled. */ - private _tryClaudeCodeFallback(message: AssistantMessage, retryGeneration: number): boolean { + private _tryClaudeCodeFallback( + message: AssistantMessage, + retryGeneration: number, + ): boolean { if (!this._deps.isClaudeCodeReady?.()) return false; const currentModel = this._deps.getModel(); @@ -544,7 +596,10 @@ export class RetryHandler { if (currentModel.provider !== "anthropic") return false; // Find the same model ID under the claude-code provider - const ccModel = this._deps.modelRegistry.find("claude-code", currentModel.id); + const ccModel = this._deps.modelRegistry.find( + "claude-code", + currentModel.id, + ); if (!ccModel) return false; const previousProvider = currentModel.provider; @@ -556,7 +611,8 @@ export class RetryHandler { type: "fallback_provider_switch", from: `${previousProvider}/${currentModel.id}`, to: `claude-code/${ccModel.id}`, - reason: "Anthropic subscription blocked for third-party apps — routing through Claude Code CLI", + reason: + "Anthropic subscription blocked for third-party apps — routing through Claude Code CLI", }); this._deps.emit({ @@ -574,7 +630,10 @@ export class RetryHandler { /** Remove the last assistant error message from agent state */ private _removeLastAssistantError(): void { const messages = this._deps.agent.state.messages; - if (messages.length > 0 && messages[messages.length - 1].role === "assistant") { + if ( + messages.length > 0 && + messages[messages.length - 1].role === "assistant" + ) { this._deps.agent.replaceMessages(messages.slice(0, -1)); } }