group gsd model picker by provider

This commit is contained in:
Jeremy 2026-04-09 16:02:45 -05:00
parent f5c6c1d94c
commit c666ff55eb
2 changed files with 153 additions and 25 deletions

View file

@ -194,6 +194,56 @@ function sortModelsForSelection(models: Model<any>[], currentModel: Model<any> |
});
}
function buildProviderModelGroups(
models: Model<any>[],
currentModel: Model<any> | undefined,
): Map<string, Model<any>[]> {
const byProvider = new Map<string, Model<any>[]>();
for (const model of sortModelsForSelection(models, currentModel)) {
let group = byProvider.get(model.provider);
if (!group) {
group = [];
byProvider.set(model.provider, group);
}
group.push(model);
}
return byProvider;
}
async function selectModelByProvider(
title: string,
models: Model<any>[],
ctx: ExtensionCommandContext,
currentModel: Model<any> | undefined,
): Promise<Model<any> | undefined> {
const byProvider = buildProviderModelGroups(models, currentModel);
const providerOptions = Array.from(byProvider.entries()).map(([provider, group]) =>
`${provider} (${group.length} model${group.length === 1 ? "" : "s"})`,
);
providerOptions.push("(cancel)");
const providerChoice = await ctx.ui.select(`${title} — choose provider:`, providerOptions);
if (!providerChoice || typeof providerChoice !== "string" || providerChoice === "(cancel)") return undefined;
const providerName = providerChoice.replace(/ \(\d+ models?\)$/, "");
const providerModels = byProvider.get(providerName);
if (!providerModels || providerModels.length === 0) return undefined;
const optionToModel = new Map<string, Model<any>>();
const modelOptions = providerModels.map((model) => {
const isCurrent = currentModel && model.provider === currentModel.provider && model.id === currentModel.id;
const label = `${isCurrent ? "* " : ""}${model.id}`;
optionToModel.set(label, model);
return label;
});
modelOptions.push("(cancel)");
const modelChoice = await ctx.ui.select(`${title}${providerName}:`, modelOptions);
if (!modelChoice || typeof modelChoice !== "string" || modelChoice === "(cancel)") return undefined;
return optionToModel.get(modelChoice);
}
async function resolveRequestedModel(
query: string,
ctx: ExtensionCommandContext,
@ -211,19 +261,7 @@ async function resolveRequestedModel(
if (partialMatches.length === 1) return partialMatches[0];
if (partialMatches.length === 0 || !ctx.hasUI) return undefined;
const sorted = sortModelsForSelection(partialMatches, ctx.model);
const optionToModel = new Map<string, Model<any>>();
const options = sorted.map((model) => {
const label = `${model.provider}/${model.id}`;
optionToModel.set(label, model);
return label;
});
options.push("(cancel)");
const choice = await ctx.ui.select(`Multiple models match "${query}" — choose one:`, options);
if (!choice || typeof choice !== "string" || choice === "(cancel)") return undefined;
return optionToModel.get(choice);
return selectModelByProvider(`Multiple models match "${query}"`, partialMatches, ctx, ctx.model);
}
async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi: ExtensionAPI | undefined): Promise<void> {
@ -247,18 +285,7 @@ async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi
return;
}
const optionToModel = new Map<string, Model<any>>();
const options = sortModelsForSelection(availableModels, ctx.model).map((model) => {
const isCurrent = ctx.model && model.provider === ctx.model.provider && model.id === ctx.model.id;
const label = `${isCurrent ? "* " : ""}${model.provider}/${model.id}`;
optionToModel.set(label, model);
return label;
});
options.push("(cancel)");
const choice = await ctx.ui.select("Select session model:", options);
if (!choice || typeof choice !== "string" || choice === "(cancel)") return;
targetModel = optionToModel.get(choice);
targetModel = await selectModelByProvider("Select session model:", availableModels, ctx, ctx.model);
} else {
targetModel = await resolveRequestedModel(trimmed, ctx);
}

View file

@ -74,3 +74,104 @@ test("model command resolves and persists exact provider-qualified selection", a
assert.deepEqual(applied, selectedModel);
assert.match(notices[0]!.message, /openai\/gpt-5\.4/);
});
test("interactive model picker chooses provider first, then model", async () => {
const selectedModel = { provider: "openai", id: "gpt-5.4" };
let applied: typeof selectedModel | null = null;
const selects: Array<{ title: string; options: string[] }> = [];
const notices: Array<{ message: string; type?: string }> = [];
const ctx = {
hasUI: true,
model: { provider: "anthropic", id: "claude-sonnet-4-6" },
modelRegistry: {
getAvailable: () => [
{ provider: "openai", id: "gpt-5.4" },
{ provider: "anthropic", id: "claude-opus-4-6" },
{ provider: "openai", id: "gpt-5.3-mini" },
{ provider: "anthropic", id: "claude-sonnet-4-6" },
],
},
ui: {
select: async (title: string, options: string[]) => {
selects.push({ title, options });
return selects.length === 1 ? "openai (2 models)" : "gpt-5.4";
},
notify: (message: string, type?: string) => {
notices.push({ message, type });
},
},
} as any;
const pi = {
setModel: async (model: typeof selectedModel) => {
applied = model;
return true;
},
} as any;
const handled = await handleCoreCommand("model", ctx, pi);
assert.equal(handled, true);
assert.deepEqual(selects, [
{
title: "Select session model: — choose provider:",
options: ["anthropic (2 models)", "openai (2 models)", "(cancel)"],
},
{
title: "Select session model: — openai:",
options: ["gpt-5.3-mini", "gpt-5.4", "(cancel)"],
},
]);
assert.deepEqual(applied, selectedModel);
assert.match(notices[0]!.message, /openai\/gpt-5\.4/);
});
test("ambiguous typed model selection chooses provider first, then model", async () => {
const selectedModel = { provider: "github-copilot", id: "gpt-5" };
let applied: typeof selectedModel | null = null;
const selects: Array<{ title: string; options: string[] }> = [];
const notices: Array<{ message: string; type?: string }> = [];
const ctx = {
hasUI: true,
model: { provider: "anthropic", id: "claude-sonnet-4-6" },
modelRegistry: {
getAvailable: () => [
{ provider: "openai", id: "gpt-5" },
{ provider: "github-copilot", id: "gpt-5" },
{ provider: "openai", id: "gpt-5-mini" },
],
},
ui: {
select: async (title: string, options: string[]) => {
selects.push({ title, options });
return selects.length === 1 ? "github-copilot (1 model)" : "gpt-5";
},
notify: (message: string, type?: string) => {
notices.push({ message, type });
},
},
} as any;
const pi = {
setModel: async (model: typeof selectedModel) => {
applied = model;
return true;
},
} as any;
const handled = await handleCoreCommand("model gpt", ctx, pi);
assert.equal(handled, true);
assert.deepEqual(selects, [
{
title: "Multiple models match \"gpt\" — choose provider:",
options: ["github-copilot (1 model)", "openai (2 models)", "(cancel)"],
},
{
title: "Multiple models match \"gpt\" — github-copilot:",
options: ["gpt-5", "(cancel)"],
},
]);
assert.deepEqual(applied, selectedModel);
assert.match(notices[0]!.message, /github-copilot\/gpt-5/);
});