From 9ed812ed54b8e70f871232473314d8b958126147 Mon Sep 17 00:00:00 2001 From: Flux Labs Date: Mon, 16 Mar 2026 07:23:18 -0500 Subject: [PATCH] feat: dynamic model discovery & provider management UX (#581) --- .plans/dynamic-model-discovery.md | 27 ++ .plans/preferences-wizard-completeness.md | 49 ++++ packages/pi-coding-agent/src/cli/args.ts | 21 ++ .../pi-coding-agent/src/cli/list-models.ts | 87 +++++-- .../src/core/discovery-cache.test.ts | 170 +++++++++++++ .../src/core/discovery-cache.ts | 97 ++++++++ .../src/core/model-discovery.test.ts | 125 ++++++++++ .../src/core/model-discovery.ts | 231 ++++++++++++++++++ .../src/core/model-registry-discovery.test.ts | 135 ++++++++++ .../src/core/model-registry.ts | 107 ++++++++ .../src/core/models-json-writer.test.ts | 145 +++++++++++ .../src/core/models-json-writer.ts | 188 ++++++++++++++ .../src/core/settings-manager.ts | 21 ++ .../src/core/slash-commands.ts | 1 + packages/pi-coding-agent/src/index.ts | 5 + packages/pi-coding-agent/src/main.ts | 21 +- .../src/modes/interactive/components/index.ts | 1 + .../interactive/components/model-selector.ts | 2 +- .../components/provider-manager.ts | 163 ++++++++++++ .../src/modes/interactive/interactive-mode.ts | 37 +++ src/resources/extensions/gsd/commands.ts | 176 ++++++++++++- .../gsd/docs/preferences-reference.md | 94 +++++++ src/resources/extensions/gsd/preferences.ts | 60 ++++- .../extensions/gsd/templates/preferences.md | 14 ++ .../tests/preferences-wizard-fields.test.ts | 168 +++++++++++++ 25 files changed, 2122 insertions(+), 23 deletions(-) create mode 100644 .plans/dynamic-model-discovery.md create mode 100644 .plans/preferences-wizard-completeness.md create mode 100644 packages/pi-coding-agent/src/core/discovery-cache.test.ts create mode 100644 packages/pi-coding-agent/src/core/discovery-cache.ts create mode 100644 packages/pi-coding-agent/src/core/model-discovery.test.ts create mode 100644 packages/pi-coding-agent/src/core/model-discovery.ts create mode 100644 packages/pi-coding-agent/src/core/model-registry-discovery.test.ts create mode 100644 packages/pi-coding-agent/src/core/models-json-writer.test.ts create mode 100644 packages/pi-coding-agent/src/core/models-json-writer.ts create mode 100644 packages/pi-coding-agent/src/modes/interactive/components/provider-manager.ts create mode 100644 src/resources/extensions/gsd/tests/preferences-wizard-fields.test.ts diff --git a/.plans/dynamic-model-discovery.md b/.plans/dynamic-model-discovery.md new file mode 100644 index 000000000..00267f353 --- /dev/null +++ b/.plans/dynamic-model-discovery.md @@ -0,0 +1,27 @@ +# Dynamic Model Discovery + +## Overview +Runtime model discovery from provider APIs with caching, TUI management, and CLI flags. + +## Components +1. **model-discovery.ts** — Provider adapters (OpenAI, Ollama, OpenRouter, Google) + static adapters +2. **discovery-cache.ts** — Disk cache at `{agentDir}/discovery-cache.json` with per-provider TTLs +3. **models-json-writer.ts** — Safe read-modify-write for `models.json` with file locking +4. **provider-manager.ts** — TUI component for provider management (`/provider` command) +5. **model-registry.ts** — Extended with `discoverModels()`, `getAllWithDiscovered()`, cache integration +6. **settings-manager.ts** — `modelDiscovery` settings (enabled, providers, ttlMinutes, autoRefreshOnModelSelect) +7. **args.ts** — `--discover`, `--add-provider`, `--base-url`, `--discover-models` CLI flags +8. **list-models.ts** — Rewritten with `[discovered]` badge support +9. **main.ts** — CLI handlers for new flags +10. **interactive-mode.ts** — `/provider` command handler +11. **preferences.ts** — `updatePreferencesModels()` and `validateModelId()` helpers + +## TTL Strategy +- Ollama: 5 min (local, models change often) +- OpenAI / Google / OpenRouter: 1 hour +- Default: 24 hours + +## Merge Rules +- Discovered models never override existing built-in or custom models +- Discovered models are appended to the registry with `[discovered]` badge +- Background discovery is opt-in via `modelDiscovery.enabled` setting diff --git a/.plans/preferences-wizard-completeness.md b/.plans/preferences-wizard-completeness.md new file mode 100644 index 000000000..5709d7f21 --- /dev/null +++ b/.plans/preferences-wizard-completeness.md @@ -0,0 +1,49 @@ +# Preferences Wizard Completeness + +## Problem +The `/gsd prefs wizard` currently only configures 6 of 18+ preference fields. Users must hand-edit YAML for the rest. + +## Current Wizard Coverage +1. Models (per phase) ✓ +2. Auto-supervisor timeouts ✓ +3. Git main_branch ✓ +4. Skill discovery mode ✓ +5. Unique milestone IDs ✓ + +## Missing Fields to Add + +### Group 1: Git Settings (expand existing section) +- `auto_push` (boolean) — auto-push commits ✓ +- `push_branches` (boolean) — push milestone branches ✓ +- `remote` (string) — git remote name ✓ +- `snapshots` (boolean) — WIP snapshot commits ✓ +- `pre_merge_check` (boolean | "auto") — pre-merge validation ✓ +- `commit_type` (select) — conventional commit prefix ✓ +- `merge_strategy` (select) — squash vs merge ✓ +- `isolation` (select) — worktree vs branch ✓ + +### Group 2: Budget & Cost Control ✓ +- `budget_ceiling` (number) — dollar limit +- `budget_enforcement` (select: warn/pause/halt) +- `context_pause_threshold` (number 0-100) + +### Group 3: Notifications ✓ +- `notifications.enabled` (boolean) +- `notifications.on_complete` (boolean) +- `notifications.on_error` (boolean) +- `notifications.on_budget` (boolean) +- `notifications.on_milestone` (boolean) +- `notifications.on_attention` (boolean) + +### Group 4: Behavior Toggles ✓ +- `uat_dispatch` (boolean) + +### Group 5: Update Serialization Order ✓ +- Added missing keys to `orderedKeys` in `serializePreferencesToFrontmatter()` + +### Group 6: Update Template & Docs ✓ +- Updated `templates/preferences.md` with new fields +- Updated `docs/preferences-reference.md` with budget, notifications, git, hooks + +### Group 7: Tests ✓ +- Added `preferences-wizard-fields.test.ts` covering all new fields diff --git a/packages/pi-coding-agent/src/cli/args.ts b/packages/pi-coding-agent/src/cli/args.ts index 40306049c..101e67da5 100644 --- a/packages/pi-coding-agent/src/cli/args.ts +++ b/packages/pi-coding-agent/src/cli/args.ts @@ -38,6 +38,11 @@ export interface Args { themes?: string[]; noThemes?: boolean; listModels?: string | true; + discover?: boolean; + addProvider?: string; + addProviderBaseUrl?: string; + addProviderApiKey?: string; + discoverModels?: string | true; offline?: boolean; verbose?: boolean; messages: string[]; @@ -150,6 +155,18 @@ export function parseArgs(args: string[], extensionFlags?: Map Export session file to HTML and exit --list-models [search] List available models (with optional fuzzy search) + --discover Include discovered models in --list-models output + --discover-models [provider] Discover models from provider APIs (all or specific) + --add-provider Add a provider to models.json (use with --base-url, --api-key) + --base-url Base URL for --add-provider --verbose Force verbose startup (overrides quietStartup setting) --offline Disable startup network operations (same as PI_OFFLINE=1) --help, -h Show this help diff --git a/packages/pi-coding-agent/src/cli/list-models.ts b/packages/pi-coding-agent/src/cli/list-models.ts index 72c276cda..b611c271d 100644 --- a/packages/pi-coding-agent/src/cli/list-models.ts +++ b/packages/pi-coding-agent/src/cli/list-models.ts @@ -1,11 +1,18 @@ /** - * List available models with optional fuzzy search + * List available models with optional fuzzy search and discovery support */ import type { Api, Model } from "@gsd/pi-ai"; import { fuzzyFilter } from "@gsd/pi-tui"; import type { ModelRegistry } from "../core/model-registry.js"; +export interface ListModelsOptions { + /** Include discovered models in output */ + discover?: boolean; + /** Search pattern for fuzzy filtering */ + searchPattern?: string; +} + /** * Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M") */ @@ -22,10 +29,48 @@ function formatTokenCount(count: number): string { } /** - * List available models, optionally filtered by search pattern + * Discover models from provider APIs and print results. */ -export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise { - const models = modelRegistry.getAvailable(); +export async function discoverAndPrintModels( + modelRegistry: ModelRegistry, + provider?: string, +): Promise { + const providers = provider ? [provider] : undefined; + + console.log("Discovering models..."); + const results = await modelRegistry.discoverModels(providers); + + for (const result of results) { + if (result.error) { + console.log(` ${result.provider}: error - ${result.error}`); + } else { + console.log(` ${result.provider}: ${result.models.length} models found`); + } + } +} + +/** + * List available models, optionally filtered by search pattern. + * Accepts either a string (backward compat) or ListModelsOptions. + */ +export async function listModels( + modelRegistry: ModelRegistry, + optionsOrSearch?: string | ListModelsOptions, +): Promise { + const options: ListModelsOptions = + typeof optionsOrSearch === "string" + ? { searchPattern: optionsOrSearch } + : optionsOrSearch ?? {}; + + // If discover flag is set, run discovery first + if (options.discover) { + await modelRegistry.discoverModels(); + } + + // Get models — include discovered if discovery was run + const models = options.discover + ? modelRegistry.getAllWithDiscovered() + : modelRegistry.getAvailable(); if (models.length === 0) { console.log("No models available. Set API keys in environment variables."); @@ -34,12 +79,12 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s // Apply fuzzy filter if search pattern provided let filteredModels: Model[] = models; - if (searchPattern) { - filteredModels = fuzzyFilter(models, searchPattern, (m) => `${m.provider} ${m.id}`); + if (options.searchPattern) { + filteredModels = fuzzyFilter(models, options.searchPattern, (m) => `${m.provider} ${m.id}`); } if (filteredModels.length === 0) { - console.log(`No models matching "${searchPattern}"`); + console.log(`No models matching "${options.searchPattern}"`); return; } @@ -53,15 +98,19 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s }); // Calculate column widths - const rows = filteredModels.map((m) => ({ - provider: m.provider, - model: m.id, - name: m.name, - context: formatTokenCount(m.contextWindow), - maxOut: formatTokenCount(m.maxTokens), - thinking: m.reasoning ? "yes" : "no", - images: m.input.includes("image") ? "yes" : "no", - })); + const rows = filteredModels.map((m) => { + const isDiscovered = options.discover && modelRegistry.isDiscovered(m); + return { + provider: m.provider, + model: m.id, + name: m.name, + context: formatTokenCount(m.contextWindow), + maxOut: formatTokenCount(m.maxTokens), + thinking: m.reasoning ? "yes" : "no", + images: m.input.includes("image") ? "yes" : "no", + badge: isDiscovered ? "[discovered]" : "", + }; + }); const headers = { provider: "provider", @@ -71,6 +120,7 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s maxOut: "max-out", thinking: "thinking", images: "images", + badge: "", }; const widths = { @@ -105,7 +155,10 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s row.maxOut.padEnd(widths.maxOut), row.thinking.padEnd(widths.thinking), row.images.padEnd(widths.images), - ].join(" "); + row.badge, + ] + .join(" ") + .trimEnd(); console.log(line); } } diff --git a/packages/pi-coding-agent/src/core/discovery-cache.test.ts b/packages/pi-coding-agent/src/core/discovery-cache.test.ts new file mode 100644 index 000000000..4c5e8a245 --- /dev/null +++ b/packages/pi-coding-agent/src/core/discovery-cache.test.ts @@ -0,0 +1,170 @@ +import assert from "node:assert/strict"; +import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { afterEach, beforeEach, describe, it } from "node:test"; +import { ModelDiscoveryCache } from "./discovery-cache.js"; + +let testDir: string; +let cachePath: string; + +beforeEach(() => { + testDir = join(tmpdir(), `discovery-cache-test-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(testDir, { recursive: true }); + cachePath = join(testDir, "discovery-cache.json"); +}); + +afterEach(() => { + try { + rmSync(testDir, { recursive: true, force: true }); + } catch { + // Cleanup best-effort + } +}); + +// ─── basic operations ──────────────────────────────────────────────────────── + +describe("ModelDiscoveryCache — basic operations", () => { + it("starts with no entries", () => { + const cache = new ModelDiscoveryCache(cachePath); + assert.equal(cache.get("openai"), undefined); + }); + + it("stores and retrieves models", () => { + const cache = new ModelDiscoveryCache(cachePath); + const models = [{ id: "gpt-4o", name: "GPT-4o" }]; + cache.set("openai", models); + + const entry = cache.get("openai"); + assert.ok(entry); + assert.deepEqual(entry.models, models); + assert.ok(entry.fetchedAt > 0); + assert.ok(entry.ttlMs > 0); + }); + + it("persists to disk and reloads", () => { + const cache1 = new ModelDiscoveryCache(cachePath); + cache1.set("openai", [{ id: "gpt-4o" }]); + + const cache2 = new ModelDiscoveryCache(cachePath); + const entry = cache2.get("openai"); + assert.ok(entry); + assert.equal(entry.models[0].id, "gpt-4o"); + }); + + it("clear removes a specific provider", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + cache.set("google", [{ id: "gemini-pro" }]); + + cache.clear("openai"); + assert.equal(cache.get("openai"), undefined); + assert.ok(cache.get("google")); + }); + + it("clear without provider removes all entries", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + cache.set("google", [{ id: "gemini-pro" }]); + + cache.clear(); + assert.equal(cache.get("openai"), undefined); + assert.equal(cache.get("google"), undefined); + }); +}); + +// ─── staleness ─────────────────────────────────────────────────────────────── + +describe("ModelDiscoveryCache — staleness", () => { + it("newly set entries are not stale", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + assert.equal(cache.isStale("openai"), false); + }); + + it("missing providers are stale", () => { + const cache = new ModelDiscoveryCache(cachePath); + assert.equal(cache.isStale("unknown"), true); + }); + + it("entries with expired TTL are stale", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }], 1); // 1ms TTL + + // Wait for TTL to expire + const start = Date.now(); + while (Date.now() - start < 5) { + // busy wait + } + + assert.equal(cache.isStale("openai"), true); + }); +}); + +// ─── getAll ────────────────────────────────────────────────────────────────── + +describe("ModelDiscoveryCache — getAll", () => { + it("returns non-stale entries by default", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + cache.set("stale", [{ id: "old" }], 1); + + // Wait for stale TTL + const start = Date.now(); + while (Date.now() - start < 5) { + // busy wait + } + + const all = cache.getAll(); + assert.ok(all.has("openai")); + assert.ok(!all.has("stale")); + }); + + it("returns all entries when includeStale is true", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + cache.set("stale", [{ id: "old" }], 1); + + // Wait for stale TTL + const start = Date.now(); + while (Date.now() - start < 5) { + // busy wait + } + + const all = cache.getAll(true); + assert.ok(all.has("openai")); + assert.ok(all.has("stale")); + }); +}); + +// ─── edge cases ────────────────────────────────────────────────────────────── + +describe("ModelDiscoveryCache — edge cases", () => { + it("handles corrupted cache file gracefully", () => { + writeFileSync(cachePath, "not valid json", "utf-8"); + const cache = new ModelDiscoveryCache(cachePath); + assert.equal(cache.get("openai"), undefined); + }); + + it("handles wrong version gracefully", () => { + writeFileSync(cachePath, JSON.stringify({ version: 99, entries: {} }), "utf-8"); + const cache = new ModelDiscoveryCache(cachePath); + assert.equal(cache.get("openai"), undefined); + }); + + it("handles missing cache file", () => { + const cache = new ModelDiscoveryCache(join(testDir, "nonexistent", "cache.json")); + assert.equal(cache.get("openai"), undefined); + }); + + it("overwrites existing entry for same provider", () => { + const cache = new ModelDiscoveryCache(cachePath); + cache.set("openai", [{ id: "gpt-4o" }]); + cache.set("openai", [{ id: "gpt-4o-mini" }]); + + const entry = cache.get("openai"); + assert.ok(entry); + assert.equal(entry.models.length, 1); + assert.equal(entry.models[0].id, "gpt-4o-mini"); + }); +}); diff --git a/packages/pi-coding-agent/src/core/discovery-cache.ts b/packages/pi-coding-agent/src/core/discovery-cache.ts new file mode 100644 index 000000000..a75633c2f --- /dev/null +++ b/packages/pi-coding-agent/src/core/discovery-cache.ts @@ -0,0 +1,97 @@ +/** + * Disk-based cache for discovered models. + * Stores results at {agentDir}/discovery-cache.json with per-provider TTLs. + */ + +import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; +import { dirname, join } from "path"; +import { getAgentDir } from "../config.js"; +import { type DiscoveredModel, getDefaultTTL } from "./model-discovery.js"; + +export interface DiscoveryCacheEntry { + models: DiscoveredModel[]; + fetchedAt: number; + ttlMs: number; +} + +export interface DiscoveryCacheData { + version: 1; + entries: Record; +} + +export class ModelDiscoveryCache { + private data: DiscoveryCacheData; + private cachePath: string; + + constructor(cachePath?: string) { + this.cachePath = cachePath ?? join(getAgentDir(), "discovery-cache.json"); + this.data = { version: 1, entries: {} }; + this.load(); + } + + get(provider: string): DiscoveryCacheEntry | undefined { + const entry = this.data.entries[provider]; + return entry; + } + + set(provider: string, models: DiscoveredModel[], ttlMs?: number): void { + this.data.entries[provider] = { + models, + fetchedAt: Date.now(), + ttlMs: ttlMs ?? getDefaultTTL(provider), + }; + this.save(); + } + + isStale(provider: string): boolean { + const entry = this.data.entries[provider]; + if (!entry) return true; + return Date.now() - entry.fetchedAt > entry.ttlMs; + } + + clear(provider?: string): void { + if (provider) { + delete this.data.entries[provider]; + } else { + this.data.entries = {}; + } + this.save(); + } + + getAll(includeStale = false): Map { + const result = new Map(); + for (const [provider, entry] of Object.entries(this.data.entries)) { + if (includeStale || !this.isStale(provider)) { + result.set(provider, entry); + } + } + return result; + } + + load(): void { + try { + if (existsSync(this.cachePath)) { + const content = readFileSync(this.cachePath, "utf-8"); + const parsed = JSON.parse(content) as DiscoveryCacheData; + if (parsed.version === 1 && parsed.entries) { + this.data = parsed; + } + } + } catch { + // Corrupted or unreadable cache — start fresh + this.data = { version: 1, entries: {} }; + } + } + + save(): void { + try { + const dir = dirname(this.cachePath); + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true }); + } + writeFileSync(this.cachePath, JSON.stringify(this.data, null, 2), "utf-8"); + } catch { + // Silently ignore write failures (read-only FS, permissions, etc.) + } + } +} diff --git a/packages/pi-coding-agent/src/core/model-discovery.test.ts b/packages/pi-coding-agent/src/core/model-discovery.test.ts new file mode 100644 index 000000000..43a35a7a3 --- /dev/null +++ b/packages/pi-coding-agent/src/core/model-discovery.test.ts @@ -0,0 +1,125 @@ +import assert from "node:assert/strict"; +import { describe, it } from "node:test"; +import { + DISCOVERY_TTLS, + getDefaultTTL, + getDiscoverableProviders, + getDiscoveryAdapter, +} from "./model-discovery.js"; + +// ─── getDiscoveryAdapter ───────────────────────────────────────────────────── + +describe("getDiscoveryAdapter", () => { + it("returns an adapter for openai", () => { + const adapter = getDiscoveryAdapter("openai"); + assert.equal(adapter.provider, "openai"); + assert.equal(adapter.supportsDiscovery, true); + }); + + it("returns an adapter for ollama", () => { + const adapter = getDiscoveryAdapter("ollama"); + assert.equal(adapter.provider, "ollama"); + assert.equal(adapter.supportsDiscovery, true); + }); + + it("returns an adapter for openrouter", () => { + const adapter = getDiscoveryAdapter("openrouter"); + assert.equal(adapter.provider, "openrouter"); + assert.equal(adapter.supportsDiscovery, true); + }); + + it("returns an adapter for google", () => { + const adapter = getDiscoveryAdapter("google"); + assert.equal(adapter.provider, "google"); + assert.equal(adapter.supportsDiscovery, true); + }); + + it("returns a static adapter for anthropic", () => { + const adapter = getDiscoveryAdapter("anthropic"); + assert.equal(adapter.provider, "anthropic"); + assert.equal(adapter.supportsDiscovery, false); + }); + + it("returns a static adapter for bedrock", () => { + const adapter = getDiscoveryAdapter("bedrock"); + assert.equal(adapter.provider, "bedrock"); + assert.equal(adapter.supportsDiscovery, false); + }); + + it("returns a static adapter for unknown providers", () => { + const adapter = getDiscoveryAdapter("unknown-provider"); + assert.equal(adapter.provider, "unknown-provider"); + assert.equal(adapter.supportsDiscovery, false); + }); + + it("static adapter fetchModels returns empty array", async () => { + const adapter = getDiscoveryAdapter("anthropic"); + const models = await adapter.fetchModels("key"); + assert.deepEqual(models, []); + }); +}); + +// ─── getDiscoverableProviders ──────────────────────────────────────────────── + +describe("getDiscoverableProviders", () => { + it("returns only providers that support discovery", () => { + const providers = getDiscoverableProviders(); + assert.ok(providers.includes("openai")); + assert.ok(providers.includes("ollama")); + assert.ok(providers.includes("openrouter")); + assert.ok(providers.includes("google")); + assert.ok(!providers.includes("anthropic")); + assert.ok(!providers.includes("bedrock")); + }); + + it("returns an array of strings", () => { + const providers = getDiscoverableProviders(); + assert.ok(Array.isArray(providers)); + for (const p of providers) { + assert.equal(typeof p, "string"); + } + }); +}); + +// ─── getDefaultTTL ─────────────────────────────────────────────────────────── + +describe("getDefaultTTL", () => { + it("returns 5 minutes for ollama", () => { + assert.equal(getDefaultTTL("ollama"), 5 * 60 * 1000); + }); + + it("returns 1 hour for openai", () => { + assert.equal(getDefaultTTL("openai"), 60 * 60 * 1000); + }); + + it("returns 1 hour for google", () => { + assert.equal(getDefaultTTL("google"), 60 * 60 * 1000); + }); + + it("returns 1 hour for openrouter", () => { + assert.equal(getDefaultTTL("openrouter"), 60 * 60 * 1000); + }); + + it("returns 24 hours for unknown providers", () => { + assert.equal(getDefaultTTL("some-custom"), 24 * 60 * 60 * 1000); + }); +}); + +// ─── DISCOVERY_TTLS ────────────────────────────────────────────────────────── + +describe("DISCOVERY_TTLS", () => { + it("has expected keys", () => { + assert.ok("ollama" in DISCOVERY_TTLS); + assert.ok("openai" in DISCOVERY_TTLS); + assert.ok("google" in DISCOVERY_TTLS); + assert.ok("openrouter" in DISCOVERY_TTLS); + assert.ok("default" in DISCOVERY_TTLS); + }); + + it("all values are positive numbers", () => { + for (const [, value] of Object.entries(DISCOVERY_TTLS)) { + assert.equal(typeof value, "number"); + assert.ok(value > 0); + } + }); +}); diff --git a/packages/pi-coding-agent/src/core/model-discovery.ts b/packages/pi-coding-agent/src/core/model-discovery.ts new file mode 100644 index 000000000..7e8ce3372 --- /dev/null +++ b/packages/pi-coding-agent/src/core/model-discovery.ts @@ -0,0 +1,231 @@ +/** + * Provider discovery adapters for runtime model enumeration. + * Each adapter implements ProviderDiscoveryAdapter to fetch models from provider APIs. + */ + +export interface DiscoveredModel { + id: string; + name?: string; + contextWindow?: number; + maxTokens?: number; + reasoning?: boolean; + input?: ("text" | "image")[]; + cost?: { input: number; output: number; cacheRead: number; cacheWrite: number }; +} + +export interface DiscoveryResult { + provider: string; + models: DiscoveredModel[]; + fetchedAt: number; + error?: string; +} + +export interface ProviderDiscoveryAdapter { + provider: string; + supportsDiscovery: boolean; + fetchModels(apiKey: string, baseUrl?: string): Promise; +} + +/** Per-provider TTLs in milliseconds */ +export const DISCOVERY_TTLS: Record = { + ollama: 5 * 60 * 1000, // 5 minutes (local, models change often) + openai: 60 * 60 * 1000, // 1 hour + google: 60 * 60 * 1000, // 1 hour + openrouter: 60 * 60 * 1000, // 1 hour + default: 24 * 60 * 60 * 1000, // 24 hours +}; + +export function getDefaultTTL(provider: string): number { + return DISCOVERY_TTLS[provider] ?? DISCOVERY_TTLS.default; +} + +async function fetchWithTimeout(url: string, options: RequestInit = {}, timeoutMs = 5000): Promise { + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), timeoutMs); + try { + return await fetch(url, { ...options, signal: controller.signal }); + } finally { + clearTimeout(timeout); + } +} + +// ─── OpenAI Adapter ────────────────────────────────────────────────────────── + +const OPENAI_EXCLUDED_PREFIXES = ["embedding", "tts", "dall-e", "whisper", "text-embedding", "davinci", "babbage"]; + +class OpenAIDiscoveryAdapter implements ProviderDiscoveryAdapter { + provider = "openai"; + supportsDiscovery = true; + + async fetchModels(apiKey: string, baseUrl?: string): Promise { + const url = `${baseUrl ?? "https://api.openai.com"}/v1/models`; + const response = await fetchWithTimeout(url, { + headers: { Authorization: `Bearer ${apiKey}` }, + }); + + if (!response.ok) { + throw new Error(`OpenAI models API returned ${response.status}: ${response.statusText}`); + } + + const data = (await response.json()) as { data: Array<{ id: string; owned_by?: string }> }; + return data.data + .filter((m) => !OPENAI_EXCLUDED_PREFIXES.some((prefix) => m.id.startsWith(prefix))) + .map((m) => ({ + id: m.id, + name: m.id, + input: ["text" as const, "image" as const], + })); + } +} + +// ─── Ollama Adapter ────────────────────────────────────────────────────────── + +class OllamaDiscoveryAdapter implements ProviderDiscoveryAdapter { + provider = "ollama"; + supportsDiscovery = true; + + async fetchModels(_apiKey: string, baseUrl?: string): Promise { + const url = `${baseUrl ?? "http://localhost:11434"}/api/tags`; + const response = await fetchWithTimeout(url); + + if (!response.ok) { + throw new Error(`Ollama tags API returned ${response.status}: ${response.statusText}`); + } + + const data = (await response.json()) as { + models: Array<{ name: string; size: number; details?: { parameter_size?: string } }>; + }; + + return (data.models ?? []).map((m) => ({ + id: m.name, + name: m.name, + input: ["text" as const], + })); + } +} + +// ─── OpenRouter Adapter ────────────────────────────────────────────────────── + +class OpenRouterDiscoveryAdapter implements ProviderDiscoveryAdapter { + provider = "openrouter"; + supportsDiscovery = true; + + async fetchModels(apiKey: string, baseUrl?: string): Promise { + const url = `${baseUrl ?? "https://openrouter.ai"}/api/v1/models`; + const response = await fetchWithTimeout(url, { + headers: { Authorization: `Bearer ${apiKey}` }, + }); + + if (!response.ok) { + throw new Error(`OpenRouter models API returned ${response.status}: ${response.statusText}`); + } + + const data = (await response.json()) as { + data: Array<{ + id: string; + name: string; + context_length?: number; + top_provider?: { max_completion_tokens?: number }; + pricing?: { prompt: string; completion: string }; + }>; + }; + + return (data.data ?? []).map((m) => { + const cost = + m.pricing?.prompt !== undefined && m.pricing?.completion !== undefined + ? { + input: parseFloat(m.pricing.prompt) * 1_000_000, + output: parseFloat(m.pricing.completion) * 1_000_000, + cacheRead: 0, + cacheWrite: 0, + } + : undefined; + + return { + id: m.id, + name: m.name, + contextWindow: m.context_length, + maxTokens: m.top_provider?.max_completion_tokens, + cost, + input: ["text" as const, "image" as const], + }; + }); + } +} + +// ─── Google/Gemini Adapter ─────────────────────────────────────────────────── + +class GoogleDiscoveryAdapter implements ProviderDiscoveryAdapter { + provider = "google"; + supportsDiscovery = true; + + async fetchModels(apiKey: string, baseUrl?: string): Promise { + const url = `${baseUrl ?? "https://generativelanguage.googleapis.com"}/v1beta/models?key=${apiKey}`; + const response = await fetchWithTimeout(url); + + if (!response.ok) { + throw new Error(`Google models API returned ${response.status}: ${response.statusText}`); + } + + const data = (await response.json()) as { + models: Array<{ + name: string; + displayName: string; + supportedGenerationMethods?: string[]; + inputTokenLimit?: number; + outputTokenLimit?: number; + }>; + }; + + return (data.models ?? []) + .filter((m) => m.supportedGenerationMethods?.includes("generateContent")) + .map((m) => ({ + id: m.name.replace("models/", ""), + name: m.displayName, + contextWindow: m.inputTokenLimit, + maxTokens: m.outputTokenLimit, + input: ["text" as const, "image" as const], + })); + } +} + +// ─── Static Adapter (no discovery) ─────────────────────────────────────────── + +class StaticDiscoveryAdapter implements ProviderDiscoveryAdapter { + provider: string; + supportsDiscovery = false; + + constructor(provider: string) { + this.provider = provider; + } + + async fetchModels(): Promise { + return []; + } +} + +// ─── Registry ──────────────────────────────────────────────────────────────── + +const adapters: Record = { + openai: new OpenAIDiscoveryAdapter(), + ollama: new OllamaDiscoveryAdapter(), + openrouter: new OpenRouterDiscoveryAdapter(), + google: new GoogleDiscoveryAdapter(), + anthropic: new StaticDiscoveryAdapter("anthropic"), + bedrock: new StaticDiscoveryAdapter("bedrock"), + "azure-openai": new StaticDiscoveryAdapter("azure-openai"), + groq: new StaticDiscoveryAdapter("groq"), + cerebras: new StaticDiscoveryAdapter("cerebras"), + xai: new StaticDiscoveryAdapter("xai"), + mistral: new StaticDiscoveryAdapter("mistral"), +}; + +export function getDiscoveryAdapter(provider: string): ProviderDiscoveryAdapter { + return adapters[provider] ?? new StaticDiscoveryAdapter(provider); +} + +export function getDiscoverableProviders(): string[] { + return Object.entries(adapters) + .filter(([, adapter]) => adapter.supportsDiscovery) + .map(([name]) => name); +} diff --git a/packages/pi-coding-agent/src/core/model-registry-discovery.test.ts b/packages/pi-coding-agent/src/core/model-registry-discovery.test.ts new file mode 100644 index 000000000..223c5b471 --- /dev/null +++ b/packages/pi-coding-agent/src/core/model-registry-discovery.test.ts @@ -0,0 +1,135 @@ +import assert from "node:assert/strict"; +import { mkdirSync, rmSync, writeFileSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { afterEach, beforeEach, describe, it } from "node:test"; +import { AuthStorage } from "./auth-storage.js"; +import { ModelDiscoveryCache } from "./discovery-cache.js"; +import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js"; + +let testDir: string; + +beforeEach(() => { + testDir = join(tmpdir(), `model-registry-discovery-test-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(testDir, { recursive: true }); +}); + +afterEach(() => { + try { + rmSync(testDir, { recursive: true, force: true }); + } catch { + // Cleanup best-effort + } +}); + +// ─── discovery cache integration ───────────────────────────────────────────── + +describe("ModelDiscoveryCache — integration with discovery", () => { + it("cache respects provider-specific TTLs", () => { + const cachePath = join(testDir, "cache.json"); + const cache = new ModelDiscoveryCache(cachePath); + + cache.set("ollama", [{ id: "llama2" }]); + const entry = cache.get("ollama"); + assert.ok(entry); + assert.equal(entry.ttlMs, getDefaultTTL("ollama")); + }); + + it("cache uses custom TTL when provided", () => { + const cachePath = join(testDir, "cache.json"); + const cache = new ModelDiscoveryCache(cachePath); + + cache.set("openai", [{ id: "gpt-4o" }], 999); + const entry = cache.get("openai"); + assert.ok(entry); + assert.equal(entry.ttlMs, 999); + }); +}); + +// ─── adapter resolution ───────────────────────────────────────────────────── + +describe("Discovery adapter resolution", () => { + it("all discoverable providers have adapters", () => { + const providers = getDiscoverableProviders(); + for (const provider of providers) { + const adapter = getDiscoveryAdapter(provider); + assert.equal(adapter.supportsDiscovery, true, `${provider} should support discovery`); + } + }); + + it("static adapters return empty model lists", async () => { + const staticProviders = ["anthropic", "bedrock", "azure-openai", "groq", "cerebras"]; + for (const provider of staticProviders) { + const adapter = getDiscoveryAdapter(provider); + assert.equal(adapter.supportsDiscovery, false, `${provider} should not support discovery`); + const models = await adapter.fetchModels("dummy-key"); + assert.deepEqual(models, [], `${provider} should return empty models`); + } + }); +}); + +// ─── AuthStorage hasAuth for discovery ─────────────────────────────────────── + +describe("AuthStorage — hasAuth for discovery providers", () => { + it("returns false for providers without auth", () => { + const storage = AuthStorage.inMemory({}); + assert.equal(storage.hasAuth("openai"), false); + assert.equal(storage.hasAuth("ollama"), false); + }); + + it("returns true for providers with stored keys", () => { + const storage = AuthStorage.inMemory({ + openai: { type: "api_key" as const, key: "sk-test" }, + }); + assert.equal(storage.hasAuth("openai"), true); + assert.equal(storage.hasAuth("ollama"), false); + }); +}); + +// ─── cache persistence across instances ────────────────────────────────────── + +describe("ModelDiscoveryCache — persistence", () => { + it("data survives across cache instances", () => { + const cachePath = join(testDir, "persist.json"); + + const cache1 = new ModelDiscoveryCache(cachePath); + cache1.set("openai", [ + { id: "gpt-4o", name: "GPT-4o", contextWindow: 128000 }, + { id: "gpt-4o-mini", name: "GPT-4o Mini" }, + ]); + + const cache2 = new ModelDiscoveryCache(cachePath); + const entry = cache2.get("openai"); + assert.ok(entry); + assert.equal(entry.models.length, 2); + assert.equal(entry.models[0].contextWindow, 128000); + }); + + it("clear persists across instances", () => { + const cachePath = join(testDir, "clear.json"); + + const cache1 = new ModelDiscoveryCache(cachePath); + cache1.set("openai", [{ id: "gpt-4o" }]); + cache1.clear("openai"); + + const cache2 = new ModelDiscoveryCache(cachePath); + assert.equal(cache2.get("openai"), undefined); + }); +}); + +// ─── discovery TTL values ──────────────────────────────────────────────────── + +describe("Discovery TTL configuration", () => { + it("ollama has shortest TTL (local models change often)", () => { + const ollamaTTL = getDefaultTTL("ollama"); + const openaiTTL = getDefaultTTL("openai"); + assert.ok(ollamaTTL < openaiTTL, "ollama TTL should be shorter than openai"); + }); + + it("unknown providers get default TTL", () => { + const customTTL = getDefaultTTL("my-custom-provider"); + const defaultTTL = getDefaultTTL("default"); + // Unknown providers should get the same TTL as the explicit "default" key + assert.equal(customTTL, defaultTTL); + }); +}); diff --git a/packages/pi-coding-agent/src/core/model-registry.ts b/packages/pi-coding-agent/src/core/model-registry.ts index 6d90af67f..a38068ccb 100644 --- a/packages/pi-coding-agent/src/core/model-registry.ts +++ b/packages/pi-coding-agent/src/core/model-registry.ts @@ -24,6 +24,9 @@ import { existsSync, readFileSync } from "fs"; import { join } from "path"; import { getAgentDir } from "../config.js"; import type { AuthStorage } from "./auth-storage.js"; +import { ModelDiscoveryCache } from "./discovery-cache.js"; +import type { DiscoveredModel, DiscoveryResult } from "./model-discovery.js"; +import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js"; import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js"; const Ajv = (AjvModule as any).default || AjvModule; @@ -221,6 +224,8 @@ export const clearApiKeyCache = clearConfigValueCache; */ export class ModelRegistry { private models: Model[] = []; + private discoveredModels: Model[] = []; + private discoveryCache: ModelDiscoveryCache; private customProviderApiKeys: Map = new Map(); private registeredProviders: Map = new Map(); private loadError: string | undefined = undefined; @@ -229,6 +234,8 @@ export class ModelRegistry { readonly authStorage: AuthStorage, private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"), ) { + this.discoveryCache = new ModelDiscoveryCache(); + // Set up fallback resolver for custom provider API keys this.authStorage.setFallbackResolver((provider) => { const keyConfig = this.customProviderApiKeys.get(provider); @@ -666,6 +673,106 @@ export class ModelRegistry { }); } } + + /** + * Discover models from all providers that support discovery. + * Results are cached and merged into the registry (never overrides existing models). + */ + async discoverModels(providers?: string[]): Promise { + const targetProviders = providers ?? getDiscoverableProviders(); + const results: DiscoveryResult[] = []; + + for (const providerName of targetProviders) { + const adapter = getDiscoveryAdapter(providerName); + if (!adapter.supportsDiscovery) continue; + + // Skip if cache is still fresh + if (!this.discoveryCache.isStale(providerName)) { + const cached = this.discoveryCache.get(providerName); + if (cached) { + results.push({ + provider: providerName, + models: cached.models, + fetchedAt: cached.fetchedAt, + }); + continue; + } + } + + try { + const apiKey = await this.authStorage.getApiKey(providerName); + if (!apiKey && providerName !== "ollama") continue; + + const models = await adapter.fetchModels(apiKey ?? "", undefined); + this.discoveryCache.set(providerName, models); + results.push({ + provider: providerName, + models, + fetchedAt: Date.now(), + }); + } catch (error) { + results.push({ + provider: providerName, + models: [], + fetchedAt: Date.now(), + error: error instanceof Error ? error.message : String(error), + }); + } + } + + // Convert and merge discovered models + this.discoveredModels = this.convertDiscoveredModels(results); + return results; + } + + /** + * Get all models including discovered ones. + * Discovered models are appended but never override existing models. + */ + getAllWithDiscovered(): Model[] { + const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`)); + const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`)); + return [...this.models, ...unique]; + } + + /** + * Check if a model was added via discovery (not built-in or custom). + */ + isDiscovered(model: Model): boolean { + return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id); + } + + /** + * Get the discovery cache instance. + */ + getDiscoveryCache(): ModelDiscoveryCache { + return this.discoveryCache; + } + + /** + * Convert DiscoveryResult[] into Model[] with default values. + */ + private convertDiscoveredModels(results: DiscoveryResult[]): Model[] { + const converted: Model[] = []; + for (const result of results) { + if (result.error) continue; + for (const dm of result.models) { + converted.push({ + id: dm.id, + name: dm.name ?? dm.id, + api: "openai" as Api, + provider: result.provider, + baseUrl: "", + reasoning: dm.reasoning ?? false, + input: dm.input ?? ["text"], + cost: dm.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: dm.contextWindow ?? 128000, + maxTokens: dm.maxTokens ?? 16384, + } as Model); + } + } + return converted; + } } /** diff --git a/packages/pi-coding-agent/src/core/models-json-writer.test.ts b/packages/pi-coding-agent/src/core/models-json-writer.test.ts new file mode 100644 index 000000000..3dcb0be98 --- /dev/null +++ b/packages/pi-coding-agent/src/core/models-json-writer.test.ts @@ -0,0 +1,145 @@ +import assert from "node:assert/strict"; +import { existsSync, mkdirSync, readFileSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { afterEach, beforeEach, describe, it } from "node:test"; +import { ModelsJsonWriter } from "./models-json-writer.js"; + +let testDir: string; +let modelsJsonPath: string; + +beforeEach(() => { + testDir = join(tmpdir(), `models-json-writer-test-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(testDir, { recursive: true }); + modelsJsonPath = join(testDir, "models.json"); +}); + +afterEach(() => { + try { + rmSync(testDir, { recursive: true, force: true }); + } catch { + // Cleanup best-effort + } +}); + +function readModels(): Record { + return JSON.parse(readFileSync(modelsJsonPath, "utf-8")); +} + +// ─── addModel ──────────────────────────────────────────────────────────────── + +describe("ModelsJsonWriter — addModel", () => { + it("creates file and adds model to new provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.addModel("openai", { id: "gpt-4o", name: "GPT-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" }); + + const config = readModels() as any; + assert.ok(config.providers.openai); + assert.equal(config.providers.openai.models.length, 1); + assert.equal(config.providers.openai.models[0].id, "gpt-4o"); + }); + + it("appends model to existing provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" }); + writer.addModel("openai", { id: "gpt-4o-mini" }); + + const config = readModels() as any; + assert.equal(config.providers.openai.models.length, 2); + }); + + it("replaces model with same id", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.addModel("openai", { id: "gpt-4o", name: "Old" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" }); + writer.addModel("openai", { id: "gpt-4o", name: "New" }); + + const config = readModels() as any; + assert.equal(config.providers.openai.models.length, 1); + assert.equal(config.providers.openai.models[0].name, "New"); + }); +}); + +// ─── removeModel ───────────────────────────────────────────────────────────── + +describe("ModelsJsonWriter — removeModel", () => { + it("removes a model from provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" }); + writer.addModel("openai", { id: "gpt-4o-mini" }); + + writer.removeModel("openai", "gpt-4o"); + + const config = readModels() as any; + assert.equal(config.providers.openai.models.length, 1); + assert.equal(config.providers.openai.models[0].id, "gpt-4o-mini"); + }); + + it("removes provider when last model is removed", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" }); + + writer.removeModel("openai", "gpt-4o"); + + const config = readModels() as any; + assert.equal(config.providers.openai, undefined); + }); + + it("handles removing from nonexistent provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + // Should not throw + writer.removeModel("nonexistent", "model-id"); + }); +}); + +// ─── setProvider / removeProvider ──────────────────────────────────────────── + +describe("ModelsJsonWriter — provider operations", () => { + it("sets a provider configuration", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.setProvider("custom", { + baseUrl: "http://localhost:8080", + apiKey: "test-key", + api: "openai", + models: [{ id: "local-model" }], + }); + + const config = readModels() as any; + assert.ok(config.providers.custom); + assert.equal(config.providers.custom.baseUrl, "http://localhost:8080"); + }); + + it("removes a provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.setProvider("custom", { baseUrl: "http://localhost:8080" }); + writer.removeProvider("custom"); + + const config = readModels() as any; + assert.equal(config.providers.custom, undefined); + }); + + it("handles removing nonexistent provider", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.removeProvider("nonexistent"); + // Should not throw + }); +}); + +// ─── listProviders ─────────────────────────────────────────────────────────── + +describe("ModelsJsonWriter — listProviders", () => { + it("returns empty config when file does not exist", () => { + const writer = new ModelsJsonWriter(join(testDir, "nonexistent.json")); + const config = writer.listProviders(); + assert.deepEqual(config, { providers: {} }); + }); + + it("returns current provider config", () => { + const writer = new ModelsJsonWriter(modelsJsonPath); + writer.setProvider("openai", { baseUrl: "https://api.openai.com" }); + writer.setProvider("ollama", { baseUrl: "http://localhost:11434" }); + + const config = writer.listProviders(); + assert.ok(config.providers.openai); + assert.ok(config.providers.ollama); + }); +}); diff --git a/packages/pi-coding-agent/src/core/models-json-writer.ts b/packages/pi-coding-agent/src/core/models-json-writer.ts new file mode 100644 index 000000000..0d5e643b1 --- /dev/null +++ b/packages/pi-coding-agent/src/core/models-json-writer.ts @@ -0,0 +1,188 @@ +/** + * Safe read-modify-write for models.json with file locking. + * Prevents concurrent writes from corrupting the config file. + */ + +import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs"; +import { dirname, join } from "path"; +import lockfile from "proper-lockfile"; +import { getAgentDir } from "../config.js"; + +interface ModelDefinition { + id: string; + name?: string; + api?: string; + baseUrl?: string; + reasoning?: boolean; + input?: ("text" | "image")[]; + cost?: { input: number; output: number; cacheRead: number; cacheWrite: number }; + contextWindow?: number; + maxTokens?: number; +} + +interface ProviderConfig { + baseUrl?: string; + apiKey?: string; + api?: string; + headers?: Record; + authHeader?: boolean; + models?: ModelDefinition[]; + modelOverrides?: Record>; +} + +interface ModelsConfig { + providers: Record; +} + +export class ModelsJsonWriter { + private modelsJsonPath: string; + + constructor(modelsJsonPath?: string) { + this.modelsJsonPath = modelsJsonPath ?? join(getAgentDir(), "models.json"); + } + + /** + * Add a model to a provider. Creates the provider if it doesn't exist. + */ + addModel(provider: string, model: ModelDefinition, providerConfig?: Partial): void { + this.withLock((config) => { + if (!config.providers[provider]) { + config.providers[provider] = { + ...providerConfig, + models: [], + }; + } + + const providerEntry = config.providers[provider]; + if (!providerEntry.models) { + providerEntry.models = []; + } + + // Replace existing model with same id, or append + const existingIndex = providerEntry.models.findIndex((m) => m.id === model.id); + if (existingIndex >= 0) { + providerEntry.models[existingIndex] = model; + } else { + providerEntry.models.push(model); + } + + return config; + }); + } + + /** + * Remove a model from a provider. Removes the provider if no models remain. + */ + removeModel(provider: string, modelId: string): void { + this.withLock((config) => { + const providerEntry = config.providers[provider]; + if (!providerEntry?.models) return config; + + providerEntry.models = providerEntry.models.filter((m) => m.id !== modelId); + + // Clean up empty provider (no models and no overrides) + if (providerEntry.models.length === 0 && !providerEntry.modelOverrides) { + delete config.providers[provider]; + } + + return config; + }); + } + + /** + * Set or update an entire provider configuration. + */ + setProvider(provider: string, providerConfig: ProviderConfig): void { + this.withLock((config) => { + config.providers[provider] = providerConfig; + return config; + }); + } + + /** + * Remove a provider and all its models. + */ + removeProvider(provider: string): void { + this.withLock((config) => { + delete config.providers[provider]; + return config; + }); + } + + /** + * List all providers and their configurations. + */ + listProviders(): ModelsConfig { + return this.readConfig(); + } + + private readConfig(): ModelsConfig { + if (!existsSync(this.modelsJsonPath)) { + return { providers: {} }; + } + try { + const content = readFileSync(this.modelsJsonPath, "utf-8"); + return JSON.parse(content) as ModelsConfig; + } catch { + return { providers: {} }; + } + } + + private writeConfig(config: ModelsConfig): void { + const dir = dirname(this.modelsJsonPath); + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true }); + } + writeFileSync(this.modelsJsonPath, JSON.stringify(config, null, 2), "utf-8"); + } + + private acquireLockWithRetry(): () => void { + const maxAttempts = 10; + const delayMs = 20; + let lastError: unknown; + + // Ensure file exists for locking + const dir = dirname(this.modelsJsonPath); + if (!existsSync(dir)) { + mkdirSync(dir, { recursive: true }); + } + if (!existsSync(this.modelsJsonPath)) { + writeFileSync(this.modelsJsonPath, JSON.stringify({ providers: {} }, null, 2), "utf-8"); + } + + for (let attempt = 1; attempt <= maxAttempts; attempt++) { + try { + return lockfile.lockSync(this.modelsJsonPath, { realpath: false }); + } catch (error) { + const code = + typeof error === "object" && error !== null && "code" in error + ? String((error as { code?: unknown }).code) + : undefined; + if (code !== "ELOCKED" || attempt === maxAttempts) { + throw error; + } + lastError = error; + const start = Date.now(); + while (Date.now() - start < delayMs) { + // Busy-wait (same pattern as auth-storage.ts) + } + } + } + + throw (lastError as Error) ?? new Error("Failed to acquire models.json lock"); + } + + private withLock(fn: (config: ModelsConfig) => ModelsConfig): void { + let release: (() => void) | undefined; + try { + release = this.acquireLockWithRetry(); + const config = this.readConfig(); + const updated = fn(config); + this.writeConfig(updated); + } finally { + if (release) { + release(); + } + } + } +} diff --git a/packages/pi-coding-agent/src/core/settings-manager.ts b/packages/pi-coding-agent/src/core/settings-manager.ts index ce1f7bbd7..059b3a0da 100644 --- a/packages/pi-coding-agent/src/core/settings-manager.ts +++ b/packages/pi-coding-agent/src/core/settings-manager.ts @@ -79,6 +79,13 @@ export interface FallbackSettings { chains?: Record; // keyed by chain name } +export interface ModelDiscoverySettings { + enabled?: boolean; // default: false + providers?: string[]; // limit discovery to specific providers + ttlMinutes?: number; // override default TTLs (in minutes) + autoRefreshOnModelSelect?: boolean; // default: false - refresh discovery when opening model selector +} + export type TransportSetting = Transport; /** @@ -134,6 +141,7 @@ export interface Settings { bashInterceptor?: BashInterceptorSettings; taskIsolation?: TaskIsolationSettings; fallback?: FallbackSettings; + modelDiscovery?: ModelDiscoverySettings; } /** Deep merge settings: project/overrides take precedence, nested objects merge recursively */ @@ -1076,4 +1084,17 @@ export class SettingsManager { chains: this.getFallbackChains(), }; } + + getModelDiscoverySettings(): ModelDiscoverySettings { + return this.settings.modelDiscovery ?? {}; + } + + setModelDiscoveryEnabled(enabled: boolean): void { + if (!this.globalSettings.modelDiscovery) { + this.globalSettings.modelDiscovery = {}; + } + this.globalSettings.modelDiscovery.enabled = enabled; + this.markModified("modelDiscovery", "enabled"); + this.save(); + } } diff --git a/packages/pi-coding-agent/src/core/slash-commands.ts b/packages/pi-coding-agent/src/core/slash-commands.ts index fd4b667b5..8c2800811 100644 --- a/packages/pi-coding-agent/src/core/slash-commands.ts +++ b/packages/pi-coding-agent/src/core/slash-commands.ts @@ -28,6 +28,7 @@ export const BUILTIN_SLASH_COMMANDS: ReadonlyArray = [ { name: "hotkeys", description: "Show all keyboard shortcuts" }, { name: "fork", description: "Create a new fork from a previous message" }, { name: "tree", description: "Navigate session tree (switch branches)" }, + { name: "provider", description: "Manage provider configuration" }, { name: "login", description: "Login with OAuth provider" }, { name: "logout", description: "Logout from OAuth provider" }, { name: "new", description: "Start a new session" }, diff --git a/packages/pi-coding-agent/src/index.ts b/packages/pi-coding-agent/src/index.ts index 911431151..86a808a05 100644 --- a/packages/pi-coding-agent/src/index.ts +++ b/packages/pi-coding-agent/src/index.ts @@ -143,7 +143,11 @@ export { // Footer data provider (git branch + extension statuses - data not otherwise available to extensions) export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js"; export { convertToLlm } from "./core/messages.js"; +export { ModelDiscoveryCache } from "./core/discovery-cache.js"; +export type { DiscoveredModel, DiscoveryResult, ProviderDiscoveryAdapter } from "./core/model-discovery.js"; +export { getDiscoverableProviders, getDiscoveryAdapter } from "./core/model-discovery.js"; export { ModelRegistry } from "./core/model-registry.js"; +export { ModelsJsonWriter } from "./core/models-json-writer.js"; export type { PackageManager, PathMetadata, @@ -307,6 +311,7 @@ export { LoginDialogComponent, ModelSelectorComponent, OAuthSelectorComponent, + ProviderManagerComponent, type RenderDiffOptions, rawKeyHint, renderDiff, diff --git a/packages/pi-coding-agent/src/main.ts b/packages/pi-coding-agent/src/main.ts index 5c39de898..7152f63b3 100644 --- a/packages/pi-coding-agent/src/main.ts +++ b/packages/pi-coding-agent/src/main.ts @@ -11,7 +11,7 @@ import { createInterface } from "readline"; import { type Args, parseArgs, printHelp } from "./cli/args.js"; import { selectConfig } from "./cli/config-selector.js"; import { processFileArguments } from "./cli/file-processor.js"; -import { listModels } from "./cli/list-models.js"; +import { discoverAndPrintModels, listModels } from "./cli/list-models.js"; import { selectSession } from "./cli/session-picker.js"; import { APP_NAME, getAgentDir, getModelsPath, VERSION } from "./config.js"; import { AuthStorage } from "./core/auth-storage.js"; @@ -660,9 +660,26 @@ export async function main(args: string[]) { process.exit(0); } + if (parsed.addProvider) { + const { ModelsJsonWriter } = await import("./core/models-json-writer.js"); + const writer = new ModelsJsonWriter(); + writer.setProvider(parsed.addProvider, { + baseUrl: parsed.addProviderBaseUrl, + apiKey: parsed.apiKey, + }); + console.log(`Provider "${parsed.addProvider}" added to models.json`); + process.exit(0); + } + + if (parsed.discoverModels !== undefined) { + const provider = typeof parsed.discoverModels === "string" ? parsed.discoverModels : undefined; + await discoverAndPrintModels(modelRegistry, provider); + process.exit(0); + } + if (parsed.listModels !== undefined) { const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined; - await listModels(modelRegistry, searchPattern); + await listModels(modelRegistry, { searchPattern, discover: parsed.discover }); process.exit(0); } diff --git a/packages/pi-coding-agent/src/modes/interactive/components/index.ts b/packages/pi-coding-agent/src/modes/interactive/components/index.ts index 78200f36c..16b39a2ec 100644 --- a/packages/pi-coding-agent/src/modes/interactive/components/index.ts +++ b/packages/pi-coding-agent/src/modes/interactive/components/index.ts @@ -18,6 +18,7 @@ export { appKey, appKeyHint, editorKey, keyHint, rawKeyHint } from "./keybinding export { LoginDialogComponent } from "./login-dialog.js"; export { ModelSelectorComponent } from "./model-selector.js"; export { OAuthSelectorComponent } from "./oauth-selector.js"; +export { ProviderManagerComponent } from "./provider-manager.js"; export { type ModelsCallbacks, type ModelsConfig, ScopedModelsSelectorComponent } from "./scoped-models-selector.js"; export { SessionSelectorComponent } from "./session-selector.js"; export { type SettingsCallbacks, type SettingsConfig, SettingsSelectorComponent } from "./settings-selector.js"; diff --git a/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts b/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts index 06ef5ac2e..b35895a79 100644 --- a/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts +++ b/packages/pi-coding-agent/src/modes/interactive/components/model-selector.ts @@ -160,7 +160,7 @@ export class ModelSelectorComponent extends Container implements Focusable { // Load available models (built-in models still work even if models.json failed) try { - const availableModels = await this.modelRegistry.getAvailable(); + const availableModels = this.modelRegistry.getAvailable(); models = availableModels.map((model: Model) => ({ provider: model.provider, id: model.id, diff --git a/packages/pi-coding-agent/src/modes/interactive/components/provider-manager.ts b/packages/pi-coding-agent/src/modes/interactive/components/provider-manager.ts new file mode 100644 index 000000000..5944d8c78 --- /dev/null +++ b/packages/pi-coding-agent/src/modes/interactive/components/provider-manager.ts @@ -0,0 +1,163 @@ +/** + * TUI component for managing provider configurations. + * Shows providers with auth status, discovery support, and model counts. + */ + +import { + Container, + type Focusable, + getEditorKeybindings, + Spacer, + Text, + type TUI, +} from "@gsd/pi-tui"; +import type { AuthStorage } from "../../../core/auth-storage.js"; +import { getDiscoverableProviders } from "../../../core/model-discovery.js"; +import type { ModelRegistry } from "../../../core/model-registry.js"; +import { theme } from "../theme/theme.js"; +import { rawKeyHint } from "./keybinding-hints.js"; + +interface ProviderInfo { + name: string; + hasAuth: boolean; + supportsDiscovery: boolean; + modelCount: number; +} + +export class ProviderManagerComponent extends Container implements Focusable { + private _focused = false; + get focused(): boolean { + return this._focused; + } + set focused(value: boolean) { + this._focused = value; + } + + private providers: ProviderInfo[] = []; + private selectedIndex = 0; + private listContainer: Container; + private tui: TUI; + private authStorage: AuthStorage; + private modelRegistry: ModelRegistry; + private onDone: () => void; + private onDiscover: (provider: string) => void; + + constructor( + tui: TUI, + authStorage: AuthStorage, + modelRegistry: ModelRegistry, + onDone: () => void, + onDiscover: (provider: string) => void, + ) { + super(); + + this.tui = tui; + this.authStorage = authStorage; + this.modelRegistry = modelRegistry; + this.onDone = onDone; + this.onDiscover = onDiscover; + + // Header + this.addChild(new Text(theme.fg("accent", "Provider Manager"), 0, 0)); + this.addChild(new Spacer(1)); + + // Hints + const hints = [ + rawKeyHint("d", "discover"), + rawKeyHint("r", "remove auth"), + rawKeyHint("esc", "close"), + ].join(" "); + this.addChild(new Text(hints, 0, 0)); + this.addChild(new Spacer(1)); + + // List + this.listContainer = new Container(); + this.addChild(this.listContainer); + + this.loadProviders(); + this.updateList(); + } + + private loadProviders(): void { + const discoverableSet = new Set(getDiscoverableProviders()); + const allModels = this.modelRegistry.getAll(); + + // Group models by provider + const providerModelCounts = new Map(); + for (const model of allModels) { + providerModelCounts.set(model.provider, (providerModelCounts.get(model.provider) ?? 0) + 1); + } + + // Build provider list from all known providers + const providerNames = new Set([ + ...providerModelCounts.keys(), + ...discoverableSet, + ]); + + this.providers = Array.from(providerNames) + .sort() + .map((name) => ({ + name, + hasAuth: this.authStorage.hasAuth(name), + supportsDiscovery: discoverableSet.has(name), + modelCount: providerModelCounts.get(name) ?? 0, + })); + } + + private updateList(): void { + this.listContainer.clear(); + + for (let i = 0; i < this.providers.length; i++) { + const p = this.providers[i]; + const isSelected = i === this.selectedIndex; + + const authBadge = p.hasAuth ? theme.fg("success", "[auth]") : theme.fg("muted", "[no auth]"); + const discoveryBadge = p.supportsDiscovery ? theme.fg("accent", "[discovery]") : ""; + const countBadge = theme.fg("muted", `(${p.modelCount} models)`); + + const prefix = isSelected ? theme.fg("accent", "> ") : " "; + const nameText = isSelected ? theme.fg("accent", p.name) : p.name; + + const parts = [prefix, nameText, " ", authBadge]; + if (discoveryBadge) parts.push(" ", discoveryBadge); + parts.push(" ", countBadge); + + this.listContainer.addChild(new Text(parts.join(""), 0, 0)); + } + + if (this.providers.length === 0) { + this.listContainer.addChild(new Text(theme.fg("muted", " No providers configured"), 0, 0)); + } + } + + handleInput(keyData: string): void { + const kb = getEditorKeybindings(); + + if (kb.matches(keyData, "selectUp")) { + if (this.providers.length === 0) return; + this.selectedIndex = this.selectedIndex === 0 ? this.providers.length - 1 : this.selectedIndex - 1; + this.updateList(); + this.tui.requestRender(); + } else if (kb.matches(keyData, "selectDown")) { + if (this.providers.length === 0) return; + this.selectedIndex = this.selectedIndex === this.providers.length - 1 ? 0 : this.selectedIndex + 1; + this.updateList(); + this.tui.requestRender(); + } else if (kb.matches(keyData, "selectCancel")) { + this.onDone(); + } else if (keyData === "d" || keyData === "D") { + const provider = this.providers[this.selectedIndex]; + if (provider?.supportsDiscovery) { + this.onDiscover(provider.name); + } + } else if (keyData === "r" || keyData === "R") { + const provider = this.providers[this.selectedIndex]; + if (provider?.hasAuth) { + this.authStorage.remove(provider.name); + this.loadProviders(); + this.updateList(); + this.tui.requestRender(); + } + } + } +} diff --git a/packages/pi-coding-agent/src/modes/interactive/interactive-mode.ts b/packages/pi-coding-agent/src/modes/interactive/interactive-mode.ts index 3b64c7bc6..e536b63d3 100644 --- a/packages/pi-coding-agent/src/modes/interactive/interactive-mode.ts +++ b/packages/pi-coding-agent/src/modes/interactive/interactive-mode.ts @@ -83,6 +83,7 @@ import { appKey, appKeyHint, editorKey, formatKeyForDisplay, keyHint, rawKeyHint import { LoginDialogComponent } from "./components/login-dialog.js"; import { ModelSelectorComponent } from "./components/model-selector.js"; import { OAuthSelectorComponent } from "./components/oauth-selector.js"; +import { ProviderManagerComponent } from "./components/provider-manager.js"; import { ScopedModelsSelectorComponent } from "./components/scoped-models-selector.js"; import { SessionSelectorComponent } from "./components/session-selector.js"; import { SelectSubmenu, SettingsSelectorComponent, THINKING_DESCRIPTIONS } from "./components/settings-selector.js"; @@ -1997,6 +1998,11 @@ export class InteractiveMode { this.editor.setText(""); return; } + if (text === "/provider") { + this.showProviderManager(); + this.editor.setText(""); + return; + } if (text === "/login") { this.showOAuthSelector("login"); this.editor.setText(""); @@ -3746,6 +3752,37 @@ export class InteractiveMode { this.showStatus("Resumed session"); } + private showProviderManager(): void { + this.showSelector((done) => { + const component = new ProviderManagerComponent( + this.ui, + this.session.modelRegistry.authStorage, + this.session.modelRegistry, + () => { + done(); + this.ui.requestRender(); + }, + async (provider: string) => { + this.showStatus(`Discovering models for ${provider}...`); + try { + const results = await this.session.modelRegistry.discoverModels([provider]); + const result = results[0]; + if (result?.error) { + this.showError(`Discovery failed: ${result.error}`); + } else { + this.showStatus(`Discovered ${result?.models.length ?? 0} models from ${provider}`); + } + } catch (error) { + this.showError(error instanceof Error ? error.message : String(error)); + } + done(); + this.ui.requestRender(); + }, + ); + return { component, focus: component }; + }); + } + private async showOAuthSelector(mode: "login" | "logout"): Promise { if (mode === "logout") { const providers = this.session.modelRegistry.authStorage.list(); diff --git a/src/resources/extensions/gsd/commands.ts b/src/resources/extensions/gsd/commands.ts index 38b66e3ac..f6bf82dab 100644 --- a/src/resources/extensions/gsd/commands.ts +++ b/src/resources/extensions/gsd/commands.ts @@ -511,8 +511,10 @@ async function handlePrefsWizard( prefs.auto_supervisor = autoSup; } - // ─── Git main branch ──────────────────────────────────────────────────── + // ─── Git settings ─────────────────────────────────────────────────────── const git: Record = (prefs.git as Record) ?? {}; + + // main_branch const currentBranch = git.main_branch ? String(git.main_branch) : ""; const branchInput = await ctx.ui.input( `Git main branch${currentBranch ? ` (current: ${currentBranch})` : ""}:`, @@ -526,6 +528,90 @@ async function handlePrefsWizard( delete git.main_branch; } } + + // Boolean git toggles + const gitBooleanFields = [ + { key: "auto_push", label: "Auto-push commits after committing", defaultVal: false }, + { key: "push_branches", label: "Push milestone branches to remote", defaultVal: false }, + { key: "snapshots", label: "Create WIP snapshot commits during long tasks", defaultVal: false }, + ] as const; + + for (const field of gitBooleanFields) { + const current = git[field.key]; + const currentStr = current !== undefined ? String(current) : ""; + const choice = await ctx.ui.select( + `${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`, + ["true", "false", "(keep current)"], + ); + if (choice && choice !== "(keep current)") { + git[field.key] = choice === "true"; + } + } + + // remote + const currentRemote = git.remote ? String(git.remote) : ""; + const remoteInput = await ctx.ui.input( + `Git remote name${currentRemote ? ` (current: ${currentRemote})` : " (default: origin)"}:`, + currentRemote || "origin", + ); + if (remoteInput !== null && remoteInput !== undefined) { + const val = remoteInput.trim(); + if (val && val !== "origin") { + git.remote = val; + } else if (!val && currentRemote) { + delete git.remote; + } + } + + // pre_merge_check + const currentPreMerge = git.pre_merge_check !== undefined ? String(git.pre_merge_check) : ""; + const preMergeChoice = await ctx.ui.select( + `Pre-merge check${currentPreMerge ? ` (current: ${currentPreMerge})` : " (default: false)"}:`, + ["true", "false", "auto", "(keep current)"], + ); + if (preMergeChoice && preMergeChoice !== "(keep current)") { + if (preMergeChoice === "auto") { + git.pre_merge_check = "auto"; + } else { + git.pre_merge_check = preMergeChoice === "true"; + } + } + + // commit_type + const currentCommitType = git.commit_type ? String(git.commit_type) : ""; + const commitTypes = ["feat", "fix", "refactor", "docs", "test", "chore", "perf", "ci", "build", "style", "(inferred — default)", "(keep current)"]; + const commitChoice = await ctx.ui.select( + `Default commit type${currentCommitType ? ` (current: ${currentCommitType})` : ""}:`, + commitTypes, + ); + if (commitChoice && typeof commitChoice === "string" && commitChoice !== "(keep current)") { + if ((commitChoice as string).startsWith("(inferred")) { + delete git.commit_type; + } else { + git.commit_type = commitChoice; + } + } + + // merge_strategy + const currentMerge = git.merge_strategy ? String(git.merge_strategy) : ""; + const mergeChoice = await ctx.ui.select( + `Merge strategy${currentMerge ? ` (current: ${currentMerge})` : ""}:`, + ["squash", "merge", "(keep current)"], + ); + if (mergeChoice && mergeChoice !== "(keep current)") { + git.merge_strategy = mergeChoice; + } + + // isolation + const currentIsolation = git.isolation ? String(git.isolation) : ""; + const isolationChoice = await ctx.ui.select( + `Git isolation strategy${currentIsolation ? ` (current: ${currentIsolation})` : " (default: worktree)"}:`, + ["worktree", "branch", "(keep current)"], + ); + if (isolationChoice && isolationChoice !== "(keep current)") { + git.isolation = isolationChoice; + } + // ─── Git commit_docs ──────────────────────────────────────────────────── const currentCommitDocs = git.commit_docs; const commitDocsChoice = await ctx.ui.select( @@ -560,6 +646,89 @@ async function handlePrefsWizard( prefs.unique_milestone_ids = uniqueChoice === "true"; } + // ─── Budget & cost control ──────────────────────────────────────────── + const currentCeiling = prefs.budget_ceiling; + const ceilingStr = currentCeiling !== undefined ? String(currentCeiling) : ""; + const ceilingInput = await ctx.ui.input( + `Budget ceiling (USD)${ceilingStr ? ` (current: $${ceilingStr})` : " (default: no limit)"}:`, + ceilingStr || "", + ); + if (ceilingInput !== null && ceilingInput !== undefined) { + const val = ceilingInput.trim().replace(/^\$/, ""); + if (val && !isNaN(Number(val)) && isFinite(Number(val))) { + prefs.budget_ceiling = Number(val); + } else if (val && (isNaN(Number(val)) || !isFinite(Number(val)))) { + ctx.ui.notify(`Invalid budget ceiling "${val}" — must be a number. Keeping previous value.`, "warning"); + } else if (!val && ceilingStr) { + delete prefs.budget_ceiling; + } + } + + const currentEnforcement = (prefs.budget_enforcement as string) ?? ""; + const enforcementChoice = await ctx.ui.select( + `Budget enforcement${currentEnforcement ? ` (current: ${currentEnforcement})` : " (default: pause)"}:`, + ["warn", "pause", "halt", "(keep current)"], + ); + if (enforcementChoice && enforcementChoice !== "(keep current)") { + prefs.budget_enforcement = enforcementChoice; + } + + const currentContextPause = prefs.context_pause_threshold; + const contextPauseStr = currentContextPause !== undefined ? String(currentContextPause) : ""; + const contextPauseInput = await ctx.ui.input( + `Context pause threshold (0-100%, 0=disabled)${contextPauseStr ? ` (current: ${contextPauseStr}%)` : " (default: 0)"}:`, + contextPauseStr || "0", + ); + if (contextPauseInput !== null && contextPauseInput !== undefined) { + const val = contextPauseInput.trim().replace(/%$/, ""); + if (val && !isNaN(Number(val)) && Number(val) >= 0 && Number(val) <= 100) { + const num = Number(val); + if (num === 0) { + delete prefs.context_pause_threshold; + } else { + prefs.context_pause_threshold = num; + } + } else if (val && (isNaN(Number(val)) || Number(val) < 0 || Number(val) > 100)) { + ctx.ui.notify(`Invalid context pause threshold "${val}" — must be 0-100. Keeping previous value.`, "warning"); + } + } + + // ─── Notifications ──────────────────────────────────────────────────── + const notif: Record = (prefs.notifications as Record) ?? {}; + const notifFields = [ + { key: "enabled", label: "Notifications enabled (master toggle)", defaultVal: true }, + { key: "on_complete", label: "Notify on unit completion", defaultVal: true }, + { key: "on_error", label: "Notify on errors", defaultVal: true }, + { key: "on_budget", label: "Notify on budget thresholds", defaultVal: true }, + { key: "on_milestone", label: "Notify on milestone completion", defaultVal: true }, + { key: "on_attention", label: "Notify when manual attention needed", defaultVal: true }, + ] as const; + + for (const field of notifFields) { + const current = notif[field.key]; + const currentStr = current !== undefined ? String(current) : ""; + const choice = await ctx.ui.select( + `${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`, + ["true", "false", "(keep current)"], + ); + if (choice && choice !== "(keep current)") { + notif[field.key] = choice === "true"; + } + } + if (Object.keys(notif).length > 0) { + prefs.notifications = notif; + } + + // ─── UAT dispatch ───────────────────────────────────────────────────── + const currentUat = prefs.uat_dispatch; + const uatChoice = await ctx.ui.select( + `UAT dispatch mode${currentUat !== undefined ? ` (current: ${currentUat})` : " (default: false)"}:`, + ["true", "false", "(keep current)"], + ); + if (uatChoice && uatChoice !== "(keep current)") { + prefs.uat_dispatch = uatChoice === "true"; + } + // ─── Serialize to frontmatter ─────────────────────────────────────────── prefs.version = prefs.version || 1; const frontmatter = serializePreferencesToFrontmatter(prefs); @@ -650,7 +819,10 @@ function serializePreferencesToFrontmatter(prefs: Record): stri const orderedKeys = [ "version", "always_use_skills", "prefer_skills", "avoid_skills", "skill_rules", "custom_instructions", "models", "skill_discovery", - "auto_supervisor", "uat_dispatch", "unique_milestone_ids", "budget_ceiling", "remote_questions", "git", + "auto_supervisor", "uat_dispatch", "unique_milestone_ids", + "budget_ceiling", "budget_enforcement", "context_pause_threshold", + "notifications", "remote_questions", "git", + "post_unit_hooks", "pre_dispatch_hooks", ]; const seen = new Set(); diff --git a/src/resources/extensions/gsd/docs/preferences-reference.md b/src/resources/extensions/gsd/docs/preferences-reference.md index a71f06292..8a0b4fd72 100644 --- a/src/resources/extensions/gsd/docs/preferences-reference.md +++ b/src/resources/extensions/gsd/docs/preferences-reference.md @@ -108,10 +108,51 @@ Setting `prefer_skills: []` does **not** disable skill discovery — it just mea - `pre_merge_check`: boolean or `"auto"` — run pre-merge checks before merging a worktree back to the integration branch. `true` always runs, `false` never runs, `"auto"` runs when CI is detected. Default: `false`. - `commit_type`: string — override the conventional commit type prefix. Must be one of: `feat`, `fix`, `refactor`, `docs`, `test`, `chore`, `perf`, `ci`, `build`, `style`. Default: inferred from diff content. - `main_branch`: string — the primary branch name for new git repos (e.g., `"main"`, `"master"`, `"trunk"`). Also used by `getMainBranch()` as the preferred branch when auto-detection is ambiguous. Default: `"main"`. + - `merge_strategy`: `"squash"` or `"merge"` — controls how worktree branches are merged back. `"squash"` combines all commits into one; `"merge"` preserves individual commits. Default: `"squash"`. + - `isolation`: `"worktree"` or `"branch"` — controls auto-mode git isolation strategy. `"worktree"` creates a milestone worktree for isolated work; `"branch"` works directly in the project root (useful for submodule-heavy repos). Default: `"worktree"`. - `commit_docs`: boolean — when `false`, prevents GSD from committing `.gsd/` planning artifacts to git. The `.gsd/` folder is added to `.gitignore` and kept local-only. Useful for teams where only some members use GSD, or when company policy requires a clean repository. Default: `true`. - `unique_milestone_ids`: boolean — when `true`, generates milestone IDs in `M{seq}-{rand6}` format (e.g. `M001-eh88as`) instead of plain sequential `M001`. Prevents ID collisions in team workflows where multiple contributors create milestones concurrently. Both formats coexist — existing `M001`-style milestones remain valid. Default: `false`. +- `budget_ceiling`: number — maximum dollar amount to spend on auto-mode. When reached, behavior is controlled by `budget_enforcement`. Default: no limit. + +- `budget_enforcement`: `"warn"`, `"pause"`, or `"halt"` — action taken when `budget_ceiling` is reached. + - `warn` — log a warning but continue execution. + - `pause` — pause auto-mode and wait for user confirmation. + - `halt` — stop auto-mode immediately. + - Default: `"pause"`. + +- `context_pause_threshold`: number (0-100) — context window usage percentage at which auto-mode should pause to suggest checkpointing. Set to `0` to disable. Default: `0` (disabled). + +- `notifications`: configures desktop notification behavior during auto-mode. Keys: + - `enabled`: boolean — master toggle for all notifications. Default: `true`. + - `on_complete`: boolean — notify when a unit completes. Default: `true`. + - `on_error`: boolean — notify on errors. Default: `true`. + - `on_budget`: boolean — notify when budget thresholds are reached. Default: `true`. + - `on_milestone`: boolean — notify when a milestone finishes. Default: `true`. + - `on_attention`: boolean — notify when manual attention is needed. Default: `true`. + +- `uat_dispatch`: boolean — when `true`, enables UAT (User Acceptance Testing) dispatch mode. Default: `false`. + +- `post_unit_hooks`: array — hooks that fire after a unit completes. Each entry has: + - `name`: string — unique hook identifier. + - `after`: string[] — unit types that trigger this hook (e.g., `["execute-task"]`). + - `prompt`: string — prompt sent to the LLM. Supports `{milestoneId}`, `{sliceId}`, `{taskId}` substitutions. + - `max_cycles`: number — max times this hook fires per trigger (default: 1, max: 10). + - `model`: string — optional model override. + - `artifact`: string — expected output file (skip if exists). + - `retry_on`: string — file that triggers re-run of the trigger unit. + - `enabled`: boolean — toggle without removing (default: `true`). + +- `pre_dispatch_hooks`: array — hooks that fire before a unit is dispatched. Each entry has: + - `name`: string — unique hook identifier. + - `before`: string[] — unit types to intercept. + - `action`: `"modify"`, `"skip"`, or `"replace"` — what to do with the unit. + - `prepend`: string — text prepended to unit prompt (for `"modify"` action). + - `append`: string — text appended to unit prompt (for `"modify"` action). + - `prompt`: string — replacement prompt (for `"replace"` action). + - `enabled`: boolean — toggle without removing (default: `true`). + --- ## Best Practices @@ -277,3 +318,56 @@ git: ``` All git fields are optional. Omit any field to use the default behavior. Project-level preferences override global preferences on a per-field basis. + +--- + +## Budget & Cost Control Example + +```yaml +--- +version: 1 +budget_ceiling: 10.00 +budget_enforcement: pause +context_pause_threshold: 80 +--- +``` + +Sets a $10 budget ceiling. Auto-mode pauses when the ceiling is reached. Context window pauses at 80% usage for checkpointing. + +--- + +## Notifications Example + +```yaml +--- +version: 1 +notifications: + enabled: true + on_complete: false + on_error: true + on_budget: true + on_milestone: true + on_attention: true +--- +``` + +Disables per-unit completion notifications (noisy in long runs) while keeping error, budget, milestone, and attention notifications enabled. + +--- + +## Post-Unit Hooks Example + +```yaml +--- +version: 1 +post_unit_hooks: + - name: code-review + after: + - execute-task + prompt: "Review the code changes in {sliceId}/{taskId} for quality, security, and test coverage." + max_cycles: 1 + artifact: REVIEW.md +--- +``` + +Runs an automated code review after each task execution. Skips if `REVIEW.md` already exists (idempotent). diff --git a/src/resources/extensions/gsd/preferences.ts b/src/resources/extensions/gsd/preferences.ts index b4db977b1..06227bc95 100644 --- a/src/resources/extensions/gsd/preferences.ts +++ b/src/resources/extensions/gsd/preferences.ts @@ -1,4 +1,4 @@ -import { existsSync, readdirSync, readFileSync, statSync } from "node:fs"; +import { existsSync, readdirSync, readFileSync, statSync, writeFileSync } from "node:fs"; import { homedir } from "node:os"; import { isAbsolute, join } from "node:path"; import { getAgentDir } from "@gsd/pi-coding-agent"; @@ -1252,3 +1252,61 @@ export function resolvePreDispatchHooks(): PreDispatchHookConfig[] { return (prefs?.preferences.pre_dispatch_hooks ?? []) .filter(h => h.enabled !== false); } + +/** + * Validate a model ID string. + * Returns true if the ID looks like a valid model identifier. + */ +export function validateModelId(modelId: string): boolean { + if (!modelId || typeof modelId !== "string") return false; + const trimmed = modelId.trim(); + if (trimmed.length === 0 || trimmed.length > 256) return false; + // Allow alphanumeric, hyphens, underscores, dots, slashes, colons + return /^[a-zA-Z0-9\-_./:]+$/.test(trimmed); +} + +/** + * Update the models section of the global GSD preferences file. + * Performs a safe read-modify-write: reads current content, updates the models + * YAML block, and writes back. Creates the file if it doesn't exist. + */ +export function updatePreferencesModels(models: GSDModelConfigV2): void { + const prefsPath = getGlobalGSDPreferencesPath(); + + let content = ""; + if (existsSync(prefsPath)) { + content = readFileSync(prefsPath, "utf-8"); + } + + // Build the new models block + const lines: string[] = ["models:"]; + for (const [phase, value] of Object.entries(models)) { + if (typeof value === "string") { + lines.push(` ${phase}: ${value}`); + } else if (value && typeof value === "object") { + const config = value as GSDPhaseModelConfig; + lines.push(` ${phase}:`); + lines.push(` model: ${config.model}`); + if (config.provider) { + lines.push(` provider: ${config.provider}`); + } + if (config.fallbacks && config.fallbacks.length > 0) { + lines.push(` fallbacks:`); + for (const fb of config.fallbacks) { + lines.push(` - ${fb}`); + } + } + } + } + const modelsBlock = lines.join("\n"); + + // Replace existing models block or append + const modelsRegex = /^models:[\s\S]*?(?=\n[a-z_]|\n*$)/m; + if (modelsRegex.test(content)) { + content = content.replace(modelsRegex, modelsBlock); + } else { + content = content.trimEnd() + "\n\n" + modelsBlock + "\n"; + } + + writeFileSync(prefsPath, content, "utf-8"); +} diff --git a/src/resources/extensions/gsd/templates/preferences.md b/src/resources/extensions/gsd/templates/preferences.md index b3c540f96..d5ac04656 100644 --- a/src/resources/extensions/gsd/templates/preferences.md +++ b/src/resources/extensions/gsd/templates/preferences.md @@ -15,7 +15,21 @@ git: snapshots: pre_merge_check: commit_type: + main_branch: + merge_strategy: + isolation: unique_milestone_ids: +budget_ceiling: +budget_enforcement: +context_pause_threshold: +notifications: + enabled: + on_complete: + on_error: + on_budget: + on_milestone: + on_attention: +uat_dispatch: --- # GSD Skill Preferences diff --git a/src/resources/extensions/gsd/tests/preferences-wizard-fields.test.ts b/src/resources/extensions/gsd/tests/preferences-wizard-fields.test.ts new file mode 100644 index 000000000..9efa54953 --- /dev/null +++ b/src/resources/extensions/gsd/tests/preferences-wizard-fields.test.ts @@ -0,0 +1,168 @@ +/** + * preferences-wizard-fields.test.ts — Validates that all wizard-configurable + * preference fields are properly validated and round-trip through the schema. + */ + +import { createTestContext } from "./test-helpers.ts"; +import { validatePreferences } from "../preferences.ts"; +import type { GSDPreferences } from "../preferences.ts"; + +const { assertEq, assertTrue, report } = createTestContext(); + +async function main(): Promise { + console.log("\n=== budget fields validate correctly ==="); + + { + const { preferences, errors } = validatePreferences({ + budget_ceiling: 25.50, + budget_enforcement: "warn", + context_pause_threshold: 80, + }); + assertEq(errors.length, 0, "valid budget fields produce no errors"); + assertEq(preferences.budget_ceiling, 25.50, "budget_ceiling passes through"); + assertEq(preferences.budget_enforcement, "warn", "budget_enforcement passes through"); + assertEq(preferences.context_pause_threshold, 80, "context_pause_threshold passes through"); + } + + { + const { preferences, errors } = validatePreferences({ + budget_enforcement: "pause", + }); + assertEq(errors.length, 0, "budget_enforcement 'pause' is valid"); + assertEq(preferences.budget_enforcement, "pause", "pause passes through"); + } + + { + const { preferences, errors } = validatePreferences({ + budget_enforcement: "halt", + }); + assertEq(errors.length, 0, "budget_enforcement 'halt' is valid"); + assertEq(preferences.budget_enforcement, "halt", "halt passes through"); + } + + { + const { errors } = validatePreferences({ + budget_enforcement: "invalid", + } as unknown as GSDPreferences); + assertTrue(errors.some(e => e.includes("budget_enforcement")), "invalid budget_enforcement rejected"); + } + + console.log("\n=== notification fields validate correctly ==="); + + { + const { preferences, errors } = validatePreferences({ + notifications: { + enabled: true, + on_complete: false, + on_error: true, + on_budget: true, + on_milestone: false, + on_attention: true, + }, + }); + assertEq(errors.length, 0, "valid notifications produce no errors"); + assertEq(preferences.notifications?.enabled, true, "notifications.enabled passes through"); + assertEq(preferences.notifications?.on_complete, false, "notifications.on_complete passes through"); + assertEq(preferences.notifications?.on_milestone, false, "notifications.on_milestone passes through"); + } + + { + const { errors } = validatePreferences({ + notifications: "invalid", + } as unknown as GSDPreferences); + assertTrue(errors.some(e => e.includes("notifications")), "invalid notifications rejected"); + } + + console.log("\n=== git fields validate correctly ==="); + + { + const { preferences, errors } = validatePreferences({ + git: { + auto_push: true, + push_branches: false, + remote: "upstream", + snapshots: true, + pre_merge_check: "auto", + commit_type: "feat", + main_branch: "develop", + merge_strategy: "squash", + isolation: "branch", + }, + }); + assertEq(errors.length, 0, "valid git fields produce no errors"); + assertEq(preferences.git?.auto_push, true, "git.auto_push passes through"); + assertEq(preferences.git?.push_branches, false, "git.push_branches passes through"); + assertEq(preferences.git?.remote, "upstream", "git.remote passes through"); + assertEq(preferences.git?.snapshots, true, "git.snapshots passes through"); + assertEq(preferences.git?.pre_merge_check, "auto", "git.pre_merge_check passes through"); + assertEq(preferences.git?.commit_type, "feat", "git.commit_type passes through"); + assertEq(preferences.git?.main_branch, "develop", "git.main_branch passes through"); + assertEq(preferences.git?.merge_strategy, "squash", "git.merge_strategy passes through"); + assertEq(preferences.git?.isolation, "branch", "git.isolation passes through"); + } + + console.log("\n=== uat_dispatch validates correctly ==="); + + { + const { preferences, errors } = validatePreferences({ uat_dispatch: true }); + assertEq(errors.length, 0, "valid uat_dispatch produces no errors"); + assertEq(preferences.uat_dispatch, true, "uat_dispatch true passes through"); + } + + { + const { preferences, errors } = validatePreferences({ uat_dispatch: false }); + assertEq(errors.length, 0, "valid uat_dispatch false produces no errors"); + assertEq(preferences.uat_dispatch, false, "uat_dispatch false passes through"); + } + + console.log("\n=== unique_milestone_ids validates correctly ==="); + + { + const { preferences, errors } = validatePreferences({ unique_milestone_ids: true }); + assertEq(errors.length, 0, "valid unique_milestone_ids produces no errors"); + assertEq(preferences.unique_milestone_ids, true, "unique_milestone_ids passes through"); + } + + console.log("\n=== all wizard fields together produce no errors ==="); + + { + const fullPrefs: GSDPreferences = { + version: 1, + models: { research: "claude-opus-4-6", planning: "claude-sonnet-4-6" }, + auto_supervisor: { soft_timeout_minutes: 15, idle_timeout_minutes: 5, hard_timeout_minutes: 25 }, + git: { + main_branch: "main", + auto_push: true, + push_branches: false, + remote: "origin", + snapshots: true, + pre_merge_check: "auto", + commit_type: "feat", + merge_strategy: "squash", + isolation: "worktree", + }, + skill_discovery: "suggest", + unique_milestone_ids: false, + budget_ceiling: 50, + budget_enforcement: "pause", + context_pause_threshold: 75, + notifications: { + enabled: true, + on_complete: true, + on_error: true, + on_budget: true, + on_milestone: true, + on_attention: true, + }, + uat_dispatch: false, + }; + const { errors, warnings } = validatePreferences(fullPrefs); + const unknownWarnings = warnings.filter(w => w.includes("unknown")); + assertEq(errors.length, 0, "full wizard prefs produce no errors"); + assertEq(unknownWarnings.length, 0, "full wizard prefs produce no unknown-key warnings"); + } + + report(); +} + +main();