feat(gsd-uok): enforce model policy filtering before routing
This commit is contained in:
parent
00521b1418
commit
414c2ee58c
3 changed files with 124 additions and 3 deletions
|
|
@ -10,7 +10,7 @@ import type { ExtensionAPI, ExtensionContext } from "@gsd/pi-coding-agent";
|
|||
import type { GSDPreferences } from "./preferences.js";
|
||||
import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig } from "./preferences.js";
|
||||
import type { ComplexityTier } from "./complexity-classifier.js";
|
||||
import { classifyUnitComplexity, tierLabel } from "./complexity-classifier.js";
|
||||
import { classifyUnitComplexity, extractTaskMetadata, tierLabel } from "./complexity-classifier.js";
|
||||
import { resolveModelForComplexity, escalateTier, getEligibleModels, loadCapabilityOverrides, adjustToolSet, filterToolsForProvider } from "./model-router.js";
|
||||
import { getLedger, getProjectTotals } from "./metrics.js";
|
||||
import { unitPhaseLabel } from "./auto-dashboard.js";
|
||||
|
|
@ -120,6 +120,10 @@ export async function selectAndApplyModel(
|
|||
let routingTierLabel = "";
|
||||
let routingEligibleModels = availableModels;
|
||||
|
||||
const taskMetadataForPolicy = unitType === "execute-task"
|
||||
? extractTaskMetadata(unitId, basePath)
|
||||
: undefined;
|
||||
|
||||
if (uokFlags.modelPolicy) {
|
||||
const policy = applyModelPolicyFilter(
|
||||
availableModels,
|
||||
|
|
@ -128,6 +132,7 @@ export async function selectAndApplyModel(
|
|||
traceId: modelPolicyTraceId,
|
||||
turnId: modelPolicyTurnId,
|
||||
unitType,
|
||||
taskMetadata: taskMetadataForPolicy,
|
||||
currentProvider: ctx.model?.provider,
|
||||
allowCrossProvider: routingConfig.cross_provider !== false,
|
||||
requiredTools: pi.getActiveTools(),
|
||||
|
|
@ -182,7 +187,13 @@ export async function selectAndApplyModel(
|
|||
const shouldClassify = !isHook || routingConfig.hooks !== false;
|
||||
|
||||
if (shouldClassify) {
|
||||
let classification = classifyUnitComplexity(unitType, unitId, basePath, budgetPct);
|
||||
let classification = classifyUnitComplexity(
|
||||
unitType,
|
||||
unitId,
|
||||
basePath,
|
||||
budgetPct,
|
||||
taskMetadataForPolicy,
|
||||
);
|
||||
const availableModelIds = routingEligibleModels.map(m => m.id);
|
||||
|
||||
// Escalate tier on retry when escalate_on_failure is enabled (default: true)
|
||||
|
|
@ -293,7 +304,8 @@ export async function selectAndApplyModel(
|
|||
let attemptedPolicyEligible = false;
|
||||
|
||||
for (const modelId of modelsToTry) {
|
||||
const model = resolveModelId(modelId, availableModels, ctx.model?.provider);
|
||||
const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels;
|
||||
const model = resolveModelId(modelId, resolutionPool, ctx.model?.provider);
|
||||
|
||||
if (!model) {
|
||||
if (verbose) ctx.ui.notify(`Model ${modelId} not found, trying fallback.`, "info");
|
||||
|
|
|
|||
|
|
@ -227,6 +227,26 @@ test("model change notify in selectAndApplyModel is gated behind verbose flag",
|
|||
);
|
||||
});
|
||||
|
||||
test("model policy resolves candidates from the policy-eligible pool", () => {
|
||||
const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8");
|
||||
assert.ok(
|
||||
src.includes("const resolutionPool = uokFlags.modelPolicy ? routingEligibleModels : availableModels"),
|
||||
"selectAndApplyModel should resolve model IDs against policy-eligible models when model policy is enabled",
|
||||
);
|
||||
});
|
||||
|
||||
test("model policy receives task metadata for requirement-vector decisions", () => {
|
||||
const src = readFileSync(join(__dirname, "..", "auto-model-selection.ts"), "utf-8");
|
||||
assert.ok(
|
||||
src.includes("taskMetadata: taskMetadataForPolicy"),
|
||||
"applyModelPolicyFilter should receive task metadata so requirement vectors are unit-aware",
|
||||
);
|
||||
assert.ok(
|
||||
src.includes("extractTaskMetadata(unitId, basePath)"),
|
||||
"execute-task dispatch should derive metadata before policy filtering",
|
||||
);
|
||||
});
|
||||
|
||||
test("resolveModelId: anthropic wins over claude-code when session provider is not claude-code", () => {
|
||||
const availableModels = [
|
||||
{ id: "claude-sonnet-4-6", provider: "claude-code" },
|
||||
|
|
|
|||
89
src/resources/extensions/gsd/tests/uok-model-policy.test.ts
Normal file
89
src/resources/extensions/gsd/tests/uok-model-policy.test.ts
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { mkdtempSync, mkdirSync, readFileSync, rmSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
import { tmpdir } from "node:os";
|
||||
|
||||
import {
|
||||
applyModelPolicyFilter,
|
||||
buildRequirementVector,
|
||||
} from "../uok/model-policy.ts";
|
||||
import {
|
||||
registerToolCompatibility,
|
||||
resetToolCompatibilityRegistry,
|
||||
} from "@gsd/pi-coding-agent";
|
||||
|
||||
test.afterEach(() => {
|
||||
resetToolCompatibilityRegistry();
|
||||
});
|
||||
|
||||
test("uok model policy builds requirement vectors from unit metadata", () => {
|
||||
const requirements = buildRequirementVector("execute-task", {
|
||||
tags: ["docs"],
|
||||
fileCount: 8,
|
||||
estimatedLines: 600,
|
||||
});
|
||||
|
||||
assert.equal(requirements.instruction, 0.9);
|
||||
assert.equal(requirements.coding, 0.3);
|
||||
assert.equal(requirements.speed, 0.7);
|
||||
});
|
||||
|
||||
test("uok model policy enforces provider/api/tool constraints and emits decision audit events", () => {
|
||||
const basePath = mkdtempSync(join(tmpdir(), "gsd-uok-model-policy-"));
|
||||
try {
|
||||
mkdirSync(join(basePath, ".gsd"), { recursive: true });
|
||||
registerToolCompatibility("screenshot", { producesImages: true });
|
||||
|
||||
const result = applyModelPolicyFilter(
|
||||
[
|
||||
{ id: "openai-image", provider: "openai", api: "openai-responses" },
|
||||
{ id: "anthropic-ok", provider: "anthropic", api: "anthropic-messages" },
|
||||
{ id: "gemini-api-deny", provider: "google", api: "google-generative-ai" },
|
||||
{ id: "blocked-provider", provider: "blocked", api: "anthropic-messages" },
|
||||
],
|
||||
{
|
||||
basePath,
|
||||
traceId: "trace-model-policy-1",
|
||||
turnId: "turn-model-policy-1",
|
||||
unitType: "execute-task",
|
||||
taskMetadata: { tags: ["docs"] },
|
||||
allowCrossProvider: true,
|
||||
requiredTools: ["screenshot"],
|
||||
allowedApis: ["anthropic-messages", "openai-responses"],
|
||||
deniedProviders: ["blocked"],
|
||||
},
|
||||
);
|
||||
|
||||
assert.deepEqual(
|
||||
result.eligible.map((m) => m.id),
|
||||
["anthropic-ok"],
|
||||
"only the policy-compliant anthropic model should remain eligible",
|
||||
);
|
||||
assert.equal(result.decisions.length, 4);
|
||||
assert.equal(result.decisions[0]?.allowed, false);
|
||||
assert.match(result.decisions[0]?.reason ?? "", /tool policy denied/);
|
||||
assert.equal(result.decisions[1]?.allowed, true);
|
||||
assert.equal(result.decisions[2]?.allowed, false);
|
||||
assert.match(result.decisions[2]?.reason ?? "", /transport\/api denied by policy/);
|
||||
assert.equal(result.decisions[3]?.allowed, false);
|
||||
assert.match(result.decisions[3]?.reason ?? "", /provider denied by policy/);
|
||||
|
||||
const auditLogPath = join(basePath, ".gsd", "audit", "events.jsonl");
|
||||
const auditLines = readFileSync(auditLogPath, "utf-8")
|
||||
.trim()
|
||||
.split("\n")
|
||||
.map((line) => JSON.parse(line) as { type: string; payload?: { reason?: string } });
|
||||
const decisionTypes = auditLines.map((event) => event.type);
|
||||
|
||||
assert.equal(auditLines.length, 4);
|
||||
assert.ok(decisionTypes.includes("model-policy-allow"));
|
||||
assert.ok(decisionTypes.includes("model-policy-deny"));
|
||||
assert.ok(
|
||||
auditLines.some((event) => (event.payload?.reason ?? "").includes("tool policy denied")),
|
||||
"audit stream should include explicit deny reasons",
|
||||
);
|
||||
} finally {
|
||||
rmSync(basePath, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
Loading…
Add table
Reference in a new issue