singularity-forge/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts

353 lines
11 KiB
TypeScript

import { type Model, modelsAreEqual } from "@gsd/pi-ai";
import {
Container,
type Focusable,
fuzzyFilter,
getEditorKeybindings,
Input,
Spacer,
Text,
type TUI,
} from "@gsd/pi-tui";
import type { ModelRegistry } from "../../../core/model-registry.js";
import type { SettingsManager } from "../../../core/settings-manager.js";
import { theme } from "../theme/theme.js";
import { DynamicBorder } from "./dynamic-border.js";
import { keyHint } from "./keybinding-hints.js";
function formatTokenCount(count: number): string {
if (count >= 1_000_000) {
const millions = count / 1_000_000;
return millions % 1 === 0 ? `${millions}M` : `${millions.toFixed(1)}M`;
}
if (count >= 1_000) {
const thousands = count / 1_000;
return thousands % 1 === 0 ? `${thousands}K` : `${thousands.toFixed(1)}K`;
}
return count.toString();
}
interface ModelItem {
provider: string;
id: string;
model: Model<any>;
}
interface ScopedModelItem {
model: Model<any>;
thinkingLevel?: string;
}
type ModelScope = "all" | "scoped";
/**
* Component that renders a model selector with search
*/
export class ModelSelectorComponent extends Container implements Focusable {
private searchInput: Input;
// Focusable implementation - propagate to searchInput for IME cursor positioning
private _focused = false;
get focused(): boolean {
return this._focused;
}
set focused(value: boolean) {
this._focused = value;
this.searchInput.focused = value;
}
private listContainer: Container;
private allModels: ModelItem[] = [];
private scopedModelItems: ModelItem[] = [];
private activeModels: ModelItem[] = [];
private filteredModels: ModelItem[] = [];
private selectedIndex: number = 0;
private currentModel?: Model<any>;
private settingsManager: SettingsManager;
private modelRegistry: ModelRegistry;
private onSelectCallback: (model: Model<any>) => void;
private onCancelCallback: () => void;
private errorMessage?: string;
private tui: TUI;
private scopedModels: ReadonlyArray<ScopedModelItem>;
private scope: ModelScope = "all";
private scopeText?: Text;
private scopeHintText?: Text;
constructor(
tui: TUI,
currentModel: Model<any> | undefined,
settingsManager: SettingsManager,
modelRegistry: ModelRegistry,
scopedModels: ReadonlyArray<ScopedModelItem>,
onSelect: (model: Model<any>) => void,
onCancel: () => void,
initialSearchInput?: string,
) {
super();
this.tui = tui;
this.currentModel = currentModel;
this.settingsManager = settingsManager;
this.modelRegistry = modelRegistry;
this.scopedModels = scopedModels;
this.scope = scopedModels.length > 0 ? "scoped" : "all";
this.onSelectCallback = onSelect;
this.onCancelCallback = onCancel;
// Add top border
this.addChild(new DynamicBorder());
this.addChild(new Spacer(1));
// Add hint about model filtering
if (scopedModels.length > 0) {
this.scopeText = new Text(this.getScopeText(), 0, 0);
this.addChild(this.scopeText);
this.scopeHintText = new Text(this.getScopeHintText(), 0, 0);
this.addChild(this.scopeHintText);
} else {
const hintText = "Only showing models with configured API keys (see README for details)";
this.addChild(new Text(theme.fg("warning", hintText), 0, 0));
}
this.addChild(new Spacer(1));
// Create search input
this.searchInput = new Input();
if (initialSearchInput) {
this.searchInput.setValue(initialSearchInput);
}
this.searchInput.onSubmit = () => {
// Enter on search input selects the first filtered item
if (this.filteredModels[this.selectedIndex]) {
this.handleSelect(this.filteredModels[this.selectedIndex].model);
}
};
this.addChild(this.searchInput);
this.addChild(new Spacer(1));
// Create list container
this.listContainer = new Container();
this.addChild(this.listContainer);
this.addChild(new Spacer(1));
// Add bottom border
this.addChild(new DynamicBorder());
// Load models and do initial render
this.loadModels().then(() => {
if (initialSearchInput) {
this.filterModels(initialSearchInput);
} else {
this.updateList();
}
// Request re-render after models are loaded
this.tui.requestRender();
});
}
private async loadModels(): Promise<void> {
let models: ModelItem[];
// Refresh to pick up any changes to models.json
this.modelRegistry.refresh();
// Check for models.json errors
const loadError = this.modelRegistry.getError();
if (loadError) {
this.errorMessage = loadError;
}
// Load available models (built-in models still work even if models.json failed)
try {
const availableModels = this.modelRegistry.getAvailable();
models = availableModels.map((model: Model<any>) => ({
provider: model.provider,
id: model.id,
model,
}));
} catch (error) {
this.allModels = [];
this.scopedModelItems = [];
this.activeModels = [];
this.filteredModels = [];
this.errorMessage = error instanceof Error ? error.message : String(error);
return;
}
this.allModels = this.sortModels(models);
this.scopedModelItems = this.sortModels(
this.scopedModels.map((scoped) => ({
provider: scoped.model.provider,
id: scoped.model.id,
model: scoped.model,
})),
);
this.activeModels = this.scope === "scoped" ? this.scopedModelItems : this.allModels;
this.filteredModels = this.activeModels;
this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1));
}
private sortModels(models: ModelItem[]): ModelItem[] {
const sorted = [...models];
// Sort: current model first, then by name descending (newest first), then by provider
sorted.sort((a, b) => {
const aIsCurrent = modelsAreEqual(this.currentModel, a.model);
const bIsCurrent = modelsAreEqual(this.currentModel, b.model);
if (aIsCurrent && !bIsCurrent) return -1;
if (!aIsCurrent && bIsCurrent) return 1;
// Group by model name (display name), newest/largest first
const nameCmp = b.model.name.localeCompare(a.model.name);
if (nameCmp !== 0) return nameCmp;
return a.provider.localeCompare(b.provider);
});
return sorted;
}
private getScopeText(): string {
const allText = this.scope === "all" ? theme.fg("accent", "all") : theme.fg("muted", "all");
const scopedText = this.scope === "scoped" ? theme.fg("accent", "scoped") : theme.fg("muted", "scoped");
return `${theme.fg("muted", "Scope: ")}${allText}${theme.fg("muted", " | ")}${scopedText}`;
}
private getScopeHintText(): string {
return keyHint("tab", "scope") + theme.fg("muted", " (all/scoped)");
}
private setScope(scope: ModelScope): void {
if (this.scope === scope) return;
this.scope = scope;
this.activeModels = this.scope === "scoped" ? this.scopedModelItems : this.allModels;
this.selectedIndex = 0;
this.filterModels(this.searchInput.getValue());
if (this.scopeText) {
this.scopeText.setText(this.getScopeText());
}
}
private filterModels(query: string): void {
this.filteredModels = query
? fuzzyFilter(this.activeModels, query, ({ id, provider }) => `${id} ${provider}`)
: this.activeModels;
this.selectedIndex = Math.min(this.selectedIndex, Math.max(0, this.filteredModels.length - 1));
this.updateList();
}
private updateList(): void {
this.listContainer.clear();
const maxVisible = 10;
const startIndex = Math.max(
0,
Math.min(this.selectedIndex - Math.floor(maxVisible / 2), this.filteredModels.length - maxVisible),
);
const endIndex = Math.min(startIndex + maxVisible, this.filteredModels.length);
// Show visible slice of filtered models
for (let i = startIndex; i < endIndex; i++) {
const item = this.filteredModels[i];
if (!item) continue;
const isSelected = i === this.selectedIndex;
const isCurrent = modelsAreEqual(this.currentModel, item.model);
const ctx = formatTokenCount(item.model.contextWindow);
const ctxBadge = theme.fg("muted", `${ctx}`);
const providerBadge = theme.fg("muted", `[${item.provider}]`);
const checkmark = isCurrent ? theme.fg("success", " ✓") : "";
let line = "";
if (isSelected) {
const prefix = theme.fg("accent", "→ ");
line = `${prefix}${theme.fg("accent", item.id)} ${ctxBadge} ${providerBadge}${checkmark}`;
} else {
line = ` ${item.id} ${ctxBadge} ${providerBadge}${checkmark}`;
}
this.listContainer.addChild(new Text(line, 0, 0));
}
// Add scroll indicator if needed
if (startIndex > 0 || endIndex < this.filteredModels.length) {
const scrollInfo = theme.fg("muted", ` (${this.selectedIndex + 1}/${this.filteredModels.length})`);
this.listContainer.addChild(new Text(scrollInfo, 0, 0));
}
// Show error message or "no results" if empty
if (this.errorMessage) {
// Show error in red
const errorLines = this.errorMessage.split("\n");
for (const line of errorLines) {
this.listContainer.addChild(new Text(theme.fg("error", line), 0, 0));
}
} else if (this.filteredModels.length === 0) {
this.listContainer.addChild(new Text(theme.fg("muted", " No matching models"), 0, 0));
} else {
const selected = this.filteredModels[this.selectedIndex];
if (selected) {
const m = selected.model;
const details = [
m.name,
`ctx: ${formatTokenCount(m.contextWindow)}`,
`out: ${formatTokenCount(m.maxTokens)}`,
m.reasoning ? "thinking" : "",
m.input.includes("image") ? "vision" : "",
].filter(Boolean).join(" · ");
this.listContainer.addChild(new Spacer(1));
this.listContainer.addChild(new Text(theme.fg("muted", ` ${details}`), 0, 0));
}
}
}
handleInput(keyData: string): void {
const kb = getEditorKeybindings();
if (kb.matches(keyData, "tab")) {
if (this.scopedModelItems.length > 0) {
const nextScope: ModelScope = this.scope === "all" ? "scoped" : "all";
this.setScope(nextScope);
if (this.scopeHintText) {
this.scopeHintText.setText(this.getScopeHintText());
}
}
return;
}
// Up arrow - wrap to bottom when at top
if (kb.matches(keyData, "selectUp")) {
if (this.filteredModels.length === 0) return;
this.selectedIndex = this.selectedIndex === 0 ? this.filteredModels.length - 1 : this.selectedIndex - 1;
this.updateList();
}
// Down arrow - wrap to top when at bottom
else if (kb.matches(keyData, "selectDown")) {
if (this.filteredModels.length === 0) return;
this.selectedIndex = this.selectedIndex === this.filteredModels.length - 1 ? 0 : this.selectedIndex + 1;
this.updateList();
}
// Enter
else if (kb.matches(keyData, "selectConfirm")) {
const selectedModel = this.filteredModels[this.selectedIndex];
if (selectedModel) {
this.handleSelect(selectedModel.model);
}
}
// Escape or Ctrl+C
else if (kb.matches(keyData, "selectCancel")) {
this.onCancelCallback();
}
// Pass everything else to search input
else {
this.searchInput.handleInput(keyData);
this.filterModels(this.searchInput.getValue());
}
}
private handleSelect(model: Model<any>): void {
// Save as new default
this.settingsManager.setDefaultModelAndProvider(model.provider, model.id);
this.onSelectCallback(model);
}
getSearchInput(): Input {
return this.searchInput;
}
}