diff --git a/src/resources/extensions/gsd/auto-model-selection.ts b/src/resources/extensions/gsd/auto-model-selection.ts index 7929f94be..c79ab55b2 100644 --- a/src/resources/extensions/gsd/auto-model-selection.ts +++ b/src/resources/extensions/gsd/auto-model-selection.ts @@ -4,6 +4,7 @@ * and fallback chains. */ +import type { Api, Model } from "@gsd/pi-ai"; import type { ExtensionAPI, ExtensionContext } from "@gsd/pi-coding-agent"; import type { GSDPreferences } from "./preferences.js"; import { resolveModelWithFallbacksForUnit, resolveDynamicRoutingConfig } from "./preferences.js"; @@ -16,6 +17,8 @@ import { unitPhaseLabel } from "./auto-dashboard.js"; export interface ModelSelectionResult { /** Routing metadata for metrics recording */ routing: { tier: string; modelDowngraded: boolean } | null; + /** Concrete model applied before dispatch so it can be restored after a fresh session. */ + appliedModel: Model | null; } export function resolvePreferredModelConfig( @@ -58,6 +61,7 @@ export async function selectAndApplyModel( ): Promise { const modelConfig = resolvePreferredModelConfig(unitType, autoModeStartModel); let routing: { tier: string; modelDowngraded: boolean } | null = null; + let appliedModel: Model | null = null; if (modelConfig) { const availableModels = ctx.modelRegistry.getAvailable(); @@ -146,6 +150,7 @@ export async function selectAndApplyModel( const ok = await pi.setModel(model, { persist: false }); if (ok) { + appliedModel = model; const fallbackNote = modelId === effectiveModelConfig.primary ? "" : ` (fallback from ${effectiveModelConfig.primary})`; @@ -172,12 +177,17 @@ export async function selectAndApplyModel( const ok = await pi.setModel(startModel, { persist: false }); if (!ok) { const byId = availableModels.find(m => m.id === autoModeStartModel.id); - if (byId) await pi.setModel(byId, { persist: false }); + if (byId) { + const fallbackOk = await pi.setModel(byId, { persist: false }); + if (fallbackOk) appliedModel = byId; + } + } else { + appliedModel = startModel; } } } - return { routing }; + return { routing, appliedModel }; } /** diff --git a/src/resources/extensions/gsd/auto/loop-deps.ts b/src/resources/extensions/gsd/auto/loop-deps.ts index 6a9ae6eae..565dde5a3 100644 --- a/src/resources/extensions/gsd/auto/loop-deps.ts +++ b/src/resources/extensions/gsd/auto/loop-deps.ts @@ -209,7 +209,10 @@ export interface LoopDeps { verbose: boolean, startModel: { provider: string; id: string } | null, retryContext?: { isRetry: boolean; previousTier?: string }, - ) => Promise<{ routing: { tier: string; modelDowngraded: boolean } | null }>; + ) => Promise<{ + routing: { tier: string; modelDowngraded: boolean } | null; + appliedModel: { provider: string; id: string } | null; + }>; resolveModelId: ( modelId: string, availableModels: T[], diff --git a/src/resources/extensions/gsd/auto/phases.ts b/src/resources/extensions/gsd/auto/phases.ts index c8297ee3c..06778ff1b 100644 --- a/src/resources/extensions/gsd/auto/phases.ts +++ b/src/resources/extensions/gsd/auto/phases.ts @@ -1015,6 +1015,8 @@ export async function runUnitPhase( ); s.currentUnitRouting = modelResult.routing as AutoSession["currentUnitRouting"]; + s.currentUnitModel = + modelResult.appliedModel as AutoSession["currentUnitModel"]; // Apply sidecar/pre-dispatch hook model override (takes priority over standard model selection) const hookModelOverride = sidecarItem?.model ?? iterData.hookModelOverride; @@ -1024,6 +1026,7 @@ export async function runUnitPhase( if (match) { const ok = await pi.setModel(match, { persist: false }); if (ok) { + s.currentUnitModel = match as AutoSession["currentUnitModel"]; ctx.ui.notify(`Hook model override: ${match.provider}/${match.id}`, "info"); } else { ctx.ui.notify( diff --git a/src/resources/extensions/gsd/auto/run-unit.ts b/src/resources/extensions/gsd/auto/run-unit.ts index 47512d395..c9e740171 100644 --- a/src/resources/extensions/gsd/auto/run-unit.ts +++ b/src/resources/extensions/gsd/auto/run-unit.ts @@ -71,6 +71,16 @@ export async function runUnit( return { status: "cancelled" }; } + if (s.currentUnitModel && typeof pi.setModel === "function") { + const restored = await pi.setModel(s.currentUnitModel, { persist: false }); + if (!restored) { + ctx.ui.notify( + `Failed to restore ${s.currentUnitModel.provider}/${s.currentUnitModel.id} after session creation. Using session default.`, + "warning", + ); + } + } + // ── Create the agent_end promise (per-unit one-shot) ── // This happens after newSession completes so session-switch agent_end events // from the previous session cannot resolve the new unit. diff --git a/src/resources/extensions/gsd/auto/session.ts b/src/resources/extensions/gsd/auto/session.ts index 9ca66a963..9d11545e3 100644 --- a/src/resources/extensions/gsd/auto/session.ts +++ b/src/resources/extensions/gsd/auto/session.ts @@ -16,6 +16,7 @@ * `let` or `var` declarations. */ +import type { Api, Model } from "@gsd/pi-ai"; import type { ExtensionCommandContext } from "@gsd/pi-coding-agent"; import type { GitServiceImpl } from "../git-service.js"; import type { CaptureEntry } from "../captures.js"; @@ -103,6 +104,7 @@ export class AutoSession { // ── Model state ────────────────────────────────────────────────────────── autoModeStartModel: StartModel | null = null; + currentUnitModel: Model | null = null; originalModelId: string | null = null; originalModelProvider: string | null = null; lastBudgetAlertLevel: BudgetAlertLevel = 0; @@ -190,6 +192,7 @@ export class AutoSession { // Model this.autoModeStartModel = null; + this.currentUnitModel = null; this.originalModelId = null; this.originalModelProvider = null; this.lastBudgetAlertLevel = 0; diff --git a/src/resources/extensions/gsd/tests/auto-loop.test.ts b/src/resources/extensions/gsd/tests/auto-loop.test.ts index af3fda5ca..007e358c2 100644 --- a/src/resources/extensions/gsd/tests/auto-loop.test.ts +++ b/src/resources/extensions/gsd/tests/auto-loop.test.ts @@ -79,11 +79,17 @@ function makeMockCtx() { */ function makeMockPi() { const calls: unknown[] = []; + const setModelCalls: unknown[] = []; return { sendMessage: (...args: unknown[]) => { calls.push(args); }, + setModel: async (...args: unknown[]) => { + setModelCalls.push(args); + return true; + }, calls, + setModelCalls, } as any; } @@ -227,6 +233,38 @@ test("runUnit only arms resolve after newSession completes", async () => { assert.equal(pi.calls.length, 1); }); +test("runUnit re-applies the selected unit model after newSession before dispatch", async () => { + _resetPendingResolve(); + + const callOrder: string[] = []; + const ctx = makeMockCtx(); + const pi = makeMockPi(); + pi.setModel = async (...args: unknown[]) => { + callOrder.push("setModel"); + pi.setModelCalls.push(args); + return true; + }; + pi.sendMessage = (...args: unknown[]) => { + callOrder.push("sendMessage"); + pi.calls.push(args); + }; + + const s = makeMockSession(); + s.currentUnitModel = { provider: "anthropic", id: "claude-opus-4-6" }; + + const resultPromise = runUnit(ctx, pi, s, "task", "T01", "prompt"); + + await new Promise((r) => setTimeout(r, 10)); + resolveAgentEnd(makeEvent()); + + const result = await resultPromise; + assert.equal(result.status, "completed"); + assert.deepEqual(callOrder, ["setModel", "sendMessage"]); + assert.equal(pi.setModelCalls.length, 1); + assert.deepEqual(pi.setModelCalls[0][0], s.currentUnitModel); + assert.equal(pi.calls.length, 1); +}); + // ─── Structural assertions ─────────────────────────────────────────────────── test("auto-loop.ts exports autoLoop, runUnit, resolveAgentEnd", async () => { @@ -372,7 +410,7 @@ function makeMockDeps( captureAvailableSkills: () => {}, ensurePreconditions: () => {}, updateSliceProgressCache: () => {}, - selectAndApplyModel: async () => ({ routing: null }), + selectAndApplyModel: async () => ({ routing: null, appliedModel: null }), startUnitSupervision: () => {}, getDeepDiagnostic: () => null, isDbAvailable: () => false, diff --git a/src/resources/extensions/gsd/tests/custom-engine-loop-integration.test.ts b/src/resources/extensions/gsd/tests/custom-engine-loop-integration.test.ts index 29e82ac59..0bfa91ed2 100644 --- a/src/resources/extensions/gsd/tests/custom-engine-loop-integration.test.ts +++ b/src/resources/extensions/gsd/tests/custom-engine-loop-integration.test.ts @@ -200,7 +200,7 @@ function makeMockDeps(overrides?: Partial): LoopDeps & { callLog: stri captureAvailableSkills: () => {}, ensurePreconditions: () => {}, updateSliceProgressCache: () => {}, - selectAndApplyModel: async () => ({ routing: null }), + selectAndApplyModel: async () => ({ routing: null, appliedModel: null }), resolveModelId: () => undefined, startUnitSupervision: () => {}, getDeepDiagnostic: () => null, diff --git a/src/resources/extensions/gsd/tests/journal-integration.test.ts b/src/resources/extensions/gsd/tests/journal-integration.test.ts index 8447019ce..846982e26 100644 --- a/src/resources/extensions/gsd/tests/journal-integration.test.ts +++ b/src/resources/extensions/gsd/tests/journal-integration.test.ts @@ -97,7 +97,7 @@ function makeMockDeps( captureAvailableSkills: () => {}, ensurePreconditions: () => {}, updateSliceProgressCache: () => {}, - selectAndApplyModel: async () => ({ routing: null }), + selectAndApplyModel: async () => ({ routing: null, appliedModel: null }), startUnitSupervision: () => {}, getDeepDiagnostic: () => null, isDbAvailable: () => false,