Merge branch 'gsd-build:main' into main

This commit is contained in:
Marcel Reschke 2026-03-13 09:35:06 +01:00 committed by GitHub
commit 2140a4de07
310 changed files with 89146 additions and 2652 deletions

View file

@ -1,30 +0,0 @@
name: Publish to npm
on:
push:
tags:
- 'v*'
jobs:
publish:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: '22'
registry-url: 'https://registry.npmjs.org'
- name: Install dependencies
run: npm ci
- name: Build
run: npm run build
- name: Publish to npm
run: npm publish --access public
env:
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}

13
.gitignore vendored
View file

@ -28,6 +28,10 @@ coverage/
.cache/
tmp/
# ── Workspace packages ──
packages/*/dist/
packages/*/node_modules/
# ── GSD baseline (auto-generated) ──
dist/
.bg_shell
@ -36,6 +40,15 @@ dist/
AGENTS.md
.bg-shell/
TODOS.md
.planning/
# ── GSD baseline (auto-generated) ──
.gsd/
# ── GSD baseline (auto-generated) ──
.gsd/activity/
.gsd/runtime/
.gsd/worktrees/
.gsd/auto.lock
.gsd/metrics.json
.gsd/STATE.md

View file

@ -6,6 +6,113 @@ Format based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
## [Unreleased]
## [2.7.0] - 2026-03-12
### Changed
- Vendor Pi SDK source (tui, ai, agent-core, coding-agent) into workspace monorepo under `packages/`, replacing the compiled npm dependency and patch-package workflow. Pi internals are now directly modifiable as TypeScript source.
- Existing patches (setModel persist option, Windows VT input caching) applied as source edits.
- Build pipeline runs workspace packages in dependency order before GSD compilation.
- Removed `patch-package` from devDependencies and postinstall.
## [2.6.0] - 2026-03-12
### Added
- Proactive secret management — planning phase forecasts required API keys into a manifest; auto-mode collects pending secrets before dispatching the first slice
- `--continue` / `-c` CLI flag to resume the most recent session
### Fixed
- Doctor post-hook no longer preempts `complete-slice` dispatch
- `main_branch` preference restored; `runPreMergeCheck` implemented for merge safety
- Recovery/retry prompt injection capped to prevent V8 OOM on large sessions
- `.gsd/` excluded from pre-switch auto-commits to prevent squash merge conflicts
## [2.5.1] - 2026-03-12
### Added
- `secure_env_collect` now auto-detects existing keys, destination files, and provides guidance field for better onboarding UX
### Changed
- Right-sized pipeline for simple work — single-slice milestones skip redundant research/plan sessions, reducing 9-10 sessions to 5-6
- Heavyweight plan sections (Proof Level, Integration Closure, Observability) are now conditional, omitted for simple slices
### Fixed
- Squash-merge now aborts cleanly on conflict and stops auto-mode instead of looping with corrupted state
- Resolved baked-in merge conflict markers in loader.ts, logo.ts, and postinstall.js
## [2.5.0] - 2026-03-12
### Added
- Native Anthropic web search — Claude models get server-side web search automatically, no Brave API key required
- GitService fully wired into codebase — programmatic git operations replace shell-based git commands in prompts
- Merge guards prevent slice completion when uncommitted changes or conflicts exist
- Snapshot support for saving and restoring `.gsd/` state
- Auto-push after slice squash-merge to main
- Rich commit messages with structured metadata
### Fixed
- State machine deadlock when units fail to produce expected artifacts — retry and cross-validation now gate completion
- Duplicate Brave search tools when toggling providers repeatedly
- Windows test glob patterns (single quotes → unquoted for shell expansion)
- Conversation replay error caused by thinking blocks in stored history
- Brave search tools removed from API payload when no `BRAVE_API_KEY` is set
- Restore notifications suppressed on session resume to reduce UX noise
## [2.4.0] - 2026-03-12
### Added
- Automatic migration of provider credentials from existing Pi installations — skip re-authentication when switching to GSD
- Pi extensions from `~/.pi/agent/extensions/` discoverable in interactive mode
- GitService core implementation for programmatic git operations
### Changed
- System prompt compressed by 48% (360 → 187 lines) for better context efficiency
- Refined agent character and communication style prompts
- Added craft standards, self-debugging awareness, and work narration to agent prompts
### Fixed
- RPC mode crash when `ctx.ui.theme` is undefined (#121)
## [2.3.11] - 2026-03-12
### Added
- Branded clack-based onboarding wizard on first launch — LLM provider selection (OAuth + API key), optional tool API keys, and setup summary (#118)
- `gsd config` subcommand to re-run the setup wizard anytime
- Shared `src/logo.ts` module as single source of truth for ASCII banner
### Fixed
- Parallel subagent results no longer truncated at 200 characters
### Changed
- `wizard.ts` trimmed to env hydration only — onboarding logic moved to `onboarding.ts`
- First-launch banner removed from `loader.ts` (onboarding wizard handles branding)
## [2.3.10] - 2026-03-12
### Added
- Branded postinstall experience with animated spinners, progress indicators, and clean summary (#115)
### Fixed
- Ctrl+Alt shortcuts (dashboard, bg manager, voice) now show slash-command fallback in terminals that lack Kitty keyboard protocol support — macOS Terminal.app, JetBrains IDEs (#100, #104)
## [2.3.9] - 2026-03-12
### Added
- Tavily as alternative web search provider alongside Brave Search (#102)
- Auto-mode progress widget now shows all stats; footer hidden during auto-mode (#75)
### Fixed
- Auto-mode infinite loop and closeout instability — idempotent unit dispatch, retry caps, and atomic closeout (#96, #109)
- Migration no longer requires ROADMAP.md — milestones inferred from phases/ directory when missing (#93, #90)
- Worktree branch safety — proper namespacing and slice branch base selection (#92)
- Windows: use `execFile` to avoid single-quote shell issues (#103)
- Broken `read @GSD-WORKFLOW.md` references replaced with `/gsd` command (#88)
- Google Search extension updated to use `gemini-2.5-flash` (#83)
- Duplicate `getCurrentBranch` import in auto.ts (#87)
- `formatCost` crash on non-number cost values (#74)
- Avoid `sudo` prompts in postinstall script (#73)
- `.gsd/` folder removed from git tracking; consolidated `.gitignore` (#78)
- Multiple community-reported bugs across CLI, auto-mode, and extensions
## [2.3.8] - 2026-03-11
### Fixed
@ -156,7 +263,15 @@ Format based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
### Changed
- License updated to MIT
[Unreleased]: https://github.com/gsd-build/gsd-2/compare/v2.3.8...HEAD
[Unreleased]: https://github.com/gsd-build/gsd-2/compare/v2.7.0...HEAD
[2.7.0]: https://github.com/gsd-build/gsd-2/compare/v2.6.0...v2.7.0
[2.6.0]: https://github.com/gsd-build/gsd-2/compare/v2.5.1...v2.6.0
[2.5.1]: https://github.com/gsd-build/gsd-2/compare/v2.5.0...v2.5.1
[2.5.0]: https://github.com/gsd-build/gsd-2/compare/v2.4.0...v2.5.0
[2.4.0]: https://github.com/gsd-build/gsd-2/compare/v2.3.11...v2.4.0
[2.3.11]: https://github.com/gsd-build/gsd-2/compare/v2.3.10...v2.3.11
[2.3.10]: https://github.com/gsd-build/gsd-2/compare/v2.3.9...v2.3.10
[2.3.9]: https://github.com/gsd-build/gsd-2/compare/v2.3.8...v2.3.9
[2.3.8]: https://github.com/gsd-build/gsd-2/compare/v2.3.7...v2.3.8
[2.3.7]: https://github.com/gsd-build/gsd-2/compare/v2.3.6...v2.3.7
[2.3.6]: https://github.com/gsd-build/gsd-2/compare/v2.3.5...v2.3.6

213
ISSUE-120-INVESTIGATION.md Normal file
View file

@ -0,0 +1,213 @@
# Issue #120 — GSD Auto Secret Collection Improvements
## Problem Statement
Users report three failures in auto-mode secret handling:
1. **Late discovery** — Secrets aren't gathered until well into execution (e.g., first slice), blocking progress for hours while the user is away
2. **Re-asking across slices** — The same secrets are requested again at the start of later slices
3. **Re-asking within slices** — The same secrets are requested again mid-slice
All three stem from the same architectural gap: GSD has no proactive secret identification, and no reliable persistence of project-specific secrets across fresh sessions.
---
## Current Architecture
### Secret Collection Tool
`src/resources/extensions/get-secrets-from-user.ts`
- `secure_env_collect` — paged, masked-input TUI for collecting env vars
- Writes to three destinations: `.env` (local), Vercel (`vercel env add`), Convex (`npx convex env set`)
- Values are masked in UI and never echoed in tool output
- Well-built tool — the problem isn't collection UX, it's when and how often collection happens
### Secret Persistence (GSD-owned keys only)
`src/wizard.ts``loadStoredEnvKeys()`
Runs at CLI startup. Loads a hardcoded list of GSD's own keys from `~/.gsd/agent/auth.json` into `process.env`:
- `BRAVE_API_KEY`, `BRAVE_ANSWERS_KEY`
- `CONTEXT7_API_KEY`, `JINA_API_KEY`, `TAVILY_API_KEY`
- `SLACK_BOT_TOKEN`, `DISCORD_BOT_TOKEN`
**Project-specific secrets** (GitHub tokens, database URLs, OpenAI keys, etc.) collected via `secure_env_collect` to `.env` are NOT loaded by this mechanism.
### Fresh Session Model
`src/resources/extensions/gsd/auto.ts`
Each unit of work (plan slice, execute task, complete slice) gets a fresh session via `ctx.newSession()`. This means:
- Clean context window
- State rebuilt from `.gsd/` artifacts on disk
- No memory of what happened in the previous session
- `process.env` does not include project `.env` contents unless something explicitly loads them
### Prompt Guidance
| File | What it says about secrets |
|------|--------------------------|
| `system.md:26-27` | Never log secrets; use `secure_env_collect` instead of manual `.env` editing |
| `system.md:131` | Routes "Secrets" to `secure_env_collect` |
| `system.md:197` | After applying secrets, rerun the blocked workflow |
| `execute-task.md:30` | Never log secrets/tokens unnecessarily |
| `secure_env_collect` promptGuidelines | Proactively call before first command needing secrets; call when commands fail due to missing env vars |
All guidance is **reactive** — "when you hit an error, collect the secret." Nothing says "identify all secrets upfront before execution begins."
### What's Missing
| Gap | Impact |
|-----|--------|
| No secret identification during research/planning | Secrets discovered reactively during execution, often hours in |
| No `.env` loading across fresh sessions | Previously-collected project secrets invisible to new sessions |
| No "secrets already collected" carry-forward | Agent in fresh session doesn't know what was already gathered |
| No `Required Credentials` section in requirements | No structured place to track what the project needs |
| No deduplication or "already have this" check | Agent re-asks for secrets it already wrote to `.env` |
---
## Root Cause Analysis
### Problem 1: Late Discovery
The research phase (`research-milestone.md`) focuses on codebase exploration, technology assessment, and strategic questions. The planning phase (`plan-milestone.md`, `plan-slice.md`) focuses on task decomposition and verification. Neither phase includes a step to identify required credentials.
The `secure_env_collect` promptGuidelines say "when starting a new project or running setup steps that require secrets, proactively call secure_env_collect before the first command that needs them" — but this fires during task execution, not during planning. By then, the user may be asleep.
### Problem 2: Re-asking Across Slices
When `secure_env_collect` writes a secret to `.env`, that file persists on disk. But when auto-mode spawns a fresh session for the next slice, the new session's `process.env` doesn't include the `.env` contents. The agent in the new session encounters the same "missing env var" error and calls `secure_env_collect` again.
The `loadStoredEnvKeys()` function only loads GSD's own keys from AuthStorage, not project-specific keys from `.env`.
### Problem 3: Re-asking Within Slices
Within a single session, if `secure_env_collect` writes to `.env` but the calling code reads from `process.env` (not the file), the secret appears missing. Additionally, if a task uses a tool that checks `process.env` independently, it won't see the `.env` file contents unless something loads them.
---
## Proposed Solutions
### Solution 1: Proactive Secret Identification During Planning
**Where**: `src/resources/extensions/gsd/prompts/plan-milestone.md`
Add a step after research is consumed and before slice decomposition:
> Identify all secrets, API keys, tokens, credentials, and external service configurations this milestone will require. Consider:
> - APIs being integrated (keys, tokens, OAuth credentials)
> - Databases (connection strings, passwords)
> - Third-party services (webhook secrets, API keys)
> - Deployment targets (platform tokens)
>
> If any secrets are needed, call `secure_env_collect` now to gather them before execution begins. This prevents blocking during unattended execution.
**Also update**: `src/resources/extensions/gsd/templates/requirements.md` — add a `## Required Credentials` section:
```markdown
## Required Credentials
| Key | Purpose | Source | Status |
|-----|---------|--------|--------|
| GITHUB_TOKEN | GitHub API access | User | collected |
| DATABASE_URL | PostgreSQL connection | User | pending |
```
### Solution 2: Load Project `.env` on Fresh Session Start
**Where**: `src/resources/extensions/gsd/auto.ts` — before spawning each fresh session
Before `ctx.newSession()`, read the project's `.env` file and inject its contents into the session's environment. This ensures previously-collected secrets carry forward without re-asking.
Implementation approach:
```typescript
import { readFile } from "node:fs/promises";
import { resolve } from "node:path";
async function loadProjectEnv(cwd: string): Promise<void> {
try {
const envPath = resolve(cwd, ".env");
const content = await readFile(envPath, "utf8");
for (const line of content.split("\n")) {
const trimmed = line.trim();
if (!trimmed || trimmed.startsWith("#")) continue;
const eqIndex = trimmed.indexOf("=");
if (eqIndex === -1) continue;
const key = trimmed.slice(0, eqIndex).trim();
const value = trimmed.slice(eqIndex + 1).trim();
// Don't override explicitly-set env vars
if (!process.env[key]) {
process.env[key] = value;
}
}
} catch {
// No .env file — that's fine
}
}
```
Call this before each fresh session spawn in auto-mode.
**Alternative**: Persist project secrets to AuthStorage alongside GSD's own keys, so `loadStoredEnvKeys()` picks them up. This is cleaner but requires changes to `secure_env_collect` to write to both `.env` and AuthStorage.
### Solution 3: Carry-Forward Context for Collected Secrets
**Where**: `src/resources/extensions/gsd/auto.ts` — in the context/prompt assembly for fresh sessions
Add a section to the injected prompt that lists secrets already collected:
> ## Previously Collected Secrets
> The following env vars have already been collected and are available in `.env`:
> - `GITHUB_TOKEN`
> - `DATABASE_URL`
>
> Do NOT re-ask the user for these. If a command fails due to a missing env var not on this list, use `secure_env_collect`.
This requires scanning `.env` for key names (not values) and including them in the carry-forward context.
### Solution 4: Update Execute-Task Prompt
**Where**: `src/resources/extensions/gsd/prompts/execute-task.md`
Add an early step:
> Before starting work, check if the task requires env vars or secrets. If so, verify they exist in `.env` or `process.env`. If missing, call `secure_env_collect` immediately rather than discovering the need mid-task.
---
## Implementation Priority
| Priority | Solution | Effort | Impact |
|----------|----------|--------|--------|
| 1 | Solution 2: Load `.env` on fresh session start | Small | Eliminates re-asking (Problems 2 & 3) |
| 2 | Solution 3: Carry-forward collected secret names | Small | Prevents agent confusion about what's available |
| 3 | Solution 1: Proactive identification during planning | Medium | Eliminates late discovery (Problem 1) |
| 4 | Solution 4: Execute-task prompt update | Small | Defense-in-depth for Problem 1 |
Solutions 1-3 together fully address the issue. Solution 4 is defense-in-depth.
---
## Files to Modify
| File | Change |
|------|--------|
| `src/resources/extensions/gsd/auto.ts` | Load `.env` before fresh sessions; include collected secret names in carry-forward context |
| `src/resources/extensions/gsd/prompts/plan-milestone.md` | Add proactive secret identification step |
| `src/resources/extensions/gsd/prompts/execute-task.md` | Add early secret verification step |
| `src/resources/extensions/gsd/templates/requirements.md` | Add Required Credentials section |
| `src/resources/extensions/get-secrets-from-user.ts` | (Optional) Dual-write to AuthStorage for cross-project persistence |
---
## Edge Cases to Consider
- **Non-dotenv destinations**: If secrets were sent to Vercel or Convex, the `.env` loading approach won't help. May need to track "collected secrets" in a `.gsd/secrets-manifest.json` (key names only, no values).
- **Multiple `.env` files**: Some projects use `.env.local`, `.env.development`, etc. The loader should check common variants.
- **Secrets that change**: If a user needs to rotate a key, the "don't re-ask" logic should have an escape hatch.
- **Workspace vs global secrets**: Some secrets (like `GITHUB_TOKEN`) are user-global; others (like `DATABASE_URL`) are project-specific. Consider whether global secrets should go to AuthStorage while project secrets stay in `.env`.

View file

@ -48,6 +48,8 @@ GSD v2 solves all of these because it's not a prompt framework anymore — it's
### Migrating from v1
> **Note:** Migration works best with a `ROADMAP.md` file for milestone structure. Without one, milestones are inferred from the `phases/` directory.
If you have projects with `.planning` directories from the original Get Shit Done, you can migrate them to GSD-2's `.gsd` format:
```bash
@ -198,7 +200,7 @@ Both terminals read and write the same `.gsd/` files on disk. Your decisions in
### First launch
On first run, GSD prompts for optional API keys (Brave Search, Google Gemini, Context7, Jina) for web research and documentation tools. All optional — press Enter to skip any.
On first run, GSD launches a branded setup wizard that walks you through LLM provider selection (OAuth or API key), then optional tool API keys (Brave Search, Context7, Jina, Slack, Discord). Every step is skippable — press Enter to skip any. If you have an existing Pi installation, your provider credentials (LLM and tool keys) are imported automatically. Run `gsd config` anytime to re-run the wizard.
### Commands
@ -221,6 +223,8 @@ On first run, GSD prompts for optional API keys (Brave Search, Google Gemini, Co
| `Ctrl+Alt+G` | Toggle dashboard overlay |
| `Ctrl+Alt+V` | Toggle voice transcription |
| `Ctrl+Alt+B` | Show background shell processes |
| `gsd config` | Re-run the setup wizard (LLM provider + tool keys) |
| `gsd --continue` (`-c`) | Resume the most recent session for the current directory |
---
@ -249,17 +253,18 @@ Branch-per-slice with squash merge. Fully automated.
```
main:
feat(M001/S03): auth and session management
docs(M001/S04): workflow documentation and examples
fix(M001/S03): bug fixes and doc corrections
feat(M001/S02): API endpoints and middleware
feat(M001/S01): data model and type system
gsd/M001/S01 (preserved):
gsd/M001/S01 (deleted after merge):
feat(S01/T03): file writer with round-trip fidelity
feat(S01/T02): markdown parser for plan files
feat(S01/T01): core types and interfaces
```
One commit per slice on main. Per-task history preserved on branches. Git bisect works. Individual slices are revertable.
One commit per slice on main. Squash commits are the permanent record — branches are deleted after merge. Git bisect works. Individual slices are revertable.
### Verification
@ -326,7 +331,7 @@ GSD ships with 13 extensions, all loaded automatically:
|-----------|-----------------|
| **GSD** | Core workflow engine, auto mode, commands, dashboard |
| **Browser Tools** | Playwright-based browser for UI verification |
| **Search the Web** | Brave Search + Jina page extraction |
| **Search the Web** | Brave Search, Tavily, or Jina page extraction |
| **Google Search** | Gemini-powered web search with AI-synthesized answers |
| **Context7** | Up-to-date library/framework documentation |
| **Background Shell** | Long-running process management with readiness detection |
@ -358,7 +363,8 @@ GSD is a TypeScript application that embeds the Pi coding agent SDK.
gsd (CLI binary)
└─ loader.ts Sets PI_PACKAGE_DIR, GSD env vars, dynamic-imports cli.ts
└─ cli.ts Wires SDK managers, loads extensions, starts InteractiveMode
├─ wizard.ts First-run API key collection (Brave/Gemini/Context7/Jina)
├─ onboarding.ts First-run setup wizard (LLM provider + tool keys)
├─ wizard.ts Env hydration from stored auth.json credentials
├─ app-paths.ts ~/.gsd/agent/, ~/.gsd/sessions/, auth.json
├─ resource-loader.ts Syncs bundled extensions + agents to ~/.gsd/agent/
└─ src/resources/
@ -386,6 +392,7 @@ gsd (CLI binary)
Optional:
- Brave Search API key (web research)
- Tavily API key (web research — alternative to Brave)
- Google Gemini API key (web research via Gemini Search grounding)
- Context7 API key (library docs)
- Jina API key (page extraction)

1435
package-lock.json generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
{
"name": "gsd-pi",
"version": "2.3.8",
"version": "2.7.0",
"description": "GSD — Get Shit Done coding agent",
"license": "MIT",
"repository": {
@ -12,13 +12,16 @@
"url": "https://github.com/glittercowboy/gsd-pi/issues"
},
"type": "module",
"workspaces": [
"packages/*"
],
"bin": {
"gsd": "dist/loader.js",
"gsd-cli": "dist/loader.js"
},
"files": [
"dist",
"patches",
"packages",
"pkg",
"src/resources",
"scripts/postinstall.js",
@ -33,9 +36,14 @@
"node": ">=20.6.0"
},
"scripts": {
"build": "tsc && npm run copy-themes",
"copy-themes": "node -e \"const{mkdirSync,cpSync}=require('fs');const{resolve}=require('path');const src=resolve(__dirname,'node_modules/@mariozechner/pi-coding-agent/dist/modes/interactive/theme');mkdirSync('pkg/dist/modes/interactive/theme',{recursive:true});cpSync(src,'pkg/dist/modes/interactive/theme',{recursive:true})\"",
"test": "node --import ./src/resources/extensions/gsd/tests/resolve-ts.mjs --experimental-strip-types --test 'src/resources/extensions/gsd/tests/*.test.ts' 'src/resources/extensions/gsd/tests/*.test.mjs' 'src/tests/*.test.ts'",
"build:pi-tui": "npm run build -w @gsd/pi-tui",
"build:pi-ai": "npm run build -w @gsd/pi-ai",
"build:pi-agent-core": "npm run build -w @gsd/pi-agent-core",
"build:pi-coding-agent": "npm run build -w @gsd/pi-coding-agent",
"build:pi": "npm run build:pi-tui && npm run build:pi-ai && npm run build:pi-agent-core && npm run build:pi-coding-agent",
"build": "npm run build:pi && tsc && npm run copy-themes",
"copy-themes": "node -e \"const{mkdirSync,cpSync}=require('fs');const{resolve}=require('path');const src=resolve(__dirname,'packages/pi-coding-agent/dist/modes/interactive/theme');mkdirSync('pkg/dist/modes/interactive/theme',{recursive:true});cpSync(src,'pkg/dist/modes/interactive/theme',{recursive:true})\"",
"test": "node --import ./src/resources/extensions/gsd/tests/resolve-ts.mjs --experimental-strip-types --test src/resources/extensions/gsd/tests/*.test.ts src/resources/extensions/gsd/tests/*.test.mjs src/tests/*.test.ts",
"dev": "tsc --watch",
"postinstall": "node scripts/postinstall.js",
"pi:install-global": "node scripts/install-pi-global.js",
@ -44,12 +52,12 @@
"prepublishOnly": "npm run sync-pkg-version && npm run build"
},
"dependencies": {
"@mariozechner/pi-coding-agent": "^0.57.1",
"@clack/prompts": "^1.1.0",
"picocolors": "^1.1.1",
"playwright": "^1.58.2"
},
"devDependencies": {
"@types/node": "^22.0.0",
"patch-package": "^8.0.1",
"typescript": "^5.4.0"
},
"overrides": {

View file

@ -0,0 +1,14 @@
{
"name": "@gsd/pi-agent-core",
"version": "0.57.1",
"description": "General-purpose agent core (vendored from pi-mono)",
"type": "module",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"scripts": {
"build": "tsc -p tsconfig.json"
},
"dependencies": {
"@gsd/pi-ai": "*"
}
}

View file

@ -0,0 +1,417 @@
/**
* Agent loop that works with AgentMessage throughout.
* Transforms to Message[] only at the LLM call boundary.
*/
import {
type AssistantMessage,
type Context,
EventStream,
streamSimple,
type ToolResultMessage,
validateToolArguments,
} from "@gsd/pi-ai";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
AgentToolResult,
StreamFn,
} from "./types.js";
/**
* Start an agent loop with a new prompt message.
* The prompt is added to the context and events are emitted for it.
*/
export function agentLoop(
prompts: AgentMessage[],
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
const stream = createAgentStream();
(async () => {
const newMessages: AgentMessage[] = [...prompts];
const currentContext: AgentContext = {
...context,
messages: [...context.messages, ...prompts],
};
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
for (const prompt of prompts) {
stream.push({ type: "message_start", message: prompt });
stream.push({ type: "message_end", message: prompt });
}
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
return stream;
}
/**
* Continue an agent loop from the current context without adding a new message.
* Used for retries - context already has user message or tool results.
*
* **Important:** The last message in context must convert to a `user` or `toolResult` message
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
* This cannot be validated here since `convertToLlm` is only called once per turn.
*/
export function agentLoopContinue(
context: AgentContext,
config: AgentLoopConfig,
signal?: AbortSignal,
streamFn?: StreamFn,
): EventStream<AgentEvent, AgentMessage[]> {
if (context.messages.length === 0) {
throw new Error("Cannot continue: no messages in context");
}
if (context.messages[context.messages.length - 1].role === "assistant") {
throw new Error("Cannot continue from message role: assistant");
}
const stream = createAgentStream();
(async () => {
const newMessages: AgentMessage[] = [];
const currentContext: AgentContext = { ...context };
stream.push({ type: "agent_start" });
stream.push({ type: "turn_start" });
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
})();
return stream;
}
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
return new EventStream<AgentEvent, AgentMessage[]>(
(event: AgentEvent) => event.type === "agent_end",
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
);
}
/**
* Main loop logic shared by agentLoop and agentLoopContinue.
*/
async function runLoop(
currentContext: AgentContext,
newMessages: AgentMessage[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<void> {
let firstTurn = true;
// Check for steering messages at start (user may have typed while waiting)
let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || [];
// Outer loop: continues when queued follow-up messages arrive after agent would stop
while (true) {
let hasMoreToolCalls = true;
let steeringAfterTools: AgentMessage[] | null = null;
// Inner loop: process tool calls and steering messages
while (hasMoreToolCalls || pendingMessages.length > 0) {
if (!firstTurn) {
stream.push({ type: "turn_start" });
} else {
firstTurn = false;
}
// Process pending messages (inject before next assistant response)
if (pendingMessages.length > 0) {
for (const message of pendingMessages) {
stream.push({ type: "message_start", message });
stream.push({ type: "message_end", message });
currentContext.messages.push(message);
newMessages.push(message);
}
pendingMessages = [];
}
// Stream assistant response
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
newMessages.push(message);
if (message.stopReason === "error" || message.stopReason === "aborted") {
stream.push({ type: "turn_end", message, toolResults: [] });
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
return;
}
// Check for tool calls
const toolCalls = message.content.filter((c) => c.type === "toolCall");
hasMoreToolCalls = toolCalls.length > 0;
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
const toolExecution = await executeToolCalls(
currentContext.tools,
message,
signal,
stream,
config.getSteeringMessages,
);
toolResults.push(...toolExecution.toolResults);
steeringAfterTools = toolExecution.steeringMessages ?? null;
for (const result of toolResults) {
currentContext.messages.push(result);
newMessages.push(result);
}
}
stream.push({ type: "turn_end", message, toolResults });
// Get steering messages after turn completes
if (steeringAfterTools && steeringAfterTools.length > 0) {
pendingMessages = steeringAfterTools;
steeringAfterTools = null;
} else {
pendingMessages = (await config.getSteeringMessages?.()) || [];
}
}
// Agent would stop here. Check for follow-up messages.
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
if (followUpMessages.length > 0) {
// Set as pending so inner loop processes them
pendingMessages = followUpMessages;
continue;
}
// No more messages, exit
break;
}
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
}
/**
* Stream an assistant response from the LLM.
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
*/
async function streamAssistantResponse(
context: AgentContext,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
streamFn?: StreamFn,
): Promise<AssistantMessage> {
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
let messages = context.messages;
if (config.transformContext) {
messages = await config.transformContext(messages, signal);
}
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
const llmMessages = await config.convertToLlm(messages);
// Build LLM context
const llmContext: Context = {
systemPrompt: context.systemPrompt,
messages: llmMessages,
tools: context.tools,
};
const streamFunction = streamFn || streamSimple;
// Resolve API key (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
const response = await streamFunction(config.model, llmContext, {
...config,
apiKey: resolvedApiKey,
signal,
});
let partialMessage: AssistantMessage | null = null;
let addedPartial = false;
for await (const event of response) {
switch (event.type) {
case "start":
partialMessage = event.partial;
context.messages.push(partialMessage);
addedPartial = true;
stream.push({ type: "message_start", message: { ...partialMessage } });
break;
case "text_start":
case "text_delta":
case "text_end":
case "thinking_start":
case "thinking_delta":
case "thinking_end":
case "toolcall_start":
case "toolcall_delta":
case "toolcall_end":
if (partialMessage) {
partialMessage = event.partial;
context.messages[context.messages.length - 1] = partialMessage;
stream.push({
type: "message_update",
assistantMessageEvent: event,
message: { ...partialMessage },
});
}
break;
case "done":
case "error": {
const finalMessage = await response.result();
if (addedPartial) {
context.messages[context.messages.length - 1] = finalMessage;
} else {
context.messages.push(finalMessage);
}
if (!addedPartial) {
stream.push({ type: "message_start", message: { ...finalMessage } });
}
stream.push({ type: "message_end", message: finalMessage });
return finalMessage;
}
}
}
return await response.result();
}
/**
* Execute tool calls from an assistant message.
*/
async function executeToolCalls(
tools: AgentTool<any>[] | undefined,
assistantMessage: AssistantMessage,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage[] = [];
let steeringMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
const tool = tools?.find((t) => t.name === toolCall.name);
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
let result: AgentToolResult<any>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
const validatedArgs = validateToolArguments(tool, toolCall);
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
});
} catch (e) {
result = {
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
details: {},
};
isError = true;
}
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError,
});
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError,
timestamp: Date.now(),
};
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
// Check for steering messages - skip remaining tools if user interrupted
if (getSteeringMessages) {
const steering = await getSteeringMessages();
if (steering.length > 0) {
steeringMessages = steering;
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
break;
}
}
}
return { toolResults: results, steeringMessages };
}
function skipToolCall(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {},
};
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError: true,
});
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: {},
isError: true,
timestamp: Date.now(),
};
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
return toolResultMessage;
}

View file

@ -0,0 +1,568 @@
/**
* Agent class that uses the agent-loop directly.
* No transport abstraction - calls streamSimple via the loop.
*/
import {
getModel,
type ImageContent,
type Message,
type Model,
type SimpleStreamOptions,
streamSimple,
type TextContent,
type ThinkingBudgets,
type Transport,
} from "@gsd/pi-ai";
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentState,
AgentTool,
StreamFn,
ThinkingLevel,
} from "./types.js";
/**
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
*/
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
}
export interface AgentOptions {
initialState?: Partial<AgentState>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
* Default filters to user/assistant/toolResult and converts attachments.
*/
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to context before convertToLlm.
* Use for context pruning, injecting external context, etc.
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
*/
steeringMode?: "all" | "one-at-a-time";
/**
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
*/
followUpMode?: "all" | "one-at-a-time";
/**
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
*/
streamFn?: StreamFn;
/**
* Optional session identifier forwarded to LLM providers.
* Used by providers that support session-based caching (e.g., OpenAI Codex).
*/
sessionId?: string;
/**
* Resolves an API key dynamically for each LLM call.
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Inspect or replace provider payloads before they are sent.
*/
onPayload?: SimpleStreamOptions["onPayload"];
/**
* Custom token budgets for thinking levels (token-based providers only).
*/
thinkingBudgets?: ThinkingBudgets;
/**
* Preferred transport for providers that support multiple transports.
*/
transport?: Transport;
/**
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
* If the server's requested delay exceeds this value, the request fails immediately,
* allowing higher-level retry logic to handle it with user visibility.
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
*/
maxRetryDelayMs?: number;
}
export class Agent {
private _state: AgentState = {
systemPrompt: "",
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
thinkingLevel: "off",
tools: [],
messages: [],
isStreaming: false,
streamMessage: null,
pendingToolCalls: new Set<string>(),
error: undefined,
};
private listeners = new Set<(e: AgentEvent) => void>();
private abortController?: AbortController;
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
private steeringQueue: AgentMessage[] = [];
private followUpQueue: AgentMessage[] = [];
private steeringMode: "all" | "one-at-a-time";
private followUpMode: "all" | "one-at-a-time";
public streamFn: StreamFn;
private _sessionId?: string;
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
private _onPayload?: SimpleStreamOptions["onPayload"];
private runningPrompt?: Promise<void>;
private resolveRunningPrompt?: () => void;
private _thinkingBudgets?: ThinkingBudgets;
private _transport: Transport;
private _maxRetryDelayMs?: number;
constructor(opts: AgentOptions = {}) {
this._state = { ...this._state, ...opts.initialState };
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
this.transformContext = opts.transformContext;
this.steeringMode = opts.steeringMode || "one-at-a-time";
this.followUpMode = opts.followUpMode || "one-at-a-time";
this.streamFn = opts.streamFn || streamSimple;
this._sessionId = opts.sessionId;
this.getApiKey = opts.getApiKey;
this._onPayload = opts.onPayload;
this._thinkingBudgets = opts.thinkingBudgets;
this._transport = opts.transport ?? "sse";
this._maxRetryDelayMs = opts.maxRetryDelayMs;
}
/**
* Get the current session ID used for provider caching.
*/
get sessionId(): string | undefined {
return this._sessionId;
}
/**
* Set the session ID for provider caching.
* Call this when switching sessions (new session, branch, resume).
*/
set sessionId(value: string | undefined) {
this._sessionId = value;
}
/**
* Get the current thinking budgets.
*/
get thinkingBudgets(): ThinkingBudgets | undefined {
return this._thinkingBudgets;
}
/**
* Set custom thinking budgets for token-based providers.
*/
set thinkingBudgets(value: ThinkingBudgets | undefined) {
this._thinkingBudgets = value;
}
/**
* Get the current preferred transport.
*/
get transport(): Transport {
return this._transport;
}
/**
* Set the preferred transport.
*/
setTransport(value: Transport) {
this._transport = value;
}
/**
* Get the current max retry delay in milliseconds.
*/
get maxRetryDelayMs(): number | undefined {
return this._maxRetryDelayMs;
}
/**
* Set the maximum delay to wait for server-requested retries.
* Set to 0 to disable the cap.
*/
set maxRetryDelayMs(value: number | undefined) {
this._maxRetryDelayMs = value;
}
get state(): AgentState {
return this._state;
}
subscribe(fn: (e: AgentEvent) => void): () => void {
this.listeners.add(fn);
return () => this.listeners.delete(fn);
}
// State mutators
setSystemPrompt(v: string) {
this._state.systemPrompt = v;
}
setModel(m: Model<any>) {
this._state.model = m;
}
setThinkingLevel(l: ThinkingLevel) {
this._state.thinkingLevel = l;
}
setSteeringMode(mode: "all" | "one-at-a-time") {
this.steeringMode = mode;
}
getSteeringMode(): "all" | "one-at-a-time" {
return this.steeringMode;
}
setFollowUpMode(mode: "all" | "one-at-a-time") {
this.followUpMode = mode;
}
getFollowUpMode(): "all" | "one-at-a-time" {
return this.followUpMode;
}
setTools(t: AgentTool<any>[]) {
this._state.tools = t;
}
replaceMessages(ms: AgentMessage[]) {
this._state.messages = ms.slice();
}
appendMessage(m: AgentMessage) {
this._state.messages = [...this._state.messages, m];
}
/**
* Queue a steering message to interrupt the agent mid-run.
* Delivered after current tool execution, skips remaining tools.
*/
steer(m: AgentMessage) {
this.steeringQueue.push(m);
}
/**
* Queue a follow-up message to be processed after the agent finishes.
* Delivered only when agent has no more tool calls or steering messages.
*/
followUp(m: AgentMessage) {
this.followUpQueue.push(m);
}
clearSteeringQueue() {
this.steeringQueue = [];
}
clearFollowUpQueue() {
this.followUpQueue = [];
}
clearAllQueues() {
this.steeringQueue = [];
this.followUpQueue = [];
}
hasQueuedMessages(): boolean {
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
}
private dequeueSteeringMessages(): AgentMessage[] {
if (this.steeringMode === "one-at-a-time") {
if (this.steeringQueue.length > 0) {
const first = this.steeringQueue[0];
this.steeringQueue = this.steeringQueue.slice(1);
return [first];
}
return [];
}
const steering = this.steeringQueue.slice();
this.steeringQueue = [];
return steering;
}
private dequeueFollowUpMessages(): AgentMessage[] {
if (this.followUpMode === "one-at-a-time") {
if (this.followUpQueue.length > 0) {
const first = this.followUpQueue[0];
this.followUpQueue = this.followUpQueue.slice(1);
return [first];
}
return [];
}
const followUp = this.followUpQueue.slice();
this.followUpQueue = [];
return followUp;
}
clearMessages() {
this._state.messages = [];
}
abort() {
this.abortController?.abort();
}
waitForIdle(): Promise<void> {
return this.runningPrompt ?? Promise.resolve();
}
reset() {
this._state.messages = [];
this._state.isStreaming = false;
this._state.streamMessage = null;
this._state.pendingToolCalls = new Set<string>();
this._state.error = undefined;
this.steeringQueue = [];
this.followUpQueue = [];
}
/** Send a prompt with an AgentMessage */
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
async prompt(input: string, images?: ImageContent[]): Promise<void>;
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
if (this._state.isStreaming) {
throw new Error(
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
);
}
const model = this._state.model;
if (!model) throw new Error("No model configured");
let msgs: AgentMessage[];
if (Array.isArray(input)) {
msgs = input;
} else if (typeof input === "string") {
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
if (images && images.length > 0) {
content.push(...images);
}
msgs = [
{
role: "user",
content,
timestamp: Date.now(),
},
];
} else {
msgs = [input];
}
await this._runLoop(msgs);
}
/**
* Continue from current context (used for retries and resuming queued messages).
*/
async continue() {
if (this._state.isStreaming) {
throw new Error("Agent is already processing. Wait for completion before continuing.");
}
const messages = this._state.messages;
if (messages.length === 0) {
throw new Error("No messages to continue from");
}
if (messages[messages.length - 1].role === "assistant") {
const queuedSteering = this.dequeueSteeringMessages();
if (queuedSteering.length > 0) {
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
return;
}
const queuedFollowUp = this.dequeueFollowUpMessages();
if (queuedFollowUp.length > 0) {
await this._runLoop(queuedFollowUp);
return;
}
throw new Error("Cannot continue from message role: assistant");
}
await this._runLoop(undefined);
}
/**
* Run the agent loop.
* If messages are provided, starts a new conversation turn with those messages.
* Otherwise, continues from existing context.
*/
private async _runLoop(messages?: AgentMessage[], options?: { skipInitialSteeringPoll?: boolean }) {
const model = this._state.model;
if (!model) throw new Error("No model configured");
this.runningPrompt = new Promise<void>((resolve) => {
this.resolveRunningPrompt = resolve;
});
this.abortController = new AbortController();
this._state.isStreaming = true;
this._state.streamMessage = null;
this._state.error = undefined;
const reasoning = this._state.thinkingLevel === "off" ? undefined : this._state.thinkingLevel;
const context: AgentContext = {
systemPrompt: this._state.systemPrompt,
messages: this._state.messages.slice(),
tools: this._state.tools,
};
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
const config: AgentLoopConfig = {
model,
reasoning,
sessionId: this._sessionId,
onPayload: this._onPayload,
transport: this._transport,
thinkingBudgets: this._thinkingBudgets,
maxRetryDelayMs: this._maxRetryDelayMs,
convertToLlm: this.convertToLlm,
transformContext: this.transformContext,
getApiKey: this.getApiKey,
getSteeringMessages: async () => {
if (skipInitialSteeringPoll) {
skipInitialSteeringPoll = false;
return [];
}
return this.dequeueSteeringMessages();
},
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
};
let partial: AgentMessage | null = null;
try {
const stream = messages
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
for await (const event of stream) {
// Update internal state based on events
switch (event.type) {
case "message_start":
partial = event.message;
this._state.streamMessage = event.message;
break;
case "message_update":
partial = event.message;
this._state.streamMessage = event.message;
break;
case "message_end":
partial = null;
this._state.streamMessage = null;
this.appendMessage(event.message);
break;
case "tool_execution_start": {
const s = new Set(this._state.pendingToolCalls);
s.add(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "tool_execution_end": {
const s = new Set(this._state.pendingToolCalls);
s.delete(event.toolCallId);
this._state.pendingToolCalls = s;
break;
}
case "turn_end":
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
this._state.error = (event.message as any).errorMessage;
}
break;
case "agent_end":
this._state.isStreaming = false;
this._state.streamMessage = null;
break;
}
// Emit to listeners
this.emit(event);
}
// Handle any remaining partial message
if (partial && partial.role === "assistant" && partial.content.length > 0) {
const onlyEmpty = !partial.content.some(
(c) =>
(c.type === "thinking" && c.thinking.trim().length > 0) ||
(c.type === "text" && c.text.trim().length > 0) ||
(c.type === "toolCall" && c.name.trim().length > 0),
);
if (!onlyEmpty) {
this.appendMessage(partial);
} else {
if (this.abortController?.signal.aborted) {
throw new Error("Request was aborted");
}
}
}
} catch (err: any) {
const errorMsg: AgentMessage = {
role: "assistant",
content: [{ type: "text", text: "" }],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
errorMessage: err?.message || String(err),
timestamp: Date.now(),
} as AgentMessage;
this.appendMessage(errorMsg);
this._state.error = err?.message || String(err);
this.emit({ type: "agent_end", messages: [errorMsg] });
} finally {
this._state.isStreaming = false;
this._state.streamMessage = null;
this._state.pendingToolCalls = new Set<string>();
this.abortController = undefined;
this.resolveRunningPrompt?.();
this.runningPrompt = undefined;
this.resolveRunningPrompt = undefined;
}
}
private emit(e: AgentEvent) {
for (const listener of this.listeners) {
listener(e);
}
}
}

View file

@ -0,0 +1,8 @@
// Core Agent
export * from "./agent.js";
// Loop functions
export * from "./agent-loop.js";
// Proxy utilities
export * from "./proxy.js";
// Types
export * from "./types.js";

View file

@ -0,0 +1,340 @@
/**
* Proxy stream function for apps that route LLM calls through a server.
* The server manages auth and proxies requests to LLM providers.
*/
// Internal import for JSON parsing utility
import {
type AssistantMessage,
type AssistantMessageEvent,
type Context,
EventStream,
type Model,
parseStreamingJson,
type SimpleStreamOptions,
type StopReason,
type ToolCall,
} from "@gsd/pi-ai";
// Create stream class matching ProxyMessageEventStream
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") return event.message;
if (event.type === "error") return event.error;
throw new Error("Unexpected event type");
},
);
}
}
/**
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
*/
export type ProxyAssistantMessageEvent =
| { type: "start" }
| { type: "text_start"; contentIndex: number }
| { type: "text_delta"; contentIndex: number; delta: string }
| { type: "text_end"; contentIndex: number; contentSignature?: string }
| { type: "thinking_start"; contentIndex: number }
| { type: "thinking_delta"; contentIndex: number; delta: string }
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
| { type: "toolcall_delta"; contentIndex: number; delta: string }
| { type: "toolcall_end"; contentIndex: number }
| {
type: "done";
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
usage: AssistantMessage["usage"];
}
| {
type: "error";
reason: Extract<StopReason, "aborted" | "error">;
errorMessage?: string;
usage: AssistantMessage["usage"];
};
export interface ProxyStreamOptions extends SimpleStreamOptions {
/** Auth token for the proxy server */
authToken: string;
/** Proxy server URL (e.g., "https://genai.example.com") */
proxyUrl: string;
}
/**
* Stream function that proxies through a server instead of calling LLM providers directly.
* The server strips the partial field from delta events to reduce bandwidth.
* We reconstruct the partial message client-side.
*
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
*
* @example
* ```typescript
* const agent = new Agent({
* streamFn: (model, context, options) =>
* streamProxy(model, context, {
* ...options,
* authToken: await getAuthToken(),
* proxyUrl: "https://genai.example.com",
* }),
* });
* ```
*/
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
const stream = new ProxyMessageEventStream();
(async () => {
// Initialize the partial message that we'll build up from events
const partial: AssistantMessage = {
role: "assistant",
stopReason: "stop",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
timestamp: Date.now(),
};
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
const abortHandler = () => {
if (reader) {
reader.cancel("Request aborted by user").catch(() => {});
}
};
if (options.signal) {
options.signal.addEventListener("abort", abortHandler);
}
try {
const response = await fetch(`${options.proxyUrl}/api/stream`, {
method: "POST",
headers: {
Authorization: `Bearer ${options.authToken}`,
"Content-Type": "application/json",
},
body: JSON.stringify({
model,
context,
options: {
temperature: options.temperature,
maxTokens: options.maxTokens,
reasoning: options.reasoning,
},
}),
signal: options.signal,
});
if (!response.ok) {
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
try {
const errorData = (await response.json()) as { error?: string };
if (errorData.error) {
errorMessage = `Proxy error: ${errorData.error}`;
}
} catch {
// Couldn't parse error response
}
throw new Error(errorMessage);
}
reader = response.body!.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const data = line.slice(6).trim();
if (data) {
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
const event = processProxyEvent(proxyEvent, partial);
if (event) {
stream.push(event);
}
}
}
}
}
if (options.signal?.aborted) {
throw new Error("Request aborted by user");
}
stream.end();
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error);
const reason = options.signal?.aborted ? "aborted" : "error";
partial.stopReason = reason;
partial.errorMessage = errorMessage;
stream.push({
type: "error",
reason,
error: partial,
});
stream.end();
} finally {
if (options.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
}
})();
return stream;
}
/**
* Process a proxy event and update the partial message.
*/
function processProxyEvent(
proxyEvent: ProxyAssistantMessageEvent,
partial: AssistantMessage,
): AssistantMessageEvent | undefined {
switch (proxyEvent.type) {
case "start":
return { type: "start", partial };
case "text_start":
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
case "text_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.text += proxyEvent.delta;
return {
type: "text_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received text_delta for non-text content");
}
case "text_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "text") {
content.textSignature = proxyEvent.contentSignature;
return {
type: "text_end",
contentIndex: proxyEvent.contentIndex,
content: content.text,
partial,
};
}
throw new Error("Received text_end for non-text content");
}
case "thinking_start":
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
case "thinking_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinking += proxyEvent.delta;
return {
type: "thinking_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received thinking_delta for non-thinking content");
}
case "thinking_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "thinking") {
content.thinkingSignature = proxyEvent.contentSignature;
return {
type: "thinking_end",
contentIndex: proxyEvent.contentIndex,
content: content.thinking,
partial,
};
}
throw new Error("Received thinking_end for non-thinking content");
}
case "toolcall_start":
partial.content[proxyEvent.contentIndex] = {
type: "toolCall",
id: proxyEvent.id,
name: proxyEvent.toolName,
arguments: {},
partialJson: "",
} satisfies ToolCall & { partialJson: string } as ToolCall;
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
case "toolcall_delta": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
(content as any).partialJson += proxyEvent.delta;
content.arguments = parseStreamingJson((content as any).partialJson) || {};
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
return {
type: "toolcall_delta",
contentIndex: proxyEvent.contentIndex,
delta: proxyEvent.delta,
partial,
};
}
throw new Error("Received toolcall_delta for non-toolCall content");
}
case "toolcall_end": {
const content = partial.content[proxyEvent.contentIndex];
if (content?.type === "toolCall") {
delete (content as any).partialJson;
return {
type: "toolcall_end",
contentIndex: proxyEvent.contentIndex,
toolCall: content,
partial,
};
}
return undefined;
}
case "done":
partial.stopReason = proxyEvent.reason;
partial.usage = proxyEvent.usage;
return { type: "done", reason: proxyEvent.reason, message: partial };
case "error":
partial.stopReason = proxyEvent.reason;
partial.errorMessage = proxyEvent.errorMessage;
partial.usage = proxyEvent.usage;
return { type: "error", reason: proxyEvent.reason, error: partial };
default: {
const _exhaustiveCheck: never = proxyEvent;
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
return undefined;
}
}
}

View file

@ -0,0 +1,194 @@
import type {
AssistantMessageEvent,
ImageContent,
Message,
Model,
SimpleStreamOptions,
streamSimple,
TextContent,
Tool,
ToolResultMessage,
} from "@gsd/pi-ai";
import type { Static, TSchema } from "@sinclair/typebox";
/** Stream function - can return sync or Promise for async config lookup */
export type StreamFn = (
...args: Parameters<typeof streamSimple>
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
/**
* Configuration for the agent loop.
*/
export interface AgentLoopConfig extends SimpleStreamOptions {
model: Model<any>;
/**
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
*
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
* status messages) should be filtered out.
*
* @example
* ```typescript
* convertToLlm: (messages) => messages.flatMap(m => {
* if (m.role === "custom") {
* // Convert custom message to user message
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
* }
* if (m.role === "notification") {
* // Filter out UI-only messages
* return [];
* }
* // Pass through standard LLM messages
* return [m];
* })
* ```
*/
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
/**
* Optional transform applied to the context before `convertToLlm`.
*
* Use this for operations that work at the AgentMessage level:
* - Context window management (pruning old messages)
* - Injecting context from external sources
*
* @example
* ```typescript
* transformContext: async (messages) => {
* if (estimateTokens(messages) > MAX_TOKENS) {
* return pruneOldMessages(messages);
* }
* return messages;
* }
* ```
*/
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
/**
* Resolves an API key dynamically for each LLM call.
*
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
* during long-running tool execution phases.
*/
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
/**
* Returns steering messages to inject into the conversation mid-run.
*
* Called after each tool execution to check for user interruptions.
* If messages are returned, remaining tool calls are skipped and
* these messages are added to the context before the next LLM call.
*
* Use this for "steering" the agent while it's working.
*/
getSteeringMessages?: () => Promise<AgentMessage[]>;
/**
* Returns follow-up messages to process after the agent would otherwise stop.
*
* Called when the agent has no more tool calls and no steering messages.
* If messages are returned, they're added to the context and the agent
* continues with another turn.
*
* Use this for follow-up messages that should wait until the agent finishes.
*/
getFollowUpMessages?: () => Promise<AgentMessage[]>;
}
/**
* Thinking/reasoning level for models that support it.
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
*/
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
/**
* Extensible interface for custom app messages.
* Apps can extend via declaration merging:
*
* @example
* ```typescript
* declare module "@mariozechner/agent" {
* interface CustomAgentMessages {
* artifact: ArtifactMessage;
* notification: NotificationMessage;
* }
* }
* ```
*/
export interface CustomAgentMessages {
// Empty by default - apps extend via declaration merging
}
/**
* AgentMessage: Union of LLM messages + custom messages.
* This abstraction allows apps to add custom message types while maintaining
* type safety and compatibility with the base LLM messages.
*/
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
/**
* Agent state containing all configuration and conversation data.
*/
export interface AgentState {
systemPrompt: string;
model: Model<any>;
thinkingLevel: ThinkingLevel;
tools: AgentTool<any>[];
messages: AgentMessage[]; // Can include attachments + custom message types
isStreaming: boolean;
streamMessage: AgentMessage | null;
pendingToolCalls: Set<string>;
error?: string;
}
export interface AgentToolResult<T> {
// Content blocks supporting text and images
content: (TextContent | ImageContent)[];
// Details to be displayed in a UI or logged
details: T;
}
// Callback for streaming tool execution updates
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
// AgentTool extends Tool but adds the execute function
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
// A human-readable label for the tool to be displayed in UI
label: string;
execute: (
toolCallId: string,
params: Static<TParameters>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<TDetails>,
) => Promise<AgentToolResult<TDetails>>;
}
// AgentContext is like Context but uses AgentTool
export interface AgentContext {
systemPrompt: string;
messages: AgentMessage[];
tools?: AgentTool<any>[];
}
/**
* Events emitted by the Agent for UI updates.
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
*/
export type AgentEvent =
// Agent lifecycle
| { type: "agent_start" }
| { type: "agent_end"; messages: AgentMessage[] }
// Turn lifecycle - a turn is one assistant response + any tool calls/results
| { type: "turn_start" }
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
// Message lifecycle - emitted for user, assistant, and toolResult messages
| { type: "message_start"; message: AgentMessage }
// Only emitted for assistant messages during streaming
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
| { type: "message_end"; message: AgentMessage }
// Tool execution lifecycle
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }
| { type: "tool_execution_end"; toolCallId: string; toolName: string; result: any; isError: boolean };

View file

@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2024",
"module": "Node16",
"lib": ["ES2024"],
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"inlineSources": true,
"inlineSourceMap": false,
"moduleResolution": "Node16",
"resolveJsonModule": true,
"allowImportingTsExtensions": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true,
"useDefineForClassFields": false,
"types": ["node"],
"outDir": "./dist",
"rootDir": "./src"
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
}

1
packages/pi-ai/bedrock-provider.d.ts vendored Normal file
View file

@ -0,0 +1 @@
export * from "./dist/bedrock-provider.js";

View file

@ -0,0 +1 @@
export * from "./dist/bedrock-provider.js";

View file

@ -0,0 +1,40 @@
{
"name": "@gsd/pi-ai",
"version": "0.57.1",
"description": "Unified LLM API (vendored from pi-mono)",
"type": "module",
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js"
},
"./oauth": {
"types": "./dist/oauth.d.ts",
"import": "./dist/oauth.js"
},
"./bedrock-provider": {
"types": "./bedrock-provider.d.ts",
"import": "./bedrock-provider.js"
}
},
"scripts": {
"build": "tsc -p tsconfig.json"
},
"dependencies": {
"@anthropic-ai/sdk": "^0.73.0",
"@aws-sdk/client-bedrock-runtime": "^3.983.0",
"@google/genai": "^1.40.0",
"@mistralai/mistralai": "1.14.1",
"@sinclair/typebox": "^0.34.41",
"ajv": "^8.17.1",
"ajv-formats": "^3.0.1",
"chalk": "^5.6.2",
"openai": "6.26.0",
"partial-json": "^0.1.7",
"proxy-agent": "^6.5.0",
"undici": "^7.19.1",
"zod-to-json-schema": "^3.24.6"
}
}

View file

@ -0,0 +1,98 @@
import type {
Api,
AssistantMessageEventStream,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "./types.js";
export type ApiStreamFunction = (
model: Model<Api>,
context: Context,
options?: StreamOptions,
) => AssistantMessageEventStream;
export type ApiStreamSimpleFunction = (
model: Model<Api>,
context: Context,
options?: SimpleStreamOptions,
) => AssistantMessageEventStream;
export interface ApiProvider<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> {
api: TApi;
stream: StreamFunction<TApi, TOptions>;
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
}
interface ApiProviderInternal {
api: Api;
stream: ApiStreamFunction;
streamSimple: ApiStreamSimpleFunction;
}
type RegisteredApiProvider = {
provider: ApiProviderInternal;
sourceId?: string;
};
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
api: TApi,
stream: StreamFunction<TApi, TOptions>,
): ApiStreamFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return stream(model as Model<TApi>, context, options as TOptions);
};
}
function wrapStreamSimple<TApi extends Api>(
api: TApi,
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
): ApiStreamSimpleFunction {
return (model, context, options) => {
if (model.api !== api) {
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
}
return streamSimple(model as Model<TApi>, context, options);
};
}
export function registerApiProvider<TApi extends Api, TOptions extends StreamOptions>(
provider: ApiProvider<TApi, TOptions>,
sourceId?: string,
): void {
apiProviderRegistry.set(provider.api, {
provider: {
api: provider.api,
stream: wrapStream(provider.api, provider.stream),
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
},
sourceId,
});
}
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
return apiProviderRegistry.get(api)?.provider;
}
export function getApiProviders(): ApiProviderInternal[] {
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
}
export function unregisterApiProviders(sourceId: string): void {
for (const [api, entry] of apiProviderRegistry.entries()) {
if (entry.sourceId === sourceId) {
apiProviderRegistry.delete(api);
}
}
}
export function clearApiProviders(): void {
apiProviderRegistry.clear();
}

View file

@ -0,0 +1,6 @@
import { streamBedrock, streamSimpleBedrock } from "./providers/amazon-bedrock.js";
export const bedrockProviderModule = {
streamBedrock,
streamSimpleBedrock,
};

133
packages/pi-ai/src/cli.ts Normal file
View file

@ -0,0 +1,133 @@
#!/usr/bin/env node
import { existsSync, readFileSync, writeFileSync } from "fs";
import { createInterface } from "readline";
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
const AUTH_FILE = "auth.json";
const PROVIDERS = getOAuthProviders();
function prompt(rl: ReturnType<typeof createInterface>, question: string): Promise<string> {
return new Promise((resolve) => rl.question(question, resolve));
}
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
if (!existsSync(AUTH_FILE)) return {};
try {
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
} catch {
return {};
}
}
function saveAuth(auth: Record<string, { type: "oauth" } & OAuthCredentials>): void {
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
}
async function login(providerId: OAuthProviderId): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
console.error(`Unknown provider: ${providerId}`);
process.exit(1);
}
const rl = createInterface({ input: process.stdin, output: process.stdout });
const promptFn = (msg: string) => prompt(rl, `${msg} `);
try {
const credentials = await provider.login({
onAuth: (info) => {
console.log(`\nOpen this URL in your browser:\n${info.url}`);
if (info.instructions) console.log(info.instructions);
console.log();
},
onPrompt: async (p) => {
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
},
onProgress: (msg) => console.log(msg),
});
const auth = loadAuth();
auth[providerId] = { type: "oauth", ...credentials };
saveAuth(auth);
console.log(`\nCredentials saved to ${AUTH_FILE}`);
} finally {
rl.close();
}
}
async function main(): Promise<void> {
const args = process.argv.slice(2);
const command = args[0];
if (!command || command === "help" || command === "--help" || command === "-h") {
const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n");
console.log(`Usage: npx @gsd/pi-ai <command> [provider]
Commands:
login [provider] Login to an OAuth provider
list List available providers
Providers:
${providerList}
Examples:
npx @gsd/pi-ai login # interactive provider selection
npx @gsd/pi-ai login anthropic # login to specific provider
npx @gsd/pi-ai list # list providers
`);
return;
}
if (command === "list") {
console.log("Available OAuth providers:\n");
for (const p of PROVIDERS) {
console.log(` ${p.id.padEnd(20)} ${p.name}`);
}
return;
}
if (command === "login") {
let provider = args[1] as OAuthProviderId | undefined;
if (!provider) {
const rl = createInterface({ input: process.stdin, output: process.stdout });
console.log("Select a provider:\n");
for (let i = 0; i < PROVIDERS.length; i++) {
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
}
console.log();
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
rl.close();
const index = parseInt(choice, 10) - 1;
if (index < 0 || index >= PROVIDERS.length) {
console.error("Invalid selection");
process.exit(1);
}
provider = PROVIDERS[index].id;
}
if (!PROVIDERS.some((p) => p.id === provider)) {
console.error(`Unknown provider: ${provider}`);
console.error(`Use 'npx @gsd/pi-ai list' to see available providers`);
process.exit(1);
}
console.log(`Logging in to ${provider}...`);
await login(provider);
return;
}
console.error(`Unknown command: ${command}`);
console.error(`Use 'npx @gsd/pi-ai --help' for usage`);
process.exit(1);
}
main().catch((err) => {
console.error("Error:", err.message);
process.exit(1);
});

View file

@ -0,0 +1,129 @@
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _existsSync: typeof import("node:fs").existsSync | null = null;
let _homedir: typeof import("node:os").homedir | null = null;
let _join: typeof import("node:path").join | null = null;
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const NODE_FS_SPECIFIER = "node:" + "fs";
const NODE_OS_SPECIFIER = "node:" + "os";
const NODE_PATH_SPECIFIER = "node:" + "path";
// Eagerly load in Node.js/Bun environment only
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
_existsSync = (m as typeof import("node:fs")).existsSync;
});
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
_homedir = (m as typeof import("node:os")).homedir;
});
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
_join = (m as typeof import("node:path")).join;
});
}
import type { KnownProvider } from "./types.js";
let cachedVertexAdcCredentialsExists: boolean | null = null;
function hasVertexAdcCredentials(): boolean {
if (cachedVertexAdcCredentialsExists === null) {
// If node modules haven't loaded yet (async import race at startup),
// return false WITHOUT caching so the next call retries once they're ready.
// Only cache false permanently in a browser environment where fs is never available.
if (!_existsSync || !_homedir || !_join) {
const isNode = typeof process !== "undefined" && (process.versions?.node || process.versions?.bun);
if (!isNode) {
// Definitively in a browser — safe to cache false permanently
cachedVertexAdcCredentialsExists = false;
}
return false;
}
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
if (gacPath) {
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
} else {
// Fall back to default ADC path (lazy evaluation)
cachedVertexAdcCredentialsExists = _existsSync(
_join(_homedir(), ".config", "gcloud", "application_default_credentials.json"),
);
}
}
return cachedVertexAdcCredentialsExists;
}
/**
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
*
* Will not return API keys for providers that require OAuth tokens.
*/
export function getEnvApiKey(provider: KnownProvider): string | undefined;
export function getEnvApiKey(provider: string): string | undefined;
export function getEnvApiKey(provider: any): string | undefined {
// Fall back to environment variables
if (provider === "github-copilot") {
return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
}
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
if (provider === "anthropic") {
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
}
// Vertex AI uses Application Default Credentials, not API keys.
// Auth is configured via `gcloud auth application-default login`.
if (provider === "google-vertex") {
const hasCredentials = hasVertexAdcCredentials();
const hasProject = !!(process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT);
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
if (hasCredentials && hasProject && hasLocation) {
return "<authenticated>";
}
}
if (provider === "amazon-bedrock") {
// Amazon Bedrock supports multiple credential sources:
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
if (
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
) {
return "<authenticated>";
}
}
const envMap: Record<string, string> = {
openai: "OPENAI_API_KEY",
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
google: "GEMINI_API_KEY",
groq: "GROQ_API_KEY",
cerebras: "CEREBRAS_API_KEY",
xai: "XAI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
zai: "ZAI_API_KEY",
mistral: "MISTRAL_API_KEY",
minimax: "MINIMAX_API_KEY",
"minimax-cn": "MINIMAX_CN_API_KEY",
huggingface: "HF_TOKEN",
opencode: "OPENCODE_API_KEY",
"opencode-go": "OPENCODE_API_KEY",
"kimi-coding": "KIMI_API_KEY",
};
const envVar = envMap[provider];
return envVar ? process.env[envVar] : undefined;
}

View file

@ -0,0 +1,32 @@
export type { Static, TSchema } from "@sinclair/typebox";
export { Type } from "@sinclair/typebox";
export * from "./api-registry.js";
export * from "./env-api-keys.js";
export * from "./models.js";
export * from "./providers/anthropic.js";
export * from "./providers/azure-openai-responses.js";
export * from "./providers/google.js";
export * from "./providers/google-gemini-cli.js";
export * from "./providers/google-vertex.js";
export * from "./providers/mistral.js";
export * from "./providers/openai-completions.js";
export * from "./providers/openai-responses.js";
export * from "./providers/register-builtins.js";
export * from "./stream.js";
export * from "./types.js";
export * from "./utils/event-stream.js";
export * from "./utils/json-parse.js";
export type {
OAuthAuthInfo,
OAuthCredentials,
OAuthLoginCallbacks,
OAuthPrompt,
OAuthProvider,
OAuthProviderId,
OAuthProviderInfo,
OAuthProviderInterface,
} from "./utils/oauth/types.js";
export * from "./utils/overflow.js";
export * from "./utils/typebox-helpers.js";
export * from "./utils/validation.js";

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,77 @@
import { MODELS } from "./models.generated.js";
import type { Api, KnownProvider, Model, Usage } from "./types.js";
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
// Initialize registry from MODELS on module load
for (const [provider, models] of Object.entries(MODELS)) {
const providerModels = new Map<string, Model<Api>>();
for (const [id, model] of Object.entries(models)) {
providerModels.set(id, model as Model<Api>);
}
modelRegistry.set(provider, providerModels);
}
type ModelApi<
TProvider extends KnownProvider,
TModelId extends keyof (typeof MODELS)[TProvider],
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof MODELS)[TProvider]>(
provider: TProvider,
modelId: TModelId,
): Model<ModelApi<TProvider, TModelId>> {
const providerModels = modelRegistry.get(provider);
return providerModels?.get(modelId as string) as Model<ModelApi<TProvider, TModelId>>;
}
export function getProviders(): KnownProvider[] {
return Array.from(modelRegistry.keys()) as KnownProvider[];
}
export function getModels<TProvider extends KnownProvider>(
provider: TProvider,
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
const models = modelRegistry.get(provider);
return models ? (Array.from(models.values()) as Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[]) : [];
}
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
usage.cost.input = (model.cost.input / 1000000) * usage.input;
usage.cost.output = (model.cost.output / 1000000) * usage.output;
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
return usage.cost;
}
/**
* Check if a model supports xhigh thinking level.
*
* Supported today:
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
* - Anthropic Messages API Opus 4.6 models (xhigh maps to adaptive effort "max")
*/
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
if (model.id.includes("gpt-5.2") || model.id.includes("gpt-5.3") || model.id.includes("gpt-5.4")) {
return true;
}
if (model.api === "anthropic-messages") {
return model.id.includes("opus-4-6") || model.id.includes("opus-4.6");
}
return false;
}
/**
* Check if two models are equal by comparing both their id and provider.
* Returns false if either model is null or undefined.
*/
export function modelsAreEqual<TApi extends Api>(
a: Model<TApi> | null | undefined,
b: Model<TApi> | null | undefined,
): boolean {
if (!a || !b) return false;
return a.id === b.id && a.provider === b.provider;
}

View file

@ -0,0 +1 @@
export * from "./utils/oauth/index.js";

View file

@ -0,0 +1,751 @@
import {
BedrockRuntimeClient,
type BedrockRuntimeClientConfig,
StopReason as BedrockStopReason,
type Tool as BedrockTool,
CachePointType,
CacheTTL,
type ContentBlock,
type ContentBlockDeltaEvent,
type ContentBlockStartEvent,
type ContentBlockStopEvent,
ConversationRole,
ConverseStreamCommand,
type ConverseStreamMetadataEvent,
ImageFormat,
type Message,
type SystemContentBlock,
type ToolChoice,
type ToolConfiguration,
ToolResultStatus,
} from "@aws-sdk/client-bedrock-runtime";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { adjustMaxTokensForThinking, buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
export interface BedrockOptions extends StreamOptions {
region?: string;
profile?: string;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
reasoning?: ThinkingLevel;
/* Custom token budgets per thinking level. Overrides default budgets. */
thinkingBudgets?: ThinkingBudgets;
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
interleavedThinking?: boolean;
}
type Block = (TextContent | ThinkingContent | ToolCall) & { index?: number; partialJson?: string };
export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOptions> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options: BedrockOptions = {},
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "bedrock-converse-stream" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
const blocks = output.content as Block[];
const config: BedrockRuntimeClientConfig = {
profile: options.profile,
};
// in Node.js/Bun environment only
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
// Region resolution: explicit option > env vars > SDK default chain.
// When AWS_PROFILE is set, we leave region undefined so the SDK can
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
const explicitRegion = options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION;
if (explicitRegion) {
config.region = explicitRegion;
} else if (!process.env.AWS_PROFILE) {
config.region = "us-east-1";
}
// Support proxies that don't need authentication
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
config.credentials = {
accessKeyId: "dummy-access-key",
secretAccessKey: "dummy-secret-key",
};
}
if (
process.env.HTTP_PROXY ||
process.env.HTTPS_PROXY ||
process.env.NO_PROXY ||
process.env.http_proxy ||
process.env.https_proxy ||
process.env.no_proxy
) {
const nodeHttpHandler = await import("@smithy/node-http-handler");
const proxyAgent = await import("proxy-agent");
const agent = new proxyAgent.ProxyAgent();
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
// on `http2` module and has no support for http agent.
// Use NodeHttpHandler to support http agent.
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
httpAgent: agent,
httpsAgent: agent,
});
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
const nodeHttpHandler = await import("@smithy/node-http-handler");
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
}
} else {
// Non-Node environment (browser): fall back to us-east-1 since
// there's no config file resolution available.
config.region = options.region || "us-east-1";
}
try {
const client = new BedrockRuntimeClient(config);
const cacheRetention = resolveCacheRetention(options.cacheRetention);
let commandInput = {
modelId: model.id,
messages: convertMessages(context, model, cacheRetention),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
toolConfig: convertToolConfig(context.tools, options.toolChoice),
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
};
const nextCommandInput = await options?.onPayload?.(commandInput, model);
if (nextCommandInput !== undefined) {
commandInput = nextCommandInput as typeof commandInput;
}
const command = new ConverseStreamCommand(commandInput);
const response = await client.send(command, { abortSignal: options.signal });
for await (const item of response.stream!) {
if (item.messageStart) {
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
throw new Error("Unexpected assistant message start but got user message start instead");
}
stream.push({ type: "start", partial: output });
} else if (item.contentBlockStart) {
handleContentBlockStart(item.contentBlockStart, blocks, output, stream);
} else if (item.contentBlockDelta) {
handleContentBlockDelta(item.contentBlockDelta, blocks, output, stream);
} else if (item.contentBlockStop) {
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
} else if (item.messageStop) {
output.stopReason = mapStopReason(item.messageStop.stopReason);
} else if (item.metadata) {
handleMetadata(item.metadata, model, output);
} else if (item.internalServerException) {
throw new Error(`Internal server error: ${item.internalServerException.message}`);
} else if (item.modelStreamErrorException) {
throw new Error(`Model stream error: ${item.modelStreamErrorException.message}`);
} else if (item.validationException) {
throw new Error(`Validation error: ${item.validationException.message}`);
} else if (item.throttlingException) {
throw new Error(`Throttling error: ${item.throttlingException.message}`);
} else if (item.serviceUnavailableException) {
throw new Error(`Service unavailable: ${item.serviceUnavailableException.message}`);
}
}
if (options.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "error" || output.stopReason === "aborted") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) {
delete (block as Block).index;
delete (block as Block).partialJson;
}
output.stopReason = options.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleBedrock: StreamFunction<"bedrock-converse-stream", SimpleStreamOptions> = (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamBedrock(model, context, { ...base, reasoning: undefined } satisfies BedrockOptions);
}
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
if (supportsAdaptiveThinking(model.id)) {
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamBedrock(model, context, {
...base,
maxTokens: adjusted.maxTokens,
reasoning: options.reasoning,
thinkingBudgets: {
...(options.thinkingBudgets || {}),
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
},
} satisfies BedrockOptions);
}
return streamBedrock(model, context, {
...base,
reasoning: options.reasoning,
thinkingBudgets: options.thinkingBudgets,
} satisfies BedrockOptions);
};
function handleContentBlockStart(
event: ContentBlockStartEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = event.contentBlockIndex!;
const start = event.start;
if (start?.toolUse) {
const block: Block = {
type: "toolCall",
id: start.toolUse.toolUseId || "",
name: start.toolUse.name || "",
arguments: {},
partialJson: "",
index,
};
output.content.push(block);
stream.push({ type: "toolcall_start", contentIndex: blocks.length - 1, partial: output });
}
}
function handleContentBlockDelta(
event: ContentBlockDeltaEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const contentBlockIndex = event.contentBlockIndex!;
const delta = event.delta;
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
let block = blocks[index];
if (delta?.text !== undefined) {
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
if (!block) {
const newBlock: Block = { type: "text", text: "", index: contentBlockIndex };
output.content.push(newBlock);
index = blocks.length - 1;
block = blocks[index];
stream.push({ type: "text_start", contentIndex: index, partial: output });
}
if (block.type === "text") {
block.text += delta.text;
stream.push({ type: "text_delta", contentIndex: index, delta: delta.text, partial: output });
}
} else if (delta?.toolUse && block?.type === "toolCall") {
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
block.arguments = parseStreamingJson(block.partialJson);
stream.push({ type: "toolcall_delta", contentIndex: index, delta: delta.toolUse.input || "", partial: output });
} else if (delta?.reasoningContent) {
let thinkingBlock = block;
let thinkingIndex = index;
if (!thinkingBlock) {
const newBlock: Block = { type: "thinking", thinking: "", thinkingSignature: "", index: contentBlockIndex };
output.content.push(newBlock);
thinkingIndex = blocks.length - 1;
thinkingBlock = blocks[thinkingIndex];
stream.push({ type: "thinking_start", contentIndex: thinkingIndex, partial: output });
}
if (thinkingBlock?.type === "thinking") {
if (delta.reasoningContent.text) {
thinkingBlock.thinking += delta.reasoningContent.text;
stream.push({
type: "thinking_delta",
contentIndex: thinkingIndex,
delta: delta.reasoningContent.text,
partial: output,
});
}
if (delta.reasoningContent.signature) {
thinkingBlock.thinkingSignature =
(thinkingBlock.thinkingSignature || "") + delta.reasoningContent.signature;
}
}
}
}
function handleMetadata(
event: ConverseStreamMetadataEvent,
model: Model<"bedrock-converse-stream">,
output: AssistantMessage,
): void {
if (event.usage) {
output.usage.input = event.usage.inputTokens || 0;
output.usage.output = event.usage.outputTokens || 0;
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
output.usage.totalTokens = event.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
}
function handleContentBlockStop(
event: ContentBlockStopEvent,
blocks: Block[],
output: AssistantMessage,
stream: AssistantMessageEventStream,
): void {
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
const block = blocks[index];
if (!block) return;
delete (block as Block).index;
switch (block.type) {
case "text":
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output });
break;
case "thinking":
stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output });
break;
case "toolCall":
block.arguments = parseStreamingJson(block.partialJson);
delete (block as Block).partialJson;
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output });
break;
}
}
/**
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
*/
function supportsAdaptiveThinking(modelId: string): boolean {
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
function mapThinkingLevelToEffort(
level: SimpleStreamOptions["reasoning"],
modelId: string,
): "low" | "medium" | "high" | "max" {
switch (level) {
case "minimal":
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
default:
return "high";
}
}
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
/**
* Check if the model supports prompt caching.
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
*/
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
if (model.cost.cacheRead || model.cost.cacheWrite) {
return true;
}
const id = model.id.toLowerCase();
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
// Claude 3.7 Sonnet
if (id.includes("claude-3-7-sonnet")) return true;
// Claude 3.5 Haiku
if (id.includes("claude-3-5-haiku")) return true;
return false;
}
/**
* Check if the model supports thinking signatures in reasoningContent.
* Only Anthropic Claude models support the signature field.
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
* "This model doesn't support the reasoningContent.reasoningText.signature field"
*/
function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boolean {
const id = model.id.toLowerCase();
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
}
function buildSystemPrompt(
systemPrompt: string | undefined,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): SystemContentBlock[] | undefined {
if (!systemPrompt) return undefined;
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
// Add cache point for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
blocks.push({
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
});
}
return blocks;
}
function normalizeToolCallId(id: string): string {
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
}
function convertMessages(
context: Context,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
): Message[] {
const result: Message[] = [];
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
for (let i = 0; i < transformedMessages.length; i++) {
const m = transformedMessages[i];
switch (m.role) {
case "user":
result.push({
role: ConversationRole.USER,
content:
typeof m.content === "string"
? [{ text: sanitizeSurrogates(m.content) }]
: m.content.map((c) => {
switch (c.type) {
case "text":
return { text: sanitizeSurrogates(c.text) };
case "image":
return { image: createImageBlock(c.mimeType, c.data) };
default:
throw new Error("Unknown user content type");
}
}),
});
break;
case "assistant": {
// Skip assistant messages with empty content (e.g., from aborted requests)
// Bedrock rejects messages with empty content arrays
if (m.content.length === 0) {
continue;
}
const contentBlocks: ContentBlock[] = [];
for (const c of m.content) {
switch (c.type) {
case "text":
// Skip empty text blocks
if (c.text.trim().length === 0) continue;
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
break;
case "toolCall":
contentBlocks.push({
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
});
break;
case "thinking":
// Skip empty thinking blocks
if (c.thinking.trim().length === 0) continue;
// Only Anthropic models support the signature field in reasoningText.
// For other models, we omit the signature to avoid errors like:
// "This model doesn't support the reasoningContent.reasoningText.signature field"
if (supportsThinkingSignature(model)) {
contentBlocks.push({
reasoningContent: {
reasoningText: { text: sanitizeSurrogates(c.thinking), signature: c.thinkingSignature },
},
});
} else {
contentBlocks.push({
reasoningContent: {
reasoningText: { text: sanitizeSurrogates(c.thinking) },
},
});
}
break;
default:
throw new Error("Unknown assistant content type");
}
}
// Skip if all content blocks were filtered out
if (contentBlocks.length === 0) {
continue;
}
result.push({
role: ConversationRole.ASSISTANT,
content: contentBlocks,
});
break;
}
case "toolResult": {
// Collect all consecutive toolResult messages into a single user message
// Bedrock requires all tool results to be in one message
const toolResults: ContentBlock.ToolResultMember[] = [];
// Add current tool result with all content blocks combined
toolResults.push({
toolResult: {
toolUseId: m.toolCallId,
content: m.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: m.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
},
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
const nextMsg = transformedMessages[j] as ToolResultMessage;
toolResults.push({
toolResult: {
toolUseId: nextMsg.toolCallId,
content: nextMsg.content.map((c) =>
c.type === "image"
? { image: createImageBlock(c.mimeType, c.data) }
: { text: sanitizeSurrogates(c.text) },
),
status: nextMsg.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
},
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
result.push({
role: ConversationRole.USER,
content: toolResults,
});
break;
}
default:
throw new Error("Unknown message role");
}
}
// Add cache point to the last user message for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
const lastMessage = result[result.length - 1];
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
(lastMessage.content as ContentBlock[]).push({
cachePoint: {
type: CachePointType.DEFAULT,
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
},
});
}
}
return result;
}
function convertToolConfig(
tools: Tool[] | undefined,
toolChoice: BedrockOptions["toolChoice"],
): ToolConfiguration | undefined {
if (!tools?.length || toolChoice === "none") return undefined;
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
toolSpec: {
name: tool.name,
description: tool.description,
inputSchema: { json: tool.parameters },
},
}));
let bedrockToolChoice: ToolChoice | undefined;
switch (toolChoice) {
case "auto":
bedrockToolChoice = { auto: {} };
break;
case "any":
bedrockToolChoice = { any: {} };
break;
default:
if (toolChoice?.type === "tool") {
bedrockToolChoice = { tool: { name: toolChoice.name } };
}
}
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
}
function mapStopReason(reason: string | undefined): StopReason {
switch (reason) {
case BedrockStopReason.END_TURN:
case BedrockStopReason.STOP_SEQUENCE:
return "stop";
case BedrockStopReason.MAX_TOKENS:
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
return "length";
case BedrockStopReason.TOOL_USE:
return "toolUse";
default:
return "error";
}
}
function buildAdditionalModelRequestFields(
model: Model<"bedrock-converse-stream">,
options: BedrockOptions,
): Record<string, any> | undefined {
if (!options.reasoning || !model.reasoning) {
return undefined;
}
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
? {
thinking: { type: "adaptive" },
output_config: { effort: mapThinkingLevelToEffort(options.reasoning, model.id) },
}
: (() => {
const defaultBudgets: Record<ThinkingLevel, number> = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
};
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
const level = options.reasoning === "xhigh" ? "high" : options.reasoning;
const budget = options.thinkingBudgets?.[level] ?? defaultBudgets[options.reasoning];
return {
thinking: {
type: "enabled",
budget_tokens: budget,
},
};
})();
if (!supportsAdaptiveThinking(model.id) && (options.interleavedThinking ?? true)) {
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
}
return result;
}
return undefined;
}
function createImageBlock(mimeType: string, data: string) {
let format: ImageFormat;
switch (mimeType) {
case "image/jpeg":
case "image/jpg":
format = ImageFormat.JPEG;
break;
case "image/png":
format = ImageFormat.PNG;
break;
case "image/gif":
format = ImageFormat.GIF;
break;
case "image/webp":
format = ImageFormat.WEBP;
break;
default:
throw new Error(`Unknown image type: ${mimeType}`);
}
const binaryString = atob(data);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return { source: { bytes }, format };
}

View file

@ -0,0 +1,883 @@
import Anthropic from "@anthropic-ai/sdk";
import type {
ContentBlockParam,
MessageCreateParamsStreaming,
MessageParam,
} from "@anthropic-ai/sdk/resources/messages.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
ImageContent,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
function getCacheControl(
baseUrl: string,
cacheRetention?: CacheRetention,
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
const retention = resolveCacheRetention(cacheRetention);
if (retention === "none") {
return { retention };
}
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
return {
retention,
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
};
}
// Stealth mode: Mimic Claude Code's tool naming exactly
const claudeCodeVersion = "2.1.62";
// Claude Code 2.x tool names (canonical casing)
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
// To update: https://github.com/badlogic/cchistory
const claudeCodeTools = [
"Read",
"Write",
"Edit",
"Bash",
"Grep",
"Glob",
"AskUserQuestion",
"EnterPlanMode",
"ExitPlanMode",
"KillShell",
"NotebookEdit",
"Skill",
"Task",
"TaskOutput",
"TodoWrite",
"WebFetch",
"WebSearch",
];
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
// Convert tool name to CC canonical casing if it matches (case-insensitive)
const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name;
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
if (tools && tools.length > 0) {
const lowerName = name.toLowerCase();
const matchedTool = tools.find((tool) => tool.name.toLowerCase() === lowerName);
if (matchedTool) return matchedTool.name;
}
return name;
};
/**
* Convert content blocks to Anthropic API format
*/
function convertContentBlocks(content: (TextContent | ImageContent)[]):
| string
| Array<
| { type: "text"; text: string }
| {
type: "image";
source: {
type: "base64";
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
data: string;
};
}
> {
// If only text blocks, return as concatenated string for simplicity
const hasImages = content.some((c) => c.type === "image");
if (!hasImages) {
return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n"));
}
// If we have images, convert to content block array
const blocks = content.map((block) => {
if (block.type === "text") {
return {
type: "text" as const,
text: sanitizeSurrogates(block.text),
};
}
return {
type: "image" as const,
source: {
type: "base64" as const,
media_type: block.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: block.data,
},
};
});
// If only images (no text), add placeholder text block
const hasText = blocks.some((b) => b.type === "text");
if (!hasText) {
blocks.unshift({
type: "text" as const,
text: "(see attached image)",
});
}
return blocks;
}
export type AnthropicEffort = "low" | "medium" | "high" | "max";
export interface AnthropicOptions extends StreamOptions {
/**
* Enable extended thinking.
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
* For older models: uses budget-based thinking with thinkingBudgetTokens.
*/
thinkingEnabled?: boolean;
/**
* Token budget for extended thinking (older models only).
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
*/
thinkingBudgetTokens?: number;
/**
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
* Controls how much thinking Claude allocates:
* - "max": Always thinks with no constraints (Opus 4.6 only)
* - "high": Always thinks, deep reasoning (default)
* - "medium": Moderate thinking, may skip for simple queries
* - "low": Minimal thinking, skips for simple tasks
* Ignored for older models.
*/
effort?: AnthropicEffort;
interleavedThinking?: boolean;
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
}
function mergeHeaders(...headerSources: (Record<string, string> | undefined)[]): Record<string, string> {
const merged: Record<string, string> = {};
for (const headers of headerSources) {
if (headers) {
Object.assign(merged, headers);
}
}
return merged;
}
export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOptions> = (
model: Model<"anthropic-messages">,
context: Context,
options?: AnthropicOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
let copilotDynamicHeaders: Record<string, string> | undefined;
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
copilotDynamicHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
}
const { client, isOAuthToken } = createClient(
model,
apiKey,
options?.interleavedThinking ?? true,
options?.headers,
copilotDynamicHeaders,
);
let params = buildParams(model, context, isOAuthToken, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as MessageCreateParamsStreaming;
}
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
stream.push({ type: "start", partial: output });
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
const blocks = output.content as Block[];
for await (const event of anthropicStream) {
if (event.type === "message_start") {
// Capture initial token usage from message_start event
// This ensures we have input token counts even if the stream is aborted early
output.usage.input = event.message.usage.input_tokens || 0;
output.usage.output = event.message.usage.output_tokens || 0;
output.usage.cacheRead = event.message.usage.cache_read_input_tokens || 0;
output.usage.cacheWrite = event.message.usage.cache_creation_input_tokens || 0;
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
calculateCost(model, output.usage);
} else if (event.type === "content_block_start") {
if (event.content_block.type === "text") {
const block: Block = {
type: "text",
text: "",
index: event.index,
};
output.content.push(block);
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "thinking") {
const block: Block = {
type: "thinking",
thinking: "",
thinkingSignature: "",
index: event.index,
};
output.content.push(block);
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "redacted_thinking") {
const block: Block = {
type: "thinking",
thinking: "[Reasoning redacted]",
thinkingSignature: event.content_block.data,
redacted: true,
index: event.index,
};
output.content.push(block);
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
} else if (event.content_block.type === "tool_use") {
const block: Block = {
type: "toolCall",
id: event.content_block.id,
name: isOAuthToken
? fromClaudeCodeName(event.content_block.name, context.tools)
: event.content_block.name,
arguments: (event.content_block.input as Record<string, any>) ?? {},
partialJson: "",
index: event.index,
};
output.content.push(block);
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
}
} else if (event.type === "content_block_delta") {
if (event.delta.type === "text_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "text") {
block.text += event.delta.text;
stream.push({
type: "text_delta",
contentIndex: index,
delta: event.delta.text,
partial: output,
});
}
} else if (event.delta.type === "thinking_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinking += event.delta.thinking;
stream.push({
type: "thinking_delta",
contentIndex: index,
delta: event.delta.thinking,
partial: output,
});
}
} else if (event.delta.type === "input_json_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "toolCall") {
block.partialJson += event.delta.partial_json;
block.arguments = parseStreamingJson(block.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: index,
delta: event.delta.partial_json,
partial: output,
});
}
} else if (event.delta.type === "signature_delta") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block && block.type === "thinking") {
block.thinkingSignature = block.thinkingSignature || "";
block.thinkingSignature += event.delta.signature;
}
}
} else if (event.type === "content_block_stop") {
const index = blocks.findIndex((b) => b.index === event.index);
const block = blocks[index];
if (block) {
delete (block as any).index;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: index,
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialJson);
delete (block as any).partialJson;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: block,
partial: output,
});
}
}
} else if (event.type === "message_delta") {
if (event.delta.stop_reason) {
output.stopReason = mapStopReason(event.delta.stop_reason);
}
// Only update usage fields if present (not null).
// Preserves input_tokens from message_start when proxies omit it in message_delta.
if (event.usage.input_tokens != null) {
output.usage.input = event.usage.input_tokens;
}
if (event.usage.output_tokens != null) {
output.usage.output = event.usage.output_tokens;
}
if (event.usage.cache_read_input_tokens != null) {
output.usage.cacheRead = event.usage.cache_read_input_tokens;
}
if (event.usage.cache_creation_input_tokens != null) {
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
}
// Anthropic doesn't provide total_tokens, compute from components
output.usage.totalTokens =
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
calculateCost(model, output.usage);
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
*/
function supportsAdaptiveThinking(modelId: string): boolean {
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
return (
modelId.includes("opus-4-6") ||
modelId.includes("opus-4.6") ||
modelId.includes("sonnet-4-6") ||
modelId.includes("sonnet-4.6")
);
}
/**
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
* Note: effort "max" is only valid on Opus 4.6.
*/
function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"], modelId: string): AnthropicEffort {
switch (level) {
case "minimal":
return "low";
case "low":
return "low";
case "medium":
return "medium";
case "high":
return "high";
case "xhigh":
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
default:
return "high";
}
}
export const streamSimpleAnthropic: StreamFunction<"anthropic-messages", SimpleStreamOptions> = (
model: Model<"anthropic-messages">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamAnthropic(model, context, { ...base, thinkingEnabled: false } satisfies AnthropicOptions);
}
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
// For older models: use budget-based thinking
if (supportsAdaptiveThinking(model.id)) {
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
return streamAnthropic(model, context, {
...base,
thinkingEnabled: true,
effort,
} satisfies AnthropicOptions);
}
const adjusted = adjustMaxTokensForThinking(
base.maxTokens || 0,
model.maxTokens,
options.reasoning,
options.thinkingBudgets,
);
return streamAnthropic(model, context, {
...base,
maxTokens: adjusted.maxTokens,
thinkingEnabled: true,
thinkingBudgetTokens: adjusted.thinkingBudget,
} satisfies AnthropicOptions);
};
function isOAuthToken(apiKey: string): boolean {
return apiKey.includes("sk-ant-oat");
}
function createClient(
model: Model<"anthropic-messages">,
apiKey: string,
interleavedThinking: boolean,
optionsHeaders?: Record<string, string>,
dynamicHeaders?: Record<string, string>,
): { client: Anthropic; isOAuthToken: boolean } {
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
const needsInterleavedBeta = interleavedThinking && !supportsAdaptiveThinking(model.id);
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
if (model.provider === "github-copilot") {
const betaFeatures: string[] = [];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
...(betaFeatures.length > 0 ? { "anthropic-beta": betaFeatures.join(",") } : {}),
},
model.headers,
dynamicHeaders,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
if (needsInterleavedBeta) {
betaFeatures.push("interleaved-thinking-2025-05-14");
}
// OAuth: Bearer auth, Claude Code identity headers
if (isOAuthToken(apiKey)) {
const client = new Anthropic({
apiKey: null,
authToken: apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
"user-agent": `claude-cli/${claudeCodeVersion}`,
"x-app": "cli",
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: true };
}
// API key auth
const client = new Anthropic({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
accept: "application/json",
"anthropic-dangerous-direct-browser-access": "true",
"anthropic-beta": betaFeatures.join(","),
},
model.headers,
optionsHeaders,
),
});
return { client, isOAuthToken: false };
}
function buildParams(
model: Model<"anthropic-messages">,
context: Context,
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
const params: MessageCreateParamsStreaming = {
model: model.id,
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
stream: true,
};
// For OAuth tokens, we MUST include Claude Code identity
if (isOAuthToken) {
params.system = [
{
type: "text",
text: "You are Claude Code, Anthropic's official CLI for Claude.",
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
if (context.systemPrompt) {
params.system.push({
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
});
}
} else if (context.systemPrompt) {
// Add cache control to system prompt for non-OAuth tokens
params.system = [
{
type: "text",
text: sanitizeSurrogates(context.systemPrompt),
...(cacheControl ? { cache_control: cacheControl } : {}),
},
];
}
// Temperature is incompatible with extended thinking (adaptive or budget-based).
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, isOAuthToken);
}
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6) or budget-based (older models)
if (options?.thinkingEnabled && model.reasoning) {
if (supportsAdaptiveThinking(model.id)) {
// Adaptive thinking: Claude decides when and how much to think
params.thinking = { type: "adaptive" };
if (options.effort) {
params.output_config = { effort: options.effort };
}
} else {
// Budget-based thinking for older models
params.thinking = {
type: "enabled",
budget_tokens: options.thinkingBudgetTokens || 1024,
};
}
}
if (options?.metadata) {
const userId = options.metadata.user_id;
if (typeof userId === "string") {
params.metadata = { user_id: userId };
}
}
if (options?.toolChoice) {
if (typeof options.toolChoice === "string") {
params.tool_choice = { type: options.toolChoice };
} else {
params.tool_choice = options.toolChoice;
}
}
return params;
}
// Normalize tool call IDs to match Anthropic's required pattern and length
function normalizeToolCallId(id: string): string {
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
}
function convertMessages(
messages: Message[],
model: Model<"anthropic-messages">,
isOAuthToken: boolean,
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
): MessageParam[] {
const params: MessageParam[] = [];
// Transform messages for cross-provider compatibility
const transformedMessages = transformMessages(messages, model, normalizeToolCallId);
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
if (msg.role === "user") {
if (typeof msg.content === "string") {
if (msg.content.trim().length > 0) {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
}
} else {
const blocks: ContentBlockParam[] = msg.content.map((item) => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
};
} else {
return {
type: "image",
source: {
type: "base64",
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
data: item.data,
},
};
}
});
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
filteredBlocks = filteredBlocks.filter((b) => {
if (b.type === "text") {
return b.text.trim().length > 0;
}
return true;
});
if (filteredBlocks.length === 0) continue;
params.push({
role: "user",
content: filteredBlocks,
});
}
} else if (msg.role === "assistant") {
const blocks: ContentBlockParam[] = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length === 0) continue;
blocks.push({
type: "text",
text: sanitizeSurrogates(block.text),
});
} else if (block.type === "thinking") {
// Redacted thinking: pass the opaque payload back as redacted_thinking
if (block.redacted) {
blocks.push({
type: "redacted_thinking",
data: block.thinkingSignature!,
});
continue;
}
if (block.thinking.trim().length === 0) continue;
// If thinking signature is missing/empty (e.g., from aborted stream),
// convert to plain text block without <thinking> tags to avoid API rejection
// and prevent Claude from mimicking the tags in responses
if (!block.thinkingSignature || block.thinkingSignature.trim().length === 0) {
blocks.push({
type: "text",
text: sanitizeSurrogates(block.thinking),
});
} else {
blocks.push({
type: "thinking",
thinking: sanitizeSurrogates(block.thinking),
signature: block.thinkingSignature,
});
}
} else if (block.type === "toolCall") {
blocks.push({
type: "tool_use",
id: block.id,
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
input: block.arguments ?? {},
});
}
}
if (blocks.length === 0) continue;
params.push({
role: "assistant",
content: blocks,
});
} else if (msg.role === "toolResult") {
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
const toolResults: ContentBlockParam[] = [];
// Add the current tool result
toolResults.push({
type: "tool_result",
tool_use_id: msg.toolCallId,
content: convertContentBlocks(msg.content),
is_error: msg.isError,
});
// Look ahead for consecutive toolResult messages
let j = i + 1;
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
toolResults.push({
type: "tool_result",
tool_use_id: nextMsg.toolCallId,
content: convertContentBlocks(nextMsg.content),
is_error: nextMsg.isError,
});
j++;
}
// Skip the messages we've already processed
i = j - 1;
// Add a single user message with all tool results
params.push({
role: "user",
content: toolResults,
});
}
}
// Add cache_control to the last user message to cache conversation history
if (cacheControl && params.length > 0) {
const lastMessage = params[params.length - 1];
if (lastMessage.role === "user") {
if (Array.isArray(lastMessage.content)) {
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
if (
lastBlock &&
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
) {
(lastBlock as any).cache_control = cacheControl;
}
} else if (typeof lastMessage.content === "string") {
lastMessage.content = [
{
type: "text",
text: lastMessage.content,
cache_control: cacheControl,
},
] as any;
}
}
}
return params;
}
function convertTools(tools: Tool[], isOAuthToken: boolean): Anthropic.Messages.Tool[] {
if (!tools) return [];
return tools.map((tool) => {
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
return {
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
description: tool.description,
input_schema: {
type: "object" as const,
properties: jsonSchema.properties || {},
required: jsonSchema.required || [],
},
};
});
}
function mapStopReason(reason: Anthropic.Messages.StopReason | string): StopReason {
switch (reason) {
case "end_turn":
return "stop";
case "max_tokens":
return "length";
case "tool_use":
return "toolUse";
case "refusal":
return "error";
case "pause_turn": // Stop is good enough -> resubmit
return "stop";
case "stop_sequence":
return "stop"; // We don't supply stop sequences, so this should never happen
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
return "error";
default:
// Handle unknown stop reasons gracefully (API may add new values)
throw new Error(`Unhandled stop reason: ${reason}`);
}
}

View file

@ -0,0 +1,259 @@
import { AzureOpenAI } from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const DEFAULT_AZURE_API_VERSION = "v1";
const AZURE_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode", "azure-openai-responses"]);
function parseDeploymentNameMap(value: string | undefined): Map<string, string> {
const map = new Map<string, string>();
if (!value) return map;
for (const entry of value.split(",")) {
const trimmed = entry.trim();
if (!trimmed) continue;
const [modelId, deploymentName] = trimmed.split("=", 2);
if (!modelId || !deploymentName) continue;
map.set(modelId.trim(), deploymentName.trim());
}
return map;
}
function resolveDeploymentName(model: Model<"azure-openai-responses">, options?: AzureOpenAIResponsesOptions): string {
if (options?.azureDeploymentName) {
return options.azureDeploymentName;
}
const mappedDeployment = parseDeploymentNameMap(process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP).get(model.id);
return mappedDeployment || model.id;
}
// Azure OpenAI Responses-specific options
export interface AzureOpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
azureApiVersion?: string;
azureResourceName?: string;
azureBaseUrl?: string;
azureDeploymentName?: string;
}
/**
* Generate function for Azure OpenAI Responses API
*/
export const streamAzureOpenAIResponses: StreamFunction<"azure-openai-responses", AzureOpenAIResponsesOptions> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: AzureOpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const deploymentName = resolveDeploymentName(model, options);
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "azure-openai-responses" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create Azure OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options);
let params = buildParams(model, context, options, deploymentName);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as ResponseCreateParamsStreaming;
}
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleAzureOpenAIResponses: StreamFunction<"azure-openai-responses", SimpleStreamOptions> = (
model: Model<"azure-openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
return streamAzureOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies AzureOpenAIResponsesOptions);
};
function normalizeAzureBaseUrl(baseUrl: string): string {
return baseUrl.replace(/\/+$/, "");
}
function buildDefaultBaseUrl(resourceName: string): string {
return `https://${resourceName}.openai.azure.com/openai/v1`;
}
function resolveAzureConfig(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): { baseUrl: string; apiVersion: string } {
const apiVersion = options?.azureApiVersion || process.env.AZURE_OPENAI_API_VERSION || DEFAULT_AZURE_API_VERSION;
const baseUrl = options?.azureBaseUrl?.trim() || process.env.AZURE_OPENAI_BASE_URL?.trim() || undefined;
const resourceName = options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
let resolvedBaseUrl = baseUrl;
if (!resolvedBaseUrl && resourceName) {
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
}
if (!resolvedBaseUrl && model.baseUrl) {
resolvedBaseUrl = model.baseUrl;
}
if (!resolvedBaseUrl) {
throw new Error(
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
);
}
return {
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
apiVersion,
};
}
function createClient(model: Model<"azure-openai-responses">, apiKey: string, options?: AzureOpenAIResponsesOptions) {
if (!apiKey) {
if (!process.env.AZURE_OPENAI_API_KEY) {
throw new Error(
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.AZURE_OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (options?.headers) {
Object.assign(headers, options.headers);
}
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
return new AzureOpenAI({
apiKey,
apiVersion,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
baseURL: baseUrl,
});
}
function buildParams(
model: Model<"azure-openai-responses">,
context: Context,
options: AzureOpenAIResponsesOptions | undefined,
deploymentName: string,
) {
const messages = convertResponsesMessages(model, context, AZURE_TOOL_CALL_PROVIDERS);
const params: ResponseCreateParamsStreaming = {
model: deploymentName,
input: messages,
stream: true,
prompt_cache_key: options?.sessionId,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.toLowerCase().startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}

View file

@ -0,0 +1,37 @@
import type { Message } from "../types.js";
// Copilot expects X-Initiator to indicate whether the request is user-initiated
// or agent-initiated (e.g. follow-up after assistant/tool messages).
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
const last = messages[messages.length - 1];
return last && last.role !== "user" ? "agent" : "user";
}
// Copilot requires Copilot-Vision-Request header when sending images
export function hasCopilotVisionInput(messages: Message[]): boolean {
return messages.some((msg) => {
if (msg.role === "user" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
return msg.content.some((c) => c.type === "image");
}
return false;
});
}
export function buildCopilotDynamicHeaders(params: {
messages: Message[];
hasImages: boolean;
}): Record<string, string> {
const headers: Record<string, string> = {
"X-Initiator": inferCopilotInitiator(params.messages),
"Openai-Intent": "conversation-edits",
};
if (params.hasImages) {
headers["Copilot-Vision-Request"] = "true";
}
return headers;
}

View file

@ -0,0 +1,967 @@
/**
* Google Gemini CLI / Antigravity provider.
* Shared implementation for both google-gemini-cli and google-antigravity providers.
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
*/
import type { Content, ThinkingConfig } from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReasonString,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
/**
* Thinking level for Gemini 3 models.
* Mirrors Google's ThinkingLevel enum values.
*/
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
export interface GoogleGeminiCliOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
/**
* Thinking/reasoning configuration.
* - Gemini 2.x models: use `budgetTokens` to set the thinking budget
* - Gemini 3 models (gemini-3-pro-*, gemini-3-flash-*): use `level` instead
*
* When using `streamSimple`, this is handled automatically based on the model.
*/
thinking?: {
enabled: boolean;
/** Thinking budget in tokens. Use for Gemini 2.x models. */
budgetTokens?: number;
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
level?: GoogleThinkingLevel;
};
projectId?: string;
}
const DEFAULT_ENDPOINT = "https://cloudcode-pa.googleapis.com";
const ANTIGRAVITY_DAILY_ENDPOINT = "https://daily-cloudcode-pa.sandbox.googleapis.com";
const ANTIGRAVITY_AUTOPUSH_ENDPOINT = "https://autopush-cloudcode-pa.sandbox.googleapis.com";
const ANTIGRAVITY_ENDPOINT_FALLBACKS = [
ANTIGRAVITY_DAILY_ENDPOINT,
ANTIGRAVITY_AUTOPUSH_ENDPOINT,
DEFAULT_ENDPOINT,
] as const;
// Headers for Gemini CLI (prod endpoint)
const GEMINI_CLI_HEADERS = {
"User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"X-Goog-Api-Client": "gl-node/22.17.0",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Headers for Antigravity (sandbox endpoint) - requires specific User-Agent
const DEFAULT_ANTIGRAVITY_VERSION = "1.18.4";
function getAntigravityHeaders() {
const version = process.env.PI_AI_ANTIGRAVITY_VERSION || DEFAULT_ANTIGRAVITY_VERSION;
return {
"User-Agent": `antigravity/${version} darwin/arm64`,
};
}
// Antigravity system instruction (compact version from CLIProxyAPI).
const ANTIGRAVITY_SYSTEM_INSTRUCTION =
"You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding." +
"You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question." +
"**Absolute paths only**" +
"**Proactiveness**";
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
// Retry configuration
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 1000;
const MAX_EMPTY_STREAM_RETRIES = 2;
const EMPTY_STREAM_BASE_DELAY_MS = 500;
const CLAUDE_THINKING_BETA_HEADER = "interleaved-thinking-2025-05-14";
/**
* Extract retry delay from Gemini error response (in milliseconds).
* Checks headers first (Retry-After, x-ratelimit-reset, x-ratelimit-reset-after),
* then parses body patterns like:
* - "Your quota will reset after 39s"
* - "Your quota will reset after 18h31m10s"
* - "Please retry in Xs" or "Please retry in Xms"
* - "retryDelay": "34.074824224s" (JSON field)
*/
export function extractRetryDelay(errorText: string, response?: Response | Headers): number | undefined {
const normalizeDelay = (ms: number): number | undefined => (ms > 0 ? Math.ceil(ms + 1000) : undefined);
const headers = response instanceof Headers ? response : response?.headers;
if (headers) {
const retryAfter = headers.get("retry-after");
if (retryAfter) {
const retryAfterSeconds = Number(retryAfter);
if (Number.isFinite(retryAfterSeconds)) {
const delay = normalizeDelay(retryAfterSeconds * 1000);
if (delay !== undefined) {
return delay;
}
}
const retryAfterDate = new Date(retryAfter);
const retryAfterMs = retryAfterDate.getTime();
if (!Number.isNaN(retryAfterMs)) {
const delay = normalizeDelay(retryAfterMs - Date.now());
if (delay !== undefined) {
return delay;
}
}
}
const rateLimitReset = headers.get("x-ratelimit-reset");
if (rateLimitReset) {
const resetSeconds = Number.parseInt(rateLimitReset, 10);
if (!Number.isNaN(resetSeconds)) {
const delay = normalizeDelay(resetSeconds * 1000 - Date.now());
if (delay !== undefined) {
return delay;
}
}
}
const rateLimitResetAfter = headers.get("x-ratelimit-reset-after");
if (rateLimitResetAfter) {
const resetAfterSeconds = Number(rateLimitResetAfter);
if (Number.isFinite(resetAfterSeconds)) {
const delay = normalizeDelay(resetAfterSeconds * 1000);
if (delay !== undefined) {
return delay;
}
}
}
}
// Pattern 1: "Your quota will reset after ..." (formats: "18h31m10s", "10m15s", "6s", "39s")
const durationMatch = errorText.match(/reset after (?:(\d+)h)?(?:(\d+)m)?(\d+(?:\.\d+)?)s/i);
if (durationMatch) {
const hours = durationMatch[1] ? parseInt(durationMatch[1], 10) : 0;
const minutes = durationMatch[2] ? parseInt(durationMatch[2], 10) : 0;
const seconds = parseFloat(durationMatch[3]);
if (!Number.isNaN(seconds)) {
const totalMs = ((hours * 60 + minutes) * 60 + seconds) * 1000;
const delay = normalizeDelay(totalMs);
if (delay !== undefined) {
return delay;
}
}
}
// Pattern 2: "Please retry in X[ms|s]"
const retryInMatch = errorText.match(/Please retry in ([0-9.]+)(ms|s)/i);
if (retryInMatch?.[1]) {
const value = parseFloat(retryInMatch[1]);
if (!Number.isNaN(value) && value > 0) {
const ms = retryInMatch[2].toLowerCase() === "ms" ? value : value * 1000;
const delay = normalizeDelay(ms);
if (delay !== undefined) {
return delay;
}
}
}
// Pattern 3: "retryDelay": "34.074824224s" (JSON field in error details)
const retryDelayMatch = errorText.match(/"retryDelay":\s*"([0-9.]+)(ms|s)"/i);
if (retryDelayMatch?.[1]) {
const value = parseFloat(retryDelayMatch[1]);
if (!Number.isNaN(value) && value > 0) {
const ms = retryDelayMatch[2].toLowerCase() === "ms" ? value : value * 1000;
const delay = normalizeDelay(ms);
if (delay !== undefined) {
return delay;
}
}
}
return undefined;
}
function needsClaudeThinkingBetaHeader(model: Model<"google-gemini-cli">): boolean {
return model.provider === "google-antigravity" && model.id.startsWith("claude-") && model.reasoning;
}
function isGemini3ProModel(modelId: string): boolean {
return /gemini-3(?:\.1)?-pro/.test(modelId.toLowerCase());
}
function isGemini3FlashModel(modelId: string): boolean {
return /gemini-3(?:\.1)?-flash/.test(modelId.toLowerCase());
}
function isGemini3Model(modelId: string): boolean {
return isGemini3ProModel(modelId) || isGemini3FlashModel(modelId);
}
/**
* Check if an error is retryable (rate limit, server error, network error, etc.)
*/
function isRetryableError(status: number, errorText: string): boolean {
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
return true;
}
return /resource.?exhausted|rate.?limit|overloaded|service.?unavailable|other.?side.?closed/i.test(errorText);
}
/**
* Extract a clean, user-friendly error message from Google API error response.
* Parses JSON error responses and returns just the message field.
*/
function extractErrorMessage(errorText: string): string {
try {
const parsed = JSON.parse(errorText) as { error?: { message?: string } };
if (parsed.error?.message) {
return parsed.error.message;
}
} catch {
// Not JSON, return as-is
}
return errorText;
}
/**
* Sleep for a given number of milliseconds, respecting abort signal.
*/
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Request was aborted"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener("abort", () => {
clearTimeout(timeout);
reject(new Error("Request was aborted"));
});
});
}
interface CloudCodeAssistRequest {
project: string;
model: string;
request: {
contents: Content[];
sessionId?: string;
systemInstruction?: { role?: string; parts: { text: string }[] };
generationConfig?: {
maxOutputTokens?: number;
temperature?: number;
thinkingConfig?: ThinkingConfig;
};
tools?: ReturnType<typeof convertTools>;
toolConfig?: {
functionCallingConfig: {
mode: ReturnType<typeof mapToolChoice>;
};
};
};
requestType?: string;
userAgent?: string;
requestId?: string;
}
interface CloudCodeAssistResponseChunk {
response?: {
candidates?: Array<{
content?: {
role: string;
parts?: Array<{
text?: string;
thought?: boolean;
thoughtSignature?: string;
functionCall?: {
name: string;
args: Record<string, unknown>;
id?: string;
};
}>;
};
finishReason?: string;
}>;
usageMetadata?: {
promptTokenCount?: number;
candidatesTokenCount?: number;
thoughtsTokenCount?: number;
totalTokenCount?: number;
cachedContentTokenCount?: number;
};
modelVersion?: string;
responseId?: string;
};
traceId?: string;
}
export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli", GoogleGeminiCliOptions> = (
model: Model<"google-gemini-cli">,
context: Context,
options?: GoogleGeminiCliOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-gemini-cli" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// apiKey is JSON-encoded: { token, projectId }
const apiKeyRaw = options?.apiKey;
if (!apiKeyRaw) {
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
}
let accessToken: string;
let projectId: string;
try {
const parsed = JSON.parse(apiKeyRaw) as { token: string; projectId: string };
accessToken = parsed.token;
projectId = parsed.projectId;
} catch {
throw new Error("Invalid Google Cloud Code Assist credentials. Use /login to re-authenticate.");
}
if (!accessToken || !projectId) {
throw new Error("Missing token or projectId in Google Cloud credentials. Use /login to re-authenticate.");
}
const isAntigravity = model.provider === "google-antigravity";
const baseUrl = model.baseUrl?.trim();
const endpoints = baseUrl ? [baseUrl] : isAntigravity ? ANTIGRAVITY_ENDPOINT_FALLBACKS : [DEFAULT_ENDPOINT];
let requestBody = buildRequest(model, context, projectId, options, isAntigravity);
const nextRequestBody = await options?.onPayload?.(requestBody, model);
if (nextRequestBody !== undefined) {
requestBody = nextRequestBody as CloudCodeAssistRequest;
}
const headers = isAntigravity ? getAntigravityHeaders() : GEMINI_CLI_HEADERS;
const requestHeaders = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
Accept: "text/event-stream",
...headers,
...(needsClaudeThinkingBetaHeader(model) ? { "anthropic-beta": CLAUDE_THINKING_BETA_HEADER } : {}),
...options?.headers,
};
const requestBodyJson = JSON.stringify(requestBody);
// Fetch with retry logic for rate limits, transient errors, and endpoint fallbacks.
// On 403/404, immediately try the next endpoint (no delay).
// On 429/5xx, retry with backoff on the same or next endpoint.
let response: Response | undefined;
let lastError: Error | undefined;
let requestUrl: string | undefined;
let endpointIndex = 0;
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
try {
const endpoint = endpoints[endpointIndex];
requestUrl = `${endpoint}/v1internal:streamGenerateContent?alt=sse`;
response = await fetch(requestUrl, {
method: "POST",
headers: requestHeaders,
body: requestBodyJson,
signal: options?.signal,
});
if (response.ok) {
break; // Success, exit retry loop
}
const errorText = await response.text();
// On 403/404, cascade to the next endpoint immediately (no delay)
if ((response.status === 403 || response.status === 404) && endpointIndex < endpoints.length - 1) {
endpointIndex++;
continue;
}
// Check if retryable (429, 5xx, network patterns)
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
// Advance endpoint if possible
if (endpointIndex < endpoints.length - 1) {
endpointIndex++;
}
// Use server-provided delay or exponential backoff
const serverDelay = extractRetryDelay(errorText, response);
const delayMs = serverDelay ?? BASE_DELAY_MS * 2 ** attempt;
// Check if server delay exceeds max allowed (default: 60s)
const maxDelayMs = options?.maxRetryDelayMs ?? 60000;
if (maxDelayMs > 0 && serverDelay && serverDelay > maxDelayMs) {
const delaySeconds = Math.ceil(serverDelay / 1000);
throw new Error(
`Server requested ${delaySeconds}s retry delay (max: ${Math.ceil(maxDelayMs / 1000)}s). ${extractErrorMessage(errorText)}`,
);
}
await sleep(delayMs, options?.signal);
continue;
}
// Not retryable or max retries exceeded
throw new Error(`Cloud Code Assist API error (${response.status}): ${extractErrorMessage(errorText)}`);
} catch (error) {
// Check for abort - fetch throws AbortError, our code throws "Request was aborted"
if (error instanceof Error) {
if (error.name === "AbortError" || error.message === "Request was aborted") {
throw new Error("Request was aborted");
}
}
// Extract detailed error message from fetch errors (Node includes cause)
lastError = error instanceof Error ? error : new Error(String(error));
if (lastError.message === "fetch failed" && lastError.cause instanceof Error) {
lastError = new Error(`Network error: ${lastError.cause.message}`);
}
// Network errors are retryable
if (attempt < MAX_RETRIES) {
const delayMs = BASE_DELAY_MS * 2 ** attempt;
await sleep(delayMs, options?.signal);
continue;
}
throw lastError;
}
}
if (!response || !response.ok) {
throw lastError ?? new Error("Failed to get response after retries");
}
let started = false;
const ensureStarted = () => {
if (!started) {
stream.push({ type: "start", partial: output });
started = true;
}
};
const resetOutput = () => {
output.content = [];
output.usage = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
output.stopReason = "stop";
output.errorMessage = undefined;
output.timestamp = Date.now();
started = false;
};
const streamResponse = async (activeResponse: Response): Promise<boolean> => {
if (!activeResponse.body) {
throw new Error("No response body");
}
let hasContent = false;
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
// Read SSE stream
const reader = activeResponse.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
// Set up abort handler to cancel reader when signal fires
const abortHandler = () => {
void reader.cancel().catch(() => {});
};
options?.signal?.addEventListener("abort", abortHandler);
try {
while (true) {
// Check abort signal before each read
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split("\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (!line.startsWith("data:")) continue;
const jsonStr = line.slice(5).trim();
if (!jsonStr) continue;
let chunk: CloudCodeAssistResponseChunk;
try {
chunk = JSON.parse(jsonStr);
} catch {
continue;
}
// Unwrap the response
const responseData = chunk.response;
if (!responseData) continue;
const candidate = responseData.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
hasContent = true;
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
output.content.push(currentBlock);
ensureStarted();
stream.push({
type: "thinking_start",
contentIndex: blockIndex(),
partial: output,
});
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
ensureStarted();
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
hasContent = true;
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
const providedId = part.functionCall.id;
const needsNewId =
!providedId ||
output.content.some((b) => b.type === "toolCall" && b.id === providedId);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments: (part.functionCall.args as Record<string, unknown>) ?? {},
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
};
output.content.push(toolCall);
ensureStarted();
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall,
partial: output,
});
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReasonString(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (responseData.usageMetadata) {
// promptTokenCount includes cachedContentTokenCount, so subtract to get fresh input
const promptTokens = responseData.usageMetadata.promptTokenCount || 0;
const cacheReadTokens = responseData.usageMetadata.cachedContentTokenCount || 0;
output.usage = {
input: promptTokens - cacheReadTokens,
output:
(responseData.usageMetadata.candidatesTokenCount || 0) +
(responseData.usageMetadata.thoughtsTokenCount || 0),
cacheRead: cacheReadTokens,
cacheWrite: 0,
totalTokens: responseData.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
}
} finally {
options?.signal?.removeEventListener("abort", abortHandler);
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
return hasContent;
};
let receivedContent = false;
let currentResponse = response;
for (let emptyAttempt = 0; emptyAttempt <= MAX_EMPTY_STREAM_RETRIES; emptyAttempt++) {
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (emptyAttempt > 0) {
const backoffMs = EMPTY_STREAM_BASE_DELAY_MS * 2 ** (emptyAttempt - 1);
await sleep(backoffMs, options?.signal);
if (!requestUrl) {
throw new Error("Missing request URL");
}
currentResponse = await fetch(requestUrl, {
method: "POST",
headers: requestHeaders,
body: requestBodyJson,
signal: options?.signal,
});
if (!currentResponse.ok) {
const retryErrorText = await currentResponse.text();
throw new Error(`Cloud Code Assist API error (${currentResponse.status}): ${retryErrorText}`);
}
}
const streamed = await streamResponse(currentResponse);
if (streamed) {
receivedContent = true;
break;
}
if (emptyAttempt < MAX_EMPTY_STREAM_RETRIES) {
resetOutput();
}
}
if (!receivedContent) {
throw new Error("Cloud Code Assist API returned an empty response");
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogleGeminiCli: StreamFunction<"google-gemini-cli", SimpleStreamOptions> = (
model: Model<"google-gemini-cli">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey;
if (!apiKey) {
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamGoogleGeminiCli(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleGeminiCliOptions);
}
const effort = clampReasoning(options.reasoning)!;
if (isGemini3Model(model.id)) {
return streamGoogleGeminiCli(model, context, {
...base,
thinking: {
enabled: true,
level: getGeminiCliThinkingLevel(effort, model.id),
},
} satisfies GoogleGeminiCliOptions);
}
const defaultBudgets: ThinkingBudgets = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
};
const budgets = { ...defaultBudgets, ...options.thinkingBudgets };
const minOutputTokens = 1024;
let thinkingBudget = budgets[effort]!;
const maxTokens = Math.min((base.maxTokens || 0) + thinkingBudget, model.maxTokens);
if (maxTokens <= thinkingBudget) {
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
}
return streamGoogleGeminiCli(model, context, {
...base,
maxTokens,
thinking: {
enabled: true,
budgetTokens: thinkingBudget,
},
} satisfies GoogleGeminiCliOptions);
};
export function buildRequest(
model: Model<"google-gemini-cli">,
context: Context,
projectId: string,
options: GoogleGeminiCliOptions = {},
isAntigravity = false,
): CloudCodeAssistRequest {
const contents = convertMessages(model, context);
const generationConfig: CloudCodeAssistRequest["request"]["generationConfig"] = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
// Thinking config
if (options.thinking?.enabled && model.reasoning) {
generationConfig.thinkingConfig = {
includeThoughts: true,
};
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
}
const request: CloudCodeAssistRequest["request"] = {
contents,
};
request.sessionId = options.sessionId;
// System instruction must be object with parts, not plain string
if (context.systemPrompt) {
request.systemInstruction = {
parts: [{ text: sanitizeSurrogates(context.systemPrompt) }],
};
}
if (Object.keys(generationConfig).length > 0) {
request.generationConfig = generationConfig;
}
if (context.tools && context.tools.length > 0) {
// Claude models on Cloud Code Assist need the legacy `parameters` field;
// the API translates it into Anthropic's `input_schema`.
const useParameters = model.id.startsWith("claude-");
request.tools = convertTools(context.tools, useParameters);
if (options.toolChoice) {
request.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
}
}
if (isAntigravity) {
const existingParts = request.systemInstruction?.parts ?? [];
request.systemInstruction = {
role: "user",
parts: [
{ text: ANTIGRAVITY_SYSTEM_INSTRUCTION },
{ text: `Please ignore following [ignore]${ANTIGRAVITY_SYSTEM_INSTRUCTION}[/ignore]` },
...existingParts,
],
};
}
return {
project: projectId,
model: model.id,
request,
...(isAntigravity ? { requestType: "agent" } : {}),
userAgent: isAntigravity ? "antigravity" : "pi-coding-agent",
requestId: `${isAntigravity ? "agent" : "pi"}-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`,
};
}
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
function getGeminiCliThinkingLevel(effort: ClampedThinkingLevel, modelId: string): GoogleThinkingLevel {
if (isGemini3ProModel(modelId)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}

View file

@ -0,0 +1,313 @@
/**
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
*/
import { type Content, FinishReason, FunctionCallingConfigMode, type Part } from "@google/genai";
import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from "../types.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
type GoogleApiType = "google-generative-ai" | "google-gemini-cli" | "google-vertex";
/**
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
*
* Protocol note (Gemini / Vertex AI thought signatures):
* - `thought: true` is the definitive marker for thinking content (thought summaries).
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
* used to preserve reasoning context across multi-turn interactions.
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
* indicate the part itself is thinking content.
* - For non-functionCall responses, the signature appears on the last part for context replay.
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
* do not merge/move signatures across parts.
*
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
*/
export function isThinkingPart(part: Pick<Part, "thought" | "thoughtSignature">): boolean {
return part.thought === true;
}
/**
* Retain thought signatures during streaming.
*
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
* This helper preserves the last non-empty signature for the current block.
*
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
* a signature from being overwritten with `undefined` within the same streamed block.
*/
export function retainThoughtSignature(existing: string | undefined, incoming: string | undefined): string | undefined {
if (typeof incoming === "string" && incoming.length > 0) return incoming;
return existing;
}
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
// Sentinel value that tells the Gemini API to skip thought signature validation.
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
function isValidThoughtSignature(signature: string | undefined): boolean {
if (!signature) return false;
if (signature.length % 4 !== 0) return false;
return base64SignaturePattern.test(signature);
}
/**
* Only keep signatures from the same provider/model and with valid base64.
*/
function resolveThoughtSignature(isSameProviderAndModel: boolean, signature: string | undefined): string | undefined {
return isSameProviderAndModel && isValidThoughtSignature(signature) ? signature : undefined;
}
/**
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
*/
export function requiresToolCallId(modelId: string): boolean {
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
}
/**
* Convert internal messages to Gemini Content[] format.
*/
export function convertMessages<T extends GoogleApiType>(model: Model<T>, context: Context): Content[] {
const contents: Content[] = [];
const normalizeToolCallId = (id: string): string => {
if (!requiresToolCallId(model.id)) return id;
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
};
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
contents.push({
role: "user",
parts: [{ text: sanitizeSurrogates(msg.content) }],
});
} else {
const parts: Part[] = msg.content.map((item) => {
if (item.type === "text") {
return { text: sanitizeSurrogates(item.text) };
} else {
return {
inlineData: {
mimeType: item.mimeType,
data: item.data,
},
};
}
});
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
if (filteredParts.length === 0) continue;
contents.push({
role: "user",
parts: filteredParts,
});
}
} else if (msg.role === "assistant") {
const parts: Part[] = [];
// Check if message is from same provider and model - only then keep thinking blocks
const isSameProviderAndModel = msg.provider === model.provider && msg.model === model.id;
for (const block of msg.content) {
if (block.type === "text") {
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
if (!block.text || block.text.trim() === "") continue;
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.textSignature);
parts.push({
text: sanitizeSurrogates(block.text),
...(thoughtSignature && { thoughtSignature }),
});
} else if (block.type === "thinking") {
// Skip empty thinking blocks
if (!block.thinking || block.thinking.trim() === "") continue;
// Only keep as thinking block if same provider AND same model
// Otherwise convert to plain text (no tags to avoid model mimicking them)
if (isSameProviderAndModel) {
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thinkingSignature);
parts.push({
thought: true,
text: sanitizeSurrogates(block.thinking),
...(thoughtSignature && { thoughtSignature }),
});
} else {
parts.push({
text: sanitizeSurrogates(block.thinking),
});
}
} else if (block.type === "toolCall") {
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thoughtSignature);
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
// Use the skip_thought_signature_validator sentinel for unsigned function calls
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
const effectiveSignature = thoughtSignature || (isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
const part: Part = {
functionCall: {
name: block.name,
args: block.arguments ?? {},
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
},
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
};
parts.push(part);
}
}
if (parts.length === 0) continue;
contents.push({
role: "model",
parts,
});
} else if (msg.role === "toolResult") {
// Extract text and image content
const textContent = msg.content.filter((c): c is TextContent => c.type === "text");
const textResult = textContent.map((c) => c.text).join("\n");
const imageContent = model.input.includes("image")
? msg.content.filter((c): c is ImageContent => c.type === "image")
: [];
const hasText = textResult.length > 0;
const hasImages = imageContent.length > 0;
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
// Older models don't support this, so we put images in a separate user message.
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
// Use "output" key for success, "error" key for errors as per SDK documentation
const responseValue = hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "";
const imageParts: Part[] = imageContent.map((imageBlock) => ({
inlineData: {
mimeType: imageBlock.mimeType,
data: imageBlock.data,
},
}));
const includeId = requiresToolCallId(model.id);
const functionResponsePart: Part = {
functionResponse: {
name: msg.toolName,
response: msg.isError ? { error: responseValue } : { output: responseValue },
// Nest images inside functionResponse.parts for Gemini 3
...(hasImages && supportsMultimodalFunctionResponse && { parts: imageParts }),
...(includeId ? { id: msg.toolCallId } : {}),
},
};
// Cloud Code Assist API requires all function responses to be in a single user turn.
// Check if the last content is already a user turn with function responses and merge.
const lastContent = contents[contents.length - 1];
if (lastContent?.role === "user" && lastContent.parts?.some((p) => p.functionResponse)) {
lastContent.parts.push(functionResponsePart);
} else {
contents.push({
role: "user",
parts: [functionResponsePart],
});
}
// For older models, add images in a separate user message
if (hasImages && !supportsMultimodalFunctionResponse) {
contents.push({
role: "user",
parts: [{ text: "Tool result image:" }, ...imageParts],
});
}
}
}
return contents;
}
/**
* Convert tools to Gemini function declarations format.
*
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
* models, where the API translates `parameters` into Anthropic's `input_schema`.
*/
export function convertTools(
tools: Tool[],
useParameters = false,
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
if (tools.length === 0) return undefined;
return [
{
functionDeclarations: tools.map((tool) => ({
name: tool.name,
description: tool.description,
...(useParameters ? { parameters: tool.parameters } : { parametersJsonSchema: tool.parameters }),
})),
},
];
}
/**
* Map tool choice string to Gemini FunctionCallingConfigMode.
*/
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
switch (choice) {
case "auto":
return FunctionCallingConfigMode.AUTO;
case "none":
return FunctionCallingConfigMode.NONE;
case "any":
return FunctionCallingConfigMode.ANY;
default:
return FunctionCallingConfigMode.AUTO;
}
}
/**
* Map Gemini FinishReason to our StopReason.
*/
export function mapStopReason(reason: FinishReason): StopReason {
switch (reason) {
case FinishReason.STOP:
return "stop";
case FinishReason.MAX_TOKENS:
return "length";
case FinishReason.BLOCKLIST:
case FinishReason.PROHIBITED_CONTENT:
case FinishReason.SPII:
case FinishReason.SAFETY:
case FinishReason.IMAGE_SAFETY:
case FinishReason.IMAGE_PROHIBITED_CONTENT:
case FinishReason.IMAGE_RECITATION:
case FinishReason.IMAGE_OTHER:
case FinishReason.RECITATION:
case FinishReason.FINISH_REASON_UNSPECIFIED:
case FinishReason.OTHER:
case FinishReason.LANGUAGE:
case FinishReason.MALFORMED_FUNCTION_CALL:
case FinishReason.UNEXPECTED_TOOL_CALL:
case FinishReason.NO_IMAGE:
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Map string finish reason to our StopReason (for raw API responses).
*/
export function mapStopReasonString(reason: string): StopReason {
switch (reason) {
case "STOP":
return "stop";
case "MAX_TOKENS":
return "length";
default:
return "error";
}
}

View file

@ -0,0 +1,485 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
ThinkingLevel,
} from "@google/genai";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
ThinkingLevel as PiThinkingLevel,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleVertexOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
project?: string;
location?: string;
}
const API_VERSION = "v1";
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
MINIMAL: ThinkingLevel.MINIMAL,
LOW: ThinkingLevel.LOW,
MEDIUM: ThinkingLevel.MEDIUM,
HIGH: ThinkingLevel.HIGH,
};
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOptions> = (
model: Model<"google-vertex">,
context: Context,
options?: GoogleVertexOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-vertex" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const project = resolveProject(options);
const location = resolveLocation(options);
const client = createClient(model, project, location, options?.headers);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as GenerateContentParameters;
}
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
const providedId = part.functionCall.id;
const needsNewId =
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments: (part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
};
output.content.push(toolCall);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogleVertex: StreamFunction<"google-vertex", SimpleStreamOptions> = (
model: Model<"google-vertex">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const base = buildBaseOptions(model, options, undefined);
if (!options?.reasoning) {
return streamGoogleVertex(model, context, {
...base,
thinking: { enabled: false },
} satisfies GoogleVertexOptions);
}
const effort = clampReasoning(options.reasoning)!;
const geminiModel = model as unknown as Model<"google-generative-ai">;
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, geminiModel),
},
} satisfies GoogleVertexOptions);
}
return streamGoogleVertex(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(geminiModel, effort, options.thinkingBudgets),
},
} satisfies GoogleVertexOptions);
};
function createClient(
model: Model<"google-vertex">,
project: string,
location: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: { headers?: Record<string, string> } = {};
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
return new GoogleGenAI({
vertexai: true,
project,
location,
apiVersion: API_VERSION,
httpOptions: hasHttpOptions ? httpOptions : undefined,
});
}
function resolveProject(options?: GoogleVertexOptions): string {
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
if (!project) {
throw new Error(
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
);
}
return project;
}
function resolveLocation(options?: GoogleVertexOptions): string {
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
if (!location) {
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
}
return location;
}
function buildParams(
model: Model<"google-vertex">,
context: Context,
options: GoogleVertexOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,455 @@
import {
type GenerateContentConfig,
type GenerateContentParameters,
GoogleGenAI,
type ThinkingConfig,
} from "@google/genai";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingBudgets,
ThinkingContent,
ThinkingLevel,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
import {
convertMessages,
convertTools,
isThinkingPart,
mapStopReason,
mapToolChoice,
retainThoughtSignature,
} from "./google-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
export interface GoogleOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any";
thinking?: {
enabled: boolean;
budgetTokens?: number; // -1 for dynamic, 0 to disable
level?: GoogleThinkingLevel;
};
}
// Counter for generating unique tool call IDs
let toolCallCounter = 0;
export const streamGoogle: StreamFunction<"google-generative-ai", GoogleOptions> = (
model: Model<"google-generative-ai">,
context: Context,
options?: GoogleOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "google-generative-ai" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, apiKey, options?.headers);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as GenerateContentParameters;
}
const googleStream = await client.models.generateContentStream(params);
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const chunk of googleStream) {
const candidate = chunk.candidates?.[0];
if (candidate?.content?.parts) {
for (const part of candidate.content.parts) {
if (part.text !== undefined) {
const isThinking = isThinkingPart(part);
if (
!currentBlock ||
(isThinking && currentBlock.type !== "thinking") ||
(!isThinking && currentBlock.type !== "text")
) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blocks.length - 1,
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (isThinking) {
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
} else {
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
}
if (currentBlock.type === "thinking") {
currentBlock.thinking += part.text;
currentBlock.thinkingSignature = retainThoughtSignature(
currentBlock.thinkingSignature,
part.thoughtSignature,
);
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
} else {
currentBlock.text += part.text;
currentBlock.textSignature = retainThoughtSignature(
currentBlock.textSignature,
part.thoughtSignature,
);
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: part.text,
partial: output,
});
}
}
if (part.functionCall) {
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
currentBlock = null;
}
// Generate unique ID if not provided or if it's a duplicate
const providedId = part.functionCall.id;
const needsNewId =
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
const toolCallId = needsNewId
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
: providedId;
const toolCall: ToolCall = {
type: "toolCall",
id: toolCallId,
name: part.functionCall.name || "",
arguments: (part.functionCall.args as Record<string, any>) ?? {},
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
};
output.content.push(toolCall);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: JSON.stringify(toolCall.arguments),
partial: output,
});
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
}
}
if (candidate?.finishReason) {
output.stopReason = mapStopReason(candidate.finishReason);
if (output.content.some((b) => b.type === "toolCall")) {
output.stopReason = "toolUse";
}
}
if (chunk.usageMetadata) {
output.usage = {
input: chunk.usageMetadata.promptTokenCount || 0,
output:
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
cacheWrite: 0,
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
}
if (currentBlock) {
if (currentBlock.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
} else {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
}
}
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
// Remove internal index property used during streaming
for (const block of output.content) {
if ("index" in block) {
delete (block as { index?: number }).index;
}
}
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleGoogle: StreamFunction<"google-generative-ai", SimpleStreamOptions> = (
model: Model<"google-generative-ai">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
if (!options?.reasoning) {
return streamGoogle(model, context, { ...base, thinking: { enabled: false } } satisfies GoogleOptions);
}
const effort = clampReasoning(options.reasoning)!;
const googleModel = model as Model<"google-generative-ai">;
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
level: getGemini3ThinkingLevel(effort, googleModel),
},
} satisfies GoogleOptions);
}
return streamGoogle(model, context, {
...base,
thinking: {
enabled: true,
budgetTokens: getGoogleBudget(googleModel, effort, options.thinkingBudgets),
},
} satisfies GoogleOptions);
};
function createClient(
model: Model<"google-generative-ai">,
apiKey?: string,
optionsHeaders?: Record<string, string>,
): GoogleGenAI {
const httpOptions: { baseUrl?: string; apiVersion?: string; headers?: Record<string, string> } = {};
if (model.baseUrl) {
httpOptions.baseUrl = model.baseUrl;
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
}
if (model.headers || optionsHeaders) {
httpOptions.headers = { ...model.headers, ...optionsHeaders };
}
return new GoogleGenAI({
apiKey,
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
});
}
function buildParams(
model: Model<"google-generative-ai">,
context: Context,
options: GoogleOptions = {},
): GenerateContentParameters {
const contents = convertMessages(model, context);
const generationConfig: GenerateContentConfig = {};
if (options.temperature !== undefined) {
generationConfig.temperature = options.temperature;
}
if (options.maxTokens !== undefined) {
generationConfig.maxOutputTokens = options.maxTokens;
}
const config: GenerateContentConfig = {
...(Object.keys(generationConfig).length > 0 && generationConfig),
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
};
if (context.tools && context.tools.length > 0 && options.toolChoice) {
config.toolConfig = {
functionCallingConfig: {
mode: mapToolChoice(options.toolChoice),
},
};
} else {
config.toolConfig = undefined;
}
if (options.thinking?.enabled && model.reasoning) {
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
if (options.thinking.level !== undefined) {
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
thinkingConfig.thinkingLevel = options.thinking.level as any;
} else if (options.thinking.budgetTokens !== undefined) {
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
}
config.thinkingConfig = thinkingConfig;
}
if (options.signal) {
if (options.signal.aborted) {
throw new Error("Request aborted");
}
config.abortSignal = options.signal;
}
const params: GenerateContentParameters = {
model: model.id,
contents,
config,
};
return params;
}
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
}
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
}
function getGemini3ThinkingLevel(
effort: ClampedThinkingLevel,
model: Model<"google-generative-ai">,
): GoogleThinkingLevel {
if (isGemini3ProModel(model)) {
switch (effort) {
case "minimal":
case "low":
return "LOW";
case "medium":
case "high":
return "HIGH";
}
}
switch (effort) {
case "minimal":
return "MINIMAL";
case "low":
return "LOW";
case "medium":
return "MEDIUM";
case "high":
return "HIGH";
}
}
function getGoogleBudget(
model: Model<"google-generative-ai">,
effort: ClampedThinkingLevel,
customBudgets?: ThinkingBudgets,
): number {
if (customBudgets?.[effort] !== undefined) {
return customBudgets[effort]!;
}
if (model.id.includes("2.5-pro")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 32768,
};
return budgets[effort];
}
if (model.id.includes("2.5-flash")) {
const budgets: Record<ClampedThinkingLevel, number> = {
minimal: 128,
low: 2048,
medium: 8192,
high: 24576,
};
return budgets[effort];
}
return -1;
}

View file

@ -0,0 +1,582 @@
import { Mistral } from "@mistralai/mistralai";
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
import type {
ChatCompletionStreamRequest,
ChatCompletionStreamRequestMessages,
CompletionEvent,
ContentChunk,
FunctionTool,
} from "@mistralai/mistralai/models/components/index.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
/**
* Provider-specific options for the Mistral API.
*/
export interface MistralOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } };
promptMode?: "reasoning";
}
/**
* Stream responses from Mistral using `chat.stream`.
*/
export const streamMistral: StreamFunction<"mistral-conversations", MistralOptions> = (
model: Model<"mistral-conversations">,
context: Context,
options?: MistralOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output = createOutput(model);
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
const mistral = new Mistral({
apiKey,
serverURL: model.baseUrl,
});
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeMistralToolCallId(id));
let payload = buildChatPayload(model, context, transformedMessages, options);
const nextPayload = await options?.onPayload?.(payload, model);
if (nextPayload !== undefined) {
payload = nextPayload as ChatCompletionStreamRequest;
}
const mistralStream = await mistral.chat.stream(payload, buildRequestOptions(model, options));
stream.push({ type: "start", partial: output });
await consumeChatStream(model, output, stream, mistralStream);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = formatMistralError(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
/**
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
*/
export const streamSimpleMistral: StreamFunction<"mistral-conversations", SimpleStreamOptions> = (
model: Model<"mistral-conversations">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoning = clampReasoning(options?.reasoning);
return streamMistral(model, context, {
...base,
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
} satisfies MistralOptions);
};
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
return {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
}
function createMistralToolCallIdNormalizer(): (id: string) => string {
const idMap = new Map<string, string>();
const reverseMap = new Map<string, string>();
return (id: string): string => {
const existing = idMap.get(id);
if (existing) return existing;
let attempt = 0;
while (true) {
const candidate = deriveMistralToolCallId(id, attempt);
const owner = reverseMap.get(candidate);
if (!owner || owner === id) {
idMap.set(id, candidate);
reverseMap.set(candidate, id);
return candidate;
}
attempt++;
}
};
}
function deriveMistralToolCallId(id: string, attempt: number): string {
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH) return normalized;
const seedBase = normalized || id;
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
return shortHash(seed)
.replace(/[^a-zA-Z0-9]/g, "")
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
}
function formatMistralError(error: unknown): string {
if (error instanceof Error) {
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
const statusCode = typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
const bodyText = typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
if (statusCode !== undefined && bodyText) {
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
}
if (statusCode !== undefined) return `Mistral API error (${statusCode}): ${error.message}`;
return error.message;
}
return safeJsonStringify(error);
}
function truncateErrorText(text: string, maxChars: number): string {
if (text.length <= maxChars) return text;
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
}
function safeJsonStringify(value: unknown): string {
try {
const serialized = JSON.stringify(value);
return serialized === undefined ? String(value) : serialized;
} catch {
return String(value);
}
}
function buildRequestOptions(model: Model<"mistral-conversations">, options?: MistralOptions): RequestOptions {
const requestOptions: RequestOptions = {};
if (options?.signal) requestOptions.signal = options.signal;
requestOptions.retries = { strategy: "none" };
const headers: Record<string, string> = {};
if (model.headers) Object.assign(headers, model.headers);
if (options?.headers) Object.assign(headers, options.headers);
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
// Respect explicit caller-provided header values.
if (options?.sessionId && !headers["x-affinity"]) {
headers["x-affinity"] = options.sessionId;
}
if (Object.keys(headers).length > 0) {
requestOptions.headers = headers;
}
return requestOptions;
}
function buildChatPayload(
model: Model<"mistral-conversations">,
context: Context,
messages: Message[],
options?: MistralOptions,
): ChatCompletionStreamRequest {
const payload: ChatCompletionStreamRequest = {
model: model.id,
stream: true,
messages: toChatMessages(messages, model.input.includes("image")),
};
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
if (options?.temperature !== undefined) payload.temperature = options.temperature;
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
if (options?.toolChoice) payload.toolChoice = mapToolChoice(options.toolChoice);
if (options?.promptMode) payload.promptMode = options.promptMode as any;
if (context.systemPrompt) {
payload.messages.unshift({
role: "system",
content: sanitizeSurrogates(context.systemPrompt),
});
}
return payload;
}
async function consumeChatStream(
model: Model<"mistral-conversations">,
output: AssistantMessage,
stream: AssistantMessageEventStream,
mistralStream: AsyncIterable<CompletionEvent>,
): Promise<void> {
let currentBlock: TextContent | ThinkingContent | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const toolBlocksByKey = new Map<string, number>();
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (!block) return;
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
return;
}
if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
}
};
for await (const event of mistralStream) {
const chunk = event.data;
if (chunk.usage) {
output.usage.input = chunk.usage.promptTokens || 0;
output.usage.output = chunk.usage.completionTokens || 0;
output.usage.cacheRead = 0;
output.usage.cacheWrite = 0;
output.usage.totalTokens = chunk.usage.totalTokens || output.usage.input + output.usage.output;
calculateCost(model, output.usage);
}
const choice = chunk.choices[0];
if (!choice) continue;
if (choice.finishReason) {
output.stopReason = mapChatStopReason(choice.finishReason);
}
const delta = choice.delta;
if (delta.content !== null && delta.content !== undefined) {
const contentItems = typeof delta.content === "string" ? [delta.content] : delta.content;
for (const item of contentItems) {
if (typeof item === "string") {
const textDelta = sanitizeSurrogates(item);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
continue;
}
if (item.type === "thinking") {
const deltaText = item.thinking
.map((part) => ("text" in part ? part.text : ""))
.filter((text) => text.length > 0)
.join("");
const thinkingDelta = sanitizeSurrogates(deltaText);
if (!thinkingDelta) continue;
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
}
currentBlock.thinking += thinkingDelta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: thinkingDelta,
partial: output,
});
continue;
}
if (item.type === "text") {
const textDelta = sanitizeSurrogates(item.text);
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
currentBlock.text += textDelta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: textDelta,
partial: output,
});
}
}
}
const toolCalls = delta.toolCalls || [];
for (const toolCall of toolCalls) {
if (currentBlock) {
finishCurrentBlock(currentBlock);
currentBlock = null;
}
const callId =
toolCall.id && toolCall.id !== "null"
? toolCall.id
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
const key = `${callId}:${toolCall.index || 0}`;
const existingIndex = toolBlocksByKey.get(key);
let block: (ToolCall & { partialArgs?: string }) | undefined;
if (existingIndex !== undefined) {
const existing = output.content[existingIndex];
if (existing?.type === "toolCall") {
block = existing as ToolCall & { partialArgs?: string };
}
}
if (!block) {
block = {
type: "toolCall",
id: callId,
name: toolCall.function.name,
arguments: {},
partialArgs: "",
};
output.content.push(block);
toolBlocksByKey.set(key, output.content.length - 1);
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
}
const argsDelta =
typeof toolCall.function.arguments === "string"
? toolCall.function.arguments
: JSON.stringify(toolCall.function.arguments || {});
block.partialArgs = (block.partialArgs || "") + argsDelta;
block.arguments = parseStreamingJson<Record<string, unknown>>(block.partialArgs);
stream.push({
type: "toolcall_delta",
contentIndex: toolBlocksByKey.get(key)!,
delta: argsDelta,
partial: output,
});
}
}
finishCurrentBlock(currentBlock);
for (const index of toolBlocksByKey.values()) {
const block = output.content[index];
if (block.type !== "toolCall") continue;
const toolBlock = block as ToolCall & { partialArgs?: string };
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(toolBlock.partialArgs);
delete toolBlock.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: index,
toolCall: toolBlock,
partial: output,
});
}
}
function toFunctionTools(tools: Tool[]): Array<FunctionTool & { type: "function" }> {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as unknown as Record<string, unknown>,
strict: false,
},
}));
}
function toChatMessages(messages: Message[], supportsImages: boolean): ChatCompletionStreamRequestMessages[] {
const result: ChatCompletionStreamRequestMessages[] = [];
for (const msg of messages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
continue;
}
const hadImages = msg.content.some((item) => item.type === "image");
const content: ContentChunk[] = msg.content
.filter((item) => item.type === "text" || supportsImages)
.map((item) => {
if (item.type === "text") return { type: "text", text: sanitizeSurrogates(item.text) };
return { type: "image_url", imageUrl: `data:${item.mimeType};base64,${item.data}` };
});
if (content.length > 0) {
result.push({ role: "user", content });
continue;
}
if (hadImages && !supportsImages) {
result.push({ role: "user", content: "(image omitted: model does not support images)" });
}
continue;
}
if (msg.role === "assistant") {
const contentParts: ContentChunk[] = [];
const toolCalls: Array<{ id: string; type: "function"; function: { name: string; arguments: string } }> = [];
for (const block of msg.content) {
if (block.type === "text") {
if (block.text.trim().length > 0) {
contentParts.push({ type: "text", text: sanitizeSurrogates(block.text) });
}
continue;
}
if (block.type === "thinking") {
if (block.thinking.trim().length > 0) {
contentParts.push({
type: "thinking",
thinking: [{ type: "text", text: sanitizeSurrogates(block.thinking) }],
});
}
continue;
}
toolCalls.push({
id: block.id,
type: "function",
function: { name: block.name, arguments: JSON.stringify(block.arguments || {}) },
});
}
const assistantMessage: ChatCompletionStreamRequestMessages = { role: "assistant" };
if (contentParts.length > 0) assistantMessage.content = contentParts;
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
if (contentParts.length > 0 || toolCalls.length > 0) result.push(assistantMessage);
continue;
}
const toolContent: ContentChunk[] = [];
const textResult = msg.content
.filter((part) => part.type === "text")
.map((part) => (part.type === "text" ? sanitizeSurrogates(part.text) : ""))
.join("\n");
const hasImages = msg.content.some((part) => part.type === "image");
const toolText = buildToolResultText(textResult, hasImages, supportsImages, msg.isError);
toolContent.push({ type: "text", text: toolText });
for (const part of msg.content) {
if (!supportsImages) continue;
if (part.type !== "image") continue;
toolContent.push({
type: "image_url",
imageUrl: `data:${part.mimeType};base64,${part.data}`,
});
}
result.push({
role: "tool",
toolCallId: msg.toolCallId,
name: msg.toolName,
content: toolContent,
});
}
return result;
}
function buildToolResultText(text: string, hasImages: boolean, supportsImages: boolean, isError: boolean): string {
const trimmed = text.trim();
const errorPrefix = isError ? "[tool error] " : "";
if (trimmed.length > 0) {
const imageSuffix = hasImages && !supportsImages ? "\n[tool image omitted: model does not support images]" : "";
return `${errorPrefix}${trimmed}${imageSuffix}`;
}
if (hasImages) {
if (supportsImages) {
return isError ? "[tool error] (see attached image)" : "(see attached image)";
}
return isError
? "[tool error] (image omitted: model does not support images)"
: "(image omitted: model does not support images)";
}
return isError ? "[tool error] (no tool output)" : "(no tool output)";
}
function mapToolChoice(
choice: MistralOptions["toolChoice"],
): "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } } | undefined {
if (!choice) return undefined;
if (choice === "auto" || choice === "none" || choice === "any" || choice === "required") {
return choice as any;
}
return {
type: "function",
function: { name: choice.function.name },
};
}
function mapChatStopReason(reason: string | null): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
case "model_length":
return "length";
case "tool_calls":
return "toolUse";
case "error":
return "error";
default:
return "stop";
}
}

View file

@ -0,0 +1,875 @@
import type * as NodeOs from "node:os";
import type { Tool as OpenAITool, ResponseInput, ResponseStreamEvent } from "openai/resources/responses/responses.js";
// NEVER convert to top-level runtime imports - breaks browser/Vite builds (web-ui)
let _os: typeof NodeOs | null = null;
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const NODE_OS_SPECIFIER = "node:" + "os";
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
_os = m as typeof NodeOs;
});
}
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
// ============================================================================
// Configuration
// ============================================================================
const DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api";
const JWT_CLAIM_PATH = "https://api.openai.com/auth" as const;
const MAX_RETRIES = 3;
const BASE_DELAY_MS = 1000;
const CODEX_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
const CODEX_RESPONSE_STATUSES = new Set<CodexResponseStatus>([
"completed",
"incomplete",
"failed",
"cancelled",
"queued",
"in_progress",
]);
// ============================================================================
// Types
// ============================================================================
export interface OpenAICodexResponsesOptions extends StreamOptions {
reasoningEffort?: "none" | "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "concise" | "detailed" | "off" | "on" | null;
textVerbosity?: "low" | "medium" | "high";
}
type CodexResponseStatus = "completed" | "incomplete" | "failed" | "cancelled" | "queued" | "in_progress";
interface RequestBody {
model: string;
store?: boolean;
stream?: boolean;
instructions?: string;
input?: ResponseInput;
tools?: OpenAITool[];
tool_choice?: "auto";
parallel_tool_calls?: boolean;
temperature?: number;
reasoning?: { effort?: string; summary?: string };
text?: { verbosity?: string };
include?: string[];
prompt_cache_key?: string;
[key: string]: unknown;
}
// ============================================================================
// Retry Helpers
// ============================================================================
function isRetryableError(status: number, errorText: string): boolean {
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
return true;
}
return /rate.?limit|overloaded|service.?unavailable|upstream.?connect|connection.?refused/i.test(errorText);
}
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Request was aborted"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener("abort", () => {
clearTimeout(timeout);
reject(new Error("Request was aborted"));
});
});
}
// ============================================================================
// Main Stream Function
// ============================================================================
export const streamOpenAICodexResponses: StreamFunction<"openai-codex-responses", OpenAICodexResponsesOptions> = (
model: Model<"openai-codex-responses">,
context: Context,
options?: OpenAICodexResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: "openai-codex-responses" as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const accountId = extractAccountId(apiKey);
let body = buildRequestBody(model, context, options);
const nextBody = await options?.onPayload?.(body, model);
if (nextBody !== undefined) {
body = nextBody as RequestBody;
}
const headers = buildHeaders(model.headers, options?.headers, accountId, apiKey, options?.sessionId);
const bodyJson = JSON.stringify(body);
const transport = options?.transport || "sse";
if (transport !== "sse") {
let websocketStarted = false;
try {
await processWebSocketStream(
resolveCodexWebSocketUrl(model.baseUrl),
body,
headers,
output,
stream,
model,
() => {
websocketStarted = true;
},
options,
);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
stream.push({
type: "done",
reason: output.stopReason as "stop" | "length" | "toolUse",
message: output,
});
stream.end();
return;
} catch (error) {
if (transport === "websocket" || websocketStarted) {
throw error;
}
}
}
// Fetch with retry logic for rate limits and transient errors
let response: Response | undefined;
let lastError: Error | undefined;
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
try {
response = await fetch(resolveCodexUrl(model.baseUrl), {
method: "POST",
headers,
body: bodyJson,
signal: options?.signal,
});
if (response.ok) {
break;
}
const errorText = await response.text();
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
const delayMs = BASE_DELAY_MS * 2 ** attempt;
await sleep(delayMs, options?.signal);
continue;
}
// Parse error for friendly message on final attempt or non-retryable error
const fakeResponse = new Response(errorText, {
status: response.status,
statusText: response.statusText,
});
const info = await parseErrorResponse(fakeResponse);
throw new Error(info.friendlyMessage || info.message);
} catch (error) {
if (error instanceof Error) {
if (error.name === "AbortError" || error.message === "Request was aborted") {
throw new Error("Request was aborted");
}
}
lastError = error instanceof Error ? error : new Error(String(error));
// Network errors are retryable
if (attempt < MAX_RETRIES && !lastError.message.includes("usage limit")) {
const delayMs = BASE_DELAY_MS * 2 ** attempt;
await sleep(delayMs, options?.signal);
continue;
}
throw lastError;
}
}
if (!response?.ok) {
throw lastError ?? new Error("Failed after retries");
}
if (!response.body) {
throw new Error("No response body");
}
stream.push({ type: "start", partial: output });
await processStream(response, output, stream, model);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output });
stream.end();
} catch (error) {
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : String(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAICodexResponses: StreamFunction<"openai-codex-responses", SimpleStreamOptions> = (
model: Model<"openai-codex-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
return streamOpenAICodexResponses(model, context, {
...base,
reasoningEffort,
} satisfies OpenAICodexResponsesOptions);
};
// ============================================================================
// Request Building
// ============================================================================
function buildRequestBody(
model: Model<"openai-codex-responses">,
context: Context,
options?: OpenAICodexResponsesOptions,
): RequestBody {
const messages = convertResponsesMessages(model, context, CODEX_TOOL_CALL_PROVIDERS, {
includeSystemPrompt: false,
});
const body: RequestBody = {
model: model.id,
store: false,
stream: true,
instructions: context.systemPrompt,
input: messages,
text: { verbosity: options?.textVerbosity || "medium" },
include: ["reasoning.encrypted_content"],
prompt_cache_key: options?.sessionId,
tool_choice: "auto",
parallel_tool_calls: true,
};
if (options?.temperature !== undefined) {
body.temperature = options.temperature;
}
if (context.tools) {
body.tools = convertResponsesTools(context.tools, { strict: null });
}
if (options?.reasoningEffort !== undefined) {
body.reasoning = {
effort: clampReasoningEffort(model.id, options.reasoningEffort),
summary: options.reasoningSummary ?? "auto",
};
}
return body;
}
function clampReasoningEffort(modelId: string, effort: string): string {
const id = modelId.includes("/") ? modelId.split("/").pop()! : modelId;
if ((id.startsWith("gpt-5.2") || id.startsWith("gpt-5.3") || id.startsWith("gpt-5.4")) && effort === "minimal")
return "low";
if (id === "gpt-5.1" && effort === "xhigh") return "high";
if (id === "gpt-5.1-codex-mini") return effort === "high" || effort === "xhigh" ? "high" : "medium";
return effort;
}
function resolveCodexUrl(baseUrl?: string): string {
const raw = baseUrl && baseUrl.trim().length > 0 ? baseUrl : DEFAULT_CODEX_BASE_URL;
const normalized = raw.replace(/\/+$/, "");
if (normalized.endsWith("/codex/responses")) return normalized;
if (normalized.endsWith("/codex")) return `${normalized}/responses`;
return `${normalized}/codex/responses`;
}
function resolveCodexWebSocketUrl(baseUrl?: string): string {
const url = new URL(resolveCodexUrl(baseUrl));
if (url.protocol === "https:") url.protocol = "wss:";
if (url.protocol === "http:") url.protocol = "ws:";
return url.toString();
}
// ============================================================================
// Response Processing
// ============================================================================
async function processStream(
response: Response,
output: AssistantMessage,
stream: AssistantMessageEventStream,
model: Model<"openai-codex-responses">,
): Promise<void> {
await processResponsesStream(mapCodexEvents(parseSSE(response)), output, stream, model);
}
async function* mapCodexEvents(events: AsyncIterable<Record<string, unknown>>): AsyncGenerator<ResponseStreamEvent> {
for await (const event of events) {
const type = typeof event.type === "string" ? event.type : undefined;
if (!type) continue;
if (type === "error") {
const code = (event as { code?: string }).code || "";
const message = (event as { message?: string }).message || "";
throw new Error(`Codex error: ${message || code || JSON.stringify(event)}`);
}
if (type === "response.failed") {
const msg = (event as { response?: { error?: { message?: string } } }).response?.error?.message;
throw new Error(msg || "Codex response failed");
}
if (type === "response.done" || type === "response.completed") {
const response = (event as { response?: { status?: unknown } }).response;
const normalizedResponse = response
? { ...response, status: normalizeCodexStatus(response.status) }
: response;
yield { ...event, type: "response.completed", response: normalizedResponse } as ResponseStreamEvent;
continue;
}
yield event as unknown as ResponseStreamEvent;
}
}
function normalizeCodexStatus(status: unknown): CodexResponseStatus | undefined {
if (typeof status !== "string") return undefined;
return CODEX_RESPONSE_STATUSES.has(status as CodexResponseStatus) ? (status as CodexResponseStatus) : undefined;
}
// ============================================================================
// SSE Parsing
// ============================================================================
async function* parseSSE(response: Response): AsyncGenerator<Record<string, unknown>> {
if (!response.body) return;
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
while (true) {
const { done, value } = await reader.read();
if (done) break;
buffer += decoder.decode(value, { stream: true });
let idx = buffer.indexOf("\n\n");
while (idx !== -1) {
const chunk = buffer.slice(0, idx);
buffer = buffer.slice(idx + 2);
const dataLines = chunk
.split("\n")
.filter((l) => l.startsWith("data:"))
.map((l) => l.slice(5).trim());
if (dataLines.length > 0) {
const data = dataLines.join("\n").trim();
if (data && data !== "[DONE]") {
try {
yield JSON.parse(data);
} catch {}
}
}
idx = buffer.indexOf("\n\n");
}
}
}
// ============================================================================
// WebSocket Parsing
// ============================================================================
const OPENAI_BETA_RESPONSES_WEBSOCKETS = "responses_websockets=2026-02-06";
const SESSION_WEBSOCKET_CACHE_TTL_MS = 5 * 60 * 1000;
type WebSocketEventType = "open" | "message" | "error" | "close";
type WebSocketListener = (event: unknown) => void;
interface WebSocketLike {
close(code?: number, reason?: string): void;
send(data: string): void;
addEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
removeEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
}
interface CachedWebSocketConnection {
socket: WebSocketLike;
busy: boolean;
idleTimer?: ReturnType<typeof setTimeout>;
}
const websocketSessionCache = new Map<string, CachedWebSocketConnection>();
type WebSocketConstructor = new (
url: string,
protocols?: string | string[] | { headers?: Record<string, string> },
) => WebSocketLike;
function getWebSocketConstructor(): WebSocketConstructor | null {
const ctor = (globalThis as { WebSocket?: unknown }).WebSocket;
if (typeof ctor !== "function") return null;
return ctor as unknown as WebSocketConstructor;
}
function headersToRecord(headers: Headers): Record<string, string> {
const out: Record<string, string> = {};
for (const [key, value] of headers.entries()) {
out[key] = value;
}
return out;
}
function getWebSocketReadyState(socket: WebSocketLike): number | undefined {
const readyState = (socket as { readyState?: unknown }).readyState;
return typeof readyState === "number" ? readyState : undefined;
}
function isWebSocketReusable(socket: WebSocketLike): boolean {
const readyState = getWebSocketReadyState(socket);
// If readyState is unavailable, assume the runtime keeps it open/reusable.
return readyState === undefined || readyState === 1;
}
function closeWebSocketSilently(socket: WebSocketLike, code = 1000, reason = "done"): void {
try {
socket.close(code, reason);
} catch {}
}
function scheduleSessionWebSocketExpiry(sessionId: string, entry: CachedWebSocketConnection): void {
if (entry.idleTimer) {
clearTimeout(entry.idleTimer);
}
entry.idleTimer = setTimeout(() => {
if (entry.busy) return;
closeWebSocketSilently(entry.socket, 1000, "idle_timeout");
websocketSessionCache.delete(sessionId);
}, SESSION_WEBSOCKET_CACHE_TTL_MS);
}
async function connectWebSocket(url: string, headers: Headers, signal?: AbortSignal): Promise<WebSocketLike> {
const WebSocketCtor = getWebSocketConstructor();
if (!WebSocketCtor) {
throw new Error("WebSocket transport is not available in this runtime");
}
const wsHeaders = headersToRecord(headers);
wsHeaders["OpenAI-Beta"] = OPENAI_BETA_RESPONSES_WEBSOCKETS;
return new Promise<WebSocketLike>((resolve, reject) => {
let settled = false;
let socket: WebSocketLike;
try {
socket = new WebSocketCtor(url, { headers: wsHeaders });
} catch (error) {
reject(error instanceof Error ? error : new Error(String(error)));
return;
}
const onOpen: WebSocketListener = () => {
if (settled) return;
settled = true;
cleanup();
resolve(socket);
};
const onError: WebSocketListener = (event) => {
if (settled) return;
settled = true;
cleanup();
reject(extractWebSocketError(event));
};
const onClose: WebSocketListener = (event) => {
if (settled) return;
settled = true;
cleanup();
reject(extractWebSocketCloseError(event));
};
const onAbort = () => {
if (settled) return;
settled = true;
cleanup();
socket.close(1000, "aborted");
reject(new Error("Request was aborted"));
};
const cleanup = () => {
socket.removeEventListener("open", onOpen);
socket.removeEventListener("error", onError);
socket.removeEventListener("close", onClose);
signal?.removeEventListener("abort", onAbort);
};
socket.addEventListener("open", onOpen);
socket.addEventListener("error", onError);
socket.addEventListener("close", onClose);
signal?.addEventListener("abort", onAbort);
});
}
async function acquireWebSocket(
url: string,
headers: Headers,
sessionId: string | undefined,
signal?: AbortSignal,
): Promise<{ socket: WebSocketLike; release: (options?: { keep?: boolean }) => void }> {
if (!sessionId) {
const socket = await connectWebSocket(url, headers, signal);
return {
socket,
release: ({ keep } = {}) => {
if (keep === false) {
closeWebSocketSilently(socket);
return;
}
closeWebSocketSilently(socket);
},
};
}
const cached = websocketSessionCache.get(sessionId);
if (cached) {
if (cached.idleTimer) {
clearTimeout(cached.idleTimer);
cached.idleTimer = undefined;
}
if (!cached.busy && isWebSocketReusable(cached.socket)) {
cached.busy = true;
return {
socket: cached.socket,
release: ({ keep } = {}) => {
if (!keep || !isWebSocketReusable(cached.socket)) {
closeWebSocketSilently(cached.socket);
websocketSessionCache.delete(sessionId);
return;
}
cached.busy = false;
scheduleSessionWebSocketExpiry(sessionId, cached);
},
};
}
if (cached.busy) {
const socket = await connectWebSocket(url, headers, signal);
return {
socket,
release: () => {
closeWebSocketSilently(socket);
},
};
}
if (!isWebSocketReusable(cached.socket)) {
closeWebSocketSilently(cached.socket);
websocketSessionCache.delete(sessionId);
}
}
const socket = await connectWebSocket(url, headers, signal);
const entry: CachedWebSocketConnection = { socket, busy: true };
websocketSessionCache.set(sessionId, entry);
return {
socket,
release: ({ keep } = {}) => {
if (!keep || !isWebSocketReusable(entry.socket)) {
closeWebSocketSilently(entry.socket);
if (entry.idleTimer) clearTimeout(entry.idleTimer);
if (websocketSessionCache.get(sessionId) === entry) {
websocketSessionCache.delete(sessionId);
}
return;
}
entry.busy = false;
scheduleSessionWebSocketExpiry(sessionId, entry);
},
};
}
function extractWebSocketError(event: unknown): Error {
if (event && typeof event === "object" && "message" in event) {
const message = (event as { message?: unknown }).message;
if (typeof message === "string" && message.length > 0) {
return new Error(message);
}
}
return new Error("WebSocket error");
}
function extractWebSocketCloseError(event: unknown): Error {
if (event && typeof event === "object") {
const code = "code" in event ? (event as { code?: unknown }).code : undefined;
const reason = "reason" in event ? (event as { reason?: unknown }).reason : undefined;
const codeText = typeof code === "number" ? ` ${code}` : "";
const reasonText = typeof reason === "string" && reason.length > 0 ? ` ${reason}` : "";
return new Error(`WebSocket closed${codeText}${reasonText}`.trim());
}
return new Error("WebSocket closed");
}
async function decodeWebSocketData(data: unknown): Promise<string | null> {
if (typeof data === "string") return data;
if (data instanceof ArrayBuffer) {
return new TextDecoder().decode(new Uint8Array(data));
}
if (ArrayBuffer.isView(data)) {
const view = data as ArrayBufferView;
return new TextDecoder().decode(new Uint8Array(view.buffer, view.byteOffset, view.byteLength));
}
if (data && typeof data === "object" && "arrayBuffer" in data) {
const blobLike = data as { arrayBuffer: () => Promise<ArrayBuffer> };
const arrayBuffer = await blobLike.arrayBuffer();
return new TextDecoder().decode(new Uint8Array(arrayBuffer));
}
return null;
}
async function* parseWebSocket(socket: WebSocketLike, signal?: AbortSignal): AsyncGenerator<Record<string, unknown>> {
const queue: Record<string, unknown>[] = [];
let pending: (() => void) | null = null;
let done = false;
let failed: Error | null = null;
let sawCompletion = false;
const wake = () => {
if (!pending) return;
const resolve = pending;
pending = null;
resolve();
};
const onMessage: WebSocketListener = (event) => {
void (async () => {
if (!event || typeof event !== "object" || !("data" in event)) return;
const text = await decodeWebSocketData((event as { data?: unknown }).data);
if (!text) return;
try {
const parsed = JSON.parse(text) as Record<string, unknown>;
const type = typeof parsed.type === "string" ? parsed.type : "";
if (type === "response.completed" || type === "response.done") {
sawCompletion = true;
done = true;
}
queue.push(parsed);
wake();
} catch {}
})();
};
const onError: WebSocketListener = (event) => {
failed = extractWebSocketError(event);
done = true;
wake();
};
const onClose: WebSocketListener = (event) => {
if (sawCompletion) {
done = true;
wake();
return;
}
if (!failed) {
failed = extractWebSocketCloseError(event);
}
done = true;
wake();
};
const onAbort = () => {
failed = new Error("Request was aborted");
done = true;
wake();
};
socket.addEventListener("message", onMessage);
socket.addEventListener("error", onError);
socket.addEventListener("close", onClose);
signal?.addEventListener("abort", onAbort);
try {
while (true) {
if (signal?.aborted) {
throw new Error("Request was aborted");
}
if (queue.length > 0) {
yield queue.shift()!;
continue;
}
if (done) break;
await new Promise<void>((resolve) => {
pending = resolve;
});
}
if (failed) {
throw failed;
}
if (!sawCompletion) {
throw new Error("WebSocket stream closed before response.completed");
}
} finally {
socket.removeEventListener("message", onMessage);
socket.removeEventListener("error", onError);
socket.removeEventListener("close", onClose);
signal?.removeEventListener("abort", onAbort);
}
}
async function processWebSocketStream(
url: string,
body: RequestBody,
headers: Headers,
output: AssistantMessage,
stream: AssistantMessageEventStream,
model: Model<"openai-codex-responses">,
onStart: () => void,
options?: OpenAICodexResponsesOptions,
): Promise<void> {
const { socket, release } = await acquireWebSocket(url, headers, options?.sessionId, options?.signal);
let keepConnection = true;
try {
socket.send(JSON.stringify({ type: "response.create", ...body }));
onStart();
stream.push({ type: "start", partial: output });
await processResponsesStream(mapCodexEvents(parseWebSocket(socket, options?.signal)), output, stream, model);
if (options?.signal?.aborted) {
keepConnection = false;
}
} catch (error) {
keepConnection = false;
throw error;
} finally {
release({ keep: keepConnection });
}
}
// ============================================================================
// Error Handling
// ============================================================================
async function parseErrorResponse(response: Response): Promise<{ message: string; friendlyMessage?: string }> {
const raw = await response.text();
let message = raw || response.statusText || "Request failed";
let friendlyMessage: string | undefined;
try {
const parsed = JSON.parse(raw) as {
error?: { code?: string; type?: string; message?: string; plan_type?: string; resets_at?: number };
};
const err = parsed?.error;
if (err) {
const code = err.code || err.type || "";
if (/usage_limit_reached|usage_not_included|rate_limit_exceeded/i.test(code) || response.status === 429) {
const plan = err.plan_type ? ` (${err.plan_type.toLowerCase()} plan)` : "";
const mins = err.resets_at
? Math.max(0, Math.round((err.resets_at * 1000 - Date.now()) / 60000))
: undefined;
const when = mins !== undefined ? ` Try again in ~${mins} min.` : "";
friendlyMessage = `You have hit your ChatGPT usage limit${plan}.${when}`.trim();
}
message = err.message || friendlyMessage || message;
}
} catch {}
return { message, friendlyMessage };
}
// ============================================================================
// Auth & Headers
// ============================================================================
function extractAccountId(token: string): string {
try {
const parts = token.split(".");
if (parts.length !== 3) throw new Error("Invalid token");
const payload = JSON.parse(atob(parts[1]));
const accountId = payload?.[JWT_CLAIM_PATH]?.chatgpt_account_id;
if (!accountId) throw new Error("No account ID in token");
return accountId;
} catch {
throw new Error("Failed to extract accountId from token");
}
}
function buildHeaders(
initHeaders: Record<string, string> | undefined,
additionalHeaders: Record<string, string> | undefined,
accountId: string,
token: string,
sessionId?: string,
): Headers {
const headers = new Headers(initHeaders);
headers.set("Authorization", `Bearer ${token}`);
headers.set("chatgpt-account-id", accountId);
headers.set("OpenAI-Beta", "responses=experimental");
headers.set("originator", "pi");
const userAgent = _os ? `pi (${_os.platform()} ${_os.release()}; ${_os.arch()})` : "pi (browser)";
headers.set("User-Agent", userAgent);
headers.set("accept", "text/event-stream");
headers.set("content-type", "application/json");
for (const [key, value] of Object.entries(additionalHeaders || {})) {
headers.set(key, value);
}
if (sessionId) {
headers.set("session_id", sessionId);
}
return headers;
}

View file

@ -0,0 +1,820 @@
import OpenAI from "openai";
import type {
ChatCompletionAssistantMessageParam,
ChatCompletionChunk,
ChatCompletionContentPart,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionMessageParam,
ChatCompletionToolMessageParam,
} from "openai/resources/chat/completions.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { calculateCost, supportsXhigh } from "../models.js";
import type {
AssistantMessage,
Context,
Message,
Model,
OpenAICompletionsCompat,
SimpleStreamOptions,
StopReason,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
Tool,
ToolCall,
ToolResultMessage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
import { transformMessages } from "./transform-messages.js";
/**
* Check if conversation messages contain tool calls or tool results.
* This is needed because Anthropic (via proxy) requires the tools param
* to be present when messages include tool_calls or tool role messages.
*/
function hasToolHistory(messages: Message[]): boolean {
for (const msg of messages) {
if (msg.role === "toolResult") {
return true;
}
if (msg.role === "assistant") {
if (msg.content.some((block) => block.type === "toolCall")) {
return true;
}
}
}
return false;
}
export interface OpenAICompletionsOptions extends StreamOptions {
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
}
export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenAICompletionsOptions> = (
model: Model<"openai-completions">,
context: Context,
options?: OpenAICompletionsOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming;
}
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
stream.push({ type: "start", partial: output });
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
const finishCurrentBlock = (block?: typeof currentBlock) => {
if (block) {
if (block.type === "text") {
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: block.text,
partial: output,
});
} else if (block.type === "thinking") {
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: block.thinking,
partial: output,
});
} else if (block.type === "toolCall") {
block.arguments = parseStreamingJson(block.partialArgs);
delete block.partialArgs;
stream.push({
type: "toolcall_end",
contentIndex: blockIndex(),
toolCall: block,
partial: output,
});
}
}
};
for await (const chunk of openaiStream) {
if (chunk.usage) {
const cachedTokens = chunk.usage.prompt_tokens_details?.cached_tokens || 0;
const reasoningTokens = chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
const outputTokens = (chunk.usage.completion_tokens || 0) + reasoningTokens;
output.usage = {
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
input,
output: outputTokens,
cacheRead: cachedTokens,
cacheWrite: 0,
// Compute totalTokens ourselves since we add reasoning_tokens to output
// and some providers (e.g., Groq) don't include them in total_tokens
totalTokens: input + outputTokens + cachedTokens,
cost: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
total: 0,
},
};
calculateCost(model, output.usage);
}
const choice = chunk.choices?.[0];
if (!choice) continue;
if (choice.finish_reason) {
output.stopReason = mapStopReason(choice.finish_reason);
}
if (choice.delta) {
if (
choice.delta.content !== null &&
choice.delta.content !== undefined &&
choice.delta.content.length > 0
) {
if (!currentBlock || currentBlock.type !== "text") {
finishCurrentBlock(currentBlock);
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
}
if (currentBlock.type === "text") {
currentBlock.text += choice.delta.content;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: choice.delta.content,
partial: output,
});
}
}
// Some endpoints return reasoning in reasoning_content (llama.cpp),
// or reasoning (other openai compatible endpoints)
// Use the first non-empty reasoning field to avoid duplication
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
const reasoningFields = ["reasoning_content", "reasoning", "reasoning_text"];
let foundReasoningField: string | null = null;
for (const field of reasoningFields) {
if (
(choice.delta as any)[field] !== null &&
(choice.delta as any)[field] !== undefined &&
(choice.delta as any)[field].length > 0
) {
if (!foundReasoningField) {
foundReasoningField = field;
break;
}
}
}
if (foundReasoningField) {
if (!currentBlock || currentBlock.type !== "thinking") {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "thinking",
thinking: "",
thinkingSignature: foundReasoningField,
};
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
}
if (currentBlock.type === "thinking") {
const delta = (choice.delta as any)[foundReasoningField];
currentBlock.thinking += delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
if (choice?.delta?.tool_calls) {
for (const toolCall of choice.delta.tool_calls) {
if (
!currentBlock ||
currentBlock.type !== "toolCall" ||
(toolCall.id && currentBlock.id !== toolCall.id)
) {
finishCurrentBlock(currentBlock);
currentBlock = {
type: "toolCall",
id: toolCall.id || "",
name: toolCall.function?.name || "",
arguments: {},
partialArgs: "",
};
output.content.push(currentBlock);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
}
if (currentBlock.type === "toolCall") {
if (toolCall.id) currentBlock.id = toolCall.id;
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
let delta = "";
if (toolCall.function?.arguments) {
delta = toolCall.function.arguments;
currentBlock.partialArgs += toolCall.function.arguments;
currentBlock.arguments = parseStreamingJson(currentBlock.partialArgs);
}
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta,
partial: output,
});
}
}
}
const reasoningDetails = (choice.delta as any).reasoning_details;
if (reasoningDetails && Array.isArray(reasoningDetails)) {
for (const detail of reasoningDetails) {
if (detail.type === "reasoning.encrypted" && detail.id && detail.data) {
const matchingToolCall = output.content.find(
(b) => b.type === "toolCall" && b.id === detail.id,
) as ToolCall | undefined;
if (matchingToolCall) {
matchingToolCall.thoughtSignature = JSON.stringify(detail);
}
}
}
}
}
}
finishCurrentBlock(currentBlock);
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as any).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
// Some providers via OpenRouter give additional information in this field.
const rawMetadata = (error as any)?.error?.metadata?.raw;
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAICompletions: StreamFunction<"openai-completions", SimpleStreamOptions> = (
model: Model<"openai-completions">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
const toolChoice = (options as OpenAICompletionsOptions | undefined)?.toolChoice;
return streamOpenAICompletions(model, context, {
...base,
reasoningEffort,
toolChoice,
} satisfies OpenAICompletionsOptions);
};
function createClient(
model: Model<"openai-completions">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
const compat = getCompat(model);
const messages = convertMessages(model, context, compat);
maybeAddOpenRouterAnthropicCacheControl(model, messages);
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: model.id,
messages,
stream: true,
};
if (compat.supportsUsageInStreaming !== false) {
(params as any).stream_options = { include_usage: true };
}
if (compat.supportsStore) {
params.store = false;
}
if (options?.maxTokens) {
if (compat.maxTokensField === "max_tokens") {
(params as any).max_tokens = options.maxTokens;
} else {
params.max_completion_tokens = options.maxTokens;
}
}
if (options?.temperature !== undefined) {
params.temperature = options.temperature;
}
if (context.tools) {
params.tools = convertTools(context.tools, compat);
} else if (hasToolHistory(context.messages)) {
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
params.tools = [];
}
if (options?.toolChoice) {
params.tool_choice = options.toolChoice;
}
if ((compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") && model.reasoning) {
// Both Z.ai and Qwen use enable_thinking: boolean
(params as any).enable_thinking = !!options?.reasoningEffort;
} else if (options?.reasoningEffort && model.reasoning && compat.supportsReasoningEffort) {
// OpenAI-style reasoning_effort
(params as any).reasoning_effort = mapReasoningEffort(options.reasoningEffort, compat.reasoningEffortMap);
}
// OpenRouter provider routing preferences
if (model.baseUrl.includes("openrouter.ai") && model.compat?.openRouterRouting) {
(params as any).provider = model.compat.openRouterRouting;
}
// Vercel AI Gateway provider routing preferences
if (model.baseUrl.includes("ai-gateway.vercel.sh") && model.compat?.vercelGatewayRouting) {
const routing = model.compat.vercelGatewayRouting;
if (routing.only || routing.order) {
const gatewayOptions: Record<string, string[]> = {};
if (routing.only) gatewayOptions.only = routing.only;
if (routing.order) gatewayOptions.order = routing.order;
(params as any).providerOptions = { gateway: gatewayOptions };
}
}
return params;
}
function mapReasoningEffort(
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
reasoningEffortMap: Partial<Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>>,
): string {
return reasoningEffortMap[effort] ?? effort;
}
function maybeAddOpenRouterAnthropicCacheControl(
model: Model<"openai-completions">,
messages: ChatCompletionMessageParam[],
): void {
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/")) return;
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
// on the last user/assistant message (walking backwards until we find text content).
for (let i = messages.length - 1; i >= 0; i--) {
const msg = messages[i];
if (msg.role !== "user" && msg.role !== "assistant") continue;
const content = msg.content;
if (typeof content === "string") {
msg.content = [
Object.assign({ type: "text" as const, text: content }, { cache_control: { type: "ephemeral" } }),
];
return;
}
if (!Array.isArray(content)) continue;
// Find last text part and add cache_control
for (let j = content.length - 1; j >= 0; j--) {
const part = content[j];
if (part?.type === "text") {
Object.assign(part, { cache_control: { type: "ephemeral" } });
return;
}
}
}
}
export function convertMessages(
model: Model<"openai-completions">,
context: Context,
compat: Required<OpenAICompletionsCompat>,
): ChatCompletionMessageParam[] {
const params: ChatCompletionMessageParam[] = [];
const normalizeToolCallId = (id: string): string => {
// Handle pipe-separated IDs from OpenAI Responses API
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
// These come from providers like github-copilot, openai-codex, opencode
// Extract just the call_id part and normalize it
if (id.includes("|")) {
const [callId] = id.split("|");
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
}
if (model.provider === "openai") return id.length > 40 ? id.slice(0, 40) : id;
return id;
};
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeToolCallId(id));
if (context.systemPrompt) {
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
const role = useDeveloperRole ? "developer" : "system";
params.push({ role: role, content: sanitizeSurrogates(context.systemPrompt) });
}
let lastRole: string | null = null;
for (let i = 0; i < transformedMessages.length; i++) {
const msg = transformedMessages[i];
// Some providers don't allow user messages directly after tool results
// Insert a synthetic assistant message to bridge the gap
if (compat.requiresAssistantAfterToolResult && lastRole === "toolResult" && msg.role === "user") {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
if (msg.role === "user") {
if (typeof msg.content === "string") {
params.push({
role: "user",
content: sanitizeSurrogates(msg.content),
});
} else {
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
if (item.type === "text") {
return {
type: "text",
text: sanitizeSurrogates(item.text),
} satisfies ChatCompletionContentPartText;
} else {
return {
type: "image_url",
image_url: {
url: `data:${item.mimeType};base64,${item.data}`,
},
} satisfies ChatCompletionContentPartImage;
}
});
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "image_url")
: content;
if (filteredContent.length === 0) continue;
params.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
// Some providers don't accept null content, use empty string instead
const assistantMsg: ChatCompletionAssistantMessageParam = {
role: "assistant",
content: compat.requiresAssistantAfterToolResult ? "" : null,
};
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
// Filter out empty text blocks to avoid API validation errors
const nonEmptyTextBlocks = textBlocks.filter((b) => b.text && b.text.trim().length > 0);
if (nonEmptyTextBlocks.length > 0) {
// GitHub Copilot requires assistant content as a string, not an array.
// Sending as array causes Claude models to re-answer all previous prompts.
if (model.provider === "github-copilot") {
assistantMsg.content = nonEmptyTextBlocks.map((b) => sanitizeSurrogates(b.text)).join("");
} else {
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
return { type: "text", text: sanitizeSurrogates(b.text) };
});
}
}
// Handle thinking blocks
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
// Filter out empty thinking blocks to avoid API validation errors
const nonEmptyThinkingBlocks = thinkingBlocks.filter((b) => b.thinking && b.thinking.trim().length > 0);
if (nonEmptyThinkingBlocks.length > 0) {
if (compat.requiresThinkingAsText) {
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
const thinkingText = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n\n");
const textContent = assistantMsg.content as Array<{ type: "text"; text: string }> | null;
if (textContent) {
textContent.unshift({ type: "text", text: thinkingText });
} else {
assistantMsg.content = [{ type: "text", text: thinkingText }];
}
} else {
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
if (signature && signature.length > 0) {
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n");
}
}
}
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
if (toolCalls.length > 0) {
assistantMsg.tool_calls = toolCalls.map((tc) => ({
id: tc.id,
type: "function" as const,
function: {
name: tc.name,
arguments: JSON.stringify(tc.arguments),
},
}));
const reasoningDetails = toolCalls
.filter((tc) => tc.thoughtSignature)
.map((tc) => {
try {
return JSON.parse(tc.thoughtSignature!);
} catch {
return null;
}
})
.filter(Boolean);
if (reasoningDetails.length > 0) {
(assistantMsg as any).reasoning_details = reasoningDetails;
}
}
// Skip assistant messages that have no content and no tool calls.
// Some providers require "either content or tool_calls, but not none".
// Other providers also don't accept empty assistant messages.
// This handles aborted assistant responses that got no content.
const content = assistantMsg.content;
const hasContent =
content !== null &&
content !== undefined &&
(typeof content === "string" ? content.length > 0 : content.length > 0);
if (!hasContent && !assistantMsg.tool_calls) {
continue;
}
params.push(assistantMsg);
} else if (msg.role === "toolResult") {
const imageBlocks: Array<{ type: "image_url"; image_url: { url: string } }> = [];
let j = i;
for (; j < transformedMessages.length && transformedMessages[j].role === "toolResult"; j++) {
const toolMsg = transformedMessages[j] as ToolResultMessage;
// Extract text and image content
const textResult = toolMsg.content
.filter((c) => c.type === "text")
.map((c) => (c as any).text)
.join("\n");
const hasImages = toolMsg.content.some((c) => c.type === "image");
// Always send tool result with text (or placeholder if only images)
const hasText = textResult.length > 0;
// Some providers require the 'name' field in tool results
const toolResultMsg: ChatCompletionToolMessageParam = {
role: "tool",
content: sanitizeSurrogates(hasText ? textResult : "(see attached image)"),
tool_call_id: toolMsg.toolCallId,
};
if (compat.requiresToolResultName && toolMsg.toolName) {
(toolResultMsg as any).name = toolMsg.toolName;
}
params.push(toolResultMsg);
if (hasImages && model.input.includes("image")) {
for (const block of toolMsg.content) {
if (block.type === "image") {
imageBlocks.push({
type: "image_url",
image_url: {
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
},
});
}
}
}
}
i = j - 1;
if (imageBlocks.length > 0) {
if (compat.requiresAssistantAfterToolResult) {
params.push({
role: "assistant",
content: "I have processed the tool results.",
});
}
params.push({
role: "user",
content: [
{
type: "text",
text: "Attached image(s) from tool result:",
},
...imageBlocks,
],
});
lastRole = "user";
} else {
lastRole = "toolResult";
}
continue;
}
lastRole = msg.role;
}
return params;
}
function convertTools(
tools: Tool[],
compat: Required<OpenAICompletionsCompat>,
): OpenAI.Chat.Completions.ChatCompletionTool[] {
return tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
// Only include strict if provider supports it. Some reject unknown fields.
...(compat.supportsStrictMode !== false && { strict: false }),
},
}));
}
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
if (reason === null) return "stop";
switch (reason) {
case "stop":
return "stop";
case "length":
return "length";
case "function_call":
case "tool_calls":
return "toolUse";
case "content_filter":
return "error";
default: {
const _exhaustive: never = reason;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}
/**
* Detect compatibility settings from provider and baseUrl for known providers.
* Provider takes precedence over URL-based detection since it's explicitly configured.
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
*/
function detectCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
const provider = model.provider;
const baseUrl = model.baseUrl;
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
const isNonStandard =
provider === "cerebras" ||
baseUrl.includes("cerebras.ai") ||
provider === "xai" ||
baseUrl.includes("api.x.ai") ||
baseUrl.includes("chutes.ai") ||
baseUrl.includes("deepseek.com") ||
isZai ||
provider === "opencode" ||
baseUrl.includes("opencode.ai");
const useMaxTokens = baseUrl.includes("chutes.ai");
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
const reasoningEffortMap =
isGroq && model.id === "qwen/qwen3-32b"
? {
minimal: "default",
low: "default",
medium: "default",
high: "default",
xhigh: "default",
}
: {};
return {
supportsStore: !isNonStandard,
supportsDeveloperRole: !isNonStandard,
supportsReasoningEffort: !isGrok && !isZai,
reasoningEffortMap,
supportsUsageInStreaming: true,
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
requiresToolResultName: false,
requiresAssistantAfterToolResult: false,
requiresThinkingAsText: false,
thinkingFormat: isZai ? "zai" : "openai",
openRouterRouting: {},
vercelGatewayRouting: {},
supportsStrictMode: true,
};
}
/**
* Get resolved compatibility settings for a model.
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
*/
function getCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
const detected = detectCompat(model);
if (!model.compat) return detected;
return {
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
supportsDeveloperRole: model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
supportsReasoningEffort: model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
reasoningEffortMap: model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
supportsUsageInStreaming: model.compat.supportsUsageInStreaming ?? detected.supportsUsageInStreaming,
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
requiresToolResultName: model.compat.requiresToolResultName ?? detected.requiresToolResultName,
requiresAssistantAfterToolResult:
model.compat.requiresAssistantAfterToolResult ?? detected.requiresAssistantAfterToolResult,
requiresThinkingAsText: model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
openRouterRouting: model.compat.openRouterRouting ?? {},
vercelGatewayRouting: model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
supportsStrictMode: model.compat.supportsStrictMode ?? detected.supportsStrictMode,
};
}

View file

@ -0,0 +1,496 @@
import type OpenAI from "openai";
import type {
Tool as OpenAITool,
ResponseCreateParamsStreaming,
ResponseFunctionToolCall,
ResponseInput,
ResponseInputContent,
ResponseInputImage,
ResponseInputText,
ResponseOutputMessage,
ResponseReasoningItem,
ResponseStreamEvent,
} from "openai/resources/responses/responses.js";
import { calculateCost } from "../models.js";
import type {
Api,
AssistantMessage,
Context,
ImageContent,
Model,
StopReason,
TextContent,
TextSignatureV1,
ThinkingContent,
Tool,
ToolCall,
Usage,
} from "../types.js";
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
import { shortHash } from "../utils/hash.js";
import { parseStreamingJson } from "../utils/json-parse.js";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
import { transformMessages } from "./transform-messages.js";
// =============================================================================
// Utilities
// =============================================================================
function encodeTextSignatureV1(id: string, phase?: TextSignatureV1["phase"]): string {
const payload: TextSignatureV1 = { v: 1, id };
if (phase) payload.phase = phase;
return JSON.stringify(payload);
}
function parseTextSignature(
signature: string | undefined,
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
if (!signature) return undefined;
if (signature.startsWith("{")) {
try {
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
if (parsed.v === 1 && typeof parsed.id === "string") {
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
return { id: parsed.id, phase: parsed.phase };
}
return { id: parsed.id };
}
} catch {
// Fall through to legacy plain-string handling.
}
}
return { id: signature };
}
export interface OpenAIResponsesStreamOptions {
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
applyServiceTierPricing?: (
usage: Usage,
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
) => void;
}
export interface ConvertResponsesMessagesOptions {
includeSystemPrompt?: boolean;
}
export interface ConvertResponsesToolsOptions {
strict?: boolean | null;
}
// =============================================================================
// Message conversion
// =============================================================================
export function convertResponsesMessages<TApi extends Api>(
model: Model<TApi>,
context: Context,
allowedToolCallProviders: ReadonlySet<string>,
options?: ConvertResponsesMessagesOptions,
): ResponseInput {
const messages: ResponseInput = [];
const normalizeToolCallId = (id: string): string => {
if (!allowedToolCallProviders.has(model.provider)) return id;
if (!id.includes("|")) return id;
const [callId, itemId] = id.split("|");
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
// OpenAI Responses API requires item id to start with "fc"
if (!sanitizedItemId.startsWith("fc")) {
sanitizedItemId = `fc_${sanitizedItemId}`;
}
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
let normalizedCallId = sanitizedCallId.length > 64 ? sanitizedCallId.slice(0, 64) : sanitizedCallId;
let normalizedItemId = sanitizedItemId.length > 64 ? sanitizedItemId.slice(0, 64) : sanitizedItemId;
normalizedCallId = normalizedCallId.replace(/_+$/, "");
normalizedItemId = normalizedItemId.replace(/_+$/, "");
return `${normalizedCallId}|${normalizedItemId}`;
};
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
if (includeSystemPrompt && context.systemPrompt) {
const role = model.reasoning ? "developer" : "system";
messages.push({
role,
content: sanitizeSurrogates(context.systemPrompt),
});
}
let msgIndex = 0;
for (const msg of transformedMessages) {
if (msg.role === "user") {
if (typeof msg.content === "string") {
messages.push({
role: "user",
content: [{ type: "input_text", text: sanitizeSurrogates(msg.content) }],
});
} else {
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
if (item.type === "text") {
return {
type: "input_text",
text: sanitizeSurrogates(item.text),
} satisfies ResponseInputText;
}
return {
type: "input_image",
detail: "auto",
image_url: `data:${item.mimeType};base64,${item.data}`,
} satisfies ResponseInputImage;
});
const filteredContent = !model.input.includes("image")
? content.filter((c) => c.type !== "input_image")
: content;
if (filteredContent.length === 0) continue;
messages.push({
role: "user",
content: filteredContent,
});
}
} else if (msg.role === "assistant") {
const output: ResponseInput = [];
const assistantMsg = msg as AssistantMessage;
const isDifferentModel =
assistantMsg.model !== model.id &&
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api;
for (const block of msg.content) {
if (block.type === "thinking") {
if (block.thinkingSignature) {
const reasoningItem = JSON.parse(block.thinkingSignature) as ResponseReasoningItem;
output.push(reasoningItem);
}
} else if (block.type === "text") {
const textBlock = block as TextContent;
const parsedSignature = parseTextSignature(textBlock.textSignature);
// OpenAI requires id to be max 64 characters
let msgId = parsedSignature?.id;
if (!msgId) {
msgId = `msg_${msgIndex}`;
} else if (msgId.length > 64) {
msgId = `msg_${shortHash(msgId)}`;
}
output.push({
type: "message",
role: "assistant",
content: [{ type: "output_text", text: sanitizeSurrogates(textBlock.text), annotations: [] }],
status: "completed",
id: msgId,
phase: parsedSignature?.phase,
} satisfies ResponseOutputMessage);
} else if (block.type === "toolCall") {
const toolCall = block as ToolCall;
const [callId, itemIdRaw] = toolCall.id.split("|");
let itemId: string | undefined = itemIdRaw;
// For different-model messages, set id to undefined to avoid pairing validation.
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
// By omitting the id, we avoid triggering that validation (like cross-provider does).
if (isDifferentModel && itemId?.startsWith("fc_")) {
itemId = undefined;
}
output.push({
type: "function_call",
id: itemId,
call_id: callId,
name: toolCall.name,
arguments: JSON.stringify(toolCall.arguments),
});
}
}
if (output.length === 0) continue;
messages.push(...output);
} else if (msg.role === "toolResult") {
// Extract text and image content
const textResult = msg.content
.filter((c): c is TextContent => c.type === "text")
.map((c) => c.text)
.join("\n");
const hasImages = msg.content.some((c): c is ImageContent => c.type === "image");
// Always send function_call_output with text (or placeholder if only images)
const hasText = textResult.length > 0;
const [callId] = msg.toolCallId.split("|");
messages.push({
type: "function_call_output",
call_id: callId,
output: sanitizeSurrogates(hasText ? textResult : "(see attached image)"),
});
// If there are images and model supports them, send a follow-up user message with images
if (hasImages && model.input.includes("image")) {
const contentParts: ResponseInputContent[] = [];
// Add text prefix
contentParts.push({
type: "input_text",
text: "Attached image(s) from tool result:",
} satisfies ResponseInputText);
// Add images
for (const block of msg.content) {
if (block.type === "image") {
contentParts.push({
type: "input_image",
detail: "auto",
image_url: `data:${block.mimeType};base64,${block.data}`,
} satisfies ResponseInputImage);
}
}
messages.push({
role: "user",
content: contentParts,
});
}
}
msgIndex++;
}
return messages;
}
// =============================================================================
// Tool conversion
// =============================================================================
export function convertResponsesTools(tools: Tool[], options?: ConvertResponsesToolsOptions): OpenAITool[] {
const strict = options?.strict === undefined ? false : options.strict;
return tools.map((tool) => ({
type: "function",
name: tool.name,
description: tool.description,
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
strict,
}));
}
// =============================================================================
// Stream processing
// =============================================================================
export async function processResponsesStream<TApi extends Api>(
openaiStream: AsyncIterable<ResponseStreamEvent>,
output: AssistantMessage,
stream: AssistantMessageEventStream,
model: Model<TApi>,
options?: OpenAIResponsesStreamOptions,
): Promise<void> {
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
const blocks = output.content;
const blockIndex = () => blocks.length - 1;
for await (const event of openaiStream) {
if (event.type === "response.output_item.added") {
const item = event.item;
if (item.type === "reasoning") {
currentItem = item;
currentBlock = { type: "thinking", thinking: "" };
output.content.push(currentBlock);
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
} else if (item.type === "message") {
currentItem = item;
currentBlock = { type: "text", text: "" };
output.content.push(currentBlock);
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
} else if (item.type === "function_call") {
currentItem = item;
currentBlock = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: {},
partialJson: item.arguments || "",
};
output.content.push(currentBlock);
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
}
} else if (event.type === "response.reasoning_summary_part.added") {
if (currentItem && currentItem.type === "reasoning") {
currentItem.summary = currentItem.summary || [];
currentItem.summary.push(event.part);
}
} else if (event.type === "response.reasoning_summary_text.delta") {
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += event.delta;
lastPart.text += event.delta;
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.reasoning_summary_part.done") {
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
currentItem.summary = currentItem.summary || [];
const lastPart = currentItem.summary[currentItem.summary.length - 1];
if (lastPart) {
currentBlock.thinking += "\n\n";
lastPart.text += "\n\n";
stream.push({
type: "thinking_delta",
contentIndex: blockIndex(),
delta: "\n\n",
partial: output,
});
}
}
} else if (event.type === "response.content_part.added") {
if (currentItem?.type === "message") {
currentItem.content = currentItem.content || [];
// Filter out ReasoningText, only accept output_text and refusal
if (event.part.type === "output_text" || event.part.type === "refusal") {
currentItem.content.push(event.part);
}
}
} else if (event.type === "response.output_text.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "output_text") {
currentBlock.text += event.delta;
lastPart.text += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.refusal.delta") {
if (currentItem?.type === "message" && currentBlock?.type === "text") {
if (!currentItem.content || currentItem.content.length === 0) {
continue;
}
const lastPart = currentItem.content[currentItem.content.length - 1];
if (lastPart?.type === "refusal") {
currentBlock.text += event.delta;
lastPart.refusal += event.delta;
stream.push({
type: "text_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
}
} else if (event.type === "response.function_call_arguments.delta") {
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
currentBlock.partialJson += event.delta;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
stream.push({
type: "toolcall_delta",
contentIndex: blockIndex(),
delta: event.delta,
partial: output,
});
}
} else if (event.type === "response.function_call_arguments.done") {
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
currentBlock.partialJson = event.arguments;
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
}
} else if (event.type === "response.output_item.done") {
const item = event.item;
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
currentBlock.thinkingSignature = JSON.stringify(item);
stream.push({
type: "thinking_end",
contentIndex: blockIndex(),
content: currentBlock.thinking,
partial: output,
});
currentBlock = null;
} else if (item.type === "message" && currentBlock?.type === "text") {
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
currentBlock.textSignature = encodeTextSignatureV1(item.id, item.phase ?? undefined);
stream.push({
type: "text_end",
contentIndex: blockIndex(),
content: currentBlock.text,
partial: output,
});
currentBlock = null;
} else if (item.type === "function_call") {
const args =
currentBlock?.type === "toolCall" && currentBlock.partialJson
? parseStreamingJson(currentBlock.partialJson)
: parseStreamingJson(item.arguments || "{}");
const toolCall: ToolCall = {
type: "toolCall",
id: `${item.call_id}|${item.id}`,
name: item.name,
arguments: args,
};
currentBlock = null;
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
}
} else if (event.type === "response.completed") {
const response = event.response;
if (response?.usage) {
const cachedTokens = response.usage.input_tokens_details?.cached_tokens || 0;
output.usage = {
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
input: (response.usage.input_tokens || 0) - cachedTokens,
output: response.usage.output_tokens || 0,
cacheRead: cachedTokens,
cacheWrite: 0,
totalTokens: response.usage.total_tokens || 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
}
calculateCost(model, output.usage);
if (options?.applyServiceTierPricing) {
const serviceTier = response?.service_tier ?? options.serviceTier;
options.applyServiceTierPricing(output.usage, serviceTier);
}
// Map status to stop reason
output.stopReason = mapStopReason(response?.status);
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
output.stopReason = "toolUse";
}
} else if (event.type === "error") {
throw new Error(`Error Code ${event.code}: ${event.message}` || "Unknown error");
} else if (event.type === "response.failed") {
throw new Error("Unknown error");
}
}
}
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
if (!status) return "stop";
switch (status) {
case "completed":
return "stop";
case "incomplete":
return "length";
case "failed":
case "cancelled":
return "error";
// These two are wonky ...
case "in_progress":
case "queued":
return "stop";
default: {
const _exhaustive: never = status;
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
}
}
}

View file

@ -0,0 +1,262 @@
import OpenAI from "openai";
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
import { getEnvApiKey } from "../env-api-keys.js";
import { supportsXhigh } from "../models.js";
import type {
Api,
AssistantMessage,
CacheRetention,
Context,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
Usage,
} from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
/**
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
return "long";
}
return "short";
}
/**
* Get prompt cache retention based on cacheRetention and base URL.
* Only applies to direct OpenAI API calls (api.openai.com).
*/
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
if (cacheRetention !== "long") {
return undefined;
}
if (baseUrl.includes("api.openai.com")) {
return "24h";
}
return undefined;
}
// OpenAI Responses-specific options
export interface OpenAIResponsesOptions extends StreamOptions {
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
reasoningSummary?: "auto" | "detailed" | "concise" | null;
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
}
/**
* Generate function for OpenAI Responses API
*/
export const streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIResponsesOptions> = (
model: Model<"openai-responses">,
context: Context,
options?: OpenAIResponsesOptions,
): AssistantMessageEventStream => {
const stream = new AssistantMessageEventStream();
// Start async processing
(async () => {
const output: AssistantMessage = {
role: "assistant",
content: [],
api: model.api as Api,
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
};
try {
// Create OpenAI client
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
const client = createClient(model, context, apiKey, options?.headers);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
params = nextParams as ResponseCreateParamsStreaming;
}
const openaiStream = await client.responses.create(
params,
options?.signal ? { signal: options.signal } : undefined,
);
stream.push({ type: "start", partial: output });
await processResponsesStream(openaiStream, output, stream, model, {
serviceTier: options?.serviceTier,
applyServiceTierPricing,
});
if (options?.signal?.aborted) {
throw new Error("Request was aborted");
}
if (output.stopReason === "aborted" || output.stopReason === "error") {
throw new Error("An unknown error occurred");
}
stream.push({ type: "done", reason: output.stopReason, message: output });
stream.end();
} catch (error) {
for (const block of output.content) delete (block as { index?: number }).index;
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
stream.push({ type: "error", reason: output.stopReason, error: output });
stream.end();
}
})();
return stream;
};
export const streamSimpleOpenAIResponses: StreamFunction<"openai-responses", SimpleStreamOptions> = (
model: Model<"openai-responses">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream => {
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const base = buildBaseOptions(model, options, apiKey);
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
return streamOpenAIResponses(model, context, {
...base,
reasoningEffort,
} satisfies OpenAIResponsesOptions);
};
function createClient(
model: Model<"openai-responses">,
context: Context,
apiKey?: string,
optionsHeaders?: Record<string, string>,
) {
if (!apiKey) {
if (!process.env.OPENAI_API_KEY) {
throw new Error(
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
);
}
apiKey = process.env.OPENAI_API_KEY;
}
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
const hasImages = hasCopilotVisionInput(context.messages);
const copilotHeaders = buildCopilotDynamicHeaders({
messages: context.messages,
hasImages,
});
Object.assign(headers, copilotHeaders);
}
// Merge options headers last so they can override defaults
if (optionsHeaders) {
Object.assign(headers, optionsHeaders);
}
return new OpenAI({
apiKey,
baseURL: model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders: headers,
});
}
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const params: ResponseCreateParamsStreaming = {
model: model.id,
input: messages,
stream: true,
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
store: false,
};
if (options?.maxTokens) {
params.max_output_tokens = options?.maxTokens;
}
if (options?.temperature !== undefined) {
params.temperature = options?.temperature;
}
if (options?.serviceTier !== undefined) {
params.service_tier = options.serviceTier;
}
if (context.tools) {
params.tools = convertResponsesTools(context.tools);
}
if (model.reasoning) {
if (options?.reasoningEffort || options?.reasoningSummary) {
params.reasoning = {
effort: options?.reasoningEffort || "medium",
summary: options?.reasoningSummary || "auto",
};
params.include = ["reasoning.encrypted_content"];
} else {
if (model.name.startsWith("gpt-5")) {
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
messages.push({
role: "developer",
content: [
{
type: "input_text",
text: "# Juice: 0 !important",
},
],
});
}
}
}
return params;
}
function getServiceTierCostMultiplier(serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined): number {
switch (serviceTier) {
case "flex":
return 0.5;
case "priority":
return 2;
default:
return 1;
}
}
function applyServiceTierPricing(usage: Usage, serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined) {
const multiplier = getServiceTierCostMultiplier(serviceTier);
if (multiplier === 1) return;
usage.cost.input *= multiplier;
usage.cost.output *= multiplier;
usage.cost.cacheRead *= multiplier;
usage.cost.cacheWrite *= multiplier;
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
}

View file

@ -0,0 +1,186 @@
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
import type { AssistantMessage, AssistantMessageEvent, Context, Model, SimpleStreamOptions } from "../types.js";
import { AssistantMessageEventStream } from "../utils/event-stream.js";
import type { BedrockOptions } from "./amazon-bedrock.js";
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
import { streamAzureOpenAIResponses, streamSimpleAzureOpenAIResponses } from "./azure-openai-responses.js";
import { streamGoogle, streamSimpleGoogle } from "./google.js";
import { streamGoogleGeminiCli, streamSimpleGoogleGeminiCli } from "./google-gemini-cli.js";
import { streamGoogleVertex, streamSimpleGoogleVertex } from "./google-vertex.js";
import { streamMistral, streamSimpleMistral } from "./mistral.js";
import { streamOpenAICodexResponses, streamSimpleOpenAICodexResponses } from "./openai-codex-responses.js";
import { streamOpenAICompletions, streamSimpleOpenAICompletions } from "./openai-completions.js";
import { streamOpenAIResponses, streamSimpleOpenAIResponses } from "./openai-responses.js";
interface BedrockProviderModule {
streamBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: BedrockOptions,
) => AsyncIterable<AssistantMessageEvent>;
streamSimpleBedrock: (
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
) => AsyncIterable<AssistantMessageEvent>;
}
type DynamicImport = (specifier: string) => Promise<unknown>;
const dynamicImport: DynamicImport = (specifier) => import(specifier);
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
export function setBedrockProviderModule(module: BedrockProviderModule): void {
bedrockProviderModuleOverride = module;
}
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
if (bedrockProviderModuleOverride) {
return bedrockProviderModuleOverride;
}
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
return module as BedrockProviderModule;
}
function forwardStream(target: AssistantMessageEventStream, source: AsyncIterable<AssistantMessageEvent>): void {
(async () => {
for await (const event of source) {
target.push(event);
}
target.end();
})();
}
function createLazyLoadErrorMessage(model: Model<"bedrock-converse-stream">, error: unknown): AssistantMessage {
return {
role: "assistant",
content: [],
api: "bedrock-converse-stream",
provider: model.provider,
model: model.id,
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "error",
errorMessage: error instanceof Error ? error.message : String(error),
timestamp: Date.now(),
};
}
function streamBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: BedrockOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
function streamSimpleBedrockLazy(
model: Model<"bedrock-converse-stream">,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const outer = new AssistantMessageEventStream();
loadBedrockProviderModule()
.then((module) => {
const inner = module.streamSimpleBedrock(model, context, options);
forwardStream(outer, inner);
})
.catch((error) => {
const message = createLazyLoadErrorMessage(model, error);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
});
return outer;
}
export function registerBuiltInApiProviders(): void {
registerApiProvider({
api: "anthropic-messages",
stream: streamAnthropic,
streamSimple: streamSimpleAnthropic,
});
registerApiProvider({
api: "openai-completions",
stream: streamOpenAICompletions,
streamSimple: streamSimpleOpenAICompletions,
});
registerApiProvider({
api: "mistral-conversations",
stream: streamMistral,
streamSimple: streamSimpleMistral,
});
registerApiProvider({
api: "openai-responses",
stream: streamOpenAIResponses,
streamSimple: streamSimpleOpenAIResponses,
});
registerApiProvider({
api: "azure-openai-responses",
stream: streamAzureOpenAIResponses,
streamSimple: streamSimpleAzureOpenAIResponses,
});
registerApiProvider({
api: "openai-codex-responses",
stream: streamOpenAICodexResponses,
streamSimple: streamSimpleOpenAICodexResponses,
});
registerApiProvider({
api: "google-generative-ai",
stream: streamGoogle,
streamSimple: streamSimpleGoogle,
});
registerApiProvider({
api: "google-gemini-cli",
stream: streamGoogleGeminiCli,
streamSimple: streamSimpleGoogleGeminiCli,
});
registerApiProvider({
api: "google-vertex",
stream: streamGoogleVertex,
streamSimple: streamSimpleGoogleVertex,
});
registerApiProvider({
api: "bedrock-converse-stream",
stream: streamBedrockLazy,
streamSimple: streamSimpleBedrockLazy,
});
}
export function resetApiProviders(): void {
clearApiProviders();
registerBuiltInApiProviders();
}
registerBuiltInApiProviders();

View file

@ -0,0 +1,46 @@
import type { Api, Model, SimpleStreamOptions, StreamOptions, ThinkingBudgets, ThinkingLevel } from "../types.js";
export function buildBaseOptions(model: Model<Api>, options?: SimpleStreamOptions, apiKey?: string): StreamOptions {
return {
temperature: options?.temperature,
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
signal: options?.signal,
apiKey: apiKey || options?.apiKey,
cacheRetention: options?.cacheRetention,
sessionId: options?.sessionId,
headers: options?.headers,
onPayload: options?.onPayload,
maxRetryDelayMs: options?.maxRetryDelayMs,
metadata: options?.metadata,
};
}
export function clampReasoning(effort: ThinkingLevel | undefined): Exclude<ThinkingLevel, "xhigh"> | undefined {
return effort === "xhigh" ? "high" : effort;
}
export function adjustMaxTokensForThinking(
baseMaxTokens: number,
modelMaxTokens: number,
reasoningLevel: ThinkingLevel,
customBudgets?: ThinkingBudgets,
): { maxTokens: number; thinkingBudget: number } {
const defaultBudgets: ThinkingBudgets = {
minimal: 1024,
low: 2048,
medium: 8192,
high: 16384,
};
const budgets = { ...defaultBudgets, ...customBudgets };
const minOutputTokens = 1024;
const level = clampReasoning(reasoningLevel)!;
let thinkingBudget = budgets[level]!;
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
if (maxTokens <= thinkingBudget) {
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
}
return { maxTokens, thinkingBudget };
}

View file

@ -0,0 +1,172 @@
import type { Api, AssistantMessage, Message, Model, ToolCall, ToolResultMessage } from "../types.js";
/**
* Normalize tool call ID for cross-provider compatibility.
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
*/
export function transformMessages<TApi extends Api>(
messages: Message[],
model: Model<TApi>,
normalizeToolCallId?: (id: string, model: Model<TApi>, source: AssistantMessage) => string,
): Message[] {
// Build a map of original tool call IDs to normalized IDs
const toolCallIdMap = new Map<string, string>();
// First pass: transform messages (thinking blocks, tool call ID normalization)
const transformed = messages.map((msg) => {
// User messages pass through unchanged
if (msg.role === "user") {
return msg;
}
// Handle toolResult messages - normalize toolCallId if we have a mapping
if (msg.role === "toolResult") {
const normalizedId = toolCallIdMap.get(msg.toolCallId);
if (normalizedId && normalizedId !== msg.toolCallId) {
return { ...msg, toolCallId: normalizedId };
}
return msg;
}
// Assistant messages need transformation check
if (msg.role === "assistant") {
const assistantMsg = msg as AssistantMessage;
const isSameModel =
assistantMsg.provider === model.provider &&
assistantMsg.api === model.api &&
assistantMsg.model === model.id;
const transformedContent = assistantMsg.content.flatMap((block) => {
if (block.type === "thinking") {
// Redacted thinking is opaque encrypted content, only valid for the same model.
// Drop it for cross-model to avoid API errors.
if (block.redacted) {
return isSameModel ? block : [];
}
// For same model: keep thinking blocks with signatures (needed for replay)
// even if the thinking text is empty (OpenAI encrypted reasoning)
if (isSameModel && block.thinkingSignature) return block;
// Skip empty thinking blocks, convert others to plain text
if (!block.thinking || block.thinking.trim() === "") return [];
if (isSameModel) return block;
return {
type: "text" as const,
text: block.thinking,
};
}
if (block.type === "text") {
if (isSameModel) return block;
return {
type: "text" as const,
text: block.text,
};
}
if (block.type === "toolCall") {
const toolCall = block as ToolCall;
let normalizedToolCall: ToolCall = toolCall;
if (!isSameModel && toolCall.thoughtSignature) {
normalizedToolCall = { ...toolCall };
delete (normalizedToolCall as { thoughtSignature?: string }).thoughtSignature;
}
if (!isSameModel && normalizeToolCallId) {
const normalizedId = normalizeToolCallId(toolCall.id, model, assistantMsg);
if (normalizedId !== toolCall.id) {
toolCallIdMap.set(toolCall.id, normalizedId);
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
}
}
return normalizedToolCall;
}
return block;
});
return {
...assistantMsg,
content: transformedContent,
};
}
return msg;
});
// Second pass: insert synthetic empty tool results for orphaned tool calls
// This preserves thinking signatures and satisfies API requirements
const result: Message[] = [];
let pendingToolCalls: ToolCall[] = [];
let existingToolResultIds = new Set<string>();
for (let i = 0; i < transformed.length; i++) {
const msg = transformed[i];
if (msg.role === "assistant") {
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
// Skip errored/aborted assistant messages entirely.
// These are incomplete turns that shouldn't be replayed:
// - May have partial content (reasoning without message, incomplete tool calls)
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
// - The model should retry from the last valid state
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason === "error" || assistantMsg.stopReason === "aborted") {
continue;
}
// Track tool calls from this assistant message
const toolCalls = assistantMsg.content.filter((b) => b.type === "toolCall") as ToolCall[];
if (toolCalls.length > 0) {
pendingToolCalls = toolCalls;
existingToolResultIds = new Set();
}
result.push(msg);
} else if (msg.role === "toolResult") {
existingToolResultIds.add(msg.toolCallId);
result.push(msg);
} else if (msg.role === "user") {
// User message interrupts tool flow - insert synthetic results for orphaned calls
if (pendingToolCalls.length > 0) {
for (const tc of pendingToolCalls) {
if (!existingToolResultIds.has(tc.id)) {
result.push({
role: "toolResult",
toolCallId: tc.id,
toolName: tc.name,
content: [{ type: "text", text: "No result provided" }],
isError: true,
timestamp: Date.now(),
} as ToolResultMessage);
}
}
pendingToolCalls = [];
existingToolResultIds = new Set();
}
result.push(msg);
} else {
result.push(msg);
}
}
return result;
}

View file

@ -0,0 +1,59 @@
import "./providers/register-builtins.js";
import { getApiProvider } from "./api-registry.js";
import type {
Api,
AssistantMessage,
AssistantMessageEventStream,
Context,
Model,
ProviderStreamOptions,
SimpleStreamOptions,
StreamOptions,
} from "./types.js";
export { getEnvApiKey } from "./env-api-keys.js";
function resolveApiProvider(api: Api) {
const provider = getApiProvider(api);
if (!provider) {
throw new Error(`No API provider registered for api: ${api}`);
}
return provider;
}
export function stream<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.stream(model, context, options as StreamOptions);
}
export async function complete<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: ProviderStreamOptions,
): Promise<AssistantMessage> {
const s = stream(model, context, options);
return s.result();
}
export function streamSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): AssistantMessageEventStream {
const provider = resolveApiProvider(model.api);
return provider.streamSimple(model, context, options);
}
export async function completeSimple<TApi extends Api>(
model: Model<TApi>,
context: Context,
options?: SimpleStreamOptions,
): Promise<AssistantMessage> {
const s = streamSimple(model, context, options);
return s.result();
}

321
packages/pi-ai/src/types.ts Normal file
View file

@ -0,0 +1,321 @@
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
export type KnownApi =
| "openai-completions"
| "mistral-conversations"
| "openai-responses"
| "azure-openai-responses"
| "openai-codex-responses"
| "anthropic-messages"
| "bedrock-converse-stream"
| "google-generative-ai"
| "google-gemini-cli"
| "google-vertex";
export type Api = KnownApi | (string & {});
export type KnownProvider =
| "amazon-bedrock"
| "anthropic"
| "google"
| "google-gemini-cli"
| "google-antigravity"
| "google-vertex"
| "openai"
| "azure-openai-responses"
| "openai-codex"
| "github-copilot"
| "xai"
| "groq"
| "cerebras"
| "openrouter"
| "vercel-ai-gateway"
| "zai"
| "mistral"
| "minimax"
| "minimax-cn"
| "huggingface"
| "opencode"
| "opencode-go"
| "kimi-coding";
export type Provider = KnownProvider | string;
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
/** Token budgets for each thinking level (token-based providers only) */
export interface ThinkingBudgets {
minimal?: number;
low?: number;
medium?: number;
high?: number;
}
// Base options all providers share
export type CacheRetention = "none" | "short" | "long";
export type Transport = "sse" | "websocket" | "auto";
export interface StreamOptions {
temperature?: number;
maxTokens?: number;
signal?: AbortSignal;
apiKey?: string;
/**
* Preferred transport for providers that support multiple transports.
* Providers that do not support this option ignore it.
*/
transport?: Transport;
/**
* Prompt cache retention preference. Providers map this to their supported values.
* Default: "short".
*/
cacheRetention?: CacheRetention;
/**
* Optional session identifier for providers that support session-based caching.
* Providers can use this to enable prompt caching, request routing, or other
* session-aware features. Ignored by providers that don't support it.
*/
sessionId?: string;
/**
* Optional callback for inspecting or replacing provider payloads before sending.
* Return undefined to keep the payload unchanged.
*/
onPayload?: (payload: unknown, model: Model<Api>) => unknown | undefined | Promise<unknown | undefined>;
/**
* Optional custom HTTP headers to include in API requests.
* Merged with provider defaults; can override default headers.
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
*/
headers?: Record<string, string>;
/**
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
* If the server's requested delay exceeds this value, the request fails immediately
* with an error containing the requested delay, allowing higher-level retry logic
* to handle it with user visibility.
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
*/
maxRetryDelayMs?: number;
/**
* Optional metadata to include in API requests.
* Providers extract the fields they understand and ignore the rest.
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
*/
metadata?: Record<string, unknown>;
}
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
// Unified options with reasoning passed to streamSimple() and completeSimple()
export interface SimpleStreamOptions extends StreamOptions {
reasoning?: ThinkingLevel;
/** Custom token budgets for thinking levels (token-based providers only) */
thinkingBudgets?: ThinkingBudgets;
}
// Generic StreamFunction with typed options
export type StreamFunction<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> = (
model: Model<TApi>,
context: Context,
options?: TOptions,
) => AssistantMessageEventStream;
export interface TextSignatureV1 {
v: 1;
id: string;
phase?: "commentary" | "final_answer";
}
export interface TextContent {
type: "text";
text: string;
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
}
export interface ThinkingContent {
type: "thinking";
thinking: string;
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
/** When true, the thinking content was redacted by safety filters. The opaque
* encrypted payload is stored in `thinkingSignature` so it can be passed back
* to the API for multi-turn continuity. */
redacted?: boolean;
}
export interface ImageContent {
type: "image";
data: string; // base64 encoded image data
mimeType: string; // e.g., "image/jpeg", "image/png"
}
export interface ToolCall {
type: "toolCall";
id: string;
name: string;
arguments: Record<string, any>;
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
}
export interface Usage {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
totalTokens: number;
cost: {
input: number;
output: number;
cacheRead: number;
cacheWrite: number;
total: number;
};
}
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
export interface UserMessage {
role: "user";
content: string | (TextContent | ImageContent)[];
timestamp: number; // Unix timestamp in milliseconds
}
export interface AssistantMessage {
role: "assistant";
content: (TextContent | ThinkingContent | ToolCall)[];
api: Api;
provider: Provider;
model: string;
usage: Usage;
stopReason: StopReason;
errorMessage?: string;
timestamp: number; // Unix timestamp in milliseconds
}
export interface ToolResultMessage<TDetails = any> {
role: "toolResult";
toolCallId: string;
toolName: string;
content: (TextContent | ImageContent)[]; // Supports text and images
details?: TDetails;
isError: boolean;
timestamp: number; // Unix timestamp in milliseconds
}
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
import type { TSchema } from "@sinclair/typebox";
export interface Tool<TParameters extends TSchema = TSchema> {
name: string;
description: string;
parameters: TParameters;
}
export interface Context {
systemPrompt?: string;
messages: Message[];
tools?: Tool[];
}
export type AssistantMessageEvent =
| { type: "start"; partial: AssistantMessage }
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
| { type: "text_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
| { type: "text_end"; contentIndex: number; content: string; partial: AssistantMessage }
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
| { type: "thinking_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
| { type: "thinking_end"; contentIndex: number; content: string; partial: AssistantMessage }
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
| { type: "toolcall_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
| { type: "toolcall_end"; contentIndex: number; toolCall: ToolCall; partial: AssistantMessage }
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; message: AssistantMessage }
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; error: AssistantMessage };
/**
* Compatibility settings for OpenAI-compatible completions APIs.
* Use this to override URL-based auto-detection for custom providers.
*/
export interface OpenAICompletionsCompat {
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
supportsStore?: boolean;
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
supportsDeveloperRole?: boolean;
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
supportsReasoningEffort?: boolean;
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
supportsUsageInStreaming?: boolean;
/** Which field to use for max tokens. Default: auto-detected from URL. */
maxTokensField?: "max_completion_tokens" | "max_tokens";
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
requiresToolResultName?: boolean;
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
requiresAssistantAfterToolResult?: boolean;
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
requiresThinkingAsText?: boolean;
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
thinkingFormat?: "openai" | "zai" | "qwen";
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
openRouterRouting?: OpenRouterRouting;
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
vercelGatewayRouting?: VercelGatewayRouting;
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
supportsStrictMode?: boolean;
}
/** Compatibility settings for OpenAI Responses APIs. */
export interface OpenAIResponsesCompat {
// Reserved for future use
}
/**
* OpenRouter provider routing preferences.
* Controls which upstream providers OpenRouter routes requests to.
* @see https://openrouter.ai/docs/provider-routing
*/
export interface OpenRouterRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
/**
* Vercel AI Gateway routing preferences.
* Controls which upstream providers the gateway routes requests to.
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
*/
export interface VercelGatewayRouting {
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
only?: string[];
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
order?: string[];
}
// Model interface for the unified model system
export interface Model<TApi extends Api> {
id: string;
name: string;
api: TApi;
provider: Provider;
baseUrl: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: {
input: number; // $/million tokens
output: number; // $/million tokens
cacheRead: number; // $/million tokens
cacheWrite: number; // $/million tokens
};
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
compat?: TApi extends "openai-completions"
? OpenAICompletionsCompat
: TApi extends "openai-responses"
? OpenAIResponsesCompat
: never;
}

View file

@ -0,0 +1,87 @@
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
// Generic event stream class for async iteration
export class EventStream<T, R = T> implements AsyncIterable<T> {
private queue: T[] = [];
private waiting: ((value: IteratorResult<T>) => void)[] = [];
private done = false;
private finalResultPromise: Promise<R>;
private resolveFinalResult!: (result: R) => void;
constructor(
private isComplete: (event: T) => boolean,
private extractResult: (event: T) => R,
) {
this.finalResultPromise = new Promise((resolve) => {
this.resolveFinalResult = resolve;
});
}
push(event: T): void {
if (this.done) return;
if (this.isComplete(event)) {
this.done = true;
this.resolveFinalResult(this.extractResult(event));
}
// Deliver to waiting consumer or queue it
const waiter = this.waiting.shift();
if (waiter) {
waiter({ value: event, done: false });
} else {
this.queue.push(event);
}
}
end(result?: R): void {
this.done = true;
if (result !== undefined) {
this.resolveFinalResult(result);
}
// Notify all waiting consumers that we're done
while (this.waiting.length > 0) {
const waiter = this.waiting.shift()!;
waiter({ value: undefined as any, done: true });
}
}
async *[Symbol.asyncIterator](): AsyncIterator<T> {
while (true) {
if (this.queue.length > 0) {
yield this.queue.shift()!;
} else if (this.done) {
return;
} else {
const result = await new Promise<IteratorResult<T>>((resolve) => this.waiting.push(resolve));
if (result.done) return;
yield result.value;
}
}
}
result(): Promise<R> {
return this.finalResultPromise;
}
}
export class AssistantMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
constructor() {
super(
(event) => event.type === "done" || event.type === "error",
(event) => {
if (event.type === "done") {
return event.message;
} else if (event.type === "error") {
return event.error;
}
throw new Error("Unexpected event type for final result");
},
);
}
}
/** Factory function for AssistantMessageEventStream (for use in extensions) */
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
return new AssistantMessageEventStream();
}

View file

@ -0,0 +1,13 @@
/** Fast deterministic hash to shorten long strings */
export function shortHash(str: string): string {
let h1 = 0xdeadbeef;
let h2 = 0x41c6ce57;
for (let i = 0; i < str.length; i++) {
const ch = str.charCodeAt(i);
h1 = Math.imul(h1 ^ ch, 2654435761);
h2 = Math.imul(h2 ^ ch, 1597334677);
}
h1 = Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^ Math.imul(h2 ^ (h2 >>> 13), 3266489909);
h2 = Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^ Math.imul(h1 ^ (h1 >>> 13), 3266489909);
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
}

View file

@ -0,0 +1,28 @@
import { parse as partialParse } from "partial-json";
/**
* Attempts to parse potentially incomplete JSON during streaming.
* Always returns a valid object, even if the JSON is incomplete.
*
* @param partialJson The partial JSON string from streaming
* @returns Parsed object or empty object if parsing fails
*/
export function parseStreamingJson<T = any>(partialJson: string | undefined): T {
if (!partialJson || partialJson.trim() === "") {
return {} as T;
}
// Try standard parsing first (fastest for complete JSON)
try {
return JSON.parse(partialJson) as T;
} catch {
// Try partial-json for incomplete JSON
try {
const result = partialParse(partialJson);
return (result ?? {}) as T;
} catch {
// If all parsing fails, return empty object
return {} as T;
}
}
}

View file

@ -0,0 +1,138 @@
/**
* Anthropic OAuth flow (Claude Pro/Max)
*/
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
const SCOPES = "org:create_api_key user:profile user:inference";
/**
* Login with Anthropic OAuth (device code flow)
*
* @param onAuthUrl - Callback to handle the authorization URL (e.g., open browser)
* @param onPromptCode - Callback to prompt user for the authorization code
*/
export async function loginAnthropic(
onAuthUrl: (url: string) => void,
onPromptCode: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Build authorization URL
const authParams = new URLSearchParams({
code: "true",
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES,
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
});
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuthUrl(authUrl);
// Wait for user to paste authorization code (format: code#state)
const authCode = await onPromptCode();
const splits = authCode.split("#");
const code = splits[0];
const state = splits[1];
// Exchange code for tokens
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code: code,
state: state,
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
// Save credentials
return {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
};
}
/**
* Refresh Anthropic OAuth token
*/
export async function refreshAnthropicToken(refreshToken: string): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
grant_type: "refresh_token",
client_id: CLIENT_ID,
refresh_token: refreshToken,
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Anthropic token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
return {
refresh: data.refresh_token,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
};
}
export const anthropicOAuthProvider: OAuthProviderInterface = {
id: "anthropic",
name: "Anthropic (Claude Pro/Max)",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAnthropic(
(url) => callbacks.onAuth({ url }),
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshAnthropicToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,381 @@
/**
* GitHub Copilot OAuth flow
*/
import { getModels } from "../../models.js";
import type { Api, Model } from "../../types.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type CopilotCredentials = OAuthCredentials & {
enterpriseUrl?: string;
};
const decode = (s: string) => atob(s);
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
const COPILOT_HEADERS = {
"User-Agent": "GitHubCopilotChat/0.35.0",
"Editor-Version": "vscode/1.107.0",
"Editor-Plugin-Version": "copilot-chat/0.35.0",
"Copilot-Integration-Id": "vscode-chat",
} as const;
type DeviceCodeResponse = {
device_code: string;
user_code: string;
verification_uri: string;
interval: number;
expires_in: number;
};
type DeviceTokenSuccessResponse = {
access_token: string;
token_type?: string;
scope?: string;
};
type DeviceTokenErrorResponse = {
error: string;
error_description?: string;
interval?: number;
};
export function normalizeDomain(input: string): string | null {
const trimmed = input.trim();
if (!trimmed) return null;
try {
const url = trimmed.includes("://") ? new URL(trimmed) : new URL(`https://${trimmed}`);
return url.hostname;
} catch {
return null;
}
}
function getUrls(domain: string): {
deviceCodeUrl: string;
accessTokenUrl: string;
copilotTokenUrl: string;
} {
return {
deviceCodeUrl: `https://${domain}/login/device/code`,
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
};
}
/**
* Parse the proxy-ep from a Copilot token and convert to API base URL.
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
* Returns API URL like https://api.individual.githubcopilot.com
*/
function getBaseUrlFromToken(token: string): string | null {
const match = token.match(/proxy-ep=([^;]+)/);
if (!match) return null;
const proxyHost = match[1];
// Convert proxy.xxx to api.xxx
const apiHost = proxyHost.replace(/^proxy\./, "api.");
return `https://${apiHost}`;
}
export function getGitHubCopilotBaseUrl(token?: string, enterpriseDomain?: string): string {
// If we have a token, extract the base URL from proxy-ep
if (token) {
const urlFromToken = getBaseUrlFromToken(token);
if (urlFromToken) return urlFromToken;
}
// Fallback for enterprise or if token parsing fails
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
return "https://api.individual.githubcopilot.com";
}
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
const response = await fetch(url, init);
if (!response.ok) {
const text = await response.text();
throw new Error(`${response.status} ${response.statusText}: ${text}`);
}
return response.json();
}
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
const urls = getUrls(domain);
const data = await fetchJson(urls.deviceCodeUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
scope: "read:user",
}),
});
if (!data || typeof data !== "object") {
throw new Error("Invalid device code response");
}
const deviceCode = (data as Record<string, unknown>).device_code;
const userCode = (data as Record<string, unknown>).user_code;
const verificationUri = (data as Record<string, unknown>).verification_uri;
const interval = (data as Record<string, unknown>).interval;
const expiresIn = (data as Record<string, unknown>).expires_in;
if (
typeof deviceCode !== "string" ||
typeof userCode !== "string" ||
typeof verificationUri !== "string" ||
typeof interval !== "number" ||
typeof expiresIn !== "number"
) {
throw new Error("Invalid device code response fields");
}
return {
device_code: deviceCode,
user_code: userCode,
verification_uri: verificationUri,
interval,
expires_in: expiresIn,
};
}
/**
* Sleep that can be interrupted by an AbortSignal
*/
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
return new Promise((resolve, reject) => {
if (signal?.aborted) {
reject(new Error("Login cancelled"));
return;
}
const timeout = setTimeout(resolve, ms);
signal?.addEventListener(
"abort",
() => {
clearTimeout(timeout);
reject(new Error("Login cancelled"));
},
{ once: true },
);
});
}
async function pollForGitHubAccessToken(
domain: string,
deviceCode: string,
intervalSeconds: number,
expiresIn: number,
signal?: AbortSignal,
) {
const urls = getUrls(domain);
const deadline = Date.now() + expiresIn * 1000;
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
while (Date.now() < deadline) {
if (signal?.aborted) {
throw new Error("Login cancelled");
}
const raw = await fetchJson(urls.accessTokenUrl, {
method: "POST",
headers: {
Accept: "application/json",
"Content-Type": "application/json",
"User-Agent": "GitHubCopilotChat/0.35.0",
},
body: JSON.stringify({
client_id: CLIENT_ID,
device_code: deviceCode,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
}),
});
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenSuccessResponse).access_token === "string") {
return (raw as DeviceTokenSuccessResponse).access_token;
}
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenErrorResponse).error === "string") {
const err = (raw as DeviceTokenErrorResponse).error;
if (err === "authorization_pending") {
await abortableSleep(intervalMs, signal);
continue;
}
if (err === "slow_down") {
intervalMs += 5000;
await abortableSleep(intervalMs, signal);
continue;
}
throw new Error(`Device flow failed: ${err}`);
}
await abortableSleep(intervalMs, signal);
}
throw new Error("Device flow timed out");
}
/**
* Refresh GitHub Copilot token
*/
export async function refreshGitHubCopilotToken(
refreshToken: string,
enterpriseDomain?: string,
): Promise<OAuthCredentials> {
const domain = enterpriseDomain || "github.com";
const urls = getUrls(domain);
const raw = await fetchJson(urls.copilotTokenUrl, {
headers: {
Accept: "application/json",
Authorization: `Bearer ${refreshToken}`,
...COPILOT_HEADERS,
},
});
if (!raw || typeof raw !== "object") {
throw new Error("Invalid Copilot token response");
}
const token = (raw as Record<string, unknown>).token;
const expiresAt = (raw as Record<string, unknown>).expires_at;
if (typeof token !== "string" || typeof expiresAt !== "number") {
throw new Error("Invalid Copilot token response fields");
}
return {
refresh: refreshToken,
access: token,
expires: expiresAt * 1000 - 5 * 60 * 1000,
enterpriseUrl: enterpriseDomain,
};
}
/**
* Enable a model for the user's GitHub Copilot account.
* This is required for some models (like Claude, Grok) before they can be used.
*/
async function enableGitHubCopilotModel(token: string, modelId: string, enterpriseDomain?: string): Promise<boolean> {
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
const url = `${baseUrl}/models/${modelId}/policy`;
try {
const response = await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
Authorization: `Bearer ${token}`,
...COPILOT_HEADERS,
"openai-intent": "chat-policy",
"x-interaction-type": "chat-policy",
},
body: JSON.stringify({ state: "enabled" }),
});
return response.ok;
} catch {
return false;
}
}
/**
* Enable all known GitHub Copilot models that may require policy acceptance.
* Called after successful login to ensure all models are available.
*/
async function enableAllGitHubCopilotModels(
token: string,
enterpriseDomain?: string,
onProgress?: (model: string, success: boolean) => void,
): Promise<void> {
const models = getModels("github-copilot");
await Promise.all(
models.map(async (model) => {
const success = await enableGitHubCopilotModel(token, model.id, enterpriseDomain);
onProgress?.(model.id, success);
}),
);
}
/**
* Login with GitHub Copilot OAuth (device code flow)
*
* @param options.onAuth - Callback with URL and optional instructions (user code)
* @param options.onPrompt - Callback to prompt user for input
* @param options.onProgress - Optional progress callback
* @param options.signal - Optional AbortSignal for cancellation
*/
export async function loginGitHubCopilot(options: {
onAuth: (url: string, instructions?: string) => void;
onPrompt: (prompt: { message: string; placeholder?: string; allowEmpty?: boolean }) => Promise<string>;
onProgress?: (message: string) => void;
signal?: AbortSignal;
}): Promise<OAuthCredentials> {
const input = await options.onPrompt({
message: "GitHub Enterprise URL/domain (blank for github.com)",
placeholder: "company.ghe.com",
allowEmpty: true,
});
if (options.signal?.aborted) {
throw new Error("Login cancelled");
}
const trimmed = input.trim();
const enterpriseDomain = normalizeDomain(input);
if (trimmed && !enterpriseDomain) {
throw new Error("Invalid GitHub Enterprise URL/domain");
}
const domain = enterpriseDomain || "github.com";
const device = await startDeviceFlow(domain);
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
const githubAccessToken = await pollForGitHubAccessToken(
domain,
device.device_code,
device.interval,
device.expires_in,
options.signal,
);
const credentials = await refreshGitHubCopilotToken(githubAccessToken, enterpriseDomain ?? undefined);
// Enable all models after successful login
options.onProgress?.("Enabling models...");
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
return credentials;
}
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
id: "github-copilot",
name: "GitHub Copilot",
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGitHubCopilot({
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
signal: callbacks.signal,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as CopilotCredentials;
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
modifyModels(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[] {
const creds = credentials as CopilotCredentials;
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
},
};

View file

@ -0,0 +1,457 @@
/**
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type AntigravityCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
// Antigravity OAuth credentials (different from Gemini CLI)
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
);
const CLIENT_SECRET = decode("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=");
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
// Antigravity requires additional scopes
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
// Fallback project ID when discovery fails
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error("Antigravity OAuth is only available in Node.js environments");
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:51121`);
if (url.pathname === "/oauth-callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(51121, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string | { id?: string };
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Discover or provision a project for the user
*/
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
"Client-Metadata": JSON.stringify({
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
}),
};
// Try endpoints in order: prod first, then sandbox
const endpoints = ["https://cloudcode-pa.googleapis.com", "https://daily-cloudcode-pa.sandbox.googleapis.com"];
onProgress?.("Checking for existing project...");
for (const endpoint of endpoints) {
try {
const loadResponse = await fetch(`${endpoint}/v1internal:loadCodeAssist`, {
method: "POST",
headers,
body: JSON.stringify({
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
}),
});
if (loadResponse.ok) {
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
// Handle both string and object formats
if (typeof data.cloudaicompanionProject === "string" && data.cloudaicompanionProject) {
return data.cloudaicompanionProject;
}
if (
data.cloudaicompanionProject &&
typeof data.cloudaicompanionProject === "object" &&
data.cloudaicompanionProject.id
) {
return data.cloudaicompanionProject.id;
}
}
} catch {
// Try next endpoint
}
}
// Use fallback project ID
onProgress?.("Using default project...");
return DEFAULT_PROJECT_ID;
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Antigravity token
*/
export async function refreshAntigravityToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Antigravity token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Antigravity OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginAntigravity(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const antigravityOAuthProvider: OAuthProviderInterface = {
id: "google-antigravity",
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as AntigravityCredentials;
if (!creds.projectId) {
throw new Error("Antigravity credentials missing projectId");
}
return refreshAntigravityToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as AntigravityCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,599 @@
/**
* Gemini CLI OAuth flow (Google Cloud Code Assist)
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
*
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
import type { Server } from "node:http";
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
type GeminiCredentials = OAuthCredentials & {
projectId: string;
};
let _createServer: typeof import("node:http").createServer | null = null;
let _httpImportPromise: Promise<void> | null = null;
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
_httpImportPromise = import("node:http").then((m) => {
_createServer = m.createServer;
});
}
const decode = (s: string) => atob(s);
const CLIENT_ID = decode(
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
);
const CLIENT_SECRET = decode("R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=");
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
const SCOPES = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
];
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
const TOKEN_URL = "https://oauth2.googleapis.com/token";
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
type CallbackServerInfo = {
server: Server;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string; state: string } | null>;
};
/**
* Start a local HTTP server to receive the OAuth callback
*/
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
if (_createServer) return _createServer;
if (_httpImportPromise) {
await _httpImportPromise;
}
if (_createServer) return _createServer;
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
}
async function startCallbackServer(): Promise<CallbackServerInfo> {
const createServer = await getNodeCreateServer();
return new Promise((resolve, reject) => {
let result: { code: string; state: string } | null = null;
let cancelled = false;
const server = createServer((req, res) => {
const url = new URL(req.url || "", `http://localhost:8085`);
if (url.pathname === "/oauth2callback") {
const code = url.searchParams.get("code");
const state = url.searchParams.get("state");
const error = url.searchParams.get("error");
if (error) {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
);
return;
}
if (code && state) {
res.writeHead(200, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
);
result = { code, state };
} else {
res.writeHead(400, { "Content-Type": "text/html" });
res.end(
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
);
}
} else {
res.writeHead(404);
res.end();
}
});
server.on("error", (err) => {
reject(err);
});
server.listen(8085, "127.0.0.1", () => {
resolve({
server,
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
while (!result && !cancelled) {
await sleep();
}
return result;
},
});
});
});
}
/**
* Parse redirect URL to extract code and state
*/
function parseRedirectUrl(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// Not a URL, return empty
return {};
}
}
interface LoadCodeAssistPayload {
cloudaicompanionProject?: string;
currentTier?: { id?: string };
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
}
/**
* Long-running operation response from onboardUser
*/
interface LongRunningOperationResponse {
name?: string;
done?: boolean;
response?: {
cloudaicompanionProject?: { id?: string };
};
}
// Tier IDs as used by the Cloud Code API
const TIER_FREE = "free-tier";
const TIER_LEGACY = "legacy-tier";
const TIER_STANDARD = "standard-tier";
interface GoogleRpcErrorResponse {
error?: {
details?: Array<{ reason?: string }>;
};
}
/**
* Wait helper for onboarding retries
*/
function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms));
}
/**
* Get default tier from allowed tiers
*/
function getDefaultTier(allowedTiers?: Array<{ id?: string; isDefault?: boolean }>): { id?: string } {
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
const defaultTier = allowedTiers.find((t) => t.isDefault);
return defaultTier ?? { id: TIER_LEGACY };
}
function isVpcScAffectedUser(payload: unknown): boolean {
if (!payload || typeof payload !== "object") return false;
if (!("error" in payload)) return false;
const error = (payload as GoogleRpcErrorResponse).error;
if (!error?.details || !Array.isArray(error.details)) return false;
return error.details.some((detail) => detail.reason === "SECURITY_POLICY_VIOLATED");
}
/**
* Poll a long-running operation until completion
*/
async function pollOperation(
operationName: string,
headers: Record<string, string>,
onProgress?: (message: string) => void,
): Promise<LongRunningOperationResponse> {
let attempt = 0;
while (true) {
if (attempt > 0) {
onProgress?.(`Waiting for project provisioning (attempt ${attempt + 1})...`);
await wait(5000);
}
const response = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`, {
method: "GET",
headers,
});
if (!response.ok) {
throw new Error(`Failed to poll operation: ${response.status} ${response.statusText}`);
}
const data = (await response.json()) as LongRunningOperationResponse;
if (data.done) {
return data;
}
attempt += 1;
}
}
/**
* Discover or provision a Google Cloud project for the user
*/
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
// Check for user-provided project ID via environment variable
const envProjectId = process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
const headers = {
Authorization: `Bearer ${accessToken}`,
"Content-Type": "application/json",
"User-Agent": "google-api-nodejs-client/9.15.1",
"X-Goog-Api-Client": "gl-node/22.17.0",
};
// Try to load existing project via loadCodeAssist
onProgress?.("Checking for existing Cloud Code Assist project...");
const loadResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`, {
method: "POST",
headers,
body: JSON.stringify({
cloudaicompanionProject: envProjectId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
duetProject: envProjectId,
},
}),
});
let data: LoadCodeAssistPayload;
if (!loadResponse.ok) {
let errorPayload: unknown;
try {
errorPayload = await loadResponse.clone().json();
} catch {
errorPayload = undefined;
}
if (isVpcScAffectedUser(errorPayload)) {
data = { currentTier: { id: TIER_STANDARD } };
} else {
const errorText = await loadResponse.text();
throw new Error(`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`);
}
} else {
data = (await loadResponse.json()) as LoadCodeAssistPayload;
}
// If user already has a current tier and project, use it
if (data.currentTier) {
if (data.cloudaicompanionProject) {
return data.cloudaicompanionProject;
}
// User has a tier but no managed project - they need to provide one via env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
// User needs to be onboarded - get the default tier
const tier = getDefaultTier(data.allowedTiers);
const tierId = tier?.id ?? TIER_FREE;
if (tierId !== TIER_FREE && !envProjectId) {
throw new Error(
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
onProgress?.("Provisioning Cloud Code Assist project (this may take a moment)...");
// Build onboard request - for free tier, don't include project ID (Google provisions one)
// For other tiers, include the user's project ID if available
const onboardBody: Record<string, unknown> = {
tierId,
metadata: {
ideType: "IDE_UNSPECIFIED",
platform: "PLATFORM_UNSPECIFIED",
pluginType: "GEMINI",
},
};
if (tierId !== TIER_FREE && envProjectId) {
onboardBody.cloudaicompanionProject = envProjectId;
(onboardBody.metadata as Record<string, unknown>).duetProject = envProjectId;
}
// Start onboarding - this returns a long-running operation
const onboardResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`, {
method: "POST",
headers,
body: JSON.stringify(onboardBody),
});
if (!onboardResponse.ok) {
const errorText = await onboardResponse.text();
throw new Error(`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`);
}
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
// If the operation isn't done yet, poll until completion
if (!lroData.done && lroData.name) {
lroData = await pollOperation(lroData.name, headers, onProgress);
}
// Try to get project ID from the response
const projectId = lroData.response?.cloudaicompanionProject?.id;
if (projectId) {
return projectId;
}
// If no project ID from onboarding, fall back to env var
if (envProjectId) {
return envProjectId;
}
throw new Error(
"Could not discover or provision a Google Cloud project. " +
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
);
}
/**
* Get user email from the access token
*/
async function getUserEmail(accessToken: string): Promise<string | undefined> {
try {
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
headers: {
Authorization: `Bearer ${accessToken}`,
},
});
if (response.ok) {
const data = (await response.json()) as { email?: string };
return data.email;
}
} catch {
// Ignore errors, email is optional
}
return undefined;
}
/**
* Refresh Google Cloud Code Assist token
*/
export async function refreshGoogleCloudToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
refresh_token: refreshToken,
grant_type: "refresh_token",
}),
});
if (!response.ok) {
const error = await response.text();
throw new Error(`Google Cloud token refresh failed: ${error}`);
}
const data = (await response.json()) as {
access_token: string;
expires_in: number;
refresh_token?: string;
};
return {
refresh: data.refresh_token || refreshToken,
access: data.access_token,
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
projectId,
};
}
/**
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
*
* @param onAuth - Callback with URL and optional instructions
* @param onProgress - Optional progress callback
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
* Races with browser callback - whichever completes first wins.
*/
export async function loginGeminiCli(
onAuth: (info: { url: string; instructions?: string }) => void,
onProgress?: (message: string) => void,
onManualCodeInput?: () => Promise<string>,
): Promise<OAuthCredentials> {
const { verifier, challenge } = await generatePKCE();
// Start local server for callback
onProgress?.("Starting local server for OAuth callback...");
const server = await startCallbackServer();
let code: string | undefined;
try {
// Build authorization URL
const authParams = new URLSearchParams({
client_id: CLIENT_ID,
response_type: "code",
redirect_uri: REDIRECT_URI,
scope: SCOPES.join(" "),
code_challenge: challenge,
code_challenge_method: "S256",
state: verifier,
access_type: "offline",
prompt: "consent",
});
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
// Notify caller with URL to open
onAuth({
url: authUrl,
instructions: "Complete the sign-in in your browser.",
});
// Wait for the callback, racing with manual input if provided
onProgress?.("Waiting for OAuth callback...");
if (onManualCodeInput) {
// Race between browser callback and manual input
let manualInput: string | undefined;
let manualError: Error | undefined;
const manualPromise = onManualCodeInput()
.then((input) => {
manualInput = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won - verify state
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
} else if (manualInput) {
// Manual input won
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
// If still no code, wait for manual promise and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualInput) {
const parsed = parseRedirectUrl(manualInput);
if (parsed.state && parsed.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = parsed.code;
}
}
} else {
// Original flow: just wait for callback
const result = await server.waitForCode();
if (result?.code) {
if (result.state !== verifier) {
throw new Error("OAuth state mismatch - possible CSRF attack");
}
code = result.code;
}
}
if (!code) {
throw new Error("No authorization code received");
}
// Exchange code for tokens
onProgress?.("Exchanging authorization code for tokens...");
const tokenResponse = await fetch(TOKEN_URL, {
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
body: new URLSearchParams({
client_id: CLIENT_ID,
client_secret: CLIENT_SECRET,
code,
grant_type: "authorization_code",
redirect_uri: REDIRECT_URI,
code_verifier: verifier,
}),
});
if (!tokenResponse.ok) {
const error = await tokenResponse.text();
throw new Error(`Token exchange failed: ${error}`);
}
const tokenData = (await tokenResponse.json()) as {
access_token: string;
refresh_token: string;
expires_in: number;
};
if (!tokenData.refresh_token) {
throw new Error("No refresh token received. Please try again.");
}
// Get user email
onProgress?.("Getting user info...");
const email = await getUserEmail(tokenData.access_token);
// Discover project
const projectId = await discoverProject(tokenData.access_token, onProgress);
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
const credentials: OAuthCredentials = {
refresh: tokenData.refresh_token,
access: tokenData.access_token,
expires: expiresAt,
projectId,
email,
};
return credentials;
} finally {
server.server.close();
}
}
export const geminiCliOAuthProvider: OAuthProviderInterface = {
id: "google-gemini-cli",
name: "Google Cloud Code Assist (Gemini CLI)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
const creds = credentials as GeminiCredentials;
if (!creds.projectId) {
throw new Error("Google Cloud credentials missing projectId");
}
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
},
getApiKey(credentials: OAuthCredentials): string {
const creds = credentials as GeminiCredentials;
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
},
};

View file

@ -0,0 +1,162 @@
/**
* OAuth credential management for AI providers.
*
* This module handles login, token refresh, and credential storage
* for OAuth-based providers:
* - Anthropic (Claude Pro/Max)
* - GitHub Copilot
* - Google Cloud Code Assist (Gemini CLI)
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
*/
// Anthropic
export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
// GitHub Copilot
export {
getGitHubCopilotBaseUrl,
githubCopilotOAuthProvider,
loginGitHubCopilot,
normalizeDomain,
refreshGitHubCopilotToken,
} from "./github-copilot.js";
// Google Antigravity
export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js";
// Google Gemini CLI
export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js";
// OpenAI Codex (ChatGPT OAuth)
export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js";
export * from "./types.js";
// ============================================================================
// Provider Registry
// ============================================================================
import { anthropicOAuthProvider } from "./anthropic.js";
import { githubCopilotOAuthProvider } from "./github-copilot.js";
import { antigravityOAuthProvider } from "./google-antigravity.js";
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
import { openaiCodexOAuthProvider } from "./openai-codex.js";
import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js";
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
anthropicOAuthProvider,
githubCopilotOAuthProvider,
geminiCliOAuthProvider,
antigravityOAuthProvider,
openaiCodexOAuthProvider,
];
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
);
/**
* Get an OAuth provider by ID
*/
export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined {
return oauthProviderRegistry.get(id);
}
/**
* Register a custom OAuth provider
*/
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
oauthProviderRegistry.set(provider.id, provider);
}
/**
* Unregister an OAuth provider.
*
* If the provider is built-in, restores the built-in implementation.
* Custom providers are removed completely.
*/
export function unregisterOAuthProvider(id: string): void {
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find((provider) => provider.id === id);
if (builtInProvider) {
oauthProviderRegistry.set(id, builtInProvider);
return;
}
oauthProviderRegistry.delete(id);
}
/**
* Reset OAuth providers to built-ins.
*/
export function resetOAuthProviders(): void {
oauthProviderRegistry.clear();
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
oauthProviderRegistry.set(provider.id, provider);
}
}
/**
* Get all registered OAuth providers
*/
export function getOAuthProviders(): OAuthProviderInterface[] {
return Array.from(oauthProviderRegistry.values());
}
/**
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
*/
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
return getOAuthProviders().map((p) => ({
id: p.id,
name: p.name,
available: true,
}));
}
// ============================================================================
// High-level API (uses provider registry)
// ============================================================================
/**
* Refresh token for any OAuth provider.
* @deprecated Use getOAuthProvider(id).refreshToken() instead
*/
export async function refreshOAuthToken(
providerId: OAuthProviderId,
credentials: OAuthCredentials,
): Promise<OAuthCredentials> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
return provider.refreshToken(credentials);
}
/**
* Get API key for a provider from OAuth credentials.
* Automatically refreshes expired tokens.
*
* @returns API key string and updated credentials, or null if no credentials
* @throws Error if refresh fails
*/
export async function getOAuthApiKey(
providerId: OAuthProviderId,
credentials: Record<string, OAuthCredentials>,
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
let creds = credentials[providerId];
if (!creds) {
return null;
}
// Refresh if expired
if (Date.now() >= creds.expires) {
try {
creds = await provider.refreshToken(creds);
} catch (_error) {
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
}
}
const apiKey = provider.getApiKey(creds);
return { newCredentials: creds, apiKey };
}

View file

@ -0,0 +1,455 @@
/**
* OpenAI Codex (ChatGPT OAuth) flow
*
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
* It is only intended for CLI use, not browser environments.
*/
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
let _http: typeof import("node:http") | null = null;
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
import("node:crypto").then((m) => {
_randomBytes = m.randomBytes;
});
import("node:http").then((m) => {
_http = m;
});
}
import { generatePKCE } from "./pkce.js";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
const TOKEN_URL = "https://auth.openai.com/oauth/token";
const REDIRECT_URI = "http://localhost:1455/auth/callback";
const SCOPE = "openid profile email offline_access";
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
const SUCCESS_HTML = `<!doctype html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Authentication successful</title>
</head>
<body>
<p>Authentication successful. Return to your terminal to continue.</p>
</body>
</html>`;
type TokenSuccess = { type: "success"; access: string; refresh: string; expires: number };
type TokenFailure = { type: "failed" };
type TokenResult = TokenSuccess | TokenFailure;
type JwtPayload = {
[JWT_CLAIM_PATH]?: {
chatgpt_account_id?: string;
};
[key: string]: unknown;
};
function createState(): string {
if (!_randomBytes) {
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
}
return _randomBytes(16).toString("hex");
}
function parseAuthorizationInput(input: string): { code?: string; state?: string } {
const value = input.trim();
if (!value) return {};
try {
const url = new URL(value);
return {
code: url.searchParams.get("code") ?? undefined,
state: url.searchParams.get("state") ?? undefined,
};
} catch {
// not a URL
}
if (value.includes("#")) {
const [code, state] = value.split("#", 2);
return { code, state };
}
if (value.includes("code=")) {
const params = new URLSearchParams(value);
return {
code: params.get("code") ?? undefined,
state: params.get("state") ?? undefined,
};
}
return { code: value };
}
function decodeJwt(token: string): JwtPayload | null {
try {
const parts = token.split(".");
if (parts.length !== 3) return null;
const payload = parts[1] ?? "";
const decoded = atob(payload);
return JSON.parse(decoded) as JwtPayload;
} catch {
return null;
}
}
async function exchangeAuthorizationCode(
code: string,
verifier: string,
redirectUri: string = REDIRECT_URI,
): Promise<TokenResult> {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "authorization_code",
client_id: CLIENT_ID,
code,
code_verifier: verifier,
redirect_uri: redirectUri,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error("[openai-codex] code->token failed:", response.status, text);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
console.error("[openai-codex] token response missing fields:", json);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
}
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
try {
const response = await fetch(TOKEN_URL, {
method: "POST",
headers: { "Content-Type": "application/x-www-form-urlencoded" },
body: new URLSearchParams({
grant_type: "refresh_token",
refresh_token: refreshToken,
client_id: CLIENT_ID,
}),
});
if (!response.ok) {
const text = await response.text().catch(() => "");
console.error("[openai-codex] Token refresh failed:", response.status, text);
return { type: "failed" };
}
const json = (await response.json()) as {
access_token?: string;
refresh_token?: string;
expires_in?: number;
};
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
console.error("[openai-codex] Token refresh response missing fields:", json);
return { type: "failed" };
}
return {
type: "success",
access: json.access_token,
refresh: json.refresh_token,
expires: Date.now() + json.expires_in * 1000,
};
} catch (error) {
console.error("[openai-codex] Token refresh error:", error);
return { type: "failed" };
}
}
async function createAuthorizationFlow(
originator: string = "pi",
): Promise<{ verifier: string; state: string; url: string }> {
const { verifier, challenge } = await generatePKCE();
const state = createState();
const url = new URL(AUTHORIZE_URL);
url.searchParams.set("response_type", "code");
url.searchParams.set("client_id", CLIENT_ID);
url.searchParams.set("redirect_uri", REDIRECT_URI);
url.searchParams.set("scope", SCOPE);
url.searchParams.set("code_challenge", challenge);
url.searchParams.set("code_challenge_method", "S256");
url.searchParams.set("state", state);
url.searchParams.set("id_token_add_organizations", "true");
url.searchParams.set("codex_cli_simplified_flow", "true");
url.searchParams.set("originator", originator);
return { verifier, state, url: url.toString() };
}
type OAuthServerInfo = {
close: () => void;
cancelWait: () => void;
waitForCode: () => Promise<{ code: string } | null>;
};
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
if (!_http) {
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
}
let lastCode: string | null = null;
let cancelled = false;
const server = _http.createServer((req, res) => {
try {
const url = new URL(req.url || "", "http://localhost");
if (url.pathname !== "/auth/callback") {
res.statusCode = 404;
res.end("Not found");
return;
}
if (url.searchParams.get("state") !== state) {
res.statusCode = 400;
res.end("State mismatch");
return;
}
const code = url.searchParams.get("code");
if (!code) {
res.statusCode = 400;
res.end("Missing authorization code");
return;
}
res.statusCode = 200;
res.setHeader("Content-Type", "text/html; charset=utf-8");
res.end(SUCCESS_HTML);
lastCode = code;
} catch {
res.statusCode = 500;
res.end("Internal error");
}
});
return new Promise((resolve) => {
server
.listen(1455, "127.0.0.1", () => {
resolve({
close: () => server.close(),
cancelWait: () => {
cancelled = true;
},
waitForCode: async () => {
const sleep = () => new Promise((r) => setTimeout(r, 100));
for (let i = 0; i < 600; i += 1) {
if (lastCode) return { code: lastCode };
if (cancelled) return null;
await sleep();
}
return null;
},
});
})
.on("error", (err: NodeJS.ErrnoException) => {
console.error(
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
err.code,
") Falling back to manual paste.",
);
resolve({
close: () => {
try {
server.close();
} catch {
// ignore
}
},
cancelWait: () => {},
waitForCode: async () => null,
});
});
});
}
function getAccountId(accessToken: string): string | null {
const payload = decodeJwt(accessToken);
const auth = payload?.[JWT_CLAIM_PATH];
const accountId = auth?.chatgpt_account_id;
return typeof accountId === "string" && accountId.length > 0 ? accountId : null;
}
/**
* Login with OpenAI Codex OAuth
*
* @param options.onAuth - Called with URL and instructions when auth starts
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
* @param options.onProgress - Optional progress messages
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
* Races with browser callback - whichever completes first wins.
* Useful for showing paste input immediately alongside browser flow.
* @param options.originator - OAuth originator parameter (defaults to "pi")
*/
export async function loginOpenAICodex(options: {
onAuth: (info: { url: string; instructions?: string }) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
originator?: string;
}): Promise<OAuthCredentials> {
const { verifier, state, url } = await createAuthorizationFlow(options.originator);
const server = await startLocalOAuthServer(state);
options.onAuth({ url, instructions: "A browser window should open. Complete login to finish." });
let code: string | undefined;
try {
if (options.onManualCodeInput) {
// Race between browser callback and manual input
let manualCode: string | undefined;
let manualError: Error | undefined;
const manualPromise = options
.onManualCodeInput()
.then((input) => {
manualCode = input;
server.cancelWait();
})
.catch((err) => {
manualError = err instanceof Error ? err : new Error(String(err));
server.cancelWait();
});
const result = await server.waitForCode();
// If manual input was cancelled, throw that error
if (manualError) {
throw manualError;
}
if (result?.code) {
// Browser callback won
code = result.code;
} else if (manualCode) {
// Manual input won (or callback timed out and user had entered code)
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
// If still no code, wait for manual promise to complete and try that
if (!code) {
await manualPromise;
if (manualError) {
throw manualError;
}
if (manualCode) {
const parsed = parseAuthorizationInput(manualCode);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
}
} else {
// Original flow: wait for callback, then prompt if needed
const result = await server.waitForCode();
if (result?.code) {
code = result.code;
}
}
// Fallback to onPrompt if still no code
if (!code) {
const input = await options.onPrompt({
message: "Paste the authorization code (or full redirect URL):",
});
const parsed = parseAuthorizationInput(input);
if (parsed.state && parsed.state !== state) {
throw new Error("State mismatch");
}
code = parsed.code;
}
if (!code) {
throw new Error("Missing authorization code");
}
const tokenResult = await exchangeAuthorizationCode(code, verifier);
if (tokenResult.type !== "success") {
throw new Error("Token exchange failed");
}
const accountId = getAccountId(tokenResult.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: tokenResult.access,
refresh: tokenResult.refresh,
expires: tokenResult.expires,
accountId,
};
} finally {
server.close();
}
}
/**
* Refresh OpenAI Codex OAuth token
*/
export async function refreshOpenAICodexToken(refreshToken: string): Promise<OAuthCredentials> {
const result = await refreshAccessToken(refreshToken);
if (result.type !== "success") {
throw new Error("Failed to refresh OpenAI Codex token");
}
const accountId = getAccountId(result.access);
if (!accountId) {
throw new Error("Failed to extract accountId from token");
}
return {
access: result.access,
refresh: result.refresh,
expires: result.expires,
accountId,
};
}
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
id: "openai-codex",
name: "ChatGPT Plus/Pro (Codex Subscription)",
usesCallbackServer: true,
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
return loginOpenAICodex({
onAuth: callbacks.onAuth,
onPrompt: callbacks.onPrompt,
onProgress: callbacks.onProgress,
onManualCodeInput: callbacks.onManualCodeInput,
});
},
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
return refreshOpenAICodexToken(credentials.refresh);
},
getApiKey(credentials: OAuthCredentials): string {
return credentials.access;
},
};

View file

@ -0,0 +1,34 @@
/**
* PKCE utilities using Web Crypto API.
* Works in both Node.js 20+ and browsers.
*/
/**
* Encode bytes as base64url string.
*/
function base64urlEncode(bytes: Uint8Array): string {
let binary = "";
for (const byte of bytes) {
binary += String.fromCharCode(byte);
}
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
}
/**
* Generate PKCE code verifier and challenge.
* Uses Web Crypto API for cross-platform compatibility.
*/
export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
// Generate random verifier
const verifierBytes = new Uint8Array(32);
crypto.getRandomValues(verifierBytes);
const verifier = base64urlEncode(verifierBytes);
// Compute SHA-256 challenge
const encoder = new TextEncoder();
const data = encoder.encode(verifier);
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
return { verifier, challenge };
}

View file

@ -0,0 +1,59 @@
import type { Api, Model } from "../../types.js";
export type OAuthCredentials = {
refresh: string;
access: string;
expires: number;
[key: string]: unknown;
};
export type OAuthProviderId = string;
/** @deprecated Use OAuthProviderId instead */
export type OAuthProvider = OAuthProviderId;
export type OAuthPrompt = {
message: string;
placeholder?: string;
allowEmpty?: boolean;
};
export type OAuthAuthInfo = {
url: string;
instructions?: string;
};
export interface OAuthLoginCallbacks {
onAuth: (info: OAuthAuthInfo) => void;
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
onProgress?: (message: string) => void;
onManualCodeInput?: () => Promise<string>;
signal?: AbortSignal;
}
export interface OAuthProviderInterface {
readonly id: OAuthProviderId;
readonly name: string;
/** Run the login flow, return credentials to persist */
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
/** Whether login uses a local callback server and supports manual code input. */
usesCallbackServer?: boolean;
/** Refresh expired credentials, return updated credentials to persist */
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
/** Convert credentials to API key string for the provider */
getApiKey(credentials: OAuthCredentials): string;
/** Optional: modify models for this provider (e.g., update baseUrl) */
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
}
/** @deprecated Use OAuthProviderInterface instead */
export interface OAuthProviderInfo {
id: OAuthProviderId;
name: string;
available: boolean;
}

View file

@ -0,0 +1,123 @@
import type { AssistantMessage } from "../types.js";
/**
* Regex patterns to detect context overflow errors from different providers.
*
* These patterns match error messages returned when the input exceeds
* the model's context window.
*
* Provider-specific patterns (with example error messages):
*
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
* - OpenAI: "Your input exceeds the context window of this model"
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
* - Groq: "Please reduce the length of the messages or completion"
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
* - llama.cpp: "the request exceeds the available context size, try increasing it"
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
* - MiniMax: "invalid params, context window exceeds limit"
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
* - Ollama: Silently truncates input - not detectable via error message
*/
const OVERFLOW_PATTERNS = [
/prompt is too long/i, // Anthropic
/input is too long for requested model/i, // Amazon Bedrock
/exceeds the context window/i, // OpenAI (Completions & Responses API)
/input token count.*exceeds the maximum/i, // Google (Gemini)
/maximum prompt length is \d+/i, // xAI (Grok)
/reduce the length of the messages/i, // Groq
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
/exceeds the limit of \d+/i, // GitHub Copilot
/exceeds the available context size/i, // llama.cpp server
/greater than the context length/i, // LM Studio
/context window exceeds limit/i, // MiniMax
/exceeded model token limit/i, // Kimi For Coding
/too large for model with \d+ maximum context length/i, // Mistral
/model_context_window_exceeded/i, // z.ai non-standard finish_reason surfaced as error text
/context[_ ]length[_ ]exceeded/i, // Generic fallback
/too many tokens/i, // Generic fallback
/token limit exceeded/i, // Generic fallback
];
/**
* Check if an assistant message represents a context overflow error.
*
* This handles two cases:
* 1. Error-based overflow: Most providers return stopReason "error" with a
* specific error message pattern.
* 2. Silent overflow: Some providers accept overflow requests and return
* successfully. For these, we check if usage.input exceeds the context window.
*
* ## Reliability by Provider
*
* **Reliable detection (returns error with detectable message):**
* - Anthropic: "prompt is too long: X tokens > Y maximum"
* - OpenAI (Completions & Responses): "exceeds the context window"
* - Google Gemini: "input token count exceeds the maximum"
* - xAI (Grok): "maximum prompt length is X but request contains Y"
* - Groq: "reduce the length of the messages"
* - Cerebras: 400/413 status code (no body)
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
* - OpenRouter (all backends): "maximum context length is X tokens"
* - llama.cpp: "exceeds the available context size"
* - LM Studio: "greater than the context length"
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
*
* **Unreliable detection:**
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
* The response will have usage.input < expected, but we don't know the expected value.
*
* ## Custom Providers
*
* If you've added custom models via settings.json, this function may not detect
* overflow errors from those providers. To add support:
*
* 1. Send a request that exceeds the model's context window
* 2. Check the errorMessage in the response
* 3. Create a regex pattern that matches the error
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
* check the errorMessage yourself before calling this function
*
* @param message - The assistant message to check
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
* @returns true if the message indicates a context overflow
*/
export function isContextOverflow(message: AssistantMessage, contextWindow?: number): boolean {
// Case 1: Check error message patterns
if (message.stopReason === "error" && message.errorMessage) {
// Check known patterns
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
return true;
}
// Cerebras returns 400/413 with no body for context overflow
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
if (/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)) {
return true;
}
}
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
if (contextWindow && message.stopReason === "stop") {
const inputTokens = message.usage.input + message.usage.cacheRead;
if (inputTokens > contextWindow) {
return true;
}
}
return false;
}
/**
* Get the overflow patterns for testing purposes.
*/
export function getOverflowPatterns(): RegExp[] {
return [...OVERFLOW_PATTERNS];
}

View file

@ -0,0 +1,25 @@
/**
* Removes unpaired Unicode surrogate characters from a string.
*
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
* or vice versa) cause JSON serialization errors in many API providers.
*
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
* surrogates and will NOT be affected by this function.
*
* @param text - The text to sanitize
* @returns The sanitized text with unpaired surrogates removed
*
* @example
* // Valid emoji (properly paired surrogates) are preserved
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
*
* // Unpaired high surrogate is removed
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
*/
export function sanitizeSurrogates(text: string): string {
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
return text.replace(/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g, "");
}

View file

@ -0,0 +1,24 @@
import { type TUnsafe, Type } from "@sinclair/typebox";
/**
* Creates a string enum schema compatible with Google's API and other providers
* that don't support anyOf/const patterns.
*
* @example
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
* description: "The operation to perform"
* });
*
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
*/
export function StringEnum<T extends readonly string[]>(
values: T,
options?: { description?: string; default?: T[number] },
): TUnsafe<T[number]> {
return Type.Unsafe<T[number]>({
type: "string",
enum: values as any,
...(options?.description && { description: options.description }),
...(options?.default && { default: options.default }),
});
}

View file

@ -0,0 +1,84 @@
import AjvModule from "ajv";
import addFormatsModule from "ajv-formats";
// Handle both default and named exports
const Ajv = (AjvModule as any).default || AjvModule;
const addFormats = (addFormatsModule as any).default || addFormatsModule;
import type { Tool, ToolCall } from "../types.js";
// Detect if we're in a browser extension environment with strict CSP
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
// Create a singleton AJV instance with formats (only if not in browser extension)
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
let ajv: any = null;
if (!isBrowserExtension) {
try {
ajv = new Ajv({
allErrors: true,
strict: false,
coerceTypes: true,
});
addFormats(ajv);
} catch (_e) {
// AJV initialization failed (likely CSP restriction)
console.warn("AJV validation disabled due to CSP restrictions");
}
}
/**
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
* @param tools Array of tool definitions
* @param toolCall The tool call from the LLM
* @returns The validated arguments
* @throws Error if tool is not found or validation fails
*/
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
const tool = tools.find((t) => t.name === toolCall.name);
if (!tool) {
throw new Error(`Tool "${toolCall.name}" not found`);
}
return validateToolArguments(tool, toolCall);
}
/**
* Validates tool call arguments against the tool's TypeBox schema
* @param tool The tool definition with TypeBox schema
* @param toolCall The tool call from the LLM
* @returns The validated (and potentially coerced) arguments
* @throws Error with formatted message if validation fails
*/
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
if (!ajv || isBrowserExtension) {
// Trust the LLM's output without validation
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
return toolCall.arguments;
}
// Compile the schema
const validate = ajv.compile(tool.parameters);
// Clone arguments so AJV can safely mutate for type coercion
const args = structuredClone(toolCall.arguments);
// Validate the arguments (AJV mutates args in-place for type coercion)
if (validate(args)) {
return args;
}
// Format validation errors nicely
const errors =
validate.errors
?.map((err: any) => {
const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root";
return ` - ${path}: ${err.message}`;
})
.join("\n") || "Unknown validation error";
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
throw new Error(errorMessage);
}

View file

@ -0,0 +1,27 @@
{
"compilerOptions": {
"target": "ES2024",
"module": "Node16",
"lib": ["ES2024"],
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"declaration": true,
"declarationMap": true,
"sourceMap": true,
"inlineSources": true,
"inlineSourceMap": false,
"moduleResolution": "Node16",
"resolveJsonModule": true,
"allowImportingTsExtensions": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true,
"useDefineForClassFields": false,
"types": ["node"],
"outDir": "./dist",
"rootDir": "./src"
},
"include": ["src/**/*.ts"],
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
}

View file

@ -0,0 +1,55 @@
{
"name": "@gsd/pi-coding-agent",
"version": "0.57.1",
"description": "Coding agent CLI (vendored from pi-mono)",
"type": "module",
"piConfig": {
"name": "pi",
"configDir": ".pi"
},
"main": "./dist/index.js",
"types": "./dist/index.d.ts",
"exports": {
".": {
"types": "./dist/index.d.ts",
"import": "./dist/index.js"
},
"./hooks": {
"types": "./dist/core/hooks/index.d.ts",
"import": "./dist/core/hooks/index.js"
}
},
"scripts": {
"build": "tsc -p tsconfig.json && npm run copy-assets",
"copy-assets": "node -e \"const{mkdirSync,cpSync}=require('fs');mkdirSync('dist/modes/interactive/theme',{recursive:true});cpSync('src/modes/interactive/theme','dist/modes/interactive/theme',{recursive:true,filter:(s)=>!s.endsWith('.ts')});mkdirSync('dist/core/export-html/vendor',{recursive:true});cpSync('src/core/export-html/template.html','dist/core/export-html/template.html');cpSync('src/core/export-html/template.css','dist/core/export-html/template.css');cpSync('src/core/export-html/template.js','dist/core/export-html/template.js');cpSync('src/core/export-html/vendor','dist/core/export-html/vendor',{recursive:true,filter:(s)=>!s.endsWith('.ts')})\""
},
"dependencies": {
"@gsd/pi-agent-core": "*",
"@gsd/pi-ai": "*",
"@gsd/pi-tui": "*",
"@mariozechner/jiti": "^2.6.2",
"@silvia-odwyer/photon-node": "^0.3.4",
"chalk": "^5.5.0",
"cli-highlight": "^2.1.11",
"diff": "^8.0.2",
"extract-zip": "^2.0.1",
"file-type": "^21.1.1",
"glob": "^13.0.1",
"hosted-git-info": "^9.0.2",
"ignore": "^7.0.5",
"marked": "^15.0.12",
"minimatch": "^10.2.3",
"proper-lockfile": "^4.1.2",
"strip-ansi": "^7.1.0",
"undici": "^7.19.1",
"yaml": "^2.8.2"
},
"optionalDependencies": {
"@mariozechner/clipboard": "^0.3.2"
},
"devDependencies": {
"@types/diff": "^7.0.2",
"@types/hosted-git-info": "^3.0.5",
"@types/proper-lockfile": "^4.1.4"
}
}

View file

@ -0,0 +1,18 @@
#!/usr/bin/env node
/**
* CLI entry point for the refactored coding agent.
* Uses main.ts with AgentSession and new mode modules.
*
* Test with: npx tsx src/cli-new.ts [args...]
*/
process.title = "pi";
import { setBedrockProviderModule } from "@gsd/pi-ai";
import { bedrockProviderModule } from "@gsd/pi-ai/bedrock-provider";
import { EnvHttpProxyAgent, setGlobalDispatcher } from "undici";
import { main } from "./main.js";
setGlobalDispatcher(new EnvHttpProxyAgent());
setBedrockProviderModule(bedrockProviderModule);
main(process.argv.slice(2));

View file

@ -0,0 +1,316 @@
/**
* CLI argument parsing and help display
*/
import type { ThinkingLevel } from "@gsd/pi-agent-core";
import chalk from "chalk";
import { APP_NAME, CONFIG_DIR_NAME, ENV_AGENT_DIR } from "../config.js";
import { allTools, type ToolName } from "../core/tools/index.js";
export type Mode = "text" | "json" | "rpc";
export interface Args {
provider?: string;
model?: string;
apiKey?: string;
systemPrompt?: string;
appendSystemPrompt?: string;
thinking?: ThinkingLevel;
continue?: boolean;
resume?: boolean;
help?: boolean;
version?: boolean;
mode?: Mode;
noSession?: boolean;
session?: string;
sessionDir?: string;
models?: string[];
tools?: ToolName[];
noTools?: boolean;
extensions?: string[];
noExtensions?: boolean;
print?: boolean;
export?: string;
noSkills?: boolean;
skills?: string[];
promptTemplates?: string[];
noPromptTemplates?: boolean;
themes?: string[];
noThemes?: boolean;
listModels?: string | true;
offline?: boolean;
verbose?: boolean;
messages: string[];
fileArgs: string[];
/** Unknown flags (potentially extension flags) - map of flag name to value */
unknownFlags: Map<string, boolean | string>;
}
const VALID_THINKING_LEVELS = ["off", "minimal", "low", "medium", "high", "xhigh"] as const;
export function isValidThinkingLevel(level: string): level is ThinkingLevel {
return VALID_THINKING_LEVELS.includes(level as ThinkingLevel);
}
export function parseArgs(args: string[], extensionFlags?: Map<string, { type: "boolean" | "string" }>): Args {
const result: Args = {
messages: [],
fileArgs: [],
unknownFlags: new Map(),
};
for (let i = 0; i < args.length; i++) {
const arg = args[i];
if (arg === "--help" || arg === "-h") {
result.help = true;
} else if (arg === "--version" || arg === "-v") {
result.version = true;
} else if (arg === "--mode" && i + 1 < args.length) {
const mode = args[++i];
if (mode === "text" || mode === "json" || mode === "rpc") {
result.mode = mode;
}
} else if (arg === "--continue" || arg === "-c") {
result.continue = true;
} else if (arg === "--resume" || arg === "-r") {
result.resume = true;
} else if (arg === "--provider" && i + 1 < args.length) {
result.provider = args[++i];
} else if (arg === "--model" && i + 1 < args.length) {
result.model = args[++i];
} else if (arg === "--api-key" && i + 1 < args.length) {
result.apiKey = args[++i];
} else if (arg === "--system-prompt" && i + 1 < args.length) {
result.systemPrompt = args[++i];
} else if (arg === "--append-system-prompt" && i + 1 < args.length) {
result.appendSystemPrompt = args[++i];
} else if (arg === "--no-session") {
result.noSession = true;
} else if (arg === "--session" && i + 1 < args.length) {
result.session = args[++i];
} else if (arg === "--session-dir" && i + 1 < args.length) {
result.sessionDir = args[++i];
} else if (arg === "--models" && i + 1 < args.length) {
result.models = args[++i].split(",").map((s) => s.trim());
} else if (arg === "--no-tools") {
result.noTools = true;
} else if (arg === "--tools" && i + 1 < args.length) {
const toolNames = args[++i].split(",").map((s) => s.trim());
const validTools: ToolName[] = [];
for (const name of toolNames) {
if (name in allTools) {
validTools.push(name as ToolName);
} else {
console.error(
chalk.yellow(`Warning: Unknown tool "${name}". Valid tools: ${Object.keys(allTools).join(", ")}`),
);
}
}
result.tools = validTools;
} else if (arg === "--thinking" && i + 1 < args.length) {
const level = args[++i];
if (isValidThinkingLevel(level)) {
result.thinking = level;
} else {
console.error(
chalk.yellow(
`Warning: Invalid thinking level "${level}". Valid values: ${VALID_THINKING_LEVELS.join(", ")}`,
),
);
}
} else if (arg === "--print" || arg === "-p") {
result.print = true;
} else if (arg === "--export" && i + 1 < args.length) {
result.export = args[++i];
} else if ((arg === "--extension" || arg === "-e") && i + 1 < args.length) {
result.extensions = result.extensions ?? [];
result.extensions.push(args[++i]);
} else if (arg === "--no-extensions" || arg === "-ne") {
result.noExtensions = true;
} else if (arg === "--skill" && i + 1 < args.length) {
result.skills = result.skills ?? [];
result.skills.push(args[++i]);
} else if (arg === "--prompt-template" && i + 1 < args.length) {
result.promptTemplates = result.promptTemplates ?? [];
result.promptTemplates.push(args[++i]);
} else if (arg === "--theme" && i + 1 < args.length) {
result.themes = result.themes ?? [];
result.themes.push(args[++i]);
} else if (arg === "--no-skills" || arg === "-ns") {
result.noSkills = true;
} else if (arg === "--no-prompt-templates" || arg === "-np") {
result.noPromptTemplates = true;
} else if (arg === "--no-themes") {
result.noThemes = true;
} else if (arg === "--list-models") {
// Check if next arg is a search pattern (not a flag or file arg)
if (i + 1 < args.length && !args[i + 1].startsWith("-") && !args[i + 1].startsWith("@")) {
result.listModels = args[++i];
} else {
result.listModels = true;
}
} else if (arg === "--verbose") {
result.verbose = true;
} else if (arg === "--offline") {
result.offline = true;
} else if (arg.startsWith("@")) {
result.fileArgs.push(arg.slice(1)); // Remove @ prefix
} else if (arg.startsWith("--") && extensionFlags) {
// Check if it's an extension-registered flag
const flagName = arg.slice(2);
const extFlag = extensionFlags.get(flagName);
if (extFlag) {
if (extFlag.type === "boolean") {
result.unknownFlags.set(flagName, true);
} else if (extFlag.type === "string" && i + 1 < args.length) {
result.unknownFlags.set(flagName, args[++i]);
}
}
// Unknown flags without extensionFlags are silently ignored (first pass)
} else if (!arg.startsWith("-")) {
result.messages.push(arg);
}
}
return result;
}
export function printHelp(): void {
console.log(`${chalk.bold(APP_NAME)} - AI coding assistant with read, bash, edit, write tools
${chalk.bold("Usage:")}
${APP_NAME} [options] [@files...] [messages...]
${chalk.bold("Commands:")}
${APP_NAME} install <source> [-l] Install extension source and add to settings
${APP_NAME} remove <source> [-l] Remove extension source from settings
${APP_NAME} update [source] Update installed extensions (skips pinned sources)
${APP_NAME} list List installed extensions from settings
${APP_NAME} config Open TUI to enable/disable package resources
${APP_NAME} <command> --help Show help for install/remove/update/list
${chalk.bold("Options:")}
--provider <name> Provider name (default: google)
--model <pattern> Model pattern or ID (supports "provider/id" and optional ":<thinking>")
--api-key <key> API key (defaults to env vars)
--system-prompt <text> System prompt (default: coding assistant prompt)
--append-system-prompt <text> Append text or file contents to the system prompt
--mode <mode> Output mode: text (default), json, or rpc
--print, -p Non-interactive mode: process prompt and exit
--continue, -c Continue previous session
--resume, -r Select a session to resume
--session <path> Use specific session file
--session-dir <dir> Directory for session storage and lookup
--no-session Don't save session (ephemeral)
--models <patterns> Comma-separated model patterns for Ctrl+P cycling
Supports globs (anthropic/*, *sonnet*) and fuzzy matching
--no-tools Disable all built-in tools
--tools <tools> Comma-separated list of tools to enable (default: read,bash,edit,write)
Available: read, bash, edit, write, grep, find, ls
--thinking <level> Set thinking level: off, minimal, low, medium, high, xhigh
--extension, -e <path> Load an extension file (can be used multiple times)
--no-extensions, -ne Disable extension discovery (explicit -e paths still work)
--skill <path> Load a skill file or directory (can be used multiple times)
--no-skills, -ns Disable skills discovery and loading
--prompt-template <path> Load a prompt template file or directory (can be used multiple times)
--no-prompt-templates, -np Disable prompt template discovery and loading
--theme <path> Load a theme file or directory (can be used multiple times)
--no-themes Disable theme discovery and loading
--export <file> Export session file to HTML and exit
--list-models [search] List available models (with optional fuzzy search)
--verbose Force verbose startup (overrides quietStartup setting)
--offline Disable startup network operations (same as PI_OFFLINE=1)
--help, -h Show this help
--version, -v Show version number
Extensions can register additional flags (e.g., --plan from plan-mode extension).
${chalk.bold("Examples:")}
# Interactive mode
${APP_NAME}
# Interactive mode with initial prompt
${APP_NAME} "List all .ts files in src/"
# Include files in initial message
${APP_NAME} @prompt.md @image.png "What color is the sky?"
# Non-interactive mode (process and exit)
${APP_NAME} -p "List all .ts files in src/"
# Multiple messages (interactive)
${APP_NAME} "Read package.json" "What dependencies do we have?"
# Continue previous session
${APP_NAME} --continue "What did we discuss?"
# Use different model
${APP_NAME} --provider openai --model gpt-4o-mini "Help me refactor this code"
# Use model with provider prefix (no --provider needed)
${APP_NAME} --model openai/gpt-4o "Help me refactor this code"
# Use model with thinking level shorthand
${APP_NAME} --model sonnet:high "Solve this complex problem"
# Limit model cycling to specific models
${APP_NAME} --models claude-sonnet,claude-haiku,gpt-4o
# Limit to a specific provider with glob pattern
${APP_NAME} --models "github-copilot/*"
# Cycle models with fixed thinking levels
${APP_NAME} --models sonnet:high,haiku:low
# Start with a specific thinking level
${APP_NAME} --thinking high "Solve this complex problem"
# Read-only mode (no file modifications possible)
${APP_NAME} --tools read,grep,find,ls -p "Review the code in src/"
# Export a session file to HTML
${APP_NAME} --export ~/${CONFIG_DIR_NAME}/agent/sessions/--path--/session.jsonl
${APP_NAME} --export session.jsonl output.html
${chalk.bold("Environment Variables:")}
ANTHROPIC_API_KEY - Anthropic Claude API key
ANTHROPIC_OAUTH_TOKEN - Anthropic OAuth token (alternative to API key)
OPENAI_API_KEY - OpenAI GPT API key
AZURE_OPENAI_API_KEY - Azure OpenAI API key
AZURE_OPENAI_BASE_URL - Azure OpenAI base URL (https://{resource}.openai.azure.com/openai/v1)
AZURE_OPENAI_RESOURCE_NAME - Azure OpenAI resource name (alternative to base URL)
AZURE_OPENAI_API_VERSION - Azure OpenAI API version (default: v1)
AZURE_OPENAI_DEPLOYMENT_NAME_MAP - Azure OpenAI model=deployment map (comma-separated)
GEMINI_API_KEY - Google Gemini API key
GROQ_API_KEY - Groq API key
CEREBRAS_API_KEY - Cerebras API key
XAI_API_KEY - xAI Grok API key
OPENROUTER_API_KEY - OpenRouter API key
AI_GATEWAY_API_KEY - Vercel AI Gateway API key
ZAI_API_KEY - ZAI API key
MISTRAL_API_KEY - Mistral API key
MINIMAX_API_KEY - MiniMax API key
OPENCODE_API_KEY - OpenCode Zen/OpenCode Go API key
KIMI_API_KEY - Kimi For Coding API key
AWS_PROFILE - AWS profile for Amazon Bedrock
AWS_ACCESS_KEY_ID - AWS access key for Amazon Bedrock
AWS_SECRET_ACCESS_KEY - AWS secret key for Amazon Bedrock
AWS_BEARER_TOKEN_BEDROCK - Bedrock API key (bearer token)
AWS_REGION - AWS region for Amazon Bedrock (e.g., us-east-1)
${ENV_AGENT_DIR.padEnd(32)} - Session storage directory (default: ~/${CONFIG_DIR_NAME}/agent)
PI_PACKAGE_DIR - Override package directory (for Nix/Guix store paths)
PI_OFFLINE - Disable startup network operations when set to 1/true/yes
PI_SHARE_VIEWER_URL - Base URL for /share command (default: https://pi.dev/session/)
PI_AI_ANTIGRAVITY_VERSION - Override Antigravity User-Agent version (e.g., 1.23.0)
${chalk.bold("Available Tools (default: read, bash, edit, write):")}
read - Read file contents
bash - Execute bash commands
edit - Edit files with find/replace
write - Write files (creates/overwrites)
grep - Search file contents (read-only, off by default)
find - Find files by glob pattern (read-only, off by default)
ls - List directory contents (read-only, off by default)
`);
}

View file

@ -0,0 +1,52 @@
/**
* TUI config selector for `pi config` command
*/
import { ProcessTerminal, TUI } from "@gsd/pi-tui";
import type { ResolvedPaths } from "../core/package-manager.js";
import type { SettingsManager } from "../core/settings-manager.js";
import { ConfigSelectorComponent } from "../modes/interactive/components/config-selector.js";
import { initTheme, stopThemeWatcher } from "../modes/interactive/theme/theme.js";
export interface ConfigSelectorOptions {
resolvedPaths: ResolvedPaths;
settingsManager: SettingsManager;
cwd: string;
agentDir: string;
}
/** Show TUI config selector and return when closed */
export async function selectConfig(options: ConfigSelectorOptions): Promise<void> {
// Initialize theme before showing TUI
initTheme(options.settingsManager.getTheme(), true);
return new Promise((resolve) => {
const ui = new TUI(new ProcessTerminal());
let resolved = false;
const selector = new ConfigSelectorComponent(
options.resolvedPaths,
options.settingsManager,
options.cwd,
options.agentDir,
() => {
if (!resolved) {
resolved = true;
ui.stop();
stopThemeWatcher();
resolve();
}
},
() => {
ui.stop();
stopThemeWatcher();
process.exit(0);
},
() => ui.requestRender(),
);
ui.addChild(selector);
ui.setFocus(selector.getResourceList());
ui.start();
});
}

View file

@ -0,0 +1,96 @@
/**
* Process @file CLI arguments into text content and image attachments
*/
import { access, readFile, stat } from "node:fs/promises";
import type { ImageContent } from "@gsd/pi-ai";
import chalk from "chalk";
import { resolve } from "path";
import { resolveReadPath } from "../core/tools/path-utils.js";
import { formatDimensionNote, resizeImage } from "../utils/image-resize.js";
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
export interface ProcessedFiles {
text: string;
images: ImageContent[];
}
export interface ProcessFileOptions {
/** Whether to auto-resize images to 2000x2000 max. Default: true */
autoResizeImages?: boolean;
}
/** Process @file arguments into text content and image attachments */
export async function processFileArguments(fileArgs: string[], options?: ProcessFileOptions): Promise<ProcessedFiles> {
const autoResizeImages = options?.autoResizeImages ?? true;
let text = "";
const images: ImageContent[] = [];
for (const fileArg of fileArgs) {
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
const absolutePath = resolve(resolveReadPath(fileArg, process.cwd()));
// Check if file exists
try {
await access(absolutePath);
} catch {
console.error(chalk.red(`Error: File not found: ${absolutePath}`));
process.exit(1);
}
// Check if file is empty
const stats = await stat(absolutePath);
if (stats.size === 0) {
// Skip empty files
continue;
}
const mimeType = await detectSupportedImageMimeTypeFromFile(absolutePath);
if (mimeType) {
// Handle image file
const content = await readFile(absolutePath);
const base64Content = content.toString("base64");
let attachment: ImageContent;
let dimensionNote: string | undefined;
if (autoResizeImages) {
const resized = await resizeImage({ type: "image", data: base64Content, mimeType });
dimensionNote = formatDimensionNote(resized);
attachment = {
type: "image",
mimeType: resized.mimeType,
data: resized.data,
};
} else {
attachment = {
type: "image",
mimeType,
data: base64Content,
};
}
images.push(attachment);
// Add text reference to image with optional dimension note
if (dimensionNote) {
text += `<file name="${absolutePath}">${dimensionNote}</file>\n`;
} else {
text += `<file name="${absolutePath}"></file>\n`;
}
} else {
// Handle text file
try {
const content = await readFile(absolutePath, "utf-8");
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
} catch (error: unknown) {
const message = error instanceof Error ? error.message : String(error);
console.error(chalk.red(`Error: Could not read file ${absolutePath}: ${message}`));
process.exit(1);
}
}
}
return { text, images };
}

View file

@ -0,0 +1,104 @@
/**
* List available models with optional fuzzy search
*/
import type { Api, Model } from "@gsd/pi-ai";
import { fuzzyFilter } from "@gsd/pi-tui";
import type { ModelRegistry } from "../core/model-registry.js";
/**
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
*/
function formatTokenCount(count: number): string {
if (count >= 1_000_000) {
const millions = count / 1_000_000;
return millions % 1 === 0 ? `${millions}M` : `${millions.toFixed(1)}M`;
}
if (count >= 1_000) {
const thousands = count / 1_000;
return thousands % 1 === 0 ? `${thousands}K` : `${thousands.toFixed(1)}K`;
}
return count.toString();
}
/**
* List available models, optionally filtered by search pattern
*/
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
const models = modelRegistry.getAvailable();
if (models.length === 0) {
console.log("No models available. Set API keys in environment variables.");
return;
}
// Apply fuzzy filter if search pattern provided
let filteredModels: Model<Api>[] = models;
if (searchPattern) {
filteredModels = fuzzyFilter(models, searchPattern, (m) => `${m.provider} ${m.id}`);
}
if (filteredModels.length === 0) {
console.log(`No models matching "${searchPattern}"`);
return;
}
// Sort by provider, then by model id
filteredModels.sort((a, b) => {
const providerCmp = a.provider.localeCompare(b.provider);
if (providerCmp !== 0) return providerCmp;
return a.id.localeCompare(b.id);
});
// Calculate column widths
const rows = filteredModels.map((m) => ({
provider: m.provider,
model: m.id,
context: formatTokenCount(m.contextWindow),
maxOut: formatTokenCount(m.maxTokens),
thinking: m.reasoning ? "yes" : "no",
images: m.input.includes("image") ? "yes" : "no",
}));
const headers = {
provider: "provider",
model: "model",
context: "context",
maxOut: "max-out",
thinking: "thinking",
images: "images",
};
const widths = {
provider: Math.max(headers.provider.length, ...rows.map((r) => r.provider.length)),
model: Math.max(headers.model.length, ...rows.map((r) => r.model.length)),
context: Math.max(headers.context.length, ...rows.map((r) => r.context.length)),
maxOut: Math.max(headers.maxOut.length, ...rows.map((r) => r.maxOut.length)),
thinking: Math.max(headers.thinking.length, ...rows.map((r) => r.thinking.length)),
images: Math.max(headers.images.length, ...rows.map((r) => r.images.length)),
};
// Print header
const headerLine = [
headers.provider.padEnd(widths.provider),
headers.model.padEnd(widths.model),
headers.context.padEnd(widths.context),
headers.maxOut.padEnd(widths.maxOut),
headers.thinking.padEnd(widths.thinking),
headers.images.padEnd(widths.images),
].join(" ");
console.log(headerLine);
// Print rows
for (const row of rows) {
const line = [
row.provider.padEnd(widths.provider),
row.model.padEnd(widths.model),
row.context.padEnd(widths.context),
row.maxOut.padEnd(widths.maxOut),
row.thinking.padEnd(widths.thinking),
row.images.padEnd(widths.images),
].join(" ");
console.log(line);
}
}

View file

@ -0,0 +1,51 @@
/**
* TUI session selector for --resume flag
*/
import { ProcessTerminal, TUI } from "@gsd/pi-tui";
import { KeybindingsManager } from "../core/keybindings.js";
import type { SessionInfo, SessionListProgress } from "../core/session-manager.js";
import { SessionSelectorComponent } from "../modes/interactive/components/session-selector.js";
type SessionsLoader = (onProgress?: SessionListProgress) => Promise<SessionInfo[]>;
/** Show TUI session selector and return selected session path or null if cancelled */
export async function selectSession(
currentSessionsLoader: SessionsLoader,
allSessionsLoader: SessionsLoader,
): Promise<string | null> {
return new Promise((resolve) => {
const ui = new TUI(new ProcessTerminal());
const keybindings = KeybindingsManager.create();
let resolved = false;
const selector = new SessionSelectorComponent(
currentSessionsLoader,
allSessionsLoader,
(path: string) => {
if (!resolved) {
resolved = true;
ui.stop();
resolve(path);
}
},
() => {
if (!resolved) {
resolved = true;
ui.stop();
resolve(null);
}
},
() => {
ui.stop();
process.exit(0);
},
() => ui.requestRender(),
{ showRenameHint: false, keybindings },
);
ui.addChild(selector);
ui.setFocus(selector.getSessionList());
ui.start();
});
}

View file

@ -0,0 +1,241 @@
import { existsSync, readFileSync } from "fs";
import { homedir } from "os";
import { dirname, join, resolve } from "path";
import { fileURLToPath } from "url";
// =============================================================================
// Package Detection
// =============================================================================
const __filename = fileURLToPath(import.meta.url);
const __dirname = dirname(__filename);
/**
* Detect if we're running as a Bun compiled binary.
* Bun binaries have import.meta.url containing "$bunfs", "~BUN", or "%7EBUN" (Bun's virtual filesystem path)
*/
export const isBunBinary =
import.meta.url.includes("$bunfs") || import.meta.url.includes("~BUN") || import.meta.url.includes("%7EBUN");
/** Detect if Bun is the runtime (compiled binary or bun run) */
export const isBunRuntime = !!process.versions.bun;
// =============================================================================
// Install Method Detection
// =============================================================================
export type InstallMethod = "bun-binary" | "npm" | "pnpm" | "yarn" | "bun" | "unknown";
export function detectInstallMethod(): InstallMethod {
if (isBunBinary) {
return "bun-binary";
}
const resolvedPath = `${__dirname}\0${process.execPath || ""}`.toLowerCase();
if (resolvedPath.includes("/pnpm/") || resolvedPath.includes("/.pnpm/") || resolvedPath.includes("\\pnpm\\")) {
return "pnpm";
}
if (resolvedPath.includes("/yarn/") || resolvedPath.includes("/.yarn/") || resolvedPath.includes("\\yarn\\")) {
return "yarn";
}
if (isBunRuntime) {
return "bun";
}
if (resolvedPath.includes("/npm/") || resolvedPath.includes("/node_modules/") || resolvedPath.includes("\\npm\\")) {
return "npm";
}
return "unknown";
}
export function getUpdateInstruction(packageName: string): string {
const method = detectInstallMethod();
switch (method) {
case "bun-binary":
return `Download from: https://github.com/badlogic/pi-mono/releases/latest`;
case "pnpm":
return `Run: pnpm install -g ${packageName}`;
case "yarn":
return `Run: yarn global add ${packageName}`;
case "bun":
return `Run: bun install -g ${packageName}`;
case "npm":
return `Run: npm install -g ${packageName}`;
default:
return `Run: npm install -g ${packageName}`;
}
}
// =============================================================================
// Package Asset Paths (shipped with executable)
// =============================================================================
/**
* Get the base directory for resolving package assets (themes, package.json, README.md, CHANGELOG.md).
* - For Bun binary: returns the directory containing the executable
* - For Node.js (dist/): returns __dirname (the dist/ directory)
* - For tsx (src/): returns parent directory (the package root)
*/
export function getPackageDir(): string {
// Allow override via environment variable (useful for Nix/Guix where store paths tokenize poorly)
const envDir = process.env.PI_PACKAGE_DIR;
if (envDir) {
if (envDir === "~") return homedir();
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
return envDir;
}
if (isBunBinary) {
// Bun binary: process.execPath points to the compiled executable
return dirname(process.execPath);
}
// Node.js: walk up from __dirname until we find package.json
let dir = __dirname;
while (dir !== dirname(dir)) {
if (existsSync(join(dir, "package.json"))) {
return dir;
}
dir = dirname(dir);
}
// Fallback (shouldn't happen)
return __dirname;
}
/**
* Get path to built-in themes directory (shipped with package)
* - For Bun binary: theme/ next to executable
* - For Node.js (dist/): dist/modes/interactive/theme/
* - For tsx (src/): src/modes/interactive/theme/
*/
export function getThemesDir(): string {
if (isBunBinary) {
return join(dirname(process.execPath), "theme");
}
// Theme is in modes/interactive/theme/ relative to src/ or dist/
const packageDir = getPackageDir();
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
return join(packageDir, srcOrDist, "modes", "interactive", "theme");
}
/**
* Get path to HTML export template directory (shipped with package)
* - For Bun binary: export-html/ next to executable
* - For Node.js (dist/): dist/core/export-html/
* - For tsx (src/): src/core/export-html/
*/
export function getExportTemplateDir(): string {
if (isBunBinary) {
return join(dirname(process.execPath), "export-html");
}
const packageDir = getPackageDir();
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
return join(packageDir, srcOrDist, "core", "export-html");
}
/** Get path to package.json */
export function getPackageJsonPath(): string {
return join(getPackageDir(), "package.json");
}
/** Get path to README.md */
export function getReadmePath(): string {
return resolve(join(getPackageDir(), "README.md"));
}
/** Get path to docs directory */
export function getDocsPath(): string {
return resolve(join(getPackageDir(), "docs"));
}
/** Get path to examples directory */
export function getExamplesPath(): string {
return resolve(join(getPackageDir(), "examples"));
}
/** Get path to CHANGELOG.md */
export function getChangelogPath(): string {
return resolve(join(getPackageDir(), "CHANGELOG.md"));
}
// =============================================================================
// App Config (from package.json piConfig)
// =============================================================================
const pkg = JSON.parse(readFileSync(getPackageJsonPath(), "utf-8"));
export const APP_NAME: string = pkg.piConfig?.name || "pi";
export const CONFIG_DIR_NAME: string = pkg.piConfig?.configDir || ".pi";
export const VERSION: string = pkg.version;
// e.g., PI_CODING_AGENT_DIR or TAU_CODING_AGENT_DIR
export const ENV_AGENT_DIR = `${APP_NAME.toUpperCase()}_CODING_AGENT_DIR`;
const DEFAULT_SHARE_VIEWER_URL = "https://pi.dev/session/";
/** Get the share viewer URL for a gist ID */
export function getShareViewerUrl(gistId: string): string {
const baseUrl = process.env.PI_SHARE_VIEWER_URL || DEFAULT_SHARE_VIEWER_URL;
return `${baseUrl}#${gistId}`;
}
// =============================================================================
// User Config Paths (~/.pi/agent/*)
// =============================================================================
/** Get the agent config directory (e.g., ~/.pi/agent/) */
export function getAgentDir(): string {
const envDir = process.env[ENV_AGENT_DIR];
if (envDir) {
// Expand tilde to home directory
if (envDir === "~") return homedir();
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
return envDir;
}
return join(homedir(), CONFIG_DIR_NAME, "agent");
}
/** Get path to user's custom themes directory */
export function getCustomThemesDir(): string {
return join(getAgentDir(), "themes");
}
/** Get path to models.json */
export function getModelsPath(): string {
return join(getAgentDir(), "models.json");
}
/** Get path to auth.json */
export function getAuthPath(): string {
return join(getAgentDir(), "auth.json");
}
/** Get path to settings.json */
export function getSettingsPath(): string {
return join(getAgentDir(), "settings.json");
}
/** Get path to tools directory */
export function getToolsDir(): string {
return join(getAgentDir(), "tools");
}
/** Get path to managed binaries directory (fd, rg) */
export function getBinDir(): string {
return join(getAgentDir(), "bin");
}
/** Get path to prompt templates directory */
export function getPromptsDir(): string {
return join(getAgentDir(), "prompts");
}
/** Get path to sessions directory */
export function getSessionsDir(): string {
return join(getAgentDir(), "sessions");
}
/** Get path to debug log file */
export function getDebugLogPath(): string {
return join(getAgentDir(), `${APP_NAME}-debug.log`);
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,489 @@
/**
* Credential storage for API keys and OAuth tokens.
* Handles loading, saving, and refreshing credentials from auth.json.
*
* Uses file locking to prevent race conditions when multiple pi instances
* try to refresh tokens simultaneously.
*/
import {
getEnvApiKey,
type OAuthCredentials,
type OAuthLoginCallbacks,
type OAuthProviderId,
} from "@gsd/pi-ai";
import { getOAuthApiKey, getOAuthProvider, getOAuthProviders } from "@gsd/pi-ai/oauth";
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
import { dirname, join } from "path";
import lockfile from "proper-lockfile";
import { getAgentDir } from "../config.js";
import { resolveConfigValue } from "./resolve-config-value.js";
export type ApiKeyCredential = {
type: "api_key";
key: string;
};
export type OAuthCredential = {
type: "oauth";
} & OAuthCredentials;
export type AuthCredential = ApiKeyCredential | OAuthCredential;
export type AuthStorageData = Record<string, AuthCredential>;
type LockResult<T> = {
result: T;
next?: string;
};
export interface AuthStorageBackend {
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T>;
}
export class FileAuthStorageBackend implements AuthStorageBackend {
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
private ensureParentDir(): void {
const dir = dirname(this.authPath);
if (!existsSync(dir)) {
mkdirSync(dir, { recursive: true, mode: 0o700 });
}
}
private ensureFileExists(): void {
if (!existsSync(this.authPath)) {
writeFileSync(this.authPath, "{}", "utf-8");
chmodSync(this.authPath, 0o600);
}
}
private acquireLockSyncWithRetry(path: string): () => void {
const maxAttempts = 10;
const delayMs = 20;
let lastError: unknown;
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return lockfile.lockSync(path, { realpath: false });
} catch (error) {
const code =
typeof error === "object" && error !== null && "code" in error
? String((error as { code?: unknown }).code)
: undefined;
if (code !== "ELOCKED" || attempt === maxAttempts) {
throw error;
}
lastError = error;
const start = Date.now();
while (Date.now() - start < delayMs) {
// Sleep synchronously to avoid changing callers to async.
}
}
}
throw (lastError as Error) ?? new Error("Failed to acquire auth storage lock");
}
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => void) | undefined;
try {
release = this.acquireLockSyncWithRetry(this.authPath);
const current = existsSync(this.authPath) ? readFileSync(this.authPath, "utf-8") : undefined;
const { result, next } = fn(current);
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
return result;
} finally {
if (release) {
release();
}
}
}
async withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T> {
this.ensureParentDir();
this.ensureFileExists();
let release: (() => Promise<void>) | undefined;
let lockCompromised = false;
let lockCompromisedError: Error | undefined;
const throwIfCompromised = () => {
if (lockCompromised) {
throw lockCompromisedError ?? new Error("Auth storage lock was compromised");
}
};
try {
release = await lockfile.lock(this.authPath, {
retries: {
retries: 10,
factor: 2,
minTimeout: 100,
maxTimeout: 10000,
randomize: true,
},
stale: 30000,
onCompromised: (err) => {
lockCompromised = true;
lockCompromisedError = err;
},
});
throwIfCompromised();
const current = existsSync(this.authPath) ? readFileSync(this.authPath, "utf-8") : undefined;
const { result, next } = await fn(current);
throwIfCompromised();
if (next !== undefined) {
writeFileSync(this.authPath, next, "utf-8");
chmodSync(this.authPath, 0o600);
}
throwIfCompromised();
return result;
} finally {
if (release) {
try {
await release();
} catch {
// Ignore unlock errors when lock is compromised.
}
}
}
}
}
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
private value: string | undefined;
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
const { result, next } = fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
async withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T> {
const { result, next } = await fn(this.value);
if (next !== undefined) {
this.value = next;
}
return result;
}
}
/**
* Credential storage backed by a JSON file.
*/
export class AuthStorage {
private data: AuthStorageData = {};
private runtimeOverrides: Map<string, string> = new Map();
private fallbackResolver?: (provider: string) => string | undefined;
private loadError: Error | null = null;
private errors: Error[] = [];
private constructor(private storage: AuthStorageBackend) {
this.reload();
}
static create(authPath?: string): AuthStorage {
return new AuthStorage(new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")));
}
static fromStorage(storage: AuthStorageBackend): AuthStorage {
return new AuthStorage(storage);
}
static inMemory(data: AuthStorageData = {}): AuthStorage {
const storage = new InMemoryAuthStorageBackend();
storage.withLock(() => ({ result: undefined, next: JSON.stringify(data, null, 2) }));
return AuthStorage.fromStorage(storage);
}
/**
* Set a runtime API key override (not persisted to disk).
* Used for CLI --api-key flag.
*/
setRuntimeApiKey(provider: string, apiKey: string): void {
this.runtimeOverrides.set(provider, apiKey);
}
/**
* Remove a runtime API key override.
*/
removeRuntimeApiKey(provider: string): void {
this.runtimeOverrides.delete(provider);
}
/**
* Set a fallback resolver for API keys not found in auth.json or env vars.
* Used for custom provider keys from models.json.
*/
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
this.fallbackResolver = resolver;
}
private recordError(error: unknown): void {
const normalizedError = error instanceof Error ? error : new Error(String(error));
this.errors.push(normalizedError);
}
private parseStorageData(content: string | undefined): AuthStorageData {
if (!content) {
return {};
}
return JSON.parse(content) as AuthStorageData;
}
/**
* Reload credentials from storage.
*/
reload(): void {
let content: string | undefined;
try {
this.storage.withLock((current) => {
content = current;
return { result: undefined };
});
this.data = this.parseStorageData(content);
this.loadError = null;
} catch (error) {
this.loadError = error as Error;
this.recordError(error);
}
}
private persistProviderChange(provider: string, credential: AuthCredential | undefined): void {
if (this.loadError) {
return;
}
try {
this.storage.withLock((current) => {
const currentData = this.parseStorageData(current);
const merged: AuthStorageData = { ...currentData };
if (credential) {
merged[provider] = credential;
} else {
delete merged[provider];
}
return { result: undefined, next: JSON.stringify(merged, null, 2) };
});
} catch (error) {
this.recordError(error);
}
}
/**
* Get credential for a provider.
*/
get(provider: string): AuthCredential | undefined {
return this.data[provider] ?? undefined;
}
/**
* Set credential for a provider.
*/
set(provider: string, credential: AuthCredential): void {
this.data[provider] = credential;
this.persistProviderChange(provider, credential);
}
/**
* Remove credential for a provider.
*/
remove(provider: string): void {
delete this.data[provider];
this.persistProviderChange(provider, undefined);
}
/**
* List all providers with credentials.
*/
list(): string[] {
return Object.keys(this.data);
}
/**
* Check if credentials exist for a provider in auth.json.
*/
has(provider: string): boolean {
return provider in this.data;
}
/**
* Check if any form of auth is configured for a provider.
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
*/
hasAuth(provider: string): boolean {
if (this.runtimeOverrides.has(provider)) return true;
if (this.data[provider]) return true;
if (getEnvApiKey(provider)) return true;
if (this.fallbackResolver?.(provider)) return true;
return false;
}
/**
* Get all credentials (for passing to getOAuthApiKey).
*/
getAll(): AuthStorageData {
return { ...this.data };
}
drainErrors(): Error[] {
const drained = [...this.errors];
this.errors = [];
return drained;
}
/**
* Login to an OAuth provider.
*/
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
const provider = getOAuthProvider(providerId);
if (!provider) {
throw new Error(`Unknown OAuth provider: ${providerId}`);
}
const credentials = await provider.login(callbacks);
this.set(providerId, { type: "oauth", ...credentials });
}
/**
* Logout from a provider.
*/
logout(provider: string): void {
this.remove(provider);
}
/**
* Refresh OAuth token with backend locking to prevent race conditions.
* Multiple pi instances may try to refresh simultaneously when tokens expire.
*/
private async refreshOAuthTokenWithLock(
providerId: OAuthProviderId,
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
const provider = getOAuthProvider(providerId);
if (!provider) {
return null;
}
const result = await this.storage.withLockAsync(async (current) => {
const currentData = this.parseStorageData(current);
this.data = currentData;
this.loadError = null;
const cred = currentData[providerId];
if (cred?.type !== "oauth") {
return { result: null };
}
if (Date.now() < cred.expires) {
return { result: { apiKey: provider.getApiKey(cred), newCredentials: cred } };
}
const oauthCreds: Record<string, OAuthCredentials> = {};
for (const [key, value] of Object.entries(currentData)) {
if (value.type === "oauth") {
oauthCreds[key] = value;
}
}
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
if (!refreshed) {
return { result: null };
}
const merged: AuthStorageData = {
...currentData,
[providerId]: { type: "oauth", ...refreshed.newCredentials },
};
this.data = merged;
this.loadError = null;
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
});
return result;
}
/**
* Get API key for a provider.
* Priority:
* 1. Runtime override (CLI --api-key)
* 2. API key from auth.json
* 3. OAuth token from auth.json (auto-refreshed with locking)
* 4. Environment variable
* 5. Fallback resolver (models.json custom providers)
*/
async getApiKey(providerId: string): Promise<string | undefined> {
// Runtime override takes highest priority
const runtimeKey = this.runtimeOverrides.get(providerId);
if (runtimeKey) {
return runtimeKey;
}
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return resolveConfigValue(cred.key);
}
if (cred?.type === "oauth") {
const provider = getOAuthProvider(providerId);
if (!provider) {
// Unknown OAuth provider, can't get API key
return undefined;
}
// Check if token needs refresh
const needsRefresh = Date.now() >= cred.expires;
if (needsRefresh) {
// Use locked refresh to prevent race conditions
try {
const result = await this.refreshOAuthTokenWithLock(providerId);
if (result) {
return result.apiKey;
}
} catch (error) {
this.recordError(error);
// Refresh failed - re-read file to check if another instance succeeded
this.reload();
const updatedCred = this.data[providerId];
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
// Another instance refreshed successfully, use those credentials
return provider.getApiKey(updatedCred);
}
// Refresh truly failed - return undefined so model discovery skips this provider
// User can /login to re-authenticate (credentials preserved for retry)
return undefined;
}
} else {
// Token not expired, use current access token
return provider.getApiKey(cred);
}
}
// Fall back to environment variable
const envKey = getEnvApiKey(providerId);
if (envKey) return envKey;
// Fall back to custom resolver (e.g., models.json custom providers)
return this.fallbackResolver?.(providerId) ?? undefined;
}
/**
* Get all registered OAuth providers
*/
getOAuthProviders() {
return getOAuthProviders();
}
}

View file

@ -0,0 +1,278 @@
/**
* Bash command execution with streaming support and cancellation.
*
* This module provides a unified bash execution implementation used by:
* - AgentSession.executeBash() for interactive and RPC modes
* - Direct calls from modes that need bash execution
*/
import { randomBytes } from "node:crypto";
import { createWriteStream, type WriteStream } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { type ChildProcess, spawn } from "child_process";
import stripAnsi from "strip-ansi";
import { getShellConfig, getShellEnv, killProcessTree, sanitizeBinaryOutput } from "../utils/shell.js";
import type { BashOperations } from "./tools/bash.js";
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
// ============================================================================
// Types
// ============================================================================
export interface BashExecutorOptions {
/** Callback for streaming output chunks (already sanitized) */
onChunk?: (chunk: string) => void;
/** AbortSignal for cancellation */
signal?: AbortSignal;
}
export interface BashResult {
/** Combined stdout + stderr output (sanitized, possibly truncated) */
output: string;
/** Process exit code (undefined if killed/cancelled) */
exitCode: number | undefined;
/** Whether the command was cancelled via signal */
cancelled: boolean;
/** Whether the output was truncated */
truncated: boolean;
/** Path to temp file containing full output (if output exceeded truncation threshold) */
fullOutputPath?: string;
}
// ============================================================================
// Implementation
// ============================================================================
/**
* Execute a bash command with optional streaming and cancellation support.
*
* Features:
* - Streams sanitized output via onChunk callback
* - Writes large output to temp file for later retrieval
* - Supports cancellation via AbortSignal
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
* - Truncates output if it exceeds the default max bytes
*
* @param command - The bash command to execute
* @param options - Optional streaming callback and abort signal
* @returns Promise resolving to execution result
*/
export function executeBash(command: string, options?: BashExecutorOptions): Promise<BashResult> {
return new Promise((resolve, reject) => {
const { shell, args } = getShellConfig();
const child: ChildProcess = spawn(shell, [...args, command], {
detached: true,
env: getShellEnv(),
stdio: ["ignore", "pipe", "pipe"],
});
// Track sanitized output for truncation
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
// Temp file for large output
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
// Handle abort signal
const abortHandler = () => {
if (child.pid) {
killProcessTree(child.pid);
}
};
if (options?.signal) {
if (options.signal.aborted) {
// Already aborted, don't even start
child.kill();
resolve({
output: "",
exitCode: undefined,
cancelled: true,
truncated: false,
});
return;
}
options.signal.addEventListener("abort", abortHandler, { once: true });
}
const decoder = new TextDecoder();
const handleData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(stripAnsi(decoder.decode(data, { stream: true }))).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
// Write already-buffered chunks to temp file
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer of sanitized text
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback if provided
if (options?.onChunk) {
options.onChunk(text);
}
};
child.stdout?.on("data", handleData);
child.stderr?.on("data", handleData);
child.on("close", (code) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
// Combine buffered chunks for truncation (already sanitized)
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
// code === null means killed (cancelled)
const cancelled = code === null;
resolve({
output: truncationResult.truncated ? truncationResult.content : fullOutput,
exitCode: cancelled ? undefined : code,
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
});
});
child.on("error", (err) => {
// Clean up abort listener
if (options?.signal) {
options.signal.removeEventListener("abort", abortHandler);
}
if (tempFileStream) {
tempFileStream.end();
}
reject(err);
});
});
}
/**
* Execute a bash command using custom BashOperations.
* Used for remote execution (SSH, containers, etc.).
*/
export async function executeBashWithOperations(
command: string,
cwd: string,
operations: BashOperations,
options?: BashExecutorOptions,
): Promise<BashResult> {
const outputChunks: string[] = [];
let outputBytes = 0;
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
let tempFilePath: string | undefined;
let tempFileStream: WriteStream | undefined;
let totalBytes = 0;
const decoder = new TextDecoder();
const onData = (data: Buffer) => {
totalBytes += data.length;
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
const text = sanitizeBinaryOutput(stripAnsi(decoder.decode(data, { stream: true }))).replace(/\r/g, "");
// Start writing to temp file if exceeds threshold
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
const id = randomBytes(8).toString("hex");
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
tempFileStream = createWriteStream(tempFilePath);
for (const chunk of outputChunks) {
tempFileStream.write(chunk);
}
}
if (tempFileStream) {
tempFileStream.write(text);
}
// Keep rolling buffer
outputChunks.push(text);
outputBytes += text.length;
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
const removed = outputChunks.shift()!;
outputBytes -= removed.length;
}
// Stream to callback
if (options?.onChunk) {
options.onChunk(text);
}
};
try {
const result = await operations.exec(command, cwd, {
onData,
signal: options?.signal,
});
if (tempFileStream) {
tempFileStream.end();
}
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
const cancelled = options?.signal?.aborted ?? false;
return {
output: truncationResult.truncated ? truncationResult.content : fullOutput,
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
cancelled,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
} catch (err) {
if (tempFileStream) {
tempFileStream.end();
}
// Check if it was an abort
if (options?.signal?.aborted) {
const fullOutput = outputChunks.join("");
const truncationResult = truncateTail(fullOutput);
return {
output: truncationResult.truncated ? truncationResult.content : fullOutput,
exitCode: undefined,
cancelled: true,
truncated: truncationResult.truncated,
fullOutputPath: tempFilePath,
};
}
throw err;
}
}

View file

@ -0,0 +1,352 @@
/**
* Branch summarization for tree navigation.
*
* When navigating to a different point in the session tree, this generates
* a summary of the branch being left so context isn't lost.
*/
import type { AgentMessage } from "@gsd/pi-agent-core";
import type { Model } from "@gsd/pi-ai";
import { completeSimple } from "@gsd/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type { ReadonlySessionManager, SessionEntry } from "../session-manager.js";
import { estimateTokens } from "./compaction.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// Types
// ============================================================================
export interface BranchSummaryResult {
summary?: string;
readFiles?: string[];
modifiedFiles?: string[];
aborted?: boolean;
error?: string;
}
/** Details stored in BranchSummaryEntry.details for file tracking */
export interface BranchSummaryDetails {
readFiles: string[];
modifiedFiles: string[];
}
export type { FileOperations } from "./utils.js";
export interface BranchPreparation {
/** Messages extracted for summarization, in chronological order */
messages: AgentMessage[];
/** File operations extracted from tool calls */
fileOps: FileOperations;
/** Total estimated tokens in messages */
totalTokens: number;
}
export interface CollectEntriesResult {
/** Entries to summarize, in chronological order */
entries: SessionEntry[];
/** Common ancestor between old and new position, if any */
commonAncestorId: string | null;
}
export interface GenerateBranchSummaryOptions {
/** Model to use for summarization */
model: Model<any>;
/** API key for the model */
apiKey: string;
/** Abort signal for cancellation */
signal: AbortSignal;
/** Optional custom instructions for summarization */
customInstructions?: string;
/** If true, customInstructions replaces the default prompt instead of being appended */
replaceInstructions?: boolean;
/** Tokens reserved for prompt + LLM response (default 16384) */
reserveTokens?: number;
}
// ============================================================================
// Entry Collection
// ============================================================================
/**
* Collect entries that should be summarized when navigating from one position to another.
*
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
* along the way. Does NOT stop at compaction boundaries - those are included and their
* summaries become context.
*
* @param session - Session manager (read-only access)
* @param oldLeafId - Current position (where we're navigating from)
* @param targetId - Target position (where we're navigating to)
* @returns Entries to summarize and the common ancestor
*/
export function collectEntriesForBranchSummary(
session: ReadonlySessionManager,
oldLeafId: string | null,
targetId: string,
): CollectEntriesResult {
// If no old position, nothing to summarize
if (!oldLeafId) {
return { entries: [], commonAncestorId: null };
}
// Find common ancestor (deepest node that's on both paths)
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
const targetPath = session.getBranch(targetId);
// targetPath is root-first, so iterate backwards to find deepest common ancestor
let commonAncestorId: string | null = null;
for (let i = targetPath.length - 1; i >= 0; i--) {
if (oldPath.has(targetPath[i].id)) {
commonAncestorId = targetPath[i].id;
break;
}
}
// Collect entries from old leaf back to common ancestor
const entries: SessionEntry[] = [];
let current: string | null = oldLeafId;
while (current && current !== commonAncestorId) {
const entry = session.getEntry(current);
if (!entry) break;
entries.push(entry);
current = entry.parentId;
}
// Reverse to get chronological order
entries.reverse();
return { entries, commonAncestorId };
}
// ============================================================================
// Entry to Message Conversion
// ============================================================================
/**
* Extract AgentMessage from a session entry.
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
switch (entry.type) {
case "message":
// Skip tool results - context is in assistant's tool call
if (entry.message.role === "toolResult") return undefined;
return entry.message;
case "custom_message":
return createCustomMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
case "branch_summary":
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
case "compaction":
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
// These don't contribute to conversation content
case "thinking_level_change":
case "model_change":
case "custom":
case "label":
return undefined;
}
}
/**
* Prepare entries for summarization with token budget.
*
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
* This ensures we keep the most recent context when the branch is too long.
*
* Also collects file operations from:
* - Tool calls in assistant messages
* - Existing branch_summary entries' details (for cumulative tracking)
*
* @param entries - Entries in chronological order
* @param tokenBudget - Maximum tokens to include (0 = no limit)
*/
export function prepareBranchEntries(entries: SessionEntry[], tokenBudget: number = 0): BranchPreparation {
const messages: AgentMessage[] = [];
const fileOps = createFileOps();
let totalTokens = 0;
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
// This ensures we capture cumulative file tracking from nested branch summaries
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
for (const entry of entries) {
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
const details = entry.details as BranchSummaryDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
// Modified files go into both edited and written for proper deduplication
for (const f of details.modifiedFiles) {
fileOps.edited.add(f);
}
}
}
}
// Second pass: walk from newest to oldest, adding messages until token budget
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
const message = getMessageFromEntry(entry);
if (!message) continue;
// Extract file ops from assistant messages (tool calls)
extractFileOpsFromMessage(message, fileOps);
const tokens = estimateTokens(message);
// Check budget before adding
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
// If this is a summary entry, try to fit it anyway as it's important context
if (entry.type === "compaction" || entry.type === "branch_summary") {
if (totalTokens < tokenBudget * 0.9) {
messages.unshift(message);
totalTokens += tokens;
}
}
// Stop - we've hit the budget
break;
}
messages.unshift(message);
totalTokens += tokens;
}
return { messages, fileOps, totalTokens };
}
// ============================================================================
// Summary Generation
// ============================================================================
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
Summary of that exploration:
`;
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
Use this EXACT format:
## Goal
[What was the user trying to accomplish in this branch?]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Work that was started but not finished]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [What should happen next to continue this work]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of abandoned branch entries.
*
* @param entries - Session entries to summarize (chronological order)
* @param options - Generation options
*/
export async function generateBranchSummary(
entries: SessionEntry[],
options: GenerateBranchSummaryOptions,
): Promise<BranchSummaryResult> {
const { model, apiKey, signal, customInstructions, replaceInstructions, reserveTokens = 16384 } = options;
// Token budget = context window minus reserved space for prompt + response
const contextWindow = model.contextWindow || 128000;
const tokenBudget = contextWindow - reserveTokens;
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
if (messages.length === 0) {
return { summary: "No content to summarize" };
}
// Transform to LLM-compatible messages, then serialize to text
// Serialization prevents the model from treating it as a conversation to continue
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
// Build prompt
let instructions: string;
if (replaceInstructions && customInstructions) {
instructions = customInstructions;
} else if (customInstructions) {
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
} else {
instructions = BRANCH_SUMMARY_PROMPT;
}
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
// Call LLM for summarization
const response = await completeSimple(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
{ apiKey, signal, maxTokens: 2048 },
);
// Check if aborted or errored
if (response.stopReason === "aborted") {
return { aborted: true };
}
if (response.stopReason === "error") {
return { error: response.errorMessage || "Summarization failed" };
}
let summary = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
// Prepend preamble to provide context about the branch summary
summary = BRANCH_SUMMARY_PREAMBLE + summary;
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
return {
summary: summary || "No summary generated",
readFiles,
modifiedFiles,
};
}

View file

@ -0,0 +1,813 @@
/**
* Context compaction for long sessions.
*
* Pure functions for compaction logic. The session manager handles I/O,
* and after compaction the session is reloaded.
*/
import type { AgentMessage } from "@gsd/pi-agent-core";
import type { AssistantMessage, Model, Usage } from "@gsd/pi-ai";
import { completeSimple } from "@gsd/pi-ai";
import {
convertToLlm,
createBranchSummaryMessage,
createCompactionSummaryMessage,
createCustomMessage,
} from "../messages.js";
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
import {
computeFileLists,
createFileOps,
extractFileOpsFromMessage,
type FileOperations,
formatFileOperations,
SUMMARIZATION_SYSTEM_PROMPT,
serializeConversation,
} from "./utils.js";
// ============================================================================
// File Operation Tracking
// ============================================================================
/** Details stored in CompactionEntry.details for file tracking */
export interface CompactionDetails {
readFiles: string[];
modifiedFiles: string[];
}
/**
* Extract file operations from messages and previous compaction entries.
*/
function extractFileOperations(
messages: AgentMessage[],
entries: SessionEntry[],
prevCompactionIndex: number,
): FileOperations {
const fileOps = createFileOps();
// Collect from previous compaction's details (if pi-generated)
if (prevCompactionIndex >= 0) {
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
if (!prevCompaction.fromHook && prevCompaction.details) {
// fromHook field kept for session file compatibility
const details = prevCompaction.details as CompactionDetails;
if (Array.isArray(details.readFiles)) {
for (const f of details.readFiles) fileOps.read.add(f);
}
if (Array.isArray(details.modifiedFiles)) {
for (const f of details.modifiedFiles) fileOps.edited.add(f);
}
}
}
// Extract from tool calls in messages
for (const msg of messages) {
extractFileOpsFromMessage(msg, fileOps);
}
return fileOps;
}
// ============================================================================
// Message Extraction
// ============================================================================
/**
* Extract AgentMessage from an entry if it produces one.
* Returns undefined for entries that don't contribute to LLM context.
*/
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
if (entry.type === "message") {
return entry.message;
}
if (entry.type === "custom_message") {
return createCustomMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
}
if (entry.type === "branch_summary") {
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
}
if (entry.type === "compaction") {
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
}
return undefined;
}
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
export interface CompactionResult<T = unknown> {
summary: string;
firstKeptEntryId: string;
tokensBefore: number;
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
details?: T;
}
// ============================================================================
// Types
// ============================================================================
export interface CompactionSettings {
enabled: boolean;
reserveTokens: number;
keepRecentTokens: number;
}
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
enabled: true,
reserveTokens: 16384,
keepRecentTokens: 20000,
};
// ============================================================================
// Token calculation
// ============================================================================
/**
* Calculate total context tokens from usage.
* Uses the native totalTokens field when available, falls back to computing from components.
*/
export function calculateContextTokens(usage: Usage): number {
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
}
/**
* Get usage from an assistant message if available.
* Skips aborted and error messages as they don't have valid usage data.
*/
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
if (msg.role === "assistant" && "usage" in msg) {
const assistantMsg = msg as AssistantMessage;
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
return assistantMsg.usage;
}
}
return undefined;
}
/**
* Find the last non-aborted assistant message usage from session entries.
*/
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | undefined {
for (let i = entries.length - 1; i >= 0; i--) {
const entry = entries[i];
if (entry.type === "message") {
const usage = getAssistantUsage(entry.message);
if (usage) return usage;
}
}
return undefined;
}
export interface ContextUsageEstimate {
tokens: number;
usageTokens: number;
trailingTokens: number;
lastUsageIndex: number | null;
}
function getLastAssistantUsageInfo(messages: AgentMessage[]): { usage: Usage; index: number } | undefined {
for (let i = messages.length - 1; i >= 0; i--) {
const usage = getAssistantUsage(messages[i]);
if (usage) return { usage, index: i };
}
return undefined;
}
/**
* Estimate context tokens from messages, using the last assistant usage when available.
* If there are messages after the last usage, estimate their tokens with estimateTokens.
*/
export function estimateContextTokens(messages: AgentMessage[]): ContextUsageEstimate {
const usageInfo = getLastAssistantUsageInfo(messages);
if (!usageInfo) {
let estimated = 0;
for (const message of messages) {
estimated += estimateTokens(message);
}
return {
tokens: estimated,
usageTokens: 0,
trailingTokens: estimated,
lastUsageIndex: null,
};
}
const usageTokens = calculateContextTokens(usageInfo.usage);
let trailingTokens = 0;
for (let i = usageInfo.index + 1; i < messages.length; i++) {
trailingTokens += estimateTokens(messages[i]);
}
return {
tokens: usageTokens + trailingTokens,
usageTokens,
trailingTokens,
lastUsageIndex: usageInfo.index,
};
}
/**
* Check if compaction should trigger based on context usage.
*/
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
if (!settings.enabled) return false;
return contextTokens > contextWindow - settings.reserveTokens;
}
// ============================================================================
// Cut point detection
// ============================================================================
/**
* Estimate token count for a message using chars/4 heuristic.
* This is conservative (overestimates tokens).
*/
export function estimateTokens(message: AgentMessage): number {
let chars = 0;
switch (message.role) {
case "user": {
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
if (typeof content === "string") {
chars = content.length;
} else if (Array.isArray(content)) {
for (const block of content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
}
}
return Math.ceil(chars / 4);
}
case "assistant": {
const assistant = message as AssistantMessage;
for (const block of assistant.content) {
if (block.type === "text") {
chars += block.text.length;
} else if (block.type === "thinking") {
chars += block.thinking.length;
} else if (block.type === "toolCall") {
chars += block.name.length + JSON.stringify(block.arguments).length;
}
}
return Math.ceil(chars / 4);
}
case "custom":
case "toolResult": {
if (typeof message.content === "string") {
chars = message.content.length;
} else {
for (const block of message.content) {
if (block.type === "text" && block.text) {
chars += block.text.length;
}
if (block.type === "image") {
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
}
}
}
return Math.ceil(chars / 4);
}
case "bashExecution": {
chars = message.command.length + message.output.length;
return Math.ceil(chars / 4);
}
case "branchSummary":
case "compactionSummary": {
chars = message.summary.length;
return Math.ceil(chars / 4);
}
}
return 0;
}
/**
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
* Never cut at tool results (they must follow their tool call).
* When we cut at an assistant message with tool calls, its tool results follow it
* and will be kept.
* BashExecutionMessage is treated like a user message (user-initiated context).
*/
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
const cutPoints: number[] = [];
for (let i = startIndex; i < endIndex; i++) {
const entry = entries[i];
switch (entry.type) {
case "message": {
const role = entry.message.role;
switch (role) {
case "bashExecution":
case "custom":
case "branchSummary":
case "compactionSummary":
case "user":
case "assistant":
cutPoints.push(i);
break;
case "toolResult":
break;
}
break;
}
case "thinking_level_change":
case "model_change":
case "compaction":
case "branch_summary":
case "custom":
case "custom_message":
case "label":
}
// branch_summary and custom_message are user-role messages, valid cut points
if (entry.type === "branch_summary" || entry.type === "custom_message") {
cutPoints.push(i);
}
}
return cutPoints;
}
/**
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
* Returns -1 if no turn start found before the index.
* BashExecutionMessage is treated like a user message for turn boundaries.
*/
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
for (let i = entryIndex; i >= startIndex; i--) {
const entry = entries[i];
// branch_summary and custom_message are user-role messages, can start a turn
if (entry.type === "branch_summary" || entry.type === "custom_message") {
return i;
}
if (entry.type === "message") {
const role = entry.message.role;
if (role === "user" || role === "bashExecution") {
return i;
}
}
}
return -1;
}
export interface CutPointResult {
/** Index of first entry to keep */
firstKeptEntryIndex: number;
/** Index of user message that starts the turn being split, or -1 if not splitting */
turnStartIndex: number;
/** Whether this cut splits a turn (cut point is not a user message) */
isSplitTurn: boolean;
}
/**
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
*
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
*
* Can cut at user OR assistant messages (never tool results). When cutting at an
* assistant message with tool calls, its tool results come after and will be kept.
*
* Returns CutPointResult with:
* - firstKeptEntryIndex: the entry index to start keeping from
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
* - isSplitTurn: whether we're cutting in the middle of a turn
*
* Only considers entries between `startIndex` and `endIndex` (exclusive).
*/
export function findCutPoint(
entries: SessionEntry[],
startIndex: number,
endIndex: number,
keepRecentTokens: number,
): CutPointResult {
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
if (cutPoints.length === 0) {
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
}
// Walk backwards from newest, accumulating estimated message sizes
let accumulatedTokens = 0;
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
for (let i = endIndex - 1; i >= startIndex; i--) {
const entry = entries[i];
if (entry.type !== "message") continue;
// Estimate this message's size
const messageTokens = estimateTokens(entry.message);
accumulatedTokens += messageTokens;
// Check if we've exceeded the budget
if (accumulatedTokens >= keepRecentTokens) {
// Find the closest valid cut point at or after this entry
for (let c = 0; c < cutPoints.length; c++) {
if (cutPoints[c] >= i) {
cutIndex = cutPoints[c];
break;
}
}
break;
}
}
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
while (cutIndex > startIndex) {
const prevEntry = entries[cutIndex - 1];
// Stop at session header or compaction boundaries
if (prevEntry.type === "compaction") {
break;
}
if (prevEntry.type === "message") {
// Stop if we hit any message
break;
}
// Include this non-message entry (bash, settings change, etc.)
cutIndex--;
}
// Determine if this is a split turn
const cutEntry = entries[cutIndex];
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
return {
firstKeptEntryIndex: cutIndex,
turnStartIndex,
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
};
}
// ============================================================================
// Summarization
// ============================================================================
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
Use this EXACT format:
## Goal
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
## Constraints & Preferences
- [Any constraints, preferences, or requirements mentioned by user]
- [Or "(none)" if none were mentioned]
## Progress
### Done
- [x] [Completed tasks/changes]
### In Progress
- [ ] [Current work]
### Blocked
- [Issues preventing progress, if any]
## Key Decisions
- **[Decision]**: [Brief rationale]
## Next Steps
1. [Ordered list of what should happen next]
## Critical Context
- [Any data, examples, or references needed to continue]
- [Or "(none)" if not applicable]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
Update the existing structured summary with new information. RULES:
- PRESERVE all existing information from the previous summary
- ADD new progress, decisions, and context from the new messages
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
- UPDATE "Next Steps" based on what was accomplished
- PRESERVE exact file paths, function names, and error messages
- If something is no longer relevant, you may remove it
Use this EXACT format:
## Goal
[Preserve existing goals, add new ones if the task expanded]
## Constraints & Preferences
- [Preserve existing, add new ones discovered]
## Progress
### Done
- [x] [Include previously done items AND newly completed items]
### In Progress
- [ ] [Current work - update based on progress]
### Blocked
- [Current blockers - remove if resolved]
## Key Decisions
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
## Next Steps
1. [Update based on current state]
## Critical Context
- [Preserve important context, add new if needed]
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
/**
* Generate a summary of the conversation using the LLM.
* If previousSummary is provided, uses the update prompt to merge.
*/
export async function generateSummary(
currentMessages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
customInstructions?: string,
previousSummary?: string,
): Promise<string> {
const maxTokens = Math.floor(0.8 * reserveTokens);
// Use update prompt if we have a previous summary, otherwise initial prompt
let basePrompt = previousSummary ? UPDATE_SUMMARIZATION_PROMPT : SUMMARIZATION_PROMPT;
if (customInstructions) {
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
}
// Serialize conversation to text so model doesn't try to continue it
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
const llmMessages = convertToLlm(currentMessages);
const conversationText = serializeConversation(llmMessages);
// Build the prompt with conversation wrapped in tags
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
if (previousSummary) {
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
}
promptText += basePrompt;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const completionOptions = model.reasoning
? { maxTokens, signal, apiKey, reasoning: "high" as const }
: { maxTokens, signal, apiKey };
const response = await completeSimple(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
completionOptions,
);
if (response.stopReason === "error") {
throw new Error(`Summarization failed: ${response.errorMessage || "Unknown error"}`);
}
const textContent = response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
return textContent;
}
// ============================================================================
// Compaction Preparation (for extensions)
// ============================================================================
export interface CompactionPreparation {
/** UUID of first entry to keep */
firstKeptEntryId: string;
/** Messages that will be summarized and discarded */
messagesToSummarize: AgentMessage[];
/** Messages that will be turned into turn prefix summary (if splitting) */
turnPrefixMessages: AgentMessage[];
/** Whether this is a split turn (cut point in middle of turn) */
isSplitTurn: boolean;
tokensBefore: number;
/** Summary from previous compaction, for iterative update */
previousSummary?: string;
/** File operations extracted from messagesToSummarize */
fileOps: FileOperations;
/** Compaction settions from settings.jsonl */
settings: CompactionSettings;
}
export function prepareCompaction(
pathEntries: SessionEntry[],
settings: CompactionSettings,
): CompactionPreparation | undefined {
if (pathEntries.length > 0 && pathEntries[pathEntries.length - 1].type === "compaction") {
return undefined;
}
let prevCompactionIndex = -1;
for (let i = pathEntries.length - 1; i >= 0; i--) {
if (pathEntries[i].type === "compaction") {
prevCompactionIndex = i;
break;
}
}
const boundaryStart = prevCompactionIndex + 1;
const boundaryEnd = pathEntries.length;
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
const usageMessages: AgentMessage[] = [];
for (let i = usageStart; i < boundaryEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) usageMessages.push(msg);
}
const tokensBefore = estimateContextTokens(usageMessages).tokens;
const cutPoint = findCutPoint(pathEntries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
// Get UUID of first kept entry
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
if (!firstKeptEntry?.id) {
return undefined; // Session needs migration
}
const firstKeptEntryId = firstKeptEntry.id;
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
// Messages to summarize (will be discarded after summary)
const messagesToSummarize: AgentMessage[] = [];
for (let i = boundaryStart; i < historyEnd; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) messagesToSummarize.push(msg);
}
// Messages for turn prefix summary (if splitting a turn)
const turnPrefixMessages: AgentMessage[] = [];
if (cutPoint.isSplitTurn) {
for (let i = cutPoint.turnStartIndex; i < cutPoint.firstKeptEntryIndex; i++) {
const msg = getMessageFromEntry(pathEntries[i]);
if (msg) turnPrefixMessages.push(msg);
}
}
// Get previous summary for iterative update
let previousSummary: string | undefined;
if (prevCompactionIndex >= 0) {
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
previousSummary = prevCompaction.summary;
}
// Extract file operations from messages and previous compaction
const fileOps = extractFileOperations(messagesToSummarize, pathEntries, prevCompactionIndex);
// Also extract file ops from turn prefix if splitting
if (cutPoint.isSplitTurn) {
for (const msg of turnPrefixMessages) {
extractFileOpsFromMessage(msg, fileOps);
}
}
return {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn: cutPoint.isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
};
}
// ============================================================================
// Main compaction function
// ============================================================================
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
Summarize the prefix to provide context for the retained suffix:
## Original Request
[What did the user ask for in this turn?]
## Early Progress
- [Key decisions and work done in the prefix]
## Context for Suffix
- [Information needed to understand the retained recent work]
Be concise. Focus on what's needed to understand the kept suffix.`;
/**
* Generate summaries for compaction using prepared data.
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
*
* @param preparation - Pre-calculated preparation from prepareCompaction()
* @param customInstructions - Optional custom focus for the summary
*/
export async function compact(
preparation: CompactionPreparation,
model: Model<any>,
apiKey: string,
customInstructions?: string,
signal?: AbortSignal,
): Promise<CompactionResult> {
const {
firstKeptEntryId,
messagesToSummarize,
turnPrefixMessages,
isSplitTurn,
tokensBefore,
previousSummary,
fileOps,
settings,
} = preparation;
// Generate summaries (can be parallel if both needed) and merge into one
let summary: string;
if (isSplitTurn && turnPrefixMessages.length > 0) {
// Generate both summaries in parallel
const [historyResult, turnPrefixResult] = await Promise.all([
messagesToSummarize.length > 0
? generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
]);
// Merge into single summary
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
} else {
// Just generate history summary
summary = await generateSummary(
messagesToSummarize,
model,
settings.reserveTokens,
apiKey,
signal,
customInstructions,
previousSummary,
);
}
// Compute file lists and append to summary
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
summary += formatFileOperations(readFiles, modifiedFiles);
if (!firstKeptEntryId) {
throw new Error("First kept entry has no UUID - session may need migration");
}
return {
summary,
firstKeptEntryId,
tokensBefore,
details: { readFiles, modifiedFiles } as CompactionDetails,
};
}
/**
* Generate a summary for a turn prefix (when splitting a turn).
*/
async function generateTurnPrefixSummary(
messages: AgentMessage[],
model: Model<any>,
reserveTokens: number,
apiKey: string,
signal?: AbortSignal,
): Promise<string> {
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
const llmMessages = convertToLlm(messages);
const conversationText = serializeConversation(llmMessages);
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
const summarizationMessages = [
{
role: "user" as const,
content: [{ type: "text" as const, text: promptText }],
timestamp: Date.now(),
},
];
const response = await completeSimple(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
{ maxTokens, signal, apiKey },
);
if (response.stopReason === "error") {
throw new Error(`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`);
}
return response.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("\n");
}

View file

@ -0,0 +1,7 @@
/**
* Compaction and summarization utilities.
*/
export * from "./branch-summarization.js";
export * from "./compaction.js";
export * from "./utils.js";

View file

@ -0,0 +1,170 @@
/**
* Shared utilities for compaction and branch summarization.
*/
import type { AgentMessage } from "@gsd/pi-agent-core";
import type { Message } from "@gsd/pi-ai";
// ============================================================================
// File Operation Tracking
// ============================================================================
export interface FileOperations {
read: Set<string>;
written: Set<string>;
edited: Set<string>;
}
export function createFileOps(): FileOperations {
return {
read: new Set(),
written: new Set(),
edited: new Set(),
};
}
/**
* Extract file operations from tool calls in an assistant message.
*/
export function extractFileOpsFromMessage(message: AgentMessage, fileOps: FileOperations): void {
if (message.role !== "assistant") return;
if (!("content" in message) || !Array.isArray(message.content)) return;
for (const block of message.content) {
if (typeof block !== "object" || block === null) continue;
if (!("type" in block) || block.type !== "toolCall") continue;
if (!("arguments" in block) || !("name" in block)) continue;
const args = block.arguments as Record<string, unknown> | undefined;
if (!args) continue;
const path = typeof args.path === "string" ? args.path : undefined;
if (!path) continue;
switch (block.name) {
case "read":
fileOps.read.add(path);
break;
case "write":
fileOps.written.add(path);
break;
case "edit":
fileOps.edited.add(path);
break;
}
}
}
/**
* Compute final file lists from file operations.
* Returns readFiles (files only read, not modified) and modifiedFiles.
*/
export function computeFileLists(fileOps: FileOperations): { readFiles: string[]; modifiedFiles: string[] } {
const modified = new Set([...fileOps.edited, ...fileOps.written]);
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
const modifiedFiles = [...modified].sort();
return { readFiles: readOnly, modifiedFiles };
}
/**
* Format file operations as XML tags for summary.
*/
export function formatFileOperations(readFiles: string[], modifiedFiles: string[]): string {
const sections: string[] = [];
if (readFiles.length > 0) {
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
}
if (modifiedFiles.length > 0) {
sections.push(`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`);
}
if (sections.length === 0) return "";
return `\n\n${sections.join("\n\n")}`;
}
// ============================================================================
// Message Serialization
// ============================================================================
/** Maximum characters for a tool result in serialized summaries. */
const TOOL_RESULT_MAX_CHARS = 2000;
/**
* Truncate text to a maximum character length for summarization.
* Keeps the beginning and appends a truncation marker.
*/
function truncateForSummary(text: string, maxChars: number): string {
if (text.length <= maxChars) return text;
const truncatedChars = text.length - maxChars;
return `${text.slice(0, maxChars)}\n\n[... ${truncatedChars} more characters truncated]`;
}
/**
* Serialize LLM messages to text for summarization.
* This prevents the model from treating it as a conversation to continue.
* Call convertToLlm() first to handle custom message types.
*
* Tool results are truncated to keep the summarization request within
* reasonable token budgets. Full content is not needed for summarization.
*/
export function serializeConversation(messages: Message[]): string {
const parts: string[] = [];
for (const msg of messages) {
if (msg.role === "user") {
const content =
typeof msg.content === "string"
? msg.content
: msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) parts.push(`[User]: ${content}`);
} else if (msg.role === "assistant") {
const textParts: string[] = [];
const thinkingParts: string[] = [];
const toolCalls: string[] = [];
for (const block of msg.content) {
if (block.type === "text") {
textParts.push(block.text);
} else if (block.type === "thinking") {
thinkingParts.push(block.thinking);
} else if (block.type === "toolCall") {
const args = block.arguments as Record<string, unknown>;
const argsStr = Object.entries(args)
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
.join(", ");
toolCalls.push(`${block.name}(${argsStr})`);
}
}
if (thinkingParts.length > 0) {
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
}
if (textParts.length > 0) {
parts.push(`[Assistant]: ${textParts.join("\n")}`);
}
if (toolCalls.length > 0) {
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
}
} else if (msg.role === "toolResult") {
const content = msg.content
.filter((c): c is { type: "text"; text: string } => c.type === "text")
.map((c) => c.text)
.join("");
if (content) {
parts.push(`[Tool result]: ${truncateForSummary(content, TOOL_RESULT_MAX_CHARS)}`);
}
}
}
return parts.join("\n\n");
}
// ============================================================================
// Summarization System Prompt
// ============================================================================
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;

View file

@ -0,0 +1,3 @@
import type { ThinkingLevel } from "@gsd/pi-agent-core";
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";

View file

@ -0,0 +1,15 @@
export interface ResourceCollision {
resourceType: "extension" | "skill" | "prompt" | "theme";
name: string; // skill name, command/tool/flag name, prompt name, theme name
winnerPath: string;
loserPath: string;
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
loserSource?: string;
}
export interface ResourceDiagnostic {
type: "warning" | "error" | "collision";
message: string;
path?: string;
collision?: ResourceCollision;
}

View file

@ -0,0 +1,33 @@
import { EventEmitter } from "node:events";
export interface EventBus {
emit(channel: string, data: unknown): void;
on(channel: string, handler: (data: unknown) => void): () => void;
}
export interface EventBusController extends EventBus {
clear(): void;
}
export function createEventBus(): EventBusController {
const emitter = new EventEmitter();
return {
emit: (channel, data) => {
emitter.emit(channel, data);
},
on: (channel, handler) => {
const safeHandler = async (data: unknown) => {
try {
await handler(data);
} catch (err) {
console.error(`Event handler error (${channel}):`, err);
}
};
emitter.on(channel, safeHandler);
return () => emitter.off(channel, safeHandler);
},
clear: () => {
emitter.removeAllListeners();
},
};
}

View file

@ -0,0 +1,104 @@
/**
* Shared command execution utilities for extensions and custom tools.
*/
import { spawn } from "node:child_process";
/**
* Options for executing shell commands.
*/
export interface ExecOptions {
/** AbortSignal to cancel the command */
signal?: AbortSignal;
/** Timeout in milliseconds */
timeout?: number;
/** Working directory */
cwd?: string;
}
/**
* Result of executing a shell command.
*/
export interface ExecResult {
stdout: string;
stderr: string;
code: number;
killed: boolean;
}
/**
* Execute a shell command and return stdout/stderr/code.
* Supports timeout and abort signal.
*/
export async function execCommand(
command: string,
args: string[],
cwd: string,
options?: ExecOptions,
): Promise<ExecResult> {
return new Promise((resolve) => {
const proc = spawn(command, args, {
cwd,
shell: false,
stdio: ["ignore", "pipe", "pipe"],
});
let stdout = "";
let stderr = "";
let killed = false;
let timeoutId: NodeJS.Timeout | undefined;
const killProcess = () => {
if (!killed) {
killed = true;
proc.kill("SIGTERM");
// Force kill after 5 seconds if SIGTERM doesn't work
setTimeout(() => {
if (!proc.killed) {
proc.kill("SIGKILL");
}
}, 5000);
}
};
// Handle abort signal
if (options?.signal) {
if (options.signal.aborted) {
killProcess();
} else {
options.signal.addEventListener("abort", killProcess, { once: true });
}
}
// Handle timeout
if (options?.timeout && options.timeout > 0) {
timeoutId = setTimeout(() => {
killProcess();
}, options.timeout);
}
proc.stdout?.on("data", (data) => {
stdout += data.toString();
});
proc.stderr?.on("data", (data) => {
stderr += data.toString();
});
proc.on("close", (code) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: code ?? 0, killed });
});
proc.on("error", (_err) => {
if (timeoutId) clearTimeout(timeoutId);
if (options?.signal) {
options.signal.removeEventListener("abort", killProcess);
}
resolve({ stdout, stderr, code: 1, killed });
});
});
}

View file

@ -0,0 +1,258 @@
/**
* ANSI escape code to HTML converter.
*
* Converts terminal ANSI color/style codes to HTML with inline styles.
* Supports:
* - Standard foreground colors (30-37) and bright variants (90-97)
* - Standard background colors (40-47) and bright variants (100-107)
* - 256-color palette (38;5;N and 48;5;N)
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
* - Text styles: bold (1), dim (2), italic (3), underline (4)
* - Reset (0)
*/
// Standard ANSI color palette (0-15)
const ANSI_COLORS = [
"#000000", // 0: black
"#800000", // 1: red
"#008000", // 2: green
"#808000", // 3: yellow
"#000080", // 4: blue
"#800080", // 5: magenta
"#008080", // 6: cyan
"#c0c0c0", // 7: white
"#808080", // 8: bright black
"#ff0000", // 9: bright red
"#00ff00", // 10: bright green
"#ffff00", // 11: bright yellow
"#0000ff", // 12: bright blue
"#ff00ff", // 13: bright magenta
"#00ffff", // 14: bright cyan
"#ffffff", // 15: bright white
];
/**
* Convert 256-color index to hex.
*/
function color256ToHex(index: number): string {
// Standard colors (0-15)
if (index < 16) {
return ANSI_COLORS[index];
}
// Color cube (16-231): 6x6x6 = 216 colors
if (index < 232) {
const cubeIndex = index - 16;
const r = Math.floor(cubeIndex / 36);
const g = Math.floor((cubeIndex % 36) / 6);
const b = cubeIndex % 6;
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
}
// Grayscale (232-255): 24 shades
const gray = 8 + (index - 232) * 10;
const grayHex = gray.toString(16).padStart(2, "0");
return `#${grayHex}${grayHex}${grayHex}`;
}
/**
* Escape HTML special characters.
*/
function escapeHtml(text: string): string {
return text
.replace(/&/g, "&amp;")
.replace(/</g, "&lt;")
.replace(/>/g, "&gt;")
.replace(/"/g, "&quot;")
.replace(/'/g, "&#039;");
}
interface TextStyle {
fg: string | null;
bg: string | null;
bold: boolean;
dim: boolean;
italic: boolean;
underline: boolean;
}
function createEmptyStyle(): TextStyle {
return {
fg: null,
bg: null,
bold: false,
dim: false,
italic: false,
underline: false,
};
}
function styleToInlineCSS(style: TextStyle): string {
const parts: string[] = [];
if (style.fg) parts.push(`color:${style.fg}`);
if (style.bg) parts.push(`background-color:${style.bg}`);
if (style.bold) parts.push("font-weight:bold");
if (style.dim) parts.push("opacity:0.6");
if (style.italic) parts.push("font-style:italic");
if (style.underline) parts.push("text-decoration:underline");
return parts.join(";");
}
function hasStyle(style: TextStyle): boolean {
return style.fg !== null || style.bg !== null || style.bold || style.dim || style.italic || style.underline;
}
/**
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
*/
function applySgrCode(params: number[], style: TextStyle): void {
let i = 0;
while (i < params.length) {
const code = params[i];
if (code === 0) {
// Reset all
style.fg = null;
style.bg = null;
style.bold = false;
style.dim = false;
style.italic = false;
style.underline = false;
} else if (code === 1) {
style.bold = true;
} else if (code === 2) {
style.dim = true;
} else if (code === 3) {
style.italic = true;
} else if (code === 4) {
style.underline = true;
} else if (code === 22) {
// Reset bold/dim
style.bold = false;
style.dim = false;
} else if (code === 23) {
style.italic = false;
} else if (code === 24) {
style.underline = false;
} else if (code >= 30 && code <= 37) {
// Standard foreground colors
style.fg = ANSI_COLORS[code - 30];
} else if (code === 38) {
// Extended foreground color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 38;5;N
style.fg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 38;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.fg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 39) {
// Default foreground
style.fg = null;
} else if (code >= 40 && code <= 47) {
// Standard background colors
style.bg = ANSI_COLORS[code - 40];
} else if (code === 48) {
// Extended background color
if (params[i + 1] === 5 && params.length > i + 2) {
// 256-color: 48;5;N
style.bg = color256ToHex(params[i + 2]);
i += 2;
} else if (params[i + 1] === 2 && params.length > i + 4) {
// RGB: 48;2;R;G;B
const r = params[i + 2];
const g = params[i + 3];
const b = params[i + 4];
style.bg = `rgb(${r},${g},${b})`;
i += 4;
}
} else if (code === 49) {
// Default background
style.bg = null;
} else if (code >= 90 && code <= 97) {
// Bright foreground colors
style.fg = ANSI_COLORS[code - 90 + 8];
} else if (code >= 100 && code <= 107) {
// Bright background colors
style.bg = ANSI_COLORS[code - 100 + 8];
}
// Ignore unrecognized codes
i++;
}
}
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
/**
* Convert ANSI-escaped text to HTML with inline styles.
*/
export function ansiToHtml(text: string): string {
const style = createEmptyStyle();
let result = "";
let lastIndex = 0;
let inSpan = false;
// Reset regex state
ANSI_REGEX.lastIndex = 0;
let match = ANSI_REGEX.exec(text);
while (match !== null) {
// Add text before this escape sequence
const beforeText = text.slice(lastIndex, match.index);
if (beforeText) {
result += escapeHtml(beforeText);
}
// Parse SGR parameters
const paramStr = match[1];
const params = paramStr ? paramStr.split(";").map((p) => parseInt(p, 10) || 0) : [0];
// Close existing span if we have one
if (inSpan) {
result += "</span>";
inSpan = false;
}
// Apply the codes
applySgrCode(params, style);
// Open new span if we have any styling
if (hasStyle(style)) {
result += `<span style="${styleToInlineCSS(style)}">`;
inSpan = true;
}
lastIndex = match.index + match[0].length;
match = ANSI_REGEX.exec(text);
}
// Add remaining text
const remainingText = text.slice(lastIndex);
if (remainingText) {
result += escapeHtml(remainingText);
}
// Close any open span
if (inSpan) {
result += "</span>";
}
return result;
}
/**
* Convert array of ANSI-escaped lines to HTML.
* Each line is wrapped in a div element.
*/
export function ansiLinesToHtml(lines: string[]): string {
return lines.map((line) => `<div class="ansi-line">${ansiToHtml(line) || "&nbsp;"}</div>`).join("\n");
}

View file

@ -0,0 +1,306 @@
import type { AgentState } from "@gsd/pi-agent-core";
import { existsSync, readFileSync, writeFileSync } from "fs";
import { basename, join } from "path";
import { APP_NAME, getExportTemplateDir } from "../../config.js";
import { getResolvedThemeColors, getThemeExportColors } from "../../modes/interactive/theme/theme.js";
import type { ToolInfo } from "../extensions/types.js";
import type { SessionEntry } from "../session-manager.js";
import { SessionManager } from "../session-manager.js";
/**
* Interface for rendering custom tools to HTML.
* Used by agent-session to pre-render extension tool output.
*/
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to HTML. Returns collapsed/expanded or undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
details: unknown,
isError: boolean,
): { collapsed?: string; expanded?: string } | undefined;
}
/** Pre-rendered HTML for a custom tool call and result */
interface RenderedToolHtml {
callHtml?: string;
resultHtmlCollapsed?: string;
resultHtmlExpanded?: string;
}
export interface ExportOptions {
outputPath?: string;
themeName?: string;
/** Optional tool renderer for custom tools */
toolRenderer?: ToolHtmlRenderer;
}
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
function parseColor(color: string): { r: number; g: number; b: number } | undefined {
const hexMatch = color.match(/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/);
if (hexMatch) {
return {
r: Number.parseInt(hexMatch[1], 16),
g: Number.parseInt(hexMatch[2], 16),
b: Number.parseInt(hexMatch[3], 16),
};
}
const rgbMatch = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/);
if (rgbMatch) {
return {
r: Number.parseInt(rgbMatch[1], 10),
g: Number.parseInt(rgbMatch[2], 10),
b: Number.parseInt(rgbMatch[3], 10),
};
}
return undefined;
}
/** Calculate relative luminance of a color (0-1, higher = lighter). */
function getLuminance(r: number, g: number, b: number): number {
const toLinear = (c: number) => {
const s = c / 255;
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
};
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
}
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
function adjustBrightness(color: string, factor: number): string {
const parsed = parseColor(color);
if (!parsed) return color;
const adjust = (c: number) => Math.min(255, Math.max(0, Math.round(c * factor)));
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
}
/** Derive export background colors from a base color (e.g., userMessageBg). */
function deriveExportColors(baseColor: string): { pageBg: string; cardBg: string; infoBg: string } {
const parsed = parseColor(baseColor);
if (!parsed) {
return {
pageBg: "rgb(24, 24, 30)",
cardBg: "rgb(30, 30, 36)",
infoBg: "rgb(60, 55, 40)",
};
}
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
const isLight = luminance > 0.5;
if (isLight) {
return {
pageBg: adjustBrightness(baseColor, 0.96),
cardBg: baseColor,
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
};
}
return {
pageBg: adjustBrightness(baseColor, 0.7),
cardBg: adjustBrightness(baseColor, 0.85),
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
};
}
/**
* Generate CSS custom property declarations from theme colors.
*/
function generateThemeVars(themeName?: string): string {
const colors = getResolvedThemeColors(themeName);
const lines: string[] = [];
for (const [key, value] of Object.entries(colors)) {
lines.push(`--${key}: ${value};`);
}
// Use explicit theme export colors if available, otherwise derive from userMessageBg
const themeExport = getThemeExportColors(themeName);
const userMessageBg = colors.userMessageBg || "#343541";
const derivedColors = deriveExportColors(userMessageBg);
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
return lines.join("\n ");
}
interface SessionData {
header: ReturnType<SessionManager["getHeader"]>;
entries: ReturnType<SessionManager["getEntries"]>;
leafId: string | null;
systemPrompt?: string;
tools?: ToolInfo[];
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
renderedTools?: Record<string, RenderedToolHtml>;
}
/**
* Core HTML generation logic shared by both export functions.
*/
function generateHtml(sessionData: SessionData, themeName?: string): string {
const templateDir = getExportTemplateDir();
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
const markedJs = readFileSync(join(templateDir, "vendor", "marked.min.js"), "utf-8");
const hljsJs = readFileSync(join(templateDir, "vendor", "highlight.min.js"), "utf-8");
const themeVars = generateThemeVars(themeName);
const colors = getResolvedThemeColors(themeName);
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
const bodyBg = exportColors.pageBg;
const containerBg = exportColors.cardBg;
const infoBg = exportColors.infoBg;
// Base64 encode session data to avoid escaping issues
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString("base64");
// Build the CSS with theme variables injected
const css = templateCss
.replace("{{THEME_VARS}}", themeVars)
.replace("{{BODY_BG}}", bodyBg)
.replace("{{CONTAINER_BG}}", containerBg)
.replace("{{INFO_BG}}", infoBg);
return template
.replace("{{CSS}}", css)
.replace("{{JS}}", templateJs)
.replace("{{SESSION_DATA}}", sessionDataBase64)
.replace("{{MARKED_JS}}", markedJs)
.replace("{{HIGHLIGHT_JS}}", hljsJs);
}
/** Built-in tool names that have custom rendering in template.js */
const BUILTIN_TOOLS = new Set(["bash", "read", "write", "edit", "ls", "find", "grep"]);
/**
* Pre-render custom tools to HTML using their TUI renderers.
*/
function preRenderCustomTools(
entries: SessionEntry[],
toolRenderer: ToolHtmlRenderer,
): Record<string, RenderedToolHtml> {
const renderedTools: Record<string, RenderedToolHtml> = {};
for (const entry of entries) {
if (entry.type !== "message") continue;
const msg = entry.message;
// Find tool calls in assistant messages
if (msg.role === "assistant" && Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
if (callHtml) {
renderedTools[block.id] = { callHtml };
}
}
}
}
// Find tool results
if (msg.role === "toolResult" && msg.toolCallId) {
const toolName = msg.toolName || "";
// Only render if we have a pre-rendered call OR it's not a built-in tool
const existing = renderedTools[msg.toolCallId];
if (existing || !BUILTIN_TOOLS.has(toolName)) {
const rendered = toolRenderer.renderResult(toolName, msg.content, msg.details, msg.isError || false);
if (rendered) {
renderedTools[msg.toolCallId] = {
...existing,
resultHtmlCollapsed: rendered.collapsed,
resultHtmlExpanded: rendered.expanded,
};
}
}
}
}
return renderedTools;
}
/**
* Export session to HTML using SessionManager and AgentState.
* Used by TUI's /export command.
*/
export async function exportSessionToHtml(
sm: SessionManager,
state?: AgentState,
options?: ExportOptions | string,
): Promise<string> {
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
const sessionFile = sm.getSessionFile();
if (!sessionFile) {
throw new Error("Cannot export in-memory session to HTML");
}
if (!existsSync(sessionFile)) {
throw new Error("Nothing to export yet - start a conversation first");
}
const entries = sm.getEntries();
// Pre-render custom tools if a tool renderer is provided
let renderedTools: Record<string, RenderedToolHtml> | undefined;
if (opts.toolRenderer) {
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
// Only include if we actually rendered something
if (Object.keys(renderedTools).length === 0) {
renderedTools = undefined;
}
}
const sessionData: SessionData = {
header: sm.getHeader(),
entries,
leafId: sm.getLeafId(),
systemPrompt: state?.systemPrompt,
tools: state?.tools?.map((t) => ({ name: t.name, description: t.description, parameters: t.parameters })),
renderedTools,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const sessionBasename = basename(sessionFile, ".jsonl");
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}
/**
* Export session file to HTML (standalone, without AgentState).
* Used by CLI for exporting arbitrary session files.
*/
export async function exportFromFile(inputPath: string, options?: ExportOptions | string): Promise<string> {
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
if (!existsSync(inputPath)) {
throw new Error(`File not found: ${inputPath}`);
}
const sm = SessionManager.open(inputPath);
const sessionData: SessionData = {
header: sm.getHeader(),
entries: sm.getEntries(),
leafId: sm.getLeafId(),
systemPrompt: undefined,
tools: undefined,
};
const html = generateHtml(sessionData, opts.themeName);
let outputPath = opts.outputPath;
if (!outputPath) {
const inputBasename = basename(inputPath, ".jsonl");
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
}
writeFileSync(outputPath, html, "utf8");
return outputPath;
}

View file

@ -0,0 +1,971 @@
:root {
{{THEME_VARS}}
--body-bg: {{BODY_BG}};
--container-bg: {{CONTAINER_BG}};
--info-bg: {{INFO_BG}};
}
* { margin: 0; padding: 0; box-sizing: border-box; }
:root {
--line-height: 18px; /* 12px font * 1.5 */
}
body {
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
font-size: 12px;
line-height: var(--line-height);
color: var(--text);
background: var(--body-bg);
}
#app {
display: flex;
min-height: 100vh;
}
/* Sidebar */
#sidebar {
width: 400px;
background: var(--container-bg);
flex-shrink: 0;
display: flex;
flex-direction: column;
position: sticky;
top: 0;
height: 100vh;
border-right: 1px solid var(--dim);
}
.sidebar-header {
padding: 8px 12px;
flex-shrink: 0;
}
.sidebar-controls {
padding: 8px 8px 4px 8px;
}
.sidebar-search {
width: 100%;
box-sizing: border-box;
padding: 4px 8px;
font-size: 11px;
font-family: inherit;
background: var(--body-bg);
color: var(--text);
border: 1px solid var(--dim);
border-radius: 3px;
}
.sidebar-filters {
display: flex;
padding: 4px 8px 8px 8px;
gap: 4px;
align-items: center;
flex-wrap: wrap;
}
.sidebar-search:focus {
outline: none;
border-color: var(--accent);
}
.sidebar-search::placeholder {
color: var(--muted);
}
.filter-btn {
padding: 3px 8px;
font-size: 10px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
.filter-btn:hover {
color: var(--text);
border-color: var(--text);
}
.filter-btn.active {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.sidebar-close {
display: none;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
margin-left: auto;
}
.sidebar-close:hover {
color: var(--text);
border-color: var(--text);
}
.tree-container {
flex: 1;
overflow: auto;
padding: 4px 0;
}
.tree-node {
padding: 0 8px;
cursor: pointer;
display: flex;
align-items: baseline;
font-size: 11px;
line-height: 13px;
white-space: nowrap;
}
.tree-node:hover {
background: var(--selectedBg);
}
.tree-node.active {
background: var(--selectedBg);
}
.tree-node.active .tree-content {
font-weight: bold;
}
.tree-node.in-path {
background: color-mix(in srgb, var(--accent) 10%, transparent);
}
.tree-node:not(.in-path) {
opacity: 0.5;
}
.tree-node:not(.in-path):hover {
opacity: 1;
}
.tree-prefix {
color: var(--muted);
flex-shrink: 0;
font-family: monospace;
white-space: pre;
}
.tree-marker {
color: var(--accent);
flex-shrink: 0;
}
.tree-content {
color: var(--text);
}
.tree-role-user {
color: var(--accent);
}
.tree-role-assistant {
color: var(--success);
}
.tree-role-tool {
color: var(--muted);
}
.tree-muted {
color: var(--muted);
}
.tree-error {
color: var(--error);
}
.tree-compaction {
color: var(--borderAccent);
}
.tree-branch-summary {
color: var(--warning);
}
.tree-custom-message {
color: var(--customMessageLabel);
}
.tree-status {
padding: 4px 12px;
font-size: 10px;
color: var(--muted);
flex-shrink: 0;
}
/* Main content */
#content {
flex: 1;
overflow-y: auto;
padding: var(--line-height) calc(var(--line-height) * 2);
display: flex;
flex-direction: column;
align-items: center;
}
#content > * {
width: 100%;
max-width: 800px;
}
/* Help bar */
.help-bar {
font-size: 11px;
color: var(--warning);
margin-bottom: var(--line-height);
display: flex;
align-items: center;
gap: 12px;
}
.download-json-btn {
font-size: 10px;
padding: 2px 8px;
background: var(--container-bg);
border: 1px solid var(--border);
border-radius: 3px;
color: var(--text);
cursor: pointer;
font-family: inherit;
}
.download-json-btn:hover {
background: var(--hover);
border-color: var(--borderAccent);
}
/* Header */
.header {
background: var(--container-bg);
border-radius: 4px;
padding: var(--line-height);
margin-bottom: var(--line-height);
}
.header h1 {
font-size: 12px;
font-weight: bold;
color: var(--borderAccent);
margin-bottom: var(--line-height);
}
.header-info {
display: flex;
flex-direction: column;
gap: 0;
font-size: 11px;
}
.info-item {
color: var(--dim);
display: flex;
align-items: baseline;
}
.info-label {
font-weight: 600;
margin-right: 8px;
min-width: 100px;
}
.info-value {
color: var(--text);
flex: 1;
}
/* Messages */
#messages {
display: flex;
flex-direction: column;
gap: var(--line-height);
}
.message-timestamp {
font-size: 10px;
color: var(--dim);
opacity: 0.8;
}
.user-message {
background: var(--userMessageBg);
color: var(--userMessageText);
padding: var(--line-height);
border-radius: 4px;
position: relative;
}
.assistant-message {
padding: 0;
position: relative;
}
/* Copy link button - appears on hover */
.copy-link-btn {
position: absolute;
top: 8px;
right: 8px;
width: 28px;
height: 28px;
padding: 6px;
background: var(--container-bg);
border: 1px solid var(--dim);
border-radius: 4px;
color: var(--muted);
cursor: pointer;
opacity: 0;
transition: opacity 0.15s, background 0.15s, color 0.15s;
display: flex;
align-items: center;
justify-content: center;
z-index: 10;
}
.user-message:hover .copy-link-btn,
.assistant-message:hover .copy-link-btn {
opacity: 1;
}
.copy-link-btn:hover {
background: var(--accent);
color: var(--body-bg);
border-color: var(--accent);
}
.copy-link-btn.copied {
background: var(--success, #22c55e);
color: white;
border-color: var(--success, #22c55e);
}
/* Highlight effect for deep-linked messages */
.user-message.highlight,
.assistant-message.highlight {
animation: highlight-pulse 2s ease-out;
}
@keyframes highlight-pulse {
0% {
box-shadow: 0 0 0 3px var(--accent);
}
100% {
box-shadow: 0 0 0 0 transparent;
}
}
.assistant-message > .message-timestamp {
padding-left: var(--line-height);
}
.assistant-text {
padding: var(--line-height);
padding-bottom: 0;
}
.message-timestamp + .assistant-text,
.message-timestamp + .thinking-block {
padding-top: 0;
}
.thinking-block + .assistant-text {
padding-top: 0;
}
.thinking-text {
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
white-space: pre-wrap;
}
.message-timestamp + .thinking-block .thinking-text,
.message-timestamp + .thinking-block .thinking-collapsed {
padding-top: 0;
}
.thinking-collapsed {
display: none;
padding: var(--line-height);
color: var(--thinkingText);
font-style: italic;
}
/* Tool execution */
.tool-execution {
padding: var(--line-height);
border-radius: 4px;
}
.tool-execution + .tool-execution {
margin-top: var(--line-height);
}
.assistant-text + .tool-execution {
margin-top: var(--line-height);
}
.tool-execution.pending { background: var(--toolPendingBg); }
.tool-execution.success { background: var(--toolSuccessBg); }
.tool-execution.error { background: var(--toolErrorBg); }
.tool-header, .tool-name {
font-weight: bold;
}
.tool-path {
color: var(--accent);
word-break: break-all;
}
.line-numbers {
color: var(--warning);
}
.line-count {
color: var(--dim);
}
.tool-command {
font-weight: bold;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
}
.tool-output {
margin-top: var(--line-height);
color: var(--toolOutput);
word-wrap: break-word;
overflow-wrap: break-word;
word-break: break-word;
font-family: inherit;
overflow-x: auto;
}
.tool-output > div,
.output-preview,
.output-full {
margin: 0;
padding: 0;
line-height: var(--line-height);
}
.tool-output pre {
margin: 0;
padding: 0;
font-family: inherit;
color: inherit;
white-space: pre-wrap;
word-wrap: break-word;
overflow-wrap: break-word;
}
.tool-output code {
padding: 0;
background: none;
color: var(--text);
}
.tool-output.expandable {
cursor: pointer;
}
.tool-output.expandable:hover {
opacity: 0.9;
}
.tool-output.expandable .output-full {
display: none;
}
.tool-output.expandable.expanded .output-preview {
display: none;
}
.tool-output.expandable.expanded .output-full {
display: block;
}
.ansi-line {
white-space: pre-wrap;
}
.tool-images {
}
.tool-image {
max-width: 100%;
max-height: 500px;
border-radius: 4px;
margin: var(--line-height) 0;
}
.expand-hint {
color: var(--toolOutput);
}
/* Diff */
.tool-diff {
font-size: 11px;
overflow-x: auto;
white-space: pre;
}
.diff-added { color: var(--toolDiffAdded); }
.diff-removed { color: var(--toolDiffRemoved); }
.diff-context { color: var(--toolDiffContext); }
/* Model change */
.model-change {
padding: 0 var(--line-height);
color: var(--dim);
font-size: 11px;
}
.model-name {
color: var(--borderAccent);
font-weight: bold;
}
/* Compaction / Branch Summary - matches customMessage colors from TUI */
.compaction {
background: var(--customMessageBg);
border-radius: 4px;
padding: var(--line-height);
cursor: pointer;
}
.compaction-label {
color: var(--customMessageLabel);
font-weight: bold;
}
.compaction-collapsed {
color: var(--customMessageText);
}
.compaction-content {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
margin-top: var(--line-height);
}
.compaction.expanded .compaction-collapsed {
display: none;
}
.compaction.expanded .compaction-content {
display: block;
}
/* System prompt */
.system-prompt {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.system-prompt.expandable {
cursor: pointer;
}
.system-prompt-header {
font-weight: bold;
color: var(--customMessageLabel);
}
.system-prompt-preview {
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt-expand-hint {
color: var(--muted);
font-style: italic;
margin-top: 4px;
}
.system-prompt-full {
display: none;
color: var(--customMessageText);
white-space: pre-wrap;
word-wrap: break-word;
font-size: 11px;
margin-top: var(--line-height);
}
.system-prompt.expanded .system-prompt-preview,
.system-prompt.expanded .system-prompt-expand-hint {
display: none;
}
.system-prompt.expanded .system-prompt-full {
display: block;
}
.system-prompt.provider-prompt {
border-left: 3px solid var(--warning);
}
.system-prompt-note {
font-size: 10px;
font-style: italic;
color: var(--muted);
margin-top: 4px;
}
/* Tools list */
.tools-list {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
margin-bottom: var(--line-height);
}
.tools-header {
font-weight: bold;
color: var(--customMessageLabel);
margin-bottom: var(--line-height);
}
.tool-item {
font-size: 11px;
}
.tool-item-name {
font-weight: bold;
color: var(--text);
}
.tool-item-desc {
color: var(--dim);
}
.tool-params-hint {
color: var(--muted);
font-style: italic;
}
.tool-item:has(.tool-params-hint) {
cursor: pointer;
}
.tool-params-hint::after {
content: '[click to show parameters]';
}
.tool-item.params-expanded .tool-params-hint::after {
content: '[hide parameters]';
}
.tool-params-content {
display: none;
margin-top: 4px;
margin-left: 12px;
padding-left: 8px;
border-left: 1px solid var(--dim);
}
.tool-item.params-expanded .tool-params-content {
display: block;
}
.tool-param {
margin-bottom: 4px;
font-size: 11px;
}
.tool-param-name {
font-weight: bold;
color: var(--text);
}
.tool-param-type {
color: var(--dim);
font-style: italic;
}
.tool-param-required {
color: var(--warning, #e8a838);
font-size: 10px;
}
.tool-param-optional {
color: var(--dim);
font-size: 10px;
}
.tool-param-desc {
color: var(--dim);
margin-left: 8px;
}
/* Hook/custom messages */
.hook-message {
background: var(--customMessageBg);
color: var(--customMessageText);
padding: var(--line-height);
border-radius: 4px;
}
.hook-type {
color: var(--customMessageLabel);
font-weight: bold;
}
/* Branch summary */
.branch-summary {
background: var(--customMessageBg);
padding: var(--line-height);
border-radius: 4px;
}
.branch-summary-header {
font-weight: bold;
color: var(--borderAccent);
}
/* Error */
.error-text {
color: var(--error);
padding: 0 var(--line-height);
}
.tool-error {
color: var(--error);
}
/* Images */
.message-images {
margin-bottom: 12px;
}
.message-image {
max-width: 100%;
max-height: 400px;
border-radius: 4px;
margin: var(--line-height) 0;
}
/* Markdown content */
.markdown-content h1,
.markdown-content h2,
.markdown-content h3,
.markdown-content h4,
.markdown-content h5,
.markdown-content h6 {
color: var(--mdHeading);
margin: var(--line-height) 0 0 0;
font-weight: bold;
}
.markdown-content h1 { font-size: 1em; }
.markdown-content h2 { font-size: 1em; }
.markdown-content h3 { font-size: 1em; }
.markdown-content h4 { font-size: 1em; }
.markdown-content h5 { font-size: 1em; }
.markdown-content h6 { font-size: 1em; }
.markdown-content p { margin: 0; }
.markdown-content p + p { margin-top: var(--line-height); }
.markdown-content a {
color: var(--mdLink);
text-decoration: underline;
}
.markdown-content code {
background: rgba(128, 128, 128, 0.2);
color: var(--mdCode);
padding: 0 4px;
border-radius: 3px;
font-family: inherit;
}
.markdown-content pre {
background: transparent;
margin: var(--line-height) 0;
overflow-x: auto;
}
.markdown-content pre code {
display: block;
background: none;
color: var(--text);
}
.markdown-content blockquote {
border-left: 3px solid var(--mdQuoteBorder);
padding-left: var(--line-height);
margin: var(--line-height) 0;
color: var(--mdQuote);
font-style: italic;
}
.markdown-content ul,
.markdown-content ol {
margin: var(--line-height) 0;
padding-left: calc(var(--line-height) * 2);
}
.markdown-content li { margin: 0; }
.markdown-content li::marker { color: var(--mdListBullet); }
.markdown-content hr {
border: none;
border-top: 1px solid var(--mdHr);
margin: var(--line-height) 0;
}
.markdown-content table {
border-collapse: collapse;
margin: 0.5em 0;
width: 100%;
}
.markdown-content th,
.markdown-content td {
border: 1px solid var(--mdCodeBlockBorder);
padding: 6px 10px;
text-align: left;
}
.markdown-content th {
background: rgba(128, 128, 128, 0.1);
font-weight: bold;
}
.markdown-content img {
max-width: 100%;
border-radius: 4px;
}
/* Syntax highlighting */
.hljs { background: transparent; color: var(--text); }
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
/* Types: hljs v11 uses .hljs-title.class_ for class names */
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
.hljs-operator { color: var(--syntaxOperator); }
.hljs-punctuation { color: var(--syntaxPunctuation); }
.hljs-subst { color: var(--text); }
/* Footer */
.footer {
margin-top: 48px;
padding: 20px;
text-align: center;
color: var(--dim);
font-size: 10px;
}
/* Mobile */
#hamburger {
display: none;
position: fixed;
top: 10px;
left: 10px;
z-index: 100;
padding: 3px 8px;
font-size: 12px;
font-family: inherit;
background: transparent;
color: var(--muted);
border: 1px solid var(--dim);
border-radius: 3px;
cursor: pointer;
}
#hamburger:hover {
color: var(--text);
border-color: var(--text);
}
#sidebar-overlay {
display: none;
position: fixed;
top: 0;
left: 0;
right: 0;
bottom: 0;
background: rgba(0, 0, 0, 0.5);
z-index: 98;
}
@media (max-width: 900px) {
#sidebar {
position: fixed;
left: -400px;
width: 400px;
top: 0;
bottom: 0;
height: 100vh;
z-index: 99;
transition: left 0.3s;
}
#sidebar.open {
left: 0;
}
#sidebar-overlay.open {
display: block;
}
#hamburger {
display: block;
}
.sidebar-close {
display: block;
}
#content {
padding: var(--line-height) 16px;
}
#content > * {
max-width: 100%;
}
}
@media (max-width: 500px) {
#sidebar {
width: 100vw;
left: -100vw;
}
}
@media print {
#sidebar, #sidebar-toggle { display: none !important; }
body { background: white; color: black; }
#content { max-width: none; }
}

View file

@ -0,0 +1,54 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Session Export</title>
<style>
{{CSS}}
</style>
</head>
<body>
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
<div id="sidebar-overlay"></div>
<div id="app">
<aside id="sidebar">
<div class="sidebar-header">
<div class="sidebar-controls">
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
</div>
<div class="sidebar-filters">
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
<button class="sidebar-close" id="sidebar-close" title="Close"></button>
</div>
</div>
<div class="tree-container" id="tree-container"></div>
<div class="tree-status" id="tree-status"></div>
</aside>
<main id="content">
<div id="header-container"></div>
<div id="messages"></div>
</main>
<div id="image-modal" class="image-modal">
<img id="modal-image" src="" alt="">
</div>
</div>
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
<!-- Vendored libraries -->
<script>{{MARKED_JS}}</script>
<!-- highlight.js -->
<script>{{HIGHLIGHT_JS}}</script>
<!-- Main application code -->
<script>
{{JS}}
</script>
</body>
</html>

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
/**
* Tool HTML renderer for custom tools in HTML export.
*
* Renders custom tool calls and results to HTML by invoking their TUI renderers
* and converting the ANSI output to HTML.
*/
import type { ImageContent, TextContent } from "@gsd/pi-ai";
import type { Theme } from "../../modes/interactive/theme/theme.js";
import type { ToolDefinition } from "../extensions/types.js";
import { ansiLinesToHtml } from "./ansi-to-html.js";
export interface ToolHtmlRendererDeps {
/** Function to look up tool definition by name */
getToolDefinition: (name: string) => ToolDefinition | undefined;
/** Theme for styling */
theme: Theme;
/** Terminal width for rendering (default: 100) */
width?: number;
}
export interface ToolHtmlRenderer {
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
renderCall(toolName: string, args: unknown): string | undefined;
/** Render a tool result to collapsed/expanded HTML. Returns undefined if tool has no custom renderer. */
renderResult(
toolName: string,
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
details: unknown,
isError: boolean,
): { collapsed?: string; expanded?: string } | undefined;
}
/**
* Create a tool HTML renderer.
*
* The renderer looks up tool definitions and invokes their renderCall/renderResult
* methods, converting the resulting TUI Component output (ANSI) to HTML.
*/
export function createToolHtmlRenderer(deps: ToolHtmlRendererDeps): ToolHtmlRenderer {
const { getToolDefinition, theme, width = 100 } = deps;
return {
renderCall(toolName: string, args: unknown): string | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderCall) {
return undefined;
}
const component = toolDef.renderCall(args, theme);
if (!component) {
return undefined;
}
const lines = component.render(width);
return ansiLinesToHtml(lines);
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
renderResult(
toolName: string,
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
details: unknown,
isError: boolean,
): { collapsed?: string; expanded?: string } | undefined {
try {
const toolDef = getToolDefinition(toolName);
if (!toolDef?.renderResult) {
return undefined;
}
// Build AgentToolResult from content array
// Cast content since session storage uses generic object types
const agentToolResult = {
content: result as (TextContent | ImageContent)[],
details,
isError,
};
// Render collapsed
const collapsedComponent = toolDef.renderResult(
agentToolResult,
{ expanded: false, isPartial: false },
theme,
);
const collapsed = collapsedComponent ? ansiLinesToHtml(collapsedComponent.render(width)) : undefined;
// Render expanded
const expandedComponent = toolDef.renderResult(
agentToolResult,
{ expanded: true, isPartial: false },
theme,
);
const expanded = expandedComponent ? ansiLinesToHtml(expandedComponent.render(width)) : undefined;
// Return collapsed only if it exists and differs from expanded
if (!expanded) {
return undefined;
}
return {
...(collapsed && collapsed !== expanded ? { collapsed } : {}),
expanded,
};
} catch {
// On error, return undefined to trigger JSON fallback
return undefined;
}
},
};
}

View file

@ -0,0 +1,171 @@
/**
* Extension system for lifecycle events and custom tools.
*/
export type { SlashCommandInfo, SlashCommandLocation, SlashCommandSource } from "../slash-commands.js";
export {
createExtensionRuntime,
discoverAndLoadExtensions,
loadExtensionFromFactory,
loadExtensions,
} from "./loader.js";
export type {
ExtensionErrorListener,
ForkHandler,
NavigateTreeHandler,
NewSessionHandler,
ShutdownHandler,
SwitchSessionHandler,
} from "./runner.js";
export { ExtensionRunner } from "./runner.js";
export type {
AgentEndEvent,
AgentStartEvent,
// Re-exports
AgentToolResult,
AgentToolUpdateCallback,
// App keybindings (for custom editors)
AppAction,
AppendEntryHandler,
// Events - Tool (ToolCallEvent types)
BashToolCallEvent,
BashToolResultEvent,
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
BeforeProviderRequestEvent,
BeforeProviderRequestEventResult,
// Context
CompactOptions,
// Events - Agent
ContextEvent,
// Event Results
ContextEventResult,
ContextUsage,
CustomToolCallEvent,
CustomToolResultEvent,
EditToolCallEvent,
EditToolResultEvent,
ExecOptions,
ExecResult,
Extension,
ExtensionActions,
// API
ExtensionAPI,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
// Errors
ExtensionError,
ExtensionEvent,
ExtensionFactory,
ExtensionFlag,
ExtensionHandler,
// Runtime
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
ExtensionUIDialogOptions,
ExtensionWidgetOptions,
FindToolCallEvent,
FindToolResultEvent,
GetActiveToolsHandler,
GetAllToolsHandler,
GetCommandsHandler,
GetThinkingLevelHandler,
GrepToolCallEvent,
GrepToolResultEvent,
// Events - Input
InputEvent,
InputEventResult,
InputSource,
KeybindingsManager,
LoadExtensionsResult,
LsToolCallEvent,
LsToolResultEvent,
// Events - Message
MessageEndEvent,
// Message Rendering
MessageRenderer,
MessageRenderOptions,
MessageStartEvent,
MessageUpdateEvent,
ModelSelectEvent,
ModelSelectSource,
// Provider Registration
ProviderConfig,
ProviderModelConfig,
ReadToolCallEvent,
ReadToolResultEvent,
// Commands
RegisteredCommand,
RegisteredTool,
// Events - Resources
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SendMessageHandler,
SendUserMessageHandler,
SessionBeforeCompactEvent,
SessionBeforeCompactResult,
SessionBeforeForkEvent,
SessionBeforeForkResult,
SessionBeforeSwitchEvent,
SessionBeforeSwitchResult,
SessionBeforeTreeEvent,
SessionBeforeTreeResult,
SessionCompactEvent,
SessionDirectoryEvent,
SessionDirectoryHandler,
SessionDirectoryResult,
SessionEvent,
SessionForkEvent,
SessionShutdownEvent,
// Events - Session
SessionStartEvent,
SessionSwitchEvent,
SessionTreeEvent,
SetActiveToolsHandler,
SetLabelHandler,
SetModelHandler,
SetThinkingLevelHandler,
TerminalInputHandler,
// Events - Tool
ToolCallEvent,
ToolCallEventResult,
// Tools
ToolDefinition,
// Events - Tool Execution
ToolExecutionEndEvent,
ToolExecutionStartEvent,
ToolExecutionUpdateEvent,
ToolInfo,
ToolRenderResultOptions,
ToolResultEvent,
ToolResultEventResult,
TreePreparation,
TurnEndEvent,
TurnStartEvent,
// Events - User Bash
UserBashEvent,
UserBashEventResult,
WidgetPlacement,
WriteToolCallEvent,
WriteToolResultEvent,
} from "./types.js";
// Type guards
export {
isBashToolResult,
isEditToolResult,
isFindToolResult,
isGrepToolResult,
isLsToolResult,
isReadToolResult,
isToolCallEventType,
isWriteToolResult,
} from "./types.js";
export {
wrapRegisteredTool,
wrapRegisteredTools,
wrapToolsWithExtensions,
wrapToolWithExtensions,
} from "./wrapper.js";

View file

@ -0,0 +1,545 @@
/**
* Extension loader - loads TypeScript extension modules using jiti.
*
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
*/
import * as fs from "node:fs";
import { createRequire } from "node:module";
import * as os from "node:os";
import * as path from "node:path";
import { fileURLToPath } from "node:url";
import { createJiti } from "@mariozechner/jiti";
import * as _bundledPiAgentCore from "@gsd/pi-agent-core";
import * as _bundledPiAi from "@gsd/pi-ai";
import * as _bundledPiAiOauth from "@gsd/pi-ai/oauth";
import type { KeyId } from "@gsd/pi-tui";
import * as _bundledPiTui from "@gsd/pi-tui";
// Static imports of packages that extensions may use.
// These MUST be static so Bun bundles them into the compiled binary.
// The virtualModules option then makes them available to extensions.
import * as _bundledTypebox from "@sinclair/typebox";
import { getAgentDir, isBunBinary } from "../../config.js";
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
// avoiding a circular dependency. Extensions can import from @gsd/pi-coding-agent.
import * as _bundledPiCodingAgent from "../../index.js";
import { createEventBus, type EventBus } from "../event-bus.js";
import type { ExecOptions } from "../exec.js";
import { execCommand } from "../exec.js";
import type {
Extension,
ExtensionAPI,
ExtensionFactory,
ExtensionRuntime,
LoadExtensionsResult,
MessageRenderer,
ProviderConfig,
RegisteredCommand,
ToolDefinition,
} from "./types.js";
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
const VIRTUAL_MODULES: Record<string, unknown> = {
"@sinclair/typebox": _bundledTypebox,
"@gsd/pi-agent-core": _bundledPiAgentCore,
"@gsd/pi-tui": _bundledPiTui,
"@gsd/pi-ai": _bundledPiAi,
"@gsd/pi-ai/oauth": _bundledPiAiOauth,
"@gsd/pi-coding-agent": _bundledPiCodingAgent,
};
const require = createRequire(import.meta.url);
/**
* Get aliases for jiti (used in Node.js/development mode).
* In Bun binary mode, virtualModules is used instead.
*/
let _aliases: Record<string, string> | null = null;
function getAliases(): Record<string, string> {
if (_aliases) return _aliases;
const __dirname = path.dirname(fileURLToPath(import.meta.url));
const packageIndex = path.resolve(__dirname, "../..", "index.js");
const typeboxEntry = require.resolve("@sinclair/typebox");
const typeboxRoot = typeboxEntry.replace(/[\\/]build[\\/]cjs[\\/]index\.js$/, "");
const packagesRoot = path.resolve(__dirname, "../../../../");
const resolveWorkspaceOrImport = (workspaceRelativePath: string, specifier: string): string => {
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
if (fs.existsSync(workspacePath)) {
return workspacePath;
}
return fileURLToPath(import.meta.resolve(specifier));
};
_aliases = {
"@gsd/pi-coding-agent": packageIndex,
"@gsd/pi-agent-core": resolveWorkspaceOrImport("agent/dist/index.js", "@gsd/pi-agent-core"),
"@gsd/pi-tui": resolveWorkspaceOrImport("tui/dist/index.js", "@gsd/pi-tui"),
"@gsd/pi-ai": resolveWorkspaceOrImport("ai/dist/index.js", "@gsd/pi-ai"),
"@gsd/pi-ai/oauth": resolveWorkspaceOrImport("ai/dist/oauth.js", "@gsd/pi-ai/oauth"),
"@sinclair/typebox": typeboxRoot,
};
return _aliases;
}
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
function normalizeUnicodeSpaces(str: string): string {
return str.replace(UNICODE_SPACES, " ");
}
function expandPath(p: string): string {
const normalized = normalizeUnicodeSpaces(p);
if (normalized.startsWith("~/")) {
return path.join(os.homedir(), normalized.slice(2));
}
if (normalized.startsWith("~")) {
return path.join(os.homedir(), normalized.slice(1));
}
return normalized;
}
function resolvePath(extPath: string, cwd: string): string {
const expanded = expandPath(extPath);
if (path.isAbsolute(expanded)) {
return expanded;
}
return path.resolve(cwd, expanded);
}
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
/**
* Create a runtime with throwing stubs for action methods.
* Runner.bindCore() replaces these with real implementations.
*/
export function createExtensionRuntime(): ExtensionRuntime {
const notInitialized = () => {
throw new Error("Extension runtime not initialized. Action methods cannot be called during extension loading.");
};
const runtime: ExtensionRuntime = {
sendMessage: notInitialized,
sendUserMessage: notInitialized,
appendEntry: notInitialized,
setSessionName: notInitialized,
getSessionName: notInitialized,
setLabel: notInitialized,
getActiveTools: notInitialized,
getAllTools: notInitialized,
setActiveTools: notInitialized,
// registerTool() is valid during extension load; refresh is only needed post-bind.
refreshTools: () => {},
getCommands: notInitialized,
setModel: () => Promise.reject(new Error("Extension runtime not initialized")),
getThinkingLevel: notInitialized,
setThinkingLevel: notInitialized,
flagValues: new Map(),
pendingProviderRegistrations: [],
// Pre-bind: queue registrations so bindCore() can flush them once the
// model registry is available. bindCore() replaces both with direct calls.
registerProvider: (name, config) => {
runtime.pendingProviderRegistrations.push({ name, config });
},
unregisterProvider: (name) => {
runtime.pendingProviderRegistrations = runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
},
};
return runtime;
}
/**
* Create the ExtensionAPI for an extension.
* Registration methods write to the extension object.
* Action methods delegate to the shared runtime.
*/
function createExtensionAPI(
extension: Extension,
runtime: ExtensionRuntime,
cwd: string,
eventBus: EventBus,
): ExtensionAPI {
const api = {
// Registration methods - write to extension
on(event: string, handler: HandlerFn): void {
const list = extension.handlers.get(event) ?? [];
list.push(handler);
extension.handlers.set(event, list);
},
registerTool(tool: ToolDefinition): void {
extension.tools.set(tool.name, {
definition: tool,
extensionPath: extension.path,
});
runtime.refreshTools();
},
registerCommand(name: string, options: Omit<RegisteredCommand, "name">): void {
extension.commands.set(name, { name, ...options });
},
registerShortcut(
shortcut: KeyId,
options: {
description?: string;
handler: (ctx: import("./types.js").ExtensionContext) => Promise<void> | void;
},
): void {
extension.shortcuts.set(shortcut, { shortcut, extensionPath: extension.path, ...options });
},
registerFlag(
name: string,
options: { description?: string; type: "boolean" | "string"; default?: boolean | string },
): void {
extension.flags.set(name, { name, extensionPath: extension.path, ...options });
if (options.default !== undefined && !runtime.flagValues.has(name)) {
runtime.flagValues.set(name, options.default);
}
},
registerMessageRenderer<T>(customType: string, renderer: MessageRenderer<T>): void {
extension.messageRenderers.set(customType, renderer as MessageRenderer);
},
// Flag access - checks extension registered it, reads from runtime
getFlag(name: string): boolean | string | undefined {
if (!extension.flags.has(name)) return undefined;
return runtime.flagValues.get(name);
},
// Action methods - delegate to shared runtime
sendMessage(message, options): void {
runtime.sendMessage(message, options);
},
sendUserMessage(content, options): void {
runtime.sendUserMessage(content, options);
},
appendEntry(customType: string, data?: unknown): void {
runtime.appendEntry(customType, data);
},
setSessionName(name: string): void {
runtime.setSessionName(name);
},
getSessionName(): string | undefined {
return runtime.getSessionName();
},
setLabel(entryId: string, label: string | undefined): void {
runtime.setLabel(entryId, label);
},
exec(command: string, args: string[], options?: ExecOptions) {
return execCommand(command, args, options?.cwd ?? cwd, options);
},
getActiveTools(): string[] {
return runtime.getActiveTools();
},
getAllTools() {
return runtime.getAllTools();
},
setActiveTools(toolNames: string[]): void {
runtime.setActiveTools(toolNames);
},
getCommands() {
return runtime.getCommands();
},
setModel(model) {
return runtime.setModel(model);
},
getThinkingLevel() {
return runtime.getThinkingLevel();
},
setThinkingLevel(level) {
runtime.setThinkingLevel(level);
},
registerProvider(name: string, config: ProviderConfig) {
runtime.registerProvider(name, config);
},
unregisterProvider(name: string) {
runtime.unregisterProvider(name);
},
events: eventBus,
} as ExtensionAPI;
return api;
}
async function loadExtensionModule(extensionPath: string) {
const jiti = createJiti(import.meta.url, {
moduleCache: false,
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
// In Node.js/dev: use aliases to resolve to node_modules paths
...(isBunBinary ? { virtualModules: VIRTUAL_MODULES, tryNative: false } : { alias: getAliases() }),
});
const module = await jiti.import(extensionPath, { default: true });
const factory = module as ExtensionFactory;
return typeof factory !== "function" ? undefined : factory;
}
/**
* Create an Extension object with empty collections.
*/
function createExtension(extensionPath: string, resolvedPath: string): Extension {
return {
path: extensionPath,
resolvedPath,
handlers: new Map(),
tools: new Map(),
messageRenderers: new Map(),
commands: new Map(),
flags: new Map(),
shortcuts: new Map(),
};
}
async function loadExtension(
extensionPath: string,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
): Promise<{ extension: Extension | null; error: string | null }> {
const resolvedPath = resolvePath(extensionPath, cwd);
try {
const factory = await loadExtensionModule(resolvedPath);
if (!factory) {
return { extension: null, error: `Extension does not export a valid factory function: ${extensionPath}` };
}
const extension = createExtension(extensionPath, resolvedPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return { extension, error: null };
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
return { extension: null, error: `Failed to load extension: ${message}` };
}
}
/**
* Create an Extension from an inline factory function.
*/
export async function loadExtensionFromFactory(
factory: ExtensionFactory,
cwd: string,
eventBus: EventBus,
runtime: ExtensionRuntime,
extensionPath = "<inline>",
): Promise<Extension> {
const extension = createExtension(extensionPath, extensionPath);
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
await factory(api);
return extension;
}
/**
* Load extensions from paths.
*/
export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise<LoadExtensionsResult> {
const extensions: Extension[] = [];
const errors: Array<{ path: string; error: string }> = [];
const resolvedEventBus = eventBus ?? createEventBus();
const runtime = createExtensionRuntime();
for (const extPath of paths) {
const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, runtime);
if (error) {
errors.push({ path: extPath, error });
continue;
}
if (extension) {
extensions.push(extension);
}
}
return {
extensions,
errors,
runtime,
};
}
interface PiManifest {
extensions?: string[];
themes?: string[];
skills?: string[];
prompts?: string[];
}
function readPiManifest(packageJsonPath: string): PiManifest | null {
try {
const content = fs.readFileSync(packageJsonPath, "utf-8");
const pkg = JSON.parse(content);
if (pkg.pi && typeof pkg.pi === "object") {
return pkg.pi as PiManifest;
}
return null;
} catch {
return null;
}
}
function isExtensionFile(name: string): boolean {
return name.endsWith(".ts") || name.endsWith(".js");
}
/**
* Resolve extension entry points from a directory.
*
* Checks for:
* 1. package.json with "pi.extensions" field -> returns declared paths
* 2. index.ts or index.js -> returns the index file
*
* Returns resolved paths or null if no entry points found.
*/
function resolveExtensionEntries(dir: string): string[] | null {
// Check for package.json with "pi" field first
const packageJsonPath = path.join(dir, "package.json");
if (fs.existsSync(packageJsonPath)) {
const manifest = readPiManifest(packageJsonPath);
if (manifest?.extensions?.length) {
const entries: string[] = [];
for (const extPath of manifest.extensions) {
const resolvedExtPath = path.resolve(dir, extPath);
if (fs.existsSync(resolvedExtPath)) {
entries.push(resolvedExtPath);
}
}
if (entries.length > 0) {
return entries;
}
}
}
// Check for index.ts or index.js
const indexTs = path.join(dir, "index.ts");
const indexJs = path.join(dir, "index.js");
if (fs.existsSync(indexTs)) {
return [indexTs];
}
if (fs.existsSync(indexJs)) {
return [indexJs];
}
return null;
}
/**
* Discover extensions in a directory.
*
* Discovery rules:
* 1. Direct files: `extensions/*.ts` or `*.js` load
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` load
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field load what it declares
*
* No recursion beyond one level. Complex packages must use package.json manifest.
*/
function discoverExtensionsInDir(dir: string): string[] {
if (!fs.existsSync(dir)) {
return [];
}
const discovered: string[] = [];
try {
const entries = fs.readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const entryPath = path.join(dir, entry.name);
// 1. Direct files: *.ts or *.js
if ((entry.isFile() || entry.isSymbolicLink()) && isExtensionFile(entry.name)) {
discovered.push(entryPath);
continue;
}
// 2 & 3. Subdirectories
if (entry.isDirectory() || entry.isSymbolicLink()) {
const entries = resolveExtensionEntries(entryPath);
if (entries) {
discovered.push(...entries);
}
}
}
} catch {
return [];
}
return discovered;
}
/**
* Discover and load extensions from standard locations.
*/
export async function discoverAndLoadExtensions(
configuredPaths: string[],
cwd: string,
agentDir: string = getAgentDir(),
eventBus?: EventBus,
): Promise<LoadExtensionsResult> {
const allPaths: string[] = [];
const seen = new Set<string>();
const addPaths = (paths: string[]) => {
for (const p of paths) {
const resolved = path.resolve(p);
if (!seen.has(resolved)) {
seen.add(resolved);
allPaths.push(p);
}
}
};
// 1. Project-local extensions: cwd/.pi/extensions/
const localExtDir = path.join(cwd, ".pi", "extensions");
addPaths(discoverExtensionsInDir(localExtDir));
// 2. Global extensions: agentDir/extensions/
const globalExtDir = path.join(agentDir, "extensions");
addPaths(discoverExtensionsInDir(globalExtDir));
// 3. Explicitly configured paths
for (const p of configuredPaths) {
const resolved = resolvePath(p, cwd);
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
// Check for package.json with pi manifest or index.ts
const entries = resolveExtensionEntries(resolved);
if (entries) {
addPaths(entries);
continue;
}
// No explicit entries - discover individual files in directory
addPaths(discoverExtensionsInDir(resolved));
continue;
}
addPaths([resolved]);
}
return loadExtensions(allPaths, cwd, eventBus);
}

View file

@ -0,0 +1,884 @@
/**
* Extension runner - executes extensions and manages their lifecycle.
*/
import type { AgentMessage } from "@gsd/pi-agent-core";
import type { ImageContent, Model } from "@gsd/pi-ai";
import type { KeyId } from "@gsd/pi-tui";
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
import type { ResourceDiagnostic } from "../diagnostics.js";
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
import type { ModelRegistry } from "../model-registry.js";
import type { SessionManager } from "../session-manager.js";
import type {
BeforeAgentStartEvent,
BeforeAgentStartEventResult,
BeforeProviderRequestEvent,
CompactOptions,
ContextEvent,
ContextEventResult,
ContextUsage,
Extension,
ExtensionActions,
ExtensionCommandContext,
ExtensionCommandContextActions,
ExtensionContext,
ExtensionContextActions,
ExtensionError,
ExtensionEvent,
ExtensionFlag,
ExtensionRuntime,
ExtensionShortcut,
ExtensionUIContext,
InputEvent,
InputEventResult,
InputSource,
MessageRenderer,
RegisteredCommand,
RegisteredTool,
ResourcesDiscoverEvent,
ResourcesDiscoverResult,
SessionBeforeCompactResult,
SessionBeforeForkResult,
SessionBeforeSwitchResult,
SessionBeforeTreeResult,
ToolCallEvent,
ToolCallEventResult,
ToolResultEvent,
ToolResultEventResult,
UserBashEvent,
UserBashEventResult,
} from "./types.js";
// Keybindings for these actions cannot be overridden by extensions
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"externalEditor",
"followUp",
"submit",
"selectConfirm",
"selectCancel",
"copy",
"deleteToLineEnd",
];
type BuiltInKeyBindings = Partial<Record<KeyId, { action: KeyAction; restrictOverride: boolean }>>;
const buildBuiltinKeybindings = (effectiveKeybindings: Required<KeybindingsConfig>): BuiltInKeyBindings => {
const builtinKeybindings = {} as BuiltInKeyBindings;
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
const keyAction = action as KeyAction;
const keyList = Array.isArray(keys) ? keys : [keys];
const restrictOverride = RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
for (const key of keyList) {
const normalizedKey = key.toLowerCase() as KeyId;
builtinKeybindings[normalizedKey] = {
action: keyAction,
restrictOverride: restrictOverride,
};
}
}
return builtinKeybindings;
};
/** Combined result from all before_agent_start handlers */
interface BeforeAgentStartCombinedResult {
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
systemPrompt?: string;
}
/**
* Events handled by the generic emit() method.
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
*/
type RunnerEmitEvent = Exclude<
ExtensionEvent,
| ToolCallEvent
| ToolResultEvent
| UserBashEvent
| ContextEvent
| BeforeProviderRequestEvent
| BeforeAgentStartEvent
| ResourcesDiscoverEvent
| InputEvent
>;
type SessionBeforeEvent = Extract<
RunnerEmitEvent,
{ type: "session_before_switch" | "session_before_fork" | "session_before_compact" | "session_before_tree" }
>;
type SessionBeforeEventResult =
| SessionBeforeSwitchResult
| SessionBeforeForkResult
| SessionBeforeCompactResult
| SessionBeforeTreeResult;
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends { type: "session_before_switch" }
? SessionBeforeSwitchResult | undefined
: TEvent extends { type: "session_before_fork" }
? SessionBeforeForkResult | undefined
: TEvent extends { type: "session_before_compact" }
? SessionBeforeCompactResult | undefined
: TEvent extends { type: "session_before_tree" }
? SessionBeforeTreeResult | undefined
: undefined;
export type ExtensionErrorListener = (error: ExtensionError) => void;
export type NewSessionHandler = (options?: {
parentSession?: string;
setup?: (sessionManager: SessionManager) => Promise<void>;
}) => Promise<{ cancelled: boolean }>;
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
export type NavigateTreeHandler = (
targetId: string,
options?: { summarize?: boolean; customInstructions?: string; replaceInstructions?: boolean; label?: string },
) => Promise<{ cancelled: boolean }>;
export type SwitchSessionHandler = (sessionPath: string) => Promise<{ cancelled: boolean }>;
export type ReloadHandler = () => Promise<void>;
export type ShutdownHandler = () => void;
/**
* Helper function to emit session_shutdown event to extensions.
* Returns true if the event was emitted, false if there were no handlers.
*/
export async function emitSessionShutdownEvent(extensionRunner: ExtensionRunner | undefined): Promise<boolean> {
if (extensionRunner?.hasHandlers("session_shutdown")) {
await extensionRunner.emit({
type: "session_shutdown",
});
return true;
}
return false;
}
const noOpUIContext: ExtensionUIContext = {
select: async () => undefined,
confirm: async () => false,
input: async () => undefined,
notify: () => {},
onTerminalInput: () => () => {},
setStatus: () => {},
setWorkingMessage: () => {},
setWidget: () => {},
setFooter: () => {},
setHeader: () => {},
setTitle: () => {},
custom: async () => undefined as never,
pasteToEditor: () => {},
setEditorText: () => {},
getEditorText: () => "",
editor: async () => undefined,
setEditorComponent: () => {},
get theme() {
return theme;
},
getAllThemes: () => [],
getTheme: () => undefined,
setTheme: (_theme: string | Theme) => ({ success: false, error: "UI not available" }),
getToolsExpanded: () => false,
setToolsExpanded: () => {},
};
export class ExtensionRunner {
private extensions: Extension[];
private runtime: ExtensionRuntime;
private uiContext: ExtensionUIContext;
private cwd: string;
private sessionManager: SessionManager;
private modelRegistry: ModelRegistry;
private errorListeners: Set<ExtensionErrorListener> = new Set();
private getModel: () => Model<any> | undefined = () => undefined;
private isIdleFn: () => boolean = () => true;
private waitForIdleFn: () => Promise<void> = async () => {};
private abortFn: () => void = () => {};
private hasPendingMessagesFn: () => boolean = () => false;
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
private compactFn: (options?: CompactOptions) => void = () => {};
private getSystemPromptFn: () => string = () => "";
private newSessionHandler: NewSessionHandler = async () => ({ cancelled: false });
private forkHandler: ForkHandler = async () => ({ cancelled: false });
private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false });
private switchSessionHandler: SwitchSessionHandler = async () => ({ cancelled: false });
private reloadHandler: ReloadHandler = async () => {};
private shutdownHandler: ShutdownHandler = () => {};
private shortcutDiagnostics: ResourceDiagnostic[] = [];
private commandDiagnostics: ResourceDiagnostic[] = [];
constructor(
extensions: Extension[],
runtime: ExtensionRuntime,
cwd: string,
sessionManager: SessionManager,
modelRegistry: ModelRegistry,
) {
this.extensions = extensions;
this.runtime = runtime;
this.uiContext = noOpUIContext;
this.cwd = cwd;
this.sessionManager = sessionManager;
this.modelRegistry = modelRegistry;
}
bindCore(actions: ExtensionActions, contextActions: ExtensionContextActions): void {
// Copy actions into the shared runtime (all extension APIs reference this)
this.runtime.sendMessage = actions.sendMessage;
this.runtime.sendUserMessage = actions.sendUserMessage;
this.runtime.appendEntry = actions.appendEntry;
this.runtime.setSessionName = actions.setSessionName;
this.runtime.getSessionName = actions.getSessionName;
this.runtime.setLabel = actions.setLabel;
this.runtime.getActiveTools = actions.getActiveTools;
this.runtime.getAllTools = actions.getAllTools;
this.runtime.setActiveTools = actions.setActiveTools;
this.runtime.refreshTools = actions.refreshTools;
this.runtime.getCommands = actions.getCommands;
this.runtime.setModel = actions.setModel;
this.runtime.getThinkingLevel = actions.getThinkingLevel;
this.runtime.setThinkingLevel = actions.setThinkingLevel;
// Context actions (required)
this.getModel = contextActions.getModel;
this.isIdleFn = contextActions.isIdle;
this.abortFn = contextActions.abort;
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
this.shutdownHandler = contextActions.shutdown;
this.getContextUsageFn = contextActions.getContextUsage;
this.compactFn = contextActions.compact;
this.getSystemPromptFn = contextActions.getSystemPrompt;
// Flush provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
}
this.runtime.pendingProviderRegistrations = [];
// From this point on, provider registration/unregistration takes effect immediately
// without requiring a /reload.
this.runtime.registerProvider = (name, config) => this.modelRegistry.registerProvider(name, config);
this.runtime.unregisterProvider = (name) => this.modelRegistry.unregisterProvider(name);
}
bindCommandContext(actions?: ExtensionCommandContextActions): void {
if (actions) {
this.waitForIdleFn = actions.waitForIdle;
this.newSessionHandler = actions.newSession;
this.forkHandler = actions.fork;
this.navigateTreeHandler = actions.navigateTree;
this.switchSessionHandler = actions.switchSession;
this.reloadHandler = actions.reload;
return;
}
this.waitForIdleFn = async () => {};
this.newSessionHandler = async () => ({ cancelled: false });
this.forkHandler = async () => ({ cancelled: false });
this.navigateTreeHandler = async () => ({ cancelled: false });
this.switchSessionHandler = async () => ({ cancelled: false });
this.reloadHandler = async () => {};
}
setUIContext(uiContext?: ExtensionUIContext): void {
this.uiContext = uiContext ?? noOpUIContext;
}
getUIContext(): ExtensionUIContext {
return this.uiContext;
}
hasUI(): boolean {
return this.uiContext !== noOpUIContext;
}
getExtensionPaths(): string[] {
return this.extensions.map((e) => e.path);
}
/** Get all registered tools from all extensions (first registration per name wins). */
getAllRegisteredTools(): RegisteredTool[] {
const toolsByName = new Map<string, RegisteredTool>();
for (const ext of this.extensions) {
for (const tool of ext.tools.values()) {
if (!toolsByName.has(tool.definition.name)) {
toolsByName.set(tool.definition.name, tool);
}
}
}
return Array.from(toolsByName.values());
}
/** Get a tool definition by name. Returns undefined if not found. */
getToolDefinition(toolName: string): RegisteredTool["definition"] | undefined {
for (const ext of this.extensions) {
const tool = ext.tools.get(toolName);
if (tool) {
return tool.definition;
}
}
return undefined;
}
getFlags(): Map<string, ExtensionFlag> {
const allFlags = new Map<string, ExtensionFlag>();
for (const ext of this.extensions) {
for (const [name, flag] of ext.flags) {
if (!allFlags.has(name)) {
allFlags.set(name, flag);
}
}
}
return allFlags;
}
setFlagValue(name: string, value: boolean | string): void {
this.runtime.flagValues.set(name, value);
}
getFlagValues(): Map<string, boolean | string> {
return new Map(this.runtime.flagValues);
}
getShortcuts(effectiveKeybindings: Required<KeybindingsConfig>): Map<KeyId, ExtensionShortcut> {
this.shortcutDiagnostics = [];
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
const addDiagnostic = (message: string, extensionPath: string) => {
this.shortcutDiagnostics.push({ type: "warning", message, path: extensionPath });
if (!this.hasUI()) {
console.warn(message);
}
};
for (const ext of this.extensions) {
for (const [key, shortcut] of ext.shortcuts) {
const normalizedKey = key.toLowerCase() as KeyId;
const builtInKeybinding = builtinKeybindings[normalizedKey];
if (builtInKeybinding?.restrictOverride === true) {
addDiagnostic(
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
shortcut.extensionPath,
);
continue;
}
if (builtInKeybinding?.restrictOverride === false) {
addDiagnostic(
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
if (existingExtensionShortcut) {
addDiagnostic(
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
shortcut.extensionPath,
);
}
extensionShortcuts.set(normalizedKey, shortcut);
}
}
return extensionShortcuts;
}
getShortcutDiagnostics(): ResourceDiagnostic[] {
return this.shortcutDiagnostics;
}
onError(listener: ExtensionErrorListener): () => void {
this.errorListeners.add(listener);
return () => this.errorListeners.delete(listener);
}
emitError(error: ExtensionError): void {
for (const listener of this.errorListeners) {
listener(error);
}
}
hasHandlers(eventType: string): boolean {
for (const ext of this.extensions) {
const handlers = ext.handlers.get(eventType);
if (handlers && handlers.length > 0) {
return true;
}
}
return false;
}
getMessageRenderer(customType: string): MessageRenderer | undefined {
for (const ext of this.extensions) {
const renderer = ext.messageRenderers.get(customType);
if (renderer) {
return renderer;
}
}
return undefined;
}
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
this.commandDiagnostics = [];
const commands: RegisteredCommand[] = [];
const commandOwners = new Map<string, string>();
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
if (reserved?.has(command.name)) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
this.commandDiagnostics.push({ type: "warning", message, path: ext.path });
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
const existingOwner = commandOwners.get(command.name);
if (existingOwner) {
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
this.commandDiagnostics.push({ type: "warning", message, path: ext.path });
if (!this.hasUI()) {
console.warn(message);
}
continue;
}
commandOwners.set(command.name, ext.path);
commands.push(command);
}
}
return commands;
}
getCommandDiagnostics(): ResourceDiagnostic[] {
return this.commandDiagnostics;
}
getRegisteredCommandsWithPaths(): Array<{ command: RegisteredCommand; extensionPath: string }> {
const result: Array<{ command: RegisteredCommand; extensionPath: string }> = [];
for (const ext of this.extensions) {
for (const command of ext.commands.values()) {
result.push({ command, extensionPath: ext.path });
}
}
return result;
}
getCommand(name: string): RegisteredCommand | undefined {
for (const ext of this.extensions) {
const command = ext.commands.get(name);
if (command) {
return command;
}
}
return undefined;
}
/**
* Request a graceful shutdown. Called by extension tools and event handlers.
* The actual shutdown behavior is provided by the mode via bindExtensions().
*/
shutdown(): void {
this.shutdownHandler();
}
/**
* Create an ExtensionContext for use in event handlers and tool execution.
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
*/
createContext(): ExtensionContext {
const getModel = this.getModel;
return {
ui: this.uiContext,
hasUI: this.hasUI(),
cwd: this.cwd,
sessionManager: this.sessionManager,
modelRegistry: this.modelRegistry,
get model() {
return getModel();
},
isIdle: () => this.isIdleFn(),
abort: () => this.abortFn(),
hasPendingMessages: () => this.hasPendingMessagesFn(),
shutdown: () => this.shutdownHandler(),
getContextUsage: () => this.getContextUsageFn(),
compact: (options) => this.compactFn(options),
getSystemPrompt: () => this.getSystemPromptFn(),
};
}
createCommandContext(): ExtensionCommandContext {
return {
...this.createContext(),
waitForIdle: () => this.waitForIdleFn(),
newSession: (options) => this.newSessionHandler(options),
fork: (entryId) => this.forkHandler(entryId),
navigateTree: (targetId, options) => this.navigateTreeHandler(targetId, options),
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
reload: () => this.reloadHandler(),
};
}
private isSessionBeforeEvent(event: RunnerEmitEvent): event is SessionBeforeEvent {
return (
event.type === "session_before_switch" ||
event.type === "session_before_fork" ||
event.type === "session_before_compact" ||
event.type === "session_before_tree"
);
}
async emit<TEvent extends RunnerEmitEvent>(event: TEvent): Promise<RunnerEmitResult<TEvent>> {
const ctx = this.createContext();
let result: SessionBeforeEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get(event.type);
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (this.isSessionBeforeEvent(event) && handlerResult) {
result = handlerResult as SessionBeforeEventResult;
if (result.cancel) {
return result as RunnerEmitResult<TEvent>;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: event.type,
error: message,
stack,
});
}
}
}
return result as RunnerEmitResult<TEvent>;
}
async emitToolResult(event: ToolResultEvent): Promise<ToolResultEventResult | undefined> {
const ctx = this.createContext();
const currentEvent: ToolResultEvent = { ...event };
let modified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_result");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = (await handler(currentEvent, ctx)) as ToolResultEventResult | undefined;
if (!handlerResult) continue;
if (handlerResult.content !== undefined) {
currentEvent.content = handlerResult.content;
modified = true;
}
if (handlerResult.details !== undefined) {
currentEvent.details = handlerResult.details;
modified = true;
}
if (handlerResult.isError !== undefined) {
currentEvent.isError = handlerResult.isError;
modified = true;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "tool_result",
error: message,
stack,
});
}
}
}
if (!modified) {
return undefined;
}
return {
content: currentEvent.content,
details: currentEvent.details,
isError: currentEvent.isError,
};
}
async emitToolCall(event: ToolCallEvent): Promise<ToolCallEventResult | undefined> {
const ctx = this.createContext();
let result: ToolCallEventResult | undefined;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("tool_call");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
result = handlerResult as ToolCallEventResult;
if (result.block) {
return result;
}
}
}
}
return result;
}
async emitUserBash(event: UserBashEvent): Promise<UserBashEventResult | undefined> {
const ctx = this.createContext();
for (const ext of this.extensions) {
const handlers = ext.handlers.get("user_bash");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const handlerResult = await handler(event, ctx);
if (handlerResult) {
return handlerResult as UserBashEventResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "user_bash",
error: message,
stack,
});
}
}
}
return undefined;
}
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
const ctx = this.createContext();
let currentMessages = structuredClone(messages);
for (const ext of this.extensions) {
const handlers = ext.handlers.get("context");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ContextEvent = { type: "context", messages: currentMessages };
const handlerResult = await handler(event, ctx);
if (handlerResult && (handlerResult as ContextEventResult).messages) {
currentMessages = (handlerResult as ContextEventResult).messages!;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "context",
error: message,
stack,
});
}
}
}
return currentMessages;
}
async emitBeforeProviderRequest(payload: unknown): Promise<unknown> {
const ctx = this.createContext();
let currentPayload = payload;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("before_provider_request");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeProviderRequestEvent = {
type: "before_provider_request",
payload: currentPayload,
};
const handlerResult = await handler(event, ctx);
if (handlerResult !== undefined) {
currentPayload = handlerResult;
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "before_provider_request",
error: message,
stack,
});
}
}
}
return currentPayload;
}
async emitBeforeAgentStart(
prompt: string,
images: ImageContent[] | undefined,
systemPrompt: string,
): Promise<BeforeAgentStartCombinedResult | undefined> {
const ctx = this.createContext();
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
let currentSystemPrompt = systemPrompt;
let systemPromptModified = false;
for (const ext of this.extensions) {
const handlers = ext.handlers.get("before_agent_start");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: BeforeAgentStartEvent = {
type: "before_agent_start",
prompt,
images,
systemPrompt: currentSystemPrompt,
};
const handlerResult = await handler(event, ctx);
if (handlerResult) {
const result = handlerResult as BeforeAgentStartEventResult;
if (result.message) {
messages.push(result.message);
}
if (result.systemPrompt !== undefined) {
currentSystemPrompt = result.systemPrompt;
systemPromptModified = true;
}
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "before_agent_start",
error: message,
stack,
});
}
}
}
if (messages.length > 0 || systemPromptModified) {
return {
messages: messages.length > 0 ? messages : undefined,
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
};
}
return undefined;
}
async emitResourcesDiscover(
cwd: string,
reason: ResourcesDiscoverEvent["reason"],
): Promise<{
skillPaths: Array<{ path: string; extensionPath: string }>;
promptPaths: Array<{ path: string; extensionPath: string }>;
themePaths: Array<{ path: string; extensionPath: string }>;
}> {
const ctx = this.createContext();
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
const themePaths: Array<{ path: string; extensionPath: string }> = [];
for (const ext of this.extensions) {
const handlers = ext.handlers.get("resources_discover");
if (!handlers || handlers.length === 0) continue;
for (const handler of handlers) {
try {
const event: ResourcesDiscoverEvent = { type: "resources_discover", cwd, reason };
const handlerResult = await handler(event, ctx);
const result = handlerResult as ResourcesDiscoverResult | undefined;
if (result?.skillPaths?.length) {
skillPaths.push(...result.skillPaths.map((path) => ({ path, extensionPath: ext.path })));
}
if (result?.promptPaths?.length) {
promptPaths.push(...result.promptPaths.map((path) => ({ path, extensionPath: ext.path })));
}
if (result?.themePaths?.length) {
themePaths.push(...result.themePaths.map((path) => ({ path, extensionPath: ext.path })));
}
} catch (err) {
const message = err instanceof Error ? err.message : String(err);
const stack = err instanceof Error ? err.stack : undefined;
this.emitError({
extensionPath: ext.path,
event: "resources_discover",
error: message,
stack,
});
}
}
}
return { skillPaths, promptPaths, themePaths };
}
/** Emit input event. Transforms chain, "handled" short-circuits. */
async emitInput(text: string, images: ImageContent[] | undefined, source: InputSource): Promise<InputEventResult> {
const ctx = this.createContext();
let currentText = text;
let currentImages = images;
for (const ext of this.extensions) {
for (const handler of ext.handlers.get("input") ?? []) {
try {
const event: InputEvent = { type: "input", text: currentText, images: currentImages, source };
const result = (await handler(event, ctx)) as InputEventResult | undefined;
if (result?.action === "handled") return result;
if (result?.action === "transform") {
currentText = result.text;
currentImages = result.images ?? currentImages;
}
} catch (err) {
this.emitError({
extensionPath: ext.path,
event: "input",
error: err instanceof Error ? err.message : String(err),
stack: err instanceof Error ? err.stack : undefined,
});
}
}
}
return currentText !== text || currentImages !== images
? { action: "transform", text: currentText, images: currentImages }
: { action: "continue" };
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,118 @@
/**
* Tool wrappers for extensions.
*/
import type { AgentTool, AgentToolUpdateCallback } from "@gsd/pi-agent-core";
import type { ExtensionRunner } from "./runner.js";
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
/**
* Wrap a RegisteredTool into an AgentTool.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTool(registeredTool: RegisteredTool, runner: ExtensionRunner): AgentTool {
const { definition } = registeredTool;
return {
name: definition.name,
label: definition.label,
description: definition.description,
parameters: definition.parameters,
execute: (toolCallId, params, signal, onUpdate) =>
definition.execute(toolCallId, params, signal, onUpdate, runner.createContext()),
};
}
/**
* Wrap all registered tools into AgentTools.
* Uses the runner's createContext() for consistent context across tools and event handlers.
*/
export function wrapRegisteredTools(registeredTools: RegisteredTool[], runner: ExtensionRunner): AgentTool[] {
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
}
/**
* Wrap a tool with extension callbacks for interception.
* - Emits tool_call event before execution (can block)
* - Emits tool_result event after execution (can modify result)
*/
export function wrapToolWithExtensions<T>(tool: AgentTool<any, T>, runner: ExtensionRunner): AgentTool<any, T> {
return {
...tool,
execute: async (
toolCallId: string,
params: Record<string, unknown>,
signal?: AbortSignal,
onUpdate?: AgentToolUpdateCallback<T>,
) => {
// Emit tool_call event - extensions can block execution
if (runner.hasHandlers("tool_call")) {
try {
const callResult = (await runner.emitToolCall({
type: "tool_call",
toolName: tool.name,
toolCallId,
input: params,
})) as ToolCallEventResult | undefined;
if (callResult?.block) {
const reason = callResult.reason || "Tool execution was blocked by an extension";
throw new Error(reason);
}
} catch (err) {
if (err instanceof Error) {
throw err;
}
throw new Error(`Extension failed, blocking execution: ${String(err)}`);
}
}
// Execute the actual tool
try {
const result = await tool.execute(toolCallId, params, signal, onUpdate);
// Emit tool_result event - extensions can modify the result
if (runner.hasHandlers("tool_result")) {
const resultResult = await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: result.content,
details: result.details,
isError: false,
});
if (resultResult) {
return {
content: resultResult.content ?? result.content,
details: (resultResult.details ?? result.details) as T,
};
}
}
return result;
} catch (err) {
// Emit tool_result event for errors
if (runner.hasHandlers("tool_result")) {
await runner.emitToolResult({
type: "tool_result",
toolName: tool.name,
toolCallId,
input: params,
content: [{ type: "text", text: err instanceof Error ? err.message : String(err) }],
details: undefined,
isError: true,
});
}
throw err;
}
},
};
}
/**
* Wrap all tools with extension callbacks.
*/
export function wrapToolsWithExtensions<T>(tools: AgentTool<any, T>[], runner: ExtensionRunner): AgentTool<any, T>[] {
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
}

View file

@ -0,0 +1,144 @@
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
import { dirname, join, resolve } from "path";
/**
* Find the git HEAD path by walking up from cwd.
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
*/
function findGitHeadPath(): string | null {
let dir = process.cwd();
while (true) {
const gitPath = join(dir, ".git");
if (existsSync(gitPath)) {
try {
const stat = statSync(gitPath);
if (stat.isFile()) {
const content = readFileSync(gitPath, "utf8").trim();
if (content.startsWith("gitdir: ")) {
const gitDir = content.slice(8);
const headPath = resolve(dir, gitDir, "HEAD");
if (existsSync(headPath)) return headPath;
}
} else if (stat.isDirectory()) {
const headPath = join(gitPath, "HEAD");
if (existsSync(headPath)) return headPath;
}
} catch {
return null;
}
}
const parent = dirname(dir);
if (parent === dir) return null;
dir = parent;
}
}
/**
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
* Token stats, model info available via ctx.sessionManager and ctx.model.
*/
export class FooterDataProvider {
private extensionStatuses = new Map<string, string>();
private cachedBranch: string | null | undefined = undefined;
private gitWatcher: FSWatcher | null = null;
private branchChangeCallbacks = new Set<() => void>();
private availableProviderCount = 0;
constructor() {
this.setupGitWatcher();
}
/** Current git branch, null if not in repo, "detached" if detached HEAD */
getGitBranch(): string | null {
if (this.cachedBranch !== undefined) return this.cachedBranch;
try {
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) {
this.cachedBranch = null;
return null;
}
const content = readFileSync(gitHeadPath, "utf8").trim();
this.cachedBranch = content.startsWith("ref: refs/heads/") ? content.slice(16) : "detached";
} catch {
this.cachedBranch = null;
}
return this.cachedBranch;
}
/** Extension status texts set via ctx.ui.setStatus() */
getExtensionStatuses(): ReadonlyMap<string, string> {
return this.extensionStatuses;
}
/** Subscribe to git branch changes. Returns unsubscribe function. */
onBranchChange(callback: () => void): () => void {
this.branchChangeCallbacks.add(callback);
return () => this.branchChangeCallbacks.delete(callback);
}
/** Internal: set extension status */
setExtensionStatus(key: string, text: string | undefined): void {
if (text === undefined) {
this.extensionStatuses.delete(key);
} else {
this.extensionStatuses.set(key, text);
}
}
/** Internal: clear extension statuses */
clearExtensionStatuses(): void {
this.extensionStatuses.clear();
}
/** Number of unique providers with available models (for footer display) */
getAvailableProviderCount(): number {
return this.availableProviderCount;
}
/** Internal: update available provider count */
setAvailableProviderCount(count: number): void {
this.availableProviderCount = count;
}
/** Internal: cleanup */
dispose(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
this.branchChangeCallbacks.clear();
}
private setupGitWatcher(): void {
if (this.gitWatcher) {
this.gitWatcher.close();
this.gitWatcher = null;
}
const gitHeadPath = findGitHeadPath();
if (!gitHeadPath) return;
// Watch the directory containing HEAD, not HEAD itself.
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
// fs.watch on a file stops working after the inode changes.
const gitDir = dirname(gitHeadPath);
try {
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
if (filename === "HEAD") {
this.cachedBranch = undefined;
for (const cb of this.branchChangeCallbacks) cb();
}
});
} catch {
// Silently fail if we can't watch
}
}
}
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
export type ReadonlyFooterDataProvider = Pick<
FooterDataProvider,
"getGitBranch" | "getExtensionStatuses" | "getAvailableProviderCount" | "onBranchChange"
>;

View file

@ -0,0 +1,61 @@
/**
* Core modules shared between all run modes.
*/
export {
AgentSession,
type AgentSessionConfig,
type AgentSessionEvent,
type AgentSessionEventListener,
type ModelCycleResult,
type PromptOptions,
type SessionStats,
} from "./agent-session.js";
export { type BashExecutorOptions, type BashResult, executeBash, executeBashWithOperations } from "./bash-executor.js";
export type { CompactionResult } from "./compaction/index.js";
export { createEventBus, type EventBus, type EventBusController } from "./event-bus.js";
// Extensions system
export {
type AgentEndEvent,
type AgentStartEvent,
type AgentToolResult,
type AgentToolUpdateCallback,
type BeforeAgentStartEvent,
type ContextEvent,
discoverAndLoadExtensions,
type ExecOptions,
type ExecResult,
type Extension,
type ExtensionAPI,
type ExtensionCommandContext,
type ExtensionContext,
type ExtensionError,
type ExtensionEvent,
type ExtensionFactory,
type ExtensionFlag,
type ExtensionHandler,
ExtensionRunner,
type ExtensionShortcut,
type ExtensionUIContext,
type LoadExtensionsResult,
type MessageRenderer,
type RegisteredCommand,
type SessionBeforeCompactEvent,
type SessionBeforeForkEvent,
type SessionBeforeSwitchEvent,
type SessionBeforeTreeEvent,
type SessionCompactEvent,
type SessionForkEvent,
type SessionShutdownEvent,
type SessionStartEvent,
type SessionSwitchEvent,
type SessionTreeEvent,
type ToolCallEvent,
type ToolDefinition,
type ToolRenderResultOptions,
type ToolResultEvent,
type TurnEndEvent,
type TurnStartEvent,
wrapToolsWithExtensions,
} from "./extensions/index.js";

View file

@ -0,0 +1,211 @@
import {
DEFAULT_EDITOR_KEYBINDINGS,
type EditorAction,
type EditorKeybindingsConfig,
EditorKeybindingsManager,
type KeyId,
matchesKey,
setEditorKeybindings,
} from "@gsd/pi-tui";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
/**
* Application-level actions (coding agent specific).
*/
export type AppAction =
| "interrupt"
| "clear"
| "exit"
| "suspend"
| "cycleThinkingLevel"
| "cycleModelForward"
| "cycleModelBackward"
| "selectModel"
| "expandTools"
| "toggleThinking"
| "toggleSessionNamedFilter"
| "externalEditor"
| "followUp"
| "dequeue"
| "pasteImage"
| "newSession"
| "tree"
| "fork"
| "resume";
/**
* All configurable actions.
*/
export type KeyAction = AppAction | EditorAction;
/**
* Full keybindings configuration (app + editor actions).
*/
export type KeybindingsConfig = {
[K in KeyAction]?: KeyId | KeyId[];
};
/**
* Default application keybindings.
*/
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
interrupt: "escape",
clear: "ctrl+c",
exit: "ctrl+d",
suspend: "ctrl+z",
cycleThinkingLevel: "shift+tab",
cycleModelForward: "ctrl+p",
cycleModelBackward: "shift+ctrl+p",
selectModel: "ctrl+l",
expandTools: "ctrl+o",
toggleThinking: "ctrl+t",
toggleSessionNamedFilter: "ctrl+n",
externalEditor: "ctrl+g",
followUp: "alt+enter",
dequeue: "alt+up",
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
newSession: [],
tree: [],
fork: [],
resume: [],
};
/**
* All default keybindings (app + editor).
*/
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
...DEFAULT_EDITOR_KEYBINDINGS,
...DEFAULT_APP_KEYBINDINGS,
};
// App actions list for type checking
const APP_ACTIONS: AppAction[] = [
"interrupt",
"clear",
"exit",
"suspend",
"cycleThinkingLevel",
"cycleModelForward",
"cycleModelBackward",
"selectModel",
"expandTools",
"toggleThinking",
"toggleSessionNamedFilter",
"externalEditor",
"followUp",
"dequeue",
"pasteImage",
"newSession",
"tree",
"fork",
"resume",
];
function isAppAction(action: string): action is AppAction {
return APP_ACTIONS.includes(action as AppAction);
}
/**
* Manages all keybindings (app + editor).
*/
export class KeybindingsManager {
private config: KeybindingsConfig;
private appActionToKeys: Map<AppAction, KeyId[]>;
private constructor(config: KeybindingsConfig) {
this.config = config;
this.appActionToKeys = new Map();
this.buildMaps();
}
/**
* Create from config file and set up editor keybindings.
*/
static create(agentDir: string = getAgentDir()): KeybindingsManager {
const configPath = join(agentDir, "keybindings.json");
const config = KeybindingsManager.loadFromFile(configPath);
const manager = new KeybindingsManager(config);
// Set up editor keybindings globally
// Include both editor actions and expandTools (shared between app and editor)
const editorConfig: EditorKeybindingsConfig = {};
for (const [action, keys] of Object.entries(config)) {
if (!isAppAction(action) || action === "expandTools") {
editorConfig[action as EditorAction] = keys;
}
}
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
return manager;
}
/**
* Create in-memory.
*/
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
return new KeybindingsManager(config);
}
private static loadFromFile(path: string): KeybindingsConfig {
if (!existsSync(path)) return {};
try {
return JSON.parse(readFileSync(path, "utf-8"));
} catch {
return {};
}
}
private buildMaps(): void {
this.appActionToKeys.clear();
// Set defaults for app actions
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action as AppAction, [...keyArray]);
}
// Override with user config (app actions only)
for (const [action, keys] of Object.entries(this.config)) {
if (keys === undefined || !isAppAction(action)) continue;
const keyArray = Array.isArray(keys) ? keys : [keys];
this.appActionToKeys.set(action, keyArray);
}
}
/**
* Check if input matches an app action.
*/
matches(data: string, action: AppAction): boolean {
const keys = this.appActionToKeys.get(action);
if (!keys) return false;
for (const key of keys) {
if (matchesKey(data, key)) return true;
}
return false;
}
/**
* Get keys bound to an app action.
*/
getKeys(action: AppAction): KeyId[] {
return this.appActionToKeys.get(action) ?? [];
}
/**
* Get the full effective config.
*/
getEffectiveConfig(): Required<KeybindingsConfig> {
const result = { ...DEFAULT_KEYBINDINGS };
for (const [action, keys] of Object.entries(this.config)) {
if (keys !== undefined) {
(result as KeybindingsConfig)[action as KeyAction] = keys;
}
}
return result;
}
}
// Re-export for convenience
export type { EditorAction, KeyId };

View file

@ -0,0 +1,195 @@
/**
* Custom message types and transformers for the coding agent.
*
* Extends the base AgentMessage type with coding-agent specific message types,
* and provides a transformer to convert them to LLM-compatible messages.
*/
import type { AgentMessage } from "@gsd/pi-agent-core";
import type { ImageContent, Message, TextContent } from "@gsd/pi-ai";
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
<summary>
`;
export const COMPACTION_SUMMARY_SUFFIX = `
</summary>`;
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
<summary>
`;
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
/**
* Message type for bash executions via the ! command.
*/
export interface BashExecutionMessage {
role: "bashExecution";
command: string;
output: string;
exitCode: number | undefined;
cancelled: boolean;
truncated: boolean;
fullOutputPath?: string;
timestamp: number;
/** If true, this message is excluded from LLM context (!! prefix) */
excludeFromContext?: boolean;
}
/**
* Message type for extension-injected messages via sendMessage().
* These are custom messages that extensions can inject into the conversation.
*/
export interface CustomMessage<T = unknown> {
role: "custom";
customType: string;
content: string | (TextContent | ImageContent)[];
display: boolean;
details?: T;
timestamp: number;
}
export interface BranchSummaryMessage {
role: "branchSummary";
summary: string;
fromId: string;
timestamp: number;
}
export interface CompactionSummaryMessage {
role: "compactionSummary";
summary: string;
tokensBefore: number;
timestamp: number;
}
// Extend CustomAgentMessages via declaration merging
declare module "@gsd/pi-agent-core" {
interface CustomAgentMessages {
bashExecution: BashExecutionMessage;
custom: CustomMessage;
branchSummary: BranchSummaryMessage;
compactionSummary: CompactionSummaryMessage;
}
}
/**
* Convert a BashExecutionMessage to user message text for LLM context.
*/
export function bashExecutionToText(msg: BashExecutionMessage): string {
let text = `Ran \`${msg.command}\`\n`;
if (msg.output) {
text += `\`\`\`\n${msg.output}\n\`\`\``;
} else {
text += "(no output)";
}
if (msg.cancelled) {
text += "\n\n(command cancelled)";
} else if (msg.exitCode !== null && msg.exitCode !== undefined && msg.exitCode !== 0) {
text += `\n\nCommand exited with code ${msg.exitCode}`;
}
if (msg.truncated && msg.fullOutputPath) {
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
}
return text;
}
export function createBranchSummaryMessage(summary: string, fromId: string, timestamp: string): BranchSummaryMessage {
return {
role: "branchSummary",
summary,
fromId,
timestamp: new Date(timestamp).getTime(),
};
}
export function createCompactionSummaryMessage(
summary: string,
tokensBefore: number,
timestamp: string,
): CompactionSummaryMessage {
return {
role: "compactionSummary",
summary: summary,
tokensBefore,
timestamp: new Date(timestamp).getTime(),
};
}
/** Convert CustomMessageEntry to AgentMessage format */
export function createCustomMessage(
customType: string,
content: string | (TextContent | ImageContent)[],
display: boolean,
details: unknown | undefined,
timestamp: string,
): CustomMessage {
return {
role: "custom",
customType,
content,
display,
details,
timestamp: new Date(timestamp).getTime(),
};
}
/**
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
*
* This is used by:
* - Agent's transormToLlm option (for prompt calls and queued messages)
* - Compaction's generateSummary (for summarization)
* - Custom extensions and tools
*/
export function convertToLlm(messages: AgentMessage[]): Message[] {
return messages
.map((m): Message | undefined => {
switch (m.role) {
case "bashExecution":
// Skip messages excluded from context (!! prefix)
if (m.excludeFromContext) {
return undefined;
}
return {
role: "user",
content: [{ type: "text", text: bashExecutionToText(m) }],
timestamp: m.timestamp,
};
case "custom": {
const content = typeof m.content === "string" ? [{ type: "text" as const, text: m.content }] : m.content;
return {
role: "user",
content,
timestamp: m.timestamp,
};
}
case "branchSummary":
return {
role: "user",
content: [{ type: "text" as const, text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX }],
timestamp: m.timestamp,
};
case "compactionSummary":
return {
role: "user",
content: [
{ type: "text" as const, text: COMPACTION_SUMMARY_PREFIX + m.summary + COMPACTION_SUMMARY_SUFFIX },
],
timestamp: m.timestamp,
};
case "user":
case "assistant":
case "toolResult":
return m;
default:
// biome-ignore lint/correctness/noSwitchDeclarations: fine
const _exhaustiveCheck: never = m;
return undefined;
}
})
.filter((m) => m !== undefined);
}

View file

@ -0,0 +1,694 @@
/**
* Model registry - manages built-in and custom models, provides API key resolution.
*/
import {
type Api,
type AssistantMessageEventStream,
type Context,
getModels,
getProviders,
type KnownProvider,
type Model,
type OAuthProviderInterface,
type OpenAICompletionsCompat,
type OpenAIResponsesCompat,
registerApiProvider,
resetApiProviders,
type SimpleStreamOptions,
} from "@gsd/pi-ai";
import { registerOAuthProvider, resetOAuthProviders } from "@gsd/pi-ai/oauth";
import { type Static, Type } from "@sinclair/typebox";
import AjvModule from "ajv";
import { existsSync, readFileSync } from "fs";
import { join } from "path";
import { getAgentDir } from "../config.js";
import type { AuthStorage } from "./auth-storage.js";
import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
const Ajv = (AjvModule as any).default || AjvModule;
const ajv = new Ajv();
// Schema for OpenRouter routing preferences
const OpenRouterRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for Vercel AI Gateway routing preferences
const VercelGatewayRoutingSchema = Type.Object({
only: Type.Optional(Type.Array(Type.String())),
order: Type.Optional(Type.Array(Type.String())),
});
// Schema for OpenAI compatibility settings
const OpenAICompletionsCompatSchema = Type.Object({
supportsStore: Type.Optional(Type.Boolean()),
supportsDeveloperRole: Type.Optional(Type.Boolean()),
supportsReasoningEffort: Type.Optional(Type.Boolean()),
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
requiresToolResultName: Type.Optional(Type.Boolean()),
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
requiresThinkingAsText: Type.Optional(Type.Boolean()),
requiresMistralToolIds: Type.Optional(Type.Boolean()),
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai"), Type.Literal("qwen")])),
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
});
const OpenAIResponsesCompatSchema = Type.Object({
// Reserved for future use
});
const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]);
// Schema for custom model definition
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
const ModelDefinitionSchema = Type.Object({
id: Type.String({ minLength: 1 }),
name: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
cost: Type.Optional(
Type.Object({
input: Type.Number(),
output: Type.Number(),
cacheRead: Type.Number(),
cacheWrite: Type.Number(),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
// Schema for per-model overrides (all fields optional, merged with built-in model)
const ModelOverrideSchema = Type.Object({
name: Type.Optional(Type.String({ minLength: 1 })),
reasoning: Type.Optional(Type.Boolean()),
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
cost: Type.Optional(
Type.Object({
input: Type.Optional(Type.Number()),
output: Type.Optional(Type.Number()),
cacheRead: Type.Optional(Type.Number()),
cacheWrite: Type.Optional(Type.Number()),
}),
),
contextWindow: Type.Optional(Type.Number()),
maxTokens: Type.Optional(Type.Number()),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
compat: Type.Optional(OpenAICompatSchema),
});
type ModelOverride = Static<typeof ModelOverrideSchema>;
const ProviderConfigSchema = Type.Object({
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
apiKey: Type.Optional(Type.String({ minLength: 1 })),
api: Type.Optional(Type.String({ minLength: 1 })),
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
authHeader: Type.Optional(Type.Boolean()),
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
modelOverrides: Type.Optional(Type.Record(Type.String(), ModelOverrideSchema)),
});
const ModelsConfigSchema = Type.Object({
providers: Type.Record(Type.String(), ProviderConfigSchema),
});
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
type ModelsConfig = Static<typeof ModelsConfigSchema>;
/** Provider override config (baseUrl, headers, apiKey) without custom models */
interface ProviderOverride {
baseUrl?: string;
headers?: Record<string, string>;
apiKey?: string;
}
/** Result of loading custom models from models.json */
interface CustomModelsResult {
models: Model<Api>[];
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
overrides: Map<string, ProviderOverride>;
/** Per-model overrides: provider -> modelId -> override */
modelOverrides: Map<string, Map<string, ModelOverride>>;
error: string | undefined;
}
function emptyCustomModelsResult(error?: string): CustomModelsResult {
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
}
function mergeCompat(
baseCompat: Model<Api>["compat"],
overrideCompat: ModelOverride["compat"],
): Model<Api>["compat"] | undefined {
if (!overrideCompat) return baseCompat;
const base = baseCompat as OpenAICompletionsCompat | OpenAIResponsesCompat | undefined;
const override = overrideCompat as OpenAICompletionsCompat | OpenAIResponsesCompat;
const merged = { ...base, ...override } as OpenAICompletionsCompat | OpenAIResponsesCompat;
const baseCompletions = base as OpenAICompletionsCompat | undefined;
const overrideCompletions = override as OpenAICompletionsCompat;
const mergedCompletions = merged as OpenAICompletionsCompat;
if (baseCompletions?.openRouterRouting || overrideCompletions.openRouterRouting) {
mergedCompletions.openRouterRouting = {
...baseCompletions?.openRouterRouting,
...overrideCompletions.openRouterRouting,
};
}
if (baseCompletions?.vercelGatewayRouting || overrideCompletions.vercelGatewayRouting) {
mergedCompletions.vercelGatewayRouting = {
...baseCompletions?.vercelGatewayRouting,
...overrideCompletions.vercelGatewayRouting,
};
}
return merged as Model<Api>["compat"];
}
/**
* Deep merge a model override into a model.
* Handles nested objects (cost, compat) by merging rather than replacing.
*/
function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<Api> {
const result = { ...model };
// Simple field overrides
if (override.name !== undefined) result.name = override.name;
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
if (override.input !== undefined) result.input = override.input as ("text" | "image")[];
if (override.contextWindow !== undefined) result.contextWindow = override.contextWindow;
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
// Merge cost (partial override)
if (override.cost) {
result.cost = {
input: override.cost.input ?? model.cost.input,
output: override.cost.output ?? model.cost.output,
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
};
}
// Merge headers
if (override.headers) {
const resolvedHeaders = resolveHeaders(override.headers);
result.headers = resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers;
}
// Deep merge compat
result.compat = mergeCompat(model.compat, override.compat);
return result;
}
/** Clear the config value command cache. Exported for testing. */
export const clearApiKeyCache = clearConfigValueCache;
/**
* Model registry - loads and manages models, resolves API keys via AuthStorage.
*/
export class ModelRegistry {
private models: Model<Api>[] = [];
private customProviderApiKeys: Map<string, string> = new Map();
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
) {
// Set up fallback resolver for custom provider API keys
this.authStorage.setFallbackResolver((provider) => {
const keyConfig = this.customProviderApiKeys.get(provider);
if (keyConfig) {
return resolveConfigValue(keyConfig);
}
return undefined;
});
// Load models
this.loadModels();
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
refresh(): void {
this.customProviderApiKeys.clear();
this.loadError = undefined;
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
resetApiProviders();
resetOAuthProviders();
this.loadModels();
for (const [providerName, config] of this.registeredProviders.entries()) {
this.applyProviderConfig(providerName, config);
}
}
/**
* Get any error from loading models.json (undefined if no error).
*/
getError(): string | undefined {
return this.loadError;
}
private loadModels(): void {
// Load custom models and overrides from models.json
const {
models: customModels,
overrides,
modelOverrides,
error,
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
if (error) {
this.loadError = error;
// Keep built-in models even if custom models failed to load
}
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
let combined = this.mergeCustomModels(builtInModels, customModels);
// Let OAuth providers modify their models (e.g., update baseUrl)
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
const cred = this.authStorage.get(oauthProvider.id);
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
combined = oauthProvider.modifyModels(combined, cred);
}
}
this.models = combined;
}
/** Load built-in models and apply provider/model overrides */
private loadBuiltInModels(
overrides: Map<string, ProviderOverride>,
modelOverrides: Map<string, Map<string, ModelOverride>>,
): Model<Api>[] {
return getProviders().flatMap((provider) => {
const models = getModels(provider as KnownProvider) as Model<Api>[];
const providerOverride = overrides.get(provider);
const perModelOverrides = modelOverrides.get(provider);
return models.map((m) => {
let model = m;
// Apply provider-level baseUrl/headers override
if (providerOverride) {
const resolvedHeaders = resolveHeaders(providerOverride.headers);
model = {
...model,
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
};
}
// Apply per-model override
const modelOverride = perModelOverrides?.get(m.id);
if (modelOverride) {
model = applyModelOverride(model, modelOverride);
}
return model;
});
});
}
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
private mergeCustomModels(builtInModels: Model<Api>[], customModels: Model<Api>[]): Model<Api>[] {
const merged = [...builtInModels];
for (const customModel of customModels) {
const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id);
if (existingIndex >= 0) {
merged[existingIndex] = customModel;
} else {
merged.push(customModel);
}
}
return merged;
}
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
if (!existsSync(modelsJsonPath)) {
return emptyCustomModelsResult();
}
try {
const content = readFileSync(modelsJsonPath, "utf-8");
const config: ModelsConfig = JSON.parse(content);
// Validate schema
const validate = ajv.getSchema("ModelsConfig")!;
if (!validate(config)) {
const errors =
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
"Unknown schema error";
return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`);
}
// Additional validation
this.validateConfig(config);
const overrides = new Map<string, ProviderOverride>();
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) {
overrides.set(providerName, {
baseUrl: providerConfig.baseUrl,
headers: providerConfig.headers,
apiKey: providerConfig.apiKey,
});
}
// Store API key for fallback resolver.
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
if (providerConfig.modelOverrides) {
modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides)));
}
}
return { models: this.parseModels(config), overrides, modelOverrides, error: undefined };
} catch (error) {
if (error instanceof SyntaxError) {
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
}
return emptyCustomModelsResult(
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
);
}
}
private validateConfig(config: ModelsConfig): void {
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const hasProviderApi = !!providerConfig.api;
const models = providerConfig.models ?? [];
const hasModelOverrides =
providerConfig.modelOverrides && Object.keys(providerConfig.modelOverrides).length > 0;
if (models.length === 0) {
// Override-only config: needs baseUrl OR modelOverrides (or both)
if (!providerConfig.baseUrl && !hasModelOverrides) {
throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`);
}
} else {
// Custom models are merged into provider models and require endpoint + auth.
if (!providerConfig.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
}
if (!providerConfig.apiKey) {
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
}
}
for (const modelDef of models) {
const hasModelApi = !!modelDef.api;
if (!hasProviderApi && !hasModelApi) {
throw new Error(
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
);
}
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
// Validate contextWindow/maxTokens only if provided (they have defaults)
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
}
}
}
private parseModels(config: ModelsConfig): Model<Api>[] {
const models: Model<Api>[] = [];
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
const modelDefs = providerConfig.models ?? [];
if (modelDefs.length === 0) continue; // Override-only, no custom models
// Store API key config for fallback resolver
if (providerConfig.apiKey) {
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
}
for (const modelDef of modelDefs) {
const api = modelDef.api || providerConfig.api;
if (!api) continue;
// Merge headers: provider headers are base, model headers override
// Resolve env vars and shell commands in header values
const providerHeaders = resolveHeaders(providerConfig.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
// If authHeader is true, add Authorization header with resolved API key
if (providerConfig.authHeader && providerConfig.apiKey) {
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
// Provider baseUrl is required when custom models are defined.
// Individual models can override it with modelDef.baseUrl.
const defaultCost = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
models.push({
id: modelDef.id,
name: modelDef.name ?? modelDef.id,
api: api as Api,
provider: providerName,
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
reasoning: modelDef.reasoning ?? false,
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
cost: modelDef.cost ?? defaultCost,
contextWindow: modelDef.contextWindow ?? 128000,
maxTokens: modelDef.maxTokens ?? 16384,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
}
return models;
}
/**
* Get all models (built-in + custom).
* If models.json had errors, returns only built-in models.
*/
getAll(): Model<Api>[] {
return this.models;
}
/**
* Get only models that have auth configured.
* This is a fast check that doesn't refresh OAuth tokens.
*/
getAvailable(): Model<Api>[] {
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
}
/**
* Find a model by provider and ID.
*/
find(provider: string, modelId: string): Model<Api> | undefined {
return this.models.find((m) => m.provider === provider && m.id === modelId);
}
/**
* Get API key for a model.
*/
async getApiKey(model: Model<Api>): Promise<string | undefined> {
return this.authStorage.getApiKey(model.provider);
}
/**
* Get API key for a provider.
*/
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
return this.authStorage.getApiKey(provider);
}
/**
* Check if a model is using OAuth credentials (subscription).
*/
isUsingOAuth(model: Model<Api>): boolean {
const cred = this.authStorage.get(model.provider);
return cred?.type === "oauth";
}
/**
* Register a provider dynamically (from extensions).
*
* If provider has models: replaces all existing models for this provider.
* If provider has only baseUrl/headers: overrides existing models' URLs.
* If provider has oauth: registers OAuth provider for /login support.
*/
registerProvider(providerName: string, config: ProviderConfigInput): void {
this.registeredProviders.set(providerName, config);
this.applyProviderConfig(providerName, config);
}
/**
* Unregister a previously registered provider.
*
* Removes the provider from the registry and reloads models from disk so that
* built-in models overridden by this provider are restored to their original state.
* Also resets dynamic OAuth and API stream registrations before reapplying
* remaining dynamic providers.
* Has no effect if the provider was never registered.
*/
unregisterProvider(providerName: string): void {
if (!this.registeredProviders.has(providerName)) return;
this.registeredProviders.delete(providerName);
this.customProviderApiKeys.delete(providerName);
this.refresh();
}
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
// Register OAuth provider if provided
if (config.oauth) {
// Ensure the OAuth provider ID matches the provider name
const oauthProvider: OAuthProviderInterface = {
...config.oauth,
id: providerName,
};
registerOAuthProvider(oauthProvider);
}
if (config.streamSimple) {
if (!config.api) {
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
}
const streamSimple = config.streamSimple;
registerApiProvider(
{
api: config.api,
stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions),
streamSimple,
},
`provider:${providerName}`,
);
}
// Store API key for auth resolution
if (config.apiKey) {
this.customProviderApiKeys.set(providerName, config.apiKey);
}
if (config.models && config.models.length > 0) {
// Full replacement: remove existing models for this provider
this.models = this.models.filter((m) => m.provider !== providerName);
// Validate required fields
if (!config.baseUrl) {
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
}
if (!config.apiKey && !config.oauth) {
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
}
// Parse and add new models
for (const modelDef of config.models) {
const api = modelDef.api || config.api;
if (!api) {
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
}
// Merge headers
const providerHeaders = resolveHeaders(config.headers);
const modelHeaders = resolveHeaders(modelDef.headers);
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
// If authHeader is true, add Authorization header
if (config.authHeader && config.apiKey) {
const resolvedKey = resolveConfigValue(config.apiKey);
if (resolvedKey) {
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
}
}
this.models.push({
id: modelDef.id,
name: modelDef.name,
api: api as Api,
provider: providerName,
baseUrl: config.baseUrl,
reasoning: modelDef.reasoning,
input: modelDef.input as ("text" | "image")[],
cost: modelDef.cost,
contextWindow: modelDef.contextWindow,
maxTokens: modelDef.maxTokens,
headers,
compat: modelDef.compat,
} as Model<Api>);
}
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
if (config.oauth?.modifyModels) {
const cred = this.authStorage.get(providerName);
if (cred?.type === "oauth") {
this.models = config.oauth.modifyModels(this.models, cred);
}
}
} else if (config.baseUrl) {
// Override-only: update baseUrl/headers for existing models
const resolvedHeaders = resolveHeaders(config.headers);
this.models = this.models.map((m) => {
if (m.provider !== providerName) return m;
return {
...m,
baseUrl: config.baseUrl ?? m.baseUrl,
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
};
});
}
}
}
/**
* Input type for registerProvider API.
*/
export interface ProviderConfigInput {
baseUrl?: string;
apiKey?: string;
api?: Api;
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
headers?: Record<string, string>;
authHeader?: boolean;
/** OAuth provider for /login support */
oauth?: Omit<OAuthProviderInterface, "id">;
models?: Array<{
id: string;
name: string;
api?: Api;
baseUrl?: string;
reasoning: boolean;
input: ("text" | "image")[];
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
contextWindow: number;
maxTokens: number;
headers?: Record<string, string>;
compat?: Model<Api>["compat"];
}>;
}

View file

@ -0,0 +1,594 @@
/**
* Model resolution, scoping, and initial selection
*/
import type { ThinkingLevel } from "@gsd/pi-agent-core";
import { type Api, type KnownProvider, type Model, modelsAreEqual } from "@gsd/pi-ai";
import chalk from "chalk";
import { minimatch } from "minimatch";
import { isValidThinkingLevel } from "../cli/args.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type { ModelRegistry } from "./model-registry.js";
/** Default model IDs for each known provider */
export const defaultModelPerProvider: Record<KnownProvider, string> = {
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
anthropic: "claude-opus-4-6",
openai: "gpt-5.4",
"azure-openai-responses": "gpt-5.2",
"openai-codex": "gpt-5.4",
google: "gemini-2.5-pro",
"google-gemini-cli": "gemini-2.5-pro",
"google-antigravity": "gemini-3.1-pro-high",
"google-vertex": "gemini-3-pro-preview",
"github-copilot": "gpt-4o",
openrouter: "openai/gpt-5.1-codex",
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
xai: "grok-4-fast-non-reasoning",
groq: "openai/gpt-oss-120b",
cerebras: "zai-glm-4.6",
zai: "glm-4.6",
mistral: "devstral-medium-latest",
minimax: "MiniMax-M2.1",
"minimax-cn": "MiniMax-M2.1",
huggingface: "moonshotai/Kimi-K2.5",
opencode: "claude-opus-4-6",
"opencode-go": "kimi-k2.5",
"kimi-coding": "kimi-k2-thinking",
};
export interface ScopedModel {
model: Model<Api>;
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
thinkingLevel?: ThinkingLevel;
}
/**
* Helper to check if a model ID looks like an alias (no date suffix)
* Dates are typically in format: -20241022 or -20250929
*/
function isAlias(id: string): boolean {
// Check if ID ends with -latest
if (id.endsWith("-latest")) return true;
// Check if ID ends with a date pattern (-YYYYMMDD)
const datePattern = /-\d{8}$/;
return !datePattern.test(id);
}
/**
* Try to match a pattern to a model from the available models list.
* Returns the matched model or undefined if no match found.
*/
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | undefined {
// Check for provider/modelId format (provider is everything before the first /)
const slashIndex = modelPattern.indexOf("/");
if (slashIndex !== -1) {
const provider = modelPattern.substring(0, slashIndex);
const modelId = modelPattern.substring(slashIndex + 1);
const providerMatch = availableModels.find(
(m) => m.provider.toLowerCase() === provider.toLowerCase() && m.id.toLowerCase() === modelId.toLowerCase(),
);
if (providerMatch) {
return providerMatch;
}
// No exact provider/model match - fall through to other matching
}
// Check for exact ID match (case-insensitive)
const exactMatch = availableModels.find((m) => m.id.toLowerCase() === modelPattern.toLowerCase());
if (exactMatch) {
return exactMatch;
}
// No exact match - fall back to partial matching
const matches = availableModels.filter(
(m) =>
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
);
if (matches.length === 0) {
return undefined;
}
// Separate into aliases and dated versions
const aliases = matches.filter((m) => isAlias(m.id));
const datedVersions = matches.filter((m) => !isAlias(m.id));
if (aliases.length > 0) {
// Prefer alias - if multiple aliases, pick the one that sorts highest
aliases.sort((a, b) => b.id.localeCompare(a.id));
return aliases[0];
} else {
// No alias found, pick latest dated version
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
return datedVersions[0];
}
}
export interface ParsedModelResult {
model: Model<Api> | undefined;
/** Thinking level if explicitly specified in pattern, undefined otherwise */
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
}
function buildFallbackModel(provider: string, modelId: string, availableModels: Model<Api>[]): Model<Api> | undefined {
const providerModels = availableModels.filter((m) => m.provider === provider);
if (providerModels.length === 0) return undefined;
const defaultId = defaultModelPerProvider[provider as KnownProvider];
const baseModel = defaultId
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
: providerModels[0];
return {
...baseModel,
id: modelId,
name: modelId,
};
}
/**
* Parse a pattern to extract model and thinking level.
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
*
* Algorithm:
* 1. Try to match full pattern as a model
* 2. If found, return it with "off" thinking level
* 3. If not found and has colons, split on last colon:
* - If suffix is valid thinking level, use it and recurse on prefix
* - If suffix is invalid, warn and recurse on prefix with "off"
*
* @internal Exported for testing
*/
export function parseModelPattern(
pattern: string,
availableModels: Model<Api>[],
options?: { allowInvalidThinkingLevelFallback?: boolean },
): ParsedModelResult {
// Try exact match first
const exactMatch = tryMatchModel(pattern, availableModels);
if (exactMatch) {
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
}
// No match - try splitting on last colon if present
const lastColonIndex = pattern.lastIndexOf(":");
if (lastColonIndex === -1) {
// No colons, pattern simply doesn't match any model
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
const prefix = pattern.substring(0, lastColonIndex);
const suffix = pattern.substring(lastColonIndex + 1);
if (isValidThinkingLevel(suffix)) {
// Valid thinking level - recurse on prefix and use this level
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
// Only use this thinking level if no warning from inner recursion
return {
model: result.model,
thinkingLevel: result.warning ? undefined : suffix,
warning: result.warning,
};
}
return result;
} else {
// Invalid suffix
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
if (!allowFallback) {
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
// This avoids accidentally resolving to a different model.
return { model: undefined, thinkingLevel: undefined, warning: undefined };
}
// Scope mode: recurse on prefix and warn
const result = parseModelPattern(prefix, availableModels, options);
if (result.model) {
return {
model: result.model,
thinkingLevel: undefined,
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
};
}
return result;
}
}
/**
* Resolve model patterns to actual Model objects with optional thinking levels
* Format: "pattern:level" where :level is optional
* For each pattern, finds all matching models and picks the best version:
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
* 2. If no alias, pick the latest dated version
*
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
* The algorithm tries to match the full pattern first, then progressively
* strips colon-suffixes to find a match.
*/
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
const availableModels = await modelRegistry.getAvailable();
const scopedModels: ScopedModel[] = [];
for (const pattern of patterns) {
// Check if pattern contains glob characters
if (pattern.includes("*") || pattern.includes("?") || pattern.includes("[")) {
// Extract optional thinking level suffix (e.g., "provider/*:high")
const colonIdx = pattern.lastIndexOf(":");
let globPattern = pattern;
let thinkingLevel: ThinkingLevel | undefined;
if (colonIdx !== -1) {
const suffix = pattern.substring(colonIdx + 1);
if (isValidThinkingLevel(suffix)) {
thinkingLevel = suffix;
globPattern = pattern.substring(0, colonIdx);
}
}
// Match against "provider/modelId" format OR just model ID
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
const matchingModels = availableModels.filter((m) => {
const fullId = `${m.provider}/${m.id}`;
return minimatch(fullId, globPattern, { nocase: true }) || minimatch(m.id, globPattern, { nocase: true });
});
if (matchingModels.length === 0) {
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
continue;
}
for (const model of matchingModels) {
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
continue;
}
const { model, thinkingLevel, warning } = parseModelPattern(pattern, availableModels);
if (warning) {
console.warn(chalk.yellow(`Warning: ${warning}`));
}
if (!model) {
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
continue;
}
// Avoid duplicates
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
scopedModels.push({ model, thinkingLevel });
}
}
return scopedModels;
}
export interface ResolveCliModelResult {
model: Model<Api> | undefined;
thinkingLevel?: ThinkingLevel;
warning: string | undefined;
/**
* Error message suitable for CLI display.
* When set, model will be undefined.
*/
error: string | undefined;
}
/**
* Resolve a single model from CLI flags.
*
* Supports:
* - --provider <provider> --model <pattern>
* - --model <provider>/<pattern>
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
*
* Note: This does not apply the thinking level by itself, but it may *parse* and
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
*/
export function resolveCliModel(options: {
cliProvider?: string;
cliModel?: string;
modelRegistry: ModelRegistry;
}): ResolveCliModelResult {
const { cliProvider, cliModel, modelRegistry } = options;
if (!cliModel) {
return { model: undefined, warning: undefined, error: undefined };
}
// Important: use *all* models here, not just models with pre-configured auth.
// This allows "--api-key" to be used for first-time setup.
const availableModels = modelRegistry.getAll();
if (availableModels.length === 0) {
return {
model: undefined,
warning: undefined,
error: "No models available. Check your installation or add models to models.json.",
};
}
// Build canonical provider lookup (case-insensitive)
const providerMap = new Map<string, string>();
for (const m of availableModels) {
providerMap.set(m.provider.toLowerCase(), m.provider);
}
let provider = cliProvider ? providerMap.get(cliProvider.toLowerCase()) : undefined;
if (cliProvider && !provider) {
return {
model: undefined,
warning: undefined,
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
};
}
// If no explicit --provider, try to interpret "provider/model" format first.
// When the prefix before the first slash matches a known provider, prefer that
// interpretation over matching models whose IDs literally contain slashes
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
// vercel-ai-gateway model with id "zai/glm-5").
let pattern = cliModel;
let inferredProvider = false;
if (!provider) {
const slashIndex = cliModel.indexOf("/");
if (slashIndex !== -1) {
const maybeProvider = cliModel.substring(0, slashIndex);
const canonical = providerMap.get(maybeProvider.toLowerCase());
if (canonical) {
provider = canonical;
pattern = cliModel.substring(slashIndex + 1);
inferredProvider = true;
}
}
}
// If no provider was inferred from the slash, try exact matches without provider inference.
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
if (!provider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) => m.id.toLowerCase() === lower || `${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return { model: exact, warning: undefined, thinkingLevel: undefined, error: undefined };
}
}
if (cliProvider && provider) {
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
const prefix = `${provider}/`;
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
pattern = cliModel.substring(prefix.length);
}
}
const candidates = provider ? availableModels.filter((m) => m.provider === provider) : availableModels;
const { model, thinkingLevel, warning } = parseModelPattern(pattern, candidates, {
allowInvalidThinkingLevelFallback: false,
});
if (model) {
return { model, thinkingLevel, warning, error: undefined };
}
// If we inferred a provider from the slash but found no match within that provider,
// fall back to matching the full input as a raw model id across all models.
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
// looks like a provider but the full string is actually a model id on openrouter.
if (inferredProvider) {
const lower = cliModel.toLowerCase();
const exact = availableModels.find(
(m) => m.id.toLowerCase() === lower || `${m.provider}/${m.id}`.toLowerCase() === lower,
);
if (exact) {
return { model: exact, warning: undefined, thinkingLevel: undefined, error: undefined };
}
// Also try parseModelPattern on the full input against all models
const fallback = parseModelPattern(cliModel, availableModels, {
allowInvalidThinkingLevelFallback: false,
});
if (fallback.model) {
return {
model: fallback.model,
thinkingLevel: fallback.thinkingLevel,
warning: fallback.warning,
error: undefined,
};
}
}
if (provider) {
const fallbackModel = buildFallbackModel(provider, pattern, availableModels);
if (fallbackModel) {
const fallbackWarning = warning
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
return { model: fallbackModel, thinkingLevel: undefined, warning: fallbackWarning, error: undefined };
}
}
const display = provider ? `${provider}/${pattern}` : cliModel;
return {
model: undefined,
thinkingLevel: undefined,
warning,
error: `Model "${display}" not found. Use --list-models to see available models.`,
};
}
export interface InitialModelResult {
model: Model<Api> | undefined;
thinkingLevel: ThinkingLevel;
fallbackMessage: string | undefined;
}
/**
* Find the initial model to use based on priority:
* 1. CLI args (provider + model)
* 2. First model from scoped models (if not continuing/resuming)
* 3. Restored from session (if continuing/resuming)
* 4. Saved default from settings
* 5. First available model with valid API key
*/
export async function findInitialModel(options: {
cliProvider?: string;
cliModel?: string;
scopedModels: ScopedModel[];
isContinuing: boolean;
defaultProvider?: string;
defaultModelId?: string;
defaultThinkingLevel?: ThinkingLevel;
modelRegistry: ModelRegistry;
}): Promise<InitialModelResult> {
const {
cliProvider,
cliModel,
scopedModels,
isContinuing,
defaultProvider,
defaultModelId,
defaultThinkingLevel,
modelRegistry,
} = options;
let model: Model<Api> | undefined;
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
// 1. CLI args take priority
if (cliProvider && cliModel) {
const resolved = resolveCliModel({
cliProvider,
cliModel,
modelRegistry,
});
if (resolved.error) {
console.error(chalk.red(resolved.error));
process.exit(1);
}
if (resolved.model) {
return { model: resolved.model, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
}
}
// 2. Use first model from scoped models (skip if continuing/resuming)
if (scopedModels.length > 0 && !isContinuing) {
return {
model: scopedModels[0].model,
thinkingLevel: scopedModels[0].thinkingLevel ?? defaultThinkingLevel ?? DEFAULT_THINKING_LEVEL,
fallbackMessage: undefined,
};
}
// 3. Try saved default from settings
if (defaultProvider && defaultModelId) {
const found = modelRegistry.find(defaultProvider, defaultModelId);
if (found) {
model = found;
if (defaultThinkingLevel) {
thinkingLevel = defaultThinkingLevel;
}
return { model, thinkingLevel, fallbackMessage: undefined };
}
}
// 4. Try first available model with valid API key
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
if (match) {
return { model: match, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
}
}
// If no default found, use first available
return { model: availableModels[0], thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
}
// 5. No model found
return { model: undefined, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
}
/**
* Restore model from session, with fallback to available models
*/
export async function restoreModelFromSession(
savedProvider: string,
savedModelId: string,
currentModel: Model<Api> | undefined,
shouldPrintMessages: boolean,
modelRegistry: ModelRegistry,
): Promise<{ model: Model<Api> | undefined; fallbackMessage: string | undefined }> {
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
// Check if restored model exists and has a valid API key
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
if (restoredModel && hasApiKey) {
if (shouldPrintMessages) {
console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`));
}
return { model: restoredModel, fallbackMessage: undefined };
}
// Model not found or no API key - fall back
const reason = !restoredModel ? "model no longer exists" : "no API key available";
if (shouldPrintMessages) {
console.error(chalk.yellow(`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`));
}
// If we already have a model, use it as fallback
if (currentModel) {
if (shouldPrintMessages) {
console.log(chalk.dim(`Falling back to: ${currentModel.provider}/${currentModel.id}`));
}
return {
model: currentModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
};
}
// Try to find any available model
const availableModels = await modelRegistry.getAvailable();
if (availableModels.length > 0) {
// Try to find a default model from known providers
let fallbackModel: Model<Api> | undefined;
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
const defaultId = defaultModelPerProvider[provider];
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
if (match) {
fallbackModel = match;
break;
}
}
// If no default found, use first available
if (!fallbackModel) {
fallbackModel = availableModels[0];
}
if (shouldPrintMessages) {
console.log(chalk.dim(`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`));
}
return {
model: fallbackModel,
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
};
}
// No models available
return { model: undefined, fallbackMessage: undefined };
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,299 @@
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
import { homedir } from "os";
import { basename, isAbsolute, join, resolve, sep } from "path";
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
import { parseFrontmatter } from "../utils/frontmatter.js";
/**
* Represents a prompt template loaded from a markdown file
*/
export interface PromptTemplate {
name: string;
description: string;
content: string;
source: string; // "user", "project", or "path"
filePath: string; // Absolute path to the template file
}
/**
* Parse command arguments respecting quoted strings (bash-style)
* Returns array of arguments
*/
export function parseCommandArgs(argsString: string): string[] {
const args: string[] = [];
let current = "";
let inQuote: string | null = null;
for (let i = 0; i < argsString.length; i++) {
const char = argsString[i];
if (inQuote) {
if (char === inQuote) {
inQuote = null;
} else {
current += char;
}
} else if (char === '"' || char === "'") {
inQuote = char;
} else if (char === " " || char === "\t") {
if (current) {
args.push(current);
current = "";
}
} else {
current += char;
}
}
if (current) {
args.push(current);
}
return args;
}
/**
* Substitute argument placeholders in template content
* Supports:
* - $1, $2, ... for positional args
* - $@ and $ARGUMENTS for all args
* - ${@:N} for args from Nth onwards (bash-style slicing)
* - ${@:N:L} for L args starting from Nth
*
* Note: Replacement happens on the template string only. Argument values
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
*/
export function substituteArgs(content: string, args: string[]): string {
let result = content;
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
result = result.replace(/\$(\d+)/g, (_, num) => {
const index = parseInt(num, 10) - 1;
return args[index] ?? "";
});
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
// Process BEFORE simple $@ to avoid conflicts
result = result.replace(/\$\{@:(\d+)(?::(\d+))?\}/g, (_, startStr, lengthStr) => {
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
// Treat 0 as 1 (bash convention: args start at 1)
if (start < 0) start = 0;
if (lengthStr) {
const length = parseInt(lengthStr, 10);
return args.slice(start, start + length).join(" ");
}
return args.slice(start).join(" ");
});
// Pre-compute all args joined (optimization)
const allArgs = args.join(" ");
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
result = result.replace(/\$ARGUMENTS/g, allArgs);
// Replace $@ with all args joined (existing syntax)
result = result.replace(/\$@/g, allArgs);
return result;
}
function loadTemplateFromFile(filePath: string, source: string, sourceLabel: string): PromptTemplate | null {
try {
const rawContent = readFileSync(filePath, "utf-8");
const { frontmatter, body } = parseFrontmatter<Record<string, string>>(rawContent);
const name = basename(filePath).replace(/\.md$/, "");
// Get description from frontmatter or first non-empty line
let description = frontmatter.description || "";
if (!description) {
const firstLine = body.split("\n").find((line) => line.trim());
if (firstLine) {
// Truncate if too long
description = firstLine.slice(0, 60);
if (firstLine.length > 60) description += "...";
}
}
// Append source to description
description = description ? `${description} ${sourceLabel}` : sourceLabel;
return {
name,
description,
content: body,
source,
filePath,
};
} catch {
return null;
}
}
/**
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
*/
function loadTemplatesFromDir(dir: string, source: string, sourceLabel: string): PromptTemplate[] {
const templates: PromptTemplate[] = [];
if (!existsSync(dir)) {
return templates;
}
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
const fullPath = join(dir, entry.name);
// For symlinks, check if they point to a file
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
const stats = statSync(fullPath);
isFile = stats.isFile();
} catch {
// Broken symlink, skip it
continue;
}
}
if (isFile && entry.name.endsWith(".md")) {
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
if (template) {
templates.push(template);
}
}
}
} catch {
return templates;
}
return templates;
}
export interface LoadPromptTemplatesOptions {
/** Working directory for project-local templates. Default: process.cwd() */
cwd?: string;
/** Agent config directory for global templates. Default: from getPromptsDir() */
agentDir?: string;
/** Explicit prompt template paths (files or directories) */
promptPaths?: string[];
/** Include default prompt directories. Default: true */
includeDefaults?: boolean;
}
function normalizePath(input: string): string {
const trimmed = input.trim();
if (trimmed === "~") return homedir();
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
return trimmed;
}
function resolvePromptPath(p: string, cwd: string): string {
const normalized = normalizePath(p);
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
}
function buildPathSourceLabel(p: string): string {
const base = basename(p).replace(/\.md$/, "") || "path";
return `(path:${base})`;
}
/**
* Load all prompt templates from:
* 1. Global: agentDir/prompts/
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
* 3. Explicit prompt paths
*/
export function loadPromptTemplates(options: LoadPromptTemplatesOptions = {}): PromptTemplate[] {
const resolvedCwd = options.cwd ?? process.cwd();
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
const promptPaths = options.promptPaths ?? [];
const includeDefaults = options.includeDefaults ?? true;
const templates: PromptTemplate[] = [];
if (includeDefaults) {
// 1. Load global templates from agentDir/prompts/
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
const globalPromptsDir = options.agentDir ? join(options.agentDir, "prompts") : resolvedAgentDir;
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
templates.push(...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"));
}
const userPromptsDir = options.agentDir ? join(options.agentDir, "prompts") : resolvedAgentDir;
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
const isUnderPath = (target: string, root: string): boolean => {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep) ? normalizedRoot : `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
};
const getSourceInfo = (resolvedPath: string): { source: string; label: string } => {
if (!includeDefaults) {
if (isUnderPath(resolvedPath, userPromptsDir)) {
return { source: "user", label: "(user)" };
}
if (isUnderPath(resolvedPath, projectPromptsDir)) {
return { source: "project", label: "(project)" };
}
}
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
};
// 3. Load explicit prompt paths
for (const rawPath of promptPaths) {
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
if (!existsSync(resolvedPath)) {
continue;
}
try {
const stats = statSync(resolvedPath);
const { source, label } = getSourceInfo(resolvedPath);
if (stats.isDirectory()) {
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
const template = loadTemplateFromFile(resolvedPath, source, label);
if (template) {
templates.push(template);
}
}
} catch {
// Ignore read failures
}
}
return templates;
}
/**
* Expand a prompt template if it matches a template name.
* Returns the expanded content or the original text if not a template.
*/
export function expandPromptTemplate(text: string, templates: PromptTemplate[]): string {
if (!text.startsWith("/")) return text;
const spaceIndex = text.indexOf(" ");
const templateName = spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
const template = templates.find((t) => t.name === templateName);
if (template) {
const args = parseCommandArgs(argsString);
return substituteArgs(template.content, args);
}
return text;
}

View file

@ -0,0 +1,64 @@
/**
* Resolve configuration values that may be shell commands, environment variables, or literals.
* Used by auth-storage.ts and model-registry.ts.
*/
import { execSync } from "child_process";
// Cache for shell command results (persists for process lifetime)
const commandResultCache = new Map<string, string | undefined>();
/**
* Resolve a config value (API key, header value, etc.) to an actual value.
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
* - Otherwise checks environment variable first, then treats as literal (not cached)
*/
export function resolveConfigValue(config: string): string | undefined {
if (config.startsWith("!")) {
return executeCommand(config);
}
const envValue = process.env[config];
return envValue || config;
}
function executeCommand(commandConfig: string): string | undefined {
if (commandResultCache.has(commandConfig)) {
return commandResultCache.get(commandConfig);
}
const command = commandConfig.slice(1);
let result: string | undefined;
try {
const output = execSync(command, {
encoding: "utf-8",
timeout: 10000,
stdio: ["ignore", "pipe", "ignore"],
});
result = output.trim() || undefined;
} catch {
result = undefined;
}
commandResultCache.set(commandConfig, result);
return result;
}
/**
* Resolve all header values using the same resolution logic as API keys.
*/
export function resolveHeaders(headers: Record<string, string> | undefined): Record<string, string> | undefined {
if (!headers) return undefined;
const resolved: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
const resolvedValue = resolveConfigValue(value);
if (resolvedValue) {
resolved[key] = resolvedValue;
}
}
return Object.keys(resolved).length > 0 ? resolved : undefined;
}
/** Clear the config value command cache. Exported for testing. */
export function clearConfigValueCache(): void {
commandResultCache.clear();
}

View file

@ -0,0 +1,868 @@
import { existsSync, readdirSync, readFileSync, statSync } from "node:fs";
import { homedir } from "node:os";
import { join, resolve, sep } from "node:path";
import chalk from "chalk";
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
import { loadThemeFromPath, type Theme } from "../modes/interactive/theme/theme.js";
import type { ResourceDiagnostic } from "./diagnostics.js";
export type { ResourceCollision, ResourceDiagnostic } from "./diagnostics.js";
import { createEventBus, type EventBus } from "./event-bus.js";
import { createExtensionRuntime, loadExtensionFromFactory, loadExtensions } from "./extensions/loader.js";
import type { Extension, ExtensionFactory, ExtensionRuntime, LoadExtensionsResult } from "./extensions/types.js";
import { DefaultPackageManager, type PathMetadata } from "./package-manager.js";
import type { PromptTemplate } from "./prompt-templates.js";
import { loadPromptTemplates } from "./prompt-templates.js";
import { SettingsManager } from "./settings-manager.js";
import type { Skill } from "./skills.js";
import { loadSkills } from "./skills.js";
export interface ResourceExtensionPaths {
skillPaths?: Array<{ path: string; metadata: PathMetadata }>;
promptPaths?: Array<{ path: string; metadata: PathMetadata }>;
themePaths?: Array<{ path: string; metadata: PathMetadata }>;
}
export interface ResourceLoader {
getExtensions(): LoadExtensionsResult;
getSkills(): { skills: Skill[]; diagnostics: ResourceDiagnostic[] };
getPrompts(): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] };
getThemes(): { themes: Theme[]; diagnostics: ResourceDiagnostic[] };
getAgentsFiles(): { agentsFiles: Array<{ path: string; content: string }> };
getSystemPrompt(): string | undefined;
getAppendSystemPrompt(): string[];
getPathMetadata(): Map<string, PathMetadata>;
extendResources(paths: ResourceExtensionPaths): void;
reload(): Promise<void>;
}
function resolvePromptInput(input: string | undefined, description: string): string | undefined {
if (!input) {
return undefined;
}
if (existsSync(input)) {
try {
return readFileSync(input, "utf-8");
} catch (error) {
console.error(chalk.yellow(`Warning: Could not read ${description} file ${input}: ${error}`));
return input;
}
}
return input;
}
function loadContextFileFromDir(dir: string): { path: string; content: string } | null {
const candidates = ["AGENTS.md", "CLAUDE.md"];
for (const filename of candidates) {
const filePath = join(dir, filename);
if (existsSync(filePath)) {
try {
return {
path: filePath,
content: readFileSync(filePath, "utf-8"),
};
} catch (error) {
console.error(chalk.yellow(`Warning: Could not read ${filePath}: ${error}`));
}
}
}
return null;
}
function loadProjectContextFiles(
options: { cwd?: string; agentDir?: string } = {},
): Array<{ path: string; content: string }> {
const resolvedCwd = options.cwd ?? process.cwd();
const resolvedAgentDir = options.agentDir ?? getAgentDir();
const contextFiles: Array<{ path: string; content: string }> = [];
const seenPaths = new Set<string>();
const globalContext = loadContextFileFromDir(resolvedAgentDir);
if (globalContext) {
contextFiles.push(globalContext);
seenPaths.add(globalContext.path);
}
const ancestorContextFiles: Array<{ path: string; content: string }> = [];
let currentDir = resolvedCwd;
const root = resolve("/");
while (true) {
const contextFile = loadContextFileFromDir(currentDir);
if (contextFile && !seenPaths.has(contextFile.path)) {
ancestorContextFiles.unshift(contextFile);
seenPaths.add(contextFile.path);
}
if (currentDir === root) break;
const parentDir = resolve(currentDir, "..");
if (parentDir === currentDir) break;
currentDir = parentDir;
}
contextFiles.push(...ancestorContextFiles);
return contextFiles;
}
export interface DefaultResourceLoaderOptions {
cwd?: string;
agentDir?: string;
settingsManager?: SettingsManager;
eventBus?: EventBus;
additionalExtensionPaths?: string[];
additionalSkillPaths?: string[];
additionalPromptTemplatePaths?: string[];
additionalThemePaths?: string[];
extensionFactories?: ExtensionFactory[];
noExtensions?: boolean;
noSkills?: boolean;
noPromptTemplates?: boolean;
noThemes?: boolean;
systemPrompt?: string;
appendSystemPrompt?: string;
extensionsOverride?: (base: LoadExtensionsResult) => LoadExtensionsResult;
skillsOverride?: (base: { skills: Skill[]; diagnostics: ResourceDiagnostic[] }) => {
skills: Skill[];
diagnostics: ResourceDiagnostic[];
};
promptsOverride?: (base: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] }) => {
prompts: PromptTemplate[];
diagnostics: ResourceDiagnostic[];
};
themesOverride?: (base: { themes: Theme[]; diagnostics: ResourceDiagnostic[] }) => {
themes: Theme[];
diagnostics: ResourceDiagnostic[];
};
agentsFilesOverride?: (base: { agentsFiles: Array<{ path: string; content: string }> }) => {
agentsFiles: Array<{ path: string; content: string }>;
};
systemPromptOverride?: (base: string | undefined) => string | undefined;
appendSystemPromptOverride?: (base: string[]) => string[];
}
export class DefaultResourceLoader implements ResourceLoader {
private cwd: string;
private agentDir: string;
private settingsManager: SettingsManager;
private eventBus: EventBus;
private packageManager: DefaultPackageManager;
private additionalExtensionPaths: string[];
private additionalSkillPaths: string[];
private additionalPromptTemplatePaths: string[];
private additionalThemePaths: string[];
private extensionFactories: ExtensionFactory[];
private noExtensions: boolean;
private noSkills: boolean;
private noPromptTemplates: boolean;
private noThemes: boolean;
private systemPromptSource?: string;
private appendSystemPromptSource?: string;
private extensionsOverride?: (base: LoadExtensionsResult) => LoadExtensionsResult;
private skillsOverride?: (base: { skills: Skill[]; diagnostics: ResourceDiagnostic[] }) => {
skills: Skill[];
diagnostics: ResourceDiagnostic[];
};
private promptsOverride?: (base: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] }) => {
prompts: PromptTemplate[];
diagnostics: ResourceDiagnostic[];
};
private themesOverride?: (base: { themes: Theme[]; diagnostics: ResourceDiagnostic[] }) => {
themes: Theme[];
diagnostics: ResourceDiagnostic[];
};
private agentsFilesOverride?: (base: { agentsFiles: Array<{ path: string; content: string }> }) => {
agentsFiles: Array<{ path: string; content: string }>;
};
private systemPromptOverride?: (base: string | undefined) => string | undefined;
private appendSystemPromptOverride?: (base: string[]) => string[];
private extensionsResult: LoadExtensionsResult;
private skills: Skill[];
private skillDiagnostics: ResourceDiagnostic[];
private prompts: PromptTemplate[];
private promptDiagnostics: ResourceDiagnostic[];
private themes: Theme[];
private themeDiagnostics: ResourceDiagnostic[];
private agentsFiles: Array<{ path: string; content: string }>;
private systemPrompt?: string;
private appendSystemPrompt: string[];
private pathMetadata: Map<string, PathMetadata>;
private lastSkillPaths: string[];
private lastPromptPaths: string[];
private lastThemePaths: string[];
constructor(options: DefaultResourceLoaderOptions) {
this.cwd = options.cwd ?? process.cwd();
this.agentDir = options.agentDir ?? getAgentDir();
this.settingsManager = options.settingsManager ?? SettingsManager.create(this.cwd, this.agentDir);
this.eventBus = options.eventBus ?? createEventBus();
this.packageManager = new DefaultPackageManager({
cwd: this.cwd,
agentDir: this.agentDir,
settingsManager: this.settingsManager,
});
this.additionalExtensionPaths = options.additionalExtensionPaths ?? [];
this.additionalSkillPaths = options.additionalSkillPaths ?? [];
this.additionalPromptTemplatePaths = options.additionalPromptTemplatePaths ?? [];
this.additionalThemePaths = options.additionalThemePaths ?? [];
this.extensionFactories = options.extensionFactories ?? [];
this.noExtensions = options.noExtensions ?? false;
this.noSkills = options.noSkills ?? false;
this.noPromptTemplates = options.noPromptTemplates ?? false;
this.noThemes = options.noThemes ?? false;
this.systemPromptSource = options.systemPrompt;
this.appendSystemPromptSource = options.appendSystemPrompt;
this.extensionsOverride = options.extensionsOverride;
this.skillsOverride = options.skillsOverride;
this.promptsOverride = options.promptsOverride;
this.themesOverride = options.themesOverride;
this.agentsFilesOverride = options.agentsFilesOverride;
this.systemPromptOverride = options.systemPromptOverride;
this.appendSystemPromptOverride = options.appendSystemPromptOverride;
this.extensionsResult = { extensions: [], errors: [], runtime: createExtensionRuntime() };
this.skills = [];
this.skillDiagnostics = [];
this.prompts = [];
this.promptDiagnostics = [];
this.themes = [];
this.themeDiagnostics = [];
this.agentsFiles = [];
this.appendSystemPrompt = [];
this.pathMetadata = new Map();
this.lastSkillPaths = [];
this.lastPromptPaths = [];
this.lastThemePaths = [];
}
getExtensions(): LoadExtensionsResult {
return this.extensionsResult;
}
getSkills(): { skills: Skill[]; diagnostics: ResourceDiagnostic[] } {
return { skills: this.skills, diagnostics: this.skillDiagnostics };
}
getPrompts(): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] } {
return { prompts: this.prompts, diagnostics: this.promptDiagnostics };
}
getThemes(): { themes: Theme[]; diagnostics: ResourceDiagnostic[] } {
return { themes: this.themes, diagnostics: this.themeDiagnostics };
}
getAgentsFiles(): { agentsFiles: Array<{ path: string; content: string }> } {
return { agentsFiles: this.agentsFiles };
}
getSystemPrompt(): string | undefined {
return this.systemPrompt;
}
getAppendSystemPrompt(): string[] {
return this.appendSystemPrompt;
}
getPathMetadata(): Map<string, PathMetadata> {
return this.pathMetadata;
}
extendResources(paths: ResourceExtensionPaths): void {
const skillPaths = this.normalizeExtensionPaths(paths.skillPaths ?? []);
const promptPaths = this.normalizeExtensionPaths(paths.promptPaths ?? []);
const themePaths = this.normalizeExtensionPaths(paths.themePaths ?? []);
if (skillPaths.length > 0) {
this.lastSkillPaths = this.mergePaths(
this.lastSkillPaths,
skillPaths.map((entry) => entry.path),
);
this.updateSkillsFromPaths(this.lastSkillPaths, skillPaths);
}
if (promptPaths.length > 0) {
this.lastPromptPaths = this.mergePaths(
this.lastPromptPaths,
promptPaths.map((entry) => entry.path),
);
this.updatePromptsFromPaths(this.lastPromptPaths, promptPaths);
}
if (themePaths.length > 0) {
this.lastThemePaths = this.mergePaths(
this.lastThemePaths,
themePaths.map((entry) => entry.path),
);
this.updateThemesFromPaths(this.lastThemePaths, themePaths);
}
}
async reload(): Promise<void> {
const resolvedPaths = await this.packageManager.resolve();
const cliExtensionPaths = await this.packageManager.resolveExtensionSources(this.additionalExtensionPaths, {
temporary: true,
});
// Helper to extract enabled paths and store metadata
const getEnabledResources = (
resources: Array<{ path: string; enabled: boolean; metadata: PathMetadata }>,
): Array<{ path: string; enabled: boolean; metadata: PathMetadata }> => {
for (const r of resources) {
if (!this.pathMetadata.has(r.path)) {
this.pathMetadata.set(r.path, r.metadata);
}
}
return resources.filter((r) => r.enabled);
};
const getEnabledPaths = (
resources: Array<{ path: string; enabled: boolean; metadata: PathMetadata }>,
): string[] => getEnabledResources(resources).map((r) => r.path);
// Store metadata and get enabled paths
this.pathMetadata = new Map();
const enabledExtensions = getEnabledPaths(resolvedPaths.extensions);
const enabledSkillResources = getEnabledResources(resolvedPaths.skills);
const enabledPrompts = getEnabledPaths(resolvedPaths.prompts);
const enabledThemes = getEnabledPaths(resolvedPaths.themes);
const mapSkillPath = (resource: { path: string; metadata: PathMetadata }): string => {
if (resource.metadata.source !== "auto" && resource.metadata.origin !== "package") {
return resource.path;
}
try {
const stats = statSync(resource.path);
if (!stats.isDirectory()) {
return resource.path;
}
} catch {
return resource.path;
}
const skillFile = join(resource.path, "SKILL.md");
if (existsSync(skillFile)) {
if (!this.pathMetadata.has(skillFile)) {
this.pathMetadata.set(skillFile, resource.metadata);
}
return skillFile;
}
return resource.path;
};
const enabledSkills = enabledSkillResources.map(mapSkillPath);
// Add CLI paths metadata
for (const r of cliExtensionPaths.extensions) {
if (!this.pathMetadata.has(r.path)) {
this.pathMetadata.set(r.path, { source: "cli", scope: "temporary", origin: "top-level" });
}
}
for (const r of cliExtensionPaths.skills) {
if (!this.pathMetadata.has(r.path)) {
this.pathMetadata.set(r.path, { source: "cli", scope: "temporary", origin: "top-level" });
}
}
const cliEnabledExtensions = getEnabledPaths(cliExtensionPaths.extensions);
const cliEnabledSkills = getEnabledPaths(cliExtensionPaths.skills);
const cliEnabledPrompts = getEnabledPaths(cliExtensionPaths.prompts);
const cliEnabledThemes = getEnabledPaths(cliExtensionPaths.themes);
const extensionPaths = this.noExtensions
? cliEnabledExtensions
: this.mergePaths(cliEnabledExtensions, enabledExtensions);
const extensionsResult = await loadExtensions(extensionPaths, this.cwd, this.eventBus);
const inlineExtensions = await this.loadExtensionFactories(extensionsResult.runtime);
extensionsResult.extensions.push(...inlineExtensions.extensions);
extensionsResult.errors.push(...inlineExtensions.errors);
// Detect extension conflicts (tools, commands, flags with same names from different extensions)
// Keep all extensions loaded. Conflicts are reported as diagnostics, and precedence is handled by load order.
const conflicts = this.detectExtensionConflicts(extensionsResult.extensions);
for (const conflict of conflicts) {
extensionsResult.errors.push({ path: conflict.path, error: conflict.message });
}
this.extensionsResult = this.extensionsOverride ? this.extensionsOverride(extensionsResult) : extensionsResult;
const skillPaths = this.noSkills
? this.mergePaths(cliEnabledSkills, this.additionalSkillPaths)
: this.mergePaths([...enabledSkills, ...cliEnabledSkills], this.additionalSkillPaths);
this.lastSkillPaths = skillPaths;
this.updateSkillsFromPaths(skillPaths);
const promptPaths = this.noPromptTemplates
? this.mergePaths(cliEnabledPrompts, this.additionalPromptTemplatePaths)
: this.mergePaths([...enabledPrompts, ...cliEnabledPrompts], this.additionalPromptTemplatePaths);
this.lastPromptPaths = promptPaths;
this.updatePromptsFromPaths(promptPaths);
const themePaths = this.noThemes
? this.mergePaths(cliEnabledThemes, this.additionalThemePaths)
: this.mergePaths([...enabledThemes, ...cliEnabledThemes], this.additionalThemePaths);
this.lastThemePaths = themePaths;
this.updateThemesFromPaths(themePaths);
for (const extension of this.extensionsResult.extensions) {
this.addDefaultMetadataForPath(extension.path);
}
const agentsFiles = { agentsFiles: loadProjectContextFiles({ cwd: this.cwd, agentDir: this.agentDir }) };
const resolvedAgentsFiles = this.agentsFilesOverride ? this.agentsFilesOverride(agentsFiles) : agentsFiles;
this.agentsFiles = resolvedAgentsFiles.agentsFiles;
const baseSystemPrompt = resolvePromptInput(
this.systemPromptSource ?? this.discoverSystemPromptFile(),
"system prompt",
);
this.systemPrompt = this.systemPromptOverride ? this.systemPromptOverride(baseSystemPrompt) : baseSystemPrompt;
const appendSource = this.appendSystemPromptSource ?? this.discoverAppendSystemPromptFile();
const resolvedAppend = resolvePromptInput(appendSource, "append system prompt");
const baseAppend = resolvedAppend ? [resolvedAppend] : [];
this.appendSystemPrompt = this.appendSystemPromptOverride
? this.appendSystemPromptOverride(baseAppend)
: baseAppend;
}
private normalizeExtensionPaths(
entries: Array<{ path: string; metadata: PathMetadata }>,
): Array<{ path: string; metadata: PathMetadata }> {
return entries.map((entry) => ({
path: this.resolveResourcePath(entry.path),
metadata: entry.metadata,
}));
}
private updateSkillsFromPaths(
skillPaths: string[],
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
): void {
let skillsResult: { skills: Skill[]; diagnostics: ResourceDiagnostic[] };
if (this.noSkills && skillPaths.length === 0) {
skillsResult = { skills: [], diagnostics: [] };
} else {
skillsResult = loadSkills({
cwd: this.cwd,
agentDir: this.agentDir,
skillPaths,
includeDefaults: false,
});
}
const resolvedSkills = this.skillsOverride ? this.skillsOverride(skillsResult) : skillsResult;
this.skills = resolvedSkills.skills;
this.skillDiagnostics = resolvedSkills.diagnostics;
this.applyExtensionMetadata(
extensionPaths,
this.skills.map((skill) => skill.filePath),
);
for (const skill of this.skills) {
this.addDefaultMetadataForPath(skill.filePath);
}
}
private updatePromptsFromPaths(
promptPaths: string[],
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
): void {
let promptsResult: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] };
if (this.noPromptTemplates && promptPaths.length === 0) {
promptsResult = { prompts: [], diagnostics: [] };
} else {
const allPrompts = loadPromptTemplates({
cwd: this.cwd,
agentDir: this.agentDir,
promptPaths,
includeDefaults: false,
});
promptsResult = this.dedupePrompts(allPrompts);
}
const resolvedPrompts = this.promptsOverride ? this.promptsOverride(promptsResult) : promptsResult;
this.prompts = resolvedPrompts.prompts;
this.promptDiagnostics = resolvedPrompts.diagnostics;
this.applyExtensionMetadata(
extensionPaths,
this.prompts.map((prompt) => prompt.filePath),
);
for (const prompt of this.prompts) {
this.addDefaultMetadataForPath(prompt.filePath);
}
}
private updateThemesFromPaths(
themePaths: string[],
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
): void {
let themesResult: { themes: Theme[]; diagnostics: ResourceDiagnostic[] };
if (this.noThemes && themePaths.length === 0) {
themesResult = { themes: [], diagnostics: [] };
} else {
const loaded = this.loadThemes(themePaths, false);
const deduped = this.dedupeThemes(loaded.themes);
themesResult = { themes: deduped.themes, diagnostics: [...loaded.diagnostics, ...deduped.diagnostics] };
}
const resolvedThemes = this.themesOverride ? this.themesOverride(themesResult) : themesResult;
this.themes = resolvedThemes.themes;
this.themeDiagnostics = resolvedThemes.diagnostics;
const themePathsWithSource = this.themes.flatMap((theme) => (theme.sourcePath ? [theme.sourcePath] : []));
this.applyExtensionMetadata(extensionPaths, themePathsWithSource);
for (const theme of this.themes) {
if (theme.sourcePath) {
this.addDefaultMetadataForPath(theme.sourcePath);
}
}
}
private applyExtensionMetadata(
extensionPaths: Array<{ path: string; metadata: PathMetadata }>,
resourcePaths: string[],
): void {
if (extensionPaths.length === 0) {
return;
}
const normalized = extensionPaths.map((entry) => ({
path: resolve(entry.path),
metadata: entry.metadata,
}));
for (const entry of normalized) {
if (!this.pathMetadata.has(entry.path)) {
this.pathMetadata.set(entry.path, entry.metadata);
}
}
for (const resourcePath of resourcePaths) {
const normalizedResourcePath = resolve(resourcePath);
if (this.pathMetadata.has(normalizedResourcePath) || this.pathMetadata.has(resourcePath)) {
continue;
}
const match = normalized.find(
(entry) =>
normalizedResourcePath === entry.path || normalizedResourcePath.startsWith(`${entry.path}${sep}`),
);
if (match) {
this.pathMetadata.set(normalizedResourcePath, match.metadata);
}
}
}
private mergePaths(primary: string[], additional: string[]): string[] {
const merged: string[] = [];
const seen = new Set<string>();
for (const p of [...primary, ...additional]) {
const resolved = this.resolveResourcePath(p);
if (seen.has(resolved)) continue;
seen.add(resolved);
merged.push(resolved);
}
return merged;
}
private resolveResourcePath(p: string): string {
const trimmed = p.trim();
let expanded = trimmed;
if (trimmed === "~") {
expanded = homedir();
} else if (trimmed.startsWith("~/")) {
expanded = join(homedir(), trimmed.slice(2));
} else if (trimmed.startsWith("~")) {
expanded = join(homedir(), trimmed.slice(1));
}
return resolve(this.cwd, expanded);
}
private loadThemes(
paths: string[],
includeDefaults: boolean = true,
): {
themes: Theme[];
diagnostics: ResourceDiagnostic[];
} {
const themes: Theme[] = [];
const diagnostics: ResourceDiagnostic[] = [];
if (includeDefaults) {
const defaultDirs = [join(this.agentDir, "themes"), join(this.cwd, CONFIG_DIR_NAME, "themes")];
for (const dir of defaultDirs) {
this.loadThemesFromDir(dir, themes, diagnostics);
}
}
for (const p of paths) {
const resolved = resolve(this.cwd, p);
if (!existsSync(resolved)) {
diagnostics.push({ type: "warning", message: "theme path does not exist", path: resolved });
continue;
}
try {
const stats = statSync(resolved);
if (stats.isDirectory()) {
this.loadThemesFromDir(resolved, themes, diagnostics);
} else if (stats.isFile() && resolved.endsWith(".json")) {
this.loadThemeFromFile(resolved, themes, diagnostics);
} else {
diagnostics.push({ type: "warning", message: "theme path is not a json file", path: resolved });
}
} catch (error) {
const message = error instanceof Error ? error.message : "failed to read theme path";
diagnostics.push({ type: "warning", message, path: resolved });
}
}
return { themes, diagnostics };
}
private loadThemesFromDir(dir: string, themes: Theme[], diagnostics: ResourceDiagnostic[]): void {
if (!existsSync(dir)) {
return;
}
try {
const entries = readdirSync(dir, { withFileTypes: true });
for (const entry of entries) {
let isFile = entry.isFile();
if (entry.isSymbolicLink()) {
try {
isFile = statSync(join(dir, entry.name)).isFile();
} catch {
continue;
}
}
if (!isFile) {
continue;
}
if (!entry.name.endsWith(".json")) {
continue;
}
this.loadThemeFromFile(join(dir, entry.name), themes, diagnostics);
}
} catch (error) {
const message = error instanceof Error ? error.message : "failed to read theme directory";
diagnostics.push({ type: "warning", message, path: dir });
}
}
private loadThemeFromFile(filePath: string, themes: Theme[], diagnostics: ResourceDiagnostic[]): void {
try {
themes.push(loadThemeFromPath(filePath));
} catch (error) {
const message = error instanceof Error ? error.message : "failed to load theme";
diagnostics.push({ type: "warning", message, path: filePath });
}
}
private async loadExtensionFactories(runtime: ExtensionRuntime): Promise<{
extensions: Extension[];
errors: Array<{ path: string; error: string }>;
}> {
const extensions: Extension[] = [];
const errors: Array<{ path: string; error: string }> = [];
for (const [index, factory] of this.extensionFactories.entries()) {
const extensionPath = `<inline:${index + 1}>`;
try {
const extension = await loadExtensionFromFactory(factory, this.cwd, this.eventBus, runtime, extensionPath);
extensions.push(extension);
} catch (error) {
const message = error instanceof Error ? error.message : "failed to load extension";
errors.push({ path: extensionPath, error: message });
}
}
return { extensions, errors };
}
private dedupePrompts(prompts: PromptTemplate[]): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] } {
const seen = new Map<string, PromptTemplate>();
const diagnostics: ResourceDiagnostic[] = [];
for (const prompt of prompts) {
const existing = seen.get(prompt.name);
if (existing) {
diagnostics.push({
type: "collision",
message: `name "/${prompt.name}" collision`,
path: prompt.filePath,
collision: {
resourceType: "prompt",
name: prompt.name,
winnerPath: existing.filePath,
loserPath: prompt.filePath,
},
});
} else {
seen.set(prompt.name, prompt);
}
}
return { prompts: Array.from(seen.values()), diagnostics };
}
private dedupeThemes(themes: Theme[]): { themes: Theme[]; diagnostics: ResourceDiagnostic[] } {
const seen = new Map<string, Theme>();
const diagnostics: ResourceDiagnostic[] = [];
for (const t of themes) {
const name = t.name ?? "unnamed";
const existing = seen.get(name);
if (existing) {
diagnostics.push({
type: "collision",
message: `name "${name}" collision`,
path: t.sourcePath,
collision: {
resourceType: "theme",
name,
winnerPath: existing.sourcePath ?? "<builtin>",
loserPath: t.sourcePath ?? "<builtin>",
},
});
} else {
seen.set(name, t);
}
}
return { themes: Array.from(seen.values()), diagnostics };
}
private discoverSystemPromptFile(): string | undefined {
const projectPath = join(this.cwd, CONFIG_DIR_NAME, "SYSTEM.md");
if (existsSync(projectPath)) {
return projectPath;
}
const globalPath = join(this.agentDir, "SYSTEM.md");
if (existsSync(globalPath)) {
return globalPath;
}
return undefined;
}
private discoverAppendSystemPromptFile(): string | undefined {
const projectPath = join(this.cwd, CONFIG_DIR_NAME, "APPEND_SYSTEM.md");
if (existsSync(projectPath)) {
return projectPath;
}
const globalPath = join(this.agentDir, "APPEND_SYSTEM.md");
if (existsSync(globalPath)) {
return globalPath;
}
return undefined;
}
private addDefaultMetadataForPath(filePath: string): void {
if (!filePath || filePath.startsWith("<")) {
return;
}
const normalizedPath = resolve(filePath);
if (this.pathMetadata.has(normalizedPath) || this.pathMetadata.has(filePath)) {
return;
}
const agentRoots = [
join(this.agentDir, "skills"),
join(this.agentDir, "prompts"),
join(this.agentDir, "themes"),
join(this.agentDir, "extensions"),
];
const projectRoots = [
join(this.cwd, CONFIG_DIR_NAME, "skills"),
join(this.cwd, CONFIG_DIR_NAME, "prompts"),
join(this.cwd, CONFIG_DIR_NAME, "themes"),
join(this.cwd, CONFIG_DIR_NAME, "extensions"),
];
for (const root of agentRoots) {
if (this.isUnderPath(normalizedPath, root)) {
this.pathMetadata.set(normalizedPath, { source: "local", scope: "user", origin: "top-level" });
return;
}
}
for (const root of projectRoots) {
if (this.isUnderPath(normalizedPath, root)) {
this.pathMetadata.set(normalizedPath, { source: "local", scope: "project", origin: "top-level" });
return;
}
}
}
private isUnderPath(target: string, root: string): boolean {
const normalizedRoot = resolve(root);
if (target === normalizedRoot) {
return true;
}
const prefix = normalizedRoot.endsWith(sep) ? normalizedRoot : `${normalizedRoot}${sep}`;
return target.startsWith(prefix);
}
private detectExtensionConflicts(extensions: Extension[]): Array<{ path: string; message: string }> {
const conflicts: Array<{ path: string; message: string }> = [];
// Track which extension registered each tool, command, and flag
const toolOwners = new Map<string, string>();
const commandOwners = new Map<string, string>();
const flagOwners = new Map<string, string>();
for (const ext of extensions) {
// Check tools
for (const toolName of ext.tools.keys()) {
const existingOwner = toolOwners.get(toolName);
if (existingOwner && existingOwner !== ext.path) {
conflicts.push({
path: ext.path,
message: `Tool "${toolName}" conflicts with ${existingOwner}`,
});
} else {
toolOwners.set(toolName, ext.path);
}
}
// Check commands
for (const commandName of ext.commands.keys()) {
const existingOwner = commandOwners.get(commandName);
if (existingOwner && existingOwner !== ext.path) {
conflicts.push({
path: ext.path,
message: `Command "/${commandName}" conflicts with ${existingOwner}`,
});
} else {
commandOwners.set(commandName, ext.path);
}
}
// Check flags
for (const flagName of ext.flags.keys()) {
const existingOwner = flagOwners.get(flagName);
if (existingOwner && existingOwner !== ext.path) {
conflicts.push({
path: ext.path,
message: `Flag "--${flagName}" conflicts with ${existingOwner}`,
});
} else {
flagOwners.set(flagName, ext.path);
}
}
}
return conflicts;
}
}

View file

@ -0,0 +1,373 @@
import { join } from "node:path";
import { Agent, type AgentMessage, type ThinkingLevel } from "@gsd/pi-agent-core";
import type { Message, Model } from "@gsd/pi-ai";
import { getAgentDir, getDocsPath } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type { ExtensionRunner, LoadExtensionsResult, ToolDefinition } from "./extensions/index.js";
import { convertToLlm } from "./messages.js";
import { ModelRegistry } from "./model-registry.js";
import { findInitialModel } from "./model-resolver.js";
import type { ResourceLoader } from "./resource-loader.js";
import { DefaultResourceLoader } from "./resource-loader.js";
import { SessionManager } from "./session-manager.js";
import { SettingsManager } from "./settings-manager.js";
import { time } from "./timings.js";
import {
allTools,
bashTool,
codingTools,
createBashTool,
createCodingTools,
createEditTool,
createFindTool,
createGrepTool,
createLsTool,
createReadOnlyTools,
createReadTool,
createWriteTool,
editTool,
findTool,
grepTool,
lsTool,
readOnlyTools,
readTool,
type Tool,
type ToolName,
writeTool,
} from "./tools/index.js";
export interface CreateAgentSessionOptions {
/** Working directory for project-local discovery. Default: process.cwd() */
cwd?: string;
/** Global config directory. Default: ~/.pi/agent */
agentDir?: string;
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
authStorage?: AuthStorage;
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
model?: Model<any>;
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
thinkingLevel?: ThinkingLevel;
/** Models available for cycling (Ctrl+P in interactive mode) */
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
tools?: Tool[];
/** Custom tools to register (in addition to built-in tools). */
customTools?: ToolDefinition[];
/** Resource loader. When omitted, DefaultResourceLoader is used. */
resourceLoader?: ResourceLoader;
/** Session manager. Default: SessionManager.create(cwd) */
sessionManager?: SessionManager;
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
settingsManager?: SettingsManager;
}
/** Result from createAgentSession */
export interface CreateAgentSessionResult {
/** The created session */
session: AgentSession;
/** Extensions result (for UI context setup in interactive mode) */
extensionsResult: LoadExtensionsResult;
/** Warning if session was restored with a different model than saved */
modelFallbackMessage?: string;
}
// Re-exports
export type {
ExtensionAPI,
ExtensionCommandContext,
ExtensionContext,
ExtensionFactory,
SlashCommandInfo,
SlashCommandLocation,
SlashCommandSource,
ToolDefinition,
} from "./extensions/index.js";
export type { PromptTemplate } from "./prompt-templates.js";
export type { Skill } from "./skills.js";
export type { Tool } from "./tools/index.js";
export {
// Pre-built tools (use process.cwd())
readTool,
bashTool,
editTool,
writeTool,
grepTool,
findTool,
lsTool,
codingTools,
readOnlyTools,
allTools as allBuiltInTools,
// Tool factories (for custom cwd)
createCodingTools,
createReadOnlyTools,
createReadTool,
createBashTool,
createEditTool,
createWriteTool,
createGrepTool,
createFindTool,
createLsTool,
};
// Helper Functions
function getDefaultAgentDir(): string {
return getAgentDir();
}
/**
* Create an AgentSession with the specified options.
*
* @example
* ```typescript
* // Minimal - uses defaults
* const { session } = await createAgentSession();
*
* // With explicit model
* import { getModel } from '@gsd/pi-ai';
* const { session } = await createAgentSession({
* model: getModel('anthropic', 'claude-opus-4-5'),
* thinkingLevel: 'high',
* });
*
* // Continue previous session
* const { session, modelFallbackMessage } = await createAgentSession({
* continueSession: true,
* });
*
* // Full control
* const loader = new DefaultResourceLoader({
* cwd: process.cwd(),
* agentDir: getAgentDir(),
* settingsManager: SettingsManager.create(),
* });
* await loader.reload();
* const { session } = await createAgentSession({
* model: myModel,
* tools: [readTool, bashTool],
* resourceLoader: loader,
* sessionManager: SessionManager.inMemory(),
* });
* ```
*/
export async function createAgentSession(options: CreateAgentSessionOptions = {}): Promise<CreateAgentSessionResult> {
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
let resourceLoader = options.resourceLoader;
// Use provided or create AuthStorage and ModelRegistry
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined;
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
if (!resourceLoader) {
resourceLoader = new DefaultResourceLoader({ cwd, agentDir, settingsManager });
await resourceLoader.reload();
time("resourceLoader.reload");
}
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
const hasExistingSession = existingSession.messages.length > 0;
const hasThinkingEntry = sessionManager.getBranch().some((entry) => entry.type === "thinking_level_change");
let model = options.model;
let modelFallbackMessage: string | undefined;
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
if (!model) {
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
}
}
// If still no model, use findInitialModel (checks settings default, then provider defaults)
if (!model) {
const result = await findInitialModel({
scopedModels: [],
isContinuing: hasExistingSession,
defaultProvider: settingsManager.getDefaultProvider(),
defaultModelId: settingsManager.getDefaultModel(),
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
modelRegistry,
});
model = result.model;
if (!model) {
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
} else if (modelFallbackMessage) {
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
}
}
let thinkingLevel = options.thinkingLevel;
// If session has data, restore thinking level from it
if (thinkingLevel === undefined && hasExistingSession) {
thinkingLevel = hasThinkingEntry
? (existingSession.thinkingLevel as ThinkingLevel)
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
}
// Fall back to settings default
if (thinkingLevel === undefined) {
thinkingLevel = settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
}
// Clamp to model capabilities
if (!model || !model.reasoning) {
thinkingLevel = "off";
}
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
const initialActiveToolNames: ToolName[] = options.tools
? options.tools.map((t) => t.name).filter((n): n is ToolName => n in allTools)
: defaultActiveToolNames;
let agent: Agent;
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
const converted = convertToLlm(messages);
// Check setting dynamically so mid-session changes take effect
if (!settingsManager.getBlockImages()) {
return converted;
}
// Filter out ImageContent from all messages, replacing with text placeholder
return converted.map((msg) => {
if (msg.role === "user" || msg.role === "toolResult") {
const content = msg.content;
if (Array.isArray(content)) {
const hasImages = content.some((c) => c.type === "image");
if (hasImages) {
const filteredContent = content
.map((c) =>
c.type === "image" ? { type: "text" as const, text: "Image reading is disabled." } : c,
)
.filter(
(c, i, arr) =>
// Dedupe consecutive "Image reading is disabled." texts
!(
c.type === "text" &&
c.text === "Image reading is disabled." &&
i > 0 &&
arr[i - 1].type === "text" &&
(arr[i - 1] as { type: "text"; text: string }).text === "Image reading is disabled."
),
);
return { ...msg, content: filteredContent };
}
}
}
return msg;
});
};
const extensionRunnerRef: { current?: ExtensionRunner } = {};
agent = new Agent({
initialState: {
systemPrompt: "",
model,
thinkingLevel,
tools: [],
},
convertToLlm: convertToLlmWithBlockImages,
onPayload: async (payload, _model) => {
const runner = extensionRunnerRef.current;
if (!runner?.hasHandlers("before_provider_request")) {
return payload;
}
return runner.emitBeforeProviderRequest(payload);
},
sessionId: sessionManager.getSessionId(),
transformContext: async (messages) => {
const runner = extensionRunnerRef.current;
if (!runner) return messages;
return runner.emitContext(messages);
},
steeringMode: settingsManager.getSteeringMode(),
followUpMode: settingsManager.getFollowUpMode(),
transport: settingsManager.getTransport(),
thinkingBudgets: settingsManager.getThinkingBudgets(),
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
getApiKey: async (provider) => {
// Use the provider argument from the in-flight request;
// agent.state.model may already be switched mid-turn.
const resolvedProvider = provider || agent.state.model?.provider;
if (!resolvedProvider) {
throw new Error("No model selected");
}
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
if (!key) {
const model = agent.state.model;
const isOAuth = model && modelRegistry.isUsingOAuth(model);
if (isOAuth) {
throw new Error(
`Authentication failed for "${resolvedProvider}". ` +
`Credentials may have expired or network is unavailable. ` +
`Run '/login ${resolvedProvider}' to re-authenticate.`,
);
}
throw new Error(
`No API key found for "${resolvedProvider}". ` +
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
);
}
return key;
},
});
// Restore messages if session has existing data
if (hasExistingSession) {
agent.replaceMessages(existingSession.messages);
if (!hasThinkingEntry) {
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
} else {
// Save initial model and thinking level for new sessions so they can be restored on resume
if (model) {
sessionManager.appendModelChange(model.provider, model.id);
}
sessionManager.appendThinkingLevelChange(thinkingLevel);
}
const session = new AgentSession({
agent,
sessionManager,
settingsManager,
cwd,
scopedModels: options.scopedModels,
resourceLoader,
customTools: options.customTools,
modelRegistry,
initialActiveToolNames,
extensionRunnerRef,
});
const extensionsResult = resourceLoader.getExtensions();
return {
session,
extensionsResult,
modelFallbackMessage,
};
}

Some files were not shown because too many files have changed in this diff Show more