fix(schema): auto-coerce string → [string] for sf_* list fields + provider_model_allow tests

Two codex-rescue tasks landed together:

1. Auto-coerce JSON-schema validator: when a tool field declares
   {type:"array", items:{type:"string"}} and the model sends a single
   string, wrap it in [string] before validation instead of hard-rejecting.
   Fixes the recurring "keyDecisions: must be array" rejection on
   sf_complete_task that wasted retries.

2. Provider_model_allow filter (proper implementation with helpers):
   - resolveProviderModelAllowList / isProviderModelAllowed /
     filterModelsByProviderModelAllow helpers in preferences-models
   - Wired into model-registry and auto-model-selection
   - New tests/provider-model-allow.test.ts

Tools coerced: sf_complete_task, sf_complete_milestone, sf_plan_milestone,
sf_plan_slice, sf_replan_slice, sf_reassess_roadmap (key list fields).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: OpenAI Codex <noreply@openai.com>
This commit is contained in:
Mikael Hugo 2026-04-28 12:30:55 +02:00
parent f98a1e360e
commit d38e5ea092
9 changed files with 434 additions and 29 deletions

View file

@ -0,0 +1,98 @@
import { describe, it } from "node:test";
import assert from "node:assert/strict";
import { z } from "zod";
import { validateToolArguments } from "../../pi-ai/src/utils/validation.ts";
import { registerWorkflowTools } from "./workflow-tools.ts";
type RegisteredTool = {
name: string;
description: string;
params: Record<string, unknown>;
};
function makeMockServer() {
const tools: RegisteredTool[] = [];
return {
tools,
tool(
name: string,
description: string,
params: Record<string, unknown>,
_handler: (args: Record<string, unknown>) => Promise<unknown>,
) {
tools.push({ name, description, params });
},
};
}
function workflowToolSchema(toolName: string): Record<string, unknown> {
const server = makeMockServer();
registerWorkflowTools(server as any);
const tool = server.tools.find((candidate) => candidate.name === toolName);
assert.ok(tool, `${toolName} should be registered`);
const schema = z.toJSONSchema(z.object(tool.params as z.ZodRawShape)) as Record<string, unknown>;
delete schema.$schema;
return schema;
}
function makeToolCall(overrides: Record<string, unknown>) {
return {
type: "toolCall" as const,
id: "call-1",
name: "sf_complete_task",
arguments: {
projectDir: "/tmp/sf-project",
taskId: "T01",
sliceId: "S01",
milestoneId: "M001",
oneLiner: "Completed task",
narrative: "Did the work.",
verification: "npm test",
...overrides,
},
};
}
describe("string-array schema coercion", () => {
const sfCompleteTaskTool = {
name: "sf_complete_task",
description: "Record a completed task to the SF database and render its SUMMARY.md.",
parameters: workflowToolSchema("sf_complete_task") as any,
};
it("coerces a bare string keyDecisions value before validation", () => {
const args = validateToolArguments(
sfCompleteTaskTool,
makeToolCall({ keyDecisions: "single string" }),
);
assert.deepEqual(args.keyDecisions, ["single string"]);
});
it("keeps an array keyDecisions value valid", () => {
const args = validateToolArguments(
sfCompleteTaskTool,
makeToolCall({ keyDecisions: ["a", "b"] }),
);
assert.deepEqual(args.keyDecisions, ["a", "b"]);
});
it("rejects a non-string, non-array keyDecisions value", () => {
assert.throws(
() => validateToolArguments(sfCompleteTaskTool, makeToolCall({ keyDecisions: 42 })),
/keyDecisions: must be array/,
);
});
it("allows an undefined optional keyDecisions value", () => {
const args = validateToolArguments(
sfCompleteTaskTool,
makeToolCall({ keyDecisions: undefined }),
);
assert.equal(args.keyDecisions, undefined);
});
});

View file

@ -7,6 +7,54 @@ const addFormats = (addFormatsModule as any).default || addFormatsModule;
import type { Tool, ToolCall } from "../types.js";
type JsonSchemaObject = Record<string, unknown>;
function isRecord(value: unknown): value is JsonSchemaObject {
return value !== null && typeof value === "object" && !Array.isArray(value);
}
function isStringArraySchema(schema: unknown): schema is JsonSchemaObject {
if (!isRecord(schema) || schema.type !== "array") return false;
const items = schema.items;
return isRecord(items) && items.type === "string";
}
function coerceSchemaValue(schema: unknown, value: unknown): unknown {
if (!isRecord(schema)) return value;
if (isStringArraySchema(schema) && typeof value === "string") {
return [value];
}
if (Array.isArray(value)) {
const items = schema.items;
if (!isRecord(items)) return value;
return value.map((item) => coerceSchemaValue(items, item));
}
if (!isRecord(value)) return value;
const properties = schema.properties;
if (!isRecord(properties)) return value;
let next: JsonSchemaObject | null = null;
for (const [key, propertySchema] of Object.entries(properties)) {
if (!Object.prototype.hasOwnProperty.call(value, key)) continue;
const coercedValue = coerceSchemaValue(propertySchema, value[key]);
if (coercedValue !== value[key]) {
next ??= { ...value };
next[key] = coercedValue;
}
}
return next ?? value;
}
/**
* Wraps bare strings for JSON-schema fields declared as string arrays before AJV validation.
*/
export function coerceStringArrays(schema: unknown, params: unknown): unknown {
return coerceSchemaValue(schema, params);
}
// Detect if we're in a browser extension environment with strict CSP
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
@ -47,7 +95,7 @@ export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
const validate = ajv.compile(tool.parameters);
// Clone arguments so AJV can safely mutate for type coercion
const args = structuredClone(toolCall.arguments);
const args = coerceStringArrays(tool.parameters, structuredClone(toolCall.arguments));
// Validate the arguments (AJV mutates args in-place for type coercion)
if (validate(args)) {

View file

@ -202,3 +202,44 @@ describe("ModelRegistry.getModel — convenience wrapper", () => {
assert.equal(model.provider, "zai");
});
});
// ── provider_model_allow final filter ─────────────────────────────────────────
describe("ModelRegistry provider_model_allow filter", () => {
it("keeps an allowed provider/model candidate", () => {
const registry = createRegistry();
registerNone(registry, "minimax", "MiniMax-M2.7");
const result = registry.getModelsForProxy("MiniMax-M2.7", {}, {
minimax: ["MiniMax-M2.7"],
});
assert.ok(
result.some((model) => model.provider === "minimax" && model.id === "MiniMax-M2.7"),
"allowed minimax/MiniMax-M2.7 candidate should survive filtering",
);
});
it("filters a disallowed provider/model candidate and falls through", () => {
const registry = createRegistry();
registerNone(registry, "minimax", "MiniMax-M2");
registerNone(registry, "minimax-cn", "MiniMax-M2");
const result = registry.getModelsForProxy("MiniMax-M2", {}, {
minimax: ["MiniMax-M2.7"],
});
assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["minimax-cn/MiniMax-M2"]);
});
it("leaves providers absent from the allow-list unrestricted", () => {
const registry = createRegistry();
registerNone(registry, "zai", "glm-4-air");
const result = registry.getModelsForProxy("glm-4-air", {}, {
minimax: ["MiniMax-M2.7"],
});
assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["zai/glm-4-air"]);
});
});

View file

@ -197,6 +197,8 @@ ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
type ModelsConfig = Static<typeof ModelsConfigSchema>;
export type ProviderModelAllowList = Record<string, readonly string[]>;
export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none";
/** Provider override config (baseUrl, headers, apiKey) without custom models */
@ -425,6 +427,28 @@ export class ModelRegistry {
return merged;
}
private isProviderModelAllowed(
provider: string,
modelId: string,
providerModelAllow?: ProviderModelAllowList,
): boolean {
if (!providerModelAllow) return true;
const providerKey = provider.toLowerCase();
const allowedModels = providerModelAllow[providerKey]
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
if (allowedModels === undefined) return true;
const modelKey = modelId.trim().toLowerCase();
return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey);
}
private filterProviderModelAllow<T extends Model<Api>>(
models: T[],
providerModelAllow?: ProviderModelAllowList,
): T[] {
if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return models;
return models.filter((model) => this.isProviderModelAllowed(model.provider, model.id, providerModelAllow));
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
if (!existsSync(modelsJsonPath)) {
return emptyCustomModelsResult();
@ -590,16 +614,19 @@ export class ModelRegistry {
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
getAll(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
return this.filterProviderModelAllow(this.models, providerModelAllow);
}
/**
* Get only models that have auth configured.
* This is a fast check that doesn't refresh OAuth tokens.
*/
getAvailable(): Model<Api>[] {
return this.models.filter((m) => this.isProviderRequestReady(m.provider));
getAvailable(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
return this.filterProviderModelAllow(
this.models.filter((m) => this.isProviderRequestReady(m.provider)),
providerModelAllow,
);
}
/**
@ -867,8 +894,15 @@ export class ModelRegistry {
* Candidates with auth configured are placed first within the same priority tier.
* The proxy server iterates this list and falls through to the next entry on 429.
*/
getModelsForProxy(modelId: string, overrides: Record<string, string[]> = {}): Model<Api>[] {
const candidates = this.models.filter((m) => m.id === modelId);
getModelsForProxy(
modelId: string,
overrides: Record<string, string[]> = {},
providerModelAllow?: ProviderModelAllowList,
): Model<Api>[] {
const candidates = this.filterProviderModelAllow(
this.models.filter((m) => m.id === modelId),
providerModelAllow,
);
if (candidates.length === 0) return [];
const order = this.buildCandidateOrder(modelId, overrides);
@ -887,8 +921,12 @@ export class ModelRegistry {
* Resolve a bare model ID to the single highest-priority candidate.
* Convenience wrapper over getModelsForProxy for callers that don't need retry.
*/
getModel(modelId: string, overrides: Record<string, string[]> = {}): Model<Api> | undefined {
return this.getModelsForProxy(modelId, overrides)[0];
getModel(
modelId: string,
overrides: Record<string, string[]> = {},
providerModelAllow?: ProviderModelAllowList,
): Model<Api> | undefined {
return this.getModelsForProxy(modelId, overrides, providerModelAllow)[0];
}
/**

View file

@ -8,7 +8,7 @@ import type { Api, Model } from "@singularity-forge/pi-ai";
import { getProviderCapabilities } from "@singularity-forge/pi-ai";
import type { ExtensionAPI, ExtensionContext } from "@singularity-forge/pi-coding-agent";
import type { SFPreferences } from "./preferences.js";
import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js";
import { filterModelsByProviderModelAllow, resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js";
import type { ComplexityTier } from "./complexity-classifier.js";
import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js";
import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js";
@ -212,16 +212,17 @@ export async function selectAndApplyModel(
// 400 "model not supported" dispatch failures in dr-repo.
const rawAvailable = ctx.modelRegistry.getAvailable();
const allowed = prefs?.allowed_providers;
const availableModels = (allowed && allowed.length > 0)
const providerAllowedModels = (allowed && allowed.length > 0)
? rawAvailable.filter(m => allowed.includes(m.provider.toLowerCase()))
: rawAvailable;
if (allowed && allowed.length > 0 && availableModels.length === 0) {
if (allowed && allowed.length > 0 && providerAllowedModels.length === 0) {
throw new Error(
`allowed_providers filter rejected every available model. ` +
`Configured providers: [${allowed.join(", ")}]. ` +
`Either add a provider to allowed_providers or remove the pref.`,
);
}
const availableModels = filterModelsByProviderModelAllow(providerAllowedModels, prefs?.provider_model_allow);
const modelPolicyTraceId = `model:${ctx.sessionManager.getSessionId()}:${Date.now()}`;
const modelPolicyTurnId = `${unitType}:${unitId}`;
let policyAllowedModelKeys: Set<string> | null = null;
@ -544,7 +545,10 @@ export async function selectAndApplyModel(
} else if (autoModeStartModel) {
// No model preference for this unit type — re-apply the model captured
// at auto-mode start to prevent bleed from shared global settings.json (#650).
const availableModels = ctx.modelRegistry.getAvailable();
const availableModels = filterModelsByProviderModelAllow(
ctx.modelRegistry.getAvailable(),
prefs?.provider_model_allow,
);
const startModel = availableModels.find(
m => m.provider === autoModeStartModel.provider && m.id === autoModeStartModel.id,
);

View file

@ -27,6 +27,37 @@ import { loadEffectiveSFPreferences, getGlobalSFPreferencesPath } from "./prefer
// Re-export types so existing consumers of ./preferences-models.js keep working
export type { SFPhaseModelConfig, SFModelConfig, SFModelConfigV2, ResolvedModelConfig } from "./preferences-types.js";
export type ProviderModelAllowList = Record<string, readonly string[]>;
function resolveProviderModelAllowList(
providerModelAllow: ProviderModelAllowList | undefined,
provider: string,
): readonly string[] | undefined {
if (!providerModelAllow) return undefined;
const providerKey = provider.toLowerCase();
return providerModelAllow[providerKey]
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
}
export function isProviderModelAllowed(
provider: string,
modelId: string,
providerModelAllow: ProviderModelAllowList | undefined,
): boolean {
const allowedModels = resolveProviderModelAllowList(providerModelAllow, provider);
if (allowedModels === undefined) return true;
const modelKey = modelId.trim().toLowerCase();
return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey);
}
export function filterModelsByProviderModelAllow<T extends { provider: string; id: string }>(
models: readonly T[],
providerModelAllow: ProviderModelAllowList | undefined,
): T[] {
if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return [...models];
return models.filter((model) => isProviderModelAllowed(model.provider, model.id, providerModelAllow));
}
/**
* Resolve which model ID to use for a given auto-mode unit type.
* Returns undefined if no model preference is set for this unit type.
@ -63,18 +94,11 @@ function resolveAutoBenchmarkPickForUnit(
): ResolvedModelConfig | undefined {
try {
const allowed = prefs?.allowed_providers?.map(s => s.toLowerCase());
const modelAllow = prefs?.provider_model_allow;
const candidates: Array<{ provider: string; id: string }> = [];
for (const provider of getProviders()) {
if (allowed && !allowed.includes(provider.toLowerCase())) continue;
// Per-provider model allow-list: when a provider is listed here, only
// the listed model IDs may be used from that provider.
const providerKey = Object.keys(modelAllow ?? {}).find(
k => k.toLowerCase() === provider.toLowerCase(),
);
const allowedModels = providerKey ? modelAllow![providerKey] : undefined;
for (const model of getModels(provider)) {
if (allowedModels && !allowedModels.includes(model.id)) continue;
if (!isProviderModelAllowed(provider, model.id, prefs?.provider_model_allow)) continue;
candidates.push({ provider, id: model.id });
}
}

View file

@ -426,7 +426,7 @@ export function validatePreferences(preferences: SFPreferences): {
}
// ─── Per-provider model allow-list ──────────────────────────────────
// When a provider has an entry here, ONLY listed model IDs are usable
// When a provider has an entry here, only listed model IDs are usable
// from that provider. Providers absent from the block are unrestricted.
if (preferences.provider_model_allow !== undefined) {
if (
@ -436,12 +436,19 @@ export function validatePreferences(preferences: SFPreferences): {
) {
const cleaned: Record<string, string[]> = {};
for (const [provider, models] of Object.entries(preferences.provider_model_allow as Record<string, unknown>)) {
const providerId = provider.trim().toLowerCase();
if (!providerId) {
errors.push("provider_model_allow provider IDs must be non-empty strings");
continue;
}
if (!Array.isArray(models) || models.some((m: unknown) => typeof m !== "string")) {
errors.push(`provider_model_allow.${provider} must be an array of model ID strings`);
continue;
}
const list = (models as string[]).map(s => s.trim()).filter(s => s.length > 0);
if (list.length > 0) cleaned[provider.toLowerCase()] = list;
const list = (models as string[])
.map(s => s.trim())
.filter(s => s.length > 0);
cleaned[providerId] = Array.from(new Set(list));
}
if (Object.keys(cleaned).length > 0) validated.provider_model_allow = cleaned;
} else {

View file

@ -95,6 +95,8 @@ export {
resolveInlineLevel,
resolveContextSelection,
resolveSearchProviderFromPreferences,
isProviderModelAllowed,
filterModelsByProviderModelAllow,
} from "./preferences-models.js";
// ─── Path Constants & Getters ───────────────────────────────────────────────
@ -509,11 +511,7 @@ function mergePreferences(base: SFPreferences, override: SFPreferences): SFPrefe
// override-wins merge so the preference actually reaches consumers.
allowed_providers: mergeStringLists(base.allowed_providers, override.allowed_providers),
provider_preference: override.provider_preference ?? base.provider_preference,
// Per-provider replace inside the array (project entry replaces global
// for that provider; providers only in base survive).
provider_model_allow: (base.provider_model_allow || override.provider_model_allow)
? { ...(base.provider_model_allow ?? {}), ...(override.provider_model_allow ?? {}) }
: undefined,
provider_model_allow: mergeProviderModelAllow(base.provider_model_allow, override.provider_model_allow),
flat_rate_providers: mergeStringLists(base.flat_rate_providers, override.flat_rate_providers),
stale_commit_threshold_minutes: override.stale_commit_threshold_minutes ?? base.stale_commit_threshold_minutes,
widget_mode: override.widget_mode ?? base.widget_mode,
@ -536,6 +534,23 @@ function mergeStringLists(base?: unknown, override?: unknown): string[] | undefi
return merged.length > 0 ? Array.from(new Set(merged)) : undefined;
}
function mergeProviderModelAllow(
base?: Record<string, string[]>,
override?: Record<string, string[]>,
): Record<string, string[]> | undefined {
if (!base && !override) return undefined;
const merged: Record<string, string[]> = {};
for (const [provider, models] of Object.entries(base ?? {})) {
merged[provider] = [...models];
}
for (const [provider, models] of Object.entries(override ?? {})) {
// Per-provider replace: a project entry replaces the global array for
// that provider instead of appending to it.
merged[provider] = [...models];
}
return Object.keys(merged).length > 0 ? merged : undefined;
}
function mergePostUnitHooks(
base?: PostUnitHookConfig[],
override?: PostUnitHookConfig[],

View file

@ -0,0 +1,130 @@
import test from "node:test";
import assert from "node:assert/strict";
import { mkdtempSync, mkdirSync, rmSync, writeFileSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { loadEffectiveSFPreferences, validatePreferences } from "../preferences.ts";
import { filterModelsByProviderModelAllow } from "../preferences-models.ts";
test("provider_model_allow: provider in allow-list and model in list passes", () => {
const models = [
{ provider: "minimax", id: "MiniMax-M2.7" },
{ provider: "minimax", id: "MiniMax-M2" },
];
const filtered = filterModelsByProviderModelAllow(models, {
minimax: ["MiniMax-M2.7"],
});
assert.deepEqual(filtered.map((m) => `${m.provider}/${m.id}`), ["minimax/MiniMax-M2.7"]);
});
test("provider_model_allow: provider in allow-list and model not in list is filtered", () => {
const models = [
{ provider: "minimax", id: "MiniMax-M2" },
{ provider: "zai", id: "glm-5" },
];
const filtered = filterModelsByProviderModelAllow(models, {
minimax: ["MiniMax-M2.7"],
});
assert.deepEqual(
filtered.map((m) => `${m.provider}/${m.id}`),
["zai/glm-5"],
"minimax/MiniMax-M2 is removed so selection can fall through to the next provider",
);
});
test("provider_model_allow: provider absent from allow-list is unrestricted", () => {
const models = [
{ provider: "minimax", id: "MiniMax-M2" },
{ provider: "zai", id: "glm-5" },
];
const filtered = filterModelsByProviderModelAllow(models, {
minimax: ["MiniMax-M2.7"],
});
assert.ok(filtered.some((m) => m.provider === "zai" && m.id === "glm-5"));
});
test("provider_model_allow: validates shape and normalizes provider IDs", () => {
const result = validatePreferences({
provider_model_allow: {
MiniMax: [" MiniMax-M2.7 ", "MiniMax-M2.7-highspeed"],
},
});
assert.equal(result.errors.length, 0);
assert.deepEqual(result.preferences.provider_model_allow, {
minimax: ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"],
});
});
test("provider_model_allow: rejects invalid shapes", () => {
const result = validatePreferences({
provider_model_allow: {
minimax: "MiniMax-M2.7",
zai: ["glm-5", 42],
} as any,
});
assert.ok(result.errors.some((error) => error.includes("provider_model_allow.minimax")));
assert.ok(result.errors.some((error) => error.includes("provider_model_allow.zai")));
assert.equal(result.preferences.provider_model_allow, undefined);
});
test("provider_model_allow: project overrides global per provider", () => {
const originalCwd = process.cwd();
const originalSfHome = process.env.SF_HOME;
const tempProject = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-project-"));
const tempHome = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-home-"));
try {
mkdirSync(join(tempProject, ".sf"), { recursive: true });
writeFileSync(
join(tempHome, "preferences.md"),
[
"---",
"provider_model_allow:",
" minimax:",
" - MiniMax-M2.7",
" zai:",
" - glm-5",
"---",
].join("\n"),
"utf-8",
);
writeFileSync(
join(tempProject, ".sf", "PREFERENCES.md"),
[
"---",
"provider_model_allow:",
" minimax:",
" - MiniMax-M2.7-highspeed",
"---",
].join("\n"),
"utf-8",
);
process.env.SF_HOME = tempHome;
process.chdir(tempProject);
const loaded = loadEffectiveSFPreferences();
assert.notEqual(loaded, null);
assert.deepEqual(loaded!.preferences.provider_model_allow, {
minimax: ["MiniMax-M2.7-highspeed"],
zai: ["glm-5"],
});
} finally {
process.chdir(originalCwd);
if (originalSfHome === undefined) delete process.env.SF_HOME;
else process.env.SF_HOME = originalSfHome;
rmSync(tempProject, { recursive: true, force: true });
rmSync(tempHome, { recursive: true, force: true });
}
});