From 00cb2f36a8c5f35421cbda236440f85885c4208d Mon Sep 17 00:00:00 2001 From: dan Date: Thu, 12 Mar 2026 20:44:01 -0700 Subject: [PATCH] fix: reuse pi provider config and extension loading --- src/cli.ts | 24 ++++-- src/onboarding.ts | 7 +- src/pi-migration.ts | 22 +++++- src/resource-loader.ts | 78 ++++++++++++++++++- .../extensions/google-search/index.ts | 21 ++++- src/tests/app-smoke.test.ts | 51 ++++++++++++ 6 files changed, 181 insertions(+), 22 deletions(-) diff --git a/src/cli.ts b/src/cli.ts index 6d1737c9a..faf5f81a0 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -15,7 +15,7 @@ import { agentDir, sessionsDir, authFilePath } from './app-paths.js' import { initResources, buildResourceLoader } from './resource-loader.js' import { ensureManagedTools } from './tool-bootstrap.js' import { loadStoredEnvKeys } from './wizard.js' -import { migratePiCredentials } from './pi-migration.js' +import { getPiDefaultModelAndProvider, migratePiCredentials } from './pi-migration.js' import { shouldRunOnboarding, runOnboarding } from './onboarding.js' // --------------------------------------------------------------------------- @@ -115,22 +115,30 @@ const settingsManager = SettingsManager.create(agentDir) const configuredProvider = settingsManager.getDefaultProvider() const configuredModel = settingsManager.getDefaultModel() const allModels = modelRegistry.getAll() +const availableModels = modelRegistry.getAvailable() const configuredExists = configuredProvider && configuredModel && allModels.some((m) => m.provider === configuredProvider && m.id === configuredModel) +const configuredAvailable = configuredProvider && configuredModel && + availableModels.some((m) => m.provider === configuredProvider && m.id === configuredModel) -if (!configuredModel || !configuredExists) { - // Fallback: pick the best available Anthropic model +if (!configuredModel || !configuredExists || !configuredAvailable) { + const piDefault = getPiDefaultModelAndProvider() const preferred = - allModels.find((m) => m.provider === 'anthropic' && m.id === 'claude-opus-4-6') || - allModels.find((m) => m.provider === 'anthropic' && m.id.includes('opus')) || - allModels.find((m) => m.provider === 'anthropic') + (piDefault + ? availableModels.find((m) => m.provider === piDefault.provider && m.id === piDefault.model) + : undefined) || + availableModels.find((m) => m.provider === 'openai' && m.id === 'gpt-5.4') || + availableModels.find((m) => m.provider === 'openai') || + availableModels.find((m) => m.provider === 'anthropic' && m.id === 'claude-opus-4-6') || + availableModels.find((m) => m.provider === 'anthropic' && m.id.includes('opus')) || + availableModels.find((m) => m.provider === 'anthropic') || + availableModels[0] if (preferred) { settingsManager.setDefaultModelAndProvider(preferred.provider, preferred.id) } } -// Default thinking level: off (always reset if not explicitly set) -if (settingsManager.getDefaultThinkingLevel() !== 'off' && !configuredExists) { +if (settingsManager.getDefaultThinkingLevel() !== 'off' && (!configuredExists || !configuredAvailable)) { settingsManager.setDefaultThinkingLevel('off') } diff --git a/src/onboarding.ts b/src/onboarding.ts index 48345ff33..09dd15ae2 100644 --- a/src/onboarding.ts +++ b/src/onboarding.ts @@ -152,18 +152,17 @@ function isCancelError(p: ClackModule, err: unknown): boolean { * Determine if the onboarding wizard should run. * * Returns true when: - * - No LLM provider has credentials in authStorage + * - No LLM provider auth is available * - We're on a TTY (interactive terminal) * * Returns false (skip wizard) when: - * - Any LLM provider is already authed (returning user) + * - Any LLM provider is already available via auth.json, env vars, runtime overrides, or fallback auth * - Not a TTY (piped input, subagent, CI) */ export function shouldRunOnboarding(authStorage: AuthStorage): boolean { if (!process.stdin.isTTY) return false // Check if any LLM provider has credentials - const authedProviders = authStorage.list() - const hasLlmAuth = authedProviders.some(id => LLM_PROVIDER_IDS.includes(id)) + const hasLlmAuth = LLM_PROVIDER_IDS.some(id => authStorage.hasAuth(id)) return !hasLlmAuth } diff --git a/src/pi-migration.ts b/src/pi-migration.ts index 3fa15902c..93a3c0b74 100644 --- a/src/pi-migration.ts +++ b/src/pi-migration.ts @@ -10,6 +10,7 @@ import { join } from 'node:path' import type { AuthStorage, AuthCredential } from '@mariozechner/pi-coding-agent' const PI_AUTH_PATH = join(homedir(), '.pi', 'agent', 'auth.json') +const PI_SETTINGS_PATH = join(homedir(), '.pi', 'agent', 'settings.json') const LLM_PROVIDER_IDS = [ 'anthropic', @@ -34,7 +35,6 @@ const LLM_PROVIDER_IDS = [ */ export function migratePiCredentials(authStorage: AuthStorage): boolean { try { - // Only migrate when GSD has no LLM providers const existing = authStorage.list() const hasLlm = existing.some(id => LLM_PROVIDER_IDS.includes(id)) if (hasLlm) return false @@ -55,7 +55,25 @@ export function migratePiCredentials(authStorage: AuthStorage): boolean { return migratedLlm } catch { - // Non-fatal — don't block startup return false } } + +export function getPiDefaultModelAndProvider(): { provider: string; model: string } | null { + try { + if (!existsSync(PI_SETTINGS_PATH)) return null + + const raw = readFileSync(PI_SETTINGS_PATH, 'utf-8') + const data = JSON.parse(raw) as { defaultProvider?: unknown; defaultModel?: unknown } + if (typeof data.defaultProvider !== 'string' || typeof data.defaultModel !== 'string') { + return null + } + + return { + provider: data.defaultProvider, + model: data.defaultModel, + } + } catch { + return null + } +} diff --git a/src/resource-loader.ts b/src/resource-loader.ts index d7595dd4d..d11087b75 100644 --- a/src/resource-loader.ts +++ b/src/resource-loader.ts @@ -1,7 +1,7 @@ import { DefaultResourceLoader } from '@mariozechner/pi-coding-agent' import { homedir } from 'node:os' -import { cpSync, existsSync, mkdirSync, readFileSync, writeFileSync } from 'node:fs' -import { dirname, join, resolve } from 'node:path' +import { cpSync, existsSync, mkdirSync, readFileSync, readdirSync, writeFileSync } from 'node:fs' +import { dirname, join, relative, resolve } from 'node:path' import { fileURLToPath } from 'node:url' // Resolves to the bundled src/resources/ inside the npm package at runtime: @@ -9,6 +9,70 @@ import { fileURLToPath } from 'node:url' const resourcesDir = resolve(dirname(fileURLToPath(import.meta.url)), '..', 'src', 'resources') const bundledExtensionsDir = join(resourcesDir, 'extensions') +function isExtensionFile(name: string): boolean { + return name.endsWith('.ts') || name.endsWith('.js') +} + +function resolveExtensionEntries(dir: string): string[] { + const packageJsonPath = join(dir, 'package.json') + if (existsSync(packageJsonPath)) { + try { + const pkg = JSON.parse(readFileSync(packageJsonPath, 'utf-8')) + const declared = pkg?.pi?.extensions + if (Array.isArray(declared)) { + const resolved = declared + .filter((entry: unknown): entry is string => typeof entry === 'string') + .map((entry: string) => resolve(dir, entry)) + .filter((entry: string) => existsSync(entry)) + if (resolved.length > 0) { + return resolved + } + } + } catch { + // Ignore malformed manifests and fall back to index.ts/index.js discovery. + } + } + + const indexTs = join(dir, 'index.ts') + if (existsSync(indexTs)) { + return [indexTs] + } + + const indexJs = join(dir, 'index.js') + if (existsSync(indexJs)) { + return [indexJs] + } + + return [] +} + +export function discoverExtensionEntryPaths(extensionsDir: string): string[] { + if (!existsSync(extensionsDir)) { + return [] + } + + const discovered: string[] = [] + for (const entry of readdirSync(extensionsDir, { withFileTypes: true })) { + const entryPath = join(extensionsDir, entry.name) + + if ((entry.isFile() || entry.isSymbolicLink()) && isExtensionFile(entry.name)) { + discovered.push(entryPath) + continue + } + + if (entry.isDirectory() || entry.isSymbolicLink()) { + discovered.push(...resolveExtensionEntries(entryPath)) + } + } + + return discovered +} + +function getExtensionKey(entryPath: string, extensionsDir: string): string { + const relPath = relative(extensionsDir, entryPath) + return relPath.split(/[\\/]/)[0] +} + /** * Syncs all bundled resources to agentDir (~/.gsd/agent/) on every launch. * @@ -60,9 +124,15 @@ export function initResources(agentDir: string): void { export function buildResourceLoader(agentDir: string): DefaultResourceLoader { const piAgentDir = join(homedir(), '.pi', 'agent') const piExtensionsDir = join(piAgentDir, 'extensions') - + const bundledKeys = new Set( + discoverExtensionEntryPaths(bundledExtensionsDir).map((entryPath) => getExtensionKey(entryPath, bundledExtensionsDir)), + ) + const piExtensionPaths = discoverExtensionEntryPaths(piExtensionsDir).filter( + (entryPath) => !bundledKeys.has(getExtensionKey(entryPath, piExtensionsDir)), + ) + return new DefaultResourceLoader({ agentDir, - additionalExtensionPaths: [piExtensionsDir], + additionalExtensionPaths: piExtensionPaths, }) } diff --git a/src/resources/extensions/google-search/index.ts b/src/resources/extensions/google-search/index.ts index 409ae5b5f..ba8b97b54 100644 --- a/src/resources/extensions/google-search/index.ts +++ b/src/resources/extensions/google-search/index.ts @@ -19,7 +19,6 @@ import { } from "@mariozechner/pi-coding-agent"; import { Text } from "@mariozechner/pi-tui"; import { Type } from "@sinclair/typebox"; -import { GoogleGenAI } from "@google/genai"; // ── Types ──────────────────────────────────────────────────────────────────── @@ -46,10 +45,24 @@ interface SearchDetails { // ── Lazy singleton client ──────────────────────────────────────────────────── -let client: GoogleGenAI | null = null; +type GoogleGenAIClient = { + models: { + generateContent: (args: { + model: string; + contents: string; + config?: { + tools?: Array<{ googleSearch: Record }>; + abortSignal?: AbortSignal; + }; + }) => Promise; + }; +}; -function getClient(): GoogleGenAI { +let client: GoogleGenAIClient | null = null; + +async function getClient(): Promise { if (!client) { + const { GoogleGenAI } = await import("@google/genai"); client = new GoogleGenAI({ apiKey: process.env.GEMINI_API_KEY! }); } return client; @@ -139,7 +152,7 @@ export default function (pi: ExtensionAPI) { // Call Gemini with Google Search grounding let result: SearchResult; try { - const ai = getClient(); + const ai = await getClient(); const response = await ai.models.generateContent({ model: process.env.GEMINI_SEARCH_MODEL || "gemini-2.5-flash", contents: params.query, diff --git a/src/tests/app-smoke.test.ts b/src/tests/app-smoke.test.ts index c71df182e..d156b54fb 100644 --- a/src/tests/app-smoke.test.ts +++ b/src/tests/app-smoke.test.ts @@ -152,6 +152,57 @@ test("initResources syncs extensions, agents, and AGENTS.md to target dir", asyn // 4. wizard loadStoredEnvKeys hydration // ═══════════════════════════════════════════════════════════════════════════ +test("buildResourceLoader expands ~/.pi extension directories into entry files", async () => { + const originalHome = process.env.HOME; + const tmp = mkdtempSync(join(tmpdir(), "gsd-pi-ext-test-")); + const fakeHome = join(tmp, "home"); + const fakeAgentDir = join(tmp, "agent"); + const piExtensionsDir = join(fakeHome, ".pi", "agent", "extensions"); + mkdirSync(piExtensionsDir, { recursive: true }); + mkdirSync(fakeAgentDir, { recursive: true }); + + writeFileSync( + join(piExtensionsDir, "top-level.ts"), + "export default function(pi){ pi.on('agent_start', () => {}); }\n", + ); + + const packagedDir = join(piExtensionsDir, "packaged-ext"); + mkdirSync(packagedDir, { recursive: true }); + writeFileSync( + join(packagedDir, "package.json"), + JSON.stringify({ pi: { extensions: ["./custom-entry.ts"] } }, null, 2), + ); + writeFileSync( + join(packagedDir, "custom-entry.ts"), + "export default function(pi){ pi.on('agent_start', () => {}); }\n", + ); + + process.env.HOME = fakeHome; + + try { + const { buildResourceLoader } = await import("../resource-loader.ts"); + const loader = buildResourceLoader(fakeAgentDir); + await loader.reload(); + const { extensions, errors } = loader.getExtensions(); + + assert.ok( + extensions.some((ext) => ext.path.endsWith("top-level.ts")), + "loads top-level ~/.pi extension files", + ); + assert.ok( + extensions.some((ext) => ext.path.endsWith("packaged-ext/custom-entry.ts")), + "loads packaged ~/.pi extensions via pi.extensions manifest", + ); + assert.ok( + !errors.some((err) => err.path === piExtensionsDir), + "does not try to load the ~/.pi/agent/extensions directory itself as a module", + ); + } finally { + if (originalHome) process.env.HOME = originalHome; else delete process.env.HOME; + rmSync(tmp, { recursive: true, force: true }); + } +}); + test("loadStoredEnvKeys hydrates process.env from auth.json", async () => { const { loadStoredEnvKeys } = await import("../wizard.ts"); const { AuthStorage } = await import("@mariozechner/pi-coding-agent");