group gsd model picker by provider
This commit is contained in:
parent
f5c6c1d94c
commit
c666ff55eb
2 changed files with 153 additions and 25 deletions
|
|
@ -194,6 +194,56 @@ function sortModelsForSelection(models: Model<any>[], currentModel: Model<any> |
|
|||
});
|
||||
}
|
||||
|
||||
function buildProviderModelGroups(
|
||||
models: Model<any>[],
|
||||
currentModel: Model<any> | undefined,
|
||||
): Map<string, Model<any>[]> {
|
||||
const byProvider = new Map<string, Model<any>[]>();
|
||||
|
||||
for (const model of sortModelsForSelection(models, currentModel)) {
|
||||
let group = byProvider.get(model.provider);
|
||||
if (!group) {
|
||||
group = [];
|
||||
byProvider.set(model.provider, group);
|
||||
}
|
||||
group.push(model);
|
||||
}
|
||||
return byProvider;
|
||||
}
|
||||
|
||||
async function selectModelByProvider(
|
||||
title: string,
|
||||
models: Model<any>[],
|
||||
ctx: ExtensionCommandContext,
|
||||
currentModel: Model<any> | undefined,
|
||||
): Promise<Model<any> | undefined> {
|
||||
const byProvider = buildProviderModelGroups(models, currentModel);
|
||||
const providerOptions = Array.from(byProvider.entries()).map(([provider, group]) =>
|
||||
`${provider} (${group.length} model${group.length === 1 ? "" : "s"})`,
|
||||
);
|
||||
providerOptions.push("(cancel)");
|
||||
|
||||
const providerChoice = await ctx.ui.select(`${title} — choose provider:`, providerOptions);
|
||||
if (!providerChoice || typeof providerChoice !== "string" || providerChoice === "(cancel)") return undefined;
|
||||
|
||||
const providerName = providerChoice.replace(/ \(\d+ models?\)$/, "");
|
||||
const providerModels = byProvider.get(providerName);
|
||||
if (!providerModels || providerModels.length === 0) return undefined;
|
||||
|
||||
const optionToModel = new Map<string, Model<any>>();
|
||||
const modelOptions = providerModels.map((model) => {
|
||||
const isCurrent = currentModel && model.provider === currentModel.provider && model.id === currentModel.id;
|
||||
const label = `${isCurrent ? "* " : ""}${model.id}`;
|
||||
optionToModel.set(label, model);
|
||||
return label;
|
||||
});
|
||||
modelOptions.push("(cancel)");
|
||||
|
||||
const modelChoice = await ctx.ui.select(`${title} — ${providerName}:`, modelOptions);
|
||||
if (!modelChoice || typeof modelChoice !== "string" || modelChoice === "(cancel)") return undefined;
|
||||
return optionToModel.get(modelChoice);
|
||||
}
|
||||
|
||||
async function resolveRequestedModel(
|
||||
query: string,
|
||||
ctx: ExtensionCommandContext,
|
||||
|
|
@ -211,19 +261,7 @@ async function resolveRequestedModel(
|
|||
|
||||
if (partialMatches.length === 1) return partialMatches[0];
|
||||
if (partialMatches.length === 0 || !ctx.hasUI) return undefined;
|
||||
|
||||
const sorted = sortModelsForSelection(partialMatches, ctx.model);
|
||||
const optionToModel = new Map<string, Model<any>>();
|
||||
const options = sorted.map((model) => {
|
||||
const label = `${model.provider}/${model.id}`;
|
||||
optionToModel.set(label, model);
|
||||
return label;
|
||||
});
|
||||
options.push("(cancel)");
|
||||
|
||||
const choice = await ctx.ui.select(`Multiple models match "${query}" — choose one:`, options);
|
||||
if (!choice || typeof choice !== "string" || choice === "(cancel)") return undefined;
|
||||
return optionToModel.get(choice);
|
||||
return selectModelByProvider(`Multiple models match "${query}"`, partialMatches, ctx, ctx.model);
|
||||
}
|
||||
|
||||
async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi: ExtensionAPI | undefined): Promise<void> {
|
||||
|
|
@ -247,18 +285,7 @@ async function handleModel(trimmedArgs: string, ctx: ExtensionCommandContext, pi
|
|||
return;
|
||||
}
|
||||
|
||||
const optionToModel = new Map<string, Model<any>>();
|
||||
const options = sortModelsForSelection(availableModels, ctx.model).map((model) => {
|
||||
const isCurrent = ctx.model && model.provider === ctx.model.provider && model.id === ctx.model.id;
|
||||
const label = `${isCurrent ? "* " : ""}${model.provider}/${model.id}`;
|
||||
optionToModel.set(label, model);
|
||||
return label;
|
||||
});
|
||||
options.push("(cancel)");
|
||||
|
||||
const choice = await ctx.ui.select("Select session model:", options);
|
||||
if (!choice || typeof choice !== "string" || choice === "(cancel)") return;
|
||||
targetModel = optionToModel.get(choice);
|
||||
targetModel = await selectModelByProvider("Select session model:", availableModels, ctx, ctx.model);
|
||||
} else {
|
||||
targetModel = await resolveRequestedModel(trimmed, ctx);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -74,3 +74,104 @@ test("model command resolves and persists exact provider-qualified selection", a
|
|||
assert.deepEqual(applied, selectedModel);
|
||||
assert.match(notices[0]!.message, /openai\/gpt-5\.4/);
|
||||
});
|
||||
|
||||
test("interactive model picker chooses provider first, then model", async () => {
|
||||
const selectedModel = { provider: "openai", id: "gpt-5.4" };
|
||||
let applied: typeof selectedModel | null = null;
|
||||
const selects: Array<{ title: string; options: string[] }> = [];
|
||||
const notices: Array<{ message: string; type?: string }> = [];
|
||||
|
||||
const ctx = {
|
||||
hasUI: true,
|
||||
model: { provider: "anthropic", id: "claude-sonnet-4-6" },
|
||||
modelRegistry: {
|
||||
getAvailable: () => [
|
||||
{ provider: "openai", id: "gpt-5.4" },
|
||||
{ provider: "anthropic", id: "claude-opus-4-6" },
|
||||
{ provider: "openai", id: "gpt-5.3-mini" },
|
||||
{ provider: "anthropic", id: "claude-sonnet-4-6" },
|
||||
],
|
||||
},
|
||||
ui: {
|
||||
select: async (title: string, options: string[]) => {
|
||||
selects.push({ title, options });
|
||||
return selects.length === 1 ? "openai (2 models)" : "gpt-5.4";
|
||||
},
|
||||
notify: (message: string, type?: string) => {
|
||||
notices.push({ message, type });
|
||||
},
|
||||
},
|
||||
} as any;
|
||||
|
||||
const pi = {
|
||||
setModel: async (model: typeof selectedModel) => {
|
||||
applied = model;
|
||||
return true;
|
||||
},
|
||||
} as any;
|
||||
|
||||
const handled = await handleCoreCommand("model", ctx, pi);
|
||||
assert.equal(handled, true);
|
||||
assert.deepEqual(selects, [
|
||||
{
|
||||
title: "Select session model: — choose provider:",
|
||||
options: ["anthropic (2 models)", "openai (2 models)", "(cancel)"],
|
||||
},
|
||||
{
|
||||
title: "Select session model: — openai:",
|
||||
options: ["gpt-5.3-mini", "gpt-5.4", "(cancel)"],
|
||||
},
|
||||
]);
|
||||
assert.deepEqual(applied, selectedModel);
|
||||
assert.match(notices[0]!.message, /openai\/gpt-5\.4/);
|
||||
});
|
||||
|
||||
test("ambiguous typed model selection chooses provider first, then model", async () => {
|
||||
const selectedModel = { provider: "github-copilot", id: "gpt-5" };
|
||||
let applied: typeof selectedModel | null = null;
|
||||
const selects: Array<{ title: string; options: string[] }> = [];
|
||||
const notices: Array<{ message: string; type?: string }> = [];
|
||||
|
||||
const ctx = {
|
||||
hasUI: true,
|
||||
model: { provider: "anthropic", id: "claude-sonnet-4-6" },
|
||||
modelRegistry: {
|
||||
getAvailable: () => [
|
||||
{ provider: "openai", id: "gpt-5" },
|
||||
{ provider: "github-copilot", id: "gpt-5" },
|
||||
{ provider: "openai", id: "gpt-5-mini" },
|
||||
],
|
||||
},
|
||||
ui: {
|
||||
select: async (title: string, options: string[]) => {
|
||||
selects.push({ title, options });
|
||||
return selects.length === 1 ? "github-copilot (1 model)" : "gpt-5";
|
||||
},
|
||||
notify: (message: string, type?: string) => {
|
||||
notices.push({ message, type });
|
||||
},
|
||||
},
|
||||
} as any;
|
||||
|
||||
const pi = {
|
||||
setModel: async (model: typeof selectedModel) => {
|
||||
applied = model;
|
||||
return true;
|
||||
},
|
||||
} as any;
|
||||
|
||||
const handled = await handleCoreCommand("model gpt", ctx, pi);
|
||||
assert.equal(handled, true);
|
||||
assert.deepEqual(selects, [
|
||||
{
|
||||
title: "Multiple models match \"gpt\" — choose provider:",
|
||||
options: ["github-copilot (1 model)", "openai (2 models)", "(cancel)"],
|
||||
},
|
||||
{
|
||||
title: "Multiple models match \"gpt\" — github-copilot:",
|
||||
options: ["gpt-5", "(cancel)"],
|
||||
},
|
||||
]);
|
||||
assert.deepEqual(applied, selectedModel);
|
||||
assert.match(notices[0]!.message, /github-copilot\/gpt-5/);
|
||||
});
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue