diff --git a/src/resources/extensions/subagent/index.ts b/src/resources/extensions/subagent/index.ts index ff7feb6a3..118e40ab6 100644 --- a/src/resources/extensions/subagent/index.ts +++ b/src/resources/extensions/subagent/index.ts @@ -264,9 +264,11 @@ function buildSubagentProcessArgs( agent: AgentConfig, task: string, tmpPromptPath: string | null, + modelOverride?: string, ): string[] { const args: string[] = ["--mode", "json", "-p", "--no-session"]; - if (agent.model) args.push("--model", agent.model); + const modelToUse = modelOverride ?? agent.model; + if (modelToUse) args.push("--model", modelToUse); if (agent.tools && agent.tools.length > 0) args.push("--tools", agent.tools.join(",")); if (tmpPromptPath) args.push("--append-system-prompt", tmpPromptPath); args.push(`Task: ${task}`); @@ -336,6 +338,7 @@ async function runSingleAgent( signal: AbortSignal | undefined, onUpdate: OnUpdateCallback | undefined, makeDetails: (results: SingleResult[]) => SubagentDetails, + modelOverride?: string, ): Promise { const agent = agents.find((a) => a.name === agentName); @@ -381,7 +384,7 @@ async function runSingleAgent( messages: [], stderr: "", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, cost: 0, contextTokens: 0, turns: 0 }, - model: agent.model, + model: modelOverride ?? agent.model, step, }; @@ -400,7 +403,7 @@ async function runSingleAgent( tmpPromptDir = tmp.dir; tmpPromptPath = tmp.filePath; } - const args = buildSubagentProcessArgs(agent, task, tmpPromptPath); + const args = buildSubagentProcessArgs(agent, task, tmpPromptPath, modelOverride); let wasAborted = false; const exitCode = await new Promise((resolve) => { @@ -480,10 +483,11 @@ async function runSingleAgentInCmuxSplit( signal: AbortSignal | undefined, onUpdate: OnUpdateCallback | undefined, makeDetails: (results: SingleResult[]) => SubagentDetails, + modelOverride?: string, ): Promise { const agent = agents.find((a) => a.name === agentName); if (!agent) { - return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails); + return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride); } let tmpPromptDir: string | null = null; @@ -498,7 +502,7 @@ async function runSingleAgentInCmuxSplit( messages: [], stderr: "", usage: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, cost: 0, contextTokens: 0, turns: 0 }, - model: agent.model, + model: modelOverride ?? agent.model, step, }; @@ -528,12 +532,12 @@ async function runSingleAgentInCmuxSplit( ? await cmuxClient.createSplit(directionOrSurfaceId as "right" | "down" | "left" | "up") : directionOrSurfaceId; if (!cmuxSurfaceId) { - return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails); + return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride); } const bundledPaths = (process.env.SF_BUNDLED_EXTENSION_PATHS ?? "").split(path.delimiter).map((s) => s.trim()).filter(Boolean); const extensionArgs = bundledPaths.flatMap((p) => ["--extension", p]); - const processArgs = [process.env.SF_BIN_PATH!, ...extensionArgs, ...buildSubagentProcessArgs(agent, task, tmpPromptPath)]; + const processArgs = [process.env.SF_BIN_PATH!, ...extensionArgs, ...buildSubagentProcessArgs(agent, task, tmpPromptPath, modelOverride)]; // Normalize all paths to forward slashes before embedding in bash strings. // On Windows, backslashes are interpreted as escape characters by bash, // mangling paths like C:\Users\user into C:Useruser (#1436). @@ -548,7 +552,7 @@ async function runSingleAgentInCmuxSplit( const sent = await cmuxClient.sendSurface(cmuxSurfaceId, `bash -lc ${shellEscape(innerScript)}`); if (!sent) { - return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails); + return runSingleAgent(defaultCwd, agents, agentName, task, cwd, step, signal, onUpdate, makeDetails, modelOverride); } const finished = await waitForFile(exitPath, signal); @@ -595,12 +599,14 @@ const TaskItem = Type.Object({ agent: Type.String({ description: "Name of the agent to invoke" }), task: Type.String({ description: "Task to delegate to the agent" }), cwd: Type.Optional(Type.String({ description: "Working directory for the agent process" })), + model: Type.Optional(Type.String({ description: "Override the agent's default model for this task" })), }); const ChainItem = Type.Object({ agent: Type.String({ description: "Name of the agent to invoke" }), task: Type.String({ description: "Task with optional {previous} placeholder for prior output" }), cwd: Type.Optional(Type.String({ description: "Working directory for the agent process" })), + model: Type.Optional(Type.String({ description: "Override the agent's default model for this step" })), }); const AgentScopeSchema = StringEnum(["user", "project", "both"] as const, { @@ -611,6 +617,9 @@ const AgentScopeSchema = StringEnum(["user", "project", "both"] as const, { const SubagentParams = Type.Object({ agent: Type.Optional(Type.String({ description: "Name of the agent to invoke (for single mode)" })), task: Type.Optional(Type.String({ description: "Task to delegate (for single mode)" })), + model: Type.Optional(Type.String({ + description: "Override the agent's default model. Applies to single mode, or as a default for all tasks/chain steps unless they set their own `model`.", + })), tasks: Type.Optional(Type.Array(TaskItem, { description: "Array of {agent, task} for parallel execution" })), chain: Type.Optional(Type.Array(ChainItem, { description: "Array of {agent, task} for sequential execution" })), agentScope: Type.Optional(AgentScopeSchema), @@ -767,6 +776,7 @@ export default function (pi: ExtensionAPI) { signal, chainUpdate, makeDetails("chain"), + step.model ?? params.model, ); results.push(result); @@ -839,6 +849,7 @@ export default function (pi: ExtensionAPI) { : []; const results = await mapWithConcurrencyLimit(params.tasks, MAX_CONCURRENCY, async (t, index) => { const workerId = registerWorker(t.agent, t.task, index, batchSize, batchId); + const taskModelOverride = t.model ?? params.model; const runTask = () => cmuxSplitsEnabled ? runSingleAgentInCmuxSplit( cmuxClient, @@ -857,6 +868,7 @@ export default function (pi: ExtensionAPI) { } }, makeDetails("parallel"), + taskModelOverride, ) : runSingleAgent( ctx.cwd, @@ -873,6 +885,7 @@ export default function (pi: ExtensionAPI) { } }, makeDetails("parallel"), + taskModelOverride, ); let result = await runTask(); @@ -931,6 +944,7 @@ export default function (pi: ExtensionAPI) { signal, onUpdate, makeDetails("single"), + params.model, ) : await runSingleAgent( ctx.cwd, @@ -942,6 +956,7 @@ export default function (pi: ExtensionAPI) { signal, onUpdate, makeDetails("single"), + params.model, ); // Capture and merge delta if isolated diff --git a/src/resources/extensions/subagent/tests/model-override.test.ts b/src/resources/extensions/subagent/tests/model-override.test.ts new file mode 100644 index 000000000..6d3af97a7 --- /dev/null +++ b/src/resources/extensions/subagent/tests/model-override.test.ts @@ -0,0 +1,115 @@ +/** + * Regression tests for the per-call `model` override on the subagent tool. + * + * Validates: + * - `model` is declared on SubagentParams, TaskItem, and ChainItem schemas. + * - `buildSubagentProcessArgs` uses the override when provided, and falls + * back to `agent.model` otherwise. + * - `runSingleAgent` and `runSingleAgentInCmuxSplit` accept a trailing + * `modelOverride` parameter. + * - The three dispatch modes (single, parallel, chain) pass the override + * through to the runner functions. + * + * These tests are structural (source-grep) because the runner spawns a + * child process that is hard to fake in a unit test; the existing + * subagent-model-dispatch.test.ts uses the same pattern. + */ + +import test from "node:test"; +import assert from "node:assert/strict"; +import { readFileSync } from "node:fs"; +import { join, dirname } from "node:path"; +import { fileURLToPath } from "node:url"; + +const __dirname = dirname(fileURLToPath(import.meta.url)); +const subagentSrc = readFileSync(join(__dirname, "..", "index.ts"), "utf-8"); + +test("SubagentParams declares optional model override field", () => { + const paramsStart = subagentSrc.indexOf("const SubagentParams = Type.Object({"); + const paramsEnd = subagentSrc.indexOf("});", paramsStart); + const paramsBlock = subagentSrc.slice(paramsStart, paramsEnd); + assert.match(paramsBlock, /model:\s*Type\.Optional\(Type\.String/); +}); + +test("TaskItem declares optional model override field", () => { + const itemStart = subagentSrc.indexOf("const TaskItem = Type.Object({"); + const itemEnd = subagentSrc.indexOf("});", itemStart); + const itemBlock = subagentSrc.slice(itemStart, itemEnd); + assert.match(itemBlock, /model:\s*Type\.Optional\(Type\.String/); +}); + +test("ChainItem declares optional model override field", () => { + const itemStart = subagentSrc.indexOf("const ChainItem = Type.Object({"); + const itemEnd = subagentSrc.indexOf("});", itemStart); + const itemBlock = subagentSrc.slice(itemStart, itemEnd); + assert.match(itemBlock, /model:\s*Type\.Optional\(Type\.String/); +}); + +test("buildSubagentProcessArgs prefers modelOverride over agent.model", () => { + const fnStart = subagentSrc.indexOf("function buildSubagentProcessArgs("); + const fnEnd = subagentSrc.indexOf("}", subagentSrc.indexOf("return args;", fnStart)); + const fn = subagentSrc.slice(fnStart, fnEnd); + + assert.match(fn, /modelOverride\?\s*:\s*string/, "signature should accept modelOverride"); + assert.match( + fn, + /const\s+modelToUse\s*=\s*modelOverride\s*\?\?\s*agent\.model/, + "should coalesce override first, then agent.model", + ); +}); + +test("runSingleAgent accepts trailing modelOverride parameter", () => { + const fnStart = subagentSrc.indexOf("async function runSingleAgent("); + const fnEnd = subagentSrc.indexOf("): Promise", fnStart); + const signature = subagentSrc.slice(fnStart, fnEnd); + assert.match(signature, /modelOverride\?\s*:\s*string/); +}); + +test("runSingleAgentInCmuxSplit accepts trailing modelOverride parameter", () => { + const fnStart = subagentSrc.indexOf("async function runSingleAgentInCmuxSplit("); + const fnEnd = subagentSrc.indexOf("): Promise", fnStart); + const signature = subagentSrc.slice(fnStart, fnEnd); + assert.match(signature, /modelOverride\?\s*:\s*string/); +}); + +test("single-mode dispatch forwards params.model to the runner", () => { + // The single-mode runSingleAgent call lives after `if (params.agent && params.task)`. + const singleStart = subagentSrc.indexOf("if (params.agent && params.task) {"); + const singleEnd = subagentSrc.indexOf("let outputText", singleStart); + const block = subagentSrc.slice(singleStart, singleEnd); + // Both the cmux and non-cmux branches should pass params.model as the trailing arg. + const occurrences = (block.match(/params\.model,/g) ?? []).length; + assert.ok( + occurrences >= 2, + `expected params.model to be passed in both cmux+non-cmux single paths; found ${occurrences}`, + ); +}); + +test("parallel-mode dispatch coalesces per-task model with params.model fallback", () => { + const parallelStart = subagentSrc.indexOf("if (params.tasks && params.tasks.length > 0) {"); + // End at the "if (params.agent && params.task)" block — everything before that + // is the parallel dispatch path. + const parallelEnd = subagentSrc.indexOf("if (params.agent && params.task) {", parallelStart); + const block = subagentSrc.slice(parallelStart, parallelEnd); + assert.match( + block, + /const\s+taskModelOverride\s*=\s*t\.model\s*\?\?\s*params\.model/, + "per-task override must coalesce with batch-level params.model", + ); + const occurrences = (block.match(/taskModelOverride/g) ?? []).length; + assert.ok( + occurrences >= 3, + `expected taskModelOverride to be declared once and passed into both cmux+non-cmux branches; found ${occurrences}`, + ); +}); + +test("chain-mode dispatch coalesces per-step model with params.model fallback", () => { + const chainStart = subagentSrc.indexOf("if (params.chain && params.chain.length > 0) {"); + const chainEnd = subagentSrc.indexOf("return {", chainStart); + const block = subagentSrc.slice(chainStart, chainEnd); + assert.match( + block, + /step\.model\s*\?\?\s*params\.model/, + "chain step override must coalesce with batch-level params.model", + ); +});