feat(gsd-uok): enforce model policy filtering before routing

This commit is contained in:
Jeremy McSpadden 2026-04-14 20:41:08 -05:00
parent 00521b1418
commit 414c2ee58c
3 changed files with 124 additions and 3 deletions

View file

@ -10,7 +10,7 @@ import type { ExtensionAPI, ExtensionContext } from "@gsd/pi-coding-agent";
import type { GSDPreferences } from "./preferences.js";
import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig } from "./preferences.js";
import type { ComplexityTier } from "./complexity-classifier.js";
import { classifyUnitComplexity, tierLabel } from "./complexity-classifier.js";
import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js";
import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js";
import { getLedger, getProjectTotals } from "./metrics.js";
import { unitPhaseLabel } from "./auto-dashboard.js";
@ -120,6 +120,10 @@ export async function selectAndApplyModel(
let routingTierLabel = "";
let routingEligibleModels = availableModels;
const taskMetadataForPolicy = unitType === "execute-task"
? extractTaskMetadata(unitId, basePath)
: undefined;
if (uokFlags.modelPolicy) {
const policy = applyModelPolicyFilter(
availableModels,
@ -128,6 +132,7 @@ export async function selectAndApplyModel(
traceId: modelPolicyTraceId,
turnId: modelPolicyTurnId,
unitType,
taskMetadata: taskMetadataForPolicy,
currentProvider: ctx.model?.provider,
allowCrossProvider: routingConfig.cross_provider !== false,
requiredTools: pi.getActiveTools(),
@ -182,7 +187,13 @@ export async function selectAndApplyModel(
const shouldClassify = !isHook || routingConfig.hooks !== false;
if (shouldClassify) {
let classification = classifyUnitComplexity(unitType, unitId, basePath, budgetPct);
let classification = classifyUnitComplexity(
unitType,
unitId,
basePath,
budgetPct,
taskMetadataForPolicy,
);
const availableModelIds = routingEligibleModels.map(m => m.id);
// Escalate tier on retry when escalate_on_failure is enabled (default: true)
@ -293,7 +304,8 @@ export async function selectAndApplyModel(
let attemptedPolicyEligible = false;
for (const modelId of modelsToTry) {
const model = resolveModelId(modelId, availableModels, ctx.model?.provider);
const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels;
const model = resolveModelId(modelId, resolutionPool, ctx.model?.provider);
if (!model) {
if (verbose) ctx.ui.notify(`Model ${modelId} not found, trying fallback.`, "info");

View file

@ -227,6 +227,26 @@ test("model change notify in selectAndApplyModel is gated behind verbose flag",
);
});
test("model policy resolves candidates from the policy-eligible pool", () => {
const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8");
assert.ok(
src.includes("const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels"),
"selectAndApplyModel should resolve model IDs against policy-eligible models when model policy is enabled",
);
});
test("model policy receives task metadata for requirement-vector decisions", () => {
const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8");
assert.ok(
src.includes("taskMetadata: taskMetadataForPolicy"),
"applyModelPolicyFilter should receive task metadata so requirement vectors are unit-aware",
);
assert.ok(
src.includes("extractTaskMetadata(unitId, basePath)"),
"execute-task dispatch should derive metadata before policy filtering",
);
});
test("resolveModelId: anthropic wins over claude-code when session provider is not claude-code", () => {
const availableModels = [
{ id: "claude-sonnet-4-6", provider: "claude-code" },

View file

@ -0,0 +1,89 @@
import test from "node:test";
import assert from "node:assert/strict";
import { mkdtempSync, mkdirSync, readFileSync, rmSync } from "node:fs";
import { join } from "node:path";
import { tmpdir } from "node:os";
import {
applyModelPolicyFilter,
buildRequirementVector,
} from "../uok/model-policy.ts";
import {
registerToolCompatibility,
resetToolCompatibilityRegistry,
} from "@gsd/pi-coding-agent";
test.afterEach(() => {
resetToolCompatibilityRegistry();
});
test("uok model policy builds requirement vectors from unit metadata", () => {
const requirements = buildRequirementVector("execute-task", {
tags: ["docs"],
fileCount: 8,
estimatedLines: 600,
});
assert.equal(requirements.instruction, 0.9);
assert.equal(requirements.coding, 0.3);
assert.equal(requirements.speed, 0.7);
});
test("uok model policy enforces provider/api/tool constraints and emits decision audit events", () => {
const basePath = mkdtempSync(join(tmpdir(), "gsd-uok-model-policy-"));
try {
mkdirSync(join(basePath, ".gsd"), { recursive: true });
registerToolCompatibility("screenshot", { producesImages: true });
const result = applyModelPolicyFilter(
[
{ id: "openai-image", provider: "openai", api: "openai-responses" },
{ id: "anthropic-ok", provider: "anthropic", api: "anthropic-messages" },
{ id: "gemini-api-deny", provider: "google", api: "google-generative-ai" },
{ id: "blocked-provider", provider: "blocked", api: "anthropic-messages" },
],
{
basePath,
traceId: "trace-model-policy-1",
turnId: "turn-model-policy-1",
unitType: "execute-task",
taskMetadata: { tags: ["docs"] },
allowCrossProvider: true,
requiredTools: ["screenshot"],
allowedApis: ["anthropic-messages", "openai-responses"],
deniedProviders: ["blocked"],
},
);
assert.deepEqual(
result.eligible.map((m) => m.id),
["anthropic-ok"],
"only the policy-compliant anthropic model should remain eligible",
);
assert.equal(result.decisions.length, 4);
assert.equal(result.decisions[0]?.allowed, false);
assert.match(result.decisions[0]?.reason ?? "", /tool policy denied/);
assert.equal(result.decisions[1]?.allowed, true);
assert.equal(result.decisions[2]?.allowed, false);
assert.match(result.decisions[2]?.reason ?? "", /transport\/api denied by policy/);
assert.equal(result.decisions[3]?.allowed, false);
assert.match(result.decisions[3]?.reason ?? "", /provider denied by policy/);
const auditLogPath = join(basePath, ".gsd", "audit", "events.jsonl");
const auditLines = readFileSync(auditLogPath, "utf-8")
.trim()
.split("\n")
.map((line) => JSON.parse(line) as { type: string; payload?: { reason?: string } });
const decisionTypes = auditLines.map((event) => event.type);
assert.equal(auditLines.length, 4);
assert.ok(decisionTypes.includes("model-policy-allow"));
assert.ok(decisionTypes.includes("model-policy-deny"));
assert.ok(
auditLines.some((event) => (event.payload?.reason ?? "").includes("tool policy denied")),
"audit stream should include explicit deny reasons",
);
} finally {
rmSync(basePath, { recursive: true, force: true });
}
});