fix(auth): use gemini cli credentials outside sf store

This commit is contained in:
Mikael Hugo 2026-05-02 13:08:41 +02:00
parent cb2ab66d4f
commit 3c3000c25f
4 changed files with 959 additions and 308 deletions

View file

@ -1,11 +1,19 @@
import assert from "node:assert/strict";
import { describe, it } from 'vitest';
import type { Api, Model, SimpleStreamOptions, Context, AssistantMessageEventStream } from "@singularity-forge/pi-ai";
import type {
Api,
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
} from "@singularity-forge/pi-ai";
import { getApiProvider } from "@singularity-forge/pi-ai";
import { describe, it } from "vitest";
import { AuthStorage, type AuthStorageData } from "./auth-storage.js";
import { ModelRegistry } from "./model-registry.js";
function createRegistry(hasAuthFn?: (provider: string) => boolean): ModelRegistry {
function createRegistry(
hasAuthFn?: (provider: string) => boolean,
): ModelRegistry {
const authStorage = {
setFallbackResolver: () => {},
onCredentialChange: () => {},
@ -22,7 +30,12 @@ function createInMemoryRegistry(data: AuthStorageData = {}): ModelRegistry {
return new ModelRegistry(AuthStorage.inMemory(data), undefined);
}
function createProviderModel(id: string, api?: string): NonNullable<Parameters<ModelRegistry["registerProvider"]>[1]["models"]>[number] {
function createProviderModel(
id: string,
api?: string,
): NonNullable<
Parameters<ModelRegistry["registerProvider"]>[1]["models"]
>[number] {
return {
id,
name: id,
@ -35,8 +48,14 @@ function createProviderModel(id: string, api?: string): NonNullable<Parameters<M
};
}
function findModel(registry: ModelRegistry, provider: string, id: string): Model<Api> | undefined {
return registry.getAvailable().find((m) => m.provider === provider && m.id === id);
function findModel(
registry: ModelRegistry,
provider: string,
id: string,
): Model<Api> | undefined {
return registry
.getAvailable()
.find((m) => m.provider === provider && m.id === id);
}
function makeModel(provider: string, id: string, api: string): Model<Api> {
@ -62,10 +81,33 @@ function makeContext(): Context {
}
/** No-op streamSimple for tests that need one to pass validation but don't inspect it. */
const noopStreamSimple = (_model: Model<Api>, _context: Context, _options?: SimpleStreamOptions) => {
const noopStreamSimple = (
_model: Model<Api>,
_context: Context,
_options?: SimpleStreamOptions,
) => {
return {
[Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; },
result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }),
[Symbol.asyncIterator]() {
return { next: async () => ({ value: undefined, done: true as const }) };
},
result: () =>
Promise.resolve({
role: "assistant" as const,
content: [],
api: "test" as Api,
provider: "test",
model: "test",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop" as const,
timestamp: Date.now(),
}),
push: () => {},
end: () => {},
} as unknown as AssistantMessageEventStream;
@ -73,16 +115,51 @@ const noopStreamSimple = (_model: Model<Api>, _context: Context, _options?: Simp
/** Create a spy streamSimple that captures the options it receives and returns a stub stream. */
function createStreamSpy(): {
streamSimple: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
streamSimple: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
getCapturedOptions: () => SimpleStreamOptions | undefined;
} {
let capturedOptions: SimpleStreamOptions | undefined;
const streamSimple = (_model: Model<Api>, _context: Context, options?: SimpleStreamOptions) => {
const streamSimple = (
_model: Model<Api>,
_context: Context,
options?: SimpleStreamOptions,
) => {
capturedOptions = options;
// Return a minimal stub that satisfies AssistantMessageEventStream
return {
[Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; },
result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }),
[Symbol.asyncIterator]() {
return {
next: async () => ({ value: undefined, done: true as const }),
};
},
result: () =>
Promise.resolve({
role: "assistant" as const,
content: [],
api: "test" as Api,
provider: "test",
model: "test",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
},
stopReason: "stop" as const,
timestamp: Date.now(),
}),
push: () => {},
end: () => {},
} as unknown as AssistantMessageEventStream;
@ -123,102 +200,153 @@ describe("ModelRegistry authMode — registration", () => {
it("rejects apiKey provider without apiKey or oauth — message mentions authMode", () => {
const registry = createRegistry();
assert.throws(() => {
registry.registerProvider("apikey-provider", {
authMode: "apiKey",
baseUrl: "https://api.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("authMode"), "error message must mention authMode");
assert.ok(err.message.includes("externalCli"), "error message must suggest externalCli");
return true;
});
assert.throws(
() => {
registry.registerProvider("apikey-provider", {
authMode: "apiKey",
baseUrl: "https://api.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("authMode"),
"error message must mention authMode",
);
assert.ok(
err.message.includes("externalCli"),
"error message must suggest externalCli",
);
return true;
},
);
});
it("rejects provider with no authMode and no apiKey/oauth (defaults to apiKey)", () => {
const registry = createRegistry();
assert.throws(() => {
registry.registerProvider("bare-provider", {
baseUrl: "https://api.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("authMode"), "error message must mention authMode");
return true;
});
assert.throws(
() => {
registry.registerProvider("bare-provider", {
baseUrl: "https://api.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("authMode"),
"error message must mention authMode",
);
return true;
},
);
});
it("rejects externalCli provider without streamSimple", () => {
const registry = createRegistry();
assert.throws(() => {
registry.registerProvider("cli-no-stream", {
authMode: "externalCli",
baseUrl: "https://cli.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple");
assert.ok(err.message.includes("externalCli"), "error message must mention authMode");
return true;
});
assert.throws(
() => {
registry.registerProvider("cli-no-stream", {
authMode: "externalCli",
baseUrl: "https://cli.local",
api: "openai-completions",
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("streamSimple"),
"error message must mention streamSimple",
);
assert.ok(
err.message.includes("externalCli"),
"error message must mention authMode",
);
return true;
},
);
});
it("rejects none provider without streamSimple", () => {
const registry = createRegistry();
assert.throws(() => {
registry.registerProvider("none-no-stream", {
authMode: "none",
baseUrl: "http://localhost:11434",
api: "openai-completions",
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple");
assert.ok(err.message.includes("none"), "error message must mention authMode");
return true;
});
assert.throws(
() => {
registry.registerProvider("none-no-stream", {
authMode: "none",
baseUrl: "http://localhost:11434",
api: "openai-completions",
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("streamSimple"),
"error message must mention streamSimple",
);
assert.ok(
err.message.includes("none"),
"error message must mention authMode",
);
return true;
},
);
});
it("rejects externalCli provider that also sets apiKey", () => {
const registry = createRegistry();
const spy = createStreamSpy();
assert.throws(() => {
registry.registerProvider("cli-with-key", {
authMode: "externalCli",
baseUrl: "https://cli.local",
api: "openai-completions",
apiKey: "SHOULD_NOT_EXIST",
streamSimple: spy.streamSimple,
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("apiKey"), "error message must mention apiKey");
assert.ok(err.message.includes("externalCli"), "error message must mention authMode");
return true;
});
assert.throws(
() => {
registry.registerProvider("cli-with-key", {
authMode: "externalCli",
baseUrl: "https://cli.local",
api: "openai-completions",
apiKey: "SHOULD_NOT_EXIST",
streamSimple: spy.streamSimple,
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("apiKey"),
"error message must mention apiKey",
);
assert.ok(
err.message.includes("externalCli"),
"error message must mention authMode",
);
return true;
},
);
});
it("rejects none provider that also sets apiKey", () => {
const registry = createRegistry();
const spy = createStreamSpy();
assert.throws(() => {
registry.registerProvider("none-with-key", {
authMode: "none",
baseUrl: "http://localhost:11434",
api: "openai-completions",
apiKey: "SHOULD_NOT_EXIST",
streamSimple: spy.streamSimple,
models: [createProviderModel("model")],
});
}, (err: Error) => {
assert.ok(err.message.includes("apiKey"), "error message must mention apiKey");
assert.ok(err.message.includes("none"), "error message must mention authMode");
return true;
});
assert.throws(
() => {
registry.registerProvider("none-with-key", {
authMode: "none",
baseUrl: "http://localhost:11434",
api: "openai-completions",
apiKey: "SHOULD_NOT_EXIST",
streamSimple: spy.streamSimple,
models: [createProviderModel("model")],
});
},
(err: Error) => {
assert.ok(
err.message.includes("apiKey"),
"error message must mention apiKey",
);
assert.ok(
err.message.includes("none"),
"error message must mention authMode",
);
return true;
},
);
});
});
@ -230,6 +358,14 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => {
assert.equal(registry.getProviderAuthMode("anthropic"), "apiKey");
});
it("treats google-gemini-cli as external CLI auth", () => {
const registry = createRegistry();
assert.equal(
registry.getProviderAuthMode("google-gemini-cli"),
"externalCli",
);
});
it("returns explicit authMode when set", () => {
const registry = createRegistry();
registry.registerProvider("cli", {
@ -258,6 +394,11 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => {
// ─── isProviderRequestReady ───────────────────────────────────────────────────
describe("ModelRegistry authMode — isProviderRequestReady", () => {
it("returns true for google-gemini-cli without .sf stored auth", () => {
const registry = createRegistry(() => false);
assert.equal(registry.isProviderRequestReady("google-gemini-cli"), true);
});
it("returns true for externalCli without stored auth", () => {
const registry = createRegistry(() => false);
registry.registerProvider("cli", {
@ -391,7 +532,10 @@ describe("ModelRegistry authMode — getAvailable", () => {
it("excludes apiKey models without stored auth", () => {
const registry = createRegistry(() => false);
const available = registry.getAvailable();
assert.equal(available.length, 0);
assert.equal(
available.filter((m) => m.provider !== "google-gemini-cli").length,
0,
);
});
it("prunes Codex models removed from ChatGPT-backed openai-codex OAuth", () => {
@ -407,7 +551,10 @@ describe("ModelRegistry authMode — getAvailable", () => {
assert.equal(registry.find("openai-codex", "gpt-5.1-codex-max"), undefined);
assert.equal(registry.find("openai-codex", "gpt-5.1"), undefined);
assert.equal(findModel(registry, "openai-codex", "gpt-5.2-codex"), undefined);
assert.equal(
findModel(registry, "openai-codex", "gpt-5.2-codex"),
undefined,
);
assert.ok(registry.find("openai-codex", "gpt-5.4"));
assert.ok(findModel(registry, "openai-codex", "gpt-5.4"));
});
@ -428,6 +575,22 @@ describe("ModelRegistry authMode — getAvailable", () => {
// ─── getApiKey ────────────────────────────────────────────────────────────────
describe("ModelRegistry authMode — getApiKey", () => {
it("returns undefined for google-gemini-cli even when stale .sf auth exists", async () => {
const registry = createInMemoryRegistry({
"google-gemini-cli": {
type: "oauth",
access: "",
refresh: "",
expires: 0,
},
});
assert.equal(
await registry.getApiKeyForProvider("google-gemini-cli"),
undefined,
);
});
it("returns undefined for externalCli provider", async () => {
const registry = createRegistry();
registry.registerProvider("cli", {
@ -480,15 +643,18 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
const provider = getApiProvider(apiType as Api);
assert.ok(provider, "provider must be registered in api registry");
provider.streamSimple(
makeModel("cli-strip", "m", apiType),
makeContext(),
{ apiKey: "should-be-stripped", maxTokens: 1024 } as SimpleStreamOptions,
);
provider.streamSimple(makeModel("cli-strip", "m", apiType), makeContext(), {
apiKey: "should-be-stripped",
maxTokens: 1024,
} as SimpleStreamOptions);
const captured = spy.getCapturedOptions();
assert.ok(captured, "streamSimple must have been called");
assert.equal("apiKey" in captured, false, "apiKey must not exist in options for externalCli provider");
assert.equal(
"apiKey" in captured,
false,
"apiKey must not exist in options for externalCli provider",
);
assert.equal(captured.maxTokens, 1024, "other options must pass through");
});
@ -516,7 +682,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
const captured = spy.getCapturedOptions();
assert.ok(captured, "streamSimple must have been called");
assert.equal("apiKey" in captured, false, "apiKey must not exist in options for none provider");
assert.equal(
"apiKey" in captured,
false,
"apiKey must not exist in options for none provider",
);
assert.equal(captured.maxTokens, 2048, "other options must pass through");
});
@ -544,7 +714,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
const captured = spy.getCapturedOptions();
assert.ok(captured, "streamSimple must have been called");
assert.equal(captured.apiKey, "sk-real-key", "apiKey must be preserved for apiKey provider");
assert.equal(
captured.apiKey,
"sk-real-key",
"apiKey must be preserved for apiKey provider",
);
assert.equal(captured.maxTokens, 4096, "other options must pass through");
});
@ -572,7 +746,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
const captured = spy.getCapturedOptions();
assert.ok(captured !== undefined, "streamSimple must have been called");
assert.equal("apiKey" in captured, false, "apiKey must not exist even when options is undefined");
assert.equal(
"apiKey" in captured,
false,
"apiKey must not exist even when options is undefined",
);
});
it("strips apiKey but preserves signal and other fields for externalCli", () => {
@ -595,15 +773,28 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
provider.streamSimple(
makeModel("cli-fields", "m", apiType),
makeContext(),
{ apiKey: "strip-me", maxTokens: 8192, signal: abortController.signal, reasoning: "high" } as SimpleStreamOptions,
{
apiKey: "strip-me",
maxTokens: 8192,
signal: abortController.signal,
reasoning: "high",
} as SimpleStreamOptions,
);
const captured = spy.getCapturedOptions();
assert.ok(captured, "streamSimple must have been called");
assert.equal("apiKey" in captured, false, "apiKey must be stripped");
assert.equal(captured.maxTokens, 8192, "maxTokens must pass through");
assert.equal(captured.signal, abortController.signal, "signal must pass through");
assert.equal((captured as Record<string, unknown>).reasoning, "high", "reasoning must pass through");
assert.equal(
captured.signal,
abortController.signal,
"signal must pass through",
);
assert.equal(
(captured as Record<string, unknown>).reasoning,
"high",
"reasoning must pass through",
);
});
});
@ -633,11 +824,12 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => {
// The built-in handler will throw (no API key), which proves the routing
// correctly delegates to the built-in instead of the custom handler.
assert.throws(
() => provider.streamSimple(
makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"),
makeContext(),
{ maxTokens: 4096 } as SimpleStreamOptions,
),
() =>
provider.streamSimple(
makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"),
makeContext(),
{ maxTokens: 4096 } as SimpleStreamOptions,
),
(err: Error) => err.message.includes("API key"),
"built-in Anthropic handler must be invoked (throws because no API key in tests)",
);
@ -672,7 +864,10 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => {
);
const captured = customSpy.getCapturedOptions();
assert.ok(captured, "custom provider's streamSimple must be called for its own models");
assert.ok(
captured,
"custom provider's streamSimple must be called for its own models",
);
assert.equal(captured.maxTokens, 2048);
});
});

View file

@ -2,10 +2,11 @@
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import { type Static, Type } from "@sinclair/typebox";
import {
type Api,
applyCapabilityPatches,
type AssistantMessageEventStream,
applyCapabilityPatches,
type Context,
getApiProvider,
getEnvApiKey,
@ -20,14 +21,17 @@ import {
resetApiProviders,
type SimpleStreamOptions,
} from "@singularity-forge/pi-ai";
import { registerOAuthProvider, resetOAuthProviders } from "@singularity-forge/pi-ai/oauth";
import { type Static, Type } from "@sinclair/typebox";
import {
registerOAuthProvider,
resetOAuthProviders,
} from "@singularity-forge/pi-ai/oauth";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import type { AuthStorage } from "./auth-storage.js";
import { ModelDiscoveryCache } from "./discovery-cache.js";
import { isLocalModel } from "./local-model-check.js";
import type { DiscoveryResult } from "./model-discovery.js";
import {
getDiscoverableCatalogSources,
@ -35,7 +39,6 @@ import {
getDiscoveryAdapter,
} from "./model-discovery.js";
import { resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
import { isLocalModel } from "./local-model-check.js";
const Ajv = (AjvModule as any).default || AjvModule;
const ajv = new Ajv();
@ -70,27 +73,33 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{
family_failover?: string[];
}> = [
// MiniMax direct (api.minimax.io) → CN endpoint as its direct pair
{ match: /^MiniMax-/i, prefix: "MiniMax-", providers: ["minimax", "minimax-cn"] },
{
match: /^MiniMax-/i,
prefix: "MiniMax-",
providers: ["minimax", "minimax-cn"],
},
// ZAI direct API for GLM
{ match: /^glm-/i, prefix: "glm-", providers: ["zai"] },
{ match: /^glm-/i, prefix: "glm-", providers: ["zai"] },
// Kimi Code direct API
{ match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] },
{ match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] },
// MiMo/Xiaomi — direct API via Xiaomi MiMo Open Platform (api.xiaomimimo.com)
// or the Token Plan endpoint (token-plan-sgp.xiaomimimo.com). Both served
// under the `xiaomi` provider namespace.
{ match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] },
// Gemini/Gemma: google-gemini-cli (OAuth), google (API key), google-vertex
// are all FIRST-PARTY Google endpoints. github-copilot re-serves and is
// failover only.
{ match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] },
// Gemini/Gemma: google-gemini-cli (CLI OAuth via ~/.gemini), google
// (API key), google-vertex are all FIRST-PARTY Google endpoints.
// github-copilot re-serves and is failover only.
{
match: /^gemini-|^gemma-/i, prefix: "gemini-",
match: /^gemini-|^gemma-/i,
prefix: "gemini-",
providers: ["google-gemini-cli", "google", "google-vertex"],
family_failover: ["github-copilot"],
},
// Claude: Anthropic is the ONLY direct provider. github-copilot re-serves
// Claude via GitHub's platform as failover.
{
match: /^claude-/i, prefix: "claude-",
match: /^claude-/i,
prefix: "claude-",
providers: ["anthropic"],
family_failover: ["github-copilot"],
},
@ -99,9 +108,14 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{
// the same weights via a different legal/contractual relationship).
// github-copilot likewise re-serves.
{
match: /^gpt-|^o\d|^codex-/i, prefix: "gpt-",
match: /^gpt-|^o\d|^codex-/i,
prefix: "gpt-",
providers: ["openai"],
family_failover: ["azure-openai-responses", "openai-codex", "github-copilot"],
family_failover: [
"azure-openai-responses",
"openai-codex",
"github-copilot",
],
},
];
@ -123,12 +137,23 @@ const OpenAICompletionsCompatSchema = Type.Object({
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
maxTokensField: Type.Optional(
Type.Union([
Type.Literal("max_completion_tokens"),
Type.Literal("max_tokens"),
]),
),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai"), Type.Literal("qwen")])),
thinkingFormat: Type.Optional(
Type.Union([
Type.Literal("openai"),
Type.Literal("zai"),
Type.Literal("qwen"),
]),
),
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
});
@ -137,7 +162,10 @@ const OpenAIResponsesCompatSchema = Type.Object({
// Reserved for future use
});
const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]);
const OpenAICompatSchema = Type.Union([
OpenAICompletionsCompatSchema,
OpenAIResponsesCompatSchema,
]);
// Schema for custom model definition
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
@ -147,7 +175,9 @@ const ModelDefinitionSchema = Type.Object({
api: Type.Optional(Type.String({ minLength: 1 })),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Number(),
@ -166,7 +196,9 @@ const ModelDefinitionSchema = Type.Object({
const ModelOverrideSchema = Type.Object({
name: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Optional(Type.Number()),
@ -190,7 +222,9 @@ const ProviderConfigSchema = Type.Object({
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
modelOverrides: Type.Optional(Type.Record(Type.String(), ModelOverrideSchema)),
modelOverrides: Type.Optional(
Type.Record(Type.String(), ModelOverrideSchema),
),
});
const ModelsConfigSchema = Type.Object({
@ -205,7 +239,8 @@ export type ProviderModelAllowList = Record<string, readonly string[]>;
export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none";
type ProviderPolicyModel = Pick<Model<Api>, "provider" | "id"> & Partial<Pick<Model<Api>, "name" | "cost">>;
type ProviderPolicyModel = Pick<Model<Api>, "provider" | "id"> &
Partial<Pick<Model<Api>, "name" | "cost">>;
const OPENCODE_FREE_MODEL_IDS = new Set([
"big-pickle",
@ -221,7 +256,12 @@ const HIDDEN_MODEL_PROVIDERS = new Set([
"xiaomi-token-plan-sgp",
]);
function providerModelAllowEntryMatches(allowedModel: string, modelKey: string): boolean {
const BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS = new Set(["google-gemini-cli"]);
function providerModelAllowEntryMatches(
allowedModel: string,
modelKey: string,
): boolean {
const allowedKey = allowedModel.trim().toLowerCase();
if (!allowedKey) return false;
if (allowedKey === modelKey) return true;
@ -239,7 +279,13 @@ function hasFreeSkuMarker(value: string | undefined): boolean {
}
function isZeroCost(cost: Model<Api>["cost"] | undefined): boolean {
return !!cost && cost.input === 0 && cost.output === 0 && cost.cacheRead === 0 && cost.cacheWrite === 0;
return (
!!cost &&
cost.input === 0 &&
cost.output === 0 &&
cost.cacheRead === 0 &&
cost.cacheWrite === 0
);
}
function isMistralSelectionModel(modelId: string): boolean {
@ -259,7 +305,9 @@ function isMistralSelectionModel(modelId: string): boolean {
return true;
}
function isModelAllowedByBuiltInProviderPolicy(model: ProviderPolicyModel): boolean {
function isModelAllowedByBuiltInProviderPolicy(
model: ProviderPolicyModel,
): boolean {
const provider = model.provider.toLowerCase();
const modelKey = model.id.trim().toLowerCase();
if (HIDDEN_MODEL_PROVIDERS.has(provider)) {
@ -309,22 +357,35 @@ function mergeCompat(
): Model<Api>["compat"] | undefined {
if (!overrideCompat) return baseCompat;
const base = baseCompat as OpenAICompletionsCompat | OpenAIResponsesCompat | undefined;
const override = overrideCompat as OpenAICompletionsCompat | OpenAIResponsesCompat;
const merged = { ...base, ...override } as OpenAICompletionsCompat | OpenAIResponsesCompat;
const base = baseCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat
| undefined;
const override = overrideCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const merged = { ...base, ...override } as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const baseCompletions = base as OpenAICompletionsCompat | undefined;
const overrideCompletions = override as OpenAICompletionsCompat;
const mergedCompletions = merged as OpenAICompletionsCompat;
if (baseCompletions?.openRouterRouting || overrideCompletions.openRouterRouting) {
if (
baseCompletions?.openRouterRouting ||
overrideCompletions.openRouterRouting
) {
mergedCompletions.openRouterRouting = {
...baseCompletions?.openRouterRouting,
...overrideCompletions.openRouterRouting,
};
}
if (baseCompletions?.vercelGatewayRouting || overrideCompletions.vercelGatewayRouting) {
if (
baseCompletions?.vercelGatewayRouting ||
overrideCompletions.vercelGatewayRouting
) {
mergedCompletions.vercelGatewayRouting = {
...baseCompletions?.vercelGatewayRouting,
...overrideCompletions.vercelGatewayRouting,
@ -338,14 +399,19 @@ function mergeCompat(
* Deep merge a model override into a model.
* Handles nested objects (cost, compat) by merging rather than replacing.
*/
function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<Api> {
function applyModelOverride(
model: Model<Api>,
override: ModelOverride,
): Model<Api> {
const result = { ...model };
// Simple field overrides
if (override.name !== undefined) result.name = override.name;
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
if (override.input !== undefined) result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined) result.contextWindow = override.contextWindow;
if (override.input !== undefined)
result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined)
result.contextWindow = override.contextWindow;
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
// Merge cost (partial override)
@ -361,7 +427,9 @@ function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<A
// Merge headers
if (override.headers) {
const resolvedHeaders = resolveHeaders(override.headers);
result.headers = resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers;
result.headers = resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers;
}
// Deep merge compat
@ -370,7 +438,6 @@ function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<A
return result;
}
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
@ -384,7 +451,10 @@ export class ModelRegistry {
constructor(
readonly authStorage: AuthStorage,
readonly modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
readonly modelsJsonPath: string | undefined = join(
getAgentDir(),
"models.json",
),
discoveryCache?: ModelDiscoveryCache,
) {
this.discoveryCache = discoveryCache ?? new ModelDiscoveryCache();
@ -437,7 +507,9 @@ export class ModelRegistry {
overrides,
modelOverrides,
error,
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
} = this.modelsJsonPath
? this.loadCustomModels(this.modelsJsonPath)
: emptyCustomModelsResult();
if (error) {
this.loadError = error;
@ -480,7 +552,9 @@ export class ModelRegistry {
model = {
...model,
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
headers: resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers,
};
}
@ -496,10 +570,15 @@ export class ModelRegistry {
}
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
private mergeCustomModels(builtInModels: Model<Api>[], customModels: Model<Api>[]): Model<Api>[] {
private mergeCustomModels(
builtInModels: Model<Api>[],
customModels: Model<Api>[],
): Model<Api>[] {
const merged = [...builtInModels];
for (const customModel of customModels) {
const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id);
const existingIndex = merged.findIndex(
(m) => m.provider === customModel.provider && m.id === customModel.id,
);
if (existingIndex >= 0) {
merged[existingIndex] = customModel;
} else {
@ -509,22 +588,32 @@ export class ModelRegistry {
return merged;
}
private isProviderModelAllowed(model: ProviderPolicyModel, providerModelAllow?: ProviderModelAllowList): boolean {
private isProviderModelAllowed(
model: ProviderPolicyModel,
providerModelAllow?: ProviderModelAllowList,
): boolean {
if (!isModelAllowedByBuiltInProviderPolicy(model)) return false;
if (!providerModelAllow) return true;
const providerKey = model.provider.toLowerCase();
const allowedModels = providerModelAllow[providerKey]
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
const allowedModels =
providerModelAllow[providerKey] ??
Object.entries(providerModelAllow).find(
([key]) => key.toLowerCase() === providerKey,
)?.[1];
if (allowedModels === undefined) return true;
const modelKey = model.id.trim().toLowerCase();
return allowedModels.some((allowedModel) => providerModelAllowEntryMatches(allowedModel, modelKey));
return allowedModels.some((allowedModel) =>
providerModelAllowEntryMatches(allowedModel, modelKey),
);
}
private filterProviderModelAllow<T extends Model<Api>>(
models: T[],
providerModelAllow?: ProviderModelAllowList,
): T[] {
return models.filter((model) => this.isProviderModelAllowed(model, providerModelAllow));
return models.filter((model) =>
this.isProviderModelAllowed(model, providerModelAllow),
);
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
@ -540,9 +629,12 @@ export class ModelRegistry {
const validate = ajv.getSchema("ModelsConfig")!;
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`);
validate.errors
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
.join("\n") || "Unknown schema error";
return emptyCustomModelsResult(
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
);
}
// Additional validation
@ -551,9 +643,15 @@ export class ModelRegistry {
const overrides = new Map<string, ProviderOverride>();
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) {
if (
providerConfig.baseUrl ||
providerConfig.headers ||
providerConfig.apiKey
) {
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
@ -567,14 +665,24 @@ export class ModelRegistry {
}
if (providerConfig.modelOverrides) {
modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides)));
modelOverrides.set(
providerName,
new Map(Object.entries(providerConfig.modelOverrides)),
);
}
}
return { models: this.parseModels(config), overrides, modelOverrides, error: undefined };
return {
models: this.parseModels(config),
overrides,
modelOverrides,
error: undefined,
};
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
return emptyCustomModelsResult(
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
);
}
return emptyCustomModelsResult(
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
@ -583,24 +691,33 @@ export class ModelRegistry {
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
const hasModelOverrides =
providerConfig.modelOverrides && Object.keys(providerConfig.modelOverrides).length > 0;
providerConfig.modelOverrides &&
Object.keys(providerConfig.modelOverrides).length > 0;
if (models.length === 0) {
// Override-only config: needs baseUrl OR modelOverrides (or both)
if (!providerConfig.baseUrl && !hasModelOverrides) {
throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`);
throw new Error(
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
);
}
} else {
// Custom models are merged into provider models and require endpoint + auth.
if (!providerConfig.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
);
}
if (!providerConfig.apiKey) {
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
throw new Error(
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
);
}
}
@ -613,12 +730,17 @@ export class ModelRegistry {
);
}
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
if (!modelDef.id)
throw new Error(`Provider ${providerName}: model missing "id"`);
// Validate contextWindow/maxTokens only if provided (they have defaults)
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
);
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
);
}
}
}
@ -626,7 +748,9 @@ export class ModelRegistry {
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
@ -655,7 +779,10 @@ export class ModelRegistry {
// Resolve env vars and shell commands in header values
const providerHeaders = resolveHeaders(providerConfig.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader && providerConfig.apiKey) {
@ -667,7 +794,12 @@ export class ModelRegistry {
// Provider baseUrl is required when custom models are defined.
// Individual models can override it with modelDef.baseUrl.
const defaultCost = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
const defaultCost = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
};
models.push({
id: modelDef.id,
name: modelDef.name ?? modelDef.id,
@ -712,6 +844,7 @@ export class ModelRegistry {
* Defaults to "apiKey" for built-ins and providers without explicit mode.
*/
getProviderAuthMode(provider: string): ProviderAuthMode {
if (BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS.has(provider)) return "externalCli";
const config = this.registeredProviders.get(provider);
if (!config) return "apiKey";
if (config.authMode) return config.authMode;
@ -745,10 +878,15 @@ export class ModelRegistry {
* Returns undefined for externalCli/none providers (no key needed).
* @param sessionId - Optional session ID for sticky credential selection
*/
async getApiKey(model: Model<Api>, sessionId?: string): Promise<string | undefined> {
async getApiKey(
model: Model<Api>,
sessionId?: string,
): Promise<string | undefined> {
const authMode = this.getProviderAuthMode(model.provider);
if (authMode === "externalCli" || authMode === "none") return undefined;
return this.authStorage.getApiKey(model.provider, sessionId, { baseUrl: model.baseUrl });
return this.authStorage.getApiKey(model.provider, sessionId, {
baseUrl: model.baseUrl,
});
}
/**
@ -756,7 +894,10 @@ export class ModelRegistry {
* Returns undefined for externalCli/none providers (no key needed).
* @param sessionId - Optional session ID for sticky credential selection
*/
async getApiKeyForProvider(provider: string, sessionId?: string): Promise<string | undefined> {
async getApiKeyForProvider(
provider: string,
sessionId?: string,
): Promise<string | undefined> {
const authMode = this.getProviderAuthMode(provider);
if (authMode === "externalCli" || authMode === "none") return undefined;
return this.authStorage.getApiKey(provider, sessionId);
@ -798,7 +939,10 @@ export class ModelRegistry {
this.refresh();
}
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
private applyProviderConfig(
providerName: string,
config: ProviderConfigInput,
): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
@ -811,19 +955,30 @@ export class ModelRegistry {
if (config.streamSimple) {
if (!config.api) {
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
throw new Error(
`Provider ${providerName}: "api" is required when registering streamSimple.`,
);
}
const rawStreamSimple = config.streamSimple;
const authMode = config.authMode ?? "apiKey";
// Keyless providers never see apiKey in options — enforced at registration,
// not by convention. Prevents undefined from reaching any handler.
const streamSimple = (authMode === "externalCli" || authMode === "none")
? ((model: Model<Api>, context: Context, options?: SimpleStreamOptions) => {
const { apiKey: _, ...opts } = options ?? {};
return rawStreamSimple(model, context, opts as SimpleStreamOptions);
})
: rawStreamSimple;
const streamSimple =
authMode === "externalCli" || authMode === "none"
? (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => {
const { apiKey: _, ...opts } = options ?? {};
return rawStreamSimple(
model,
context,
opts as SimpleStreamOptions,
);
}
: rawStreamSimple;
// Guard: if there's already a handler registered for this API, wrap
// the new one so it only fires for models from this provider and
@ -832,7 +987,11 @@ export class ModelRegistry {
// the built-in Anthropic stream handler (#2536).
const existingProvider = getApiProvider(config.api as Api);
const scopedStream = existingProvider
? (model: Model<Api>, context: Context, options?: SimpleStreamOptions): AssistantMessageEventStream => {
? (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
if (model.provider === providerName) {
return streamSimple(model, context, options);
}
@ -840,12 +999,23 @@ export class ModelRegistry {
}
: streamSimple;
const newFullStream = (model: Model<Api>, context: Context, options?: SimpleStreamOptions) =>
scopedStream(model, context, options as SimpleStreamOptions);
const newFullStream = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => scopedStream(model, context, options as SimpleStreamOptions);
const scopedFullStream = existingProvider
? (model: Model<Api>, context: Context, options?: Record<string, unknown>) => {
? (
model: Model<Api>,
context: Context,
options?: Record<string, unknown>,
) => {
if (model.provider === providerName) {
return newFullStream(model, context, options as SimpleStreamOptions);
return newFullStream(
model,
context,
options as SimpleStreamOptions,
);
}
return existingProvider.stream(model, context, options);
}
@ -872,25 +1042,35 @@ export class ModelRegistry {
// Validate required fields
if (!config.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining models.`,
);
}
const authMode = config.authMode ?? (config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey");
const authMode =
config.authMode ??
(config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey");
if (authMode === "apiKey" && !config.apiKey && !config.oauth) {
throw new Error(
`Provider ${providerName}: "apiKey" or "oauth" is required when authMode is "apiKey" (the default). ` +
`Set authMode to "externalCli" or "none" for keyless providers.`,
`Set authMode to "externalCli" or "none" for keyless providers.`,
);
}
if ((authMode === "externalCli" || authMode === "none") && !config.streamSimple) {
if (
(authMode === "externalCli" || authMode === "none") &&
!config.streamSimple
) {
throw new Error(
`Provider ${providerName}: "streamSimple" is required when authMode is "${authMode}". ` +
`Keyless providers must supply their own stream handler.`,
`Keyless providers must supply their own stream handler.`,
);
}
if ((authMode === "externalCli" || authMode === "none") && config.apiKey) {
if (
(authMode === "externalCli" || authMode === "none") &&
config.apiKey
) {
throw new Error(
`Provider ${providerName}: "apiKey" cannot be set when authMode is "${authMode}". ` +
`Keyless providers should not provide API key credentials.`,
`Keyless providers should not provide API key credentials.`,
);
}
@ -898,13 +1078,18 @@ export class ModelRegistry {
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
@ -949,15 +1134,24 @@ export class ModelRegistry {
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
headers: resolvedHeaders
? { ...m.headers, ...resolvedHeaders }
: m.headers,
};
});
}
}
private buildCandidateOrder(modelId: string, overrides: Record<string, string[]>): string[] {
const overrideEntry = Object.entries(overrides).find(([k]) => modelId.startsWith(k));
const familyEntry = PROXY_FAMILY_PRIORITY.find((r) => r.match.test(modelId));
private buildCandidateOrder(
modelId: string,
overrides: Record<string, string[]>,
): string[] {
const overrideEntry = Object.entries(overrides).find(([k]) =>
modelId.startsWith(k),
);
const familyEntry = PROXY_FAMILY_PRIORITY.find((r) =>
r.match.test(modelId),
);
// Order: direct family providers → family-scoped failover → global fallback.
// Overrides replace only the direct list (keeps family_failover + global
// chain intact) so a user pinning "glm- → [zai]" still picks up
@ -992,8 +1186,12 @@ export class ModelRegistry {
return (pa === -1 ? Infinity : pa) - (pb === -1 ? Infinity : pb);
});
const withAuth = sorted.filter((m) => this.isProviderRequestReady(m.provider));
const withoutAuth = sorted.filter((m) => !this.isProviderRequestReady(m.provider));
const withAuth = sorted.filter((m) =>
this.isProviderRequestReady(m.provider),
);
const withoutAuth = sorted.filter(
(m) => !this.isProviderRequestReady(m.provider),
);
return [...withAuth, ...withoutAuth];
}
@ -1064,7 +1262,9 @@ export class ModelRegistry {
}
// Convert and merge discovered models, then apply capability patches
this.discoveredModels = applyCapabilityPatches(this.convertDiscoveredModels(results));
this.discoveredModels = applyCapabilityPatches(
this.convertDiscoveredModels(results),
);
return results;
}
@ -1082,8 +1282,12 @@ export class ModelRegistry {
* Discovered models are appended but never override existing models.
*/
getAllWithDiscovered(): Model<Api>[] {
const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`));
const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`));
const existingIds = new Set(
this.models.map((m) => `${m.provider}/${m.id}`),
);
const unique = this.discoveredModels.filter(
(m) => !existingIds.has(`${m.provider}/${m.id}`),
);
return this.filterProviderModelAllow([...this.models, ...unique]);
}
@ -1094,15 +1298,22 @@ export class ModelRegistry {
* providers with the provider's actual `/models` response.
* Consumer: cli/list-models.ts when `--discover` or an exact provider query is used.
*/
getDiscoveredModels(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
return this.filterProviderModelAllow(this.discoveredModels, providerModelAllow);
getDiscoveredModels(
providerModelAllow?: ProviderModelAllowList,
): Model<Api>[] {
return this.filterProviderModelAllow(
this.discoveredModels,
providerModelAllow,
);
}
/**
* Check if a model was added via discovery (not built-in or custom).
*/
isDiscovered(model: Model<Api>): boolean {
return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id);
return this.discoveredModels.some(
(m) => m.provider === model.provider && m.id === model.id,
);
}
/**
@ -1125,7 +1336,9 @@ export class ModelRegistry {
const key = `${provider}/${dm.id}`;
if (seen.has(key)) continue;
seen.add(key);
const known = this.models.find((m) => m.provider === provider && m.id === dm.id);
const known = this.models.find(
(m) => m.provider === provider && m.id === dm.id,
);
const discoveredName =
dm.name && dm.name !== dm.id ? dm.name : undefined;
converted.push({
@ -1137,7 +1350,8 @@ export class ModelRegistry {
baseUrl: dm.baseUrl ?? known?.baseUrl ?? "",
reasoning: dm.reasoning ?? known?.reasoning ?? false,
input: dm.input ?? known?.input ?? ["text"],
cost: dm.cost ?? known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
cost: dm.cost ??
known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: dm.contextWindow ?? known?.contextWindow ?? 128000,
maxTokens: dm.maxTokens ?? known?.maxTokens ?? 16384,
} as Model<Api>);
@ -1177,7 +1391,11 @@ export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
streamSimple?: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
@ -1189,7 +1407,12 @@ export interface ProviderConfigInput {
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;

View file

@ -6,12 +6,12 @@
* downgrade from [1m] to base when no cross-provider fallback exists.
*/
import { vi, describe, it, beforeEach, type Mock } from 'vitest';
import assert from "node:assert/strict";
import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js";
import type { Api, AssistantMessage, Model } from "@singularity-forge/pi-ai";
import { beforeEach, describe, it, type Mock, vi } from "vitest";
import type { FallbackResolver } from "./fallback-resolver.js";
import type { ModelRegistry } from "./model-registry.js";
import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js";
import type { SettingsManager } from "./settings-manager.js";
// ─── Helpers ────────────────────────────────────────────────────────────────
@ -38,7 +38,14 @@ function errorMessage(msg: string): AssistantMessage {
api: "anthropic-messages",
provider: "anthropic",
model: "claude-opus-4-6[1m]",
usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "error",
errorMessage: msg,
timestamp: Date.now(),
@ -52,7 +59,9 @@ interface MockDeps {
onModelChangeFn: Mock<(model: Model<any>) => void>;
markUsageLimitReached: Mock<(...args: any[]) => boolean>;
findFallback: Mock<(...args: any[]) => Promise<any>>;
findModel: Mock<(provider: string, modelId: string) => Model<Api> | undefined>;
findModel: Mock<
(provider: string, modelId: string) => Model<Api> | undefined
>;
}
function createMockDeps(overrides?: {
@ -60,14 +69,19 @@ function createMockDeps(overrides?: {
retryEnabled?: boolean;
markUsageLimitReachedResult?: boolean;
fallbackResult?: any;
findModelResult?: (provider: string, modelId: string) => Model<Api> | undefined;
findModelResult?: (
provider: string,
modelId: string,
) => Model<Api> | undefined;
providerAuthMode?: "apiKey" | "oauth" | "externalCli" | "none";
retrySettings?: {
maxRetries?: number;
baseDelayMs?: number;
maxDelayMs?: number;
};
}): MockDeps {
const model = overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]");
const model =
overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]");
const emittedEvents: Array<Record<string, any>> = [];
const continueFn = vi.fn(async () => {});
const onModelChangeFn = vi.fn((_model: Model<any>) => {});
@ -76,7 +90,8 @@ function createMockDeps(overrides?: {
);
const findFallback = vi.fn(async () => overrides?.fallbackResult ?? null);
const findModel = vi.fn(
overrides?.findModelResult ?? ((_provider: string, _modelId: string) => undefined),
overrides?.findModelResult ??
((_provider: string, _modelId: string) => undefined),
);
const messages: Array<{ role: string } & Record<string, any>> = [];
@ -105,6 +120,7 @@ function createMockDeps(overrides?: {
markUsageLimitReached,
},
find: findModel,
getProviderAuthMode: () => overrides?.providerAuthMode ?? "apiKey",
} as unknown as ModelRegistry,
fallbackResolver: {
findFallback,
@ -115,13 +131,20 @@ function createMockDeps(overrides?: {
onModelChange: onModelChangeFn,
};
return { deps, emittedEvents, continueFn, onModelChangeFn, markUsageLimitReached, findFallback, findModel };
return {
deps,
emittedEvents,
continueFn,
onModelChangeFn,
markUsageLimitReached,
findFallback,
findModel,
};
}
// ─── _classifyErrorType (tested via handleRetryableError behavior) ──────────
describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
describe("error classification", () => {
it("classifies 'Extra usage is required for long context requests' as quota_exhausted, not rate_limit", async () => {
// When the error is classified as quota_exhausted AND no alternate credentials
@ -136,7 +159,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const handler = new RetryHandler(deps);
const msg = errorMessage(
'429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}'
'429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}',
);
const result = await handler.handleRetryableError(msg);
@ -145,11 +168,22 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
assert.equal(result, false);
// Should emit fallback_chain_exhausted (quota_exhausted path), NOT auto_retry_start (backoff path)
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
assert.ok(chainExhausted, "Expected fallback_chain_exhausted event for entitlement error");
const chainExhausted = emittedEvents.find(
(e) => e.type === "fallback_chain_exhausted",
);
assert.ok(
chainExhausted,
"Expected fallback_chain_exhausted event for entitlement error",
);
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
assert.equal(retryStart, undefined, "Should NOT emit auto_retry_start for entitlement error");
const retryStart = emittedEvents.find(
(e) => e.type === "auto_retry_start",
);
assert.equal(
retryStart,
undefined,
"Should NOT emit auto_retry_start for entitlement error",
);
});
it("still classifies regular 429 rate limits as rate_limit", async () => {
@ -168,7 +202,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
// Should enter the backoff loop (rate_limit path, not quota_exhausted)
assert.equal(result, true);
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
const retryStart = emittedEvents.find(
(e) => e.type === "auto_retry_start",
);
assert.ok(retryStart, "Regular 429 should enter backoff retry");
});
@ -184,7 +220,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const handler = new RetryHandler(deps);
const msg = errorMessage(
'529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}'
'529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}',
);
const result = await handler.handleRetryableError(msg);
@ -192,11 +228,18 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
// Should enter the backoff loop (rate_limit path, not quota_exhausted)
assert.equal(result, true);
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
assert.ok(retryStart, "529 overloaded_error should enter backoff retry as rate_limit");
const retryStart = emittedEvents.find(
(e) => e.type === "auto_retry_start",
);
assert.ok(
retryStart,
"529 overloaded_error should enter backoff retry as rate_limit",
);
// Must NOT be treated as quota_exhausted (would emit fallback_chain_exhausted)
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
const chainExhausted = emittedEvents.find(
(e) => e.type === "fallback_chain_exhausted",
);
assert.equal(
chainExhausted,
undefined,
@ -218,27 +261,40 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const result = await handler.handleRetryableError(msg);
assert.equal(result, true, "affordability error should trigger credit-aware retry");
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
assert.ok(retryStart, "Expected immediate retry after reducing max tokens");
assert.equal(
result,
true,
"affordability error should trigger credit-aware retry",
);
const retryStart = emittedEvents.find(
(e) => e.type === "auto_retry_start",
);
assert.ok(
retryStart,
"Expected immediate retry after reducing max tokens",
);
});
});
describe("long-context model downgrade", () => {
it("downgrades from [1m] to base model when entitlement error and no fallback", async () => {
const baseModel = createMockModel("anthropic", "claude-opus-4-6");
const { deps, emittedEvents, onModelChangeFn, continueFn } = createMockDeps({
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
markUsageLimitReachedResult: false,
fallbackResult: null,
findModelResult: (provider: string, modelId: string) => {
if (provider === "anthropic" && modelId === "claude-opus-4-6") return baseModel;
return undefined;
},
});
const { deps, emittedEvents, onModelChangeFn, continueFn } =
createMockDeps({
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
markUsageLimitReachedResult: false,
fallbackResult: null,
findModelResult: (provider: string, modelId: string) => {
if (provider === "anthropic" && modelId === "claude-opus-4-6")
return baseModel;
return undefined;
},
});
const handler = new RetryHandler(deps);
const msg = errorMessage("Extra usage is required for long context requests.");
const msg = errorMessage(
"Extra usage is required for long context requests.",
);
const result = await handler.handleRetryableError(msg);
@ -253,9 +309,17 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
assert.equal(onModelChangeFn.mock.calls.length, 1);
// Should emit a fallback_provider_switch event indicating downgrade
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
assert.ok(switchEvent, "Expected fallback_provider_switch event for downgrade");
assert.ok(switchEvent!.reason.includes("long context downgrade"), `reason should mention downgrade: ${switchEvent!.reason}`);
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch",
);
assert.ok(
switchEvent,
"Expected fallback_provider_switch event for downgrade",
);
assert.ok(
switchEvent!.reason.includes("long context downgrade"),
`reason should mention downgrade: ${switchEvent!.reason}`,
);
});
it("emits fallback_chain_exhausted when base model is also unavailable", async () => {
@ -267,13 +331,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
});
const handler = new RetryHandler(deps);
const msg = errorMessage("Extra usage is required for long context requests.");
const msg = errorMessage(
"Extra usage is required for long context requests.",
);
const result = await handler.handleRetryableError(msg);
assert.equal(result, false);
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
assert.ok(chainExhausted, "Expected fallback_chain_exhausted when base model unavailable");
const chainExhausted = emittedEvents.find(
(e) => e.type === "fallback_chain_exhausted",
);
assert.ok(
chainExhausted,
"Expected fallback_chain_exhausted when base model unavailable",
);
});
it("does not attempt downgrade for non-[1m] models", async () => {
@ -286,17 +357,27 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
});
const handler = new RetryHandler(deps);
const msg = errorMessage("Extra usage is required for long context requests.");
const msg = errorMessage(
"Extra usage is required for long context requests.",
);
const result = await handler.handleRetryableError(msg);
assert.equal(result, false);
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
const chainExhausted = emittedEvents.find(
(e) => e.type === "fallback_chain_exhausted",
);
assert.ok(chainExhausted);
// No downgrade switch should occur
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
assert.equal(switchEvent, undefined, "Should not switch for non-[1m] models");
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch",
);
assert.equal(
switchEvent,
undefined,
"Should not switch for non-[1m] models",
);
});
});
@ -315,9 +396,19 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
handler.abortRetry();
await new Promise((resolve) => setTimeout(resolve, 10));
assert.equal(continueFn.mock.calls.length, 0, "cancelled retry must not continue after explicit abort");
const endEvents = emittedEvents.filter((e) => e.type === "auto_retry_end");
assert.equal(endEvents.length, 1, "retry cancellation should emit a single auto_retry_end event");
assert.equal(
continueFn.mock.calls.length,
0,
"cancelled retry must not continue after explicit abort",
);
const endEvents = emittedEvents.filter(
(e) => e.type === "auto_retry_end",
);
assert.equal(
endEvents.length,
1,
"retry cancellation should emit a single auto_retry_end event",
);
assert.equal(endEvents[0]?.finalError, "Retry cancelled");
});
});
@ -346,10 +437,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const downgraded = setModelCalls[0][0] as Model<Api>;
assert.equal(downgraded.provider, "openrouter");
assert.equal(downgraded.id, "openai/gpt-5-pro");
assert.equal(downgraded.maxTokens, 297, "expected affordability cap with safety buffer");
assert.equal(
downgraded.maxTokens,
297,
"expected affordability cap with safety buffer",
);
assert.equal(onModelChangeFn.mock.calls.length, 1, "should notify about model update");
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
assert.equal(
onModelChangeFn.mock.calls.length,
1,
"should notify about model update",
);
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch",
);
assert.ok(switchEvent, "should emit model-adjustment event");
assert.ok(
String(switchEvent?.reason || "").includes("credit-aware retry"),
@ -373,7 +474,11 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
);
await handler.handleRetryableError(msg);
assert.equal(markUsageLimitReached.mock.calls.length, 0, "quota error should skip credential cooldown");
assert.equal(
markUsageLimitReached.mock.calls.length,
0,
"quota error should skip credential cooldown",
);
});
});
@ -381,7 +486,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
it("considers long-context entitlement error as retryable", () => {
const { deps } = createMockDeps();
const handler = new RetryHandler(deps);
const msg = errorMessage("Extra usage is required for long context requests.");
const msg = errorMessage(
"Extra usage is required for long context requests.",
);
assert.equal(handler.isRetryableError(msg), true);
});
@ -393,7 +500,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const handler = new RetryHandler(deps);
const msg = errorMessage(
'All credentials for "anthropic" are in a cooldown window. ' +
'Please wait a moment and try again, or switch to a different provider.',
"Please wait a moment and try again, or switch to a different provider.",
);
assert.equal(handler.isRetryableError(msg), false);
});
@ -414,7 +521,8 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const { deps, emittedEvents, onModelChangeFn } = createMockDeps({
model: createMockModel("anthropic", "claude-opus-4-6"),
findModelResult: (provider: string, modelId: string) => {
if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel;
if (provider === "claude-code" && modelId === "claude-opus-4-6")
return ccModel;
return undefined;
},
});
@ -426,9 +534,14 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const result = await handler.handleRetryableError(msg);
assert.equal(result, true, "should retry via claude-code fallback");
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch",
);
assert.ok(switchEvent, "Expected fallback_provider_switch event");
assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider");
assert.ok(
switchEvent!.to.startsWith("claude-code/"),
"Should switch to claude-code provider",
);
});
it("switches to claude-code on 'out of extra usage' error (#3772)", async () => {
@ -436,21 +549,29 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const { deps, emittedEvents } = createMockDeps({
model: createMockModel("anthropic", "claude-opus-4-6"),
findModelResult: (provider: string, modelId: string) => {
if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel;
if (provider === "claude-code" && modelId === "claude-opus-4-6")
return ccModel;
return undefined;
},
});
deps.isClaudeCodeReady = () => true;
const handler = new RetryHandler(deps);
const msg = errorMessage("You're out of extra usage. Add more at claude.ai/settings/usage and keep going.");
const msg = errorMessage(
"You're out of extra usage. Add more at claude.ai/settings/usage and keep going.",
);
const result = await handler.handleRetryableError(msg);
assert.equal(result, true, "should retry via claude-code fallback");
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch",
);
assert.ok(switchEvent, "Expected fallback_provider_switch event");
assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider");
assert.ok(
switchEvent!.to.startsWith("claude-code/"),
"Should switch to claude-code provider",
);
});
it("does NOT switch to claude-code when current provider is not anthropic", async () => {
@ -458,22 +579,31 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
const { deps, emittedEvents } = createMockDeps({
model: createMockModel("openai", "gpt-4o"),
findModelResult: (provider: string, modelId: string) => {
if (provider === "claude-code" && modelId === "gpt-4o") return ccModel;
if (provider === "claude-code" && modelId === "gpt-4o")
return ccModel;
return undefined;
},
});
deps.isClaudeCodeReady = () => true;
const handler = new RetryHandler(deps);
const msg = errorMessage("third-party apps are not supported for this plan");
const msg = errorMessage(
"third-party apps are not supported for this plan",
);
const result = await handler.handleRetryableError(msg);
// Should NOT have triggered the claude-code fallback
const switchEvent = emittedEvents.find(
(e) => e.type === "fallback_provider_switch" && e.to?.startsWith("claude-code/"),
(e) =>
e.type === "fallback_provider_switch" &&
e.to?.startsWith("claude-code/"),
);
assert.equal(
switchEvent,
undefined,
"Should NOT switch non-anthropic provider to claude-code",
);
assert.equal(switchEvent, undefined, "Should NOT switch non-anthropic provider to claude-code");
});
});
@ -522,16 +652,60 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
);
});
it("does NOT write credential cooldown for external CLI rate_limit errors", async () => {
const fallbackModel = createMockModel("openai", "gpt-4o");
const { deps, emittedEvents, markUsageLimitReached, findFallback } =
createMockDeps({
model: createMockModel("google-gemini-cli", "gemini-2.5-pro"),
providerAuthMode: "externalCli",
markUsageLimitReachedResult: true,
fallbackResult: {
model: fallbackModel,
reason: "cross-provider fallback",
},
});
const handler = new RetryHandler(deps);
const msg = errorMessage("429 Too Many Requests");
const result = await handler.handleRetryableError(msg);
assert.equal(
result,
true,
"external CLI 429 should still enter fallback/retry handling",
);
assert.equal(
markUsageLimitReached.mock.calls.length,
0,
"external CLI providers must not be written into .sf credential cooldown",
);
assert.equal(
findFallback.mock.calls.length,
1,
"429 should still ask for provider fallback",
);
assert.ok(
emittedEvents.some((event) => event.type === "auto_retry_start"),
"429 should still emit retry start after fallback selection",
);
});
it("still tries cross-provider fallback for quota_exhausted without credential backoff", async () => {
const fallbackModel = createMockModel("openai", "gpt-4o");
const { deps, markUsageLimitReached, continueFn } = createMockDeps({
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
markUsageLimitReachedResult: false,
fallbackResult: { model: fallbackModel, reason: "cross-provider fallback" },
fallbackResult: {
model: fallbackModel,
reason: "cross-provider fallback",
},
});
const handler = new RetryHandler(deps);
const msg = errorMessage("Extra usage is required for long context requests.");
const msg = errorMessage(
"Extra usage is required for long context requests.",
);
const result = await handler.handleRetryableError(msg);

View file

@ -12,12 +12,12 @@
import type { Agent } from "@singularity-forge/pi-agent-core";
import type { AssistantMessage, Model } from "@singularity-forge/pi-ai";
import { isContextOverflow } from "@singularity-forge/pi-ai";
import { sleep } from "../utils/sleep.js";
import type { AgentSessionEvent } from "./agent-session.js";
import type { UsageLimitErrorType } from "./auth-storage.js";
import type { FallbackResolver } from "./fallback-resolver.js";
import type { ModelRegistry } from "./model-registry.js";
import type { SettingsManager } from "./settings-manager.js";
import { sleep } from "../utils/sleep.js";
import type { AgentSessionEvent } from "./agent-session.js";
/** Dependencies injected from AgentSession into RetryHandler */
export interface RetryHandlerDeps {
@ -41,7 +41,8 @@ export class RetryHandler {
private _retryPromise: Promise<void> | undefined = undefined;
private _retryResolve: (() => void) | undefined = undefined;
private _retryGeneration = 0;
private _continueTimeout: ReturnType<typeof setTimeout> | undefined = undefined;
private _continueTimeout: ReturnType<typeof setTimeout> | undefined =
undefined;
constructor(private readonly _deps: RetryHandlerDeps) {}
@ -70,7 +71,9 @@ export class RetryHandler {
* Must be called synchronously from the agent event handler before
* any async processing, so that waitForRetry() doesn't miss in-flight retries.
*/
createRetryPromiseForAgentEnd(messages: Array<{ role: string } & Record<string, any>>): void {
createRetryPromiseForAgentEnd(
messages: Array<{ role: string } & Record<string, any>>,
): void {
if (this._retryPromise) return;
const settings = this._deps.settingsManager.getRetrySettings();
@ -162,7 +165,10 @@ export class RetryHandler {
// when provider reports "can only afford N", lower maxTokens and retry
// on the same model before rotating credentials/providers.
if (isQuotaError) {
const adjusted = this._tryAffordableMaxTokensRetry(message, retryGeneration);
const adjusted = this._tryAffordableMaxTokensRetry(
message,
retryGeneration,
);
if (adjusted) return true;
}
@ -171,12 +177,16 @@ export class RetryHandler {
// gates; rotating to another credential on the same account won't help
// and the 30-minute backoff blocks all provider requests needlessly.
if (isRateLimit) {
const provider = this._deps.getModel()!.provider;
const authMode = this._deps.modelRegistry.getProviderAuthMode(provider);
const hasAlternate =
this._deps.modelRegistry.authStorage.markUsageLimitReached(
this._deps.getModel()!.provider,
this._deps.getSessionId(),
{ errorType },
);
authMode === "externalCli" || authMode === "none"
? false
: this._deps.modelRegistry.authStorage.markUsageLimitReached(
provider,
this._deps.getSessionId(),
{ errorType },
);
if (hasAlternate) {
this._removeLastAssistantError();
@ -235,7 +245,10 @@ export class RetryHandler {
// No fallback available either
if (isQuotaError) {
// Try long-context model downgrade ([1m] → base) before giving up
const downgraded = this._tryLongContextDowngrade(message, retryGeneration);
const downgraded = this._tryLongContextDowngrade(
message,
retryGeneration,
);
if (downgraded) return true;
this._deps.emit({
@ -271,7 +284,8 @@ export class RetryHandler {
// Use server-requested delay when available, capped by maxDelayMs.
// Fall back to exponential backoff when no server hint is present.
const exponentialDelayMs = settings.baseDelayMs * 2 ** (this._retryAttempt - 1);
const exponentialDelayMs =
settings.baseDelayMs * 2 ** (this._retryAttempt - 1);
let delayMs: number;
if (message.retryAfterMs !== undefined) {
const cap = settings.maxDelayMs > 0 ? settings.maxDelayMs : Infinity;
@ -280,7 +294,8 @@ export class RetryHandler {
type: "auto_retry_end",
success: false,
attempt: this._retryAttempt - 1,
finalError: `Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(),
finalError:
`Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(),
});
this._retryAttempt = 0;
this._resolveRetry();
@ -335,9 +350,9 @@ export class RetryHandler {
/** Cancel in-progress retry */
abortRetry(): void {
const hadRetry =
this._retryPromise !== undefined
|| this._retryAbortController !== undefined
|| this._continueTimeout !== undefined;
this._retryPromise !== undefined ||
this._retryAbortController !== undefined ||
this._continueTimeout !== undefined;
if (!hadRetry) return;
const attempt = this._retryAttempt > 0 ? this._retryAttempt : 1;
@ -417,13 +432,30 @@ export class RetryHandler {
const err = errorMessage.toLowerCase();
// Long-context entitlement errors are billing gates, not transient rate limits.
// Must be checked before the generic 429/rate_limit regex.
if (/extra usage is required|long context required/i.test(err)) return "quota_exhausted";
if (/requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test(err))
if (/extra usage is required|long context required/i.test(err))
return "quota_exhausted";
if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) return "quota_exhausted";
if (/rate.?limit|too many requests|429|529|overloaded/i.test(err)) return "rate_limit";
if (/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(err)) return "server_error";
if (/401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test(err)) return "auth_error";
if (
/requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test(
err,
)
)
return "quota_exhausted";
if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err))
return "quota_exhausted";
if (/rate.?limit|too many requests|429|529|overloaded/i.test(err))
return "rate_limit";
if (
/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(
err,
)
)
return "server_error";
if (
/401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test(
err,
)
)
return "auth_error";
return "unknown";
}
@ -431,7 +463,10 @@ export class RetryHandler {
* Attempt a same-model retry by reducing maxTokens when provider reports
* an affordability cap (e.g., "can only afford 329").
*/
private _tryAffordableMaxTokensRetry(message: AssistantMessage, retryGeneration: number): boolean {
private _tryAffordableMaxTokensRetry(
message: AssistantMessage,
retryGeneration: number,
): boolean {
const currentModel = this._deps.getModel();
if (!currentModel || !message.errorMessage) return false;
@ -443,9 +478,15 @@ export class RetryHandler {
if (!Number.isFinite(affordable) || affordable <= 0) return false;
// Leave a small buffer so slight input variance doesn't immediately re-fail.
const safetyBuffer = Math.min(64, Math.max(16, Math.floor(affordable * 0.1)));
const safetyBuffer = Math.min(
64,
Math.max(16, Math.floor(affordable * 0.1)),
);
const targetMaxTokens = Math.max(64, affordable - safetyBuffer);
const downgradedMaxTokens = Math.min(currentModel.maxTokens, targetMaxTokens);
const downgradedMaxTokens = Math.min(
currentModel.maxTokens,
targetMaxTokens,
);
if (downgradedMaxTokens >= currentModel.maxTokens) return false;
const downgradedModel = {
@ -481,7 +522,10 @@ export class RetryHandler {
* base model (claude-opus-4-6) when the account lacks the long-context billing
* entitlement. Returns true if the downgrade was initiated.
*/
private _tryLongContextDowngrade(message: AssistantMessage, retryGeneration: number): boolean {
private _tryLongContextDowngrade(
message: AssistantMessage,
retryGeneration: number,
): boolean {
const currentModel = this._deps.getModel();
if (!currentModel) return false;
@ -490,7 +534,10 @@ export class RetryHandler {
if (!match) return false;
const baseModelId = match[1];
const baseModel = this._deps.modelRegistry.find(currentModel.provider, baseModelId);
const baseModel = this._deps.modelRegistry.find(
currentModel.provider,
baseModelId,
);
if (!baseModel) return false;
const previousId = currentModel.id;
@ -525,7 +572,9 @@ export class RetryHandler {
* and the "out of extra usage" variant that subscription users receive.
*/
private _isThirdPartyBlock(errorMessage: string): boolean {
return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test(errorMessage);
return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test(
errorMessage,
);
}
/**
@ -533,7 +582,10 @@ export class RetryHandler {
* Anthropic provider is blocked by the third-party policy (#3772).
* Returns true if the switch was made and retry scheduled.
*/
private _tryClaudeCodeFallback(message: AssistantMessage, retryGeneration: number): boolean {
private _tryClaudeCodeFallback(
message: AssistantMessage,
retryGeneration: number,
): boolean {
if (!this._deps.isClaudeCodeReady?.()) return false;
const currentModel = this._deps.getModel();
@ -544,7 +596,10 @@ export class RetryHandler {
if (currentModel.provider !== "anthropic") return false;
// Find the same model ID under the claude-code provider
const ccModel = this._deps.modelRegistry.find("claude-code", currentModel.id);
const ccModel = this._deps.modelRegistry.find(
"claude-code",
currentModel.id,
);
if (!ccModel) return false;
const previousProvider = currentModel.provider;
@ -556,7 +611,8 @@ export class RetryHandler {
type: "fallback_provider_switch",
from: `${previousProvider}/${currentModel.id}`,
to: `claude-code/${ccModel.id}`,
reason: "Anthropic subscription blocked for third-party apps — routing through Claude Code CLI",
reason:
"Anthropic subscription blocked for third-party apps — routing through Claude Code CLI",
});
this._deps.emit({
@ -574,7 +630,10 @@ export class RetryHandler {
/** Remove the last assistant error message from agent state */
private _removeLastAssistantError(): void {
const messages = this._deps.agent.state.messages;
if (messages.length > 0 && messages[messages.length - 1].role === "assistant") {
if (
messages.length > 0 &&
messages[messages.length - 1].role === "assistant"
) {
this._deps.agent.replaceMessages(messages.slice(0, -1));
}
}