From a8eb66b8b3fbbba358fa0d1d0bb97913f34746d2 Mon Sep 17 00:00:00 2001 From: Jeremy McSpadden Date: Tue, 17 Mar 2026 09:23:29 -0500 Subject: [PATCH] feat: group /model selector by provider (#871) --- .../interactive/components/model-selector.ts | 413 +++++++++++++++--- 1 file changed, 349 insertions(+), 64 deletions(-) diff --git a/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts b/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts index b35895a79..c86347b6f 100644 --- a/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts +++ b/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts @@ -38,10 +38,22 @@ interface ScopedModelItem { thinkingLevel?: string; } +/** + * A navigable row — either a provider group header or a selectable model entry. + */ +type ListRow = + | { kind: "header"; provider: string; count: number } + | { kind: "model"; item: ModelItem }; + type ModelScope = "all" | "scoped"; /** - * Component that renders a model selector with search + * Component that renders a grouped model selector with search. + * + * Browsing (no search): models are grouped under provider headers. + * - Current model's provider is shown first; remaining providers sorted alphabetically. + * - Arrow keys navigate all rows; headers are skipped during selection. + * Searching: reverts to a flat fuzzy-filtered list (same as before), with [provider] badges. */ export class ModelSelectorComponent extends Container implements Focusable { private searchInput: Input; @@ -59,8 +71,17 @@ export class ModelSelectorComponent extends Container implements Focusable { private allModels: ModelItem[] = []; private scopedModelItems: ModelItem[] = []; private activeModels: ModelItem[] = []; + + // Grouped (browse) state + private groupedRows: ListRow[] = []; + private modelRowIndices: number[] = []; // indices into groupedRows that are "model" kind + private selectedGroupIndex: number = 0; // index into groupedRows (can be model or header) + + // Search (flat) state private filteredModels: ModelItem[] = []; - private selectedIndex: number = 0; + private selectedFlatIndex: number = 0; + + private isSearching: boolean = false; private currentModel?: Model; private settingsManager: SettingsManager; private modelRegistry: ModelRegistry; @@ -116,9 +137,13 @@ export class ModelSelectorComponent extends Container implements Focusable { this.searchInput.setValue(initialSearchInput); } this.searchInput.onSubmit = () => { - // Enter on search input selects the first filtered item - if (this.filteredModels[this.selectedIndex]) { - this.handleSelect(this.filteredModels[this.selectedIndex].model); + if (this.isSearching) { + if (this.filteredModels[this.selectedFlatIndex]) { + this.handleSelect(this.filteredModels[this.selectedFlatIndex].model); + } + } else { + const model = this.getSelectedModel(); + if (model) this.handleSelect(model); } }; this.addChild(this.searchInput); @@ -137,8 +162,11 @@ export class ModelSelectorComponent extends Container implements Focusable { // Load models and do initial render this.loadModels().then(() => { if (initialSearchInput) { + this.isSearching = true; this.filterModels(initialSearchInput); } else { + this.buildGroupedRows(); + this.jumpToCurrentModel(); this.updateList(); } // Request re-render after models are loaded @@ -171,12 +199,14 @@ export class ModelSelectorComponent extends Container implements Focusable { this.scopedModelItems = []; this.activeModels = []; this.filteredModels = []; + this.groupedRows = []; + this.modelRowIndices = []; this.errorMessage = error instanceof Error ? error.message : String(error); return; } - this.allModels = this.sortModels(models); - this.scopedModelItems = this.sortModels( + this.allModels = this.sortModelsWithinProvider(models); + this.scopedModelItems = this.sortModelsWithinProvider( this.scopedModels.map((scoped) => ({ provider: scoped.model.provider, id: scoped.model.id, @@ -185,18 +215,20 @@ export class ModelSelectorComponent extends Container implements Focusable { ); this.activeModels = this.scope === "scoped" ? this.scopedModelItems : this.allModels; this.filteredModels = this.activeModels; - this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1)); } - private sortModels(models: ModelItem[]): ModelItem[] { + /** + * Sort models within each provider: current model first, then by name desc. + * Provider ordering is handled separately in buildGroupedRows(). + */ + private sortModelsWithinProvider(models: ModelItem[]): ModelItem[] { const sorted = [...models]; - // Sort: current model first, then by name descending (newest first), then by provider sorted.sort((a, b) => { const aIsCurrent = modelsAreEqual(this.currentModel, a.model); const bIsCurrent = modelsAreEqual(this.currentModel, b.model); if (aIsCurrent && !bIsCurrent) return -1; if (!aIsCurrent && bIsCurrent) return 1; - // Group by model name (display name), newest/largest first + // Within provider: newest/largest model name first const nameCmp = b.model.name.localeCompare(a.model.name); if (nameCmp !== 0) return nameCmp; return a.provider.localeCompare(b.provider); @@ -204,6 +236,79 @@ export class ModelSelectorComponent extends Container implements Focusable { return sorted; } + /** + * Build the grouped rows array for browse mode. + * Current model's provider comes first; remaining providers sorted alphabetically. + */ + private buildGroupedRows(): void { + // Group models by provider + const byProvider = new Map(); + for (const item of this.activeModels) { + let group = byProvider.get(item.provider); + if (!group) { + group = []; + byProvider.set(item.provider, group); + } + group.push(item); + } + + // Determine provider order: current model's provider first, rest alphabetically + const currentProvider = this.currentModel?.provider; + const providers = Array.from(byProvider.keys()).sort((a, b) => { + if (a === currentProvider) return -1; + if (b === currentProvider) return 1; + return a.localeCompare(b); + }); + + const rows: ListRow[] = []; + const modelIndices: number[] = []; + + for (const provider of providers) { + const items = byProvider.get(provider)!; + rows.push({ kind: "header", provider, count: items.length }); + for (const item of items) { + modelIndices.push(rows.length); + rows.push({ kind: "model", item }); + } + } + + this.groupedRows = rows; + this.modelRowIndices = modelIndices; + } + + /** + * Move selectedGroupIndex to point at the current model (or first model). + */ + private jumpToCurrentModel(): void { + if (this.groupedRows.length === 0) { + this.selectedGroupIndex = 0; + return; + } + // Find the current model in grouped rows + for (let i = 0; i < this.groupedRows.length; i++) { + const row = this.groupedRows[i]; + if (row.kind === "model" && modelsAreEqual(this.currentModel, row.item.model)) { + this.selectedGroupIndex = i; + return; + } + } + // Fall back to first model row + if (this.modelRowIndices.length > 0) { + this.selectedGroupIndex = this.modelRowIndices[0]; + } + } + + /** + * Get the currently selected model from grouped or flat state. + */ + private getSelectedModel(): Model | undefined { + if (this.isSearching) { + return this.filteredModels[this.selectedFlatIndex]?.model; + } + const row = this.groupedRows[this.selectedGroupIndex]; + return row?.kind === "model" ? row.item.model : undefined; + } + private getScopeText(): string { const allText = this.scope === "all" ? theme.fg("accent", "all") : theme.fg("muted", "all"); const scopedText = this.scope === "scoped" ? theme.fg("accent", "scoped") : theme.fg("muted", "scoped"); @@ -218,8 +323,16 @@ export class ModelSelectorComponent extends Container implements Focusable { if (this.scope === scope) return; this.scope = scope; this.activeModels = this.scope === "scoped" ? this.scopedModelItems : this.allModels; - this.selectedIndex = 0; - this.filterModels(this.searchInput.getValue()); + + if (this.isSearching) { + this.selectedFlatIndex = 0; + this.filterModels(this.searchInput.getValue()); + } else { + this.buildGroupedRows(); + this.jumpToCurrentModel(); + this.updateList(); + } + if (this.scopeText) { this.scopeText.setText(this.getScopeText()); } @@ -229,26 +342,51 @@ export class ModelSelectorComponent extends Container implements Focusable { this.filteredModels = query ? fuzzyFilter(this.activeModels, query, ({ id, provider }) => `${id} ${provider}`) : this.activeModels; - this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1)); + this.selectedFlatIndex = Math.min(this.selectedFlatIndex, Math.max(0, this.filteredModels.length - 1)); this.updateList(); } private updateList(): void { this.listContainer.clear(); + if (this.errorMessage) { + const errorLines = this.errorMessage.split("\n"); + for (const line of errorLines) { + this.listContainer.addChild(new Text(theme.fg("error", line), 0, 0)); + } + return; + } + + if (this.isSearching) { + this.renderFlatList(); + } else { + this.renderGroupedList(); + } + } + + /** Flat fuzzy-search results, same as original behaviour */ + private renderFlatList(): void { const maxVisible = 10; + + if (this.filteredModels.length === 0) { + this.listContainer.addChild(new Text(theme.fg("muted", " No matching models"), 0, 0)); + return; + } + const startIndex = Math.max( 0, - Math.min(this.selectedIndex - Math.floor(maxVisible / 2), this.filteredModels.length - maxVisible), + Math.min( + this.selectedFlatIndex - Math.floor(maxVisible / 2), + this.filteredModels.length - maxVisible, + ), ); const endIndex = Math.min(startIndex + maxVisible, this.filteredModels.length); - // Show visible slice of filtered models for (let i = startIndex; i < endIndex; i++) { const item = this.filteredModels[i]; if (!item) continue; - const isSelected = i === this.selectedIndex; + const isSelected = i === this.selectedFlatIndex; const isCurrent = modelsAreEqual(this.currentModel, item.model); const ctx = formatTokenCount(item.model.contextWindow); @@ -256,7 +394,7 @@ export class ModelSelectorComponent extends Container implements Focusable { const providerBadge = theme.fg("muted", `[${item.provider}]`); const checkmark = isCurrent ? theme.fg("success", " ✓") : ""; - let line = ""; + let line: string; if (isSelected) { const prefix = theme.fg("accent", "→ "); line = `${prefix}${theme.fg("accent", item.id)} ${ctxBadge} ${providerBadge}${checkmark}`; @@ -267,40 +405,110 @@ export class ModelSelectorComponent extends Container implements Focusable { this.listContainer.addChild(new Text(line, 0, 0)); } - // Add scroll indicator if needed if (startIndex > 0 || endIndex < this.filteredModels.length) { - const scrollInfo = theme.fg("muted", ` (${this.selectedIndex + 1}/${this.filteredModels.length})`); - this.listContainer.addChild(new Text(scrollInfo, 0, 0)); + this.listContainer.addChild( + new Text(theme.fg("muted", ` (${this.selectedFlatIndex + 1}/${this.filteredModels.length})`), 0, 0), + ); } - // Show error message or "no results" if empty - if (this.errorMessage) { - // Show error in red - const errorLines = this.errorMessage.split("\n"); - for (const line of errorLines) { - this.listContainer.addChild(new Text(theme.fg("error", line), 0, 0)); - } - } else if (this.filteredModels.length === 0) { - this.listContainer.addChild(new Text(theme.fg("muted", " No matching models"), 0, 0)); - } else { - const selected = this.filteredModels[this.selectedIndex]; - if (selected) { - const m = selected.model; - const details = [ - m.name, - `ctx: ${formatTokenCount(m.contextWindow)}`, - `out: ${formatTokenCount(m.maxTokens)}`, - m.reasoning ? "thinking" : "", - m.input.includes("image") ? "vision" : "", - ].filter(Boolean).join(" · "); - this.listContainer.addChild(new Spacer(1)); - this.listContainer.addChild(new Text(theme.fg("muted", ` ${details}`), 0, 0)); + // Detail line for selected model + const selected = this.filteredModels[this.selectedFlatIndex]; + if (selected) { + this.listContainer.addChild(new Spacer(1)); + this.listContainer.addChild(new Text(theme.fg("muted", ` ${this.modelDetailLine(selected.model)}`), 0, 0)); + } + } + + /** + * Grouped browse view: provider headers + model rows, windowed around selection. + * Shows enough rows to fill ~10 visible lines; headers count as one line each. + */ + private renderGroupedList(): void { + const maxVisible = 12; + + if (this.groupedRows.length === 0) { + this.listContainer.addChild(new Text(theme.fg("muted", " No models available"), 0, 0)); + return; + } + + // Window around selectedGroupIndex + const startIndex = Math.max( + 0, + Math.min( + this.selectedGroupIndex - Math.floor(maxVisible / 2), + this.groupedRows.length - maxVisible, + ), + ); + const endIndex = Math.min(startIndex + maxVisible, this.groupedRows.length); + + for (let i = startIndex; i < endIndex; i++) { + const row = this.groupedRows[i]; + if (!row) continue; + + if (row.kind === "header") { + // Provider group header — always unselectable + const providerLabel = theme.fg("borderAccent", row.provider); + const count = theme.fg("muted", ` (${row.count})`); + // Add blank line before header if not the very first visible row + if (i > startIndex) { + this.listContainer.addChild(new Text("", 0, 0)); + } + this.listContainer.addChild(new Text(` ${providerLabel}${count}`, 0, 0)); + } else { + // Model row + const isSelected = i === this.selectedGroupIndex; + const isCurrent = modelsAreEqual(this.currentModel, row.item.model); + + const ctx = formatTokenCount(row.item.model.contextWindow); + const ctxBadge = theme.fg("muted", ` ${ctx}`); + const checkmark = isCurrent ? theme.fg("success", " ✓") : ""; + + let line: string; + if (isSelected) { + line = ` ${theme.fg("accent", "→")} ${theme.fg("accent", row.item.id)}${ctxBadge}${checkmark}`; + } else { + line = ` ${row.item.id}${ctxBadge}${checkmark}`; + } + + this.listContainer.addChild(new Text(line, 0, 0)); } } + + // Scroll indicator + if (startIndex > 0 || endIndex < this.groupedRows.length) { + const modelPos = this.modelRowIndices.indexOf(this.selectedGroupIndex) + 1; + const totalModels = this.modelRowIndices.length; + this.listContainer.addChild( + new Text(theme.fg("muted", ` (${modelPos}/${totalModels})`), 0, 0), + ); + } + + // Detail line for selected model + const selectedModel = this.getSelectedModel(); + if (selectedModel) { + this.listContainer.addChild(new Spacer(1)); + this.listContainer.addChild( + new Text(theme.fg("muted", ` ${this.modelDetailLine(selectedModel)}`), 0, 0), + ); + } + } + + private modelDetailLine(m: Model): string { + return [ + m.name, + `ctx: ${formatTokenCount(m.contextWindow)}`, + `out: ${formatTokenCount(m.maxTokens)}`, + m.reasoning ? "thinking" : "", + m.input.includes("image") ? "vision" : "", + ] + .filter(Boolean) + .join(" · "); } handleInput(keyData: string): void { const kb = getEditorKeybindings(); + + // Tab: scope toggle if (kb.matches(keyData, "tab")) { if (this.scopedModelItems.length > 0) { const nextScope: ModelScope = this.scope === "all" ? "scoped" : "all"; @@ -311,34 +519,111 @@ export class ModelSelectorComponent extends Container implements Focusable { } return; } - // Up arrow - wrap to bottom when at top + + // Navigation keys if (kb.matches(keyData, "selectUp")) { - if (this.filteredModels.length === 0) return; - this.selectedIndex = this.selectedIndex === 0 ? this.filteredModels.length - 1 : this.selectedIndex - 1; - this.updateList(); + this.moveUp(); + return; } - // Down arrow - wrap to top when at bottom - else if (kb.matches(keyData, "selectDown")) { - if (this.filteredModels.length === 0) return; - this.selectedIndex = this.selectedIndex === this.filteredModels.length - 1 ? 0 : this.selectedIndex + 1; - this.updateList(); + if (kb.matches(keyData, "selectDown")) { + this.moveDown(); + return; } - // Enter - else if (kb.matches(keyData, "selectConfirm")) { - const selectedModel = this.filteredModels[this.selectedIndex]; - if (selectedModel) { - this.handleSelect(selectedModel.model); + + // Confirm + if (kb.matches(keyData, "selectConfirm")) { + const model = this.getSelectedModel(); + if (model) this.handleSelect(model); + return; + } + + // Cancel + if (kb.matches(keyData, "selectCancel")) { + this.onCancelCallback(); + return; + } + + // Everything else: feed to search input + const prevQuery = this.searchInput.getValue(); + this.searchInput.handleInput(keyData); + const newQuery = this.searchInput.getValue(); + + if (newQuery !== prevQuery) { + const entering = !prevQuery && !!newQuery; + const leaving = !!prevQuery && !newQuery; + + if (entering) { + // Entering search mode: remember current model position + this.isSearching = true; + this.selectedFlatIndex = 0; + } else if (leaving) { + // Leaving search mode: return to grouped view, restore position + this.isSearching = false; + this.buildGroupedRows(); + this.jumpToCurrentModel(); + } + if (this.isSearching) { + this.filterModels(newQuery); + } else { + this.updateList(); } } - // Escape or Ctrl+C - else if (kb.matches(keyData, "selectCancel")) { - this.onCancelCallback(); + } + + /** Move selection up, skipping headers in grouped mode */ + private moveUp(): void { + if (this.isSearching) { + if (this.filteredModels.length === 0) return; + this.selectedFlatIndex = + this.selectedFlatIndex === 0 + ? this.filteredModels.length - 1 + : this.selectedFlatIndex - 1; + this.updateList(); + return; } - // Pass everything else to search input - else { - this.searchInput.handleInput(keyData); - this.filterModels(this.searchInput.getValue()); + + if (this.groupedRows.length === 0) return; + let next = this.selectedGroupIndex - 1; + // Wrap + if (next < 0) next = this.groupedRows.length - 1; + // Skip headers + while (next > 0 && this.groupedRows[next]?.kind === "header") { + next--; } + // If landed on header at 0, wrap to bottom + if (this.groupedRows[next]?.kind === "header") { + next = this.groupedRows.length - 1; + } + this.selectedGroupIndex = next; + this.updateList(); + } + + /** Move selection down, skipping headers in grouped mode */ + private moveDown(): void { + if (this.isSearching) { + if (this.filteredModels.length === 0) return; + this.selectedFlatIndex = + this.selectedFlatIndex === this.filteredModels.length - 1 + ? 0 + : this.selectedFlatIndex + 1; + this.updateList(); + return; + } + + if (this.groupedRows.length === 0) return; + let next = this.selectedGroupIndex + 1; + // Wrap + if (next >= this.groupedRows.length) next = 0; + // Skip headers + while (next < this.groupedRows.length - 1 && this.groupedRows[next]?.kind === "header") { + next++; + } + // If landed on header at end, wrap to first model + if (this.groupedRows[next]?.kind === "header") { + next = this.modelRowIndices[0] ?? 0; + } + this.selectedGroupIndex = next; + this.updateList(); } private handleSelect(model: Model): void {