fix(schema): auto-coerce string → [string] for sf_* list fields + provider_model_allow tests
Two codex-rescue tasks landed together:
1. Auto-coerce JSON-schema validator: when a tool field declares
{type:"array", items:{type:"string"}} and the model sends a single
string, wrap it in [string] before validation instead of hard-rejecting.
Fixes the recurring "keyDecisions: must be array" rejection on
sf_complete_task that wasted retries.
2. Provider_model_allow filter (proper implementation with helpers):
- resolveProviderModelAllowList / isProviderModelAllowed /
filterModelsByProviderModelAllow helpers in preferences-models
- Wired into model-registry and auto-model-selection
- New tests/provider-model-allow.test.ts
Tools coerced: sf_complete_task, sf_complete_milestone, sf_plan_milestone,
sf_plan_slice, sf_replan_slice, sf_reassess_roadmap (key list fields).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-Authored-By: OpenAI Codex <noreply@openai.com>
This commit is contained in:
parent
f98a1e360e
commit
d38e5ea092
9 changed files with 434 additions and 29 deletions
98
packages/mcp-server/src/coerce-string-arrays.test.ts
Normal file
98
packages/mcp-server/src/coerce-string-arrays.test.ts
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import { describe, it } from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { z } from "zod";
|
||||
|
||||
import { validateToolArguments } from "../../pi-ai/src/utils/validation.ts";
|
||||
import { registerWorkflowTools } from "./workflow-tools.ts";
|
||||
|
||||
type RegisteredTool = {
|
||||
name: string;
|
||||
description: string;
|
||||
params: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function makeMockServer() {
|
||||
const tools: RegisteredTool[] = [];
|
||||
return {
|
||||
tools,
|
||||
tool(
|
||||
name: string,
|
||||
description: string,
|
||||
params: Record<string, unknown>,
|
||||
_handler: (args: Record<string, unknown>) => Promise<unknown>,
|
||||
) {
|
||||
tools.push({ name, description, params });
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function workflowToolSchema(toolName: string): Record<string, unknown> {
|
||||
const server = makeMockServer();
|
||||
registerWorkflowTools(server as any);
|
||||
const tool = server.tools.find((candidate) => candidate.name === toolName);
|
||||
assert.ok(tool, `${toolName} should be registered`);
|
||||
|
||||
const schema = z.toJSONSchema(z.object(tool.params as z.ZodRawShape)) as Record<string, unknown>;
|
||||
delete schema.$schema;
|
||||
return schema;
|
||||
}
|
||||
|
||||
function makeToolCall(overrides: Record<string, unknown>) {
|
||||
return {
|
||||
type: "toolCall" as const,
|
||||
id: "call-1",
|
||||
name: "sf_complete_task",
|
||||
arguments: {
|
||||
projectDir: "/tmp/sf-project",
|
||||
taskId: "T01",
|
||||
sliceId: "S01",
|
||||
milestoneId: "M001",
|
||||
oneLiner: "Completed task",
|
||||
narrative: "Did the work.",
|
||||
verification: "npm test",
|
||||
...overrides,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
describe("string-array schema coercion", () => {
|
||||
const sfCompleteTaskTool = {
|
||||
name: "sf_complete_task",
|
||||
description: "Record a completed task to the SF database and render its SUMMARY.md.",
|
||||
parameters: workflowToolSchema("sf_complete_task") as any,
|
||||
};
|
||||
|
||||
it("coerces a bare string keyDecisions value before validation", () => {
|
||||
const args = validateToolArguments(
|
||||
sfCompleteTaskTool,
|
||||
makeToolCall({ keyDecisions: "single string" }),
|
||||
);
|
||||
|
||||
assert.deepEqual(args.keyDecisions, ["single string"]);
|
||||
});
|
||||
|
||||
it("keeps an array keyDecisions value valid", () => {
|
||||
const args = validateToolArguments(
|
||||
sfCompleteTaskTool,
|
||||
makeToolCall({ keyDecisions: ["a", "b"] }),
|
||||
);
|
||||
|
||||
assert.deepEqual(args.keyDecisions, ["a", "b"]);
|
||||
});
|
||||
|
||||
it("rejects a non-string, non-array keyDecisions value", () => {
|
||||
assert.throws(
|
||||
() => validateToolArguments(sfCompleteTaskTool, makeToolCall({ keyDecisions: 42 })),
|
||||
/keyDecisions: must be array/,
|
||||
);
|
||||
});
|
||||
|
||||
it("allows an undefined optional keyDecisions value", () => {
|
||||
const args = validateToolArguments(
|
||||
sfCompleteTaskTool,
|
||||
makeToolCall({ keyDecisions: undefined }),
|
||||
);
|
||||
|
||||
assert.equal(args.keyDecisions, undefined);
|
||||
});
|
||||
});
|
||||
|
|
@ -7,6 +7,54 @@ const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
|||
|
||||
import type { Tool, ToolCall } from "../types.js";
|
||||
|
||||
type JsonSchemaObject = Record<string, unknown>;
|
||||
|
||||
function isRecord(value: unknown): value is JsonSchemaObject {
|
||||
return value !== null && typeof value === "object" && !Array.isArray(value);
|
||||
}
|
||||
|
||||
function isStringArraySchema(schema: unknown): schema is JsonSchemaObject {
|
||||
if (!isRecord(schema) || schema.type !== "array") return false;
|
||||
const items = schema.items;
|
||||
return isRecord(items) && items.type === "string";
|
||||
}
|
||||
|
||||
function coerceSchemaValue(schema: unknown, value: unknown): unknown {
|
||||
if (!isRecord(schema)) return value;
|
||||
if (isStringArraySchema(schema) && typeof value === "string") {
|
||||
return [value];
|
||||
}
|
||||
|
||||
if (Array.isArray(value)) {
|
||||
const items = schema.items;
|
||||
if (!isRecord(items)) return value;
|
||||
return value.map((item) => coerceSchemaValue(items, item));
|
||||
}
|
||||
|
||||
if (!isRecord(value)) return value;
|
||||
|
||||
const properties = schema.properties;
|
||||
if (!isRecord(properties)) return value;
|
||||
|
||||
let next: JsonSchemaObject | null = null;
|
||||
for (const [key, propertySchema] of Object.entries(properties)) {
|
||||
if (!Object.prototype.hasOwnProperty.call(value, key)) continue;
|
||||
const coercedValue = coerceSchemaValue(propertySchema, value[key]);
|
||||
if (coercedValue !== value[key]) {
|
||||
next ??= { ...value };
|
||||
next[key] = coercedValue;
|
||||
}
|
||||
}
|
||||
return next ?? value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wraps bare strings for JSON-schema fields declared as string arrays before AJV validation.
|
||||
*/
|
||||
export function coerceStringArrays(schema: unknown, params: unknown): unknown {
|
||||
return coerceSchemaValue(schema, params);
|
||||
}
|
||||
|
||||
// Detect if we're in a browser extension environment with strict CSP
|
||||
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
||||
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
|
||||
|
|
@ -47,7 +95,7 @@ export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
|||
const validate = ajv.compile(tool.parameters);
|
||||
|
||||
// Clone arguments so AJV can safely mutate for type coercion
|
||||
const args = structuredClone(toolCall.arguments);
|
||||
const args = coerceStringArrays(tool.parameters, structuredClone(toolCall.arguments));
|
||||
|
||||
// Validate the arguments (AJV mutates args in-place for type coercion)
|
||||
if (validate(args)) {
|
||||
|
|
|
|||
|
|
@ -202,3 +202,44 @@ describe("ModelRegistry.getModel — convenience wrapper", () => {
|
|||
assert.equal(model.provider, "zai");
|
||||
});
|
||||
});
|
||||
|
||||
// ── provider_model_allow final filter ─────────────────────────────────────────
|
||||
|
||||
describe("ModelRegistry provider_model_allow filter", () => {
|
||||
it("keeps an allowed provider/model candidate", () => {
|
||||
const registry = createRegistry();
|
||||
registerNone(registry, "minimax", "MiniMax-M2.7");
|
||||
|
||||
const result = registry.getModelsForProxy("MiniMax-M2.7", {}, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.ok(
|
||||
result.some((model) => model.provider === "minimax" && model.id === "MiniMax-M2.7"),
|
||||
"allowed minimax/MiniMax-M2.7 candidate should survive filtering",
|
||||
);
|
||||
});
|
||||
|
||||
it("filters a disallowed provider/model candidate and falls through", () => {
|
||||
const registry = createRegistry();
|
||||
registerNone(registry, "minimax", "MiniMax-M2");
|
||||
registerNone(registry, "minimax-cn", "MiniMax-M2");
|
||||
|
||||
const result = registry.getModelsForProxy("MiniMax-M2", {}, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["minimax-cn/MiniMax-M2"]);
|
||||
});
|
||||
|
||||
it("leaves providers absent from the allow-list unrestricted", () => {
|
||||
const registry = createRegistry();
|
||||
registerNone(registry, "zai", "glm-4-air");
|
||||
|
||||
const result = registry.getModelsForProxy("glm-4-air", {}, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.deepEqual(result.map((m) => `${m.provider}/${m.id}`), ["zai/glm-4-air"]);
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -197,6 +197,8 @@ ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
|
|||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
export type ProviderModelAllowList = Record<string, readonly string[]>;
|
||||
|
||||
export type ProviderAuthMode = "apiKey" | "oauth" | "externalCli" | "none";
|
||||
|
||||
/** Provider override config (baseUrl, headers, apiKey) without custom models */
|
||||
|
|
@ -425,6 +427,28 @@ export class ModelRegistry {
|
|||
return merged;
|
||||
}
|
||||
|
||||
private isProviderModelAllowed(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): boolean {
|
||||
if (!providerModelAllow) return true;
|
||||
const providerKey = provider.toLowerCase();
|
||||
const allowedModels = providerModelAllow[providerKey]
|
||||
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
|
||||
if (allowedModels === undefined) return true;
|
||||
const modelKey = modelId.trim().toLowerCase();
|
||||
return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey);
|
||||
}
|
||||
|
||||
private filterProviderModelAllow<T extends Model<Api>>(
|
||||
models: T[],
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): T[] {
|
||||
if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return models;
|
||||
return models.filter((model) => this.isProviderModelAllowed(model.provider, model.id, providerModelAllow));
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return emptyCustomModelsResult();
|
||||
|
|
@ -590,16 +614,19 @@ export class ModelRegistry {
|
|||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
getAll(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
|
||||
return this.filterProviderModelAllow(this.models, providerModelAllow);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have auth configured.
|
||||
* This is a fast check that doesn't refresh OAuth tokens.
|
||||
*/
|
||||
getAvailable(): Model<Api>[] {
|
||||
return this.models.filter((m) => this.isProviderRequestReady(m.provider));
|
||||
getAvailable(providerModelAllow?: ProviderModelAllowList): Model<Api>[] {
|
||||
return this.filterProviderModelAllow(
|
||||
this.models.filter((m) => this.isProviderRequestReady(m.provider)),
|
||||
providerModelAllow,
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -867,8 +894,15 @@ export class ModelRegistry {
|
|||
* Candidates with auth configured are placed first within the same priority tier.
|
||||
* The proxy server iterates this list and falls through to the next entry on 429.
|
||||
*/
|
||||
getModelsForProxy(modelId: string, overrides: Record<string, string[]> = {}): Model<Api>[] {
|
||||
const candidates = this.models.filter((m) => m.id === modelId);
|
||||
getModelsForProxy(
|
||||
modelId: string,
|
||||
overrides: Record<string, string[]> = {},
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): Model<Api>[] {
|
||||
const candidates = this.filterProviderModelAllow(
|
||||
this.models.filter((m) => m.id === modelId),
|
||||
providerModelAllow,
|
||||
);
|
||||
if (candidates.length === 0) return [];
|
||||
|
||||
const order = this.buildCandidateOrder(modelId, overrides);
|
||||
|
|
@ -887,8 +921,12 @@ export class ModelRegistry {
|
|||
* Resolve a bare model ID to the single highest-priority candidate.
|
||||
* Convenience wrapper over getModelsForProxy for callers that don't need retry.
|
||||
*/
|
||||
getModel(modelId: string, overrides: Record<string, string[]> = {}): Model<Api> | undefined {
|
||||
return this.getModelsForProxy(modelId, overrides)[0];
|
||||
getModel(
|
||||
modelId: string,
|
||||
overrides: Record<string, string[]> = {},
|
||||
providerModelAllow?: ProviderModelAllowList,
|
||||
): Model<Api> | undefined {
|
||||
return this.getModelsForProxy(modelId, overrides, providerModelAllow)[0];
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import type { Api, Model } from "@singularity-forge/pi-ai";
|
|||
import { getProviderCapabilities } from "@singularity-forge/pi-ai";
|
||||
import type { ExtensionAPI, ExtensionContext } from "@singularity-forge/pi-coding-agent";
|
||||
import type { SFPreferences } from "./preferences.js";
|
||||
import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js";
|
||||
import { filterModelsByProviderModelAllow, resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig, resolvePersistModelChanges } from "./preferences.js";
|
||||
import type { ComplexityTier } from "./complexity-classifier.js";
|
||||
import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js";
|
||||
import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js";
|
||||
|
|
@ -212,16 +212,17 @@ export async function selectAndApplyModel(
|
|||
// 400 "model not supported" dispatch failures in dr-repo.
|
||||
const rawAvailable = ctx.modelRegistry.getAvailable();
|
||||
const allowed = prefs?.allowed_providers;
|
||||
const availableModels = (allowed && allowed.length > 0)
|
||||
const providerAllowedModels = (allowed && allowed.length > 0)
|
||||
? rawAvailable.filter(m => allowed.includes(m.provider.toLowerCase()))
|
||||
: rawAvailable;
|
||||
if (allowed && allowed.length > 0 && availableModels.length === 0) {
|
||||
if (allowed && allowed.length > 0 && providerAllowedModels.length === 0) {
|
||||
throw new Error(
|
||||
`allowed_providers filter rejected every available model. ` +
|
||||
`Configured providers: [${allowed.join(", ")}]. ` +
|
||||
`Either add a provider to allowed_providers or remove the pref.`,
|
||||
);
|
||||
}
|
||||
const availableModels = filterModelsByProviderModelAllow(providerAllowedModels, prefs?.provider_model_allow);
|
||||
const modelPolicyTraceId = `model:${ctx.sessionManager.getSessionId()}:${Date.now()}`;
|
||||
const modelPolicyTurnId = `${unitType}:${unitId}`;
|
||||
let policyAllowedModelKeys: Set<string> | null = null;
|
||||
|
|
@ -544,7 +545,10 @@ export async function selectAndApplyModel(
|
|||
} else if (autoModeStartModel) {
|
||||
// No model preference for this unit type — re-apply the model captured
|
||||
// at auto-mode start to prevent bleed from shared global settings.json (#650).
|
||||
const availableModels = ctx.modelRegistry.getAvailable();
|
||||
const availableModels = filterModelsByProviderModelAllow(
|
||||
ctx.modelRegistry.getAvailable(),
|
||||
prefs?.provider_model_allow,
|
||||
);
|
||||
const startModel = availableModels.find(
|
||||
m => m.provider === autoModeStartModel.provider && m.id === autoModeStartModel.id,
|
||||
);
|
||||
|
|
|
|||
|
|
@ -27,6 +27,37 @@ import { loadEffectiveSFPreferences, getGlobalSFPreferencesPath } from "./prefer
|
|||
// Re-export types so existing consumers of ./preferences-models.js keep working
|
||||
export type { SFPhaseModelConfig, SFModelConfig, SFModelConfigV2, ResolvedModelConfig } from "./preferences-types.js";
|
||||
|
||||
export type ProviderModelAllowList = Record<string, readonly string[]>;
|
||||
|
||||
function resolveProviderModelAllowList(
|
||||
providerModelAllow: ProviderModelAllowList | undefined,
|
||||
provider: string,
|
||||
): readonly string[] | undefined {
|
||||
if (!providerModelAllow) return undefined;
|
||||
const providerKey = provider.toLowerCase();
|
||||
return providerModelAllow[providerKey]
|
||||
?? Object.entries(providerModelAllow).find(([key]) => key.toLowerCase() === providerKey)?.[1];
|
||||
}
|
||||
|
||||
export function isProviderModelAllowed(
|
||||
provider: string,
|
||||
modelId: string,
|
||||
providerModelAllow: ProviderModelAllowList | undefined,
|
||||
): boolean {
|
||||
const allowedModels = resolveProviderModelAllowList(providerModelAllow, provider);
|
||||
if (allowedModels === undefined) return true;
|
||||
const modelKey = modelId.trim().toLowerCase();
|
||||
return allowedModels.some((allowedModel) => allowedModel.trim().toLowerCase() === modelKey);
|
||||
}
|
||||
|
||||
export function filterModelsByProviderModelAllow<T extends { provider: string; id: string }>(
|
||||
models: readonly T[],
|
||||
providerModelAllow: ProviderModelAllowList | undefined,
|
||||
): T[] {
|
||||
if (!providerModelAllow || Object.keys(providerModelAllow).length === 0) return [...models];
|
||||
return models.filter((model) => isProviderModelAllowed(model.provider, model.id, providerModelAllow));
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve which model ID to use for a given auto-mode unit type.
|
||||
* Returns undefined if no model preference is set for this unit type.
|
||||
|
|
@ -63,18 +94,11 @@ function resolveAutoBenchmarkPickForUnit(
|
|||
): ResolvedModelConfig | undefined {
|
||||
try {
|
||||
const allowed = prefs?.allowed_providers?.map(s => s.toLowerCase());
|
||||
const modelAllow = prefs?.provider_model_allow;
|
||||
const candidates: Array<{ provider: string; id: string }> = [];
|
||||
for (const provider of getProviders()) {
|
||||
if (allowed && !allowed.includes(provider.toLowerCase())) continue;
|
||||
// Per-provider model allow-list: when a provider is listed here, only
|
||||
// the listed model IDs may be used from that provider.
|
||||
const providerKey = Object.keys(modelAllow ?? {}).find(
|
||||
k => k.toLowerCase() === provider.toLowerCase(),
|
||||
);
|
||||
const allowedModels = providerKey ? modelAllow![providerKey] : undefined;
|
||||
for (const model of getModels(provider)) {
|
||||
if (allowedModels && !allowedModels.includes(model.id)) continue;
|
||||
if (!isProviderModelAllowed(provider, model.id, prefs?.provider_model_allow)) continue;
|
||||
candidates.push({ provider, id: model.id });
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -426,7 +426,7 @@ export function validatePreferences(preferences: SFPreferences): {
|
|||
}
|
||||
|
||||
// ─── Per-provider model allow-list ──────────────────────────────────
|
||||
// When a provider has an entry here, ONLY listed model IDs are usable
|
||||
// When a provider has an entry here, only listed model IDs are usable
|
||||
// from that provider. Providers absent from the block are unrestricted.
|
||||
if (preferences.provider_model_allow !== undefined) {
|
||||
if (
|
||||
|
|
@ -436,12 +436,19 @@ export function validatePreferences(preferences: SFPreferences): {
|
|||
) {
|
||||
const cleaned: Record<string, string[]> = {};
|
||||
for (const [provider, models] of Object.entries(preferences.provider_model_allow as Record<string, unknown>)) {
|
||||
const providerId = provider.trim().toLowerCase();
|
||||
if (!providerId) {
|
||||
errors.push("provider_model_allow provider IDs must be non-empty strings");
|
||||
continue;
|
||||
}
|
||||
if (!Array.isArray(models) || models.some((m: unknown) => typeof m !== "string")) {
|
||||
errors.push(`provider_model_allow.${provider} must be an array of model ID strings`);
|
||||
continue;
|
||||
}
|
||||
const list = (models as string[]).map(s => s.trim()).filter(s => s.length > 0);
|
||||
if (list.length > 0) cleaned[provider.toLowerCase()] = list;
|
||||
const list = (models as string[])
|
||||
.map(s => s.trim())
|
||||
.filter(s => s.length > 0);
|
||||
cleaned[providerId] = Array.from(new Set(list));
|
||||
}
|
||||
if (Object.keys(cleaned).length > 0) validated.provider_model_allow = cleaned;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -95,6 +95,8 @@ export {
|
|||
resolveInlineLevel,
|
||||
resolveContextSelection,
|
||||
resolveSearchProviderFromPreferences,
|
||||
isProviderModelAllowed,
|
||||
filterModelsByProviderModelAllow,
|
||||
} from "./preferences-models.js";
|
||||
|
||||
// ─── Path Constants & Getters ───────────────────────────────────────────────
|
||||
|
|
@ -509,11 +511,7 @@ function mergePreferences(base: SFPreferences, override: SFPreferences): SFPrefe
|
|||
// override-wins merge so the preference actually reaches consumers.
|
||||
allowed_providers: mergeStringLists(base.allowed_providers, override.allowed_providers),
|
||||
provider_preference: override.provider_preference ?? base.provider_preference,
|
||||
// Per-provider replace inside the array (project entry replaces global
|
||||
// for that provider; providers only in base survive).
|
||||
provider_model_allow: (base.provider_model_allow || override.provider_model_allow)
|
||||
? { ...(base.provider_model_allow ?? {}), ...(override.provider_model_allow ?? {}) }
|
||||
: undefined,
|
||||
provider_model_allow: mergeProviderModelAllow(base.provider_model_allow, override.provider_model_allow),
|
||||
flat_rate_providers: mergeStringLists(base.flat_rate_providers, override.flat_rate_providers),
|
||||
stale_commit_threshold_minutes: override.stale_commit_threshold_minutes ?? base.stale_commit_threshold_minutes,
|
||||
widget_mode: override.widget_mode ?? base.widget_mode,
|
||||
|
|
@ -536,6 +534,23 @@ function mergeStringLists(base?: unknown, override?: unknown): string[] | undefi
|
|||
return merged.length > 0 ? Array.from(new Set(merged)) : undefined;
|
||||
}
|
||||
|
||||
function mergeProviderModelAllow(
|
||||
base?: Record<string, string[]>,
|
||||
override?: Record<string, string[]>,
|
||||
): Record<string, string[]> | undefined {
|
||||
if (!base && !override) return undefined;
|
||||
const merged: Record<string, string[]> = {};
|
||||
for (const [provider, models] of Object.entries(base ?? {})) {
|
||||
merged[provider] = [...models];
|
||||
}
|
||||
for (const [provider, models] of Object.entries(override ?? {})) {
|
||||
// Per-provider replace: a project entry replaces the global array for
|
||||
// that provider instead of appending to it.
|
||||
merged[provider] = [...models];
|
||||
}
|
||||
return Object.keys(merged).length > 0 ? merged : undefined;
|
||||
}
|
||||
|
||||
function mergePostUnitHooks(
|
||||
base?: PostUnitHookConfig[],
|
||||
override?: PostUnitHookConfig[],
|
||||
|
|
|
|||
130
src/resources/extensions/sf/tests/provider-model-allow.test.ts
Normal file
130
src/resources/extensions/sf/tests/provider-model-allow.test.ts
Normal file
|
|
@ -0,0 +1,130 @@
|
|||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { mkdtempSync, mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
|
||||
import { loadEffectiveSFPreferences, validatePreferences } from "../preferences.ts";
|
||||
import { filterModelsByProviderModelAllow } from "../preferences-models.ts";
|
||||
|
||||
test("provider_model_allow: provider in allow-list and model in list passes", () => {
|
||||
const models = [
|
||||
{ provider: "minimax", id: "MiniMax-M2.7" },
|
||||
{ provider: "minimax", id: "MiniMax-M2" },
|
||||
];
|
||||
|
||||
const filtered = filterModelsByProviderModelAllow(models, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.deepEqual(filtered.map((m) => `${m.provider}/${m.id}`), ["minimax/MiniMax-M2.7"]);
|
||||
});
|
||||
|
||||
test("provider_model_allow: provider in allow-list and model not in list is filtered", () => {
|
||||
const models = [
|
||||
{ provider: "minimax", id: "MiniMax-M2" },
|
||||
{ provider: "zai", id: "glm-5" },
|
||||
];
|
||||
|
||||
const filtered = filterModelsByProviderModelAllow(models, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.deepEqual(
|
||||
filtered.map((m) => `${m.provider}/${m.id}`),
|
||||
["zai/glm-5"],
|
||||
"minimax/MiniMax-M2 is removed so selection can fall through to the next provider",
|
||||
);
|
||||
});
|
||||
|
||||
test("provider_model_allow: provider absent from allow-list is unrestricted", () => {
|
||||
const models = [
|
||||
{ provider: "minimax", id: "MiniMax-M2" },
|
||||
{ provider: "zai", id: "glm-5" },
|
||||
];
|
||||
|
||||
const filtered = filterModelsByProviderModelAllow(models, {
|
||||
minimax: ["MiniMax-M2.7"],
|
||||
});
|
||||
|
||||
assert.ok(filtered.some((m) => m.provider === "zai" && m.id === "glm-5"));
|
||||
});
|
||||
|
||||
test("provider_model_allow: validates shape and normalizes provider IDs", () => {
|
||||
const result = validatePreferences({
|
||||
provider_model_allow: {
|
||||
MiniMax: [" MiniMax-M2.7 ", "MiniMax-M2.7-highspeed"],
|
||||
},
|
||||
});
|
||||
|
||||
assert.equal(result.errors.length, 0);
|
||||
assert.deepEqual(result.preferences.provider_model_allow, {
|
||||
minimax: ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"],
|
||||
});
|
||||
});
|
||||
|
||||
test("provider_model_allow: rejects invalid shapes", () => {
|
||||
const result = validatePreferences({
|
||||
provider_model_allow: {
|
||||
minimax: "MiniMax-M2.7",
|
||||
zai: ["glm-5", 42],
|
||||
} as any,
|
||||
});
|
||||
|
||||
assert.ok(result.errors.some((error) => error.includes("provider_model_allow.minimax")));
|
||||
assert.ok(result.errors.some((error) => error.includes("provider_model_allow.zai")));
|
||||
assert.equal(result.preferences.provider_model_allow, undefined);
|
||||
});
|
||||
|
||||
test("provider_model_allow: project overrides global per provider", () => {
|
||||
const originalCwd = process.cwd();
|
||||
const originalSfHome = process.env.SF_HOME;
|
||||
const tempProject = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-project-"));
|
||||
const tempHome = mkdtempSync(join(tmpdir(), "sf-provider-model-allow-home-"));
|
||||
|
||||
try {
|
||||
mkdirSync(join(tempProject, ".sf"), { recursive: true });
|
||||
|
||||
writeFileSync(
|
||||
join(tempHome, "preferences.md"),
|
||||
[
|
||||
"---",
|
||||
"provider_model_allow:",
|
||||
" minimax:",
|
||||
" - MiniMax-M2.7",
|
||||
" zai:",
|
||||
" - glm-5",
|
||||
"---",
|
||||
].join("\n"),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
writeFileSync(
|
||||
join(tempProject, ".sf", "PREFERENCES.md"),
|
||||
[
|
||||
"---",
|
||||
"provider_model_allow:",
|
||||
" minimax:",
|
||||
" - MiniMax-M2.7-highspeed",
|
||||
"---",
|
||||
].join("\n"),
|
||||
"utf-8",
|
||||
);
|
||||
|
||||
process.env.SF_HOME = tempHome;
|
||||
process.chdir(tempProject);
|
||||
|
||||
const loaded = loadEffectiveSFPreferences();
|
||||
assert.notEqual(loaded, null);
|
||||
assert.deepEqual(loaded!.preferences.provider_model_allow, {
|
||||
minimax: ["MiniMax-M2.7-highspeed"],
|
||||
zai: ["glm-5"],
|
||||
});
|
||||
} finally {
|
||||
process.chdir(originalCwd);
|
||||
if (originalSfHome === undefined) delete process.env.SF_HOME;
|
||||
else process.env.SF_HOME = originalSfHome;
|
||||
rmSync(tempProject, { recursive: true, force: true });
|
||||
rmSync(tempHome, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
Loading…
Add table
Reference in a new issue