singularity-forge/packages/pi-coding-agent/src/core/model-registry.ts
2026-05-05 16:42:36 +02:00

1426 lines
44 KiB
TypeScript

/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import { existsSync, readFileSync } from "node:fs";
import { join } from "node:path";
import { type Static, Type } from "@sinclair/typebox";
import {
type Api,
type AssistantMessageEventStream,
applyCapabilityPatches,
type Context,
getApiProvider,
getEnvApiKey,
getModels,
getProviders,
type KnownProvider,
type Model,
type OAuthProviderInterface,
type OpenAICompletionsCompat,
type OpenAIResponsesCompat,
registerApiProvider,
resetApiProviders,
type SimpleStreamOptions,
} from "@singularity-forge/pi-ai";
import {
registerOAuthProvider,
resetOAuthProviders,
} from "@singularity-forge/pi-ai/oauth";
import AjvModule from "ajv";
import { getAgentDir } from "../config.js";
import type { AuthStorage } from "./auth-storage.js";
import { ModelDiscoveryCache } from "./discovery-cache.js";
import { isLocalModel } from "./local-model-check.js";
import type { DiscoveryResult } from "./model-discovery.js";
import {
getDiscoverableCatalogSources,
getDiscoverableProviders,
getDiscoveryAdapter,
} from "./model-discovery.js";
import { resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
const Ajv = (AjvModule as any).default || AjvModule;
const ajv = new Ajv();
// ── Proxy provider priority ──────────────────────────────────────────────────
/** Global fallback chain appended after every family's direct providers. */
export const GLOBAL_PROVIDER_FALLBACK: readonly string[] = [
"opencode",
"opencode-go",
"openrouter",
"ollama-cloud",
];
/**
* Per-family direct-provider priority. Each entry lists only the preferred
* direct providers for that family. GLOBAL_PROVIDER_FALLBACK is always
* appended after these when building the effective resolution order.
*/
export const PROXY_FAMILY_PRIORITY: ReadonlyArray<{
match: RegExp;
/** Canonical key used when matching settings.proxy.providerPriority overrides */
prefix: string;
/** True direct providers — the vendor or first-party endpoint. Tried first. */
providers: string[];
/**
* Family-scoped failover providers — re-servers/proxies that serve this
* family but aren't the vendor. Tried AFTER direct providers but BEFORE
* the generic GLOBAL_PROVIDER_FALLBACK. Kept separate so the config is
* honest about which endpoints are "native" vs "via intermediary".
*/
family_failover?: string[];
/** Disable generic fallback for families that must stay on one provider. */
global_fallback?: boolean;
}> = [
// MiniMax direct (api.minimax.io) → CN endpoint as its direct pair
{
match: /^MiniMax-/i,
prefix: "MiniMax-",
providers: ["minimax", "minimax-cn"],
},
// ZAI direct API for GLM
{ match: /^glm-/i, prefix: "glm-", providers: ["zai"] },
// Kimi Code direct API
{ match: /^kimi-/i, prefix: "kimi-", providers: ["kimi-coding"] },
// MiMo/Xiaomi — direct API via Xiaomi MiMo Open Platform (api.xiaomimimo.com)
// or the Token Plan endpoint (token-plan-sgp.xiaomimimo.com). Both served
// under the `xiaomi` provider namespace.
{ match: /^mimo-|^XiaomiMiMo\//i, prefix: "mimo-", providers: ["xiaomi"] },
// Gemini/Gemma: route bare model IDs through google-gemini-cli only.
// Direct GenAI and Vertex providers stay explicit provider-qualified routes,
// but they are hidden from normal SF/TUI selection and default fallback.
{
match: /^gemini-|^gemma-/i,
prefix: "gemini-",
providers: ["google-gemini-cli"],
global_fallback: false,
},
// Claude: Anthropic is the default provider. Copilot is disabled.
{
match: /^claude-/i,
prefix: "claude-",
providers: ["anthropic"],
},
// GPT / o-series / codex: OpenAI is direct. azure-openai-responses is
// Microsoft's re-serving of OpenAI weights — treated as failover. Copilot is disabled.
{
match: /^gpt-|^o\d|^codex-/i,
prefix: "gpt-",
providers: ["openai"],
family_failover: ["azure-openai-responses", "openai-codex"],
},
];
// ── Schema for OpenRouter routing preferences
const OpenRouterRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for Vercel AI Gateway routing preferences
const VercelGatewayRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for OpenAI compatibility settings
const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(
Type.Union([
Type.Literal("max_completion_tokens"),
Type.Literal("max_tokens"),
]),
),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(
Type.Union([
Type.Literal("openai"),
Type.Literal("zai"),
Type.Literal("qwen"),
]),
),
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
});
const OpenAIResponsesCompatSchema = Type.Object({
// Reserved for future use
});
const OpenAICompatSchema = Type.Union([
OpenAICompletionsCompatSchema,
OpenAIResponsesCompatSchema,
]);
// Schema for custom model definition
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
// Schema for per-model overrides (all fields optional, merged with built-in model)
const ModelOverrideSchema = Type.Object({
name: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(
Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")])),
),
cost: Type.Optional(
Type.Object({
input: Type.Optional(Type.Number()),
output: Type.Optional(Type.Number()),
cacheRead: Type.Optional(Type.Number()),
cacheWrite: Type.Optional(Type.Number()),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
type ModelOverride = Static<typeof ModelOverrideSchema>;
const ProviderConfigSchema = Type.Object({
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
modelOverrides: Type.Optional(
Type.Record(Type.String(), ModelOverrideSchema),
),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
type ModelsConfig = Static<typeof ModelsConfigSchema>;
export type ProviderModelAllowList = Record<string, readonly string[]>;
export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none";
type ProviderPolicyModel = Pick<Model<Api>, "provider" | "id"> &
Partial<Pick<Model<Api>, "name" | "cost">>;
const OPENCODE_FREE_MODEL_IDS = new Set([
"big-pickle",
"gpt-5-nano",
"minimax-m2.5-free",
"nemotron-3-super-free",
]);
const HIDDEN_MODEL_PROVIDERS = new Set([
"claude-code",
"google",
"google-vertex",
"github-copilot",
"xai",
"xiaomi-token-plan-ams",
"xiaomi-token-plan-cn",
"xiaomi-token-plan-sgp",
]);
const BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS = new Set(["google-gemini-cli"]);
function providerModelAllowEntryMatches(
allowedModel: string,
modelKey: string,
): boolean {
const allowedKey = allowedModel.trim().toLowerCase();
if (!allowedKey) return false;
if (allowedKey === modelKey) return true;
if (allowedKey.startsWith(":")) return modelKey.endsWith(allowedKey);
if (!allowedKey.includes("*")) return false;
const pattern = `^${allowedKey
.split("*")
.map((part) => part.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"))
.join(".*")}$`;
return new RegExp(pattern).test(modelKey);
}
function hasFreeSkuMarker(value: string | undefined): boolean {
return /(^|[-_:/\s])free($|[-_:/\s])/i.test(value ?? "");
}
function isZeroCost(cost: Model<Api>["cost"] | undefined): boolean {
return (
!!cost &&
cost.input === 0 &&
cost.output === 0 &&
cost.cacheRead === 0 &&
cost.cacheWrite === 0
);
}
function isMistralSelectionModel(modelId: string): boolean {
const modelKey = modelId.trim().toLowerCase();
if (
modelKey.startsWith("ft:") ||
modelKey.includes("embed") ||
modelKey.includes("moderation") ||
modelKey.includes("ocr") ||
modelKey.includes("voxtral") ||
modelKey.includes("transcribe") ||
modelKey.includes("tts") ||
modelKey.includes("realtime")
) {
return false;
}
return true;
}
function isModelAllowedByBuiltInProviderPolicy(
model: ProviderPolicyModel,
): boolean {
const provider = model.provider.toLowerCase();
const modelKey = model.id.trim().toLowerCase();
if (HIDDEN_MODEL_PROVIDERS.has(provider)) {
return false;
}
if (modelKey.includes("grok") || modelKey.startsWith("x-ai/")) {
return false;
}
if (provider === "mistral") {
return isMistralSelectionModel(model.id);
}
if (provider === "openrouter") {
return providerModelAllowEntryMatches(":free", modelKey);
}
if (provider === "opencode") {
return (
OPENCODE_FREE_MODEL_IDS.has(modelKey) ||
hasFreeSkuMarker(model.id) ||
hasFreeSkuMarker(model.name) ||
isZeroCost(model.cost)
);
}
return true;
}
/** Provider override config (baseUrl, headers, apiKey) without custom models */
interface ProviderOverride {
baseUrl?: string;
headers?: Record<string, string>;
apiKey?: string;
}
/** Result of loading custom models from models.json */
interface CustomModelsResult {
models: Model<Api>[];
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
overrides: Map<string, ProviderOverride>;
/** Per-model overrides: provider -> modelId -> override */
modelOverrides: Map<string, Map<string, ModelOverride>>;
error: string | undefined;
}
function emptyCustomModelsResult(error?: string): CustomModelsResult {
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
}
function mergeCompat(
baseCompat: Model<Api>["compat"],
overrideCompat: ModelOverride["compat"],
): Model<Api>["compat"] | undefined {
if (!overrideCompat) return baseCompat;
const base = baseCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat
| undefined;
const override = overrideCompat as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const merged = { ...base, ...override } as
| OpenAICompletionsCompat
| OpenAIResponsesCompat;
const baseCompletions = base as OpenAICompletionsCompat | undefined;
const overrideCompletions = override as OpenAICompletionsCompat;
const mergedCompletions = merged as OpenAICompletionsCompat;
if (
baseCompletions?.openRouterRouting ||
overrideCompletions.openRouterRouting
) {
mergedCompletions.openRouterRouting = {
...baseCompletions?.openRouterRouting,
...overrideCompletions.openRouterRouting,
};
}
if (
baseCompletions?.vercelGatewayRouting ||
overrideCompletions.vercelGatewayRouting
) {
mergedCompletions.vercelGatewayRouting = {
...baseCompletions?.vercelGatewayRouting,
...overrideCompletions.vercelGatewayRouting,
};
}
return merged as Model<Api>["compat"];
}
/**
* Deep merge a model override into a model.
* Handles nested objects (cost, compat) by merging rather than replacing.
*/
function applyModelOverride(
model: Model<Api>,
override: ModelOverride,
): Model<Api> {
const result = { ...model };
// Simple field overrides
if (override.name !== undefined) result.name = override.name;
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
if (override.input !== undefined)
result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined)
result.contextWindow = override.contextWindow;
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
// Merge cost (partial override)
if (override.cost) {
result.cost = {
input: override.cost.input ?? model.cost.input,
output: override.cost.output ?? model.cost.output,
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
};
}
// Merge headers
if (override.headers) {
const resolvedHeaders = resolveHeaders(override.headers);
result.headers = resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers;
}
// Deep merge compat
result.compat = mergeCompat(model.compat, override.compat);
return result;
}
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private discoveredModels: Model<Api>[] = [];
private discoveryCache: ModelDiscoveryCache;
private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
readonly authStorage: AuthStorage,
readonly modelsJsonPath: string | undefined = join(
getAgentDir(),
"models.json",
),
discoveryCache?: ModelDiscoveryCache,
) {
this.discoveryCache = discoveryCache ?? new ModelDiscoveryCache();
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveConfigValue(keyConfig);
}
return getEnvApiKey(provider);
});
// Refresh models when credentials change (e.g., OAuth token refresh with new model limits)
this.authStorage.onCredentialChange(() => this.refresh());
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = undefined;
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
resetApiProviders();
resetOAuthProviders();
this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
}
/**
* Get any error from loading models.json (undefined if no error).
*/
getError(): string | undefined {
return this.loadError;
}
private loadModels(): void {
// Load custom models and overrides from models.json
const {
models: customModels,
overrides,
modelOverrides,
error,
} = this.modelsJsonPath
? this.loadCustomModels(this.modelsJsonPath)
: emptyCustomModelsResult();
if (error) {
this.loadError = error;
// Keep built-in models even if custom models failed to load
}
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
let combined = this.mergeCustomModels(builtInModels, customModels);
// Let OAuth providers modify their models (e.g., update baseUrl)
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
const cred = this.authStorage.get(oauthProvider.id);
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
combined = oauthProvider.modifyModels(combined, cred);
}
}
// Apply capability patches so custom/discovered/extension models get
// capabilities (supportsXhigh, supportsServiceTier, etc.) that the
// static pi-ai registry applies at module load for built-in models.
this.models = applyCapabilityPatches(combined);
}
/** Load built-in models and apply provider/model overrides */
private loadBuiltInModels(
overrides: Map<string, ProviderOverride>,
modelOverrides: Map<string, Map<string, ModelOverride>>,
): Model<Api>[] {
return getProviders().flatMap((provider) => {
const models = getModels(provider as KnownProvider) as Model<Api>[];
const providerOverride = overrides.get(provider);
const perModelOverrides = modelOverrides.get(provider);
return models.map((m) => {
let model = m;
// Apply provider-level baseUrl/headers override
if (providerOverride) {
const resolvedHeaders = resolveHeaders(providerOverride.headers);
model = {
...model,
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
headers: resolvedHeaders
? { ...model.headers, ...resolvedHeaders }
: model.headers,
};
}
// Apply per-model override
const modelOverride = perModelOverrides?.get(m.id);
if (modelOverride) {
model = applyModelOverride(model, modelOverride);
}
return model;
});
});
}
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
private mergeCustomModels(
builtInModels: Model<Api>[],
customModels: Model<Api>[],
): Model<Api>[] {
const merged = [...builtInModels];
for (const customModel of customModels) {
const existingIndex = merged.findIndex(
(m) => m.provider === customModel.provider && m.id === customModel.id,
);
if (existingIndex >= 0) {
merged[existingIndex] = customModel;
} else {
merged.push(customModel);
}
}
return merged;
}
private isProviderModelAllowed(
model: ProviderPolicyModel,
providerModelAllow?: ProviderModelAllowList,
): boolean {
if (!isModelAllowedByBuiltInProviderPolicy(model)) return false;
if (!providerModelAllow) return true;
const providerKey = model.provider.toLowerCase();
const allowedModels =
providerModelAllow[providerKey] ??
Object.entries(providerModelAllow).find(
([key]) => key.toLowerCase() === providerKey,
)?.[1];
if (allowedModels === undefined) return true;
const modelKey = model.id.trim().toLowerCase();
return allowedModels.some((allowedModel) =>
providerModelAllowEntryMatches(allowedModel, modelKey),
);
}
private filterProviderModelAllow<T extends Model<Api>>(
models: T[],
providerModelAllow?: ProviderModelAllowList,
): T[] {
return models.filter((model) =>
this.isProviderModelAllowed(model, providerModelAllow),
);
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
if (!existsSync(modelsJsonPath)) {
return emptyCustomModelsResult();
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const validate = ajv.getSchema("ModelsConfig")!;
if (!validate(config)) {
const errors =
validate.errors
?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`)
.join("\n") || "Unknown schema error";
return emptyCustomModelsResult(
`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`,
);
}
// Additional validation
this.validateConfig(config);
const overrides = new Map<string, ProviderOverride>();
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
if (
providerConfig.baseUrl ||
providerConfig.headers ||
providerConfig.apiKey
) {
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
apiKey: providerConfig.apiKey,
});
}
// Store API key for fallback resolver.
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
if (providerConfig.modelOverrides) {
modelOverrides.set(
providerName,
new Map(Object.entries(providerConfig.modelOverrides)),
);
}
}
return {
models: this.parseModels(config),
overrides,
modelOverrides,
error: undefined,
};
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(
`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`,
);
}
return emptyCustomModelsResult(
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
);
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
const hasModelOverrides =
providerConfig.modelOverrides &&
Object.keys(providerConfig.modelOverrides).length > 0;
if (models.length === 0) {
// Override-only config: needs baseUrl OR modelOverrides (or both)
if (!providerConfig.baseUrl && !hasModelOverrides) {
throw new Error(
`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`,
);
}
} else {
// Custom models are merged into provider models and require endpoint + auth.
if (!providerConfig.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining custom models.`,
);
}
if (!providerConfig.apiKey) {
throw new Error(
`Provider ${providerName}: "apiKey" is required when defining custom models.`,
);
}
}
for (const modelDef of models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id)
throw new Error(`Provider ${providerName}: model missing "id"`);
// Validate contextWindow/maxTokens only if provided (they have defaults)
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`,
);
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`,
);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(
config.providers,
)) {
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
// Store API key config for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
// Register custom providers so isProviderRequestReady() can find
// them (#3531). Without this, models.json providers with apiKey
// fail the auth check and are invisible to the fallback resolver.
if (!this.registeredProviders.has(providerName)) {
this.registeredProviders.set(providerName, {
authMode: providerConfig.apiKey ? "apiKey" : "none",
apiKey: providerConfig.apiKey,
baseUrl: providerConfig.baseUrl,
isReady: providerConfig.apiKey ? () => true : undefined,
} as any);
}
for (const modelDef of modelDefs) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
// Resolve env vars and shell commands in header values
const providerHeaders = resolveHeaders(providerConfig.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader && providerConfig.apiKey) {
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
// Provider baseUrl is required when custom models are defined.
// Individual models can override it with modelDef.baseUrl.
const defaultCost = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
};
models.push({
id: modelDef.id,
name: modelDef.name ?? modelDef.id,
api: api as Api,
provider: providerName,
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
reasoning: modelDef.reasoning ?? false,
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
cost: modelDef.cost ?? defaultCost,
contextWindow: modelDef.contextWindow ?? 128000,
maxTokens: modelDef.maxTokens ?? 16384,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
return this.filterProviderModelAllow(this.models, providerModelAllow);
}
/**
* Get only models that have auth configured.
* This is a fast check that doesn't refresh OAuth tokens.
*/
getAvailable(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
return this.filterProviderModelAllow(
this.models.filter((m) => this.isProviderRequestReady(m.provider)),
providerModelAllow,
);
}
/**
* Get auth mode for a provider.
* Defaults to "apiKey" for built-ins and providers without explicit mode.
*/
getProviderAuthMode(provider: string): ProviderAuthMode {
if (BUILTIN_EXTERNAL_CLI_AUTH_PROVIDERS.has(provider)) return "externalCli";
const config = this.registeredProviders.get(provider);
if (!config) return "apiKey";
if (config.authMode) return config.authMode;
if (config.oauth) return "oauth";
if (config.apiKey) return "apiKey";
return "apiKey";
}
/**
* Whether a provider can be used for requests/fallback without hard auth gating.
*/
isProviderRequestReady(provider: string): boolean {
const config = this.registeredProviders.get(provider);
if (config?.isReady) return config.isReady();
const authMode = this.getProviderAuthMode(provider);
if (authMode === "externalCli" || authMode === "none") return true;
return this.authStorage.hasAuth(provider);
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | undefined {
return this.filterProviderModelAllow(
this.models.filter((m) => m.provider === provider && m.id === modelId),
)[0];
}
/**
* Get API key for a model.
* Returns undefined for externalCli/none providers (no key needed).
* @param sessionId - Optional session ID for sticky credential selection
*/
async getApiKey(
model: Model<Api>,
sessionId?: string,
): Promise<string | undefined> {
const authMode = this.getProviderAuthMode(model.provider);
if (authMode === "externalCli" || authMode === "none") return undefined;
return this.authStorage.getApiKey(model.provider, sessionId, {
baseUrl: model.baseUrl,
});
}
/**
* Get API key for a provider.
* Returns undefined for externalCli/none providers (no key needed).
* @param sessionId - Optional session ID for sticky credential selection
*/
async getApiKeyForProvider(
provider: string,
sessionId?: string,
): Promise<string | undefined> {
const authMode = this.getProviderAuthMode(provider);
if (authMode === "externalCli" || authMode === "none") return undefined;
return this.authStorage.getApiKey(provider, sessionId);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
/**
* Unregister a previously registered provider.
*
* Removes the provider from the registry and reloads models from disk so that
* built-in models overridden by this provider are restored to their original state.
* Also resets dynamic OAuth and API stream registrations before reapplying
* remaining dynamic providers.
* Has no effect if the provider was never registered.
*/
unregisterProvider(providerName: string): void {
if (!this.registeredProviders.has(providerName)) return;
this.registeredProviders.delete(providerName);
this.customProviderApiKeys.delete(providerName);
this.refresh();
}
private applyProviderConfig(
providerName: string,
config: ProviderConfigInput,
): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
if (config.streamSimple) {
if (!config.api) {
throw new Error(
`Provider ${providerName}: "api" is required when registering streamSimple.`,
);
}
const rawStreamSimple = config.streamSimple;
const authMode = config.authMode ?? "apiKey";
// Keyless providers never see apiKey in options — enforced at registration,
// not by convention. Prevents undefined from reaching any handler.
const streamSimple =
authMode === "externalCli" || authMode === "none"
? (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => {
const { apiKey: _, ...opts } = options ?? {};
return rawStreamSimple(
model,
context,
opts as SimpleStreamOptions,
);
}
: rawStreamSimple;
// Guard: if there's already a handler registered for this API, wrap
// the new one so it only fires for models from this provider and
// delegates to the previous handler for all other providers. Without
// this, a custom provider using api:"anthropic-messages" would clobber
// the built-in Anthropic stream handler (#2536).
const existingProvider = getApiProvider(config.api as Api);
const scopedStream = existingProvider
? (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
if (model.provider === providerName) {
return streamSimple(model, context, options);
}
return existingProvider.streamSimple(model, context, options);
}
: streamSimple;
const newFullStream = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => scopedStream(model, context, options as SimpleStreamOptions);
const scopedFullStream = existingProvider
? (
model: Model<Api>,
context: Context,
options?: Record<string, unknown>,
) => {
if (model.provider === providerName) {
return newFullStream(
model,
context,
options as SimpleStreamOptions,
);
}
return existingProvider.stream(model, context, options);
}
: newFullStream;
registerApiProvider(
{
api: config.api,
stream: scopedFullStream as any,
streamSimple: scopedStream,
},
`provider:${providerName}`,
);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(
`Provider ${providerName}: "baseUrl" is required when defining models.`,
);
}
const authMode =
config.authMode ??
(config.oauth ? "oauth" : config.apiKey ? "apiKey" : "apiKey");
if (authMode === "apiKey" && !config.apiKey && !config.oauth) {
throw new Error(
`Provider ${providerName}: "apiKey" or "oauth" is required when authMode is "apiKey" (the default). ` +
`Set authMode to "externalCli" or "none" for keyless providers.`,
);
}
if (
(authMode === "externalCli" || authMode === "none") &&
!config.streamSimple
) {
throw new Error(
`Provider ${providerName}: "streamSimple" is required when authMode is "${authMode}". ` +
`Keyless providers must supply their own stream handler.`,
);
}
if (
(authMode === "externalCli" || authMode === "none") &&
config.apiKey
) {
throw new Error(
`Provider ${providerName}: "apiKey" cannot be set when authMode is "${authMode}". ` +
`Keyless providers should not provide API key credentials.`,
);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`,
);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers =
providerHeaders || modelHeaders
? { ...providerHeaders, ...modelHeaders }
: undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
providerOptions: modelDef.providerOptions,
} as Model<Api>);
}
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
if (config.oauth?.modifyModels) {
const cred = this.authStorage.get(providerName);
if (cred?.type === "oauth") {
this.models = config.oauth.modifyModels(this.models, cred);
}
}
// Ensure newly added extension models get capability patches
this.models = applyCapabilityPatches(this.models);
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders
? { ...m.headers, ...resolvedHeaders }
: m.headers,
};
});
}
}
private buildCandidateOrder(
modelId: string,
overrides: Record<string, string[]>,
): string[] {
const overrideEntry = Object.entries(overrides).find(([k]) =>
modelId.startsWith(k),
);
const familyEntry = PROXY_FAMILY_PRIORITY.find((r) =>
r.match.test(modelId),
);
// Order: direct family providers → family-scoped failover → global fallback.
// Overrides replace only the direct list (keeps family_failover + global
// chain intact) so a user pinning "glm- → [zai]" still picks up
// opencode-go / openrouter / ollama-cloud as last resort.
const familyProviders = overrideEntry?.[1] ?? familyEntry?.providers ?? [];
const familyFailover = familyEntry?.family_failover ?? [];
const seen = new Set([...familyProviders, ...familyFailover]);
const globalFallback =
familyEntry?.global_fallback === false
? []
: GLOBAL_PROVIDER_FALLBACK.filter((p) => !seen.has(p));
return [...familyProviders, ...familyFailover, ...globalFallback];
}
/**
* Return all registered candidates for a bare model ID, ordered by family + global priority.
* Candidates with auth configured are placed first within the same priority tier.
* The proxy server iterates this list and falls through to the next entry on 429.
*/
getModelsForProxy(
modelId: string,
overrides: Record<string, string[]> = {},
providerModelAllow?: ProviderModelAllowList,
): Model<Api>[] {
const candidates = this.filterProviderModelAllow(
this.models.filter((m) => m.id === modelId),
providerModelAllow,
);
if (candidates.length === 0) return [];
const order = this.buildCandidateOrder(modelId, overrides);
const sorted = [...candidates].sort((a, b) => {
const pa = order.indexOf(a.provider);
const pb = order.indexOf(b.provider);
return (pa === -1 ? Infinity : pa) - (pb === -1 ? Infinity : pb);
});
const withAuth = sorted.filter((m) =>
this.isProviderRequestReady(m.provider),
);
const withoutAuth = sorted.filter(
(m) => !this.isProviderRequestReady(m.provider),
);
return [...withAuth, ...withoutAuth];
}
/**
* Resolve a bare model ID to the single highest-priority candidate.
* Convenience wrapper over getModelsForProxy for callers that don't need retry.
*/
getModel(
modelId: string,
overrides: Record<string, string[]> = {},
providerModelAllow?: ProviderModelAllowList,
): Model<Api> | undefined {
return this.getModelsForProxy(modelId, overrides, providerModelAllow)[0];
}
/**
* Discover models from all providers that support discovery.
* Results are cached and merged into the registry (never overrides existing models).
*/
async discoverModels(providers?: string[]): Promise<DiscoveryResult[]> {
const targetProviders = providers ?? getDiscoverableProviders();
const results: DiscoveryResult[] = [];
for (const providerName of targetProviders) {
const adapter = getDiscoveryAdapter(providerName);
if (!adapter.supportsDiscovery) continue;
// Skip if cache is still fresh
if (!this.discoveryCache.isStale(providerName)) {
const cached = this.discoveryCache.get(providerName);
if (cached) {
results.push({
provider: providerName,
sourceType: adapter.sourceType ?? "provider",
models: cached.models,
fetchedAt: cached.fetchedAt,
});
continue;
}
}
try {
const apiKey = await this.authStorage.getApiKey(providerName);
if (
!apiKey &&
!this.isProviderRequestReady(providerName) &&
adapter.requiresAuthForDiscovery !== false
)
continue;
const models = await adapter.fetchModels(apiKey ?? "", undefined);
this.discoveryCache.set(providerName, models);
results.push({
provider: providerName,
sourceType: adapter.sourceType ?? "provider",
models,
fetchedAt: Date.now(),
});
} catch (error) {
results.push({
provider: providerName,
sourceType: adapter.sourceType ?? "provider",
models: [],
fetchedAt: Date.now(),
error: error instanceof Error ? error.message : String(error),
});
}
}
// Convert and merge discovered models, then apply capability patches
this.discoveredModels = applyCapabilityPatches(
this.convertDiscoveredModels(results),
);
return results;
}
/**
* Discover models from catalog sources such as Singularity Memory.
* Catalog sources are not AI providers; discovered rows are materialized
* under their execution provider from catalog metadata.
*/
async discoverModelCatalogs(sources?: string[]): Promise<DiscoveryResult[]> {
return this.discoverModels(sources ?? getDiscoverableCatalogSources());
}
/**
* Get all models including discovered ones.
* Discovered models are appended but never override existing models.
*/
getAllWithDiscovered(): Model<Api>[] {
const existingIds = new Set(
this.models.map((m) => `${m.provider}/${m.id}`),
);
const unique = this.discoveredModels.filter(
(m) => !existingIds.has(`${m.provider}/${m.id}`),
);
return this.filterProviderModelAllow([...this.models, ...unique]);
}
/**
* Return only models from the most recent discovery pass.
*
* Purpose: let diagnostic list commands replace stale static rows for live-listed
* providers with the provider's actual `/models` response.
* Consumer: cli/list-models.ts when `--discover` or an exact provider query is used.
*/
getDiscoveredModels(
providerModelAllow?: ProviderModelAllowList,
): Model<Api>[] {
return this.filterProviderModelAllow(
this.discoveredModels,
providerModelAllow,
);
}
/**
* Check if a model was added via discovery (not built-in or custom).
*/
isDiscovered(model: Model<Api>): boolean {
return this.discoveredModels.some(
(m) => m.provider === model.provider && m.id === model.id,
);
}
/**
* Get the discovery cache instance.
*/
getDiscoveryCache(): ModelDiscoveryCache {
return this.discoveryCache;
}
/**
* Convert DiscoveryResult[] into Model<Api>[] with default values.
*/
private convertDiscoveredModels(results: DiscoveryResult[]): Model<Api>[] {
const converted: Model<Api>[] = [];
const seen = new Set<string>();
for (const result of results) {
if (result.error) continue;
for (const dm of result.models) {
const provider = dm.provider ?? result.provider;
const key = `${provider}/${dm.id}`;
if (seen.has(key)) continue;
seen.add(key);
const known = this.models.find(
(m) => m.provider === provider && m.id === dm.id,
);
const discoveredName =
dm.name && dm.name !== dm.id ? dm.name : undefined;
converted.push({
...known,
id: dm.id,
name: discoveredName ?? known?.name ?? dm.id,
api: (dm.api ?? known?.api ?? "openai") as Api,
provider,
baseUrl: dm.baseUrl ?? known?.baseUrl ?? "",
reasoning: dm.reasoning ?? known?.reasoning ?? false,
input: dm.input ?? known?.input ?? ["text"],
cost: dm.cost ??
known?.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: dm.contextWindow ?? known?.contextWindow ?? 128000,
maxTokens: dm.maxTokens ?? known?.maxTokens ?? 16384,
} as Model<Api>);
}
}
return converted;
}
/**
* Check if a model's baseUrl points to a local endpoint.
* Delegates to standalone isLocalModel() function.
*/
static isLocalModel(model: Model<Api>): boolean {
return isLocalModel(model);
}
/**
* Check if all models in the registry are local.
* Returns true only if every model passes isLocalModel().
* Returns false if there are no models.
*/
isAllLocalChain(): boolean {
const models = this.getAll();
if (models.length === 0) return false;
return models.every((m) => isLocalModel(m));
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
authMode?: ProviderAuthMode;
/** Optional readiness check. Called by isProviderRequestReady() before default auth checks.
* Trusted at the same level as extension code — extensions already have arbitrary code execution. */
isReady?: () => boolean;
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
providerOptions?: Record<string, unknown>;
}>;
}