fix(auth): use gemini cli credentials outside sf store
This commit is contained in:
parent
cb2ab66d4f
commit
3c3000c25f
4 changed files with 959 additions and 308 deletions
|
|
@ -1,11 +1,19 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from 'vitest';
|
||||
import type { Api, Model, SimpleStreamOptions, Context, AssistantMessageEventStream } from "@singularity-forge/pi-ai";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { getApiProvider } from "@singularity-forge/pi-ai";
|
||||
import { describe, it } from "vitest";
|
||||
import { AuthStorage, type AuthStorageData } from "./auth-storage.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
function createRegistry(hasAuthFn?: (provider: string) => boolean): ModelRegistry {
|
||||
function createRegistry(
|
||||
hasAuthFn?: (provider: string) => boolean,
|
||||
): ModelRegistry {
|
||||
const authStorage = {
|
||||
setFallbackResolver: () => {},
|
||||
onCredentialChange: () => {},
|
||||
|
|
@ -22,7 +30,12 @@ function createInMemoryRegistry(data: AuthStorageData = {}): ModelRegistry {
|
|||
return new ModelRegistry(AuthStorage.inMemory(data), undefined);
|
||||
}
|
||||
|
||||
function createProviderModel(id: string, api?: string): NonNullable<Parameters<ModelRegistry["registerProvider"]>[1]["models"]>[number] {
|
||||
function createProviderModel(
|
||||
id: string,
|
||||
api?: string,
|
||||
): NonNullable<
|
||||
Parameters<ModelRegistry["registerProvider"]>[1]["models"]
|
||||
>[number] {
|
||||
return {
|
||||
id,
|
||||
name: id,
|
||||
|
|
@ -35,8 +48,14 @@ function createProviderModel(id: string, api?: string): NonNullable<Parameters<M
|
|||
};
|
||||
}
|
||||
|
||||
function findModel(registry: ModelRegistry, provider: string, id: string): Model<Api> | undefined {
|
||||
return registry.getAvailable().find((m) => m.provider === provider && m.id === id);
|
||||
function findModel(
|
||||
registry: ModelRegistry,
|
||||
provider: string,
|
||||
id: string,
|
||||
): Model<Api> | undefined {
|
||||
return registry
|
||||
.getAvailable()
|
||||
.find((m) => m.provider === provider && m.id === id);
|
||||
}
|
||||
|
||||
function makeModel(provider: string, id: string, api: string): Model<Api> {
|
||||
|
|
@ -62,10 +81,33 @@ function makeContext(): Context {
|
|||
}
|
||||
|
||||
/** No-op streamSimple for tests that need one to pass validation but don't inspect it. */
|
||||
const noopStreamSimple = (_model: Model<Api>, _context: Context, _options?: SimpleStreamOptions) => {
|
||||
const noopStreamSimple = (
|
||||
_model: Model<Api>,
|
||||
_context: Context,
|
||||
_options?: SimpleStreamOptions,
|
||||
) => {
|
||||
return {
|
||||
[Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; },
|
||||
result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }),
|
||||
[Symbol.asyncIterator]() {
|
||||
return { next: async () => ({ value: undefined, done: true as const }) };
|
||||
},
|
||||
result: () =>
|
||||
Promise.resolve({
|
||||
role: "assistant" as const,
|
||||
content: [],
|
||||
api: "test" as Api,
|
||||
provider: "test",
|
||||
model: "test",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop" as const,
|
||||
timestamp: Date.now(),
|
||||
}),
|
||||
push: () => {},
|
||||
end: () => {},
|
||||
} as unknown as AssistantMessageEventStream;
|
||||
|
|
@ -73,16 +115,51 @@ const noopStreamSimple = (_model: Model<Api>, _context: Context, _options?: Simp
|
|||
|
||||
/** Create a spy streamSimple that captures the options it receives and returns a stub stream. */
|
||||
function createStreamSpy(): {
|
||||
streamSimple: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
|
||||
streamSimple: (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
getCapturedOptions: () => SimpleStreamOptions | undefined;
|
||||
} {
|
||||
let capturedOptions: SimpleStreamOptions | undefined;
|
||||
const streamSimple = (_model: Model<Api>, _context: Context, options?: SimpleStreamOptions) => {
|
||||
const streamSimple = (
|
||||
_model: Model<Api>,
|
||||
_context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => {
|
||||
capturedOptions = options;
|
||||
// Return a minimal stub that satisfies AssistantMessageEventStream
|
||||
return {
|
||||
[Symbol.asyncIterator]() { return { next: async () => ({ value: undefined, done: true as const }) }; },
|
||||
result: () => Promise.resolve({ role: "assistant" as const, content: [], api: "test" as Api, provider: "test", model: "test", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, stopReason: "stop" as const, timestamp: Date.now() }),
|
||||
[Symbol.asyncIterator]() {
|
||||
return {
|
||||
next: async () => ({ value: undefined, done: true as const }),
|
||||
};
|
||||
},
|
||||
result: () =>
|
||||
Promise.resolve({
|
||||
role: "assistant" as const,
|
||||
content: [],
|
||||
api: "test" as Api,
|
||||
provider: "test",
|
||||
model: "test",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
stopReason: "stop" as const,
|
||||
timestamp: Date.now(),
|
||||
}),
|
||||
push: () => {},
|
||||
end: () => {},
|
||||
} as unknown as AssistantMessageEventStream;
|
||||
|
|
@ -123,102 +200,153 @@ describe("ModelRegistry authMode — registration", () => {
|
|||
|
||||
it("rejects apiKey provider without apiKey or oauth — message mentions authMode", () => {
|
||||
const registry = createRegistry();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("apikey-provider", {
|
||||
authMode: "apiKey",
|
||||
baseUrl: "https://api.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("authMode"), "error message must mention authMode");
|
||||
assert.ok(err.message.includes("externalCli"), "error message must suggest externalCli");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("apikey-provider", {
|
||||
authMode: "apiKey",
|
||||
baseUrl: "https://api.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("authMode"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
assert.ok(
|
||||
err.message.includes("externalCli"),
|
||||
"error message must suggest externalCli",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects provider with no authMode and no apiKey/oauth (defaults to apiKey)", () => {
|
||||
const registry = createRegistry();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("bare-provider", {
|
||||
baseUrl: "https://api.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("authMode"), "error message must mention authMode");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("bare-provider", {
|
||||
baseUrl: "https://api.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("authMode"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects externalCli provider without streamSimple", () => {
|
||||
const registry = createRegistry();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("cli-no-stream", {
|
||||
authMode: "externalCli",
|
||||
baseUrl: "https://cli.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple");
|
||||
assert.ok(err.message.includes("externalCli"), "error message must mention authMode");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("cli-no-stream", {
|
||||
authMode: "externalCli",
|
||||
baseUrl: "https://cli.local",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("streamSimple"),
|
||||
"error message must mention streamSimple",
|
||||
);
|
||||
assert.ok(
|
||||
err.message.includes("externalCli"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects none provider without streamSimple", () => {
|
||||
const registry = createRegistry();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("none-no-stream", {
|
||||
authMode: "none",
|
||||
baseUrl: "http://localhost:11434",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("streamSimple"), "error message must mention streamSimple");
|
||||
assert.ok(err.message.includes("none"), "error message must mention authMode");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("none-no-stream", {
|
||||
authMode: "none",
|
||||
baseUrl: "http://localhost:11434",
|
||||
api: "openai-completions",
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("streamSimple"),
|
||||
"error message must mention streamSimple",
|
||||
);
|
||||
assert.ok(
|
||||
err.message.includes("none"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects externalCli provider that also sets apiKey", () => {
|
||||
const registry = createRegistry();
|
||||
const spy = createStreamSpy();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("cli-with-key", {
|
||||
authMode: "externalCli",
|
||||
baseUrl: "https://cli.local",
|
||||
api: "openai-completions",
|
||||
apiKey: "SHOULD_NOT_EXIST",
|
||||
streamSimple: spy.streamSimple,
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("apiKey"), "error message must mention apiKey");
|
||||
assert.ok(err.message.includes("externalCli"), "error message must mention authMode");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("cli-with-key", {
|
||||
authMode: "externalCli",
|
||||
baseUrl: "https://cli.local",
|
||||
api: "openai-completions",
|
||||
apiKey: "SHOULD_NOT_EXIST",
|
||||
streamSimple: spy.streamSimple,
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("apiKey"),
|
||||
"error message must mention apiKey",
|
||||
);
|
||||
assert.ok(
|
||||
err.message.includes("externalCli"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects none provider that also sets apiKey", () => {
|
||||
const registry = createRegistry();
|
||||
const spy = createStreamSpy();
|
||||
assert.throws(() => {
|
||||
registry.registerProvider("none-with-key", {
|
||||
authMode: "none",
|
||||
baseUrl: "http://localhost:11434",
|
||||
api: "openai-completions",
|
||||
apiKey: "SHOULD_NOT_EXIST",
|
||||
streamSimple: spy.streamSimple,
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
}, (err: Error) => {
|
||||
assert.ok(err.message.includes("apiKey"), "error message must mention apiKey");
|
||||
assert.ok(err.message.includes("none"), "error message must mention authMode");
|
||||
return true;
|
||||
});
|
||||
assert.throws(
|
||||
() => {
|
||||
registry.registerProvider("none-with-key", {
|
||||
authMode: "none",
|
||||
baseUrl: "http://localhost:11434",
|
||||
api: "openai-completions",
|
||||
apiKey: "SHOULD_NOT_EXIST",
|
||||
streamSimple: spy.streamSimple,
|
||||
models: [createProviderModel("model")],
|
||||
});
|
||||
},
|
||||
(err: Error) => {
|
||||
assert.ok(
|
||||
err.message.includes("apiKey"),
|
||||
"error message must mention apiKey",
|
||||
);
|
||||
assert.ok(
|
||||
err.message.includes("none"),
|
||||
"error message must mention authMode",
|
||||
);
|
||||
return true;
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -230,6 +358,14 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => {
|
|||
assert.equal(registry.getProviderAuthMode("anthropic"), "apiKey");
|
||||
});
|
||||
|
||||
it("treats google-gemini-cli as external CLI auth", () => {
|
||||
const registry = createRegistry();
|
||||
assert.equal(
|
||||
registry.getProviderAuthMode("google-gemini-cli"),
|
||||
"externalCli",
|
||||
);
|
||||
});
|
||||
|
||||
it("returns explicit authMode when set", () => {
|
||||
const registry = createRegistry();
|
||||
registry.registerProvider("cli", {
|
||||
|
|
@ -258,6 +394,11 @@ describe("ModelRegistry authMode — getProviderAuthMode", () => {
|
|||
// ─── isProviderRequestReady ───────────────────────────────────────────────────
|
||||
|
||||
describe("ModelRegistry authMode — isProviderRequestReady", () => {
|
||||
it("returns true for google-gemini-cli without .sf stored auth", () => {
|
||||
const registry = createRegistry(() => false);
|
||||
assert.equal(registry.isProviderRequestReady("google-gemini-cli"), true);
|
||||
});
|
||||
|
||||
it("returns true for externalCli without stored auth", () => {
|
||||
const registry = createRegistry(() => false);
|
||||
registry.registerProvider("cli", {
|
||||
|
|
@ -391,7 +532,10 @@ describe("ModelRegistry authMode — getAvailable", () => {
|
|||
it("excludes apiKey models without stored auth", () => {
|
||||
const registry = createRegistry(() => false);
|
||||
const available = registry.getAvailable();
|
||||
assert.equal(available.length, 0);
|
||||
assert.equal(
|
||||
available.filter((m) => m.provider !== "google-gemini-cli").length,
|
||||
0,
|
||||
);
|
||||
});
|
||||
|
||||
it("prunes Codex models removed from ChatGPT-backed openai-codex OAuth", () => {
|
||||
|
|
@ -407,7 +551,10 @@ describe("ModelRegistry authMode — getAvailable", () => {
|
|||
|
||||
assert.equal(registry.find("openai-codex", "gpt-5.1-codex-max"), undefined);
|
||||
assert.equal(registry.find("openai-codex", "gpt-5.1"), undefined);
|
||||
assert.equal(findModel(registry, "openai-codex", "gpt-5.2-codex"), undefined);
|
||||
assert.equal(
|
||||
findModel(registry, "openai-codex", "gpt-5.2-codex"),
|
||||
undefined,
|
||||
);
|
||||
assert.ok(registry.find("openai-codex", "gpt-5.4"));
|
||||
assert.ok(findModel(registry, "openai-codex", "gpt-5.4"));
|
||||
});
|
||||
|
|
@ -428,6 +575,22 @@ describe("ModelRegistry authMode — getAvailable", () => {
|
|||
// ─── getApiKey ────────────────────────────────────────────────────────────────
|
||||
|
||||
describe("ModelRegistry authMode — getApiKey", () => {
|
||||
it("returns undefined for google-gemini-cli even when stale .sf auth exists", async () => {
|
||||
const registry = createInMemoryRegistry({
|
||||
"google-gemini-cli": {
|
||||
type: "oauth",
|
||||
access: "",
|
||||
refresh: "",
|
||||
expires: 0,
|
||||
},
|
||||
});
|
||||
|
||||
assert.equal(
|
||||
await registry.getApiKeyForProvider("google-gemini-cli"),
|
||||
undefined,
|
||||
);
|
||||
});
|
||||
|
||||
it("returns undefined for externalCli provider", async () => {
|
||||
const registry = createRegistry();
|
||||
registry.registerProvider("cli", {
|
||||
|
|
@ -480,15 +643,18 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
|
|||
const provider = getApiProvider(apiType as Api);
|
||||
assert.ok(provider, "provider must be registered in api registry");
|
||||
|
||||
provider.streamSimple(
|
||||
makeModel("cli-strip", "m", apiType),
|
||||
makeContext(),
|
||||
{ apiKey: "should-be-stripped", maxTokens: 1024 } as SimpleStreamOptions,
|
||||
);
|
||||
provider.streamSimple(makeModel("cli-strip", "m", apiType), makeContext(), {
|
||||
apiKey: "should-be-stripped",
|
||||
maxTokens: 1024,
|
||||
} as SimpleStreamOptions);
|
||||
|
||||
const captured = spy.getCapturedOptions();
|
||||
assert.ok(captured, "streamSimple must have been called");
|
||||
assert.equal("apiKey" in captured, false, "apiKey must not exist in options for externalCli provider");
|
||||
assert.equal(
|
||||
"apiKey" in captured,
|
||||
false,
|
||||
"apiKey must not exist in options for externalCli provider",
|
||||
);
|
||||
assert.equal(captured.maxTokens, 1024, "other options must pass through");
|
||||
});
|
||||
|
||||
|
|
@ -516,7 +682,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
|
|||
|
||||
const captured = spy.getCapturedOptions();
|
||||
assert.ok(captured, "streamSimple must have been called");
|
||||
assert.equal("apiKey" in captured, false, "apiKey must not exist in options for none provider");
|
||||
assert.equal(
|
||||
"apiKey" in captured,
|
||||
false,
|
||||
"apiKey must not exist in options for none provider",
|
||||
);
|
||||
assert.equal(captured.maxTokens, 2048, "other options must pass through");
|
||||
});
|
||||
|
||||
|
|
@ -544,7 +714,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
|
|||
|
||||
const captured = spy.getCapturedOptions();
|
||||
assert.ok(captured, "streamSimple must have been called");
|
||||
assert.equal(captured.apiKey, "sk-real-key", "apiKey must be preserved for apiKey provider");
|
||||
assert.equal(
|
||||
captured.apiKey,
|
||||
"sk-real-key",
|
||||
"apiKey must be preserved for apiKey provider",
|
||||
);
|
||||
assert.equal(captured.maxTokens, 4096, "other options must pass through");
|
||||
});
|
||||
|
||||
|
|
@ -572,7 +746,11 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
|
|||
|
||||
const captured = spy.getCapturedOptions();
|
||||
assert.ok(captured !== undefined, "streamSimple must have been called");
|
||||
assert.equal("apiKey" in captured, false, "apiKey must not exist even when options is undefined");
|
||||
assert.equal(
|
||||
"apiKey" in captured,
|
||||
false,
|
||||
"apiKey must not exist even when options is undefined",
|
||||
);
|
||||
});
|
||||
|
||||
it("strips apiKey but preserves signal and other fields for externalCli", () => {
|
||||
|
|
@ -595,15 +773,28 @@ describe("ModelRegistry authMode — streamSimple apiKey boundary", () => {
|
|||
provider.streamSimple(
|
||||
makeModel("cli-fields", "m", apiType),
|
||||
makeContext(),
|
||||
{ apiKey: "strip-me", maxTokens: 8192, signal: abortController.signal, reasoning: "high" } as SimpleStreamOptions,
|
||||
{
|
||||
apiKey: "strip-me",
|
||||
maxTokens: 8192,
|
||||
signal: abortController.signal,
|
||||
reasoning: "high",
|
||||
} as SimpleStreamOptions,
|
||||
);
|
||||
|
||||
const captured = spy.getCapturedOptions();
|
||||
assert.ok(captured, "streamSimple must have been called");
|
||||
assert.equal("apiKey" in captured, false, "apiKey must be stripped");
|
||||
assert.equal(captured.maxTokens, 8192, "maxTokens must pass through");
|
||||
assert.equal(captured.signal, abortController.signal, "signal must pass through");
|
||||
assert.equal((captured as Record<string, unknown>).reasoning, "high", "reasoning must pass through");
|
||||
assert.equal(
|
||||
captured.signal,
|
||||
abortController.signal,
|
||||
"signal must pass through",
|
||||
);
|
||||
assert.equal(
|
||||
(captured as Record<string, unknown>).reasoning,
|
||||
"high",
|
||||
"reasoning must pass through",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -633,11 +824,12 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => {
|
|||
// The built-in handler will throw (no API key), which proves the routing
|
||||
// correctly delegates to the built-in instead of the custom handler.
|
||||
assert.throws(
|
||||
() => provider.streamSimple(
|
||||
makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"),
|
||||
makeContext(),
|
||||
{ maxTokens: 4096 } as SimpleStreamOptions,
|
||||
),
|
||||
() =>
|
||||
provider.streamSimple(
|
||||
makeModel("anthropic", "claude-sonnet-4-6", "anthropic-messages"),
|
||||
makeContext(),
|
||||
{ maxTokens: 4096 } as SimpleStreamOptions,
|
||||
),
|
||||
(err: Error) => err.message.includes("API key"),
|
||||
"built-in Anthropic handler must be invoked (throws because no API key in tests)",
|
||||
);
|
||||
|
|
@ -672,7 +864,10 @@ describe("ModelRegistry authMode — provider-scoped stream routing", () => {
|
|||
);
|
||||
|
||||
const captured = customSpy.getCapturedOptions();
|
||||
assert.ok(captured, "custom provider's streamSimple must be called for its own models");
|
||||
assert.ok(
|
||||
captured,
|
||||
"custom provider's streamSimple must be called for its own models",
|
||||
);
|
||||
assert.equal(captured.maxTokens, 2048);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import {
|
||||
type Api,
|
||||
applyCapabilityPatches,
|
||||
type AssistantMessageEventStream,
|
||||
applyCapabilityPatches,
|
||||
type Context,
|
||||
getApiProvider,
|
||||
getEnvApiKey,
|
||||
|
|
@ -20,14 +21,17 @@ import {
|
|||
resetApiProviders,
|
||||
type SimpleStreamOptions,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { registerOAuthProvider, resetOAuthProviders } from "@singularity-forge/pi-ai/oauth";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import {
|
||||
registerOAuthProvider,
|
||||
resetOAuthProviders,
|
||||
} from "@singularity-forge/pi-ai/oauth";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||
import { isLocalModel } from "./local-model-check.js";
|
||||
import type { DiscoveryResult } from "./model-discovery.js";
|
||||
import {
|
||||
getDiscoverableCatalogSources,
|
||||
|
|
@ -35,7 +39,6 @@ import {
|
|||
getDiscoveryAdapter,
|
||||
} from "./model-discovery.js";
|
||||
import { resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
|
||||
import { isLocalModel } from "./local-model-check.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const ajv = new Ajv();
|
||||
|
|
@ -70,27 +73,33 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{
|
|||
family_failover?: string[];
|
||||
}> = [
|
||||
// MiniMax direct (api.minimax.io) → CN endpoint as its direct pair
|
||||
{ match: /^MiniMax-/i, prefix: "MiniMax-", providers: ["minimax", "minimax-cn"] },
|
||||
{
|
||||
match: /^MiniMax-/i,
|
||||
prefix: "MiniMax-",
|
||||
providers: ["minimax", "minimax-cn"],
|
||||
},
|
||||
// ZAI direct API for GLM
|
||||
{ match: /^glm-/i, prefix: "glm-", providers: ["zai"] },
|
||||
{ match: /^glm-/i, prefix: "glm-", providers: ["zai"] },
|
||||
// Kimi Code direct API
|
||||
{ match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] },
|
||||
{ match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] },
|
||||
// MiMo/Xiaomi — direct API via Xiaomi MiMo Open Platform (api.xiaomimimo.com)
|
||||
// or the Token Plan endpoint (token-plan-sgp.xiaomimimo.com). Both served
|
||||
// under the `xiaomi` provider namespace.
|
||||
{ match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] },
|
||||
// Gemini/Gemma: google-gemini-cli (OAuth), google (API key), google-vertex
|
||||
// are all FIRST-PARTY Google endpoints. github-copilot re-serves and is
|
||||
// failover only.
|
||||
{ match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] },
|
||||
// Gemini/Gemma: google-gemini-cli (CLI OAuth via ~/.gemini), google
|
||||
// (API key), google-vertex are all FIRST-PARTY Google endpoints.
|
||||
// github-copilot re-serves and is failover only.
|
||||
{
|
||||
match: /^gemini-|^gemma-/i, prefix: "gemini-",
|
||||
match: /^gemini-|^gemma-/i,
|
||||
prefix: "gemini-",
|
||||
providers: ["google-gemini-cli", "google", "google-vertex"],
|
||||
family_failover: ["github-copilot"],
|
||||
},
|
||||
// Claude: Anthropic is the ONLY direct provider. github-copilot re-serves
|
||||
// Claude via GitHub's platform as failover.
|
||||
{
|
||||
match: /^claude-/i, prefix: "claude-",
|
||||
match: /^claude-/i,
|
||||
prefix: "claude-",
|
||||
providers: ["anthropic"],
|
||||
family_failover: ["github-copilot"],
|
||||
},
|
||||
|
|
@ -99,9 +108,14 @@ export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{
|
|||
// the same weights via a different legal/contractual relationship).
|
||||
// github-copilot likewise re-serves.
|
||||
{
|
||||
match: /^gpt-|^o\d|^codex-/i, prefix: "gpt-",
|
||||
match: /^gpt-|^o\d|^codex-/i,
|
||||
prefix: "gpt-",
|
||||
providers: ["openai"],
|
||||
family_failover: ["azure-openai-responses", "openai-codex", "github-copilot"],
|
||||
family_failover: [
|
||||
"azure-openai-responses",
|
||||
"openai-codex",
|
||||
"github-copilot",
|
||||
],
|
||||
},
|
||||
];
|
||||
|
||||
|
|
@ -123,12 +137,23 @@ const OpenAICompletionsCompatSchema = Type.Object({
|
|||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
maxTokensField: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("max_completion_tokens"),
|
||||
Type.Literal("max_tokens"),
|
||||
]),
|
||||
),
|
||||
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai"), Type.Literal("qwen")])),
|
||||
thinkingFormat: Type.Optional(
|
||||
Type.Union([
|
||||
Type.Literal("openai"),
|
||||
Type.Literal("zai"),
|
||||
Type.Literal("qwen"),
|
||||
]),
|
||||
),
|
||||
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
|
||||
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
|
||||
});
|
||||
|
|
@ -137,7 +162,10 @@ const OpenAIResponsesCompatSchema = Type.Object({
|
|||
// Reserved for future use
|
||||
});
|
||||
|
||||
const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]);
|
||||
const OpenAICompatSchema = Type.Union([
|
||||
OpenAICompletionsCompatSchema,
|
||||
OpenAIResponsesCompatSchema,
|
||||
]);
|
||||
|
||||
// Schema for custom model definition
|
||||
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
|
||||
|
|
@ -147,7 +175,9 @@ const ModelDefinitionSchema = Type.Object({
|
|||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Number(),
|
||||
|
|
@ -166,7 +196,9 @@ const ModelDefinitionSchema = Type.Object({
|
|||
const ModelOverrideSchema = Type.Object({
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
|
||||
input: Type.Optional(
|
||||
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
|
||||
),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Optional(Type.Number()),
|
||||
|
|
@ -190,7 +222,9 @@ const ProviderConfigSchema = Type.Object({
|
|||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||
modelOverrides: Type.Optional(Type.Record(Type.String(), ModelOverrideSchema)),
|
||||
modelOverrides: Type.Optional(
|
||||
Type.Record(Type.String(), ModelOverrideSchema),
|
||||
),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
|
|
@ -205,7 +239,8 @@ export type ProviderModelAllowList = Record<string, readonly string[]>;
|
|||
|
||||
export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none";
|
||||
|
||||
type ProviderPolicyModel = Pick<Model<Api>, "provider" | "id"> & Partial<Pick<Model<Api>, "name" | "cost">>;
|
||||
type ProviderPolicyModel = Pick<Model<Api>, "provider" | "id"> &
|
||||
Partial<Pick<Model<Api>, "name" | "cost">>;
|
||||
|
||||
const OPENCODE_FREE_MODEL_IDS = new Set([
|
||||
"big-pickle",
|
||||
|
|
@ -221,7 +256,12 @@ const HIDDEN_MODEL_PROVIDERS = new Set([
|
|||
"xiaomi-token-plan-sgp",
|
||||
]);
|
||||
|
||||
function providerModelAllowEntryMatches(allowedModel: string, modelKey: string): boolean {
|
||||
const BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS = new Set(["google-gemini-cli"]);
|
||||
|
||||
function providerModelAllowEntryMatches(
|
||||
allowedModel: string,
|
||||
modelKey: string,
|
||||
): boolean {
|
||||
const allowedKey = allowedModel.trim().toLowerCase();
|
||||
if (!allowedKey) return false;
|
||||
if (allowedKey === modelKey) return true;
|
||||
|
|
@ -239,7 +279,13 @@ function hasFreeSkuMarker(value: string | undefined): boolean {
|
|||
}
|
||||
|
||||
function isZeroCost(cost: Model<Api>["cost"] | undefined): boolean {
|
||||
return !!cost && cost.input === 0 && cost.output === 0 && cost.cacheRead === 0 && cost.cacheWrite === 0;
|
||||
return (
|
||||
!!cost &&
|
||||
cost.input === 0 &&
|
||||
cost.output === 0 &&
|
||||
cost.cacheRead === 0 &&
|
||||
cost.cacheWrite === 0
|
||||
);
|
||||
}
|
||||
|
||||
function isMistralSelectionModel(modelId: string): boolean {
|
||||
|
|
@ -259,7 +305,9 @@ function isMistralSelectionModel(modelId: string): boolean {
|
|||
return true;
|
||||
}
|
||||
|
||||
function isModelAllowedByBuiltInProviderPolicy(model: ProviderPolicyModel): boolean {
|
||||
function isModelAllowedByBuiltInProviderPolicy(
|
||||
model: ProviderPolicyModel,
|
||||
): boolean {
|
||||
const provider = model.provider.toLowerCase();
|
||||
const modelKey = model.id.trim().toLowerCase();
|
||||
if (HIDDEN_MODEL_PROVIDERS.has(provider)) {
|
||||
|
|
@ -309,22 +357,35 @@ function mergeCompat(
|
|||
): Model<Api>["compat"] | undefined {
|
||||
if (!overrideCompat) return baseCompat;
|
||||
|
||||
const base = baseCompat as OpenAICompletionsCompat | OpenAIResponsesCompat | undefined;
|
||||
const override = overrideCompat as OpenAICompletionsCompat | OpenAIResponsesCompat;
|
||||
const merged = { ...base, ...override } as OpenAICompletionsCompat | OpenAIResponsesCompat;
|
||||
const base = baseCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat
|
||||
| undefined;
|
||||
const override = overrideCompat as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
const merged = { ...base, ...override } as
|
||||
| OpenAICompletionsCompat
|
||||
| OpenAIResponsesCompat;
|
||||
|
||||
const baseCompletions = base as OpenAICompletionsCompat | undefined;
|
||||
const overrideCompletions = override as OpenAICompletionsCompat;
|
||||
const mergedCompletions = merged as OpenAICompletionsCompat;
|
||||
|
||||
if (baseCompletions?.openRouterRouting || overrideCompletions.openRouterRouting) {
|
||||
if (
|
||||
baseCompletions?.openRouterRouting ||
|
||||
overrideCompletions.openRouterRouting
|
||||
) {
|
||||
mergedCompletions.openRouterRouting = {
|
||||
...baseCompletions?.openRouterRouting,
|
||||
...overrideCompletions.openRouterRouting,
|
||||
};
|
||||
}
|
||||
|
||||
if (baseCompletions?.vercelGatewayRouting || overrideCompletions.vercelGatewayRouting) {
|
||||
if (
|
||||
baseCompletions?.vercelGatewayRouting ||
|
||||
overrideCompletions.vercelGatewayRouting
|
||||
) {
|
||||
mergedCompletions.vercelGatewayRouting = {
|
||||
...baseCompletions?.vercelGatewayRouting,
|
||||
...overrideCompletions.vercelGatewayRouting,
|
||||
|
|
@ -338,14 +399,19 @@ function mergeCompat(
|
|||
* Deep merge a model override into a model.
|
||||
* Handles nested objects (cost, compat) by merging rather than replacing.
|
||||
*/
|
||||
function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<Api> {
|
||||
function applyModelOverride(
|
||||
model: Model<Api>,
|
||||
override: ModelOverride,
|
||||
): Model<Api> {
|
||||
const result = { ...model };
|
||||
|
||||
// Simple field overrides
|
||||
if (override.name !== undefined) result.name = override.name;
|
||||
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
|
||||
if (override.input !== undefined) result.input = override.input as ("text" | "image")[];
|
||||
if (override.contextWindow !== undefined) result.contextWindow = override.contextWindow;
|
||||
if (override.input !== undefined)
|
||||
result.input = override.input as ("text" | "image")[];
|
||||
if (override.contextWindow !== undefined)
|
||||
result.contextWindow = override.contextWindow;
|
||||
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
|
||||
|
||||
// Merge cost (partial override)
|
||||
|
|
@ -361,7 +427,9 @@ function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<A
|
|||
// Merge headers
|
||||
if (override.headers) {
|
||||
const resolvedHeaders = resolveHeaders(override.headers);
|
||||
result.headers = resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers;
|
||||
result.headers = resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers;
|
||||
}
|
||||
|
||||
// Deep merge compat
|
||||
|
|
@ -370,7 +438,6 @@ function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<A
|
|||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
|
|
@ -384,7 +451,10 @@ export class ModelRegistry {
|
|||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
readonly modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
||||
readonly modelsJsonPath: string | undefined = join(
|
||||
getAgentDir(),
|
||||
"models.json",
|
||||
),
|
||||
discoveryCache?: ModelDiscoveryCache,
|
||||
) {
|
||||
this.discoveryCache = discoveryCache ?? new ModelDiscoveryCache();
|
||||
|
|
@ -437,7 +507,9 @@ export class ModelRegistry {
|
|||
overrides,
|
||||
modelOverrides,
|
||||
error,
|
||||
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
|
||||
} = this.modelsJsonPath
|
||||
? this.loadCustomModels(this.modelsJsonPath)
|
||||
: emptyCustomModelsResult();
|
||||
|
||||
if (error) {
|
||||
this.loadError = error;
|
||||
|
|
@ -480,7 +552,9 @@ export class ModelRegistry {
|
|||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
|
||||
headers: resolvedHeaders
|
||||
? { ...model.headers, ...resolvedHeaders }
|
||||
: model.headers,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -496,10 +570,15 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
|
||||
private mergeCustomModels(builtInModels: Model<Api>[], customModels: Model<Api>[]): Model<Api>[] {
|
||||
private mergeCustomModels(
|
||||
builtInModels: Model<Api>[],
|
||||
customModels: Model<Api>[],
|
||||
): Model<Api>[] {
|
||||
const merged = [...builtInModels];
|
||||
for (const customModel of customModels) {
|
||||
const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id);
|
||||
const existingIndex = merged.findIndex(
|
||||
(m) => m.provider === customModel.provider && m.id === customModel.id,
|
||||
);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = customModel;
|
||||
} else {
|
||||
|
|
@ -509,22 +588,32 @@ export class ModelRegistry {
|
|||
return merged;
|
||||
}
|
||||
|
||||
private isProviderModelAllowed(model: ProviderPolicyModel, providerModelAllow?: ProviderModelAllowList): boolean {
|
||||
private isProviderModelAllowed(
|
||||
model: ProviderPolicyModel,
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): boolean {
|
||||
if (!isModelAllowedByBuiltInProviderPolicy(model)) return false;
|
||||
if (!providerModelAllow) return true;
|
||||
const providerKey = model.provider.toLowerCase();
|
||||
const allowedModels = providerModelAllow[providerKey]
|
||||
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
|
||||
const allowedModels =
|
||||
providerModelAllow[providerKey] ??
|
||||
Object.entries(providerModelAllow).find(
|
||||
([key]) => key.toLowerCase() === providerKey,
|
||||
)?.[1];
|
||||
if (allowedModels === undefined) return true;
|
||||
const modelKey = model.id.trim().toLowerCase();
|
||||
return allowedModels.some((allowedModel) => providerModelAllowEntryMatches(allowedModel, modelKey));
|
||||
return allowedModels.some((allowedModel) =>
|
||||
providerModelAllowEntryMatches(allowedModel, modelKey),
|
||||
);
|
||||
}
|
||||
|
||||
private filterProviderModelAllow<T extends Model<Api>>(
|
||||
models: T[],
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): T[] {
|
||||
return models.filter((model) => this.isProviderModelAllowed(model, providerModelAllow));
|
||||
return models.filter((model) =>
|
||||
this.isProviderModelAllowed(model, providerModelAllow),
|
||||
);
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
|
|
@ -540,9 +629,12 @@ export class ModelRegistry {
|
|||
const validate = ajv.getSchema("ModelsConfig")!;
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`);
|
||||
validate.errors
|
||||
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
|
||||
.join("\n") || "Unknown schema error";
|
||||
return emptyCustomModelsResult(
|
||||
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
|
|
@ -551,9 +643,15 @@ export class ModelRegistry {
|
|||
const overrides = new Map<string, ProviderOverride>();
|
||||
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
|
||||
if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) {
|
||||
if (
|
||||
providerConfig.baseUrl ||
|
||||
providerConfig.headers ||
|
||||
providerConfig.apiKey
|
||||
) {
|
||||
overrides.set(providerName, {
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
headers: providerConfig.headers,
|
||||
|
|
@ -567,14 +665,24 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
if (providerConfig.modelOverrides) {
|
||||
modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides)));
|
||||
modelOverrides.set(
|
||||
providerName,
|
||||
new Map(Object.entries(providerConfig.modelOverrides)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return { models: this.parseModels(config), overrides, modelOverrides, error: undefined };
|
||||
return {
|
||||
models: this.parseModels(config),
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error: undefined,
|
||||
};
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
|
|
@ -583,24 +691,33 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
const models = providerConfig.models ?? [];
|
||||
const hasModelOverrides =
|
||||
providerConfig.modelOverrides && Object.keys(providerConfig.modelOverrides).length > 0;
|
||||
providerConfig.modelOverrides &&
|
||||
Object.keys(providerConfig.modelOverrides).length > 0;
|
||||
|
||||
if (models.length === 0) {
|
||||
// Override-only config: needs baseUrl OR modelOverrides (or both)
|
||||
if (!providerConfig.baseUrl && !hasModelOverrides) {
|
||||
throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
// Custom models are merged into provider models and require endpoint + auth.
|
||||
if (!providerConfig.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
if (!providerConfig.apiKey) {
|
||||
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -613,12 +730,17 @@ export class ModelRegistry {
|
|||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
if (!modelDef.id)
|
||||
throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
// Validate contextWindow/maxTokens only if provided (they have defaults)
|
||||
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
|
||||
);
|
||||
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -626,7 +748,9 @@ export class ModelRegistry {
|
|||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
for (const [providerName, providerConfig] of Object.entries(
|
||||
config.providers,
|
||||
)) {
|
||||
const modelDefs = providerConfig.models ?? [];
|
||||
if (modelDefs.length === 0) continue; // Override-only, no custom models
|
||||
|
||||
|
|
@ -655,7 +779,10 @@ export class ModelRegistry {
|
|||
// Resolve env vars and shell commands in header values
|
||||
const providerHeaders = resolveHeaders(providerConfig.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader && providerConfig.apiKey) {
|
||||
|
|
@ -667,7 +794,12 @@ export class ModelRegistry {
|
|||
|
||||
// Provider baseUrl is required when custom models are defined.
|
||||
// Individual models can override it with modelDef.baseUrl.
|
||||
const defaultCost = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
|
||||
const defaultCost = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
};
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name ?? modelDef.id,
|
||||
|
|
@ -712,6 +844,7 @@ export class ModelRegistry {
|
|||
* Defaults to "apiKey" for built-ins and providers without explicit mode.
|
||||
*/
|
||||
getProviderAuthMode(provider: string): ProviderAuthMode {
|
||||
if (BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS.has(provider)) return "externalCli";
|
||||
const config = this.registeredProviders.get(provider);
|
||||
if (!config) return "apiKey";
|
||||
if (config.authMode) return config.authMode;
|
||||
|
|
@ -745,10 +878,15 @@ export class ModelRegistry {
|
|||
* Returns undefined for externalCli/none providers (no key needed).
|
||||
* @param sessionId - Optional session ID for sticky credential selection
|
||||
*/
|
||||
async getApiKey(model: Model<Api>, sessionId?: string): Promise<string | undefined> {
|
||||
async getApiKey(
|
||||
model: Model<Api>,
|
||||
sessionId?: string,
|
||||
): Promise<string | undefined> {
|
||||
const authMode = this.getProviderAuthMode(model.provider);
|
||||
if (authMode === "externalCli" || authMode === "none") return undefined;
|
||||
return this.authStorage.getApiKey(model.provider, sessionId, { baseUrl: model.baseUrl });
|
||||
return this.authStorage.getApiKey(model.provider, sessionId, {
|
||||
baseUrl: model.baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -756,7 +894,10 @@ export class ModelRegistry {
|
|||
* Returns undefined for externalCli/none providers (no key needed).
|
||||
* @param sessionId - Optional session ID for sticky credential selection
|
||||
*/
|
||||
async getApiKeyForProvider(provider: string, sessionId?: string): Promise<string | undefined> {
|
||||
async getApiKeyForProvider(
|
||||
provider: string,
|
||||
sessionId?: string,
|
||||
): Promise<string | undefined> {
|
||||
const authMode = this.getProviderAuthMode(provider);
|
||||
if (authMode === "externalCli" || authMode === "none") return undefined;
|
||||
return this.authStorage.getApiKey(provider, sessionId);
|
||||
|
|
@ -798,7 +939,10 @@ export class ModelRegistry {
|
|||
this.refresh();
|
||||
}
|
||||
|
||||
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
|
||||
private applyProviderConfig(
|
||||
providerName: string,
|
||||
config: ProviderConfigInput,
|
||||
): void {
|
||||
// Register OAuth provider if provided
|
||||
if (config.oauth) {
|
||||
// Ensure the OAuth provider ID matches the provider name
|
||||
|
|
@ -811,19 +955,30 @@ export class ModelRegistry {
|
|||
|
||||
if (config.streamSimple) {
|
||||
if (!config.api) {
|
||||
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "api" is required when registering streamSimple.`,
|
||||
);
|
||||
}
|
||||
const rawStreamSimple = config.streamSimple;
|
||||
const authMode = config.authMode ?? "apiKey";
|
||||
|
||||
// Keyless providers never see apiKey in options — enforced at registration,
|
||||
// not by convention. Prevents undefined from reaching any handler.
|
||||
const streamSimple = (authMode === "externalCli" || authMode === "none")
|
||||
? ((model: Model<Api>, context: Context, options?: SimpleStreamOptions) => {
|
||||
const { apiKey: _, ...opts } = options ?? {};
|
||||
return rawStreamSimple(model, context, opts as SimpleStreamOptions);
|
||||
})
|
||||
: rawStreamSimple;
|
||||
const streamSimple =
|
||||
authMode === "externalCli" || authMode === "none"
|
||||
? (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => {
|
||||
const { apiKey: _, ...opts } = options ?? {};
|
||||
return rawStreamSimple(
|
||||
model,
|
||||
context,
|
||||
opts as SimpleStreamOptions,
|
||||
);
|
||||
}
|
||||
: rawStreamSimple;
|
||||
|
||||
// Guard: if there's already a handler registered for this API, wrap
|
||||
// the new one so it only fires for models from this provider and
|
||||
|
|
@ -832,7 +987,11 @@ export class ModelRegistry {
|
|||
// the built-in Anthropic stream handler (#2536).
|
||||
const existingProvider = getApiProvider(config.api as Api);
|
||||
const scopedStream = existingProvider
|
||||
? (model: Model<Api>, context: Context, options?: SimpleStreamOptions): AssistantMessageEventStream => {
|
||||
? (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
if (model.provider === providerName) {
|
||||
return streamSimple(model, context, options);
|
||||
}
|
||||
|
|
@ -840,12 +999,23 @@ export class ModelRegistry {
|
|||
}
|
||||
: streamSimple;
|
||||
|
||||
const newFullStream = (model: Model<Api>, context: Context, options?: SimpleStreamOptions) =>
|
||||
scopedStream(model, context, options as SimpleStreamOptions);
|
||||
const newFullStream = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => scopedStream(model, context, options as SimpleStreamOptions);
|
||||
const scopedFullStream = existingProvider
|
||||
? (model: Model<Api>, context: Context, options?: Record<string, unknown>) => {
|
||||
? (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: Record<string, unknown>,
|
||||
) => {
|
||||
if (model.provider === providerName) {
|
||||
return newFullStream(model, context, options as SimpleStreamOptions);
|
||||
return newFullStream(
|
||||
model,
|
||||
context,
|
||||
options as SimpleStreamOptions,
|
||||
);
|
||||
}
|
||||
return existingProvider.stream(model, context, options);
|
||||
}
|
||||
|
|
@ -872,25 +1042,35 @@ export class ModelRegistry {
|
|||
|
||||
// Validate required fields
|
||||
if (!config.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "baseUrl" is required when defining models.`,
|
||||
);
|
||||
}
|
||||
const authMode = config.authMode ?? (config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey");
|
||||
const authMode =
|
||||
config.authMode ??
|
||||
(config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey");
|
||||
if (authMode === "apiKey" && !config.apiKey && !config.oauth) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" or "oauth" is required when authMode is "apiKey" (the default). ` +
|
||||
`Set authMode to "externalCli" or "none" for keyless providers.`,
|
||||
`Set authMode to "externalCli" or "none" for keyless providers.`,
|
||||
);
|
||||
}
|
||||
if ((authMode === "externalCli" || authMode === "none") && !config.streamSimple) {
|
||||
if (
|
||||
(authMode === "externalCli" || authMode === "none") &&
|
||||
!config.streamSimple
|
||||
) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "streamSimple" is required when authMode is "${authMode}". ` +
|
||||
`Keyless providers must supply their own stream handler.`,
|
||||
`Keyless providers must supply their own stream handler.`,
|
||||
);
|
||||
}
|
||||
if ((authMode === "externalCli" || authMode === "none") && config.apiKey) {
|
||||
if (
|
||||
(authMode === "externalCli" || authMode === "none") &&
|
||||
config.apiKey
|
||||
) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}: "apiKey" cannot be set when authMode is "${authMode}". ` +
|
||||
`Keyless providers should not provide API key credentials.`,
|
||||
`Keyless providers should not provide API key credentials.`,
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -898,13 +1078,18 @@ export class ModelRegistry {
|
|||
for (const modelDef of config.models) {
|
||||
const api = modelDef.api || config.api;
|
||||
if (!api) {
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
|
||||
);
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
const providerHeaders = resolveHeaders(config.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||
let headers =
|
||||
providerHeaders || modelHeaders
|
||||
? { ...providerHeaders, ...modelHeaders }
|
||||
: undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header
|
||||
if (config.authHeader && config.apiKey) {
|
||||
|
|
@ -949,15 +1134,24 @@ export class ModelRegistry {
|
|||
return {
|
||||
...m,
|
||||
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
|
||||
headers: resolvedHeaders
|
||||
? { ...m.headers, ...resolvedHeaders }
|
||||
: m.headers,
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
private buildCandidateOrder(modelId: string, overrides: Record<string, string[]>): string[] {
|
||||
const overrideEntry = Object.entries(overrides).find(([k]) => modelId.startsWith(k));
|
||||
const familyEntry = PROXY_FAMILY_PRIORITY.find((r) => r.match.test(modelId));
|
||||
private buildCandidateOrder(
|
||||
modelId: string,
|
||||
overrides: Record<string, string[]>,
|
||||
): string[] {
|
||||
const overrideEntry = Object.entries(overrides).find(([k]) =>
|
||||
modelId.startsWith(k),
|
||||
);
|
||||
const familyEntry = PROXY_FAMILY_PRIORITY.find((r) =>
|
||||
r.match.test(modelId),
|
||||
);
|
||||
// Order: direct family providers → family-scoped failover → global fallback.
|
||||
// Overrides replace only the direct list (keeps family_failover + global
|
||||
// chain intact) so a user pinning "glm- → [zai]" still picks up
|
||||
|
|
@ -992,8 +1186,12 @@ export class ModelRegistry {
|
|||
return (pa === -1 ? Infinity : pa) - (pb === -1 ? Infinity : pb);
|
||||
});
|
||||
|
||||
const withAuth = sorted.filter((m) => this.isProviderRequestReady(m.provider));
|
||||
const withoutAuth = sorted.filter((m) => !this.isProviderRequestReady(m.provider));
|
||||
const withAuth = sorted.filter((m) =>
|
||||
this.isProviderRequestReady(m.provider),
|
||||
);
|
||||
const withoutAuth = sorted.filter(
|
||||
(m) => !this.isProviderRequestReady(m.provider),
|
||||
);
|
||||
return [...withAuth, ...withoutAuth];
|
||||
}
|
||||
|
||||
|
|
@ -1064,7 +1262,9 @@ export class ModelRegistry {
|
|||
}
|
||||
|
||||
// Convert and merge discovered models, then apply capability patches
|
||||
this.discoveredModels = applyCapabilityPatches(this.convertDiscoveredModels(results));
|
||||
this.discoveredModels = applyCapabilityPatches(
|
||||
this.convertDiscoveredModels(results),
|
||||
);
|
||||
return results;
|
||||
}
|
||||
|
||||
|
|
@ -1082,8 +1282,12 @@ export class ModelRegistry {
|
|||
* Discovered models are appended but never override existing models.
|
||||
*/
|
||||
getAllWithDiscovered(): Model<Api>[] {
|
||||
const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`));
|
||||
const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`));
|
||||
const existingIds = new Set(
|
||||
this.models.map((m) => `${m.provider}/${m.id}`),
|
||||
);
|
||||
const unique = this.discoveredModels.filter(
|
||||
(m) => !existingIds.has(`${m.provider}/${m.id}`),
|
||||
);
|
||||
return this.filterProviderModelAllow([...this.models, ...unique]);
|
||||
}
|
||||
|
||||
|
|
@ -1094,15 +1298,22 @@ export class ModelRegistry {
|
|||
* providers with the provider's actual `/models` response.
|
||||
* Consumer: cli/list-models.ts when `--discover` or an exact provider query is used.
|
||||
*/
|
||||
getDiscoveredModels(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
|
||||
return this.filterProviderModelAllow(this.discoveredModels, providerModelAllow);
|
||||
getDiscoveredModels(
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): Model<Api>[] {
|
||||
return this.filterProviderModelAllow(
|
||||
this.discoveredModels,
|
||||
providerModelAllow,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model was added via discovery (not built-in or custom).
|
||||
*/
|
||||
isDiscovered(model: Model<Api>): boolean {
|
||||
return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id);
|
||||
return this.discoveredModels.some(
|
||||
(m) => m.provider === model.provider && m.id === model.id,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -1125,7 +1336,9 @@ export class ModelRegistry {
|
|||
const key = `${provider}/${dm.id}`;
|
||||
if (seen.has(key)) continue;
|
||||
seen.add(key);
|
||||
const known = this.models.find((m) => m.provider === provider && m.id === dm.id);
|
||||
const known = this.models.find(
|
||||
(m) => m.provider === provider && m.id === dm.id,
|
||||
);
|
||||
const discoveredName =
|
||||
dm.name && dm.name !== dm.id ? dm.name : undefined;
|
||||
converted.push({
|
||||
|
|
@ -1137,7 +1350,8 @@ export class ModelRegistry {
|
|||
baseUrl: dm.baseUrl ?? known?.baseUrl ?? "",
|
||||
reasoning: dm.reasoning ?? known?.reasoning ?? false,
|
||||
input: dm.input ?? known?.input ?? ["text"],
|
||||
cost: dm.cost ?? known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
cost: dm.cost ??
|
||||
known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: dm.contextWindow ?? known?.contextWindow ?? 128000,
|
||||
maxTokens: dm.maxTokens ?? known?.maxTokens ?? 16384,
|
||||
} as Model<Api>);
|
||||
|
|
@ -1177,7 +1391,11 @@ export interface ProviderConfigInput {
|
|||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
|
||||
streamSimple?: (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
|
|
@ -1189,7 +1407,12 @@ export interface ProviderConfigInput {
|
|||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
* downgrade from [1m] to base when no cross-provider fallback exists.
|
||||
*/
|
||||
|
||||
import { vi, describe, it, beforeEach, type Mock } from 'vitest';
|
||||
import assert from "node:assert/strict";
|
||||
import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js";
|
||||
import type { Api, AssistantMessage, Model } from "@singularity-forge/pi-ai";
|
||||
import { beforeEach, describe, it, type Mock, vi } from "vitest";
|
||||
import type { FallbackResolver } from "./fallback-resolver.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
import { RetryHandler, type RetryHandlerDeps } from "./retry-handler.js";
|
||||
import type { SettingsManager } from "./settings-manager.js";
|
||||
|
||||
// ─── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
|
@ -38,7 +38,14 @@ function errorMessage(msg: string): AssistantMessage {
|
|||
api: "anthropic-messages",
|
||||
provider: "anthropic",
|
||||
model: "claude-opus-4-6[1m]",
|
||||
usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, totalTokens: 0, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
errorMessage: msg,
|
||||
timestamp: Date.now(),
|
||||
|
|
@ -52,7 +59,9 @@ interface MockDeps {
|
|||
onModelChangeFn: Mock<(model: Model<any>) => void>;
|
||||
markUsageLimitReached: Mock<(...args: any[]) => boolean>;
|
||||
findFallback: Mock<(...args: any[]) => Promise<any>>;
|
||||
findModel: Mock<(provider: string, modelId: string) => Model<Api> | undefined>;
|
||||
findModel: Mock<
|
||||
(provider: string, modelId: string) => Model<Api> | undefined
|
||||
>;
|
||||
}
|
||||
|
||||
function createMockDeps(overrides?: {
|
||||
|
|
@ -60,14 +69,19 @@ function createMockDeps(overrides?: {
|
|||
retryEnabled?: boolean;
|
||||
markUsageLimitReachedResult?: boolean;
|
||||
fallbackResult?: any;
|
||||
findModelResult?: (provider: string, modelId: string) => Model<Api> | undefined;
|
||||
findModelResult?: (
|
||||
provider: string,
|
||||
modelId: string,
|
||||
) => Model<Api> | undefined;
|
||||
providerAuthMode?: "apiKey" | "oauth" | "externalCli" | "none";
|
||||
retrySettings?: {
|
||||
maxRetries?: number;
|
||||
baseDelayMs?: number;
|
||||
maxDelayMs?: number;
|
||||
};
|
||||
}): MockDeps {
|
||||
const model = overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]");
|
||||
const model =
|
||||
overrides?.model ?? createMockModel("anthropic", "claude-opus-4-6[1m]");
|
||||
const emittedEvents: Array<Record<string, any>> = [];
|
||||
const continueFn = vi.fn(async () => {});
|
||||
const onModelChangeFn = vi.fn((_model: Model<any>) => {});
|
||||
|
|
@ -76,7 +90,8 @@ function createMockDeps(overrides?: {
|
|||
);
|
||||
const findFallback = vi.fn(async () => overrides?.fallbackResult ?? null);
|
||||
const findModel = vi.fn(
|
||||
overrides?.findModelResult ?? ((_provider: string, _modelId: string) => undefined),
|
||||
overrides?.findModelResult ??
|
||||
((_provider: string, _modelId: string) => undefined),
|
||||
);
|
||||
|
||||
const messages: Array<{ role: string } & Record<string, any>> = [];
|
||||
|
|
@ -105,6 +120,7 @@ function createMockDeps(overrides?: {
|
|||
markUsageLimitReached,
|
||||
},
|
||||
find: findModel,
|
||||
getProviderAuthMode: () => overrides?.providerAuthMode ?? "apiKey",
|
||||
} as unknown as ModelRegistry,
|
||||
fallbackResolver: {
|
||||
findFallback,
|
||||
|
|
@ -115,13 +131,20 @@ function createMockDeps(overrides?: {
|
|||
onModelChange: onModelChangeFn,
|
||||
};
|
||||
|
||||
return { deps, emittedEvents, continueFn, onModelChangeFn, markUsageLimitReached, findFallback, findModel };
|
||||
return {
|
||||
deps,
|
||||
emittedEvents,
|
||||
continueFn,
|
||||
onModelChangeFn,
|
||||
markUsageLimitReached,
|
||||
findFallback,
|
||||
findModel,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── _classifyErrorType (tested via handleRetryableError behavior) ──────────
|
||||
|
||||
describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
||||
|
||||
describe("error classification", () => {
|
||||
it("classifies 'Extra usage is required for long context requests' as quota_exhausted, not rate_limit", async () => {
|
||||
// When the error is classified as quota_exhausted AND no alternate credentials
|
||||
|
|
@ -136,7 +159,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage(
|
||||
'429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}'
|
||||
'429 {"type":"error","error":{"type":"rate_limit_error","message":"Extra usage is required for long context requests."}}',
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
|
@ -145,11 +168,22 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
assert.equal(result, false);
|
||||
|
||||
// Should emit fallback_chain_exhausted (quota_exhausted path), NOT auto_retry_start (backoff path)
|
||||
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
|
||||
assert.ok(chainExhausted, "Expected fallback_chain_exhausted event for entitlement error");
|
||||
const chainExhausted = emittedEvents.find(
|
||||
(e) => e.type === "fallback_chain_exhausted",
|
||||
);
|
||||
assert.ok(
|
||||
chainExhausted,
|
||||
"Expected fallback_chain_exhausted event for entitlement error",
|
||||
);
|
||||
|
||||
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
|
||||
assert.equal(retryStart, undefined, "Should NOT emit auto_retry_start for entitlement error");
|
||||
const retryStart = emittedEvents.find(
|
||||
(e) => e.type === "auto_retry_start",
|
||||
);
|
||||
assert.equal(
|
||||
retryStart,
|
||||
undefined,
|
||||
"Should NOT emit auto_retry_start for entitlement error",
|
||||
);
|
||||
});
|
||||
|
||||
it("still classifies regular 429 rate limits as rate_limit", async () => {
|
||||
|
|
@ -168,7 +202,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
// Should enter the backoff loop (rate_limit path, not quota_exhausted)
|
||||
assert.equal(result, true);
|
||||
|
||||
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
|
||||
const retryStart = emittedEvents.find(
|
||||
(e) => e.type === "auto_retry_start",
|
||||
);
|
||||
assert.ok(retryStart, "Regular 429 should enter backoff retry");
|
||||
});
|
||||
|
||||
|
|
@ -184,7 +220,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage(
|
||||
'529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}'
|
||||
'529 {"type":"error","error":{"type":"overloaded_error","message":"The server cluster is currently under high load. Please retry after a short wait and thank you for your patience. (2064) (529)"},"request_id":"062e76f8f25cd919caa3af4baaa49203"}',
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
|
@ -192,11 +228,18 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
// Should enter the backoff loop (rate_limit path, not quota_exhausted)
|
||||
assert.equal(result, true);
|
||||
|
||||
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
|
||||
assert.ok(retryStart, "529 overloaded_error should enter backoff retry as rate_limit");
|
||||
const retryStart = emittedEvents.find(
|
||||
(e) => e.type === "auto_retry_start",
|
||||
);
|
||||
assert.ok(
|
||||
retryStart,
|
||||
"529 overloaded_error should enter backoff retry as rate_limit",
|
||||
);
|
||||
|
||||
// Must NOT be treated as quota_exhausted (would emit fallback_chain_exhausted)
|
||||
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
|
||||
const chainExhausted = emittedEvents.find(
|
||||
(e) => e.type === "fallback_chain_exhausted",
|
||||
);
|
||||
assert.equal(
|
||||
chainExhausted,
|
||||
undefined,
|
||||
|
|
@ -218,27 +261,40 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(result, true, "affordability error should trigger credit-aware retry");
|
||||
const retryStart = emittedEvents.find((e) => e.type === "auto_retry_start");
|
||||
assert.ok(retryStart, "Expected immediate retry after reducing max tokens");
|
||||
assert.equal(
|
||||
result,
|
||||
true,
|
||||
"affordability error should trigger credit-aware retry",
|
||||
);
|
||||
const retryStart = emittedEvents.find(
|
||||
(e) => e.type === "auto_retry_start",
|
||||
);
|
||||
assert.ok(
|
||||
retryStart,
|
||||
"Expected immediate retry after reducing max tokens",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("long-context model downgrade", () => {
|
||||
it("downgrades from [1m] to base model when entitlement error and no fallback", async () => {
|
||||
const baseModel = createMockModel("anthropic", "claude-opus-4-6");
|
||||
const { deps, emittedEvents, onModelChangeFn, continueFn } = createMockDeps({
|
||||
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
|
||||
markUsageLimitReachedResult: false,
|
||||
fallbackResult: null,
|
||||
findModelResult: (provider: string, modelId: string) => {
|
||||
if (provider === "anthropic" && modelId === "claude-opus-4-6") return baseModel;
|
||||
return undefined;
|
||||
},
|
||||
});
|
||||
const { deps, emittedEvents, onModelChangeFn, continueFn } =
|
||||
createMockDeps({
|
||||
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
|
||||
markUsageLimitReachedResult: false,
|
||||
fallbackResult: null,
|
||||
findModelResult: (provider: string, modelId: string) => {
|
||||
if (provider === "anthropic" && modelId === "claude-opus-4-6")
|
||||
return baseModel;
|
||||
return undefined;
|
||||
},
|
||||
});
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("Extra usage is required for long context requests.");
|
||||
const msg = errorMessage(
|
||||
"Extra usage is required for long context requests.",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
|
|
@ -253,9 +309,17 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
assert.equal(onModelChangeFn.mock.calls.length, 1);
|
||||
|
||||
// Should emit a fallback_provider_switch event indicating downgrade
|
||||
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
|
||||
assert.ok(switchEvent, "Expected fallback_provider_switch event for downgrade");
|
||||
assert.ok(switchEvent!.reason.includes("long context downgrade"), `reason should mention downgrade: ${switchEvent!.reason}`);
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch",
|
||||
);
|
||||
assert.ok(
|
||||
switchEvent,
|
||||
"Expected fallback_provider_switch event for downgrade",
|
||||
);
|
||||
assert.ok(
|
||||
switchEvent!.reason.includes("long context downgrade"),
|
||||
`reason should mention downgrade: ${switchEvent!.reason}`,
|
||||
);
|
||||
});
|
||||
|
||||
it("emits fallback_chain_exhausted when base model is also unavailable", async () => {
|
||||
|
|
@ -267,13 +331,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
});
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("Extra usage is required for long context requests.");
|
||||
const msg = errorMessage(
|
||||
"Extra usage is required for long context requests.",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(result, false);
|
||||
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
|
||||
assert.ok(chainExhausted, "Expected fallback_chain_exhausted when base model unavailable");
|
||||
const chainExhausted = emittedEvents.find(
|
||||
(e) => e.type === "fallback_chain_exhausted",
|
||||
);
|
||||
assert.ok(
|
||||
chainExhausted,
|
||||
"Expected fallback_chain_exhausted when base model unavailable",
|
||||
);
|
||||
});
|
||||
|
||||
it("does not attempt downgrade for non-[1m] models", async () => {
|
||||
|
|
@ -286,17 +357,27 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
});
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("Extra usage is required for long context requests.");
|
||||
const msg = errorMessage(
|
||||
"Extra usage is required for long context requests.",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(result, false);
|
||||
const chainExhausted = emittedEvents.find((e) => e.type === "fallback_chain_exhausted");
|
||||
const chainExhausted = emittedEvents.find(
|
||||
(e) => e.type === "fallback_chain_exhausted",
|
||||
);
|
||||
assert.ok(chainExhausted);
|
||||
|
||||
// No downgrade switch should occur
|
||||
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
|
||||
assert.equal(switchEvent, undefined, "Should not switch for non-[1m] models");
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch",
|
||||
);
|
||||
assert.equal(
|
||||
switchEvent,
|
||||
undefined,
|
||||
"Should not switch for non-[1m] models",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -315,9 +396,19 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
handler.abortRetry();
|
||||
await new Promise((resolve) => setTimeout(resolve, 10));
|
||||
|
||||
assert.equal(continueFn.mock.calls.length, 0, "cancelled retry must not continue after explicit abort");
|
||||
const endEvents = emittedEvents.filter((e) => e.type === "auto_retry_end");
|
||||
assert.equal(endEvents.length, 1, "retry cancellation should emit a single auto_retry_end event");
|
||||
assert.equal(
|
||||
continueFn.mock.calls.length,
|
||||
0,
|
||||
"cancelled retry must not continue after explicit abort",
|
||||
);
|
||||
const endEvents = emittedEvents.filter(
|
||||
(e) => e.type === "auto_retry_end",
|
||||
);
|
||||
assert.equal(
|
||||
endEvents.length,
|
||||
1,
|
||||
"retry cancellation should emit a single auto_retry_end event",
|
||||
);
|
||||
assert.equal(endEvents[0]?.finalError, "Retry cancelled");
|
||||
});
|
||||
});
|
||||
|
|
@ -346,10 +437,20 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const downgraded = setModelCalls[0][0] as Model<Api>;
|
||||
assert.equal(downgraded.provider, "openrouter");
|
||||
assert.equal(downgraded.id, "openai/gpt-5-pro");
|
||||
assert.equal(downgraded.maxTokens, 297, "expected affordability cap with safety buffer");
|
||||
assert.equal(
|
||||
downgraded.maxTokens,
|
||||
297,
|
||||
"expected affordability cap with safety buffer",
|
||||
);
|
||||
|
||||
assert.equal(onModelChangeFn.mock.calls.length, 1, "should notify about model update");
|
||||
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
|
||||
assert.equal(
|
||||
onModelChangeFn.mock.calls.length,
|
||||
1,
|
||||
"should notify about model update",
|
||||
);
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch",
|
||||
);
|
||||
assert.ok(switchEvent, "should emit model-adjustment event");
|
||||
assert.ok(
|
||||
String(switchEvent?.reason || "").includes("credit-aware retry"),
|
||||
|
|
@ -373,7 +474,11 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
);
|
||||
|
||||
await handler.handleRetryableError(msg);
|
||||
assert.equal(markUsageLimitReached.mock.calls.length, 0, "quota error should skip credential cooldown");
|
||||
assert.equal(
|
||||
markUsageLimitReached.mock.calls.length,
|
||||
0,
|
||||
"quota error should skip credential cooldown",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -381,7 +486,9 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
it("considers long-context entitlement error as retryable", () => {
|
||||
const { deps } = createMockDeps();
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("Extra usage is required for long context requests.");
|
||||
const msg = errorMessage(
|
||||
"Extra usage is required for long context requests.",
|
||||
);
|
||||
assert.equal(handler.isRetryableError(msg), true);
|
||||
});
|
||||
|
||||
|
|
@ -393,7 +500,7 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage(
|
||||
'All credentials for "anthropic" are in a cooldown window. ' +
|
||||
'Please wait a moment and try again, or switch to a different provider.',
|
||||
"Please wait a moment and try again, or switch to a different provider.",
|
||||
);
|
||||
assert.equal(handler.isRetryableError(msg), false);
|
||||
});
|
||||
|
|
@ -414,7 +521,8 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const { deps, emittedEvents, onModelChangeFn } = createMockDeps({
|
||||
model: createMockModel("anthropic", "claude-opus-4-6"),
|
||||
findModelResult: (provider: string, modelId: string) => {
|
||||
if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel;
|
||||
if (provider === "claude-code" && modelId === "claude-opus-4-6")
|
||||
return ccModel;
|
||||
return undefined;
|
||||
},
|
||||
});
|
||||
|
|
@ -426,9 +534,14 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(result, true, "should retry via claude-code fallback");
|
||||
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch",
|
||||
);
|
||||
assert.ok(switchEvent, "Expected fallback_provider_switch event");
|
||||
assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider");
|
||||
assert.ok(
|
||||
switchEvent!.to.startsWith("claude-code/"),
|
||||
"Should switch to claude-code provider",
|
||||
);
|
||||
});
|
||||
|
||||
it("switches to claude-code on 'out of extra usage' error (#3772)", async () => {
|
||||
|
|
@ -436,21 +549,29 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const { deps, emittedEvents } = createMockDeps({
|
||||
model: createMockModel("anthropic", "claude-opus-4-6"),
|
||||
findModelResult: (provider: string, modelId: string) => {
|
||||
if (provider === "claude-code" && modelId === "claude-opus-4-6") return ccModel;
|
||||
if (provider === "claude-code" && modelId === "claude-opus-4-6")
|
||||
return ccModel;
|
||||
return undefined;
|
||||
},
|
||||
});
|
||||
deps.isClaudeCodeReady = () => true;
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("You're out of extra usage. Add more at claude.ai/settings/usage and keep going.");
|
||||
const msg = errorMessage(
|
||||
"You're out of extra usage. Add more at claude.ai/settings/usage and keep going.",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(result, true, "should retry via claude-code fallback");
|
||||
const switchEvent = emittedEvents.find((e) => e.type === "fallback_provider_switch");
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch",
|
||||
);
|
||||
assert.ok(switchEvent, "Expected fallback_provider_switch event");
|
||||
assert.ok(switchEvent!.to.startsWith("claude-code/"), "Should switch to claude-code provider");
|
||||
assert.ok(
|
||||
switchEvent!.to.startsWith("claude-code/"),
|
||||
"Should switch to claude-code provider",
|
||||
);
|
||||
});
|
||||
|
||||
it("does NOT switch to claude-code when current provider is not anthropic", async () => {
|
||||
|
|
@ -458,22 +579,31 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
const { deps, emittedEvents } = createMockDeps({
|
||||
model: createMockModel("openai", "gpt-4o"),
|
||||
findModelResult: (provider: string, modelId: string) => {
|
||||
if (provider === "claude-code" && modelId === "gpt-4o") return ccModel;
|
||||
if (provider === "claude-code" && modelId === "gpt-4o")
|
||||
return ccModel;
|
||||
return undefined;
|
||||
},
|
||||
});
|
||||
deps.isClaudeCodeReady = () => true;
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("third-party apps are not supported for this plan");
|
||||
const msg = errorMessage(
|
||||
"third-party apps are not supported for this plan",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
// Should NOT have triggered the claude-code fallback
|
||||
const switchEvent = emittedEvents.find(
|
||||
(e) => e.type === "fallback_provider_switch" && e.to?.startsWith("claude-code/"),
|
||||
(e) =>
|
||||
e.type === "fallback_provider_switch" &&
|
||||
e.to?.startsWith("claude-code/"),
|
||||
);
|
||||
assert.equal(
|
||||
switchEvent,
|
||||
undefined,
|
||||
"Should NOT switch non-anthropic provider to claude-code",
|
||||
);
|
||||
assert.equal(switchEvent, undefined, "Should NOT switch non-anthropic provider to claude-code");
|
||||
});
|
||||
});
|
||||
|
||||
|
|
@ -522,16 +652,60 @@ describe("RetryHandler — long-context entitlement 429 (#2803)", () => {
|
|||
);
|
||||
});
|
||||
|
||||
it("does NOT write credential cooldown for external CLI rate_limit errors", async () => {
|
||||
const fallbackModel = createMockModel("openai", "gpt-4o");
|
||||
const { deps, emittedEvents, markUsageLimitReached, findFallback } =
|
||||
createMockDeps({
|
||||
model: createMockModel("google-gemini-cli", "gemini-2.5-pro"),
|
||||
providerAuthMode: "externalCli",
|
||||
markUsageLimitReachedResult: true,
|
||||
fallbackResult: {
|
||||
model: fallbackModel,
|
||||
reason: "cross-provider fallback",
|
||||
},
|
||||
});
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("429 Too Many Requests");
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
assert.equal(
|
||||
result,
|
||||
true,
|
||||
"external CLI 429 should still enter fallback/retry handling",
|
||||
);
|
||||
assert.equal(
|
||||
markUsageLimitReached.mock.calls.length,
|
||||
0,
|
||||
"external CLI providers must not be written into .sf credential cooldown",
|
||||
);
|
||||
assert.equal(
|
||||
findFallback.mock.calls.length,
|
||||
1,
|
||||
"429 should still ask for provider fallback",
|
||||
);
|
||||
assert.ok(
|
||||
emittedEvents.some((event) => event.type === "auto_retry_start"),
|
||||
"429 should still emit retry start after fallback selection",
|
||||
);
|
||||
});
|
||||
|
||||
it("still tries cross-provider fallback for quota_exhausted without credential backoff", async () => {
|
||||
const fallbackModel = createMockModel("openai", "gpt-4o");
|
||||
const { deps, markUsageLimitReached, continueFn } = createMockDeps({
|
||||
model: createMockModel("anthropic", "claude-opus-4-6[1m]"),
|
||||
markUsageLimitReachedResult: false,
|
||||
fallbackResult: { model: fallbackModel, reason: "cross-provider fallback" },
|
||||
fallbackResult: {
|
||||
model: fallbackModel,
|
||||
reason: "cross-provider fallback",
|
||||
},
|
||||
});
|
||||
|
||||
const handler = new RetryHandler(deps);
|
||||
const msg = errorMessage("Extra usage is required for long context requests.");
|
||||
const msg = errorMessage(
|
||||
"Extra usage is required for long context requests.",
|
||||
);
|
||||
|
||||
const result = await handler.handleRetryableError(msg);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,12 @@
|
|||
import type { Agent } from "@singularity-forge/pi-agent-core";
|
||||
import type { AssistantMessage, Model } from "@singularity-forge/pi-ai";
|
||||
import { isContextOverflow } from "@singularity-forge/pi-ai";
|
||||
import { sleep } from "../utils/sleep.js";
|
||||
import type { AgentSessionEvent } from "./agent-session.js";
|
||||
import type { UsageLimitErrorType } from "./auth-storage.js";
|
||||
import type { FallbackResolver } from "./fallback-resolver.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
import type { SettingsManager } from "./settings-manager.js";
|
||||
import { sleep } from "../utils/sleep.js";
|
||||
import type { AgentSessionEvent } from "./agent-session.js";
|
||||
|
||||
/** Dependencies injected from AgentSession into RetryHandler */
|
||||
export interface RetryHandlerDeps {
|
||||
|
|
@ -41,7 +41,8 @@ export class RetryHandler {
|
|||
private _retryPromise: Promise<void> | undefined = undefined;
|
||||
private _retryResolve: (() => void) | undefined = undefined;
|
||||
private _retryGeneration = 0;
|
||||
private _continueTimeout: ReturnType<typeof setTimeout> | undefined = undefined;
|
||||
private _continueTimeout: ReturnType<typeof setTimeout> | undefined =
|
||||
undefined;
|
||||
|
||||
constructor(private readonly _deps: RetryHandlerDeps) {}
|
||||
|
||||
|
|
@ -70,7 +71,9 @@ export class RetryHandler {
|
|||
* Must be called synchronously from the agent event handler before
|
||||
* any async processing, so that waitForRetry() doesn't miss in-flight retries.
|
||||
*/
|
||||
createRetryPromiseForAgentEnd(messages: Array<{ role: string } & Record<string, any>>): void {
|
||||
createRetryPromiseForAgentEnd(
|
||||
messages: Array<{ role: string } & Record<string, any>>,
|
||||
): void {
|
||||
if (this._retryPromise) return;
|
||||
|
||||
const settings = this._deps.settingsManager.getRetrySettings();
|
||||
|
|
@ -162,7 +165,10 @@ export class RetryHandler {
|
|||
// when provider reports "can only afford N", lower maxTokens and retry
|
||||
// on the same model before rotating credentials/providers.
|
||||
if (isQuotaError) {
|
||||
const adjusted = this._tryAffordableMaxTokensRetry(message, retryGeneration);
|
||||
const adjusted = this._tryAffordableMaxTokensRetry(
|
||||
message,
|
||||
retryGeneration,
|
||||
);
|
||||
if (adjusted) return true;
|
||||
}
|
||||
|
||||
|
|
@ -171,12 +177,16 @@ export class RetryHandler {
|
|||
// gates; rotating to another credential on the same account won't help
|
||||
// and the 30-minute backoff blocks all provider requests needlessly.
|
||||
if (isRateLimit) {
|
||||
const provider = this._deps.getModel()!.provider;
|
||||
const authMode = this._deps.modelRegistry.getProviderAuthMode(provider);
|
||||
const hasAlternate =
|
||||
this._deps.modelRegistry.authStorage.markUsageLimitReached(
|
||||
this._deps.getModel()!.provider,
|
||||
this._deps.getSessionId(),
|
||||
{ errorType },
|
||||
);
|
||||
authMode === "externalCli" || authMode === "none"
|
||||
? false
|
||||
: this._deps.modelRegistry.authStorage.markUsageLimitReached(
|
||||
provider,
|
||||
this._deps.getSessionId(),
|
||||
{ errorType },
|
||||
);
|
||||
|
||||
if (hasAlternate) {
|
||||
this._removeLastAssistantError();
|
||||
|
|
@ -235,7 +245,10 @@ export class RetryHandler {
|
|||
// No fallback available either
|
||||
if (isQuotaError) {
|
||||
// Try long-context model downgrade ([1m] → base) before giving up
|
||||
const downgraded = this._tryLongContextDowngrade(message, retryGeneration);
|
||||
const downgraded = this._tryLongContextDowngrade(
|
||||
message,
|
||||
retryGeneration,
|
||||
);
|
||||
if (downgraded) return true;
|
||||
|
||||
this._deps.emit({
|
||||
|
|
@ -271,7 +284,8 @@ export class RetryHandler {
|
|||
|
||||
// Use server-requested delay when available, capped by maxDelayMs.
|
||||
// Fall back to exponential backoff when no server hint is present.
|
||||
const exponentialDelayMs = settings.baseDelayMs * 2 ** (this._retryAttempt - 1);
|
||||
const exponentialDelayMs =
|
||||
settings.baseDelayMs * 2 ** (this._retryAttempt - 1);
|
||||
let delayMs: number;
|
||||
if (message.retryAfterMs !== undefined) {
|
||||
const cap = settings.maxDelayMs > 0 ? settings.maxDelayMs : Infinity;
|
||||
|
|
@ -280,7 +294,8 @@ export class RetryHandler {
|
|||
type: "auto_retry_end",
|
||||
success: false,
|
||||
attempt: this._retryAttempt - 1,
|
||||
finalError: `Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(),
|
||||
finalError:
|
||||
`Rate limit reset in ${Math.ceil(message.retryAfterMs / 1000)}s (max: ${Math.ceil(cap / 1000)}s). ${message.errorMessage || ""}`.trim(),
|
||||
});
|
||||
this._retryAttempt = 0;
|
||||
this._resolveRetry();
|
||||
|
|
@ -335,9 +350,9 @@ export class RetryHandler {
|
|||
/** Cancel in-progress retry */
|
||||
abortRetry(): void {
|
||||
const hadRetry =
|
||||
this._retryPromise !== undefined
|
||||
|| this._retryAbortController !== undefined
|
||||
|| this._continueTimeout !== undefined;
|
||||
this._retryPromise !== undefined ||
|
||||
this._retryAbortController !== undefined ||
|
||||
this._continueTimeout !== undefined;
|
||||
if (!hadRetry) return;
|
||||
|
||||
const attempt = this._retryAttempt > 0 ? this._retryAttempt : 1;
|
||||
|
|
@ -417,13 +432,30 @@ export class RetryHandler {
|
|||
const err = errorMessage.toLowerCase();
|
||||
// Long-context entitlement errors are billing gates, not transient rate limits.
|
||||
// Must be checked before the generic 429/rate_limit regex.
|
||||
if (/extra usage is required|long context required/i.test(err)) return "quota_exhausted";
|
||||
if (/requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test(err))
|
||||
if (/extra usage is required|long context required/i.test(err))
|
||||
return "quota_exhausted";
|
||||
if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) return "quota_exhausted";
|
||||
if (/rate.?limit|too many requests|429|529|overloaded/i.test(err)) return "rate_limit";
|
||||
if (/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(err)) return "server_error";
|
||||
if (/401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test(err)) return "auth_error";
|
||||
if (
|
||||
/requires more credits|can only afford|insufficient credits|not enough credits|credit balance/i.test(
|
||||
err,
|
||||
)
|
||||
)
|
||||
return "quota_exhausted";
|
||||
if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err))
|
||||
return "quota_exhausted";
|
||||
if (/rate.?limit|too many requests|429|529|overloaded/i.test(err))
|
||||
return "rate_limit";
|
||||
if (
|
||||
/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(
|
||||
err,
|
||||
)
|
||||
)
|
||||
return "server_error";
|
||||
if (
|
||||
/401|authentication.*error|invalid.*api.?key|api.?key.*invalid|api.?key.*expired|failed to authenticate|unauthorized/i.test(
|
||||
err,
|
||||
)
|
||||
)
|
||||
return "auth_error";
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
|
|
@ -431,7 +463,10 @@ export class RetryHandler {
|
|||
* Attempt a same-model retry by reducing maxTokens when provider reports
|
||||
* an affordability cap (e.g., "can only afford 329").
|
||||
*/
|
||||
private _tryAffordableMaxTokensRetry(message: AssistantMessage, retryGeneration: number): boolean {
|
||||
private _tryAffordableMaxTokensRetry(
|
||||
message: AssistantMessage,
|
||||
retryGeneration: number,
|
||||
): boolean {
|
||||
const currentModel = this._deps.getModel();
|
||||
if (!currentModel || !message.errorMessage) return false;
|
||||
|
||||
|
|
@ -443,9 +478,15 @@ export class RetryHandler {
|
|||
if (!Number.isFinite(affordable) || affordable <= 0) return false;
|
||||
|
||||
// Leave a small buffer so slight input variance doesn't immediately re-fail.
|
||||
const safetyBuffer = Math.min(64, Math.max(16, Math.floor(affordable * 0.1)));
|
||||
const safetyBuffer = Math.min(
|
||||
64,
|
||||
Math.max(16, Math.floor(affordable * 0.1)),
|
||||
);
|
||||
const targetMaxTokens = Math.max(64, affordable - safetyBuffer);
|
||||
const downgradedMaxTokens = Math.min(currentModel.maxTokens, targetMaxTokens);
|
||||
const downgradedMaxTokens = Math.min(
|
||||
currentModel.maxTokens,
|
||||
targetMaxTokens,
|
||||
);
|
||||
if (downgradedMaxTokens >= currentModel.maxTokens) return false;
|
||||
|
||||
const downgradedModel = {
|
||||
|
|
@ -481,7 +522,10 @@ export class RetryHandler {
|
|||
* base model (claude-opus-4-6) when the account lacks the long-context billing
|
||||
* entitlement. Returns true if the downgrade was initiated.
|
||||
*/
|
||||
private _tryLongContextDowngrade(message: AssistantMessage, retryGeneration: number): boolean {
|
||||
private _tryLongContextDowngrade(
|
||||
message: AssistantMessage,
|
||||
retryGeneration: number,
|
||||
): boolean {
|
||||
const currentModel = this._deps.getModel();
|
||||
if (!currentModel) return false;
|
||||
|
||||
|
|
@ -490,7 +534,10 @@ export class RetryHandler {
|
|||
if (!match) return false;
|
||||
|
||||
const baseModelId = match[1];
|
||||
const baseModel = this._deps.modelRegistry.find(currentModel.provider, baseModelId);
|
||||
const baseModel = this._deps.modelRegistry.find(
|
||||
currentModel.provider,
|
||||
baseModelId,
|
||||
);
|
||||
if (!baseModel) return false;
|
||||
|
||||
const previousId = currentModel.id;
|
||||
|
|
@ -525,7 +572,9 @@ export class RetryHandler {
|
|||
* and the "out of extra usage" variant that subscription users receive.
|
||||
*/
|
||||
private _isThirdPartyBlock(errorMessage: string): boolean {
|
||||
return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test(errorMessage);
|
||||
return /third[- .]party.*(?:draw from extra|not.*available|plan limits|not permitted|cannot be used|not supported)|(?:out of|no) extra usage/i.test(
|
||||
errorMessage,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -533,7 +582,10 @@ export class RetryHandler {
|
|||
* Anthropic provider is blocked by the third-party policy (#3772).
|
||||
* Returns true if the switch was made and retry scheduled.
|
||||
*/
|
||||
private _tryClaudeCodeFallback(message: AssistantMessage, retryGeneration: number): boolean {
|
||||
private _tryClaudeCodeFallback(
|
||||
message: AssistantMessage,
|
||||
retryGeneration: number,
|
||||
): boolean {
|
||||
if (!this._deps.isClaudeCodeReady?.()) return false;
|
||||
|
||||
const currentModel = this._deps.getModel();
|
||||
|
|
@ -544,7 +596,10 @@ export class RetryHandler {
|
|||
if (currentModel.provider !== "anthropic") return false;
|
||||
|
||||
// Find the same model ID under the claude-code provider
|
||||
const ccModel = this._deps.modelRegistry.find("claude-code", currentModel.id);
|
||||
const ccModel = this._deps.modelRegistry.find(
|
||||
"claude-code",
|
||||
currentModel.id,
|
||||
);
|
||||
if (!ccModel) return false;
|
||||
|
||||
const previousProvider = currentModel.provider;
|
||||
|
|
@ -556,7 +611,8 @@ export class RetryHandler {
|
|||
type: "fallback_provider_switch",
|
||||
from: `${previousProvider}/${currentModel.id}`,
|
||||
to: `claude-code/${ccModel.id}`,
|
||||
reason: "Anthropic subscription blocked for third-party apps — routing through Claude Code CLI",
|
||||
reason:
|
||||
"Anthropic subscription blocked for third-party apps — routing through Claude Code CLI",
|
||||
});
|
||||
|
||||
this._deps.emit({
|
||||
|
|
@ -574,7 +630,10 @@ export class RetryHandler {
|
|||
/** Remove the last assistant error message from agent state */
|
||||
private _removeLastAssistantError(): void {
|
||||
const messages = this._deps.agent.state.messages;
|
||||
if (messages.length > 0 && messages[messages.length - 1].role === "assistant") {
|
||||
if (
|
||||
messages.length > 0 &&
|
||||
messages[messages.length - 1].role === "assistant"
|
||||
) {
|
||||
this._deps.agent.replaceMessages(messages.slice(0, -1));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue