diff --git a/src/resources/extensions/gsd/auto-model-selection.ts b/src/resources/extensions/gsd/auto-model-selection.ts index 75f2b434f..57a2b8904 100644 --- a/src/resources/extensions/gsd/auto-model-selection.ts +++ b/src/resources/extensions/gsd/auto-model-selection.ts @@ -10,7 +10,7 @@ import type { ExtensionAPI, ExtensionContext } from "@gsd/pi-coding-agent"; import type { GSDPreferences } from "./preferences.js"; import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig } from "./preferences.js"; import type { ComplexityTier } from "./complexity-classifier.js"; -import { classifyUnitComplexity, tierLabel } from "./complexity-classifier.js"; +import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js"; import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js"; import { getLedger, getProjectTotals } from "./metrics.js"; import { unitPhaseLabel } from "./auto-dashboard.js"; @@ -120,6 +120,10 @@ export async function selectAndApplyModel( let routingTierLabel = ""; let routingEligibleModels = availableModels; + const taskMetadataForPolicy = unitType === "execute-task" + ? extractTaskMetadata(unitId, basePath) + : undefined; + if (uokFlags.modelPolicy) { const policy = applyModelPolicyFilter( availableModels, @@ -128,6 +132,7 @@ export async function selectAndApplyModel( traceId: modelPolicyTraceId, turnId: modelPolicyTurnId, unitType, + taskMetadata: taskMetadataForPolicy, currentProvider: ctx.model?.provider, allowCrossProvider: routingConfig.cross_provider !== false, requiredTools: pi.getActiveTools(), @@ -182,7 +187,13 @@ export async function selectAndApplyModel( const shouldClassify = !isHook || routingConfig.hooks !== false; if (shouldClassify) { - let classification = classifyUnitComplexity(unitType, unitId, basePath, budgetPct); + let classification = classifyUnitComplexity( + unitType, + unitId, + basePath, + budgetPct, + taskMetadataForPolicy, + ); const availableModelIds = routingEligibleModels.map(m => m.id); // Escalate tier on retry when escalate_on_failure is enabled (default: true) @@ -293,7 +304,8 @@ export async function selectAndApplyModel( let attemptedPolicyEligible = false; for (const modelId of modelsToTry) { - const model = resolveModelId(modelId, availableModels, ctx.model?.provider); + const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels; + const model = resolveModelId(modelId, resolutionPool, ctx.model?.provider); if (!model) { if (verbose) ctx.ui.notify(`Model ${modelId} not found, trying fallback.`, "info"); diff --git a/src/resources/extensions/gsd/tests/auto-model-selection.test.ts b/src/resources/extensions/gsd/tests/auto-model-selection.test.ts index 1551888d4..7bb1cb7ba 100644 --- a/src/resources/extensions/gsd/tests/auto-model-selection.test.ts +++ b/src/resources/extensions/gsd/tests/auto-model-selection.test.ts @@ -227,6 +227,26 @@ test("model change notify in selectAndApplyModel is gated behind verbose flag", ); }); +test("model policy resolves candidates from the policy-eligible pool", () => { + const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8"); + assert.ok( + src.includes("const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels"), + "selectAndApplyModel should resolve model IDs against policy-eligible models when model policy is enabled", + ); +}); + +test("model policy receives task metadata for requirement-vector decisions", () => { + const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8"); + assert.ok( + src.includes("taskMetadata: taskMetadataForPolicy"), + "applyModelPolicyFilter should receive task metadata so requirement vectors are unit-aware", + ); + assert.ok( + src.includes("extractTaskMetadata(unitId, basePath)"), + "execute-task dispatch should derive metadata before policy filtering", + ); +}); + test("resolveModelId: anthropic wins over claude-code when session provider is not claude-code", () => { const availableModels = [ { id: "claude-sonnet-4-6", provider: "claude-code" }, diff --git a/src/resources/extensions/gsd/tests/uok-model-policy.test.ts b/src/resources/extensions/gsd/tests/uok-model-policy.test.ts new file mode 100644 index 000000000..dd2b2b93a --- /dev/null +++ b/src/resources/extensions/gsd/tests/uok-model-policy.test.ts @@ -0,0 +1,89 @@ +import test from "node:test"; +import assert from "node:assert/strict"; +import { mkdtempSync, mkdirSync, readFileSync, rmSync } from "node:fs"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; + +import { + applyModelPolicyFilter, + buildRequirementVector, +} from "../uok/model-policy.ts"; +import { + registerToolCompatibility, + resetToolCompatibilityRegistry, +} from "@gsd/pi-coding-agent"; + +test.afterEach(() => { + resetToolCompatibilityRegistry(); +}); + +test("uok model policy builds requirement vectors from unit metadata", () => { + const requirements = buildRequirementVector("execute-task", { + tags: ["docs"], + fileCount: 8, + estimatedLines: 600, + }); + + assert.equal(requirements.instruction, 0.9); + assert.equal(requirements.coding, 0.3); + assert.equal(requirements.speed, 0.7); +}); + +test("uok model policy enforces provider/api/tool constraints and emits decision audit events", () => { + const basePath = mkdtempSync(join(tmpdir(), "gsd-uok-model-policy-")); + try { + mkdirSync(join(basePath, ".gsd"), { recursive: true }); + registerToolCompatibility("screenshot", { producesImages: true }); + + const result = applyModelPolicyFilter( + [ + { id: "openai-image", provider: "openai", api: "openai-responses" }, + { id: "anthropic-ok", provider: "anthropic", api: "anthropic-messages" }, + { id: "gemini-api-deny", provider: "google", api: "google-generative-ai" }, + { id: "blocked-provider", provider: "blocked", api: "anthropic-messages" }, + ], + { + basePath, + traceId: "trace-model-policy-1", + turnId: "turn-model-policy-1", + unitType: "execute-task", + taskMetadata: { tags: ["docs"] }, + allowCrossProvider: true, + requiredTools: ["screenshot"], + allowedApis: ["anthropic-messages", "openai-responses"], + deniedProviders: ["blocked"], + }, + ); + + assert.deepEqual( + result.eligible.map((m) => m.id), + ["anthropic-ok"], + "only the policy-compliant anthropic model should remain eligible", + ); + assert.equal(result.decisions.length, 4); + assert.equal(result.decisions[0]?.allowed, false); + assert.match(result.decisions[0]?.reason ?? "", /tool policy denied/); + assert.equal(result.decisions[1]?.allowed, true); + assert.equal(result.decisions[2]?.allowed, false); + assert.match(result.decisions[2]?.reason ?? "", /transport\/api denied by policy/); + assert.equal(result.decisions[3]?.allowed, false); + assert.match(result.decisions[3]?.reason ?? "", /provider denied by policy/); + + const auditLogPath = join(basePath, ".gsd", "audit", "events.jsonl"); + const auditLines = readFileSync(auditLogPath, "utf-8") + .trim() + .split("\n") + .map((line) => JSON.parse(line) as { type: string; payload?: { reason?: string } }); + const decisionTypes = auditLines.map((event) => event.type); + + assert.equal(auditLines.length, 4); + assert.ok(decisionTypes.includes("model-policy-allow")); + assert.ok(decisionTypes.includes("model-policy-deny")); + assert.ok( + auditLines.some((event) => (event.payload?.reason ?? "").includes("tool policy denied")), + "audit stream should include explicit deny reasons", + ); + } finally { + rmSync(basePath, { recursive: true, force: true }); + } +});