From d38e5ea092f20cdfe099295d8d676767ebd42cc8 Mon Sep 17 00:00:00 2001 From: Mikael Hugo Date: Tue, 28 Apr 2026 12:30:55 +0200 Subject: [PATCH] =?UTF-8?q?fix(schema):=20auto-coerce=20string=20=E2=86=92?= =?UTF-8?q?=20[string]=20for=20sf=5F*=20list=20fields=20+=20provider=5Fmod?= =?UTF-8?q?el=5Fallow=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two codex-rescue tasks landed together: 1. Auto-coerce JSON-schema validator: when a tool field declares {type:"array", items:{type:"string"}} and the model sends a single string, wrap it in [string] before validation instead of hard-rejecting. Fixes the recurring "keyDecisions: must be array" rejection on sf_complete_task that wasted retries. 2. Provider_model_allow filter (proper implementation with helpers): - resolveProviderModelAllowList / isProviderModelAllowed / filterModelsByProviderModelAllow helpers in preferences-models - Wired into model-registry and auto-model-selection - New tests/provider-model-allow.test.ts Tools coerced: sf_complete_task, sf_complete_milestone, sf_plan_milestone, sf_plan_slice, sf_replan_slice, sf_reassess_roadmap (key list fields). Co-Authored-By: Claude Sonnet 4.6 Co-Authored-By: OpenAI Codex --- .../src/coerce-string-arrays.test.ts | 98 +++++++++++++ packages/pi-ai/src/utils/validation.ts | 50 ++++++- .../core/model-registry-proxy-routing.test.ts | 41 ++++++ .../src/core/model-registry.ts | 54 ++++++-- .../extensions/sf/auto-model-selection.ts | 12 +- .../extensions/sf/preferences-models.ts | 40 ++++-- .../extensions/sf/preferences-validation.ts | 13 +- src/resources/extensions/sf/preferences.ts | 25 +++- .../sf/tests/provider-model-allow.test.ts | 130 ++++++++++++++++++ 9 files changed, 434 insertions(+), 29 deletions(-) create mode 100644 packages/mcp-server/src/coerce-string-arrays.test.ts create mode 100644 src/resources/extensions/sf/tests/provider-model-allow.test.ts diff --git a/packages/mcp-server/src/coerce-string-arrays.test.ts b/packages/mcp-server/src/coerce-string-arrays.test.ts new file mode 100644 index 000000000..8a64d2b11 --- /dev/null +++ b/packages/mcp-server/src/coerce-string-arrays.test.ts @@ -0,0 +1,98 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; +import { z } from "zod"; + +import { validateToolArguments } from "../../pi-ai/src/utils/validation.ts"; +import { registerWorkflowTools } from "./workflow-tools.ts"; + +type RegisteredTool = { + name: string; + description: string; + params: Record; +}; + +function makeMockServer() { + const tools: RegisteredTool[] = []; + return { + tools, + tool( + name: string, + description: string, + params: Record, + _handler: (args: Record) => Promise, + ) { + tools.push({ name, description, params }); + }, + }; +} + +function workflowToolSchema(toolName: string): Record { + const server = makeMockServer(); + registerWorkflowTools(server as any); + const tool = server.tools.find((candidate) => candidate.name === toolName); + assert.ok(tool, `${toolName} should be registered`); + + const schema = z.toJSONSchema(z.object(tool.params as z.ZodRawShape)) as Record; + delete schema.$schema; + return schema; +} + +function makeToolCall(overrides: Record) { + return { + type: "toolCall" as const, + id: "call-1", + name: "sf_complete_task", + arguments: { + projectDir: "/tmp/sf-project", + taskId: "T01", + sliceId: "S01", + milestoneId: "M001", + oneLiner: "Completed task", + narrative: "Did the work.", + verification: "npm test", + ...overrides, + }, + }; +} + +describe("string-array schema coercion", () => { + const sfCompleteTaskTool = { + name: "sf_complete_task", + description: "Record a completed task to the SF database and render its SUMMARY.md.", + parameters: workflowToolSchema("sf_complete_task") as any, + }; + + it("coerces a bare string keyDecisions value before validation", () => { + const args = validateToolArguments( + sfCompleteTaskTool, + makeToolCall({ keyDecisions: "single string" }), + ); + + assert.deepEqual(args.keyDecisions, ["single string"]); + }); + + it("keeps an array keyDecisions value valid", () => { + const args = validateToolArguments( + sfCompleteTaskTool, + makeToolCall({ keyDecisions: ["a", "b"] }), + ); + + assert.deepEqual(args.keyDecisions, ["a", "b"]); + }); + + it("rejects a non-string, non-array keyDecisions value", () => { + assert.throws( + () => validateToolArguments(sfCompleteTaskTool, makeToolCall({ keyDecisions: 42 })), + /keyDecisions: must be array/, + ); + }); + + it("allows an undefined optional keyDecisions value", () => { + const args = validateToolArguments( + sfCompleteTaskTool, + makeToolCall({ keyDecisions: undefined }), + ); + + assert.equal(args.keyDecisions, undefined); + }); +}); diff --git a/packages/pi-ai/src/utils/validation.ts b/packages/pi-ai/src/utils/validation.ts index 62c507b9d..915d88408 100644 --- a/packages/pi-ai/src/utils/validation.ts +++ b/packages/pi-ai/src/utils/validation.ts @@ -7,6 +7,54 @@ const addFormats = (addFormatsModule as any).default || addFormatsModule; import type { Tool, ToolCall } from "../types.js"; +type JsonSchemaObject = Record; + +function isRecord(value: unknown): value is JsonSchemaObject { + return value !== null && typeof value === "object" && !Array.isArray(value); +} + +function isStringArraySchema(schema: unknown): schema is JsonSchemaObject { + if (!isRecord(schema) || schema.type !== "array") return false; + const items = schema.items; + return isRecord(items) && items.type === "string"; +} + +function coerceSchemaValue(schema: unknown, value: unknown): unknown { + if (!isRecord(schema)) return value; + if (isStringArraySchema(schema) && typeof value === "string") { + return [value]; + } + + if (Array.isArray(value)) { + const items = schema.items; + if (!isRecord(items)) return value; + return value.map((item) => coerceSchemaValue(items, item)); + } + + if (!isRecord(value)) return value; + + const properties = schema.properties; + if (!isRecord(properties)) return value; + + let next: JsonSchemaObject | null = null; + for (const [key, propertySchema] of Object.entries(properties)) { + if (!Object.prototype.hasOwnProperty.call(value, key)) continue; + const coercedValue = coerceSchemaValue(propertySchema, value[key]); + if (coercedValue !== value[key]) { + next ??= { ...value }; + next[key] = coercedValue; + } + } + return next ?? value; +} + +/** + * Wraps bare strings for JSON-schema fields declared as string arrays before AJV validation. + */ +export function coerceStringArrays(schema: unknown, params: unknown): unknown { + return coerceSchemaValue(schema, params); +} + // Detect if we're in a browser extension environment with strict CSP // Chrome extensions with Manifest V3 don't allow eval/Function constructor const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined; @@ -47,7 +95,7 @@ export function validateToolArguments(tool: Tool, toolCall: ToolCall): any { const validate = ajv.compile(tool.parameters); // Clone arguments so AJV can safely mutate for type coercion - const args = structuredClone(toolCall.arguments); + const args = coerceStringArrays(tool.parameters, structuredClone(toolCall.arguments)); // Validate the arguments (AJV mutates args in-place for type coercion) if (validate(args)) { diff --git a/packages/pi-coding-agent/src/core/model-registry-proxy-routing.test.ts b/packages/pi-coding-agent/src/core/model-registry-proxy-routing.test.ts index 5a2b62d52..69a4443f2 100644 --- a/packages/pi-coding-agent/src/core/model-registry-proxy-routing.test.ts +++ b/packages/pi-coding-agent/src/core/model-registry-proxy-routing.test.ts @@ -202,3 +202,44 @@ describe("ModelRegistry.getModel — convenience wrapper", () => { assert.equal(model.provider, "zai"); }); }); + +// ── provider_model_allow final filter ───────────────────────────────────────── + +describe("ModelRegistry provider_model_allow filter", () => { + it("keeps an allowed provider/model candidate", () => { + const registry = createRegistry(); + registerNone(registry, "minimax", "MiniMax-M2.7"); + + const result = registry.getModelsForProxy("MiniMax-M2.7", {}, { + minimax: ["MiniMax-M2.7"], + }); + + assert.ok( + result.some((model) => model.provider === "minimax" && model.id === "MiniMax-M2.7"), + "allowed minimax/MiniMax-M2.7 candidate should survive filtering", + ); + }); + + it("filters a disallowed provider/model candidate and falls through", () => { + const registry = createRegistry(); + registerNone(registry, "minimax", "MiniMax-M2"); + registerNone(registry, "minimax-cn", "MiniMax-M2"); + + const result = registry.getModelsForProxy("MiniMax-M2", {}, { + minimax: ["MiniMax-M2.7"], + }); + + assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["minimax-cn/MiniMax-M2"]); + }); + + it("leaves providers absent from the allow-list unrestricted", () => { + const registry = createRegistry(); + registerNone(registry, "zai", "glm-4-air"); + + const result = registry.getModelsForProxy("glm-4-air", {}, { + minimax: ["MiniMax-M2.7"], + }); + + assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["zai/glm-4-air"]); + }); +}); diff --git a/packages/pi-coding-agent/src/core/model-registry.ts b/packages/pi-coding-agent/src/core/model-registry.ts index 6de6e268b..579b868d6 100644 --- a/packages/pi-coding-agent/src/core/model-registry.ts +++ b/packages/pi-coding-agent/src/core/model-registry.ts @@ -197,6 +197,8 @@ ajv.addSchema(ModelsConfigSchema, "ModelsConfig"); type ModelsConfig = Static; +export type ProviderModelAllowList = Record; + export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none"; /** Provider override config (baseUrl, headers, apiKey) without custom models */ @@ -425,6 +427,28 @@ export class ModelRegistry { return merged; } + private isProviderModelAllowed( + provider: string, + modelId: string, + providerModelAllow?: ProviderModelAllowList, + ): boolean { + if (!providerModelAllow) return true; + const providerKey = provider.toLowerCase(); + const allowedModels = providerModelAllow[providerKey] + ?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1]; + if (allowedModels === undefined) return true; + const modelKey = modelId.trim().toLowerCase(); + return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey); + } + + private filterProviderModelAllow>( + models: T[], + providerModelAllow?: ProviderModelAllowList, + ): T[] { + if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return models; + return models.filter((model) => this.isProviderModelAllowed(model.provider, model.id, providerModelAllow)); + } + private loadCustomModels(modelsJsonPath: string): CustomModelsResult { if (!existsSync(modelsJsonPath)) { return emptyCustomModelsResult(); @@ -590,16 +614,19 @@ export class ModelRegistry { * Get all models (built-in + custom). * If models.json had errors, returns only built-in models. */ - getAll(): Model[] { - return this.models; + getAll(providerModelAllow?: ProviderModelAllowList): Model[] { + return this.filterProviderModelAllow(this.models, providerModelAllow); } /** * Get only models that have auth configured. * This is a fast check that doesn't refresh OAuth tokens. */ - getAvailable(): Model[] { - return this.models.filter((m) => this.isProviderRequestReady(m.provider)); + getAvailable(providerModelAllow?: ProviderModelAllowList): Model[] { + return this.filterProviderModelAllow( + this.models.filter((m) => this.isProviderRequestReady(m.provider)), + providerModelAllow, + ); } /** @@ -867,8 +894,15 @@ export class ModelRegistry { * Candidates with auth configured are placed first within the same priority tier. * The proxy server iterates this list and falls through to the next entry on 429. */ - getModelsForProxy(modelId: string, overrides: Record = {}): Model[] { - const candidates = this.models.filter((m) => m.id === modelId); + getModelsForProxy( + modelId: string, + overrides: Record = {}, + providerModelAllow?: ProviderModelAllowList, + ): Model[] { + const candidates = this.filterProviderModelAllow( + this.models.filter((m) => m.id === modelId), + providerModelAllow, + ); if (candidates.length === 0) return []; const order = this.buildCandidateOrder(modelId, overrides); @@ -887,8 +921,12 @@ export class ModelRegistry { * Resolve a bare model ID to the single highest-priority candidate. * Convenience wrapper over getModelsForProxy for callers that don't need retry. */ - getModel(modelId: string, overrides: Record = {}): Model | undefined { - return this.getModelsForProxy(modelId, overrides)[0]; + getModel( + modelId: string, + overrides: Record = {}, + providerModelAllow?: ProviderModelAllowList, + ): Model | undefined { + return this.getModelsForProxy(modelId, overrides, providerModelAllow)[0]; } /** diff --git a/src/resources/extensions/sf/auto-model-selection.ts b/src/resources/extensions/sf/auto-model-selection.ts index ce978a736..e5569ba12 100644 --- a/src/resources/extensions/sf/auto-model-selection.ts +++ b/src/resources/extensions/sf/auto-model-selection.ts @@ -8,7 +8,7 @@ import type { Api, Model } from "@singularity-forge/pi-ai"; import { getProviderCapabilities } from "@singularity-forge/pi-ai"; import type { ExtensionAPI, ExtensionContext } from "@singularity-forge/pi-coding-agent"; import type { SFPreferences } from "./preferences.js"; -import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js"; +import { filterModelsByProviderModelAllow, resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js"; import type { ComplexityTier } from "./complexity-classifier.js"; import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js"; import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js"; @@ -212,16 +212,17 @@ export async function selectAndApplyModel( // 400 "model not supported" dispatch failures in dr-repo. const rawAvailable = ctx.modelRegistry.getAvailable(); const allowed = prefs?.allowed_providers; - const availableModels = (allowed && allowed.length > 0) + const providerAllowedModels = (allowed && allowed.length > 0) ? rawAvailable.filter(m => allowed.includes(m.provider.toLowerCase())) : rawAvailable; - if (allowed && allowed.length > 0 && availableModels.length === 0) { + if (allowed && allowed.length > 0 && providerAllowedModels.length === 0) { throw new Error( `allowed_providers filter rejected every available model. ` + `Configured providers: [${allowed.join(", ")}]. ` + `Either add a provider to allowed_providers or remove the pref.`, ); } + const availableModels = filterModelsByProviderModelAllow(providerAllowedModels, prefs?.provider_model_allow); const modelPolicyTraceId = `model:${ctx.sessionManager.getSessionId()}:${Date.now()}`; const modelPolicyTurnId = `${unitType}:${unitId}`; let policyAllowedModelKeys: Set | null = null; @@ -544,7 +545,10 @@ export async function selectAndApplyModel( } else if (autoModeStartModel) { // No model preference for this unit type — re-apply the model captured // at auto-mode start to prevent bleed from shared global settings.json (#650). - const availableModels = ctx.modelRegistry.getAvailable(); + const availableModels = filterModelsByProviderModelAllow( + ctx.modelRegistry.getAvailable(), + prefs?.provider_model_allow, + ); const startModel = availableModels.find( m => m.provider === autoModeStartModel.provider && m.id === autoModeStartModel.id, ); diff --git a/src/resources/extensions/sf/preferences-models.ts b/src/resources/extensions/sf/preferences-models.ts index 50d495edc..2d43b8a79 100644 --- a/src/resources/extensions/sf/preferences-models.ts +++ b/src/resources/extensions/sf/preferences-models.ts @@ -27,6 +27,37 @@ import { loadEffectiveSFPreferences, getGlobalSFPreferencesPath } from "./prefer // Re-export types so existing consumers of ./preferences-models.js keep working export type { SFPhaseModelConfig, SFModelConfig, SFModelConfigV2, ResolvedModelConfig } from "./preferences-types.js"; +export type ProviderModelAllowList = Record; + +function resolveProviderModelAllowList( + providerModelAllow: ProviderModelAllowList | undefined, + provider: string, +): readonly string[] | undefined { + if (!providerModelAllow) return undefined; + const providerKey = provider.toLowerCase(); + return providerModelAllow[providerKey] + ?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1]; +} + +export function isProviderModelAllowed( + provider: string, + modelId: string, + providerModelAllow: ProviderModelAllowList | undefined, +): boolean { + const allowedModels = resolveProviderModelAllowList(providerModelAllow, provider); + if (allowedModels === undefined) return true; + const modelKey = modelId.trim().toLowerCase(); + return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey); +} + +export function filterModelsByProviderModelAllow( + models: readonly T[], + providerModelAllow: ProviderModelAllowList | undefined, +): T[] { + if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return [...models]; + return models.filter((model) => isProviderModelAllowed(model.provider, model.id, providerModelAllow)); +} + /** * Resolve which model ID to use for a given auto-mode unit type. * Returns undefined if no model preference is set for this unit type. @@ -63,18 +94,11 @@ function resolveAutoBenchmarkPickForUnit( ): ResolvedModelConfig | undefined { try { const allowed = prefs?.allowed_providers?.map(s => s.toLowerCase()); - const modelAllow = prefs?.provider_model_allow; const candidates: Array<{ provider: string; id: string }> = []; for (const provider of getProviders()) { if (allowed && !allowed.includes(provider.toLowerCase())) continue; - // Per-provider model allow-list: when a provider is listed here, only - // the listed model IDs may be used from that provider. - const providerKey = Object.keys(modelAllow ?? {}).find( - k => k.toLowerCase() === provider.toLowerCase(), - ); - const allowedModels = providerKey ? modelAllow![providerKey] : undefined; for (const model of getModels(provider)) { - if (allowedModels && !allowedModels.includes(model.id)) continue; + if (!isProviderModelAllowed(provider, model.id, prefs?.provider_model_allow)) continue; candidates.push({ provider, id: model.id }); } } diff --git a/src/resources/extensions/sf/preferences-validation.ts b/src/resources/extensions/sf/preferences-validation.ts index 4973629b5..579dbf791 100644 --- a/src/resources/extensions/sf/preferences-validation.ts +++ b/src/resources/extensions/sf/preferences-validation.ts @@ -426,7 +426,7 @@ export function validatePreferences(preferences: SFPreferences): { } // ─── Per-provider model allow-list ────────────────────────────────── - // When a provider has an entry here, ONLY listed model IDs are usable + // When a provider has an entry here, only listed model IDs are usable // from that provider. Providers absent from the block are unrestricted. if (preferences.provider_model_allow !== undefined) { if ( @@ -436,12 +436,19 @@ export function validatePreferences(preferences: SFPreferences): { ) { const cleaned: Record = {}; for (const [provider, models] of Object.entries(preferences.provider_model_allow as Record)) { + const providerId = provider.trim().toLowerCase(); + if (!providerId) { + errors.push("provider_model_allow provider IDs must be non-empty strings"); + continue; + } if (!Array.isArray(models) || models.some((m: unknown) => typeof m !== "string")) { errors.push(`provider_model_allow.${provider} must be an array of model ID strings`); continue; } - const list = (models as string[]).map(s => s.trim()).filter(s => s.length > 0); - if (list.length > 0) cleaned[provider.toLowerCase()] = list; + const list = (models as string[]) + .map(s => s.trim()) + .filter(s => s.length > 0); + cleaned[providerId] = Array.from(new Set(list)); } if (Object.keys(cleaned).length > 0) validated.provider_model_allow = cleaned; } else { diff --git a/src/resources/extensions/sf/preferences.ts b/src/resources/extensions/sf/preferences.ts index effc32c41..be98f3c6c 100644 --- a/src/resources/extensions/sf/preferences.ts +++ b/src/resources/extensions/sf/preferences.ts @@ -95,6 +95,8 @@ export { resolveInlineLevel, resolveContextSelection, resolveSearchProviderFromPreferences, + isProviderModelAllowed, + filterModelsByProviderModelAllow, } from "./preferences-models.js"; // ─── Path Constants & Getters ─────────────────────────────────────────────── @@ -509,11 +511,7 @@ function mergePreferences(base: SFPreferences, override: SFPreferences): SFPrefe // override-wins merge so the preference actually reaches consumers. allowed_providers: mergeStringLists(base.allowed_providers, override.allowed_providers), provider_preference: override.provider_preference ?? base.provider_preference, - // Per-provider replace inside the array (project entry replaces global - // for that provider; providers only in base survive). - provider_model_allow: (base.provider_model_allow || override.provider_model_allow) - ? { ...(base.provider_model_allow ?? {}), ...(override.provider_model_allow ?? {}) } - : undefined, + provider_model_allow: mergeProviderModelAllow(base.provider_model_allow, override.provider_model_allow), flat_rate_providers: mergeStringLists(base.flat_rate_providers, override.flat_rate_providers), stale_commit_threshold_minutes: override.stale_commit_threshold_minutes ?? base.stale_commit_threshold_minutes, widget_mode: override.widget_mode ?? base.widget_mode, @@ -536,6 +534,23 @@ function mergeStringLists(base?: unknown, override?: unknown): string[] | undefi return merged.length > 0 ? Array.from(new Set(merged)) : undefined; } +function mergeProviderModelAllow( + base?: Record, + override?: Record, +): Record | undefined { + if (!base && !override) return undefined; + const merged: Record = {}; + for (const [provider, models] of Object.entries(base ?? {})) { + merged[provider] = [...models]; + } + for (const [provider, models] of Object.entries(override ?? {})) { + // Per-provider replace: a project entry replaces the global array for + // that provider instead of appending to it. + merged[provider] = [...models]; + } + return Object.keys(merged).length > 0 ? merged : undefined; +} + function mergePostUnitHooks( base?: PostUnitHookConfig[], override?: PostUnitHookConfig[], diff --git a/src/resources/extensions/sf/tests/provider-model-allow.test.ts b/src/resources/extensions/sf/tests/provider-model-allow.test.ts new file mode 100644 index 000000000..b491a10f4 --- /dev/null +++ b/src/resources/extensions/sf/tests/provider-model-allow.test.ts @@ -0,0 +1,130 @@ +import test from "node:test"; +import assert from "node:assert/strict"; +import { mkdtempSync, mkdirSync, rmSync, writeFileSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; + +import { loadEffectiveSFPreferences, validatePreferences } from "../preferences.ts"; +import { filterModelsByProviderModelAllow } from "../preferences-models.ts"; + +test("provider_model_allow: provider in allow-list and model in list passes", () => { + const models = [ + { provider: "minimax", id: "MiniMax-M2.7" }, + { provider: "minimax", id: "MiniMax-M2" }, + ]; + + const filtered = filterModelsByProviderModelAllow(models, { + minimax: ["MiniMax-M2.7"], + }); + + assert.deepEqual(filtered.map((m) => `${m.provider}/${m.id}`), ["minimax/MiniMax-M2.7"]); +}); + +test("provider_model_allow: provider in allow-list and model not in list is filtered", () => { + const models = [ + { provider: "minimax", id: "MiniMax-M2" }, + { provider: "zai", id: "glm-5" }, + ]; + + const filtered = filterModelsByProviderModelAllow(models, { + minimax: ["MiniMax-M2.7"], + }); + + assert.deepEqual( + filtered.map((m) => `${m.provider}/${m.id}`), + ["zai/glm-5"], + "minimax/MiniMax-M2 is removed so selection can fall through to the next provider", + ); +}); + +test("provider_model_allow: provider absent from allow-list is unrestricted", () => { + const models = [ + { provider: "minimax", id: "MiniMax-M2" }, + { provider: "zai", id: "glm-5" }, + ]; + + const filtered = filterModelsByProviderModelAllow(models, { + minimax: ["MiniMax-M2.7"], + }); + + assert.ok(filtered.some((m) => m.provider === "zai" && m.id === "glm-5")); +}); + +test("provider_model_allow: validates shape and normalizes provider IDs", () => { + const result = validatePreferences({ + provider_model_allow: { + MiniMax: [" MiniMax-M2.7 ", "MiniMax-M2.7-highspeed"], + }, + }); + + assert.equal(result.errors.length, 0); + assert.deepEqual(result.preferences.provider_model_allow, { + minimax: ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"], + }); +}); + +test("provider_model_allow: rejects invalid shapes", () => { + const result = validatePreferences({ + provider_model_allow: { + minimax: "MiniMax-M2.7", + zai: ["glm-5", 42], + } as any, + }); + + assert.ok(result.errors.some((error) => error.includes("provider_model_allow.minimax"))); + assert.ok(result.errors.some((error) => error.includes("provider_model_allow.zai"))); + assert.equal(result.preferences.provider_model_allow, undefined); +}); + +test("provider_model_allow: project overrides global per provider", () => { + const originalCwd = process.cwd(); + const originalSfHome = process.env.SF_HOME; + const tempProject = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-project-")); + const tempHome = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-home-")); + + try { + mkdirSync(join(tempProject, ".sf"), { recursive: true }); + + writeFileSync( + join(tempHome, "preferences.md"), + [ + "---", + "provider_model_allow:", + " minimax:", + " - MiniMax-M2.7", + " zai:", + " - glm-5", + "---", + ].join("\n"), + "utf-8", + ); + + writeFileSync( + join(tempProject, ".sf", "PREFERENCES.md"), + [ + "---", + "provider_model_allow:", + " minimax:", + " - MiniMax-M2.7-highspeed", + "---", + ].join("\n"), + "utf-8", + ); + + process.env.SF_HOME = tempHome; + process.chdir(tempProject); + + const loaded = loadEffectiveSFPreferences(); + assert.notEqual(loaded, null); + assert.deepEqual(loaded!.preferences.provider_model_allow, { + minimax: ["MiniMax-M2.7-highspeed"], + zai: ["glm-5"], + }); + } finally { + process.chdir(originalCwd); + if (originalSfHome === undefined) delete process.env.SF_HOME; + else process.env.SF_HOME = originalSfHome; + rmSync(tempProject, { recursive: true, force: true }); + rmSync(tempHome, { recursive: true, force: true }); + } +});