feat: dynamic model discovery & provider management UX (#581)
This commit is contained in:
parent
570f6195be
commit
9ed812ed54
25 changed files with 2122 additions and 23 deletions
27
.plans/dynamic-model-discovery.md
Normal file
27
.plans/dynamic-model-discovery.md
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Dynamic Model Discovery
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
Runtime model discovery from provider APIs with caching, TUI management, and CLI flags.
|
||||||
|
|
||||||
|
## Components
|
||||||
|
1. **model-discovery.ts** — Provider adapters (OpenAI, Ollama, OpenRouter, Google) + static adapters
|
||||||
|
2. **discovery-cache.ts** — Disk cache at `{agentDir}/discovery-cache.json` with per-provider TTLs
|
||||||
|
3. **models-json-writer.ts** — Safe read-modify-write for `models.json` with file locking
|
||||||
|
4. **provider-manager.ts** — TUI component for provider management (`/provider` command)
|
||||||
|
5. **model-registry.ts** — Extended with `discoverModels()`, `getAllWithDiscovered()`, cache integration
|
||||||
|
6. **settings-manager.ts** — `modelDiscovery` settings (enabled, providers, ttlMinutes, autoRefreshOnModelSelect)
|
||||||
|
7. **args.ts** — `--discover`, `--add-provider`, `--base-url`, `--discover-models` CLI flags
|
||||||
|
8. **list-models.ts** — Rewritten with `[discovered]` badge support
|
||||||
|
9. **main.ts** — CLI handlers for new flags
|
||||||
|
10. **interactive-mode.ts** — `/provider` command handler
|
||||||
|
11. **preferences.ts** — `updatePreferencesModels()` and `validateModelId()` helpers
|
||||||
|
|
||||||
|
## TTL Strategy
|
||||||
|
- Ollama: 5 min (local, models change often)
|
||||||
|
- OpenAI / Google / OpenRouter: 1 hour
|
||||||
|
- Default: 24 hours
|
||||||
|
|
||||||
|
## Merge Rules
|
||||||
|
- Discovered models never override existing built-in or custom models
|
||||||
|
- Discovered models are appended to the registry with `[discovered]` badge
|
||||||
|
- Background discovery is opt-in via `modelDiscovery.enabled` setting
|
||||||
49
.plans/preferences-wizard-completeness.md
Normal file
49
.plans/preferences-wizard-completeness.md
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
# Preferences Wizard Completeness
|
||||||
|
|
||||||
|
## Problem
|
||||||
|
The `/gsd prefs wizard` currently only configures 6 of 18+ preference fields. Users must hand-edit YAML for the rest.
|
||||||
|
|
||||||
|
## Current Wizard Coverage
|
||||||
|
1. Models (per phase) ✓
|
||||||
|
2. Auto-supervisor timeouts ✓
|
||||||
|
3. Git main_branch ✓
|
||||||
|
4. Skill discovery mode ✓
|
||||||
|
5. Unique milestone IDs ✓
|
||||||
|
|
||||||
|
## Missing Fields to Add
|
||||||
|
|
||||||
|
### Group 1: Git Settings (expand existing section)
|
||||||
|
- `auto_push` (boolean) — auto-push commits ✓
|
||||||
|
- `push_branches` (boolean) — push milestone branches ✓
|
||||||
|
- `remote` (string) — git remote name ✓
|
||||||
|
- `snapshots` (boolean) — WIP snapshot commits ✓
|
||||||
|
- `pre_merge_check` (boolean | "auto") — pre-merge validation ✓
|
||||||
|
- `commit_type` (select) — conventional commit prefix ✓
|
||||||
|
- `merge_strategy` (select) — squash vs merge ✓
|
||||||
|
- `isolation` (select) — worktree vs branch ✓
|
||||||
|
|
||||||
|
### Group 2: Budget & Cost Control ✓
|
||||||
|
- `budget_ceiling` (number) — dollar limit
|
||||||
|
- `budget_enforcement` (select: warn/pause/halt)
|
||||||
|
- `context_pause_threshold` (number 0-100)
|
||||||
|
|
||||||
|
### Group 3: Notifications ✓
|
||||||
|
- `notifications.enabled` (boolean)
|
||||||
|
- `notifications.on_complete` (boolean)
|
||||||
|
- `notifications.on_error` (boolean)
|
||||||
|
- `notifications.on_budget` (boolean)
|
||||||
|
- `notifications.on_milestone` (boolean)
|
||||||
|
- `notifications.on_attention` (boolean)
|
||||||
|
|
||||||
|
### Group 4: Behavior Toggles ✓
|
||||||
|
- `uat_dispatch` (boolean)
|
||||||
|
|
||||||
|
### Group 5: Update Serialization Order ✓
|
||||||
|
- Added missing keys to `orderedKeys` in `serializePreferencesToFrontmatter()`
|
||||||
|
|
||||||
|
### Group 6: Update Template & Docs ✓
|
||||||
|
- Updated `templates/preferences.md` with new fields
|
||||||
|
- Updated `docs/preferences-reference.md` with budget, notifications, git, hooks
|
||||||
|
|
||||||
|
### Group 7: Tests ✓
|
||||||
|
- Added `preferences-wizard-fields.test.ts` covering all new fields
|
||||||
|
|
@ -38,6 +38,11 @@ export interface Args {
|
||||||
themes?: string[];
|
themes?: string[];
|
||||||
noThemes?: boolean;
|
noThemes?: boolean;
|
||||||
listModels?: string | true;
|
listModels?: string | true;
|
||||||
|
discover?: boolean;
|
||||||
|
addProvider?: string;
|
||||||
|
addProviderBaseUrl?: string;
|
||||||
|
addProviderApiKey?: string;
|
||||||
|
discoverModels?: string | true;
|
||||||
offline?: boolean;
|
offline?: boolean;
|
||||||
verbose?: boolean;
|
verbose?: boolean;
|
||||||
messages: string[];
|
messages: string[];
|
||||||
|
|
@ -150,6 +155,18 @@ export function parseArgs(args: string[], extensionFlags?: Map<string, { type: "
|
||||||
} else {
|
} else {
|
||||||
result.listModels = true;
|
result.listModels = true;
|
||||||
}
|
}
|
||||||
|
} else if (arg === "--discover") {
|
||||||
|
result.discover = true;
|
||||||
|
} else if (arg === "--add-provider" && i + 1 < args.length) {
|
||||||
|
result.addProvider = args[++i];
|
||||||
|
} else if (arg === "--base-url" && i + 1 < args.length) {
|
||||||
|
result.addProviderBaseUrl = args[++i];
|
||||||
|
} else if (arg === "--discover-models") {
|
||||||
|
if (i + 1 < args.length && !args[i + 1].startsWith("-") && !args[i + 1].startsWith("@")) {
|
||||||
|
result.discoverModels = args[++i];
|
||||||
|
} else {
|
||||||
|
result.discoverModels = true;
|
||||||
|
}
|
||||||
} else if (arg === "--verbose") {
|
} else if (arg === "--verbose") {
|
||||||
result.verbose = true;
|
result.verbose = true;
|
||||||
} else if (arg === "--offline") {
|
} else if (arg === "--offline") {
|
||||||
|
|
@ -219,6 +236,10 @@ ${chalk.bold("Options:")}
|
||||||
--no-themes Disable theme discovery and loading
|
--no-themes Disable theme discovery and loading
|
||||||
--export <file> Export session file to HTML and exit
|
--export <file> Export session file to HTML and exit
|
||||||
--list-models [search] List available models (with optional fuzzy search)
|
--list-models [search] List available models (with optional fuzzy search)
|
||||||
|
--discover Include discovered models in --list-models output
|
||||||
|
--discover-models [provider] Discover models from provider APIs (all or specific)
|
||||||
|
--add-provider <name> Add a provider to models.json (use with --base-url, --api-key)
|
||||||
|
--base-url <url> Base URL for --add-provider
|
||||||
--verbose Force verbose startup (overrides quietStartup setting)
|
--verbose Force verbose startup (overrides quietStartup setting)
|
||||||
--offline Disable startup network operations (same as PI_OFFLINE=1)
|
--offline Disable startup network operations (same as PI_OFFLINE=1)
|
||||||
--help, -h Show this help
|
--help, -h Show this help
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,18 @@
|
||||||
/**
|
/**
|
||||||
* List available models with optional fuzzy search
|
* List available models with optional fuzzy search and discovery support
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import type { Api, Model } from "@gsd/pi-ai";
|
import type { Api, Model } from "@gsd/pi-ai";
|
||||||
import { fuzzyFilter } from "@gsd/pi-tui";
|
import { fuzzyFilter } from "@gsd/pi-tui";
|
||||||
import type { ModelRegistry } from "../core/model-registry.js";
|
import type { ModelRegistry } from "../core/model-registry.js";
|
||||||
|
|
||||||
|
export interface ListModelsOptions {
|
||||||
|
/** Include discovered models in output */
|
||||||
|
discover?: boolean;
|
||||||
|
/** Search pattern for fuzzy filtering */
|
||||||
|
searchPattern?: string;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
|
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
|
||||||
*/
|
*/
|
||||||
|
|
@ -22,10 +29,48 @@ function formatTokenCount(count: number): string {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* List available models, optionally filtered by search pattern
|
* Discover models from provider APIs and print results.
|
||||||
*/
|
*/
|
||||||
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
|
export async function discoverAndPrintModels(
|
||||||
const models = modelRegistry.getAvailable();
|
modelRegistry: ModelRegistry,
|
||||||
|
provider?: string,
|
||||||
|
): Promise<void> {
|
||||||
|
const providers = provider ? [provider] : undefined;
|
||||||
|
|
||||||
|
console.log("Discovering models...");
|
||||||
|
const results = await modelRegistry.discoverModels(providers);
|
||||||
|
|
||||||
|
for (const result of results) {
|
||||||
|
if (result.error) {
|
||||||
|
console.log(` ${result.provider}: error - ${result.error}`);
|
||||||
|
} else {
|
||||||
|
console.log(` ${result.provider}: ${result.models.length} models found`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List available models, optionally filtered by search pattern.
|
||||||
|
* Accepts either a string (backward compat) or ListModelsOptions.
|
||||||
|
*/
|
||||||
|
export async function listModels(
|
||||||
|
modelRegistry: ModelRegistry,
|
||||||
|
optionsOrSearch?: string | ListModelsOptions,
|
||||||
|
): Promise<void> {
|
||||||
|
const options: ListModelsOptions =
|
||||||
|
typeof optionsOrSearch === "string"
|
||||||
|
? { searchPattern: optionsOrSearch }
|
||||||
|
: optionsOrSearch ?? {};
|
||||||
|
|
||||||
|
// If discover flag is set, run discovery first
|
||||||
|
if (options.discover) {
|
||||||
|
await modelRegistry.discoverModels();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get models — include discovered if discovery was run
|
||||||
|
const models = options.discover
|
||||||
|
? modelRegistry.getAllWithDiscovered()
|
||||||
|
: modelRegistry.getAvailable();
|
||||||
|
|
||||||
if (models.length === 0) {
|
if (models.length === 0) {
|
||||||
console.log("No models available. Set API keys in environment variables.");
|
console.log("No models available. Set API keys in environment variables.");
|
||||||
|
|
@ -34,12 +79,12 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
||||||
|
|
||||||
// Apply fuzzy filter if search pattern provided
|
// Apply fuzzy filter if search pattern provided
|
||||||
let filteredModels: Model<Api>[] = models;
|
let filteredModels: Model<Api>[] = models;
|
||||||
if (searchPattern) {
|
if (options.searchPattern) {
|
||||||
filteredModels = fuzzyFilter(models, searchPattern, (m) => `${m.provider} ${m.id}`);
|
filteredModels = fuzzyFilter(models, options.searchPattern, (m) => `${m.provider} ${m.id}`);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filteredModels.length === 0) {
|
if (filteredModels.length === 0) {
|
||||||
console.log(`No models matching "${searchPattern}"`);
|
console.log(`No models matching "${options.searchPattern}"`);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -53,15 +98,19 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
||||||
});
|
});
|
||||||
|
|
||||||
// Calculate column widths
|
// Calculate column widths
|
||||||
const rows = filteredModels.map((m) => ({
|
const rows = filteredModels.map((m) => {
|
||||||
provider: m.provider,
|
const isDiscovered = options.discover && modelRegistry.isDiscovered(m);
|
||||||
model: m.id,
|
return {
|
||||||
name: m.name,
|
provider: m.provider,
|
||||||
context: formatTokenCount(m.contextWindow),
|
model: m.id,
|
||||||
maxOut: formatTokenCount(m.maxTokens),
|
name: m.name,
|
||||||
thinking: m.reasoning ? "yes" : "no",
|
context: formatTokenCount(m.contextWindow),
|
||||||
images: m.input.includes("image") ? "yes" : "no",
|
maxOut: formatTokenCount(m.maxTokens),
|
||||||
}));
|
thinking: m.reasoning ? "yes" : "no",
|
||||||
|
images: m.input.includes("image") ? "yes" : "no",
|
||||||
|
badge: isDiscovered ? "[discovered]" : "",
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
const headers = {
|
const headers = {
|
||||||
provider: "provider",
|
provider: "provider",
|
||||||
|
|
@ -71,6 +120,7 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
||||||
maxOut: "max-out",
|
maxOut: "max-out",
|
||||||
thinking: "thinking",
|
thinking: "thinking",
|
||||||
images: "images",
|
images: "images",
|
||||||
|
badge: "",
|
||||||
};
|
};
|
||||||
|
|
||||||
const widths = {
|
const widths = {
|
||||||
|
|
@ -105,7 +155,10 @@ export async function listModels(modelRegistry: ModelRegistry, searchPattern?: s
|
||||||
row.maxOut.padEnd(widths.maxOut),
|
row.maxOut.padEnd(widths.maxOut),
|
||||||
row.thinking.padEnd(widths.thinking),
|
row.thinking.padEnd(widths.thinking),
|
||||||
row.images.padEnd(widths.images),
|
row.images.padEnd(widths.images),
|
||||||
].join(" ");
|
row.badge,
|
||||||
|
]
|
||||||
|
.join(" ")
|
||||||
|
.trimEnd();
|
||||||
console.log(line);
|
console.log(line);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
170
packages/pi-coding-agent/src/core/discovery-cache.test.ts
Normal file
170
packages/pi-coding-agent/src/core/discovery-cache.test.ts
Normal file
|
|
@ -0,0 +1,170 @@
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
import { existsSync, mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||||
|
import { tmpdir } from "node:os";
|
||||||
|
import { join } from "node:path";
|
||||||
|
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||||
|
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||||
|
|
||||||
|
let testDir: string;
|
||||||
|
let cachePath: string;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
testDir = join(tmpdir(), `discovery-cache-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
cachePath = join(testDir, "discovery-cache.json");
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
try {
|
||||||
|
rmSync(testDir, { recursive: true, force: true });
|
||||||
|
} catch {
|
||||||
|
// Cleanup best-effort
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── basic operations ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — basic operations", () => {
|
||||||
|
it("starts with no entries", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("stores and retrieves models", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
const models = [{ id: "gpt-4o", name: "GPT-4o" }];
|
||||||
|
cache.set("openai", models);
|
||||||
|
|
||||||
|
const entry = cache.get("openai");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.deepEqual(entry.models, models);
|
||||||
|
assert.ok(entry.fetchedAt > 0);
|
||||||
|
assert.ok(entry.ttlMs > 0);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("persists to disk and reloads", () => {
|
||||||
|
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache1.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
|
||||||
|
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||||
|
const entry = cache2.get("openai");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.equal(entry.models[0].id, "gpt-4o");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clear removes a specific provider", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache.set("google", [{ id: "gemini-pro" }]);
|
||||||
|
|
||||||
|
cache.clear("openai");
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
assert.ok(cache.get("google"));
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clear without provider removes all entries", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache.set("google", [{ id: "gemini-pro" }]);
|
||||||
|
|
||||||
|
cache.clear();
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
assert.equal(cache.get("google"), undefined);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── staleness ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — staleness", () => {
|
||||||
|
it("newly set entries are not stale", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
assert.equal(cache.isStale("openai"), false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("missing providers are stale", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
assert.equal(cache.isStale("unknown"), true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("entries with expired TTL are stale", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }], 1); // 1ms TTL
|
||||||
|
|
||||||
|
// Wait for TTL to expire
|
||||||
|
const start = Date.now();
|
||||||
|
while (Date.now() - start < 5) {
|
||||||
|
// busy wait
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.equal(cache.isStale("openai"), true);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── getAll ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — getAll", () => {
|
||||||
|
it("returns non-stale entries by default", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache.set("stale", [{ id: "old" }], 1);
|
||||||
|
|
||||||
|
// Wait for stale TTL
|
||||||
|
const start = Date.now();
|
||||||
|
while (Date.now() - start < 5) {
|
||||||
|
// busy wait
|
||||||
|
}
|
||||||
|
|
||||||
|
const all = cache.getAll();
|
||||||
|
assert.ok(all.has("openai"));
|
||||||
|
assert.ok(!all.has("stale"));
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns all entries when includeStale is true", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache.set("stale", [{ id: "old" }], 1);
|
||||||
|
|
||||||
|
// Wait for stale TTL
|
||||||
|
const start = Date.now();
|
||||||
|
while (Date.now() - start < 5) {
|
||||||
|
// busy wait
|
||||||
|
}
|
||||||
|
|
||||||
|
const all = cache.getAll(true);
|
||||||
|
assert.ok(all.has("openai"));
|
||||||
|
assert.ok(all.has("stale"));
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── edge cases ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — edge cases", () => {
|
||||||
|
it("handles corrupted cache file gracefully", () => {
|
||||||
|
writeFileSync(cachePath, "not valid json", "utf-8");
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles wrong version gracefully", () => {
|
||||||
|
writeFileSync(cachePath, JSON.stringify({ version: 99, entries: {} }), "utf-8");
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles missing cache file", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(join(testDir, "nonexistent", "cache.json"));
|
||||||
|
assert.equal(cache.get("openai"), undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("overwrites existing entry for same provider", () => {
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache.set("openai", [{ id: "gpt-4o-mini" }]);
|
||||||
|
|
||||||
|
const entry = cache.get("openai");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.equal(entry.models.length, 1);
|
||||||
|
assert.equal(entry.models[0].id, "gpt-4o-mini");
|
||||||
|
});
|
||||||
|
});
|
||||||
97
packages/pi-coding-agent/src/core/discovery-cache.ts
Normal file
97
packages/pi-coding-agent/src/core/discovery-cache.ts
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
/**
|
||||||
|
* Disk-based cache for discovered models.
|
||||||
|
* Stores results at {agentDir}/discovery-cache.json with per-provider TTLs.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||||
|
import { dirname, join } from "path";
|
||||||
|
import { getAgentDir } from "../config.js";
|
||||||
|
import { type DiscoveredModel, getDefaultTTL } from "./model-discovery.js";
|
||||||
|
|
||||||
|
export interface DiscoveryCacheEntry {
|
||||||
|
models: DiscoveredModel[];
|
||||||
|
fetchedAt: number;
|
||||||
|
ttlMs: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DiscoveryCacheData {
|
||||||
|
version: 1;
|
||||||
|
entries: Record<string, DiscoveryCacheEntry>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ModelDiscoveryCache {
|
||||||
|
private data: DiscoveryCacheData;
|
||||||
|
private cachePath: string;
|
||||||
|
|
||||||
|
constructor(cachePath?: string) {
|
||||||
|
this.cachePath = cachePath ?? join(getAgentDir(), "discovery-cache.json");
|
||||||
|
this.data = { version: 1, entries: {} };
|
||||||
|
this.load();
|
||||||
|
}
|
||||||
|
|
||||||
|
get(provider: string): DiscoveryCacheEntry | undefined {
|
||||||
|
const entry = this.data.entries[provider];
|
||||||
|
return entry;
|
||||||
|
}
|
||||||
|
|
||||||
|
set(provider: string, models: DiscoveredModel[], ttlMs?: number): void {
|
||||||
|
this.data.entries[provider] = {
|
||||||
|
models,
|
||||||
|
fetchedAt: Date.now(),
|
||||||
|
ttlMs: ttlMs ?? getDefaultTTL(provider),
|
||||||
|
};
|
||||||
|
this.save();
|
||||||
|
}
|
||||||
|
|
||||||
|
isStale(provider: string): boolean {
|
||||||
|
const entry = this.data.entries[provider];
|
||||||
|
if (!entry) return true;
|
||||||
|
return Date.now() - entry.fetchedAt > entry.ttlMs;
|
||||||
|
}
|
||||||
|
|
||||||
|
clear(provider?: string): void {
|
||||||
|
if (provider) {
|
||||||
|
delete this.data.entries[provider];
|
||||||
|
} else {
|
||||||
|
this.data.entries = {};
|
||||||
|
}
|
||||||
|
this.save();
|
||||||
|
}
|
||||||
|
|
||||||
|
getAll(includeStale = false): Map<string, DiscoveryCacheEntry> {
|
||||||
|
const result = new Map<string, DiscoveryCacheEntry>();
|
||||||
|
for (const [provider, entry] of Object.entries(this.data.entries)) {
|
||||||
|
if (includeStale || !this.isStale(provider)) {
|
||||||
|
result.set(provider, entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
load(): void {
|
||||||
|
try {
|
||||||
|
if (existsSync(this.cachePath)) {
|
||||||
|
const content = readFileSync(this.cachePath, "utf-8");
|
||||||
|
const parsed = JSON.parse(content) as DiscoveryCacheData;
|
||||||
|
if (parsed.version === 1 && parsed.entries) {
|
||||||
|
this.data = parsed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Corrupted or unreadable cache — start fresh
|
||||||
|
this.data = { version: 1, entries: {} };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
save(): void {
|
||||||
|
try {
|
||||||
|
const dir = dirname(this.cachePath);
|
||||||
|
if (!existsSync(dir)) {
|
||||||
|
mkdirSync(dir, { recursive: true });
|
||||||
|
}
|
||||||
|
writeFileSync(this.cachePath, JSON.stringify(this.data, null, 2), "utf-8");
|
||||||
|
} catch {
|
||||||
|
// Silently ignore write failures (read-only FS, permissions, etc.)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
125
packages/pi-coding-agent/src/core/model-discovery.test.ts
Normal file
125
packages/pi-coding-agent/src/core/model-discovery.test.ts
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
import { describe, it } from "node:test";
|
||||||
|
import {
|
||||||
|
DISCOVERY_TTLS,
|
||||||
|
getDefaultTTL,
|
||||||
|
getDiscoverableProviders,
|
||||||
|
getDiscoveryAdapter,
|
||||||
|
} from "./model-discovery.js";
|
||||||
|
|
||||||
|
// ─── getDiscoveryAdapter ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("getDiscoveryAdapter", () => {
|
||||||
|
it("returns an adapter for openai", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("openai");
|
||||||
|
assert.equal(adapter.provider, "openai");
|
||||||
|
assert.equal(adapter.supportsDiscovery, true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns an adapter for ollama", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("ollama");
|
||||||
|
assert.equal(adapter.provider, "ollama");
|
||||||
|
assert.equal(adapter.supportsDiscovery, true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns an adapter for openrouter", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("openrouter");
|
||||||
|
assert.equal(adapter.provider, "openrouter");
|
||||||
|
assert.equal(adapter.supportsDiscovery, true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns an adapter for google", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("google");
|
||||||
|
assert.equal(adapter.provider, "google");
|
||||||
|
assert.equal(adapter.supportsDiscovery, true);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns a static adapter for anthropic", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("anthropic");
|
||||||
|
assert.equal(adapter.provider, "anthropic");
|
||||||
|
assert.equal(adapter.supportsDiscovery, false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns a static adapter for bedrock", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("bedrock");
|
||||||
|
assert.equal(adapter.provider, "bedrock");
|
||||||
|
assert.equal(adapter.supportsDiscovery, false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns a static adapter for unknown providers", () => {
|
||||||
|
const adapter = getDiscoveryAdapter("unknown-provider");
|
||||||
|
assert.equal(adapter.provider, "unknown-provider");
|
||||||
|
assert.equal(adapter.supportsDiscovery, false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("static adapter fetchModels returns empty array", async () => {
|
||||||
|
const adapter = getDiscoveryAdapter("anthropic");
|
||||||
|
const models = await adapter.fetchModels("key");
|
||||||
|
assert.deepEqual(models, []);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── getDiscoverableProviders ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("getDiscoverableProviders", () => {
|
||||||
|
it("returns only providers that support discovery", () => {
|
||||||
|
const providers = getDiscoverableProviders();
|
||||||
|
assert.ok(providers.includes("openai"));
|
||||||
|
assert.ok(providers.includes("ollama"));
|
||||||
|
assert.ok(providers.includes("openrouter"));
|
||||||
|
assert.ok(providers.includes("google"));
|
||||||
|
assert.ok(!providers.includes("anthropic"));
|
||||||
|
assert.ok(!providers.includes("bedrock"));
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns an array of strings", () => {
|
||||||
|
const providers = getDiscoverableProviders();
|
||||||
|
assert.ok(Array.isArray(providers));
|
||||||
|
for (const p of providers) {
|
||||||
|
assert.equal(typeof p, "string");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── getDefaultTTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("getDefaultTTL", () => {
|
||||||
|
it("returns 5 minutes for ollama", () => {
|
||||||
|
assert.equal(getDefaultTTL("ollama"), 5 * 60 * 1000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns 1 hour for openai", () => {
|
||||||
|
assert.equal(getDefaultTTL("openai"), 60 * 60 * 1000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns 1 hour for google", () => {
|
||||||
|
assert.equal(getDefaultTTL("google"), 60 * 60 * 1000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns 1 hour for openrouter", () => {
|
||||||
|
assert.equal(getDefaultTTL("openrouter"), 60 * 60 * 1000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns 24 hours for unknown providers", () => {
|
||||||
|
assert.equal(getDefaultTTL("some-custom"), 24 * 60 * 60 * 1000);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── DISCOVERY_TTLS ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("DISCOVERY_TTLS", () => {
|
||||||
|
it("has expected keys", () => {
|
||||||
|
assert.ok("ollama" in DISCOVERY_TTLS);
|
||||||
|
assert.ok("openai" in DISCOVERY_TTLS);
|
||||||
|
assert.ok("google" in DISCOVERY_TTLS);
|
||||||
|
assert.ok("openrouter" in DISCOVERY_TTLS);
|
||||||
|
assert.ok("default" in DISCOVERY_TTLS);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("all values are positive numbers", () => {
|
||||||
|
for (const [, value] of Object.entries(DISCOVERY_TTLS)) {
|
||||||
|
assert.equal(typeof value, "number");
|
||||||
|
assert.ok(value > 0);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
231
packages/pi-coding-agent/src/core/model-discovery.ts
Normal file
231
packages/pi-coding-agent/src/core/model-discovery.ts
Normal file
|
|
@ -0,0 +1,231 @@
|
||||||
|
/**
|
||||||
|
* Provider discovery adapters for runtime model enumeration.
|
||||||
|
* Each adapter implements ProviderDiscoveryAdapter to fetch models from provider APIs.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export interface DiscoveredModel {
|
||||||
|
id: string;
|
||||||
|
name?: string;
|
||||||
|
contextWindow?: number;
|
||||||
|
maxTokens?: number;
|
||||||
|
reasoning?: boolean;
|
||||||
|
input?: ("text" | "image")[];
|
||||||
|
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface DiscoveryResult {
|
||||||
|
provider: string;
|
||||||
|
models: DiscoveredModel[];
|
||||||
|
fetchedAt: number;
|
||||||
|
error?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ProviderDiscoveryAdapter {
|
||||||
|
provider: string;
|
||||||
|
supportsDiscovery: boolean;
|
||||||
|
fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Per-provider TTLs in milliseconds */
|
||||||
|
export const DISCOVERY_TTLS: Record<string, number> = {
|
||||||
|
ollama: 5 * 60 * 1000, // 5 minutes (local, models change often)
|
||||||
|
openai: 60 * 60 * 1000, // 1 hour
|
||||||
|
google: 60 * 60 * 1000, // 1 hour
|
||||||
|
openrouter: 60 * 60 * 1000, // 1 hour
|
||||||
|
default: 24 * 60 * 60 * 1000, // 24 hours
|
||||||
|
};
|
||||||
|
|
||||||
|
export function getDefaultTTL(provider: string): number {
|
||||||
|
return DISCOVERY_TTLS[provider] ?? DISCOVERY_TTLS.default;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function fetchWithTimeout(url: string, options: RequestInit = {}, timeoutMs = 5000): Promise<Response> {
|
||||||
|
const controller = new AbortController();
|
||||||
|
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||||||
|
try {
|
||||||
|
return await fetch(url, { ...options, signal: controller.signal });
|
||||||
|
} finally {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── OpenAI Adapter ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const OPENAI_EXCLUDED_PREFIXES = ["embedding", "tts", "dall-e", "whisper", "text-embedding", "davinci", "babbage"];
|
||||||
|
|
||||||
|
class OpenAIDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||||
|
provider = "openai";
|
||||||
|
supportsDiscovery = true;
|
||||||
|
|
||||||
|
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||||
|
const url = `${baseUrl ?? "https://api.openai.com"}/v1/models`;
|
||||||
|
const response = await fetchWithTimeout(url, {
|
||||||
|
headers: { Authorization: `Bearer ${apiKey}` },
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`OpenAI models API returned ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as { data: Array<{ id: string; owned_by?: string }> };
|
||||||
|
return data.data
|
||||||
|
.filter((m) => !OPENAI_EXCLUDED_PREFIXES.some((prefix) => m.id.startsWith(prefix)))
|
||||||
|
.map((m) => ({
|
||||||
|
id: m.id,
|
||||||
|
name: m.id,
|
||||||
|
input: ["text" as const, "image" as const],
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Ollama Adapter ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class OllamaDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||||
|
provider = "ollama";
|
||||||
|
supportsDiscovery = true;
|
||||||
|
|
||||||
|
async fetchModels(_apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||||
|
const url = `${baseUrl ?? "http://localhost:11434"}/api/tags`;
|
||||||
|
const response = await fetchWithTimeout(url);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Ollama tags API returned ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as {
|
||||||
|
models: Array<{ name: string; size: number; details?: { parameter_size?: string } }>;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (data.models ?? []).map((m) => ({
|
||||||
|
id: m.name,
|
||||||
|
name: m.name,
|
||||||
|
input: ["text" as const],
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── OpenRouter Adapter ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class OpenRouterDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||||
|
provider = "openrouter";
|
||||||
|
supportsDiscovery = true;
|
||||||
|
|
||||||
|
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||||
|
const url = `${baseUrl ?? "https://openrouter.ai"}/api/v1/models`;
|
||||||
|
const response = await fetchWithTimeout(url, {
|
||||||
|
headers: { Authorization: `Bearer ${apiKey}` },
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`OpenRouter models API returned ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as {
|
||||||
|
data: Array<{
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
context_length?: number;
|
||||||
|
top_provider?: { max_completion_tokens?: number };
|
||||||
|
pricing?: { prompt: string; completion: string };
|
||||||
|
}>;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (data.data ?? []).map((m) => {
|
||||||
|
const cost =
|
||||||
|
m.pricing?.prompt !== undefined && m.pricing?.completion !== undefined
|
||||||
|
? {
|
||||||
|
input: parseFloat(m.pricing.prompt) * 1_000_000,
|
||||||
|
output: parseFloat(m.pricing.completion) * 1_000_000,
|
||||||
|
cacheRead: 0,
|
||||||
|
cacheWrite: 0,
|
||||||
|
}
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: m.id,
|
||||||
|
name: m.name,
|
||||||
|
contextWindow: m.context_length,
|
||||||
|
maxTokens: m.top_provider?.max_completion_tokens,
|
||||||
|
cost,
|
||||||
|
input: ["text" as const, "image" as const],
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Google/Gemini Adapter ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class GoogleDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||||
|
provider = "google";
|
||||||
|
supportsDiscovery = true;
|
||||||
|
|
||||||
|
async fetchModels(apiKey: string, baseUrl?: string): Promise<DiscoveredModel[]> {
|
||||||
|
const url = `${baseUrl ?? "https://generativelanguage.googleapis.com"}/v1beta/models?key=${apiKey}`;
|
||||||
|
const response = await fetchWithTimeout(url);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`Google models API returned ${response.status}: ${response.statusText}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = (await response.json()) as {
|
||||||
|
models: Array<{
|
||||||
|
name: string;
|
||||||
|
displayName: string;
|
||||||
|
supportedGenerationMethods?: string[];
|
||||||
|
inputTokenLimit?: number;
|
||||||
|
outputTokenLimit?: number;
|
||||||
|
}>;
|
||||||
|
};
|
||||||
|
|
||||||
|
return (data.models ?? [])
|
||||||
|
.filter((m) => m.supportedGenerationMethods?.includes("generateContent"))
|
||||||
|
.map((m) => ({
|
||||||
|
id: m.name.replace("models/", ""),
|
||||||
|
name: m.displayName,
|
||||||
|
contextWindow: m.inputTokenLimit,
|
||||||
|
maxTokens: m.outputTokenLimit,
|
||||||
|
input: ["text" as const, "image" as const],
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Static Adapter (no discovery) ───────────────────────────────────────────
|
||||||
|
|
||||||
|
class StaticDiscoveryAdapter implements ProviderDiscoveryAdapter {
|
||||||
|
provider: string;
|
||||||
|
supportsDiscovery = false;
|
||||||
|
|
||||||
|
constructor(provider: string) {
|
||||||
|
this.provider = provider;
|
||||||
|
}
|
||||||
|
|
||||||
|
async fetchModels(): Promise<DiscoveredModel[]> {
|
||||||
|
return [];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Registry ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
const adapters: Record<string, ProviderDiscoveryAdapter> = {
|
||||||
|
openai: new OpenAIDiscoveryAdapter(),
|
||||||
|
ollama: new OllamaDiscoveryAdapter(),
|
||||||
|
openrouter: new OpenRouterDiscoveryAdapter(),
|
||||||
|
google: new GoogleDiscoveryAdapter(),
|
||||||
|
anthropic: new StaticDiscoveryAdapter("anthropic"),
|
||||||
|
bedrock: new StaticDiscoveryAdapter("bedrock"),
|
||||||
|
"azure-openai": new StaticDiscoveryAdapter("azure-openai"),
|
||||||
|
groq: new StaticDiscoveryAdapter("groq"),
|
||||||
|
cerebras: new StaticDiscoveryAdapter("cerebras"),
|
||||||
|
xai: new StaticDiscoveryAdapter("xai"),
|
||||||
|
mistral: new StaticDiscoveryAdapter("mistral"),
|
||||||
|
};
|
||||||
|
|
||||||
|
export function getDiscoveryAdapter(provider: string): ProviderDiscoveryAdapter {
|
||||||
|
return adapters[provider] ?? new StaticDiscoveryAdapter(provider);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function getDiscoverableProviders(): string[] {
|
||||||
|
return Object.entries(adapters)
|
||||||
|
.filter(([, adapter]) => adapter.supportsDiscovery)
|
||||||
|
.map(([name]) => name);
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,135 @@
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
import { mkdirSync, rmSync, writeFileSync } from "node:fs";
|
||||||
|
import { tmpdir } from "node:os";
|
||||||
|
import { join } from "node:path";
|
||||||
|
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||||
|
import { AuthStorage } from "./auth-storage.js";
|
||||||
|
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||||
|
import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js";
|
||||||
|
|
||||||
|
let testDir: string;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
testDir = join(tmpdir(), `model-registry-discovery-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
try {
|
||||||
|
rmSync(testDir, { recursive: true, force: true });
|
||||||
|
} catch {
|
||||||
|
// Cleanup best-effort
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── discovery cache integration ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — integration with discovery", () => {
|
||||||
|
it("cache respects provider-specific TTLs", () => {
|
||||||
|
const cachePath = join(testDir, "cache.json");
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
|
||||||
|
cache.set("ollama", [{ id: "llama2" }]);
|
||||||
|
const entry = cache.get("ollama");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.equal(entry.ttlMs, getDefaultTTL("ollama"));
|
||||||
|
});
|
||||||
|
|
||||||
|
it("cache uses custom TTL when provided", () => {
|
||||||
|
const cachePath = join(testDir, "cache.json");
|
||||||
|
const cache = new ModelDiscoveryCache(cachePath);
|
||||||
|
|
||||||
|
cache.set("openai", [{ id: "gpt-4o" }], 999);
|
||||||
|
const entry = cache.get("openai");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.equal(entry.ttlMs, 999);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── adapter resolution ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("Discovery adapter resolution", () => {
|
||||||
|
it("all discoverable providers have adapters", () => {
|
||||||
|
const providers = getDiscoverableProviders();
|
||||||
|
for (const provider of providers) {
|
||||||
|
const adapter = getDiscoveryAdapter(provider);
|
||||||
|
assert.equal(adapter.supportsDiscovery, true, `${provider} should support discovery`);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
it("static adapters return empty model lists", async () => {
|
||||||
|
const staticProviders = ["anthropic", "bedrock", "azure-openai", "groq", "cerebras"];
|
||||||
|
for (const provider of staticProviders) {
|
||||||
|
const adapter = getDiscoveryAdapter(provider);
|
||||||
|
assert.equal(adapter.supportsDiscovery, false, `${provider} should not support discovery`);
|
||||||
|
const models = await adapter.fetchModels("dummy-key");
|
||||||
|
assert.deepEqual(models, [], `${provider} should return empty models`);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── AuthStorage hasAuth for discovery ───────────────────────────────────────
|
||||||
|
|
||||||
|
describe("AuthStorage — hasAuth for discovery providers", () => {
|
||||||
|
it("returns false for providers without auth", () => {
|
||||||
|
const storage = AuthStorage.inMemory({});
|
||||||
|
assert.equal(storage.hasAuth("openai"), false);
|
||||||
|
assert.equal(storage.hasAuth("ollama"), false);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns true for providers with stored keys", () => {
|
||||||
|
const storage = AuthStorage.inMemory({
|
||||||
|
openai: { type: "api_key" as const, key: "sk-test" },
|
||||||
|
});
|
||||||
|
assert.equal(storage.hasAuth("openai"), true);
|
||||||
|
assert.equal(storage.hasAuth("ollama"), false);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── cache persistence across instances ──────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelDiscoveryCache — persistence", () => {
|
||||||
|
it("data survives across cache instances", () => {
|
||||||
|
const cachePath = join(testDir, "persist.json");
|
||||||
|
|
||||||
|
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache1.set("openai", [
|
||||||
|
{ id: "gpt-4o", name: "GPT-4o", contextWindow: 128000 },
|
||||||
|
{ id: "gpt-4o-mini", name: "GPT-4o Mini" },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||||
|
const entry = cache2.get("openai");
|
||||||
|
assert.ok(entry);
|
||||||
|
assert.equal(entry.models.length, 2);
|
||||||
|
assert.equal(entry.models[0].contextWindow, 128000);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("clear persists across instances", () => {
|
||||||
|
const cachePath = join(testDir, "clear.json");
|
||||||
|
|
||||||
|
const cache1 = new ModelDiscoveryCache(cachePath);
|
||||||
|
cache1.set("openai", [{ id: "gpt-4o" }]);
|
||||||
|
cache1.clear("openai");
|
||||||
|
|
||||||
|
const cache2 = new ModelDiscoveryCache(cachePath);
|
||||||
|
assert.equal(cache2.get("openai"), undefined);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── discovery TTL values ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("Discovery TTL configuration", () => {
|
||||||
|
it("ollama has shortest TTL (local models change often)", () => {
|
||||||
|
const ollamaTTL = getDefaultTTL("ollama");
|
||||||
|
const openaiTTL = getDefaultTTL("openai");
|
||||||
|
assert.ok(ollamaTTL < openaiTTL, "ollama TTL should be shorter than openai");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("unknown providers get default TTL", () => {
|
||||||
|
const customTTL = getDefaultTTL("my-custom-provider");
|
||||||
|
const defaultTTL = getDefaultTTL("default");
|
||||||
|
// Unknown providers should get the same TTL as the explicit "default" key
|
||||||
|
assert.equal(customTTL, defaultTTL);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
@ -24,6 +24,9 @@ import { existsSync, readFileSync } from "fs";
|
||||||
import { join } from "path";
|
import { join } from "path";
|
||||||
import { getAgentDir } from "../config.js";
|
import { getAgentDir } from "../config.js";
|
||||||
import type { AuthStorage } from "./auth-storage.js";
|
import type { AuthStorage } from "./auth-storage.js";
|
||||||
|
import { ModelDiscoveryCache } from "./discovery-cache.js";
|
||||||
|
import type { DiscoveredModel, DiscoveryResult } from "./model-discovery.js";
|
||||||
|
import { getDefaultTTL, getDiscoverableProviders, getDiscoveryAdapter } from "./model-discovery.js";
|
||||||
import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
|
import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
|
||||||
|
|
||||||
const Ajv = (AjvModule as any).default || AjvModule;
|
const Ajv = (AjvModule as any).default || AjvModule;
|
||||||
|
|
@ -221,6 +224,8 @@ export const clearApiKeyCache = clearConfigValueCache;
|
||||||
*/
|
*/
|
||||||
export class ModelRegistry {
|
export class ModelRegistry {
|
||||||
private models: Model<Api>[] = [];
|
private models: Model<Api>[] = [];
|
||||||
|
private discoveredModels: Model<Api>[] = [];
|
||||||
|
private discoveryCache: ModelDiscoveryCache;
|
||||||
private customProviderApiKeys: Map<string, string> = new Map();
|
private customProviderApiKeys: Map<string, string> = new Map();
|
||||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||||
private loadError: string | undefined = undefined;
|
private loadError: string | undefined = undefined;
|
||||||
|
|
@ -229,6 +234,8 @@ export class ModelRegistry {
|
||||||
readonly authStorage: AuthStorage,
|
readonly authStorage: AuthStorage,
|
||||||
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
||||||
) {
|
) {
|
||||||
|
this.discoveryCache = new ModelDiscoveryCache();
|
||||||
|
|
||||||
// Set up fallback resolver for custom provider API keys
|
// Set up fallback resolver for custom provider API keys
|
||||||
this.authStorage.setFallbackResolver((provider) => {
|
this.authStorage.setFallbackResolver((provider) => {
|
||||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||||
|
|
@ -666,6 +673,106 @@ export class ModelRegistry {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discover models from all providers that support discovery.
|
||||||
|
* Results are cached and merged into the registry (never overrides existing models).
|
||||||
|
*/
|
||||||
|
async discoverModels(providers?: string[]): Promise<DiscoveryResult[]> {
|
||||||
|
const targetProviders = providers ?? getDiscoverableProviders();
|
||||||
|
const results: DiscoveryResult[] = [];
|
||||||
|
|
||||||
|
for (const providerName of targetProviders) {
|
||||||
|
const adapter = getDiscoveryAdapter(providerName);
|
||||||
|
if (!adapter.supportsDiscovery) continue;
|
||||||
|
|
||||||
|
// Skip if cache is still fresh
|
||||||
|
if (!this.discoveryCache.isStale(providerName)) {
|
||||||
|
const cached = this.discoveryCache.get(providerName);
|
||||||
|
if (cached) {
|
||||||
|
results.push({
|
||||||
|
provider: providerName,
|
||||||
|
models: cached.models,
|
||||||
|
fetchedAt: cached.fetchedAt,
|
||||||
|
});
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const apiKey = await this.authStorage.getApiKey(providerName);
|
||||||
|
if (!apiKey && providerName !== "ollama") continue;
|
||||||
|
|
||||||
|
const models = await adapter.fetchModels(apiKey ?? "", undefined);
|
||||||
|
this.discoveryCache.set(providerName, models);
|
||||||
|
results.push({
|
||||||
|
provider: providerName,
|
||||||
|
models,
|
||||||
|
fetchedAt: Date.now(),
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
results.push({
|
||||||
|
provider: providerName,
|
||||||
|
models: [],
|
||||||
|
fetchedAt: Date.now(),
|
||||||
|
error: error instanceof Error ? error.message : String(error),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert and merge discovered models
|
||||||
|
this.discoveredModels = this.convertDiscoveredModels(results);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get all models including discovered ones.
|
||||||
|
* Discovered models are appended but never override existing models.
|
||||||
|
*/
|
||||||
|
getAllWithDiscovered(): Model<Api>[] {
|
||||||
|
const existingIds = new Set(this.models.map((m) => `${m.provider}/${m.id}`));
|
||||||
|
const unique = this.discoveredModels.filter((m) => !existingIds.has(`${m.provider}/${m.id}`));
|
||||||
|
return [...this.models, ...unique];
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a model was added via discovery (not built-in or custom).
|
||||||
|
*/
|
||||||
|
isDiscovered(model: Model<Api>): boolean {
|
||||||
|
return this.discoveredModels.some((m) => m.provider === model.provider && m.id === model.id);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the discovery cache instance.
|
||||||
|
*/
|
||||||
|
getDiscoveryCache(): ModelDiscoveryCache {
|
||||||
|
return this.discoveryCache;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert DiscoveryResult[] into Model<Api>[] with default values.
|
||||||
|
*/
|
||||||
|
private convertDiscoveredModels(results: DiscoveryResult[]): Model<Api>[] {
|
||||||
|
const converted: Model<Api>[] = [];
|
||||||
|
for (const result of results) {
|
||||||
|
if (result.error) continue;
|
||||||
|
for (const dm of result.models) {
|
||||||
|
converted.push({
|
||||||
|
id: dm.id,
|
||||||
|
name: dm.name ?? dm.id,
|
||||||
|
api: "openai" as Api,
|
||||||
|
provider: result.provider,
|
||||||
|
baseUrl: "",
|
||||||
|
reasoning: dm.reasoning ?? false,
|
||||||
|
input: dm.input ?? ["text"],
|
||||||
|
cost: dm.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||||
|
contextWindow: dm.contextWindow ?? 128000,
|
||||||
|
maxTokens: dm.maxTokens ?? 16384,
|
||||||
|
} as Model<Api>);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return converted;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
145
packages/pi-coding-agent/src/core/models-json-writer.test.ts
Normal file
145
packages/pi-coding-agent/src/core/models-json-writer.test.ts
Normal file
|
|
@ -0,0 +1,145 @@
|
||||||
|
import assert from "node:assert/strict";
|
||||||
|
import { existsSync, mkdirSync, readFileSync, rmSync } from "node:fs";
|
||||||
|
import { tmpdir } from "node:os";
|
||||||
|
import { join } from "node:path";
|
||||||
|
import { afterEach, beforeEach, describe, it } from "node:test";
|
||||||
|
import { ModelsJsonWriter } from "./models-json-writer.js";
|
||||||
|
|
||||||
|
let testDir: string;
|
||||||
|
let modelsJsonPath: string;
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
testDir = join(tmpdir(), `models-json-writer-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||||
|
mkdirSync(testDir, { recursive: true });
|
||||||
|
modelsJsonPath = join(testDir, "models.json");
|
||||||
|
});
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
try {
|
||||||
|
rmSync(testDir, { recursive: true, force: true });
|
||||||
|
} catch {
|
||||||
|
// Cleanup best-effort
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
function readModels(): Record<string, unknown> {
|
||||||
|
return JSON.parse(readFileSync(modelsJsonPath, "utf-8"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── addModel ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelsJsonWriter — addModel", () => {
|
||||||
|
it("creates file and adds model to new provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.addModel("openai", { id: "gpt-4o", name: "GPT-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.ok(config.providers.openai);
|
||||||
|
assert.equal(config.providers.openai.models.length, 1);
|
||||||
|
assert.equal(config.providers.openai.models[0].id, "gpt-4o");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("appends model to existing provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||||
|
writer.addModel("openai", { id: "gpt-4o-mini" });
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.equal(config.providers.openai.models.length, 2);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("replaces model with same id", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.addModel("openai", { id: "gpt-4o", name: "Old" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||||
|
writer.addModel("openai", { id: "gpt-4o", name: "New" });
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.equal(config.providers.openai.models.length, 1);
|
||||||
|
assert.equal(config.providers.openai.models[0].name, "New");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── removeModel ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelsJsonWriter — removeModel", () => {
|
||||||
|
it("removes a model from provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||||
|
writer.addModel("openai", { id: "gpt-4o-mini" });
|
||||||
|
|
||||||
|
writer.removeModel("openai", "gpt-4o");
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.equal(config.providers.openai.models.length, 1);
|
||||||
|
assert.equal(config.providers.openai.models[0].id, "gpt-4o-mini");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("removes provider when last model is removed", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.addModel("openai", { id: "gpt-4o" }, { baseUrl: "https://api.openai.com", apiKey: "env:OPENAI_API_KEY", api: "openai" });
|
||||||
|
|
||||||
|
writer.removeModel("openai", "gpt-4o");
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.equal(config.providers.openai, undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles removing from nonexistent provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
// Should not throw
|
||||||
|
writer.removeModel("nonexistent", "model-id");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── setProvider / removeProvider ────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelsJsonWriter — provider operations", () => {
|
||||||
|
it("sets a provider configuration", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.setProvider("custom", {
|
||||||
|
baseUrl: "http://localhost:8080",
|
||||||
|
apiKey: "test-key",
|
||||||
|
api: "openai",
|
||||||
|
models: [{ id: "local-model" }],
|
||||||
|
});
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.ok(config.providers.custom);
|
||||||
|
assert.equal(config.providers.custom.baseUrl, "http://localhost:8080");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("removes a provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.setProvider("custom", { baseUrl: "http://localhost:8080" });
|
||||||
|
writer.removeProvider("custom");
|
||||||
|
|
||||||
|
const config = readModels() as any;
|
||||||
|
assert.equal(config.providers.custom, undefined);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("handles removing nonexistent provider", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.removeProvider("nonexistent");
|
||||||
|
// Should not throw
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
// ─── listProviders ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
describe("ModelsJsonWriter — listProviders", () => {
|
||||||
|
it("returns empty config when file does not exist", () => {
|
||||||
|
const writer = new ModelsJsonWriter(join(testDir, "nonexistent.json"));
|
||||||
|
const config = writer.listProviders();
|
||||||
|
assert.deepEqual(config, { providers: {} });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("returns current provider config", () => {
|
||||||
|
const writer = new ModelsJsonWriter(modelsJsonPath);
|
||||||
|
writer.setProvider("openai", { baseUrl: "https://api.openai.com" });
|
||||||
|
writer.setProvider("ollama", { baseUrl: "http://localhost:11434" });
|
||||||
|
|
||||||
|
const config = writer.listProviders();
|
||||||
|
assert.ok(config.providers.openai);
|
||||||
|
assert.ok(config.providers.ollama);
|
||||||
|
});
|
||||||
|
});
|
||||||
188
packages/pi-coding-agent/src/core/models-json-writer.ts
Normal file
188
packages/pi-coding-agent/src/core/models-json-writer.ts
Normal file
|
|
@ -0,0 +1,188 @@
|
||||||
|
/**
|
||||||
|
* Safe read-modify-write for models.json with file locking.
|
||||||
|
* Prevents concurrent writes from corrupting the config file.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||||
|
import { dirname, join } from "path";
|
||||||
|
import lockfile from "proper-lockfile";
|
||||||
|
import { getAgentDir } from "../config.js";
|
||||||
|
|
||||||
|
interface ModelDefinition {
|
||||||
|
id: string;
|
||||||
|
name?: string;
|
||||||
|
api?: string;
|
||||||
|
baseUrl?: string;
|
||||||
|
reasoning?: boolean;
|
||||||
|
input?: ("text" | "image")[];
|
||||||
|
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||||
|
contextWindow?: number;
|
||||||
|
maxTokens?: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ProviderConfig {
|
||||||
|
baseUrl?: string;
|
||||||
|
apiKey?: string;
|
||||||
|
api?: string;
|
||||||
|
headers?: Record<string, string>;
|
||||||
|
authHeader?: boolean;
|
||||||
|
models?: ModelDefinition[];
|
||||||
|
modelOverrides?: Record<string, Record<string, unknown>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ModelsConfig {
|
||||||
|
providers: Record<string, ProviderConfig>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ModelsJsonWriter {
|
||||||
|
private modelsJsonPath: string;
|
||||||
|
|
||||||
|
constructor(modelsJsonPath?: string) {
|
||||||
|
this.modelsJsonPath = modelsJsonPath ?? join(getAgentDir(), "models.json");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a model to a provider. Creates the provider if it doesn't exist.
|
||||||
|
*/
|
||||||
|
addModel(provider: string, model: ModelDefinition, providerConfig?: Partial<ProviderConfig>): void {
|
||||||
|
this.withLock((config) => {
|
||||||
|
if (!config.providers[provider]) {
|
||||||
|
config.providers[provider] = {
|
||||||
|
...providerConfig,
|
||||||
|
models: [],
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
const providerEntry = config.providers[provider];
|
||||||
|
if (!providerEntry.models) {
|
||||||
|
providerEntry.models = [];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace existing model with same id, or append
|
||||||
|
const existingIndex = providerEntry.models.findIndex((m) => m.id === model.id);
|
||||||
|
if (existingIndex >= 0) {
|
||||||
|
providerEntry.models[existingIndex] = model;
|
||||||
|
} else {
|
||||||
|
providerEntry.models.push(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a model from a provider. Removes the provider if no models remain.
|
||||||
|
*/
|
||||||
|
removeModel(provider: string, modelId: string): void {
|
||||||
|
this.withLock((config) => {
|
||||||
|
const providerEntry = config.providers[provider];
|
||||||
|
if (!providerEntry?.models) return config;
|
||||||
|
|
||||||
|
providerEntry.models = providerEntry.models.filter((m) => m.id !== modelId);
|
||||||
|
|
||||||
|
// Clean up empty provider (no models and no overrides)
|
||||||
|
if (providerEntry.models.length === 0 && !providerEntry.modelOverrides) {
|
||||||
|
delete config.providers[provider];
|
||||||
|
}
|
||||||
|
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set or update an entire provider configuration.
|
||||||
|
*/
|
||||||
|
setProvider(provider: string, providerConfig: ProviderConfig): void {
|
||||||
|
this.withLock((config) => {
|
||||||
|
config.providers[provider] = providerConfig;
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Remove a provider and all its models.
|
||||||
|
*/
|
||||||
|
removeProvider(provider: string): void {
|
||||||
|
this.withLock((config) => {
|
||||||
|
delete config.providers[provider];
|
||||||
|
return config;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* List all providers and their configurations.
|
||||||
|
*/
|
||||||
|
listProviders(): ModelsConfig {
|
||||||
|
return this.readConfig();
|
||||||
|
}
|
||||||
|
|
||||||
|
private readConfig(): ModelsConfig {
|
||||||
|
if (!existsSync(this.modelsJsonPath)) {
|
||||||
|
return { providers: {} };
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const content = readFileSync(this.modelsJsonPath, "utf-8");
|
||||||
|
return JSON.parse(content) as ModelsConfig;
|
||||||
|
} catch {
|
||||||
|
return { providers: {} };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private writeConfig(config: ModelsConfig): void {
|
||||||
|
const dir = dirname(this.modelsJsonPath);
|
||||||
|
if (!existsSync(dir)) {
|
||||||
|
mkdirSync(dir, { recursive: true });
|
||||||
|
}
|
||||||
|
writeFileSync(this.modelsJsonPath, JSON.stringify(config, null, 2), "utf-8");
|
||||||
|
}
|
||||||
|
|
||||||
|
private acquireLockWithRetry(): () => void {
|
||||||
|
const maxAttempts = 10;
|
||||||
|
const delayMs = 20;
|
||||||
|
let lastError: unknown;
|
||||||
|
|
||||||
|
// Ensure file exists for locking
|
||||||
|
const dir = dirname(this.modelsJsonPath);
|
||||||
|
if (!existsSync(dir)) {
|
||||||
|
mkdirSync(dir, { recursive: true });
|
||||||
|
}
|
||||||
|
if (!existsSync(this.modelsJsonPath)) {
|
||||||
|
writeFileSync(this.modelsJsonPath, JSON.stringify({ providers: {} }, null, 2), "utf-8");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||||
|
try {
|
||||||
|
return lockfile.lockSync(this.modelsJsonPath, { realpath: false });
|
||||||
|
} catch (error) {
|
||||||
|
const code =
|
||||||
|
typeof error === "object" && error !== null && "code" in error
|
||||||
|
? String((error as { code?: unknown }).code)
|
||||||
|
: undefined;
|
||||||
|
if (code !== "ELOCKED" || attempt === maxAttempts) {
|
||||||
|
throw error;
|
||||||
|
}
|
||||||
|
lastError = error;
|
||||||
|
const start = Date.now();
|
||||||
|
while (Date.now() - start < delayMs) {
|
||||||
|
// Busy-wait (same pattern as auth-storage.ts)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw (lastError as Error) ?? new Error("Failed to acquire models.json lock");
|
||||||
|
}
|
||||||
|
|
||||||
|
private withLock(fn: (config: ModelsConfig) => ModelsConfig): void {
|
||||||
|
let release: (() => void) | undefined;
|
||||||
|
try {
|
||||||
|
release = this.acquireLockWithRetry();
|
||||||
|
const config = this.readConfig();
|
||||||
|
const updated = fn(config);
|
||||||
|
this.writeConfig(updated);
|
||||||
|
} finally {
|
||||||
|
if (release) {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -79,6 +79,13 @@ export interface FallbackSettings {
|
||||||
chains?: Record<string, FallbackChainEntry[]>; // keyed by chain name
|
chains?: Record<string, FallbackChainEntry[]>; // keyed by chain name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export interface ModelDiscoverySettings {
|
||||||
|
enabled?: boolean; // default: false
|
||||||
|
providers?: string[]; // limit discovery to specific providers
|
||||||
|
ttlMinutes?: number; // override default TTLs (in minutes)
|
||||||
|
autoRefreshOnModelSelect?: boolean; // default: false - refresh discovery when opening model selector
|
||||||
|
}
|
||||||
|
|
||||||
export type TransportSetting = Transport;
|
export type TransportSetting = Transport;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -134,6 +141,7 @@ export interface Settings {
|
||||||
bashInterceptor?: BashInterceptorSettings;
|
bashInterceptor?: BashInterceptorSettings;
|
||||||
taskIsolation?: TaskIsolationSettings;
|
taskIsolation?: TaskIsolationSettings;
|
||||||
fallback?: FallbackSettings;
|
fallback?: FallbackSettings;
|
||||||
|
modelDiscovery?: ModelDiscoverySettings;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
|
/** Deep merge settings: project/overrides take precedence, nested objects merge recursively */
|
||||||
|
|
@ -1076,4 +1084,17 @@ export class SettingsManager {
|
||||||
chains: this.getFallbackChains(),
|
chains: this.getFallbackChains(),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getModelDiscoverySettings(): ModelDiscoverySettings {
|
||||||
|
return this.settings.modelDiscovery ?? {};
|
||||||
|
}
|
||||||
|
|
||||||
|
setModelDiscoveryEnabled(enabled: boolean): void {
|
||||||
|
if (!this.globalSettings.modelDiscovery) {
|
||||||
|
this.globalSettings.modelDiscovery = {};
|
||||||
|
}
|
||||||
|
this.globalSettings.modelDiscovery.enabled = enabled;
|
||||||
|
this.markModified("modelDiscovery", "enabled");
|
||||||
|
this.save();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ export const BUILTIN_SLASH_COMMANDS: ReadonlyArray<BuiltinSlashCommand> = [
|
||||||
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
|
{ name: "hotkeys", description: "Show all keyboard shortcuts" },
|
||||||
{ name: "fork", description: "Create a new fork from a previous message" },
|
{ name: "fork", description: "Create a new fork from a previous message" },
|
||||||
{ name: "tree", description: "Navigate session tree (switch branches)" },
|
{ name: "tree", description: "Navigate session tree (switch branches)" },
|
||||||
|
{ name: "provider", description: "Manage provider configuration" },
|
||||||
{ name: "login", description: "Login with OAuth provider" },
|
{ name: "login", description: "Login with OAuth provider" },
|
||||||
{ name: "logout", description: "Logout from OAuth provider" },
|
{ name: "logout", description: "Logout from OAuth provider" },
|
||||||
{ name: "new", description: "Start a new session" },
|
{ name: "new", description: "Start a new session" },
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,11 @@ export {
|
||||||
// Footer data provider (git branch + extension statuses - data not otherwise available to extensions)
|
// Footer data provider (git branch + extension statuses - data not otherwise available to extensions)
|
||||||
export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js";
|
export type { ReadonlyFooterDataProvider } from "./core/footer-data-provider.js";
|
||||||
export { convertToLlm } from "./core/messages.js";
|
export { convertToLlm } from "./core/messages.js";
|
||||||
|
export { ModelDiscoveryCache } from "./core/discovery-cache.js";
|
||||||
|
export type { DiscoveredModel, DiscoveryResult, ProviderDiscoveryAdapter } from "./core/model-discovery.js";
|
||||||
|
export { getDiscoverableProviders, getDiscoveryAdapter } from "./core/model-discovery.js";
|
||||||
export { ModelRegistry } from "./core/model-registry.js";
|
export { ModelRegistry } from "./core/model-registry.js";
|
||||||
|
export { ModelsJsonWriter } from "./core/models-json-writer.js";
|
||||||
export type {
|
export type {
|
||||||
PackageManager,
|
PackageManager,
|
||||||
PathMetadata,
|
PathMetadata,
|
||||||
|
|
@ -307,6 +311,7 @@ export {
|
||||||
LoginDialogComponent,
|
LoginDialogComponent,
|
||||||
ModelSelectorComponent,
|
ModelSelectorComponent,
|
||||||
OAuthSelectorComponent,
|
OAuthSelectorComponent,
|
||||||
|
ProviderManagerComponent,
|
||||||
type RenderDiffOptions,
|
type RenderDiffOptions,
|
||||||
rawKeyHint,
|
rawKeyHint,
|
||||||
renderDiff,
|
renderDiff,
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ import { createInterface } from "readline";
|
||||||
import { type Args, parseArgs, printHelp } from "./cli/args.js";
|
import { type Args, parseArgs, printHelp } from "./cli/args.js";
|
||||||
import { selectConfig } from "./cli/config-selector.js";
|
import { selectConfig } from "./cli/config-selector.js";
|
||||||
import { processFileArguments } from "./cli/file-processor.js";
|
import { processFileArguments } from "./cli/file-processor.js";
|
||||||
import { listModels } from "./cli/list-models.js";
|
import { discoverAndPrintModels, listModels } from "./cli/list-models.js";
|
||||||
import { selectSession } from "./cli/session-picker.js";
|
import { selectSession } from "./cli/session-picker.js";
|
||||||
import { APP_NAME, getAgentDir, getModelsPath, VERSION } from "./config.js";
|
import { APP_NAME, getAgentDir, getModelsPath, VERSION } from "./config.js";
|
||||||
import { AuthStorage } from "./core/auth-storage.js";
|
import { AuthStorage } from "./core/auth-storage.js";
|
||||||
|
|
@ -660,9 +660,26 @@ export async function main(args: string[]) {
|
||||||
process.exit(0);
|
process.exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (parsed.addProvider) {
|
||||||
|
const { ModelsJsonWriter } = await import("./core/models-json-writer.js");
|
||||||
|
const writer = new ModelsJsonWriter();
|
||||||
|
writer.setProvider(parsed.addProvider, {
|
||||||
|
baseUrl: parsed.addProviderBaseUrl,
|
||||||
|
apiKey: parsed.apiKey,
|
||||||
|
});
|
||||||
|
console.log(`Provider "${parsed.addProvider}" added to models.json`);
|
||||||
|
process.exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsed.discoverModels !== undefined) {
|
||||||
|
const provider = typeof parsed.discoverModels === "string" ? parsed.discoverModels : undefined;
|
||||||
|
await discoverAndPrintModels(modelRegistry, provider);
|
||||||
|
process.exit(0);
|
||||||
|
}
|
||||||
|
|
||||||
if (parsed.listModels !== undefined) {
|
if (parsed.listModels !== undefined) {
|
||||||
const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined;
|
const searchPattern = typeof parsed.listModels === "string" ? parsed.listModels : undefined;
|
||||||
await listModels(modelRegistry, searchPattern);
|
await listModels(modelRegistry, { searchPattern, discover: parsed.discover });
|
||||||
process.exit(0);
|
process.exit(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ export { appKey, appKeyHint, editorKey, keyHint, rawKeyHint } from "./keybinding
|
||||||
export { LoginDialogComponent } from "./login-dialog.js";
|
export { LoginDialogComponent } from "./login-dialog.js";
|
||||||
export { ModelSelectorComponent } from "./model-selector.js";
|
export { ModelSelectorComponent } from "./model-selector.js";
|
||||||
export { OAuthSelectorComponent } from "./oauth-selector.js";
|
export { OAuthSelectorComponent } from "./oauth-selector.js";
|
||||||
|
export { ProviderManagerComponent } from "./provider-manager.js";
|
||||||
export { type ModelsCallbacks, type ModelsConfig, ScopedModelsSelectorComponent } from "./scoped-models-selector.js";
|
export { type ModelsCallbacks, type ModelsConfig, ScopedModelsSelectorComponent } from "./scoped-models-selector.js";
|
||||||
export { SessionSelectorComponent } from "./session-selector.js";
|
export { SessionSelectorComponent } from "./session-selector.js";
|
||||||
export { type SettingsCallbacks, type SettingsConfig, SettingsSelectorComponent } from "./settings-selector.js";
|
export { type SettingsCallbacks, type SettingsConfig, SettingsSelectorComponent } from "./settings-selector.js";
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ export class ModelSelectorComponent extends Container implements Focusable {
|
||||||
|
|
||||||
// Load available models (built-in models still work even if models.json failed)
|
// Load available models (built-in models still work even if models.json failed)
|
||||||
try {
|
try {
|
||||||
const availableModels = await this.modelRegistry.getAvailable();
|
const availableModels = this.modelRegistry.getAvailable();
|
||||||
models = availableModels.map((model: Model<any>) => ({
|
models = availableModels.map((model: Model<any>) => ({
|
||||||
provider: model.provider,
|
provider: model.provider,
|
||||||
id: model.id,
|
id: model.id,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
/**
|
||||||
|
* TUI component for managing provider configurations.
|
||||||
|
* Shows providers with auth status, discovery support, and model counts.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
Container,
|
||||||
|
type Focusable,
|
||||||
|
getEditorKeybindings,
|
||||||
|
Spacer,
|
||||||
|
Text,
|
||||||
|
type TUI,
|
||||||
|
} from "@gsd/pi-tui";
|
||||||
|
import type { AuthStorage } from "../../../core/auth-storage.js";
|
||||||
|
import { getDiscoverableProviders } from "../../../core/model-discovery.js";
|
||||||
|
import type { ModelRegistry } from "../../../core/model-registry.js";
|
||||||
|
import { theme } from "../theme/theme.js";
|
||||||
|
import { rawKeyHint } from "./keybinding-hints.js";
|
||||||
|
|
||||||
|
interface ProviderInfo {
|
||||||
|
name: string;
|
||||||
|
hasAuth: boolean;
|
||||||
|
supportsDiscovery: boolean;
|
||||||
|
modelCount: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ProviderManagerComponent extends Container implements Focusable {
|
||||||
|
private _focused = false;
|
||||||
|
get focused(): boolean {
|
||||||
|
return this._focused;
|
||||||
|
}
|
||||||
|
set focused(value: boolean) {
|
||||||
|
this._focused = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
private providers: ProviderInfo[] = [];
|
||||||
|
private selectedIndex = 0;
|
||||||
|
private listContainer: Container;
|
||||||
|
private tui: TUI;
|
||||||
|
private authStorage: AuthStorage;
|
||||||
|
private modelRegistry: ModelRegistry;
|
||||||
|
private onDone: () => void;
|
||||||
|
private onDiscover: (provider: string) => void;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
tui: TUI,
|
||||||
|
authStorage: AuthStorage,
|
||||||
|
modelRegistry: ModelRegistry,
|
||||||
|
onDone: () => void,
|
||||||
|
onDiscover: (provider: string) => void,
|
||||||
|
) {
|
||||||
|
super();
|
||||||
|
|
||||||
|
this.tui = tui;
|
||||||
|
this.authStorage = authStorage;
|
||||||
|
this.modelRegistry = modelRegistry;
|
||||||
|
this.onDone = onDone;
|
||||||
|
this.onDiscover = onDiscover;
|
||||||
|
|
||||||
|
// Header
|
||||||
|
this.addChild(new Text(theme.fg("accent", "Provider Manager"), 0, 0));
|
||||||
|
this.addChild(new Spacer(1));
|
||||||
|
|
||||||
|
// Hints
|
||||||
|
const hints = [
|
||||||
|
rawKeyHint("d", "discover"),
|
||||||
|
rawKeyHint("r", "remove auth"),
|
||||||
|
rawKeyHint("esc", "close"),
|
||||||
|
].join(" ");
|
||||||
|
this.addChild(new Text(hints, 0, 0));
|
||||||
|
this.addChild(new Spacer(1));
|
||||||
|
|
||||||
|
// List
|
||||||
|
this.listContainer = new Container();
|
||||||
|
this.addChild(this.listContainer);
|
||||||
|
|
||||||
|
this.loadProviders();
|
||||||
|
this.updateList();
|
||||||
|
}
|
||||||
|
|
||||||
|
private loadProviders(): void {
|
||||||
|
const discoverableSet = new Set(getDiscoverableProviders());
|
||||||
|
const allModels = this.modelRegistry.getAll();
|
||||||
|
|
||||||
|
// Group models by provider
|
||||||
|
const providerModelCounts = new Map<string, number>();
|
||||||
|
for (const model of allModels) {
|
||||||
|
providerModelCounts.set(model.provider, (providerModelCounts.get(model.provider) ?? 0) + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build provider list from all known providers
|
||||||
|
const providerNames = new Set([
|
||||||
|
...providerModelCounts.keys(),
|
||||||
|
...discoverableSet,
|
||||||
|
]);
|
||||||
|
|
||||||
|
this.providers = Array.from(providerNames)
|
||||||
|
.sort()
|
||||||
|
.map((name) => ({
|
||||||
|
name,
|
||||||
|
hasAuth: this.authStorage.hasAuth(name),
|
||||||
|
supportsDiscovery: discoverableSet.has(name),
|
||||||
|
modelCount: providerModelCounts.get(name) ?? 0,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
private updateList(): void {
|
||||||
|
this.listContainer.clear();
|
||||||
|
|
||||||
|
for (let i = 0; i < this.providers.length; i++) {
|
||||||
|
const p = this.providers[i];
|
||||||
|
const isSelected = i === this.selectedIndex;
|
||||||
|
|
||||||
|
const authBadge = p.hasAuth ? theme.fg("success", "[auth]") : theme.fg("muted", "[no auth]");
|
||||||
|
const discoveryBadge = p.supportsDiscovery ? theme.fg("accent", "[discovery]") : "";
|
||||||
|
const countBadge = theme.fg("muted", `(${p.modelCount} models)`);
|
||||||
|
|
||||||
|
const prefix = isSelected ? theme.fg("accent", "> ") : " ";
|
||||||
|
const nameText = isSelected ? theme.fg("accent", p.name) : p.name;
|
||||||
|
|
||||||
|
const parts = [prefix, nameText, " ", authBadge];
|
||||||
|
if (discoveryBadge) parts.push(" ", discoveryBadge);
|
||||||
|
parts.push(" ", countBadge);
|
||||||
|
|
||||||
|
this.listContainer.addChild(new Text(parts.join(""), 0, 0));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.providers.length === 0) {
|
||||||
|
this.listContainer.addChild(new Text(theme.fg("muted", " No providers configured"), 0, 0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
handleInput(keyData: string): void {
|
||||||
|
const kb = getEditorKeybindings();
|
||||||
|
|
||||||
|
if (kb.matches(keyData, "selectUp")) {
|
||||||
|
if (this.providers.length === 0) return;
|
||||||
|
this.selectedIndex = this.selectedIndex === 0 ? this.providers.length - 1 : this.selectedIndex - 1;
|
||||||
|
this.updateList();
|
||||||
|
this.tui.requestRender();
|
||||||
|
} else if (kb.matches(keyData, "selectDown")) {
|
||||||
|
if (this.providers.length === 0) return;
|
||||||
|
this.selectedIndex = this.selectedIndex === this.providers.length - 1 ? 0 : this.selectedIndex + 1;
|
||||||
|
this.updateList();
|
||||||
|
this.tui.requestRender();
|
||||||
|
} else if (kb.matches(keyData, "selectCancel")) {
|
||||||
|
this.onDone();
|
||||||
|
} else if (keyData === "d" || keyData === "D") {
|
||||||
|
const provider = this.providers[this.selectedIndex];
|
||||||
|
if (provider?.supportsDiscovery) {
|
||||||
|
this.onDiscover(provider.name);
|
||||||
|
}
|
||||||
|
} else if (keyData === "r" || keyData === "R") {
|
||||||
|
const provider = this.providers[this.selectedIndex];
|
||||||
|
if (provider?.hasAuth) {
|
||||||
|
this.authStorage.remove(provider.name);
|
||||||
|
this.loadProviders();
|
||||||
|
this.updateList();
|
||||||
|
this.tui.requestRender();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -83,6 +83,7 @@ import { appKey, appKeyHint, editorKey, formatKeyForDisplay, keyHint, rawKeyHint
|
||||||
import { LoginDialogComponent } from "./components/login-dialog.js";
|
import { LoginDialogComponent } from "./components/login-dialog.js";
|
||||||
import { ModelSelectorComponent } from "./components/model-selector.js";
|
import { ModelSelectorComponent } from "./components/model-selector.js";
|
||||||
import { OAuthSelectorComponent } from "./components/oauth-selector.js";
|
import { OAuthSelectorComponent } from "./components/oauth-selector.js";
|
||||||
|
import { ProviderManagerComponent } from "./components/provider-manager.js";
|
||||||
import { ScopedModelsSelectorComponent } from "./components/scoped-models-selector.js";
|
import { ScopedModelsSelectorComponent } from "./components/scoped-models-selector.js";
|
||||||
import { SessionSelectorComponent } from "./components/session-selector.js";
|
import { SessionSelectorComponent } from "./components/session-selector.js";
|
||||||
import { SelectSubmenu, SettingsSelectorComponent, THINKING_DESCRIPTIONS } from "./components/settings-selector.js";
|
import { SelectSubmenu, SettingsSelectorComponent, THINKING_DESCRIPTIONS } from "./components/settings-selector.js";
|
||||||
|
|
@ -1997,6 +1998,11 @@ export class InteractiveMode {
|
||||||
this.editor.setText("");
|
this.editor.setText("");
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
if (text === "/provider") {
|
||||||
|
this.showProviderManager();
|
||||||
|
this.editor.setText("");
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (text === "/login") {
|
if (text === "/login") {
|
||||||
this.showOAuthSelector("login");
|
this.showOAuthSelector("login");
|
||||||
this.editor.setText("");
|
this.editor.setText("");
|
||||||
|
|
@ -3746,6 +3752,37 @@ export class InteractiveMode {
|
||||||
this.showStatus("Resumed session");
|
this.showStatus("Resumed session");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private showProviderManager(): void {
|
||||||
|
this.showSelector((done) => {
|
||||||
|
const component = new ProviderManagerComponent(
|
||||||
|
this.ui,
|
||||||
|
this.session.modelRegistry.authStorage,
|
||||||
|
this.session.modelRegistry,
|
||||||
|
() => {
|
||||||
|
done();
|
||||||
|
this.ui.requestRender();
|
||||||
|
},
|
||||||
|
async (provider: string) => {
|
||||||
|
this.showStatus(`Discovering models for ${provider}...`);
|
||||||
|
try {
|
||||||
|
const results = await this.session.modelRegistry.discoverModels([provider]);
|
||||||
|
const result = results[0];
|
||||||
|
if (result?.error) {
|
||||||
|
this.showError(`Discovery failed: ${result.error}`);
|
||||||
|
} else {
|
||||||
|
this.showStatus(`Discovered ${result?.models.length ?? 0} models from ${provider}`);
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
this.showError(error instanceof Error ? error.message : String(error));
|
||||||
|
}
|
||||||
|
done();
|
||||||
|
this.ui.requestRender();
|
||||||
|
},
|
||||||
|
);
|
||||||
|
return { component, focus: component };
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
|
private async showOAuthSelector(mode: "login" | "logout"): Promise<void> {
|
||||||
if (mode === "logout") {
|
if (mode === "logout") {
|
||||||
const providers = this.session.modelRegistry.authStorage.list();
|
const providers = this.session.modelRegistry.authStorage.list();
|
||||||
|
|
|
||||||
|
|
@ -511,8 +511,10 @@ async function handlePrefsWizard(
|
||||||
prefs.auto_supervisor = autoSup;
|
prefs.auto_supervisor = autoSup;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ─── Git main branch ────────────────────────────────────────────────────
|
// ─── Git settings ───────────────────────────────────────────────────────
|
||||||
const git: Record<string, unknown> = (prefs.git as Record<string, unknown>) ?? {};
|
const git: Record<string, unknown> = (prefs.git as Record<string, unknown>) ?? {};
|
||||||
|
|
||||||
|
// main_branch
|
||||||
const currentBranch = git.main_branch ? String(git.main_branch) : "";
|
const currentBranch = git.main_branch ? String(git.main_branch) : "";
|
||||||
const branchInput = await ctx.ui.input(
|
const branchInput = await ctx.ui.input(
|
||||||
`Git main branch${currentBranch ? ` (current: ${currentBranch})` : ""}:`,
|
`Git main branch${currentBranch ? ` (current: ${currentBranch})` : ""}:`,
|
||||||
|
|
@ -526,6 +528,90 @@ async function handlePrefsWizard(
|
||||||
delete git.main_branch;
|
delete git.main_branch;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Boolean git toggles
|
||||||
|
const gitBooleanFields = [
|
||||||
|
{ key: "auto_push", label: "Auto-push commits after committing", defaultVal: false },
|
||||||
|
{ key: "push_branches", label: "Push milestone branches to remote", defaultVal: false },
|
||||||
|
{ key: "snapshots", label: "Create WIP snapshot commits during long tasks", defaultVal: false },
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
for (const field of gitBooleanFields) {
|
||||||
|
const current = git[field.key];
|
||||||
|
const currentStr = current !== undefined ? String(current) : "";
|
||||||
|
const choice = await ctx.ui.select(
|
||||||
|
`${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`,
|
||||||
|
["true", "false", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (choice && choice !== "(keep current)") {
|
||||||
|
git[field.key] = choice === "true";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// remote
|
||||||
|
const currentRemote = git.remote ? String(git.remote) : "";
|
||||||
|
const remoteInput = await ctx.ui.input(
|
||||||
|
`Git remote name${currentRemote ? ` (current: ${currentRemote})` : " (default: origin)"}:`,
|
||||||
|
currentRemote || "origin",
|
||||||
|
);
|
||||||
|
if (remoteInput !== null && remoteInput !== undefined) {
|
||||||
|
const val = remoteInput.trim();
|
||||||
|
if (val && val !== "origin") {
|
||||||
|
git.remote = val;
|
||||||
|
} else if (!val && currentRemote) {
|
||||||
|
delete git.remote;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre_merge_check
|
||||||
|
const currentPreMerge = git.pre_merge_check !== undefined ? String(git.pre_merge_check) : "";
|
||||||
|
const preMergeChoice = await ctx.ui.select(
|
||||||
|
`Pre-merge check${currentPreMerge ? ` (current: ${currentPreMerge})` : " (default: false)"}:`,
|
||||||
|
["true", "false", "auto", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (preMergeChoice && preMergeChoice !== "(keep current)") {
|
||||||
|
if (preMergeChoice === "auto") {
|
||||||
|
git.pre_merge_check = "auto";
|
||||||
|
} else {
|
||||||
|
git.pre_merge_check = preMergeChoice === "true";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// commit_type
|
||||||
|
const currentCommitType = git.commit_type ? String(git.commit_type) : "";
|
||||||
|
const commitTypes = ["feat", "fix", "refactor", "docs", "test", "chore", "perf", "ci", "build", "style", "(inferred — default)", "(keep current)"];
|
||||||
|
const commitChoice = await ctx.ui.select(
|
||||||
|
`Default commit type${currentCommitType ? ` (current: ${currentCommitType})` : ""}:`,
|
||||||
|
commitTypes,
|
||||||
|
);
|
||||||
|
if (commitChoice && typeof commitChoice === "string" && commitChoice !== "(keep current)") {
|
||||||
|
if ((commitChoice as string).startsWith("(inferred")) {
|
||||||
|
delete git.commit_type;
|
||||||
|
} else {
|
||||||
|
git.commit_type = commitChoice;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// merge_strategy
|
||||||
|
const currentMerge = git.merge_strategy ? String(git.merge_strategy) : "";
|
||||||
|
const mergeChoice = await ctx.ui.select(
|
||||||
|
`Merge strategy${currentMerge ? ` (current: ${currentMerge})` : ""}:`,
|
||||||
|
["squash", "merge", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (mergeChoice && mergeChoice !== "(keep current)") {
|
||||||
|
git.merge_strategy = mergeChoice;
|
||||||
|
}
|
||||||
|
|
||||||
|
// isolation
|
||||||
|
const currentIsolation = git.isolation ? String(git.isolation) : "";
|
||||||
|
const isolationChoice = await ctx.ui.select(
|
||||||
|
`Git isolation strategy${currentIsolation ? ` (current: ${currentIsolation})` : " (default: worktree)"}:`,
|
||||||
|
["worktree", "branch", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (isolationChoice && isolationChoice !== "(keep current)") {
|
||||||
|
git.isolation = isolationChoice;
|
||||||
|
}
|
||||||
|
|
||||||
// ─── Git commit_docs ────────────────────────────────────────────────────
|
// ─── Git commit_docs ────────────────────────────────────────────────────
|
||||||
const currentCommitDocs = git.commit_docs;
|
const currentCommitDocs = git.commit_docs;
|
||||||
const commitDocsChoice = await ctx.ui.select(
|
const commitDocsChoice = await ctx.ui.select(
|
||||||
|
|
@ -560,6 +646,89 @@ async function handlePrefsWizard(
|
||||||
prefs.unique_milestone_ids = uniqueChoice === "true";
|
prefs.unique_milestone_ids = uniqueChoice === "true";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Budget & cost control ────────────────────────────────────────────
|
||||||
|
const currentCeiling = prefs.budget_ceiling;
|
||||||
|
const ceilingStr = currentCeiling !== undefined ? String(currentCeiling) : "";
|
||||||
|
const ceilingInput = await ctx.ui.input(
|
||||||
|
`Budget ceiling (USD)${ceilingStr ? ` (current: $${ceilingStr})` : " (default: no limit)"}:`,
|
||||||
|
ceilingStr || "",
|
||||||
|
);
|
||||||
|
if (ceilingInput !== null && ceilingInput !== undefined) {
|
||||||
|
const val = ceilingInput.trim().replace(/^\$/, "");
|
||||||
|
if (val && !isNaN(Number(val)) && isFinite(Number(val))) {
|
||||||
|
prefs.budget_ceiling = Number(val);
|
||||||
|
} else if (val && (isNaN(Number(val)) || !isFinite(Number(val)))) {
|
||||||
|
ctx.ui.notify(`Invalid budget ceiling "${val}" — must be a number. Keeping previous value.`, "warning");
|
||||||
|
} else if (!val && ceilingStr) {
|
||||||
|
delete prefs.budget_ceiling;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentEnforcement = (prefs.budget_enforcement as string) ?? "";
|
||||||
|
const enforcementChoice = await ctx.ui.select(
|
||||||
|
`Budget enforcement${currentEnforcement ? ` (current: ${currentEnforcement})` : " (default: pause)"}:`,
|
||||||
|
["warn", "pause", "halt", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (enforcementChoice && enforcementChoice !== "(keep current)") {
|
||||||
|
prefs.budget_enforcement = enforcementChoice;
|
||||||
|
}
|
||||||
|
|
||||||
|
const currentContextPause = prefs.context_pause_threshold;
|
||||||
|
const contextPauseStr = currentContextPause !== undefined ? String(currentContextPause) : "";
|
||||||
|
const contextPauseInput = await ctx.ui.input(
|
||||||
|
`Context pause threshold (0-100%, 0=disabled)${contextPauseStr ? ` (current: ${contextPauseStr}%)` : " (default: 0)"}:`,
|
||||||
|
contextPauseStr || "0",
|
||||||
|
);
|
||||||
|
if (contextPauseInput !== null && contextPauseInput !== undefined) {
|
||||||
|
const val = contextPauseInput.trim().replace(/%$/, "");
|
||||||
|
if (val && !isNaN(Number(val)) && Number(val) >= 0 && Number(val) <= 100) {
|
||||||
|
const num = Number(val);
|
||||||
|
if (num === 0) {
|
||||||
|
delete prefs.context_pause_threshold;
|
||||||
|
} else {
|
||||||
|
prefs.context_pause_threshold = num;
|
||||||
|
}
|
||||||
|
} else if (val && (isNaN(Number(val)) || Number(val) < 0 || Number(val) > 100)) {
|
||||||
|
ctx.ui.notify(`Invalid context pause threshold "${val}" — must be 0-100. Keeping previous value.`, "warning");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Notifications ────────────────────────────────────────────────────
|
||||||
|
const notif: Record<string, boolean> = (prefs.notifications as Record<string, boolean>) ?? {};
|
||||||
|
const notifFields = [
|
||||||
|
{ key: "enabled", label: "Notifications enabled (master toggle)", defaultVal: true },
|
||||||
|
{ key: "on_complete", label: "Notify on unit completion", defaultVal: true },
|
||||||
|
{ key: "on_error", label: "Notify on errors", defaultVal: true },
|
||||||
|
{ key: "on_budget", label: "Notify on budget thresholds", defaultVal: true },
|
||||||
|
{ key: "on_milestone", label: "Notify on milestone completion", defaultVal: true },
|
||||||
|
{ key: "on_attention", label: "Notify when manual attention needed", defaultVal: true },
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
for (const field of notifFields) {
|
||||||
|
const current = notif[field.key];
|
||||||
|
const currentStr = current !== undefined ? String(current) : "";
|
||||||
|
const choice = await ctx.ui.select(
|
||||||
|
`${field.label}${currentStr ? ` (current: ${currentStr})` : ` (default: ${field.defaultVal})`}:`,
|
||||||
|
["true", "false", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (choice && choice !== "(keep current)") {
|
||||||
|
notif[field.key] = choice === "true";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (Object.keys(notif).length > 0) {
|
||||||
|
prefs.notifications = notif;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── UAT dispatch ─────────────────────────────────────────────────────
|
||||||
|
const currentUat = prefs.uat_dispatch;
|
||||||
|
const uatChoice = await ctx.ui.select(
|
||||||
|
`UAT dispatch mode${currentUat !== undefined ? ` (current: ${currentUat})` : " (default: false)"}:`,
|
||||||
|
["true", "false", "(keep current)"],
|
||||||
|
);
|
||||||
|
if (uatChoice && uatChoice !== "(keep current)") {
|
||||||
|
prefs.uat_dispatch = uatChoice === "true";
|
||||||
|
}
|
||||||
|
|
||||||
// ─── Serialize to frontmatter ───────────────────────────────────────────
|
// ─── Serialize to frontmatter ───────────────────────────────────────────
|
||||||
prefs.version = prefs.version || 1;
|
prefs.version = prefs.version || 1;
|
||||||
const frontmatter = serializePreferencesToFrontmatter(prefs);
|
const frontmatter = serializePreferencesToFrontmatter(prefs);
|
||||||
|
|
@ -650,7 +819,10 @@ function serializePreferencesToFrontmatter(prefs: Record<string, unknown>): stri
|
||||||
const orderedKeys = [
|
const orderedKeys = [
|
||||||
"version", "always_use_skills", "prefer_skills", "avoid_skills",
|
"version", "always_use_skills", "prefer_skills", "avoid_skills",
|
||||||
"skill_rules", "custom_instructions", "models", "skill_discovery",
|
"skill_rules", "custom_instructions", "models", "skill_discovery",
|
||||||
"auto_supervisor", "uat_dispatch", "unique_milestone_ids", "budget_ceiling", "remote_questions", "git",
|
"auto_supervisor", "uat_dispatch", "unique_milestone_ids",
|
||||||
|
"budget_ceiling", "budget_enforcement", "context_pause_threshold",
|
||||||
|
"notifications", "remote_questions", "git",
|
||||||
|
"post_unit_hooks", "pre_dispatch_hooks",
|
||||||
];
|
];
|
||||||
|
|
||||||
const seen = new Set<string>();
|
const seen = new Set<string>();
|
||||||
|
|
|
||||||
|
|
@ -108,10 +108,51 @@ Setting `prefer_skills: []` does **not** disable skill discovery — it just mea
|
||||||
- `pre_merge_check`: boolean or `"auto"` — run pre-merge checks before merging a worktree back to the integration branch. `true` always runs, `false` never runs, `"auto"` runs when CI is detected. Default: `false`.
|
- `pre_merge_check`: boolean or `"auto"` — run pre-merge checks before merging a worktree back to the integration branch. `true` always runs, `false` never runs, `"auto"` runs when CI is detected. Default: `false`.
|
||||||
- `commit_type`: string — override the conventional commit type prefix. Must be one of: `feat`, `fix`, `refactor`, `docs`, `test`, `chore`, `perf`, `ci`, `build`, `style`. Default: inferred from diff content.
|
- `commit_type`: string — override the conventional commit type prefix. Must be one of: `feat`, `fix`, `refactor`, `docs`, `test`, `chore`, `perf`, `ci`, `build`, `style`. Default: inferred from diff content.
|
||||||
- `main_branch`: string — the primary branch name for new git repos (e.g., `"main"`, `"master"`, `"trunk"`). Also used by `getMainBranch()` as the preferred branch when auto-detection is ambiguous. Default: `"main"`.
|
- `main_branch`: string — the primary branch name for new git repos (e.g., `"main"`, `"master"`, `"trunk"`). Also used by `getMainBranch()` as the preferred branch when auto-detection is ambiguous. Default: `"main"`.
|
||||||
|
- `merge_strategy`: `"squash"` or `"merge"` — controls how worktree branches are merged back. `"squash"` combines all commits into one; `"merge"` preserves individual commits. Default: `"squash"`.
|
||||||
|
- `isolation`: `"worktree"` or `"branch"` — controls auto-mode git isolation strategy. `"worktree"` creates a milestone worktree for isolated work; `"branch"` works directly in the project root (useful for submodule-heavy repos). Default: `"worktree"`.
|
||||||
- `commit_docs`: boolean — when `false`, prevents GSD from committing `.gsd/` planning artifacts to git. The `.gsd/` folder is added to `.gitignore` and kept local-only. Useful for teams where only some members use GSD, or when company policy requires a clean repository. Default: `true`.
|
- `commit_docs`: boolean — when `false`, prevents GSD from committing `.gsd/` planning artifacts to git. The `.gsd/` folder is added to `.gitignore` and kept local-only. Useful for teams where only some members use GSD, or when company policy requires a clean repository. Default: `true`.
|
||||||
|
|
||||||
- `unique_milestone_ids`: boolean — when `true`, generates milestone IDs in `M{seq}-{rand6}` format (e.g. `M001-eh88as`) instead of plain sequential `M001`. Prevents ID collisions in team workflows where multiple contributors create milestones concurrently. Both formats coexist — existing `M001`-style milestones remain valid. Default: `false`.
|
- `unique_milestone_ids`: boolean — when `true`, generates milestone IDs in `M{seq}-{rand6}` format (e.g. `M001-eh88as`) instead of plain sequential `M001`. Prevents ID collisions in team workflows where multiple contributors create milestones concurrently. Both formats coexist — existing `M001`-style milestones remain valid. Default: `false`.
|
||||||
|
|
||||||
|
- `budget_ceiling`: number — maximum dollar amount to spend on auto-mode. When reached, behavior is controlled by `budget_enforcement`. Default: no limit.
|
||||||
|
|
||||||
|
- `budget_enforcement`: `"warn"`, `"pause"`, or `"halt"` — action taken when `budget_ceiling` is reached.
|
||||||
|
- `warn` — log a warning but continue execution.
|
||||||
|
- `pause` — pause auto-mode and wait for user confirmation.
|
||||||
|
- `halt` — stop auto-mode immediately.
|
||||||
|
- Default: `"pause"`.
|
||||||
|
|
||||||
|
- `context_pause_threshold`: number (0-100) — context window usage percentage at which auto-mode should pause to suggest checkpointing. Set to `0` to disable. Default: `0` (disabled).
|
||||||
|
|
||||||
|
- `notifications`: configures desktop notification behavior during auto-mode. Keys:
|
||||||
|
- `enabled`: boolean — master toggle for all notifications. Default: `true`.
|
||||||
|
- `on_complete`: boolean — notify when a unit completes. Default: `true`.
|
||||||
|
- `on_error`: boolean — notify on errors. Default: `true`.
|
||||||
|
- `on_budget`: boolean — notify when budget thresholds are reached. Default: `true`.
|
||||||
|
- `on_milestone`: boolean — notify when a milestone finishes. Default: `true`.
|
||||||
|
- `on_attention`: boolean — notify when manual attention is needed. Default: `true`.
|
||||||
|
|
||||||
|
- `uat_dispatch`: boolean — when `true`, enables UAT (User Acceptance Testing) dispatch mode. Default: `false`.
|
||||||
|
|
||||||
|
- `post_unit_hooks`: array — hooks that fire after a unit completes. Each entry has:
|
||||||
|
- `name`: string — unique hook identifier.
|
||||||
|
- `after`: string[] — unit types that trigger this hook (e.g., `["execute-task"]`).
|
||||||
|
- `prompt`: string — prompt sent to the LLM. Supports `{milestoneId}`, `{sliceId}`, `{taskId}` substitutions.
|
||||||
|
- `max_cycles`: number — max times this hook fires per trigger (default: 1, max: 10).
|
||||||
|
- `model`: string — optional model override.
|
||||||
|
- `artifact`: string — expected output file (skip if exists).
|
||||||
|
- `retry_on`: string — file that triggers re-run of the trigger unit.
|
||||||
|
- `enabled`: boolean — toggle without removing (default: `true`).
|
||||||
|
|
||||||
|
- `pre_dispatch_hooks`: array — hooks that fire before a unit is dispatched. Each entry has:
|
||||||
|
- `name`: string — unique hook identifier.
|
||||||
|
- `before`: string[] — unit types to intercept.
|
||||||
|
- `action`: `"modify"`, `"skip"`, or `"replace"` — what to do with the unit.
|
||||||
|
- `prepend`: string — text prepended to unit prompt (for `"modify"` action).
|
||||||
|
- `append`: string — text appended to unit prompt (for `"modify"` action).
|
||||||
|
- `prompt`: string — replacement prompt (for `"replace"` action).
|
||||||
|
- `enabled`: boolean — toggle without removing (default: `true`).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Best Practices
|
## Best Practices
|
||||||
|
|
@ -277,3 +318,56 @@ git:
|
||||||
```
|
```
|
||||||
|
|
||||||
All git fields are optional. Omit any field to use the default behavior. Project-level preferences override global preferences on a per-field basis.
|
All git fields are optional. Omit any field to use the default behavior. Project-level preferences override global preferences on a per-field basis.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Budget & Cost Control Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
---
|
||||||
|
version: 1
|
||||||
|
budget_ceiling: 10.00
|
||||||
|
budget_enforcement: pause
|
||||||
|
context_pause_threshold: 80
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
|
Sets a $10 budget ceiling. Auto-mode pauses when the ceiling is reached. Context window pauses at 80% usage for checkpointing.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Notifications Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
---
|
||||||
|
version: 1
|
||||||
|
notifications:
|
||||||
|
enabled: true
|
||||||
|
on_complete: false
|
||||||
|
on_error: true
|
||||||
|
on_budget: true
|
||||||
|
on_milestone: true
|
||||||
|
on_attention: true
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
|
Disables per-unit completion notifications (noisy in long runs) while keeping error, budget, milestone, and attention notifications enabled.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Post-Unit Hooks Example
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
---
|
||||||
|
version: 1
|
||||||
|
post_unit_hooks:
|
||||||
|
- name: code-review
|
||||||
|
after:
|
||||||
|
- execute-task
|
||||||
|
prompt: "Review the code changes in {sliceId}/{taskId} for quality, security, and test coverage."
|
||||||
|
max_cycles: 1
|
||||||
|
artifact: REVIEW.md
|
||||||
|
---
|
||||||
|
```
|
||||||
|
|
||||||
|
Runs an automated code review after each task execution. Skips if `REVIEW.md` already exists (idempotent).
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import { existsSync, readdirSync, readFileSync, statSync } from "node:fs";
|
import { existsSync, readdirSync, readFileSync, statSync, writeFileSync } from "node:fs";
|
||||||
import { homedir } from "node:os";
|
import { homedir } from "node:os";
|
||||||
import { isAbsolute, join } from "node:path";
|
import { isAbsolute, join } from "node:path";
|
||||||
import { getAgentDir } from "@gsd/pi-coding-agent";
|
import { getAgentDir } from "@gsd/pi-coding-agent";
|
||||||
|
|
@ -1252,3 +1252,61 @@ export function resolvePreDispatchHooks(): PreDispatchHookConfig[] {
|
||||||
return (prefs?.preferences.pre_dispatch_hooks ?? [])
|
return (prefs?.preferences.pre_dispatch_hooks ?? [])
|
||||||
.filter(h => h.enabled !== false);
|
.filter(h => h.enabled !== false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate a model ID string.
|
||||||
|
* Returns true if the ID looks like a valid model identifier.
|
||||||
|
*/
|
||||||
|
export function validateModelId(modelId: string): boolean {
|
||||||
|
if (!modelId || typeof modelId !== "string") return false;
|
||||||
|
const trimmed = modelId.trim();
|
||||||
|
if (trimmed.length === 0 || trimmed.length > 256) return false;
|
||||||
|
// Allow alphanumeric, hyphens, underscores, dots, slashes, colons
|
||||||
|
return /^[a-zA-Z0-9\-_./:]+$/.test(trimmed);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Update the models section of the global GSD preferences file.
|
||||||
|
* Performs a safe read-modify-write: reads current content, updates the models
|
||||||
|
* YAML block, and writes back. Creates the file if it doesn't exist.
|
||||||
|
*/
|
||||||
|
export function updatePreferencesModels(models: GSDModelConfigV2): void {
|
||||||
|
const prefsPath = getGlobalGSDPreferencesPath();
|
||||||
|
|
||||||
|
let content = "";
|
||||||
|
if (existsSync(prefsPath)) {
|
||||||
|
content = readFileSync(prefsPath, "utf-8");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the new models block
|
||||||
|
const lines: string[] = ["models:"];
|
||||||
|
for (const [phase, value] of Object.entries(models)) {
|
||||||
|
if (typeof value === "string") {
|
||||||
|
lines.push(` ${phase}: ${value}`);
|
||||||
|
} else if (value && typeof value === "object") {
|
||||||
|
const config = value as GSDPhaseModelConfig;
|
||||||
|
lines.push(` ${phase}:`);
|
||||||
|
lines.push(` model: ${config.model}`);
|
||||||
|
if (config.provider) {
|
||||||
|
lines.push(` provider: ${config.provider}`);
|
||||||
|
}
|
||||||
|
if (config.fallbacks && config.fallbacks.length > 0) {
|
||||||
|
lines.push(` fallbacks:`);
|
||||||
|
for (const fb of config.fallbacks) {
|
||||||
|
lines.push(` - ${fb}`);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const modelsBlock = lines.join("\n");
|
||||||
|
|
||||||
|
// Replace existing models block or append
|
||||||
|
const modelsRegex = /^models:[\s\S]*?(?=\n[a-z_]|\n*$)/m;
|
||||||
|
if (modelsRegex.test(content)) {
|
||||||
|
content = content.replace(modelsRegex, modelsBlock);
|
||||||
|
} else {
|
||||||
|
content = content.trimEnd() + "\n\n" + modelsBlock + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
writeFileSync(prefsPath, content, "utf-8");
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,21 @@ git:
|
||||||
snapshots:
|
snapshots:
|
||||||
pre_merge_check:
|
pre_merge_check:
|
||||||
commit_type:
|
commit_type:
|
||||||
|
main_branch:
|
||||||
|
merge_strategy:
|
||||||
|
isolation:
|
||||||
unique_milestone_ids:
|
unique_milestone_ids:
|
||||||
|
budget_ceiling:
|
||||||
|
budget_enforcement:
|
||||||
|
context_pause_threshold:
|
||||||
|
notifications:
|
||||||
|
enabled:
|
||||||
|
on_complete:
|
||||||
|
on_error:
|
||||||
|
on_budget:
|
||||||
|
on_milestone:
|
||||||
|
on_attention:
|
||||||
|
uat_dispatch:
|
||||||
---
|
---
|
||||||
|
|
||||||
# GSD Skill Preferences
|
# GSD Skill Preferences
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,168 @@
|
||||||
|
/**
|
||||||
|
* preferences-wizard-fields.test.ts — Validates that all wizard-configurable
|
||||||
|
* preference fields are properly validated and round-trip through the schema.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createTestContext } from "./test-helpers.ts";
|
||||||
|
import { validatePreferences } from "../preferences.ts";
|
||||||
|
import type { GSDPreferences } from "../preferences.ts";
|
||||||
|
|
||||||
|
const { assertEq, assertTrue, report } = createTestContext();
|
||||||
|
|
||||||
|
async function main(): Promise<void> {
|
||||||
|
console.log("\n=== budget fields validate correctly ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({
|
||||||
|
budget_ceiling: 25.50,
|
||||||
|
budget_enforcement: "warn",
|
||||||
|
context_pause_threshold: 80,
|
||||||
|
});
|
||||||
|
assertEq(errors.length, 0, "valid budget fields produce no errors");
|
||||||
|
assertEq(preferences.budget_ceiling, 25.50, "budget_ceiling passes through");
|
||||||
|
assertEq(preferences.budget_enforcement, "warn", "budget_enforcement passes through");
|
||||||
|
assertEq(preferences.context_pause_threshold, 80, "context_pause_threshold passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({
|
||||||
|
budget_enforcement: "pause",
|
||||||
|
});
|
||||||
|
assertEq(errors.length, 0, "budget_enforcement 'pause' is valid");
|
||||||
|
assertEq(preferences.budget_enforcement, "pause", "pause passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({
|
||||||
|
budget_enforcement: "halt",
|
||||||
|
});
|
||||||
|
assertEq(errors.length, 0, "budget_enforcement 'halt' is valid");
|
||||||
|
assertEq(preferences.budget_enforcement, "halt", "halt passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const { errors } = validatePreferences({
|
||||||
|
budget_enforcement: "invalid",
|
||||||
|
} as unknown as GSDPreferences);
|
||||||
|
assertTrue(errors.some(e => e.includes("budget_enforcement")), "invalid budget_enforcement rejected");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("\n=== notification fields validate correctly ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({
|
||||||
|
notifications: {
|
||||||
|
enabled: true,
|
||||||
|
on_complete: false,
|
||||||
|
on_error: true,
|
||||||
|
on_budget: true,
|
||||||
|
on_milestone: false,
|
||||||
|
on_attention: true,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
assertEq(errors.length, 0, "valid notifications produce no errors");
|
||||||
|
assertEq(preferences.notifications?.enabled, true, "notifications.enabled passes through");
|
||||||
|
assertEq(preferences.notifications?.on_complete, false, "notifications.on_complete passes through");
|
||||||
|
assertEq(preferences.notifications?.on_milestone, false, "notifications.on_milestone passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const { errors } = validatePreferences({
|
||||||
|
notifications: "invalid",
|
||||||
|
} as unknown as GSDPreferences);
|
||||||
|
assertTrue(errors.some(e => e.includes("notifications")), "invalid notifications rejected");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("\n=== git fields validate correctly ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({
|
||||||
|
git: {
|
||||||
|
auto_push: true,
|
||||||
|
push_branches: false,
|
||||||
|
remote: "upstream",
|
||||||
|
snapshots: true,
|
||||||
|
pre_merge_check: "auto",
|
||||||
|
commit_type: "feat",
|
||||||
|
main_branch: "develop",
|
||||||
|
merge_strategy: "squash",
|
||||||
|
isolation: "branch",
|
||||||
|
},
|
||||||
|
});
|
||||||
|
assertEq(errors.length, 0, "valid git fields produce no errors");
|
||||||
|
assertEq(preferences.git?.auto_push, true, "git.auto_push passes through");
|
||||||
|
assertEq(preferences.git?.push_branches, false, "git.push_branches passes through");
|
||||||
|
assertEq(preferences.git?.remote, "upstream", "git.remote passes through");
|
||||||
|
assertEq(preferences.git?.snapshots, true, "git.snapshots passes through");
|
||||||
|
assertEq(preferences.git?.pre_merge_check, "auto", "git.pre_merge_check passes through");
|
||||||
|
assertEq(preferences.git?.commit_type, "feat", "git.commit_type passes through");
|
||||||
|
assertEq(preferences.git?.main_branch, "develop", "git.main_branch passes through");
|
||||||
|
assertEq(preferences.git?.merge_strategy, "squash", "git.merge_strategy passes through");
|
||||||
|
assertEq(preferences.git?.isolation, "branch", "git.isolation passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("\n=== uat_dispatch validates correctly ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({ uat_dispatch: true });
|
||||||
|
assertEq(errors.length, 0, "valid uat_dispatch produces no errors");
|
||||||
|
assertEq(preferences.uat_dispatch, true, "uat_dispatch true passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({ uat_dispatch: false });
|
||||||
|
assertEq(errors.length, 0, "valid uat_dispatch false produces no errors");
|
||||||
|
assertEq(preferences.uat_dispatch, false, "uat_dispatch false passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("\n=== unique_milestone_ids validates correctly ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const { preferences, errors } = validatePreferences({ unique_milestone_ids: true });
|
||||||
|
assertEq(errors.length, 0, "valid unique_milestone_ids produces no errors");
|
||||||
|
assertEq(preferences.unique_milestone_ids, true, "unique_milestone_ids passes through");
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log("\n=== all wizard fields together produce no errors ===");
|
||||||
|
|
||||||
|
{
|
||||||
|
const fullPrefs: GSDPreferences = {
|
||||||
|
version: 1,
|
||||||
|
models: { research: "claude-opus-4-6", planning: "claude-sonnet-4-6" },
|
||||||
|
auto_supervisor: { soft_timeout_minutes: 15, idle_timeout_minutes: 5, hard_timeout_minutes: 25 },
|
||||||
|
git: {
|
||||||
|
main_branch: "main",
|
||||||
|
auto_push: true,
|
||||||
|
push_branches: false,
|
||||||
|
remote: "origin",
|
||||||
|
snapshots: true,
|
||||||
|
pre_merge_check: "auto",
|
||||||
|
commit_type: "feat",
|
||||||
|
merge_strategy: "squash",
|
||||||
|
isolation: "worktree",
|
||||||
|
},
|
||||||
|
skill_discovery: "suggest",
|
||||||
|
unique_milestone_ids: false,
|
||||||
|
budget_ceiling: 50,
|
||||||
|
budget_enforcement: "pause",
|
||||||
|
context_pause_threshold: 75,
|
||||||
|
notifications: {
|
||||||
|
enabled: true,
|
||||||
|
on_complete: true,
|
||||||
|
on_error: true,
|
||||||
|
on_budget: true,
|
||||||
|
on_milestone: true,
|
||||||
|
on_attention: true,
|
||||||
|
},
|
||||||
|
uat_dispatch: false,
|
||||||
|
};
|
||||||
|
const { errors, warnings } = validatePreferences(fullPrefs);
|
||||||
|
const unknownWarnings = warnings.filter(w => w.includes("unknown"));
|
||||||
|
assertEq(errors.length, 0, "full wizard prefs produce no errors");
|
||||||
|
assertEq(unknownWarnings.length, 0, "full wizard prefs produce no unknown-key warnings");
|
||||||
|
}
|
||||||
|
|
||||||
|
report();
|
||||||
|
}
|
||||||
|
|
||||||
|
main();
|
||||||
Loading…
Add table
Reference in a new issue