Merge pull request #3852 from jeremymcs/fix/gsd-model-switching-prefs

fix(gsd): align model switching and prefs surfaces
This commit is contained in:
Jeremy McSpadden 2026-04-09 05:47:22 -05:00 committed by GitHub
commit 5c1ea9d99c
11 changed files with 253 additions and 16 deletions

View file

@ -165,10 +165,10 @@ export function buildCategorySummaries(prefs: Record<string, unknown>): Record<s
const modeSummary = mode ?? "(not set)";
// Models
const models = prefs.models as Record<string, string> | undefined;
const models = prefs.models as Record<string, unknown> | undefined;
let modelsSummary = "(not configured)";
if (models && Object.keys(models).length > 0) {
const parts = Object.entries(models).map(([phase, model]) => `${phase}: ${model}`);
const parts = Object.entries(models).map(([phase, model]) => `${phase}: ${formatConfiguredModel(model)}`);
modelsSummary = parts.join(", ");
}
@ -255,9 +255,38 @@ export function buildCategorySummaries(prefs: Record<string, unknown>): Record<s
// ─── Category configuration functions ────────────────────────────────────────
export function formatConfiguredModel(config: unknown): string {
if (typeof config === "string") return config;
if (!config || typeof config !== "object") return "(invalid)";
const maybeConfig = config as { model?: unknown; provider?: unknown };
if (typeof maybeConfig.model !== "string" || maybeConfig.model.trim() === "") return "(invalid)";
if (typeof maybeConfig.provider === "string" && maybeConfig.provider && !maybeConfig.model.includes("/")) {
return `${maybeConfig.provider}/${maybeConfig.model}`;
}
return maybeConfig.model;
}
export function toPersistedModelId(provider: string, modelId: string): string {
if (!provider.trim()) return modelId;
const normalizedProvider = provider.trim();
const normalizedModelId = modelId.trim();
return normalizedModelId.startsWith(`${normalizedProvider}/`)
? normalizedModelId
: `${normalizedProvider}/${normalizedModelId}`;
}
async function configureModels(ctx: ExtensionCommandContext, prefs: Record<string, unknown>): Promise<void> {
const modelPhases = ["research", "planning", "execution", "completion"] as const;
const models: Record<string, string> = (prefs.models as Record<string, string>) ?? {};
const modelPhases = [
"research",
"planning",
"discuss",
"execution",
"execution_simple",
"completion",
"validation",
"subagent",
] as const;
const models: Record<string, unknown> = (prefs.models as Record<string, unknown>) ?? {};
const availableModels = ctx.modelRegistry.getAvailable();
if (availableModels.length > 0) {
@ -292,7 +321,7 @@ async function configureModels(ctx: ExtensionCommandContext, prefs: Record<strin
providerOptions.push("(keep current)", "(clear)", "(type manually)");
for (const phase of modelPhases) {
const current = models[phase] ?? "";
const current = formatConfiguredModel(models[phase]);
const phaseLabel = `Model for ${phase} phase${current ? ` (current: ${current})` : ""}`;
// Step 1: pick provider
@ -329,13 +358,13 @@ async function configureModels(ctx: ExtensionCommandContext, prefs: Record<strin
if (modelChoice === "(clear)") {
delete models[phase];
} else {
models[phase] = modelChoice;
models[phase] = toPersistedModelId(providerName, modelChoice);
}
}
}
} else {
for (const phase of modelPhases) {
const current = models[phase] ?? "";
const current = formatConfiguredModel(models[phase]);
const input = await ctx.ui.input(
`Model for ${phase} phase${current ? ` (current: ${current})` : ""}:`,
current || "e.g. claude-sonnet-4-20250514",
@ -352,6 +381,8 @@ async function configureModels(ctx: ExtensionCommandContext, prefs: Record<strin
}
if (Object.keys(models).length > 0) {
prefs.models = models;
} else {
delete prefs.models;
}
}

View file

@ -15,7 +15,7 @@ export interface GsdCommandDefinition {
type CompletionMap = Record<string, readonly GsdCommandDefinition[]>;
export const GSD_COMMAND_DESCRIPTION =
"GSD — Get Shit Done: /gsd help|start|templates|next|auto|stop|pause|status|widget|visualize|queue|quick|discuss|capture|triage|dispatch|history|undo|undo-task|reset-slice|rate|skip|export|cleanup|mode|prefs|config|keys|hooks|run-hook|skill-health|doctor|logs|forensics|changelog|migrate|remote|steer|knowledge|new-milestone|parallel|cmux|park|unpark|init|setup|inspect|extensions|update|fast|mcp|rethink|codebase|notifications";
"GSD — Get Shit Done: /gsd help|start|templates|next|auto|stop|pause|status|widget|visualize|queue|quick|discuss|capture|triage|dispatch|history|undo|undo-task|reset-slice|rate|skip|export|cleanup|model|mode|prefs|config|keys|hooks|run-hook|skill-health|doctor|logs|forensics|changelog|migrate|remote|steer|knowledge|new-milestone|parallel|cmux|park|unpark|init|setup|inspect|extensions|update|fast|mcp|rethink|codebase|notifications";
export const TOP_LEVEL_SUBCOMMANDS: readonly GsdCommandDefinition[] = [
{ cmd: "help", desc: "Categorized command reference with descriptions" },
@ -41,6 +41,7 @@ export const TOP_LEVEL_SUBCOMMANDS: readonly GsdCommandDefinition[] = [
{ cmd: "skip", desc: "Prevent a unit from auto-mode dispatch" },
{ cmd: "export", desc: "Export milestone/slice results" },
{ cmd: "cleanup", desc: "Remove merged branches or snapshots" },
{ cmd: "model", desc: "Switch the active session model or open a picker" },
{ cmd: "mode", desc: "Switch workflow mode (solo/team)" },
{ cmd: "prefs", desc: "Manage preferences (model selection, timeouts, etc.)" },
{ cmd: "config", desc: "Set API keys for external tools" },

View file

@ -14,7 +14,7 @@ export async function handleGSDCommand(
const trimmed = (typeof args === "string" ? args : "").trim();
const handlers = [
() => handleCoreCommand(trimmed, ctx),
() => handleCoreCommand(trimmed, ctx, pi),
() => handleAutoCommand(trimmed, ctx, pi),
() => handleParallelCommand(trimmed, ctx, pi),
() => handleWorkflowCommand(trimmed, ctx, pi),
@ -29,4 +29,3 @@ export async function handleGSDCommand(
ctx.ui.notify(`Unknown: /gsd ${trimmed}. Run /gsd help for available commands.`, "warning");
}

View file

@ -1,4 +1,5 @@
import type { ExtensionCommandContext, ExtensionContext } from "@gsd/pi-coding-agent";
import type { ExtensionAPI, ExtensionCommandContext, ExtensionContext } from "@gsd/pi-coding-agent";
import type { Model } from "@gsd/pi-ai";
import type { GSDState } from "../../types.js";
import { computeProgressScore, formatProgressLine } from "../../progress-score.js";
@ -48,6 +49,7 @@ export function showHelp(ctx: ExtensionCommandContext): void {
"SETUP & CONFIGURATION",
" /gsd init Project init wizard — detect, configure, bootstrap .gsd/",
" /gsd setup Global setup status [llm|search|remote|keys|prefs]",
" /gsd model Switch active session model [provider/model|model-id]",
" /gsd mode Set workflow mode (solo/team) [global|project]",
" /gsd prefs Manage preferences [global|project|status|wizard|setup|import-claude]",
" /gsd cmux Manage cmux integration [status|on|off|notifications|sidebar|splits|browser]",
@ -179,7 +181,106 @@ export async function handleSetup(args: string, ctx: ExtensionCommandContext): P
);
}
export async function handleCoreCommand(trimmed: string, ctx: ExtensionCommandContext): Promise<boolean> {
function sortModelsForSelection(models: Model<any>[], currentModel: Model<any> | undefined): Model<any>[] {
return [...models].sort((a, b) => {
const aCurrent = currentModel && a.provider === currentModel.provider && a.id === currentModel.id;
const bCurrent = currentModel && b.provider === currentModel.provider && b.id === currentModel.id;
if (aCurrent && !bCurrent) return -1;
if (!aCurrent && bCurrent) return 1;
const providerCmp = a.provider.localeCompare(b.provider);
if (providerCmp !== 0) return providerCmp;
return a.id.localeCompare(b.id);
});
}
async function resolveRequestedModel(
query: string,
ctx: ExtensionCommandContext,
): Promise<Model<any> | undefined> {
const { resolveModelId } = await import("../../auto-model-selection.js");
const models = ctx.modelRegistry.getAvailable();
const exact = resolveModelId(query, models, ctx.model?.provider);
if (exact) return exact;
const lowerQuery = query.toLowerCase();
const partialMatches = models.filter((model) =>
model.id.toLowerCase().includes(lowerQuery)
|| `${model.provider}/${model.id}`.toLowerCase().includes(lowerQuery),
);
if (partialMatches.length === 1) return partialMatches[0];
if (partialMatches.length === 0 || !ctx.hasUI) return undefined;
const sorted = sortModelsForSelection(partialMatches, ctx.model);
const optionToModel = new Map<string, Model<any>>();
const options = sorted.map((model) => {
const label = `${model.provider}/${model.id}`;
optionToModel.set(label, model);
return label;
});
options.push("(cancel)");
const choice = await ctx.ui.select(`Multiple models match "${query}" — choose one:`, options);
if (!choice || typeof choice !== "string" || choice === "(cancel)") return undefined;
return optionToModel.get(choice);
}
async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi: ExtensionAPI | undefined): Promise<void> {
const availableModels = ctx.modelRegistry.getAvailable();
if (availableModels.length === 0) {
ctx.ui.notify("No available models found. Check provider auth and model discovery.", "warning");
return;
}
if (!pi) {
ctx.ui.notify("Model switching is unavailable in this context.", "warning");
return;
}
const trimmed = trimmedArgs.trim();
let targetModel: Model<any> | undefined;
if (!trimmed) {
if (!ctx.hasUI) {
const current = ctx.model ? `${ctx.model.provider}/${ctx.model.id}` : "(none)";
ctx.ui.notify(`Current model: ${current}\nUsage: /gsd model <provider/model|model-id>`, "info");
return;
}
const optionToModel = new Map<string, Model<any>>();
const options = sortModelsForSelection(availableModels, ctx.model).map((model) => {
const isCurrent = ctx.model && model.provider === ctx.model.provider && model.id === ctx.model.id;
const label = `${isCurrent ? "* " : ""}${model.provider}/${model.id}`;
optionToModel.set(label, model);
return label;
});
options.push("(cancel)");
const choice = await ctx.ui.select("Select session model:", options);
if (!choice || typeof choice !== "string" || choice === "(cancel)") return;
targetModel = optionToModel.get(choice);
} else {
targetModel = await resolveRequestedModel(trimmed, ctx);
}
if (!targetModel) {
ctx.ui.notify(`Model "${trimmed}" not found. Use /gsd model with an exact provider/model or a unique model ID.`, "warning");
return;
}
const ok = await pi.setModel(targetModel);
if (!ok) {
ctx.ui.notify(`No API key for ${targetModel.provider}/${targetModel.id}`, "warning");
return;
}
ctx.ui.notify(`Model: ${targetModel.provider}/${targetModel.id}`, "info");
}
export async function handleCoreCommand(
trimmed: string,
ctx: ExtensionCommandContext,
pi?: ExtensionAPI,
): Promise<boolean> {
if (trimmed === "help" || trimmed === "h" || trimmed === "?") {
showHelp(ctx);
return true;
@ -203,6 +304,10 @@ export async function handleCoreCommand(trimmed: string, ctx: ExtensionCommandCo
ctx.ui.notify(`Widget: ${getWidgetMode()}`, "info");
return true;
}
if (trimmed === "model" || trimmed.startsWith("model ")) {
await handleModel(trimmed.replace(/^model\s*/, "").trim(), ctx, pi);
return true;
}
if (trimmed === "mode" || trimmed.startsWith("mode ")) {
const modeArgs = trimmed.replace(/^mode\s*/, "").trim();
const scope = modeArgs === "project" ? "project" : "global";

View file

@ -100,6 +100,18 @@ steps: []
// ─── Catalog Registration ────────────────────────────────────────────────
describe("workflow catalog registration", () => {
it("model appears in TOP_LEVEL_SUBCOMMANDS", () => {
const entry = TOP_LEVEL_SUBCOMMANDS.find((c) => c.cmd === "model");
assert.ok(entry, "model should be in TOP_LEVEL_SUBCOMMANDS");
assert.match(entry!.desc, /session model/i);
});
it("getGsdArgumentCompletions('m') includes model", () => {
const completions = getGsdArgumentCompletions("m");
const labels = completions.map((c: any) => c.label);
assert.ok(labels.includes("model"), "should include model completion");
});
it("workflow appears in TOP_LEVEL_SUBCOMMANDS", () => {
const entry = TOP_LEVEL_SUBCOMMANDS.find((c) => c.cmd === "workflow");
assert.ok(entry, "workflow should be in TOP_LEVEL_SUBCOMMANDS");

View file

@ -42,3 +42,35 @@ test("show-config only falls back when ctx.ui.custom() is unavailable", async ()
assert.equal(fallbackCtx.notices.length, 1, "unavailable overlay triggers text fallback");
assert.match(fallbackCtx.notices[0]!.message, /GSD Configuration/);
});
test("model command resolves and persists exact provider-qualified selection", async () => {
const selectedModel = { provider: "openai", id: "gpt-5.4" };
let applied: typeof selectedModel | null = null;
const ctx = {
hasUI: true,
model: { provider: "anthropic", id: "claude-sonnet-4-6" },
modelRegistry: {
getAvailable: () => [
{ provider: "anthropic", id: "claude-sonnet-4-6" },
selectedModel,
],
},
ui: {
notify: (message: string, type?: string) => {
notices.push({ message, type });
},
},
} as any;
const notices: Array<{ message: string; type?: string }> = [];
const pi = {
setModel: async (model: typeof selectedModel) => {
applied = model;
return true;
},
} as any;
const handled = await handleCoreCommand("model openai/gpt-5.4", ctx, pi);
assert.equal(handled, true);
assert.deepEqual(applied, selectedModel);
assert.match(notices[0]!.message, /openai\/gpt-5\.4/);
});

View file

@ -17,6 +17,7 @@ import {
parsePreferencesMarkdown,
_resetParseWarningFlag,
} from "../preferences.ts";
import { formatConfiguredModel, toPersistedModelId } from "../commands-prefs-wizard.ts";
import { _resetLogs, peekLogs } from "../workflow-logger.ts";
import type { GSDPreferences, GSDModelConfigV2, GSDPhaseModelConfig } from "../preferences.ts";
@ -347,6 +348,22 @@ test("handles model config with explicit provider field", () => {
assert.equal(execution.provider, "bedrock");
});
test("formatConfiguredModel renders provider-qualified object config", () => {
assert.equal(
formatConfiguredModel({ model: "claude-opus-4-6", provider: "bedrock" }),
"bedrock/claude-opus-4-6",
);
});
test("toPersistedModelId prefixes provider chosen in prefs wizard", () => {
assert.equal(toPersistedModelId("openai", "gpt-5.4"), "openai/gpt-5.4");
assert.equal(
toPersistedModelId("openai", "openai/gpt-5.4"),
"openai/gpt-5.4",
"already-qualified IDs should be preserved",
);
});
test("handles empty models config", () => {
const prefs = parsePreferencesMarkdown("---\nversion: 1\n---\n");
assert.notEqual(prefs, null);

View file

@ -73,8 +73,23 @@ export async function collectSettingsData(projectCwdOverride?: string): Promise<
'let preferences = null;',
'if (loaded) {',
' const p = loaded.preferences;',
' const models = {};',
' if (p.models && typeof p.models === "object") {',
' for (const [phase, value] of Object.entries(p.models)) {',
' if (typeof value === "string") {',
' models[phase] = value;',
' continue;',
' }',
' if (value && typeof value === "object" && typeof value.model === "string") {',
' models[phase] = typeof value.provider === "string" && value.provider && !value.model.includes("/")',
' ? `${value.provider}/${value.model}`',
' : value.model;',
' }',
' }',
' }',
' preferences = {',
' mode: p.mode,',
' models: Object.keys(models).length > 0 ? models : undefined,',
' budgetCeiling: p.budget_ceiling,',
' budgetEnforcement: p.budget_enforcement,',
' tokenProfile: p.token_profile,',

View file

@ -139,6 +139,24 @@ function SkillBadgeList({ label, skills }: { label: string; skills: string[] | u
)
}
function ModelBadgeList({ models }: { models: Record<string, string> | undefined }) {
if (!models || Object.keys(models).length === 0) return null
return (
<div className="space-y-1">
<span className="text-[11px] text-muted-foreground">Phase Models</span>
<div className="flex flex-wrap gap-1">
{Object.entries(models)
.sort(([a], [b]) => a.localeCompare(b))
.map(([phase, model]) => (
<Badge key={phase} variant="outline" className="text-[10px] px-1.5 py-0 font-mono">
{phase}: {model}
</Badge>
))}
</div>
</div>
)
}
function KvRow({ label, children }: { label: string; children: React.ReactNode }) {
return (
<div className="flex items-center justify-between gap-4 text-xs">
@ -206,12 +224,17 @@ export function PrefsPanel() {
{/* Skills */}
<div className="space-y-2">
<ModelBadgeList models={prefs.models} />
<SkillBadgeList label="Always use" skills={prefs.alwaysUseSkills} />
<SkillBadgeList label="Prefer" skills={prefs.preferSkills} />
<SkillBadgeList label="Avoid" skills={prefs.avoidSkills} />
{!prefs.alwaysUseSkills?.length && !prefs.preferSkills?.length && !prefs.avoidSkills?.length && (
<span className="text-[11px] text-muted-foreground">No skill preferences configured</span>
)}
{!prefs.models || Object.keys(prefs.models).length === 0
? !prefs.alwaysUseSkills?.length && !prefs.preferSkills?.length && !prefs.avoidSkills?.length && (
<span className="text-[11px] text-muted-foreground">No model or skill preferences configured</span>
)
: !prefs.alwaysUseSkills?.length && !prefs.preferSkills?.length && !prefs.avoidSkills?.length && (
<span className="text-[11px] text-muted-foreground">No skill preferences configured</span>
)}
</div>
{/* Toggles */}

View file

@ -126,6 +126,7 @@ const GSD_SURFACE_SUBCOMMANDS = new Map<string, BrowserSlashCommandSurface>([
["history", "gsd-history"],
["undo", "gsd-undo"],
["inspect", "gsd-inspect"],
["model", "model"],
["prefs", "gsd-prefs"],
["config", "gsd-config"],
["hooks", "gsd-hooks"],
@ -153,7 +154,7 @@ export const GSD_HELP_TEXT = `Available /gsd subcommands:
Workflow: next · auto · stop · pause · skip · queue · quick · capture · triage
Diagnostics: status · visualize · forensics · doctor · skill-health · inspect
Context: knowledge · history · undo · discuss
Settings: prefs · config · hooks · mode · steer
Settings: model · prefs · config · hooks · mode · steer
Advanced: export · cleanup · run-hook · migrate · remote
Type /gsd <subcommand> to run. Use /gsd help for this message.`

View file

@ -87,6 +87,7 @@ export interface SettingsProjectTotals {
export interface SettingsPreferencesData {
mode?: SettingsWorkflowMode
models?: Record<string, string>
budgetCeiling?: number
budgetEnforcement?: SettingsBudgetEnforcement
tokenProfile?: SettingsTokenProfile