From c666ff55eb0dce206fc8bc412601709333b81f67 Mon Sep 17 00:00:00 2001 From: Jeremy Date: Thu, 9 Apr 2026 16:02:45 -0500 Subject: [PATCH] group gsd model picker by provider --- .../extensions/gsd/commands/handlers/core.ts | 77 ++++++++----- .../gsd/tests/core-overlay-fallback.test.ts | 101 ++++++++++++++++++ 2 files changed, 153 insertions(+), 25 deletions(-) diff --git a/src/resources/extensions/gsd/commands/handlers/core.ts b/src/resources/extensions/gsd/commands/handlers/core.ts index ae8da6c60..9b608f166 100644 --- a/src/resources/extensions/gsd/commands/handlers/core.ts +++ b/src/resources/extensions/gsd/commands/handlers/core.ts @@ -194,6 +194,56 @@ function sortModelsForSelection(models: Model[], currentModel: Model | }); } +function buildProviderModelGroups( + models: Model[], + currentModel: Model | undefined, +): Map[]> { + const byProvider = new Map[]>(); + + for (const model of sortModelsForSelection(models, currentModel)) { + let group = byProvider.get(model.provider); + if (!group) { + group = []; + byProvider.set(model.provider, group); + } + group.push(model); + } + return byProvider; +} + +async function selectModelByProvider( + title: string, + models: Model[], + ctx: ExtensionCommandContext, + currentModel: Model | undefined, +): Promise | undefined> { + const byProvider = buildProviderModelGroups(models, currentModel); + const providerOptions = Array.from(byProvider.entries()).map(([provider, group]) => + `${provider} (${group.length} model${group.length === 1 ? "" : "s"})`, + ); + providerOptions.push("(cancel)"); + + const providerChoice = await ctx.ui.select(`${title} — choose provider:`, providerOptions); + if (!providerChoice || typeof providerChoice !== "string" || providerChoice === "(cancel)") return undefined; + + const providerName = providerChoice.replace(/ \(\d+ models?\)$/, ""); + const providerModels = byProvider.get(providerName); + if (!providerModels || providerModels.length === 0) return undefined; + + const optionToModel = new Map>(); + const modelOptions = providerModels.map((model) => { + const isCurrent = currentModel && model.provider === currentModel.provider && model.id === currentModel.id; + const label = `${isCurrent ? "* " : ""}${model.id}`; + optionToModel.set(label, model); + return label; + }); + modelOptions.push("(cancel)"); + + const modelChoice = await ctx.ui.select(`${title} — ${providerName}:`, modelOptions); + if (!modelChoice || typeof modelChoice !== "string" || modelChoice === "(cancel)") return undefined; + return optionToModel.get(modelChoice); +} + async function resolveRequestedModel( query: string, ctx: ExtensionCommandContext, @@ -211,19 +261,7 @@ async function resolveRequestedModel( if (partialMatches.length === 1) return partialMatches[0]; if (partialMatches.length === 0 || !ctx.hasUI) return undefined; - - const sorted = sortModelsForSelection(partialMatches, ctx.model); - const optionToModel = new Map>(); - const options = sorted.map((model) => { - const label = `${model.provider}/${model.id}`; - optionToModel.set(label, model); - return label; - }); - options.push("(cancel)"); - - const choice = await ctx.ui.select(`Multiple models match "${query}" — choose one:`, options); - if (!choice || typeof choice !== "string" || choice === "(cancel)") return undefined; - return optionToModel.get(choice); + return selectModelByProvider(`Multiple models match "${query}"`, partialMatches, ctx, ctx.model); } async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi: ExtensionAPI | undefined): Promise { @@ -247,18 +285,7 @@ async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi return; } - const optionToModel = new Map>(); - const options = sortModelsForSelection(availableModels, ctx.model).map((model) => { - const isCurrent = ctx.model && model.provider === ctx.model.provider && model.id === ctx.model.id; - const label = `${isCurrent ? "* " : ""}${model.provider}/${model.id}`; - optionToModel.set(label, model); - return label; - }); - options.push("(cancel)"); - - const choice = await ctx.ui.select("Select session model:", options); - if (!choice || typeof choice !== "string" || choice === "(cancel)") return; - targetModel = optionToModel.get(choice); + targetModel = await selectModelByProvider("Select session model:", availableModels, ctx, ctx.model); } else { targetModel = await resolveRequestedModel(trimmed, ctx); } diff --git a/src/resources/extensions/gsd/tests/core-overlay-fallback.test.ts b/src/resources/extensions/gsd/tests/core-overlay-fallback.test.ts index a6c2dc6d9..9a7a21d16 100644 --- a/src/resources/extensions/gsd/tests/core-overlay-fallback.test.ts +++ b/src/resources/extensions/gsd/tests/core-overlay-fallback.test.ts @@ -74,3 +74,104 @@ test("model command resolves and persists exact provider-qualified selection", a assert.deepEqual(applied, selectedModel); assert.match(notices[0]!.message, /openai\/gpt-5\.4/); }); + +test("interactive model picker chooses provider first, then model", async () => { + const selectedModel = { provider: "openai", id: "gpt-5.4" }; + let applied: typeof selectedModel | null = null; + const selects: Array<{ title: string; options: string[] }> = []; + const notices: Array<{ message: string; type?: string }> = []; + + const ctx = { + hasUI: true, + model: { provider: "anthropic", id: "claude-sonnet-4-6" }, + modelRegistry: { + getAvailable: () => [ + { provider: "openai", id: "gpt-5.4" }, + { provider: "anthropic", id: "claude-opus-4-6" }, + { provider: "openai", id: "gpt-5.3-mini" }, + { provider: "anthropic", id: "claude-sonnet-4-6" }, + ], + }, + ui: { + select: async (title: string, options: string[]) => { + selects.push({ title, options }); + return selects.length === 1 ? "openai (2 models)" : "gpt-5.4"; + }, + notify: (message: string, type?: string) => { + notices.push({ message, type }); + }, + }, + } as any; + + const pi = { + setModel: async (model: typeof selectedModel) => { + applied = model; + return true; + }, + } as any; + + const handled = await handleCoreCommand("model", ctx, pi); + assert.equal(handled, true); + assert.deepEqual(selects, [ + { + title: "Select session model: — choose provider:", + options: ["anthropic (2 models)", "openai (2 models)", "(cancel)"], + }, + { + title: "Select session model: — openai:", + options: ["gpt-5.3-mini", "gpt-5.4", "(cancel)"], + }, + ]); + assert.deepEqual(applied, selectedModel); + assert.match(notices[0]!.message, /openai\/gpt-5\.4/); +}); + +test("ambiguous typed model selection chooses provider first, then model", async () => { + const selectedModel = { provider: "github-copilot", id: "gpt-5" }; + let applied: typeof selectedModel | null = null; + const selects: Array<{ title: string; options: string[] }> = []; + const notices: Array<{ message: string; type?: string }> = []; + + const ctx = { + hasUI: true, + model: { provider: "anthropic", id: "claude-sonnet-4-6" }, + modelRegistry: { + getAvailable: () => [ + { provider: "openai", id: "gpt-5" }, + { provider: "github-copilot", id: "gpt-5" }, + { provider: "openai", id: "gpt-5-mini" }, + ], + }, + ui: { + select: async (title: string, options: string[]) => { + selects.push({ title, options }); + return selects.length === 1 ? "github-copilot (1 model)" : "gpt-5"; + }, + notify: (message: string, type?: string) => { + notices.push({ message, type }); + }, + }, + } as any; + + const pi = { + setModel: async (model: typeof selectedModel) => { + applied = model; + return true; + }, + } as any; + + const handled = await handleCoreCommand("model gpt", ctx, pi); + assert.equal(handled, true); + assert.deepEqual(selects, [ + { + title: "Multiple models match \"gpt\" — choose provider:", + options: ["github-copilot (1 model)", "openai (2 models)", "(cancel)"], + }, + { + title: "Multiple models match \"gpt\" — github-copilot:", + options: ["gpt-5", "(cancel)"], + }, + ]); + assert.deepEqual(applied, selectedModel); + assert.match(notices[0]!.message, /github-copilot\/gpt-5/); +});