subagent: add per-call model override (Phase 1 of skill dispatch)
Adds an optional model param to SubagentParams, TaskItem, and ChainItem so callers can override the agent's default model at dispatch time. This is the primitive that ace-coder's Task() tool exposes via its `model` arg — SF's subagent tool previously ignored model at the tool level, picking it up only from the named agent's .md frontmatter. - SubagentParams.model applies to single mode, or as a batch-level default for tasks/chain steps that don't set their own. - TaskItem.model and ChainItem.model override per-task / per-step. - runSingleAgent and runSingleAgentInCmuxSplit gain a trailing modelOverride parameter that flows into buildSubagentProcessArgs. - buildSubagentProcessArgs uses modelOverride ?? agent.model when picking the --model arg for the child process. Side benefit: retroactively fixes the latent bug where reactive_execution.subagent_model was threaded into prompt instructions but ignored by the actual tool. 9 regression tests added in subagent/tests/model-override.test.ts. All 53 subagent-related tests pass. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
254fba36c0
commit
55ee2cb5c7
2 changed files with 138 additions and 8 deletions
|
|
@ -264,9 +264,11 @@ function buildSubagentProcessArgs(
|
|||
agent: AgentConfig,
|
||||
task: string,
|
||||
tmpPromptPath: string | null,
|
||||
modelOverride?: string,
|
||||
): string[] {
|
||||
const args: string[] = ["--mode", "json", "-p", "--no-session"];
|
||||
if (agent.model) args.push("--model", agent.model);
|
||||
const modelToUse = modelOverride ?? agent.model;
|
||||
if (modelToUse) args.push("--model", modelToUse);
|
||||
if (agent.tools && agent.tools.length > 0) args.push("--tools", agent.tools.join(","));
|
||||
if (tmpPromptPath) args.push("--append-system-prompt", tmpPromptPath);
|
||||
args.push(`Task: ${task}`);
|
||||
|
|
@ -336,6 +338,7 @@ async function runSingleAgent(
|
|||
signal: AbortSignal | undefined,
|
||||
onUpdate: OnUpdateCallback | undefined,
|
||||
makeDetails: (results: SingleResult[]) => SubagentDetails,
|
||||
modelOverride?: string,
|
||||
): Promise<SingleResult> {
|
||||
const agent = agents.find((a) => a.name === agentName);
|
||||
|
||||
|
|
@ -381,7 +384,7 @@ async function runSingleAgent(
|
|||
messages: [],
|
||||
stderr: "",
|
||||
usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, cost: 0, contextTokens: 0, turns: 0 },
|
||||
model: agent.model,
|
||||
model: modelOverride ?? agent.model,
|
||||
step,
|
||||
};
|
||||
|
||||
|
|
@ -400,7 +403,7 @@ async function runSingleAgent(
|
|||
tmpPromptDir = tmp.dir;
|
||||
tmpPromptPath = tmp.filePath;
|
||||
}
|
||||
const args = buildSubagentProcessArgs(agent, task, tmpPromptPath);
|
||||
const args = buildSubagentProcessArgs(agent, task, tmpPromptPath, modelOverride);
|
||||
let wasAborted = false;
|
||||
|
||||
const exitCode = await new Promise<number>((resolve) => {
|
||||
|
|
@ -480,10 +483,11 @@ async function runSingleAgentInCmuxSplit(
|
|||
signal: AbortSignal | undefined,
|
||||
onUpdate: OnUpdateCallback | undefined,
|
||||
makeDetails: (results: SingleResult[]) => SubagentDetails,
|
||||
modelOverride?: string,
|
||||
): Promise<SingleResult> {
|
||||
const agent = agents.find((a) => a.name === agentName);
|
||||
if (!agent) {
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails);
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride);
|
||||
}
|
||||
|
||||
let tmpPromptDir: string | null = null;
|
||||
|
|
@ -498,7 +502,7 @@ async function runSingleAgentInCmuxSplit(
|
|||
messages: [],
|
||||
stderr: "",
|
||||
usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, cost: 0, contextTokens: 0, turns: 0 },
|
||||
model: agent.model,
|
||||
model: modelOverride ?? agent.model,
|
||||
step,
|
||||
};
|
||||
|
||||
|
|
@ -528,12 +532,12 @@ async function runSingleAgentInCmuxSplit(
|
|||
? await cmuxClient.createSplit(directionOrSurfaceId as "right" | "down" | "left" | "up")
|
||||
: directionOrSurfaceId;
|
||||
if (!cmuxSurfaceId) {
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails);
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride);
|
||||
}
|
||||
|
||||
const bundledPaths = (process.env.SF_BUNDLED_EXTENSION_PATHS ?? "").split(path.delimiter).map((s) => s.trim()).filter(Boolean);
|
||||
const extensionArgs = bundledPaths.flatMap((p) => ["--extension", p]);
|
||||
const processArgs = [process.env.SF_BIN_PATH!, ...extensionArgs, ...buildSubagentProcessArgs(agent, task, tmpPromptPath)];
|
||||
const processArgs = [process.env.SF_BIN_PATH!, ...extensionArgs, ...buildSubagentProcessArgs(agent, task, tmpPromptPath, modelOverride)];
|
||||
// Normalize all paths to forward slashes before embedding in bash strings.
|
||||
// On Windows, backslashes are interpreted as escape characters by bash,
|
||||
// mangling paths like C:\Users\user into C:Useruser (#1436).
|
||||
|
|
@ -548,7 +552,7 @@ async function runSingleAgentInCmuxSplit(
|
|||
|
||||
const sent = await cmuxClient.sendSurface(cmuxSurfaceId, `bash -lc ${shellEscape(innerScript)}`);
|
||||
if (!sent) {
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails);
|
||||
return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride);
|
||||
}
|
||||
|
||||
const finished = await waitForFile(exitPath, signal);
|
||||
|
|
@ -595,12 +599,14 @@ const TaskItem = Type.Object({
|
|||
agent: Type.String({ description: "Name of the agent to invoke" }),
|
||||
task: Type.String({ description: "Task to delegate to the agent" }),
|
||||
cwd: Type.Optional(Type.String({ description: "Working directory for the agent process" })),
|
||||
model: Type.Optional(Type.String({ description: "Override the agent's default model for this task" })),
|
||||
});
|
||||
|
||||
const ChainItem = Type.Object({
|
||||
agent: Type.String({ description: "Name of the agent to invoke" }),
|
||||
task: Type.String({ description: "Task with optional {previous} placeholder for prior output" }),
|
||||
cwd: Type.Optional(Type.String({ description: "Working directory for the agent process" })),
|
||||
model: Type.Optional(Type.String({ description: "Override the agent's default model for this step" })),
|
||||
});
|
||||
|
||||
const AgentScopeSchema = StringEnum(["user", "project", "both"] as const, {
|
||||
|
|
@ -611,6 +617,9 @@ const AgentScopeSchema = StringEnum(["user", "project", "both"] as const, {
|
|||
const SubagentParams = Type.Object({
|
||||
agent: Type.Optional(Type.String({ description: "Name of the agent to invoke (for single mode)" })),
|
||||
task: Type.Optional(Type.String({ description: "Task to delegate (for single mode)" })),
|
||||
model: Type.Optional(Type.String({
|
||||
description: "Override the agent's default model. Applies to single mode, or as a default for all tasks/chain steps unless they set their own `model`.",
|
||||
})),
|
||||
tasks: Type.Optional(Type.Array(TaskItem, { description: "Array of {agent, task} for parallel execution" })),
|
||||
chain: Type.Optional(Type.Array(ChainItem, { description: "Array of {agent, task} for sequential execution" })),
|
||||
agentScope: Type.Optional(AgentScopeSchema),
|
||||
|
|
@ -767,6 +776,7 @@ export default function (pi: ExtensionAPI) {
|
|||
signal,
|
||||
chainUpdate,
|
||||
makeDetails("chain"),
|
||||
step.model ?? params.model,
|
||||
);
|
||||
results.push(result);
|
||||
|
||||
|
|
@ -839,6 +849,7 @@ export default function (pi: ExtensionAPI) {
|
|||
: [];
|
||||
const results = await mapWithConcurrencyLimit(params.tasks, MAX_CONCURRENCY, async (t, index) => {
|
||||
const workerId = registerWorker(t.agent, t.task, index, batchSize, batchId);
|
||||
const taskModelOverride = t.model ?? params.model;
|
||||
const runTask = () => cmuxSplitsEnabled
|
||||
? runSingleAgentInCmuxSplit(
|
||||
cmuxClient,
|
||||
|
|
@ -857,6 +868,7 @@ export default function (pi: ExtensionAPI) {
|
|||
}
|
||||
},
|
||||
makeDetails("parallel"),
|
||||
taskModelOverride,
|
||||
)
|
||||
: runSingleAgent(
|
||||
ctx.cwd,
|
||||
|
|
@ -873,6 +885,7 @@ export default function (pi: ExtensionAPI) {
|
|||
}
|
||||
},
|
||||
makeDetails("parallel"),
|
||||
taskModelOverride,
|
||||
);
|
||||
let result = await runTask();
|
||||
|
||||
|
|
@ -931,6 +944,7 @@ export default function (pi: ExtensionAPI) {
|
|||
signal,
|
||||
onUpdate,
|
||||
makeDetails("single"),
|
||||
params.model,
|
||||
)
|
||||
: await runSingleAgent(
|
||||
ctx.cwd,
|
||||
|
|
@ -942,6 +956,7 @@ export default function (pi: ExtensionAPI) {
|
|||
signal,
|
||||
onUpdate,
|
||||
makeDetails("single"),
|
||||
params.model,
|
||||
);
|
||||
|
||||
// Capture and merge delta if isolated
|
||||
|
|
|
|||
115
src/resources/extensions/subagent/tests/model-override.test.ts
Normal file
115
src/resources/extensions/subagent/tests/model-override.test.ts
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
/**
|
||||
* Regression tests for the per-call `model` override on the subagent tool.
|
||||
*
|
||||
* Validates:
|
||||
* - `model` is declared on SubagentParams, TaskItem, and ChainItem schemas.
|
||||
* - `buildSubagentProcessArgs` uses the override when provided, and falls
|
||||
* back to `agent.model` otherwise.
|
||||
* - `runSingleAgent` and `runSingleAgentInCmuxSplit` accept a trailing
|
||||
* `modelOverride` parameter.
|
||||
* - The three dispatch modes (single, parallel, chain) pass the override
|
||||
* through to the runner functions.
|
||||
*
|
||||
* These tests are structural (source-grep) because the runner spawns a
|
||||
* child process that is hard to fake in a unit test; the existing
|
||||
* subagent-model-dispatch.test.ts uses the same pattern.
|
||||
*/
|
||||
|
||||
import test from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { join, dirname } from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
const subagentSrc = readFileSync(join(__dirname, "..", "index.ts"), "utf-8");
|
||||
|
||||
test("SubagentParams declares optional model override field", () => {
|
||||
const paramsStart = subagentSrc.indexOf("const SubagentParams = Type.Object({");
|
||||
const paramsEnd = subagentSrc.indexOf("});", paramsStart);
|
||||
const paramsBlock = subagentSrc.slice(paramsStart, paramsEnd);
|
||||
assert.match(paramsBlock, /model:\s*Type\.Optional\(Type\.String/);
|
||||
});
|
||||
|
||||
test("TaskItem declares optional model override field", () => {
|
||||
const itemStart = subagentSrc.indexOf("const TaskItem = Type.Object({");
|
||||
const itemEnd = subagentSrc.indexOf("});", itemStart);
|
||||
const itemBlock = subagentSrc.slice(itemStart, itemEnd);
|
||||
assert.match(itemBlock, /model:\s*Type\.Optional\(Type\.String/);
|
||||
});
|
||||
|
||||
test("ChainItem declares optional model override field", () => {
|
||||
const itemStart = subagentSrc.indexOf("const ChainItem = Type.Object({");
|
||||
const itemEnd = subagentSrc.indexOf("});", itemStart);
|
||||
const itemBlock = subagentSrc.slice(itemStart, itemEnd);
|
||||
assert.match(itemBlock, /model:\s*Type\.Optional\(Type\.String/);
|
||||
});
|
||||
|
||||
test("buildSubagentProcessArgs prefers modelOverride over agent.model", () => {
|
||||
const fnStart = subagentSrc.indexOf("function buildSubagentProcessArgs(");
|
||||
const fnEnd = subagentSrc.indexOf("}", subagentSrc.indexOf("return args;", fnStart));
|
||||
const fn = subagentSrc.slice(fnStart, fnEnd);
|
||||
|
||||
assert.match(fn, /modelOverride\?\s*:\s*string/, "signature should accept modelOverride");
|
||||
assert.match(
|
||||
fn,
|
||||
/const\s+modelToUse\s*=\s*modelOverride\s*\?\?\s*agent\.model/,
|
||||
"should coalesce override first, then agent.model",
|
||||
);
|
||||
});
|
||||
|
||||
test("runSingleAgent accepts trailing modelOverride parameter", () => {
|
||||
const fnStart = subagentSrc.indexOf("async function runSingleAgent(");
|
||||
const fnEnd = subagentSrc.indexOf("): Promise<SingleResult>", fnStart);
|
||||
const signature = subagentSrc.slice(fnStart, fnEnd);
|
||||
assert.match(signature, /modelOverride\?\s*:\s*string/);
|
||||
});
|
||||
|
||||
test("runSingleAgentInCmuxSplit accepts trailing modelOverride parameter", () => {
|
||||
const fnStart = subagentSrc.indexOf("async function runSingleAgentInCmuxSplit(");
|
||||
const fnEnd = subagentSrc.indexOf("): Promise<SingleResult>", fnStart);
|
||||
const signature = subagentSrc.slice(fnStart, fnEnd);
|
||||
assert.match(signature, /modelOverride\?\s*:\s*string/);
|
||||
});
|
||||
|
||||
test("single-mode dispatch forwards params.model to the runner", () => {
|
||||
// The single-mode runSingleAgent call lives after `if (params.agent && params.task)`.
|
||||
const singleStart = subagentSrc.indexOf("if (params.agent && params.task) {");
|
||||
const singleEnd = subagentSrc.indexOf("let outputText", singleStart);
|
||||
const block = subagentSrc.slice(singleStart, singleEnd);
|
||||
// Both the cmux and non-cmux branches should pass params.model as the trailing arg.
|
||||
const occurrences = (block.match(/params\.model,/g) ?? []).length;
|
||||
assert.ok(
|
||||
occurrences >= 2,
|
||||
`expected params.model to be passed in both cmux+non-cmux single paths; found ${occurrences}`,
|
||||
);
|
||||
});
|
||||
|
||||
test("parallel-mode dispatch coalesces per-task model with params.model fallback", () => {
|
||||
const parallelStart = subagentSrc.indexOf("if (params.tasks && params.tasks.length > 0) {");
|
||||
// End at the "if (params.agent && params.task)" block — everything before that
|
||||
// is the parallel dispatch path.
|
||||
const parallelEnd = subagentSrc.indexOf("if (params.agent && params.task) {", parallelStart);
|
||||
const block = subagentSrc.slice(parallelStart, parallelEnd);
|
||||
assert.match(
|
||||
block,
|
||||
/const\s+taskModelOverride\s*=\s*t\.model\s*\?\?\s*params\.model/,
|
||||
"per-task override must coalesce with batch-level params.model",
|
||||
);
|
||||
const occurrences = (block.match(/taskModelOverride/g) ?? []).length;
|
||||
assert.ok(
|
||||
occurrences >= 3,
|
||||
`expected taskModelOverride to be declared once and passed into both cmux+non-cmux branches; found ${occurrences}`,
|
||||
);
|
||||
});
|
||||
|
||||
test("chain-mode dispatch coalesces per-step model with params.model fallback", () => {
|
||||
const chainStart = subagentSrc.indexOf("if (params.chain && params.chain.length > 0) {");
|
||||
const chainEnd = subagentSrc.indexOf("return {", chainStart);
|
||||
const block = subagentSrc.slice(chainStart, chainEnd);
|
||||
assert.match(
|
||||
block,
|
||||
/step\.model\s*\?\?\s*params\.model/,
|
||||
"chain step override must coalesce with batch-level params.model",
|
||||
);
|
||||
});
|
||||
Loading…
Add table
Reference in a new issue