Merge branch 'gsd-build:main' into main
This commit is contained in:
commit
2140a4de07
310 changed files with 89146 additions and 2652 deletions
30
.github/workflows/publish.yml
vendored
30
.github/workflows/publish.yml
vendored
|
|
@ -1,30 +0,0 @@
|
|||
name: Publish to npm
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- 'v*'
|
||||
|
||||
jobs:
|
||||
publish:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: read
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: '22'
|
||||
registry-url: 'https://registry.npmjs.org'
|
||||
|
||||
- name: Install dependencies
|
||||
run: npm ci
|
||||
|
||||
- name: Build
|
||||
run: npm run build
|
||||
|
||||
- name: Publish to npm
|
||||
run: npm publish --access public
|
||||
env:
|
||||
NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }}
|
||||
13
.gitignore
vendored
13
.gitignore
vendored
|
|
@ -28,6 +28,10 @@ coverage/
|
|||
.cache/
|
||||
tmp/
|
||||
|
||||
# ── Workspace packages ──
|
||||
packages/*/dist/
|
||||
packages/*/node_modules/
|
||||
|
||||
# ── GSD baseline (auto-generated) ──
|
||||
dist/
|
||||
.bg_shell
|
||||
|
|
@ -36,6 +40,15 @@ dist/
|
|||
AGENTS.md
|
||||
.bg-shell/
|
||||
TODOS.md
|
||||
.planning/
|
||||
|
||||
# ── GSD baseline (auto-generated) ──
|
||||
.gsd/
|
||||
|
||||
# ── GSD baseline (auto-generated) ──
|
||||
.gsd/activity/
|
||||
.gsd/runtime/
|
||||
.gsd/worktrees/
|
||||
.gsd/auto.lock
|
||||
.gsd/metrics.json
|
||||
.gsd/STATE.md
|
||||
|
|
|
|||
117
CHANGELOG.md
117
CHANGELOG.md
|
|
@ -6,6 +6,113 @@ Format based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [2.7.0] - 2026-03-12
|
||||
|
||||
### Changed
|
||||
- Vendor Pi SDK source (tui, ai, agent-core, coding-agent) into workspace monorepo under `packages/`, replacing the compiled npm dependency and patch-package workflow. Pi internals are now directly modifiable as TypeScript source.
|
||||
- Existing patches (setModel persist option, Windows VT input caching) applied as source edits.
|
||||
- Build pipeline runs workspace packages in dependency order before GSD compilation.
|
||||
- Removed `patch-package` from devDependencies and postinstall.
|
||||
|
||||
## [2.6.0] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Proactive secret management — planning phase forecasts required API keys into a manifest; auto-mode collects pending secrets before dispatching the first slice
|
||||
- `--continue` / `-c` CLI flag to resume the most recent session
|
||||
|
||||
### Fixed
|
||||
- Doctor post-hook no longer preempts `complete-slice` dispatch
|
||||
- `main_branch` preference restored; `runPreMergeCheck` implemented for merge safety
|
||||
- Recovery/retry prompt injection capped to prevent V8 OOM on large sessions
|
||||
- `.gsd/` excluded from pre-switch auto-commits to prevent squash merge conflicts
|
||||
|
||||
## [2.5.1] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- `secure_env_collect` now auto-detects existing keys, destination files, and provides guidance field for better onboarding UX
|
||||
|
||||
### Changed
|
||||
- Right-sized pipeline for simple work — single-slice milestones skip redundant research/plan sessions, reducing 9-10 sessions to 5-6
|
||||
- Heavyweight plan sections (Proof Level, Integration Closure, Observability) are now conditional, omitted for simple slices
|
||||
|
||||
### Fixed
|
||||
- Squash-merge now aborts cleanly on conflict and stops auto-mode instead of looping with corrupted state
|
||||
- Resolved baked-in merge conflict markers in loader.ts, logo.ts, and postinstall.js
|
||||
|
||||
## [2.5.0] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Native Anthropic web search — Claude models get server-side web search automatically, no Brave API key required
|
||||
- GitService fully wired into codebase — programmatic git operations replace shell-based git commands in prompts
|
||||
- Merge guards prevent slice completion when uncommitted changes or conflicts exist
|
||||
- Snapshot support for saving and restoring `.gsd/` state
|
||||
- Auto-push after slice squash-merge to main
|
||||
- Rich commit messages with structured metadata
|
||||
|
||||
### Fixed
|
||||
- State machine deadlock when units fail to produce expected artifacts — retry and cross-validation now gate completion
|
||||
- Duplicate Brave search tools when toggling providers repeatedly
|
||||
- Windows test glob patterns (single quotes → unquoted for shell expansion)
|
||||
- Conversation replay error caused by thinking blocks in stored history
|
||||
- Brave search tools removed from API payload when no `BRAVE_API_KEY` is set
|
||||
- Restore notifications suppressed on session resume to reduce UX noise
|
||||
|
||||
## [2.4.0] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Automatic migration of provider credentials from existing Pi installations — skip re-authentication when switching to GSD
|
||||
- Pi extensions from `~/.pi/agent/extensions/` discoverable in interactive mode
|
||||
- GitService core implementation for programmatic git operations
|
||||
|
||||
### Changed
|
||||
- System prompt compressed by 48% (360 → 187 lines) for better context efficiency
|
||||
- Refined agent character and communication style prompts
|
||||
- Added craft standards, self-debugging awareness, and work narration to agent prompts
|
||||
|
||||
### Fixed
|
||||
- RPC mode crash when `ctx.ui.theme` is undefined (#121)
|
||||
|
||||
## [2.3.11] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Branded clack-based onboarding wizard on first launch — LLM provider selection (OAuth + API key), optional tool API keys, and setup summary (#118)
|
||||
- `gsd config` subcommand to re-run the setup wizard anytime
|
||||
- Shared `src/logo.ts` module as single source of truth for ASCII banner
|
||||
|
||||
### Fixed
|
||||
- Parallel subagent results no longer truncated at 200 characters
|
||||
|
||||
### Changed
|
||||
- `wizard.ts` trimmed to env hydration only — onboarding logic moved to `onboarding.ts`
|
||||
- First-launch banner removed from `loader.ts` (onboarding wizard handles branding)
|
||||
|
||||
## [2.3.10] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Branded postinstall experience with animated spinners, progress indicators, and clean summary (#115)
|
||||
|
||||
### Fixed
|
||||
- Ctrl+Alt shortcuts (dashboard, bg manager, voice) now show slash-command fallback in terminals that lack Kitty keyboard protocol support — macOS Terminal.app, JetBrains IDEs (#100, #104)
|
||||
|
||||
## [2.3.9] - 2026-03-12
|
||||
|
||||
### Added
|
||||
- Tavily as alternative web search provider alongside Brave Search (#102)
|
||||
- Auto-mode progress widget now shows all stats; footer hidden during auto-mode (#75)
|
||||
|
||||
### Fixed
|
||||
- Auto-mode infinite loop and closeout instability — idempotent unit dispatch, retry caps, and atomic closeout (#96, #109)
|
||||
- Migration no longer requires ROADMAP.md — milestones inferred from phases/ directory when missing (#93, #90)
|
||||
- Worktree branch safety — proper namespacing and slice branch base selection (#92)
|
||||
- Windows: use `execFile` to avoid single-quote shell issues (#103)
|
||||
- Broken `read @GSD-WORKFLOW.md` references replaced with `/gsd` command (#88)
|
||||
- Google Search extension updated to use `gemini-2.5-flash` (#83)
|
||||
- Duplicate `getCurrentBranch` import in auto.ts (#87)
|
||||
- `formatCost` crash on non-number cost values (#74)
|
||||
- Avoid `sudo` prompts in postinstall script (#73)
|
||||
- `.gsd/` folder removed from git tracking; consolidated `.gitignore` (#78)
|
||||
- Multiple community-reported bugs across CLI, auto-mode, and extensions
|
||||
|
||||
## [2.3.8] - 2026-03-11
|
||||
|
||||
### Fixed
|
||||
|
|
@ -156,7 +263,15 @@ Format based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
|
|||
### Changed
|
||||
- License updated to MIT
|
||||
|
||||
[Unreleased]: https://github.com/gsd-build/gsd-2/compare/v2.3.8...HEAD
|
||||
[Unreleased]: https://github.com/gsd-build/gsd-2/compare/v2.7.0...HEAD
|
||||
[2.7.0]: https://github.com/gsd-build/gsd-2/compare/v2.6.0...v2.7.0
|
||||
[2.6.0]: https://github.com/gsd-build/gsd-2/compare/v2.5.1...v2.6.0
|
||||
[2.5.1]: https://github.com/gsd-build/gsd-2/compare/v2.5.0...v2.5.1
|
||||
[2.5.0]: https://github.com/gsd-build/gsd-2/compare/v2.4.0...v2.5.0
|
||||
[2.4.0]: https://github.com/gsd-build/gsd-2/compare/v2.3.11...v2.4.0
|
||||
[2.3.11]: https://github.com/gsd-build/gsd-2/compare/v2.3.10...v2.3.11
|
||||
[2.3.10]: https://github.com/gsd-build/gsd-2/compare/v2.3.9...v2.3.10
|
||||
[2.3.9]: https://github.com/gsd-build/gsd-2/compare/v2.3.8...v2.3.9
|
||||
[2.3.8]: https://github.com/gsd-build/gsd-2/compare/v2.3.7...v2.3.8
|
||||
[2.3.7]: https://github.com/gsd-build/gsd-2/compare/v2.3.6...v2.3.7
|
||||
[2.3.6]: https://github.com/gsd-build/gsd-2/compare/v2.3.5...v2.3.6
|
||||
|
|
|
|||
213
ISSUE-120-INVESTIGATION.md
Normal file
213
ISSUE-120-INVESTIGATION.md
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
# Issue #120 — GSD Auto Secret Collection Improvements
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Users report three failures in auto-mode secret handling:
|
||||
|
||||
1. **Late discovery** — Secrets aren't gathered until well into execution (e.g., first slice), blocking progress for hours while the user is away
|
||||
2. **Re-asking across slices** — The same secrets are requested again at the start of later slices
|
||||
3. **Re-asking within slices** — The same secrets are requested again mid-slice
|
||||
|
||||
All three stem from the same architectural gap: GSD has no proactive secret identification, and no reliable persistence of project-specific secrets across fresh sessions.
|
||||
|
||||
---
|
||||
|
||||
## Current Architecture
|
||||
|
||||
### Secret Collection Tool
|
||||
|
||||
`src/resources/extensions/get-secrets-from-user.ts`
|
||||
|
||||
- `secure_env_collect` — paged, masked-input TUI for collecting env vars
|
||||
- Writes to three destinations: `.env` (local), Vercel (`vercel env add`), Convex (`npx convex env set`)
|
||||
- Values are masked in UI and never echoed in tool output
|
||||
- Well-built tool — the problem isn't collection UX, it's when and how often collection happens
|
||||
|
||||
### Secret Persistence (GSD-owned keys only)
|
||||
|
||||
`src/wizard.ts` — `loadStoredEnvKeys()`
|
||||
|
||||
Runs at CLI startup. Loads a hardcoded list of GSD's own keys from `~/.gsd/agent/auth.json` into `process.env`:
|
||||
|
||||
- `BRAVE_API_KEY`, `BRAVE_ANSWERS_KEY`
|
||||
- `CONTEXT7_API_KEY`, `JINA_API_KEY`, `TAVILY_API_KEY`
|
||||
- `SLACK_BOT_TOKEN`, `DISCORD_BOT_TOKEN`
|
||||
|
||||
**Project-specific secrets** (GitHub tokens, database URLs, OpenAI keys, etc.) collected via `secure_env_collect` to `.env` are NOT loaded by this mechanism.
|
||||
|
||||
### Fresh Session Model
|
||||
|
||||
`src/resources/extensions/gsd/auto.ts`
|
||||
|
||||
Each unit of work (plan slice, execute task, complete slice) gets a fresh session via `ctx.newSession()`. This means:
|
||||
|
||||
- Clean context window
|
||||
- State rebuilt from `.gsd/` artifacts on disk
|
||||
- No memory of what happened in the previous session
|
||||
- `process.env` does not include project `.env` contents unless something explicitly loads them
|
||||
|
||||
### Prompt Guidance
|
||||
|
||||
| File | What it says about secrets |
|
||||
|------|--------------------------|
|
||||
| `system.md:26-27` | Never log secrets; use `secure_env_collect` instead of manual `.env` editing |
|
||||
| `system.md:131` | Routes "Secrets" to `secure_env_collect` |
|
||||
| `system.md:197` | After applying secrets, rerun the blocked workflow |
|
||||
| `execute-task.md:30` | Never log secrets/tokens unnecessarily |
|
||||
| `secure_env_collect` promptGuidelines | Proactively call before first command needing secrets; call when commands fail due to missing env vars |
|
||||
|
||||
All guidance is **reactive** — "when you hit an error, collect the secret." Nothing says "identify all secrets upfront before execution begins."
|
||||
|
||||
### What's Missing
|
||||
|
||||
| Gap | Impact |
|
||||
|-----|--------|
|
||||
| No secret identification during research/planning | Secrets discovered reactively during execution, often hours in |
|
||||
| No `.env` loading across fresh sessions | Previously-collected project secrets invisible to new sessions |
|
||||
| No "secrets already collected" carry-forward | Agent in fresh session doesn't know what was already gathered |
|
||||
| No `Required Credentials` section in requirements | No structured place to track what the project needs |
|
||||
| No deduplication or "already have this" check | Agent re-asks for secrets it already wrote to `.env` |
|
||||
|
||||
---
|
||||
|
||||
## Root Cause Analysis
|
||||
|
||||
### Problem 1: Late Discovery
|
||||
|
||||
The research phase (`research-milestone.md`) focuses on codebase exploration, technology assessment, and strategic questions. The planning phase (`plan-milestone.md`, `plan-slice.md`) focuses on task decomposition and verification. Neither phase includes a step to identify required credentials.
|
||||
|
||||
The `secure_env_collect` promptGuidelines say "when starting a new project or running setup steps that require secrets, proactively call secure_env_collect before the first command that needs them" — but this fires during task execution, not during planning. By then, the user may be asleep.
|
||||
|
||||
### Problem 2: Re-asking Across Slices
|
||||
|
||||
When `secure_env_collect` writes a secret to `.env`, that file persists on disk. But when auto-mode spawns a fresh session for the next slice, the new session's `process.env` doesn't include the `.env` contents. The agent in the new session encounters the same "missing env var" error and calls `secure_env_collect` again.
|
||||
|
||||
The `loadStoredEnvKeys()` function only loads GSD's own keys from AuthStorage, not project-specific keys from `.env`.
|
||||
|
||||
### Problem 3: Re-asking Within Slices
|
||||
|
||||
Within a single session, if `secure_env_collect` writes to `.env` but the calling code reads from `process.env` (not the file), the secret appears missing. Additionally, if a task uses a tool that checks `process.env` independently, it won't see the `.env` file contents unless something loads them.
|
||||
|
||||
---
|
||||
|
||||
## Proposed Solutions
|
||||
|
||||
### Solution 1: Proactive Secret Identification During Planning
|
||||
|
||||
**Where**: `src/resources/extensions/gsd/prompts/plan-milestone.md`
|
||||
|
||||
Add a step after research is consumed and before slice decomposition:
|
||||
|
||||
> Identify all secrets, API keys, tokens, credentials, and external service configurations this milestone will require. Consider:
|
||||
> - APIs being integrated (keys, tokens, OAuth credentials)
|
||||
> - Databases (connection strings, passwords)
|
||||
> - Third-party services (webhook secrets, API keys)
|
||||
> - Deployment targets (platform tokens)
|
||||
>
|
||||
> If any secrets are needed, call `secure_env_collect` now to gather them before execution begins. This prevents blocking during unattended execution.
|
||||
|
||||
**Also update**: `src/resources/extensions/gsd/templates/requirements.md` — add a `## Required Credentials` section:
|
||||
|
||||
```markdown
|
||||
## Required Credentials
|
||||
|
||||
| Key | Purpose | Source | Status |
|
||||
|-----|---------|--------|--------|
|
||||
| GITHUB_TOKEN | GitHub API access | User | collected |
|
||||
| DATABASE_URL | PostgreSQL connection | User | pending |
|
||||
```
|
||||
|
||||
### Solution 2: Load Project `.env` on Fresh Session Start
|
||||
|
||||
**Where**: `src/resources/extensions/gsd/auto.ts` — before spawning each fresh session
|
||||
|
||||
Before `ctx.newSession()`, read the project's `.env` file and inject its contents into the session's environment. This ensures previously-collected secrets carry forward without re-asking.
|
||||
|
||||
Implementation approach:
|
||||
|
||||
```typescript
|
||||
import { readFile } from "node:fs/promises";
|
||||
import { resolve } from "node:path";
|
||||
|
||||
async function loadProjectEnv(cwd: string): Promise<void> {
|
||||
try {
|
||||
const envPath = resolve(cwd, ".env");
|
||||
const content = await readFile(envPath, "utf8");
|
||||
for (const line of content.split("\n")) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed || trimmed.startsWith("#")) continue;
|
||||
const eqIndex = trimmed.indexOf("=");
|
||||
if (eqIndex === -1) continue;
|
||||
const key = trimmed.slice(0, eqIndex).trim();
|
||||
const value = trimmed.slice(eqIndex + 1).trim();
|
||||
// Don't override explicitly-set env vars
|
||||
if (!process.env[key]) {
|
||||
process.env[key] = value;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// No .env file — that's fine
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Call this before each fresh session spawn in auto-mode.
|
||||
|
||||
**Alternative**: Persist project secrets to AuthStorage alongside GSD's own keys, so `loadStoredEnvKeys()` picks them up. This is cleaner but requires changes to `secure_env_collect` to write to both `.env` and AuthStorage.
|
||||
|
||||
### Solution 3: Carry-Forward Context for Collected Secrets
|
||||
|
||||
**Where**: `src/resources/extensions/gsd/auto.ts` — in the context/prompt assembly for fresh sessions
|
||||
|
||||
Add a section to the injected prompt that lists secrets already collected:
|
||||
|
||||
> ## Previously Collected Secrets
|
||||
> The following env vars have already been collected and are available in `.env`:
|
||||
> - `GITHUB_TOKEN` ✓
|
||||
> - `DATABASE_URL` ✓
|
||||
>
|
||||
> Do NOT re-ask the user for these. If a command fails due to a missing env var not on this list, use `secure_env_collect`.
|
||||
|
||||
This requires scanning `.env` for key names (not values) and including them in the carry-forward context.
|
||||
|
||||
### Solution 4: Update Execute-Task Prompt
|
||||
|
||||
**Where**: `src/resources/extensions/gsd/prompts/execute-task.md`
|
||||
|
||||
Add an early step:
|
||||
|
||||
> Before starting work, check if the task requires env vars or secrets. If so, verify they exist in `.env` or `process.env`. If missing, call `secure_env_collect` immediately rather than discovering the need mid-task.
|
||||
|
||||
---
|
||||
|
||||
## Implementation Priority
|
||||
|
||||
| Priority | Solution | Effort | Impact |
|
||||
|----------|----------|--------|--------|
|
||||
| 1 | Solution 2: Load `.env` on fresh session start | Small | Eliminates re-asking (Problems 2 & 3) |
|
||||
| 2 | Solution 3: Carry-forward collected secret names | Small | Prevents agent confusion about what's available |
|
||||
| 3 | Solution 1: Proactive identification during planning | Medium | Eliminates late discovery (Problem 1) |
|
||||
| 4 | Solution 4: Execute-task prompt update | Small | Defense-in-depth for Problem 1 |
|
||||
|
||||
Solutions 1-3 together fully address the issue. Solution 4 is defense-in-depth.
|
||||
|
||||
---
|
||||
|
||||
## Files to Modify
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `src/resources/extensions/gsd/auto.ts` | Load `.env` before fresh sessions; include collected secret names in carry-forward context |
|
||||
| `src/resources/extensions/gsd/prompts/plan-milestone.md` | Add proactive secret identification step |
|
||||
| `src/resources/extensions/gsd/prompts/execute-task.md` | Add early secret verification step |
|
||||
| `src/resources/extensions/gsd/templates/requirements.md` | Add Required Credentials section |
|
||||
| `src/resources/extensions/get-secrets-from-user.ts` | (Optional) Dual-write to AuthStorage for cross-project persistence |
|
||||
|
||||
---
|
||||
|
||||
## Edge Cases to Consider
|
||||
|
||||
- **Non-dotenv destinations**: If secrets were sent to Vercel or Convex, the `.env` loading approach won't help. May need to track "collected secrets" in a `.gsd/secrets-manifest.json` (key names only, no values).
|
||||
- **Multiple `.env` files**: Some projects use `.env.local`, `.env.development`, etc. The loader should check common variants.
|
||||
- **Secrets that change**: If a user needs to rotate a key, the "don't re-ask" logic should have an escape hatch.
|
||||
- **Workspace vs global secrets**: Some secrets (like `GITHUB_TOKEN`) are user-global; others (like `DATABASE_URL`) are project-specific. Consider whether global secrets should go to AuthStorage while project secrets stay in `.env`.
|
||||
19
README.md
19
README.md
|
|
@ -48,6 +48,8 @@ GSD v2 solves all of these because it's not a prompt framework anymore — it's
|
|||
|
||||
### Migrating from v1
|
||||
|
||||
> **Note:** Migration works best with a `ROADMAP.md` file for milestone structure. Without one, milestones are inferred from the `phases/` directory.
|
||||
|
||||
If you have projects with `.planning` directories from the original Get Shit Done, you can migrate them to GSD-2's `.gsd` format:
|
||||
|
||||
```bash
|
||||
|
|
@ -198,7 +200,7 @@ Both terminals read and write the same `.gsd/` files on disk. Your decisions in
|
|||
|
||||
### First launch
|
||||
|
||||
On first run, GSD prompts for optional API keys (Brave Search, Google Gemini, Context7, Jina) for web research and documentation tools. All optional — press Enter to skip any.
|
||||
On first run, GSD launches a branded setup wizard that walks you through LLM provider selection (OAuth or API key), then optional tool API keys (Brave Search, Context7, Jina, Slack, Discord). Every step is skippable — press Enter to skip any. If you have an existing Pi installation, your provider credentials (LLM and tool keys) are imported automatically. Run `gsd config` anytime to re-run the wizard.
|
||||
|
||||
### Commands
|
||||
|
||||
|
|
@ -221,6 +223,8 @@ On first run, GSD prompts for optional API keys (Brave Search, Google Gemini, Co
|
|||
| `Ctrl+Alt+G` | Toggle dashboard overlay |
|
||||
| `Ctrl+Alt+V` | Toggle voice transcription |
|
||||
| `Ctrl+Alt+B` | Show background shell processes |
|
||||
| `gsd config` | Re-run the setup wizard (LLM provider + tool keys) |
|
||||
| `gsd --continue` (`-c`) | Resume the most recent session for the current directory |
|
||||
|
||||
---
|
||||
|
||||
|
|
@ -249,17 +253,18 @@ Branch-per-slice with squash merge. Fully automated.
|
|||
|
||||
```
|
||||
main:
|
||||
feat(M001/S03): auth and session management
|
||||
docs(M001/S04): workflow documentation and examples
|
||||
fix(M001/S03): bug fixes and doc corrections
|
||||
feat(M001/S02): API endpoints and middleware
|
||||
feat(M001/S01): data model and type system
|
||||
|
||||
gsd/M001/S01 (preserved):
|
||||
gsd/M001/S01 (deleted after merge):
|
||||
feat(S01/T03): file writer with round-trip fidelity
|
||||
feat(S01/T02): markdown parser for plan files
|
||||
feat(S01/T01): core types and interfaces
|
||||
```
|
||||
|
||||
One commit per slice on main. Per-task history preserved on branches. Git bisect works. Individual slices are revertable.
|
||||
One commit per slice on main. Squash commits are the permanent record — branches are deleted after merge. Git bisect works. Individual slices are revertable.
|
||||
|
||||
### Verification
|
||||
|
||||
|
|
@ -326,7 +331,7 @@ GSD ships with 13 extensions, all loaded automatically:
|
|||
|-----------|-----------------|
|
||||
| **GSD** | Core workflow engine, auto mode, commands, dashboard |
|
||||
| **Browser Tools** | Playwright-based browser for UI verification |
|
||||
| **Search the Web** | Brave Search + Jina page extraction |
|
||||
| **Search the Web** | Brave Search, Tavily, or Jina page extraction |
|
||||
| **Google Search** | Gemini-powered web search with AI-synthesized answers |
|
||||
| **Context7** | Up-to-date library/framework documentation |
|
||||
| **Background Shell** | Long-running process management with readiness detection |
|
||||
|
|
@ -358,7 +363,8 @@ GSD is a TypeScript application that embeds the Pi coding agent SDK.
|
|||
gsd (CLI binary)
|
||||
└─ loader.ts Sets PI_PACKAGE_DIR, GSD env vars, dynamic-imports cli.ts
|
||||
└─ cli.ts Wires SDK managers, loads extensions, starts InteractiveMode
|
||||
├─ wizard.ts First-run API key collection (Brave/Gemini/Context7/Jina)
|
||||
├─ onboarding.ts First-run setup wizard (LLM provider + tool keys)
|
||||
├─ wizard.ts Env hydration from stored auth.json credentials
|
||||
├─ app-paths.ts ~/.gsd/agent/, ~/.gsd/sessions/, auth.json
|
||||
├─ resource-loader.ts Syncs bundled extensions + agents to ~/.gsd/agent/
|
||||
└─ src/resources/
|
||||
|
|
@ -386,6 +392,7 @@ gsd (CLI binary)
|
|||
|
||||
Optional:
|
||||
- Brave Search API key (web research)
|
||||
- Tavily API key (web research — alternative to Brave)
|
||||
- Google Gemini API key (web research via Gemini Search grounding)
|
||||
- Context7 API key (library docs)
|
||||
- Jina API key (page extraction)
|
||||
|
|
|
|||
1435
package-lock.json
generated
1435
package-lock.json
generated
File diff suppressed because it is too large
Load diff
22
package.json
22
package.json
|
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"name": "gsd-pi",
|
||||
"version": "2.3.8",
|
||||
"version": "2.7.0",
|
||||
"description": "GSD — Get Shit Done coding agent",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
|
|
@ -12,13 +12,16 @@
|
|||
"url": "https://github.com/glittercowboy/gsd-pi/issues"
|
||||
},
|
||||
"type": "module",
|
||||
"workspaces": [
|
||||
"packages/*"
|
||||
],
|
||||
"bin": {
|
||||
"gsd": "dist/loader.js",
|
||||
"gsd-cli": "dist/loader.js"
|
||||
},
|
||||
"files": [
|
||||
"dist",
|
||||
"patches",
|
||||
"packages",
|
||||
"pkg",
|
||||
"src/resources",
|
||||
"scripts/postinstall.js",
|
||||
|
|
@ -33,9 +36,14 @@
|
|||
"node": ">=20.6.0"
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc && npm run copy-themes",
|
||||
"copy-themes": "node -e \"const{mkdirSync,cpSync}=require('fs');const{resolve}=require('path');const src=resolve(__dirname,'node_modules/@mariozechner/pi-coding-agent/dist/modes/interactive/theme');mkdirSync('pkg/dist/modes/interactive/theme',{recursive:true});cpSync(src,'pkg/dist/modes/interactive/theme',{recursive:true})\"",
|
||||
"test": "node --import ./src/resources/extensions/gsd/tests/resolve-ts.mjs --experimental-strip-types --test 'src/resources/extensions/gsd/tests/*.test.ts' 'src/resources/extensions/gsd/tests/*.test.mjs' 'src/tests/*.test.ts'",
|
||||
"build:pi-tui": "npm run build -w @gsd/pi-tui",
|
||||
"build:pi-ai": "npm run build -w @gsd/pi-ai",
|
||||
"build:pi-agent-core": "npm run build -w @gsd/pi-agent-core",
|
||||
"build:pi-coding-agent": "npm run build -w @gsd/pi-coding-agent",
|
||||
"build:pi": "npm run build:pi-tui && npm run build:pi-ai && npm run build:pi-agent-core && npm run build:pi-coding-agent",
|
||||
"build": "npm run build:pi && tsc && npm run copy-themes",
|
||||
"copy-themes": "node -e \"const{mkdirSync,cpSync}=require('fs');const{resolve}=require('path');const src=resolve(__dirname,'packages/pi-coding-agent/dist/modes/interactive/theme');mkdirSync('pkg/dist/modes/interactive/theme',{recursive:true});cpSync(src,'pkg/dist/modes/interactive/theme',{recursive:true})\"",
|
||||
"test": "node --import ./src/resources/extensions/gsd/tests/resolve-ts.mjs --experimental-strip-types --test src/resources/extensions/gsd/tests/*.test.ts src/resources/extensions/gsd/tests/*.test.mjs src/tests/*.test.ts",
|
||||
"dev": "tsc --watch",
|
||||
"postinstall": "node scripts/postinstall.js",
|
||||
"pi:install-global": "node scripts/install-pi-global.js",
|
||||
|
|
@ -44,12 +52,12 @@
|
|||
"prepublishOnly": "npm run sync-pkg-version && npm run build"
|
||||
},
|
||||
"dependencies": {
|
||||
"@mariozechner/pi-coding-agent": "^0.57.1",
|
||||
"@clack/prompts": "^1.1.0",
|
||||
"picocolors": "^1.1.1",
|
||||
"playwright": "^1.58.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/node": "^22.0.0",
|
||||
"patch-package": "^8.0.1",
|
||||
"typescript": "^5.4.0"
|
||||
},
|
||||
"overrides": {
|
||||
|
|
|
|||
14
packages/pi-agent-core/package.json
Normal file
14
packages/pi-agent-core/package.json
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"name": "@gsd/pi-agent-core",
|
||||
"version": "0.57.1",
|
||||
"description": "General-purpose agent core (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json"
|
||||
},
|
||||
"dependencies": {
|
||||
"@gsd/pi-ai": "*"
|
||||
}
|
||||
}
|
||||
417
packages/pi-agent-core/src/agent-loop.ts
Normal file
417
packages/pi-agent-core/src/agent-loop.ts
Normal file
|
|
@ -0,0 +1,417 @@
|
|||
/**
|
||||
* Agent loop that works with AgentMessage throughout.
|
||||
* Transforms to Message[] only at the LLM call boundary.
|
||||
*/
|
||||
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type Context,
|
||||
EventStream,
|
||||
streamSimple,
|
||||
type ToolResultMessage,
|
||||
validateToolArguments,
|
||||
} from "@gsd/pi-ai";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
*/
|
||||
export function agentLoop(
|
||||
prompts: AgentMessage[],
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [...prompts];
|
||||
const currentContext: AgentContext = {
|
||||
...context,
|
||||
messages: [...context.messages, ...prompts],
|
||||
};
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
for (const prompt of prompts) {
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
}
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue an agent loop from the current context without adding a new message.
|
||||
* Used for retries - context already has user message or tool results.
|
||||
*
|
||||
* **Important:** The last message in context must convert to a `user` or `toolResult` message
|
||||
* via `convertToLlm`. If it doesn't, the LLM provider will reject the request.
|
||||
* This cannot be validated here since `convertToLlm` is only called once per turn.
|
||||
*/
|
||||
export function agentLoopContinue(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal?: AbortSignal,
|
||||
streamFn?: StreamFn,
|
||||
): EventStream<AgentEvent, AgentMessage[]> {
|
||||
if (context.messages.length === 0) {
|
||||
throw new Error("Cannot continue: no messages in context");
|
||||
}
|
||||
|
||||
if (context.messages[context.messages.length - 1].role === "assistant") {
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
const stream = createAgentStream();
|
||||
|
||||
(async () => {
|
||||
const newMessages: AgentMessage[] = [];
|
||||
const currentContext: AgentContext = { ...context };
|
||||
|
||||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
function createAgentStream(): EventStream<AgentEvent, AgentMessage[]> {
|
||||
return new EventStream<AgentEvent, AgentMessage[]>(
|
||||
(event: AgentEvent) => event.type === "agent_end",
|
||||
(event: AgentEvent) => (event.type === "agent_end" ? event.messages : []),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Main loop logic shared by agentLoop and agentLoopContinue.
|
||||
*/
|
||||
async function runLoop(
|
||||
currentContext: AgentContext,
|
||||
newMessages: AgentMessage[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<void> {
|
||||
let firstTurn = true;
|
||||
// Check for steering messages at start (user may have typed while waiting)
|
||||
let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || [];
|
||||
|
||||
// Outer loop: continues when queued follow-up messages arrive after agent would stop
|
||||
while (true) {
|
||||
let hasMoreToolCalls = true;
|
||||
let steeringAfterTools: AgentMessage[] | null = null;
|
||||
|
||||
// Inner loop: process tool calls and steering messages
|
||||
while (hasMoreToolCalls || pendingMessages.length > 0) {
|
||||
if (!firstTurn) {
|
||||
stream.push({ type: "turn_start" });
|
||||
} else {
|
||||
firstTurn = false;
|
||||
}
|
||||
|
||||
// Process pending messages (inject before next assistant response)
|
||||
if (pendingMessages.length > 0) {
|
||||
for (const message of pendingMessages) {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
pendingMessages = [];
|
||||
}
|
||||
|
||||
// Stream assistant response
|
||||
const message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
|
||||
newMessages.push(message);
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
stream.push({ type: "turn_end", message, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check for tool calls
|
||||
const toolCalls = message.content.filter((c) => c.type === "toolCall");
|
||||
hasMoreToolCalls = toolCalls.length > 0;
|
||||
|
||||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext.tools,
|
||||
message,
|
||||
signal,
|
||||
stream,
|
||||
config.getSteeringMessages,
|
||||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
steeringAfterTools = toolExecution.steeringMessages ?? null;
|
||||
|
||||
for (const result of toolResults) {
|
||||
currentContext.messages.push(result);
|
||||
newMessages.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
stream.push({ type: "turn_end", message, toolResults });
|
||||
|
||||
// Get steering messages after turn completes
|
||||
if (steeringAfterTools && steeringAfterTools.length > 0) {
|
||||
pendingMessages = steeringAfterTools;
|
||||
steeringAfterTools = null;
|
||||
} else {
|
||||
pendingMessages = (await config.getSteeringMessages?.()) || [];
|
||||
}
|
||||
}
|
||||
|
||||
// Agent would stop here. Check for follow-up messages.
|
||||
const followUpMessages = (await config.getFollowUpMessages?.()) || [];
|
||||
if (followUpMessages.length > 0) {
|
||||
// Set as pending so inner loop processes them
|
||||
pendingMessages = followUpMessages;
|
||||
continue;
|
||||
}
|
||||
|
||||
// No more messages, exit
|
||||
break;
|
||||
}
|
||||
|
||||
stream.push({ type: "agent_end", messages: newMessages });
|
||||
stream.end(newMessages);
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream an assistant response from the LLM.
|
||||
* This is where AgentMessage[] gets transformed to Message[] for the LLM.
|
||||
*/
|
||||
async function streamAssistantResponse(
|
||||
context: AgentContext,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
streamFn?: StreamFn,
|
||||
): Promise<AssistantMessage> {
|
||||
// Apply context transform if configured (AgentMessage[] → AgentMessage[])
|
||||
let messages = context.messages;
|
||||
if (config.transformContext) {
|
||||
messages = await config.transformContext(messages, signal);
|
||||
}
|
||||
|
||||
// Convert to LLM-compatible messages (AgentMessage[] → Message[])
|
||||
const llmMessages = await config.convertToLlm(messages);
|
||||
|
||||
// Build LLM context
|
||||
const llmContext: Context = {
|
||||
systemPrompt: context.systemPrompt,
|
||||
messages: llmMessages,
|
||||
tools: context.tools,
|
||||
};
|
||||
|
||||
const streamFunction = streamFn || streamSimple;
|
||||
|
||||
// Resolve API key (important for expiring tokens)
|
||||
const resolvedApiKey =
|
||||
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
|
||||
|
||||
const response = await streamFunction(config.model, llmContext, {
|
||||
...config,
|
||||
apiKey: resolvedApiKey,
|
||||
signal,
|
||||
});
|
||||
|
||||
let partialMessage: AssistantMessage | null = null;
|
||||
let addedPartial = false;
|
||||
|
||||
for await (const event of response) {
|
||||
switch (event.type) {
|
||||
case "start":
|
||||
partialMessage = event.partial;
|
||||
context.messages.push(partialMessage);
|
||||
addedPartial = true;
|
||||
stream.push({ type: "message_start", message: { ...partialMessage } });
|
||||
break;
|
||||
|
||||
case "text_start":
|
||||
case "text_delta":
|
||||
case "text_end":
|
||||
case "thinking_start":
|
||||
case "thinking_delta":
|
||||
case "thinking_end":
|
||||
case "toolcall_start":
|
||||
case "toolcall_delta":
|
||||
case "toolcall_end":
|
||||
if (partialMessage) {
|
||||
partialMessage = event.partial;
|
||||
context.messages[context.messages.length - 1] = partialMessage;
|
||||
stream.push({
|
||||
type: "message_update",
|
||||
assistantMessageEvent: event,
|
||||
message: { ...partialMessage },
|
||||
});
|
||||
}
|
||||
break;
|
||||
|
||||
case "done":
|
||||
case "error": {
|
||||
const finalMessage = await response.result();
|
||||
if (addedPartial) {
|
||||
context.messages[context.messages.length - 1] = finalMessage;
|
||||
} else {
|
||||
context.messages.push(finalMessage);
|
||||
}
|
||||
if (!addedPartial) {
|
||||
stream.push({ type: "message_start", message: { ...finalMessage } });
|
||||
}
|
||||
stream.push({ type: "message_end", message: finalMessage });
|
||||
return finalMessage;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return await response.result();
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
tools: AgentTool<any>[] | undefined,
|
||||
assistantMessage: AssistantMessage,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
|
||||
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
const results: ToolResultMessage[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
const tool = tools?.find((t) => t.name === toolCall.name);
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
let result: AgentToolResult<any>;
|
||||
let isError = false;
|
||||
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
|
||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
});
|
||||
} catch (e) {
|
||||
result = {
|
||||
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
|
||||
details: {},
|
||||
};
|
||||
isError = true;
|
||||
}
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
results.push(toolResultMessage);
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
// Check for steering messages - skip remaining tools if user interrupted
|
||||
if (getSteeringMessages) {
|
||||
const steering = await getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, steeringMessages };
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {},
|
||||
};
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError: true,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: {},
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
return toolResultMessage;
|
||||
}
|
||||
568
packages/pi-agent-core/src/agent.ts
Normal file
568
packages/pi-agent-core/src/agent.ts
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
/**
|
||||
* Agent class that uses the agent-loop directly.
|
||||
* No transport abstraction - calls streamSimple via the loop.
|
||||
*/
|
||||
|
||||
import {
|
||||
getModel,
|
||||
type ImageContent,
|
||||
type Message,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
streamSimple,
|
||||
type TextContent,
|
||||
type ThinkingBudgets,
|
||||
type Transport,
|
||||
} from "@gsd/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentState,
|
||||
AgentTool,
|
||||
StreamFn,
|
||||
ThinkingLevel,
|
||||
} from "./types.js";
|
||||
|
||||
/**
|
||||
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
||||
*/
|
||||
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
|
||||
}
|
||||
|
||||
export interface AgentOptions {
|
||||
initialState?: Partial<AgentState>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
* Default filters to user/assistant/toolResult and converts attachments.
|
||||
*/
|
||||
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to context before convertToLlm.
|
||||
* Use for context pruning, injecting external context, etc.
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
steeringMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
||||
*/
|
||||
followUpMode?: "all" | "one-at-a-time";
|
||||
|
||||
/**
|
||||
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
||||
*/
|
||||
streamFn?: StreamFn;
|
||||
|
||||
/**
|
||||
* Optional session identifier forwarded to LLM providers.
|
||||
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
||||
*/
|
||||
sessionId?: string;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Inspect or replace provider payloads before they are sent.
|
||||
*/
|
||||
onPayload?: SimpleStreamOptions["onPayload"];
|
||||
|
||||
/**
|
||||
* Custom token budgets for thinking levels (token-based providers only).
|
||||
*/
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
*/
|
||||
transport?: Transport;
|
||||
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately,
|
||||
* allowing higher-level retry logic to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
}
|
||||
|
||||
export class Agent {
|
||||
private _state: AgentState = {
|
||||
systemPrompt: "",
|
||||
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
||||
thinkingLevel: "off",
|
||||
tools: [],
|
||||
messages: [],
|
||||
isStreaming: false,
|
||||
streamMessage: null,
|
||||
pendingToolCalls: new Set<string>(),
|
||||
error: undefined,
|
||||
};
|
||||
|
||||
private listeners = new Set<(e: AgentEvent) => void>();
|
||||
private abortController?: AbortController;
|
||||
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
private steeringQueue: AgentMessage[] = [];
|
||||
private followUpQueue: AgentMessage[] = [];
|
||||
private steeringMode: "all" | "one-at-a-time";
|
||||
private followUpMode: "all" | "one-at-a-time";
|
||||
public streamFn: StreamFn;
|
||||
private _sessionId?: string;
|
||||
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
private _onPayload?: SimpleStreamOptions["onPayload"];
|
||||
private runningPrompt?: Promise<void>;
|
||||
private resolveRunningPrompt?: () => void;
|
||||
private _thinkingBudgets?: ThinkingBudgets;
|
||||
private _transport: Transport;
|
||||
private _maxRetryDelayMs?: number;
|
||||
|
||||
constructor(opts: AgentOptions = {}) {
|
||||
this._state = { ...this._state, ...opts.initialState };
|
||||
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
||||
this.transformContext = opts.transformContext;
|
||||
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
||||
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
||||
this.streamFn = opts.streamFn || streamSimple;
|
||||
this._sessionId = opts.sessionId;
|
||||
this.getApiKey = opts.getApiKey;
|
||||
this._onPayload = opts.onPayload;
|
||||
this._thinkingBudgets = opts.thinkingBudgets;
|
||||
this._transport = opts.transport ?? "sse";
|
||||
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current session ID used for provider caching.
|
||||
*/
|
||||
get sessionId(): string | undefined {
|
||||
return this._sessionId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the session ID for provider caching.
|
||||
* Call this when switching sessions (new session, branch, resume).
|
||||
*/
|
||||
set sessionId(value: string | undefined) {
|
||||
this._sessionId = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current thinking budgets.
|
||||
*/
|
||||
get thinkingBudgets(): ThinkingBudgets | undefined {
|
||||
return this._thinkingBudgets;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set custom thinking budgets for token-based providers.
|
||||
*/
|
||||
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
||||
this._thinkingBudgets = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current preferred transport.
|
||||
*/
|
||||
get transport(): Transport {
|
||||
return this._transport;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the preferred transport.
|
||||
*/
|
||||
setTransport(value: Transport) {
|
||||
this._transport = value;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current max retry delay in milliseconds.
|
||||
*/
|
||||
get maxRetryDelayMs(): number | undefined {
|
||||
return this._maxRetryDelayMs;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the maximum delay to wait for server-requested retries.
|
||||
* Set to 0 to disable the cap.
|
||||
*/
|
||||
set maxRetryDelayMs(value: number | undefined) {
|
||||
this._maxRetryDelayMs = value;
|
||||
}
|
||||
|
||||
get state(): AgentState {
|
||||
return this._state;
|
||||
}
|
||||
|
||||
subscribe(fn: (e: AgentEvent) => void): () => void {
|
||||
this.listeners.add(fn);
|
||||
return () => this.listeners.delete(fn);
|
||||
}
|
||||
|
||||
// State mutators
|
||||
setSystemPrompt(v: string) {
|
||||
this._state.systemPrompt = v;
|
||||
}
|
||||
|
||||
setModel(m: Model<any>) {
|
||||
this._state.model = m;
|
||||
}
|
||||
|
||||
setThinkingLevel(l: ThinkingLevel) {
|
||||
this._state.thinkingLevel = l;
|
||||
}
|
||||
|
||||
setSteeringMode(mode: "all" | "one-at-a-time") {
|
||||
this.steeringMode = mode;
|
||||
}
|
||||
|
||||
getSteeringMode(): "all" | "one-at-a-time" {
|
||||
return this.steeringMode;
|
||||
}
|
||||
|
||||
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
||||
this.followUpMode = mode;
|
||||
}
|
||||
|
||||
getFollowUpMode(): "all" | "one-at-a-time" {
|
||||
return this.followUpMode;
|
||||
}
|
||||
|
||||
setTools(t: AgentTool<any>[]) {
|
||||
this._state.tools = t;
|
||||
}
|
||||
|
||||
replaceMessages(ms: AgentMessage[]) {
|
||||
this._state.messages = ms.slice();
|
||||
}
|
||||
|
||||
appendMessage(m: AgentMessage) {
|
||||
this._state.messages = [...this._state.messages, m];
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a steering message to interrupt the agent mid-run.
|
||||
* Delivered after current tool execution, skips remaining tools.
|
||||
*/
|
||||
steer(m: AgentMessage) {
|
||||
this.steeringQueue.push(m);
|
||||
}
|
||||
|
||||
/**
|
||||
* Queue a follow-up message to be processed after the agent finishes.
|
||||
* Delivered only when agent has no more tool calls or steering messages.
|
||||
*/
|
||||
followUp(m: AgentMessage) {
|
||||
this.followUpQueue.push(m);
|
||||
}
|
||||
|
||||
clearSteeringQueue() {
|
||||
this.steeringQueue = [];
|
||||
}
|
||||
|
||||
clearFollowUpQueue() {
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
clearAllQueues() {
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
hasQueuedMessages(): boolean {
|
||||
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
||||
}
|
||||
|
||||
private dequeueSteeringMessages(): AgentMessage[] {
|
||||
if (this.steeringMode === "one-at-a-time") {
|
||||
if (this.steeringQueue.length > 0) {
|
||||
const first = this.steeringQueue[0];
|
||||
this.steeringQueue = this.steeringQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const steering = this.steeringQueue.slice();
|
||||
this.steeringQueue = [];
|
||||
return steering;
|
||||
}
|
||||
|
||||
private dequeueFollowUpMessages(): AgentMessage[] {
|
||||
if (this.followUpMode === "one-at-a-time") {
|
||||
if (this.followUpQueue.length > 0) {
|
||||
const first = this.followUpQueue[0];
|
||||
this.followUpQueue = this.followUpQueue.slice(1);
|
||||
return [first];
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
||||
const followUp = this.followUpQueue.slice();
|
||||
this.followUpQueue = [];
|
||||
return followUp;
|
||||
}
|
||||
|
||||
clearMessages() {
|
||||
this._state.messages = [];
|
||||
}
|
||||
|
||||
abort() {
|
||||
this.abortController?.abort();
|
||||
}
|
||||
|
||||
waitForIdle(): Promise<void> {
|
||||
return this.runningPrompt ?? Promise.resolve();
|
||||
}
|
||||
|
||||
reset() {
|
||||
this._state.messages = [];
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this._state.error = undefined;
|
||||
this.steeringQueue = [];
|
||||
this.followUpQueue = [];
|
||||
}
|
||||
|
||||
/** Send a prompt with an AgentMessage */
|
||||
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
||||
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
||||
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error(
|
||||
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
||||
);
|
||||
}
|
||||
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
let msgs: AgentMessage[];
|
||||
|
||||
if (Array.isArray(input)) {
|
||||
msgs = input;
|
||||
} else if (typeof input === "string") {
|
||||
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
||||
if (images && images.length > 0) {
|
||||
content.push(...images);
|
||||
}
|
||||
msgs = [
|
||||
{
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
} else {
|
||||
msgs = [input];
|
||||
}
|
||||
|
||||
await this._runLoop(msgs);
|
||||
}
|
||||
|
||||
/**
|
||||
* Continue from current context (used for retries and resuming queued messages).
|
||||
*/
|
||||
async continue() {
|
||||
if (this._state.isStreaming) {
|
||||
throw new Error("Agent is already processing. Wait for completion before continuing.");
|
||||
}
|
||||
|
||||
const messages = this._state.messages;
|
||||
if (messages.length === 0) {
|
||||
throw new Error("No messages to continue from");
|
||||
}
|
||||
if (messages[messages.length - 1].role === "assistant") {
|
||||
const queuedSteering = this.dequeueSteeringMessages();
|
||||
if (queuedSteering.length > 0) {
|
||||
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
||||
return;
|
||||
}
|
||||
|
||||
const queuedFollowUp = this.dequeueFollowUpMessages();
|
||||
if (queuedFollowUp.length > 0) {
|
||||
await this._runLoop(queuedFollowUp);
|
||||
return;
|
||||
}
|
||||
|
||||
throw new Error("Cannot continue from message role: assistant");
|
||||
}
|
||||
|
||||
await this._runLoop(undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Run the agent loop.
|
||||
* If messages are provided, starts a new conversation turn with those messages.
|
||||
* Otherwise, continues from existing context.
|
||||
*/
|
||||
private async _runLoop(messages?: AgentMessage[], options?: { skipInitialSteeringPoll?: boolean }) {
|
||||
const model = this._state.model;
|
||||
if (!model) throw new Error("No model configured");
|
||||
|
||||
this.runningPrompt = new Promise<void>((resolve) => {
|
||||
this.resolveRunningPrompt = resolve;
|
||||
});
|
||||
|
||||
this.abortController = new AbortController();
|
||||
this._state.isStreaming = true;
|
||||
this._state.streamMessage = null;
|
||||
this._state.error = undefined;
|
||||
|
||||
const reasoning = this._state.thinkingLevel === "off" ? undefined : this._state.thinkingLevel;
|
||||
|
||||
const context: AgentContext = {
|
||||
systemPrompt: this._state.systemPrompt,
|
||||
messages: this._state.messages.slice(),
|
||||
tools: this._state.tools,
|
||||
};
|
||||
|
||||
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
||||
|
||||
const config: AgentLoopConfig = {
|
||||
model,
|
||||
reasoning,
|
||||
sessionId: this._sessionId,
|
||||
onPayload: this._onPayload,
|
||||
transport: this._transport,
|
||||
thinkingBudgets: this._thinkingBudgets,
|
||||
maxRetryDelayMs: this._maxRetryDelayMs,
|
||||
convertToLlm: this.convertToLlm,
|
||||
transformContext: this.transformContext,
|
||||
getApiKey: this.getApiKey,
|
||||
getSteeringMessages: async () => {
|
||||
if (skipInitialSteeringPoll) {
|
||||
skipInitialSteeringPoll = false;
|
||||
return [];
|
||||
}
|
||||
return this.dequeueSteeringMessages();
|
||||
},
|
||||
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
||||
};
|
||||
|
||||
let partial: AgentMessage | null = null;
|
||||
|
||||
try {
|
||||
const stream = messages
|
||||
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
|
||||
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
|
||||
|
||||
for await (const event of stream) {
|
||||
// Update internal state based on events
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_update":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_end":
|
||||
partial = null;
|
||||
this._state.streamMessage = null;
|
||||
this.appendMessage(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_start": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.add(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_end": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.delete(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
break;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
||||
this._state.error = (event.message as any).errorMessage;
|
||||
}
|
||||
break;
|
||||
|
||||
case "agent_end":
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
break;
|
||||
}
|
||||
|
||||
// Emit to listeners
|
||||
this.emit(event);
|
||||
}
|
||||
|
||||
// Handle any remaining partial message
|
||||
if (partial && partial.role === "assistant" && partial.content.length > 0) {
|
||||
const onlyEmpty = !partial.content.some(
|
||||
(c) =>
|
||||
(c.type === "thinking" && c.thinking.trim().length > 0) ||
|
||||
(c.type === "text" && c.text.trim().length > 0) ||
|
||||
(c.type === "toolCall" && c.name.trim().length > 0),
|
||||
);
|
||||
if (!onlyEmpty) {
|
||||
this.appendMessage(partial);
|
||||
} else {
|
||||
if (this.abortController?.signal.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
const errorMsg: AgentMessage = {
|
||||
role: "assistant",
|
||||
content: [{ type: "text", text: "" }],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
} as AgentMessage;
|
||||
|
||||
this.appendMessage(errorMsg);
|
||||
this._state.error = err?.message || String(err);
|
||||
this.emit({ type: "agent_end", messages: [errorMsg] });
|
||||
} finally {
|
||||
this._state.isStreaming = false;
|
||||
this._state.streamMessage = null;
|
||||
this._state.pendingToolCalls = new Set<string>();
|
||||
this.abortController = undefined;
|
||||
this.resolveRunningPrompt?.();
|
||||
this.runningPrompt = undefined;
|
||||
this.resolveRunningPrompt = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private emit(e: AgentEvent) {
|
||||
for (const listener of this.listeners) {
|
||||
listener(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
8
packages/pi-agent-core/src/index.ts
Normal file
8
packages/pi-agent-core/src/index.ts
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
// Core Agent
|
||||
export * from "./agent.js";
|
||||
// Loop functions
|
||||
export * from "./agent-loop.js";
|
||||
// Proxy utilities
|
||||
export * from "./proxy.js";
|
||||
// Types
|
||||
export * from "./types.js";
|
||||
340
packages/pi-agent-core/src/proxy.ts
Normal file
340
packages/pi-agent-core/src/proxy.ts
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
/**
|
||||
* Proxy stream function for apps that route LLM calls through a server.
|
||||
* The server manages auth and proxies requests to LLM providers.
|
||||
*/
|
||||
|
||||
// Internal import for JSON parsing utility
|
||||
import {
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type Context,
|
||||
EventStream,
|
||||
type Model,
|
||||
parseStreamingJson,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type ToolCall,
|
||||
} from "@gsd/pi-ai";
|
||||
|
||||
// Create stream class matching ProxyMessageEventStream
|
||||
class ProxyMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Proxy event types - server sends these with partial field stripped to reduce bandwidth.
|
||||
*/
|
||||
export type ProxyAssistantMessageEvent =
|
||||
| { type: "start" }
|
||||
| { type: "text_start"; contentIndex: number }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string }
|
||||
| { type: "text_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "thinking_start"; contentIndex: number }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string }
|
||||
| { type: "thinking_end"; contentIndex: number; contentSignature?: string }
|
||||
| { type: "toolcall_start"; contentIndex: number; id: string; toolName: string }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string }
|
||||
| { type: "toolcall_end"; contentIndex: number }
|
||||
| {
|
||||
type: "done";
|
||||
reason: Extract<StopReason, "stop" | "length" | "toolUse">;
|
||||
usage: AssistantMessage["usage"];
|
||||
}
|
||||
| {
|
||||
type: "error";
|
||||
reason: Extract<StopReason, "aborted" | "error">;
|
||||
errorMessage?: string;
|
||||
usage: AssistantMessage["usage"];
|
||||
};
|
||||
|
||||
export interface ProxyStreamOptions extends SimpleStreamOptions {
|
||||
/** Auth token for the proxy server */
|
||||
authToken: string;
|
||||
/** Proxy server URL (e.g., "https://genai.example.com") */
|
||||
proxyUrl: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream function that proxies through a server instead of calling LLM providers directly.
|
||||
* The server strips the partial field from delta events to reduce bandwidth.
|
||||
* We reconstruct the partial message client-side.
|
||||
*
|
||||
* Use this as the `streamFn` option when creating an Agent that needs to go through a proxy.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* const agent = new Agent({
|
||||
* streamFn: (model, context, options) =>
|
||||
* streamProxy(model, context, {
|
||||
* ...options,
|
||||
* authToken: await getAuthToken(),
|
||||
* proxyUrl: "https://genai.example.com",
|
||||
* }),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export function streamProxy(model: Model<any>, context: Context, options: ProxyStreamOptions): ProxyMessageEventStream {
|
||||
const stream = new ProxyMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
// Initialize the partial message that we'll build up from events
|
||||
const partial: AssistantMessage = {
|
||||
role: "assistant",
|
||||
stopReason: "stop",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
let reader: ReadableStreamDefaultReader<Uint8Array> | undefined;
|
||||
|
||||
const abortHandler = () => {
|
||||
if (reader) {
|
||||
reader.cancel("Request aborted by user").catch(() => {});
|
||||
}
|
||||
};
|
||||
|
||||
if (options.signal) {
|
||||
options.signal.addEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await fetch(`${options.proxyUrl}/api/stream`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${options.authToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model,
|
||||
context,
|
||||
options: {
|
||||
temperature: options.temperature,
|
||||
maxTokens: options.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
},
|
||||
}),
|
||||
signal: options.signal,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
let errorMessage = `Proxy error: ${response.status} ${response.statusText}`;
|
||||
try {
|
||||
const errorData = (await response.json()) as { error?: string };
|
||||
if (errorData.error) {
|
||||
errorMessage = `Proxy error: ${errorData.error}`;
|
||||
}
|
||||
} catch {
|
||||
// Couldn't parse error response
|
||||
}
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
|
||||
reader = response.body!.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith("data: ")) {
|
||||
const data = line.slice(6).trim();
|
||||
if (data) {
|
||||
const proxyEvent = JSON.parse(data) as ProxyAssistantMessageEvent;
|
||||
const event = processProxyEvent(proxyEvent, partial);
|
||||
if (event) {
|
||||
stream.push(event);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request aborted by user");
|
||||
}
|
||||
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
const reason = options.signal?.aborted ? "aborted" : "error";
|
||||
partial.stopReason = reason;
|
||||
partial.errorMessage = errorMessage;
|
||||
stream.push({
|
||||
type: "error",
|
||||
reason,
|
||||
error: partial,
|
||||
});
|
||||
stream.end();
|
||||
} finally {
|
||||
if (options.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
/**
|
||||
* Process a proxy event and update the partial message.
|
||||
*/
|
||||
function processProxyEvent(
|
||||
proxyEvent: ProxyAssistantMessageEvent,
|
||||
partial: AssistantMessage,
|
||||
): AssistantMessageEvent | undefined {
|
||||
switch (proxyEvent.type) {
|
||||
case "start":
|
||||
return { type: "start", partial };
|
||||
|
||||
case "text_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "text", text: "" };
|
||||
return { type: "text_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "text_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.text += proxyEvent.delta;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_delta for non-text content");
|
||||
}
|
||||
|
||||
case "text_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "text") {
|
||||
content.textSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.text,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received text_end for non-text content");
|
||||
}
|
||||
|
||||
case "thinking_start":
|
||||
partial.content[proxyEvent.contentIndex] = { type: "thinking", thinking: "" };
|
||||
return { type: "thinking_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "thinking_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinking += proxyEvent.delta;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_delta for non-thinking content");
|
||||
}
|
||||
|
||||
case "thinking_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "thinking") {
|
||||
content.thinkingSignature = proxyEvent.contentSignature;
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
content: content.thinking,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received thinking_end for non-thinking content");
|
||||
}
|
||||
|
||||
case "toolcall_start":
|
||||
partial.content[proxyEvent.contentIndex] = {
|
||||
type: "toolCall",
|
||||
id: proxyEvent.id,
|
||||
name: proxyEvent.toolName,
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
} satisfies ToolCall & { partialJson: string } as ToolCall;
|
||||
return { type: "toolcall_start", contentIndex: proxyEvent.contentIndex, partial };
|
||||
|
||||
case "toolcall_delta": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
(content as any).partialJson += proxyEvent.delta;
|
||||
content.arguments = parseStreamingJson((content as any).partialJson) || {};
|
||||
partial.content[proxyEvent.contentIndex] = { ...content }; // Trigger reactivity
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
delta: proxyEvent.delta,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
throw new Error("Received toolcall_delta for non-toolCall content");
|
||||
}
|
||||
|
||||
case "toolcall_end": {
|
||||
const content = partial.content[proxyEvent.contentIndex];
|
||||
if (content?.type === "toolCall") {
|
||||
delete (content as any).partialJson;
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex: proxyEvent.contentIndex,
|
||||
toolCall: content,
|
||||
partial,
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
case "done":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "done", reason: proxyEvent.reason, message: partial };
|
||||
|
||||
case "error":
|
||||
partial.stopReason = proxyEvent.reason;
|
||||
partial.errorMessage = proxyEvent.errorMessage;
|
||||
partial.usage = proxyEvent.usage;
|
||||
return { type: "error", reason: proxyEvent.reason, error: partial };
|
||||
|
||||
default: {
|
||||
const _exhaustiveCheck: never = proxyEvent;
|
||||
console.warn(`Unhandled proxy event type: ${(proxyEvent as any).type}`);
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
}
|
||||
194
packages/pi-agent-core/src/types.ts
Normal file
194
packages/pi-agent-core/src/types.ts
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import type {
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
streamSimple,
|
||||
TextContent,
|
||||
Tool,
|
||||
ToolResultMessage,
|
||||
} from "@gsd/pi-ai";
|
||||
import type { Static, TSchema } from "@sinclair/typebox";
|
||||
|
||||
/** Stream function - can return sync or Promise for async config lookup */
|
||||
export type StreamFn = (
|
||||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Configuration for the agent loop.
|
||||
*/
|
||||
export interface AgentLoopConfig extends SimpleStreamOptions {
|
||||
model: Model<any>;
|
||||
|
||||
/**
|
||||
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
||||
*
|
||||
* Each AgentMessage must be converted to a UserMessage, AssistantMessage, or ToolResultMessage
|
||||
* that the LLM can understand. AgentMessages that cannot be converted (e.g., UI-only notifications,
|
||||
* status messages) should be filtered out.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* convertToLlm: (messages) => messages.flatMap(m => {
|
||||
* if (m.role === "custom") {
|
||||
* // Convert custom message to user message
|
||||
* return [{ role: "user", content: m.content, timestamp: m.timestamp }];
|
||||
* }
|
||||
* if (m.role === "notification") {
|
||||
* // Filter out UI-only messages
|
||||
* return [];
|
||||
* }
|
||||
* // Pass through standard LLM messages
|
||||
* return [m];
|
||||
* })
|
||||
* ```
|
||||
*/
|
||||
convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
||||
|
||||
/**
|
||||
* Optional transform applied to the context before `convertToLlm`.
|
||||
*
|
||||
* Use this for operations that work at the AgentMessage level:
|
||||
* - Context window management (pruning old messages)
|
||||
* - Injecting context from external sources
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* transformContext: async (messages) => {
|
||||
* if (estimateTokens(messages) > MAX_TOKENS) {
|
||||
* return pruneOldMessages(messages);
|
||||
* }
|
||||
* return messages;
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Resolves an API key dynamically for each LLM call.
|
||||
*
|
||||
* Useful for short-lived OAuth tokens (e.g., GitHub Copilot) that may expire
|
||||
* during long-running tool execution phases.
|
||||
*/
|
||||
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
||||
|
||||
/**
|
||||
* Returns steering messages to inject into the conversation mid-run.
|
||||
*
|
||||
* Called after each tool execution to check for user interruptions.
|
||||
* If messages are returned, remaining tool calls are skipped and
|
||||
* these messages are added to the context before the next LLM call.
|
||||
*
|
||||
* Use this for "steering" the agent while it's working.
|
||||
*/
|
||||
getSteeringMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Returns follow-up messages to process after the agent would otherwise stop.
|
||||
*
|
||||
* Called when the agent has no more tool calls and no steering messages.
|
||||
* If messages are returned, they're added to the context and the agent
|
||||
* continues with another turn.
|
||||
*
|
||||
* Use this for follow-up messages that should wait until the agent finishes.
|
||||
*/
|
||||
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Thinking/reasoning level for models that support it.
|
||||
* Note: "xhigh" is only supported by OpenAI gpt-5.1-codex-max, gpt-5.2, gpt-5.2-codex, gpt-5.3, and gpt-5.3-codex models.
|
||||
*/
|
||||
export type ThinkingLevel = "off" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/**
|
||||
* Extensible interface for custom app messages.
|
||||
* Apps can extend via declaration merging:
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* declare module "@mariozechner/agent" {
|
||||
* interface CustomAgentMessages {
|
||||
* artifact: ArtifactMessage;
|
||||
* notification: NotificationMessage;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
*/
|
||||
export interface CustomAgentMessages {
|
||||
// Empty by default - apps extend via declaration merging
|
||||
}
|
||||
|
||||
/**
|
||||
* AgentMessage: Union of LLM messages + custom messages.
|
||||
* This abstraction allows apps to add custom message types while maintaining
|
||||
* type safety and compatibility with the base LLM messages.
|
||||
*/
|
||||
export type AgentMessage = Message | CustomAgentMessages[keyof CustomAgentMessages];
|
||||
|
||||
/**
|
||||
* Agent state containing all configuration and conversation data.
|
||||
*/
|
||||
export interface AgentState {
|
||||
systemPrompt: string;
|
||||
model: Model<any>;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
tools: AgentTool<any>[];
|
||||
messages: AgentMessage[]; // Can include attachments + custom message types
|
||||
isStreaming: boolean;
|
||||
streamMessage: AgentMessage | null;
|
||||
pendingToolCalls: Set<string>;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface AgentToolResult<T> {
|
||||
// Content blocks supporting text and images
|
||||
content: (TextContent | ImageContent)[];
|
||||
// Details to be displayed in a UI or logged
|
||||
details: T;
|
||||
}
|
||||
|
||||
// Callback for streaming tool execution updates
|
||||
export type AgentToolUpdateCallback<T = any> = (partialResult: AgentToolResult<T>) => void;
|
||||
|
||||
// AgentTool extends Tool but adds the execute function
|
||||
export interface AgentTool<TParameters extends TSchema = TSchema, TDetails = any> extends Tool<TParameters> {
|
||||
// A human-readable label for the tool to be displayed in UI
|
||||
label: string;
|
||||
execute: (
|
||||
toolCallId: string,
|
||||
params: Static<TParameters>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<TDetails>,
|
||||
) => Promise<AgentToolResult<TDetails>>;
|
||||
}
|
||||
|
||||
// AgentContext is like Context but uses AgentTool
|
||||
export interface AgentContext {
|
||||
systemPrompt: string;
|
||||
messages: AgentMessage[];
|
||||
tools?: AgentTool<any>[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Events emitted by the Agent for UI updates.
|
||||
* These events provide fine-grained lifecycle information for messages, turns, and tool executions.
|
||||
*/
|
||||
export type AgentEvent =
|
||||
// Agent lifecycle
|
||||
| { type: "agent_start" }
|
||||
| { type: "agent_end"; messages: AgentMessage[] }
|
||||
// Turn lifecycle - a turn is one assistant response + any tool calls/results
|
||||
| { type: "turn_start" }
|
||||
| { type: "turn_end"; message: AgentMessage; toolResults: ToolResultMessage[] }
|
||||
// Message lifecycle - emitted for user, assistant, and toolResult messages
|
||||
| { type: "message_start"; message: AgentMessage }
|
||||
// Only emitted for assistant messages during streaming
|
||||
| { type: "message_update"; message: AgentMessage; assistantMessageEvent: AssistantMessageEvent }
|
||||
| { type: "message_end"; message: AgentMessage }
|
||||
// Tool execution lifecycle
|
||||
| { type: "tool_execution_start"; toolCallId: string; toolName: string; args: any }
|
||||
| { type: "tool_execution_update"; toolCallId: string; toolName: string; args: any; partialResult: any }
|
||||
| { type: "tool_execution_end"; toolCallId: string; toolName: string; result: any; isError: boolean };
|
||||
27
packages/pi-agent-core/tsconfig.json
Normal file
27
packages/pi-agent-core/tsconfig.json
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2024",
|
||||
"module": "Node16",
|
||||
"lib": ["ES2024"],
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"inlineSources": true,
|
||||
"inlineSourceMap": false,
|
||||
"moduleResolution": "Node16",
|
||||
"resolveJsonModule": true,
|
||||
"allowImportingTsExtensions": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true,
|
||||
"useDefineForClassFields": false,
|
||||
"types": ["node"],
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
|
||||
}
|
||||
1
packages/pi-ai/bedrock-provider.d.ts
vendored
Normal file
1
packages/pi-ai/bedrock-provider.d.ts
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
1
packages/pi-ai/bedrock-provider.js
Normal file
1
packages/pi-ai/bedrock-provider.js
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./dist/bedrock-provider.js";
|
||||
40
packages/pi-ai/package.json
Normal file
40
packages/pi-ai/package.json
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
{
|
||||
"name": "@gsd/pi-ai",
|
||||
"version": "0.57.1",
|
||||
"description": "Unified LLM API (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./oauth": {
|
||||
"types": "./dist/oauth.d.ts",
|
||||
"import": "./dist/oauth.js"
|
||||
},
|
||||
"./bedrock-provider": {
|
||||
"types": "./bedrock-provider.d.ts",
|
||||
"import": "./bedrock-provider.js"
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json"
|
||||
},
|
||||
"dependencies": {
|
||||
"@anthropic-ai/sdk": "^0.73.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.983.0",
|
||||
"@google/genai": "^1.40.0",
|
||||
"@mistralai/mistralai": "1.14.1",
|
||||
"@sinclair/typebox": "^0.34.41",
|
||||
"ajv": "^8.17.1",
|
||||
"ajv-formats": "^3.0.1",
|
||||
"chalk": "^5.6.2",
|
||||
"openai": "6.26.0",
|
||||
"partial-json": "^0.1.7",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"undici": "^7.19.1",
|
||||
"zod-to-json-schema": "^3.24.6"
|
||||
}
|
||||
}
|
||||
98
packages/pi-ai/src/api-registry.ts
Normal file
98
packages/pi-ai/src/api-registry.ts
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export type ApiStreamFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: StreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export type ApiStreamSimpleFunction = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface ApiProvider<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> {
|
||||
api: TApi;
|
||||
stream: StreamFunction<TApi, TOptions>;
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>;
|
||||
}
|
||||
|
||||
interface ApiProviderInternal {
|
||||
api: Api;
|
||||
stream: ApiStreamFunction;
|
||||
streamSimple: ApiStreamSimpleFunction;
|
||||
}
|
||||
|
||||
type RegisteredApiProvider = {
|
||||
provider: ApiProviderInternal;
|
||||
sourceId?: string;
|
||||
};
|
||||
|
||||
const apiProviderRegistry = new Map<string, RegisteredApiProvider>();
|
||||
|
||||
function wrapStream<TApi extends Api, TOptions extends StreamOptions>(
|
||||
api: TApi,
|
||||
stream: StreamFunction<TApi, TOptions>,
|
||||
): ApiStreamFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return stream(model as Model<TApi>, context, options as TOptions);
|
||||
};
|
||||
}
|
||||
|
||||
function wrapStreamSimple<TApi extends Api>(
|
||||
api: TApi,
|
||||
streamSimple: StreamFunction<TApi, SimpleStreamOptions>,
|
||||
): ApiStreamSimpleFunction {
|
||||
return (model, context, options) => {
|
||||
if (model.api !== api) {
|
||||
throw new Error(`Mismatched api: ${model.api} expected ${api}`);
|
||||
}
|
||||
return streamSimple(model as Model<TApi>, context, options);
|
||||
};
|
||||
}
|
||||
|
||||
export function registerApiProvider<TApi extends Api, TOptions extends StreamOptions>(
|
||||
provider: ApiProvider<TApi, TOptions>,
|
||||
sourceId?: string,
|
||||
): void {
|
||||
apiProviderRegistry.set(provider.api, {
|
||||
provider: {
|
||||
api: provider.api,
|
||||
stream: wrapStream(provider.api, provider.stream),
|
||||
streamSimple: wrapStreamSimple(provider.api, provider.streamSimple),
|
||||
},
|
||||
sourceId,
|
||||
});
|
||||
}
|
||||
|
||||
export function getApiProvider(api: Api): ApiProviderInternal | undefined {
|
||||
return apiProviderRegistry.get(api)?.provider;
|
||||
}
|
||||
|
||||
export function getApiProviders(): ApiProviderInternal[] {
|
||||
return Array.from(apiProviderRegistry.values(), (entry) => entry.provider);
|
||||
}
|
||||
|
||||
export function unregisterApiProviders(sourceId: string): void {
|
||||
for (const [api, entry] of apiProviderRegistry.entries()) {
|
||||
if (entry.sourceId === sourceId) {
|
||||
apiProviderRegistry.delete(api);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function clearApiProviders(): void {
|
||||
apiProviderRegistry.clear();
|
||||
}
|
||||
6
packages/pi-ai/src/bedrock-provider.ts
Normal file
6
packages/pi-ai/src/bedrock-provider.ts
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
import { streamBedrock, streamSimpleBedrock } from "./providers/amazon-bedrock.js";
|
||||
|
||||
export const bedrockProviderModule = {
|
||||
streamBedrock,
|
||||
streamSimpleBedrock,
|
||||
};
|
||||
133
packages/pi-ai/src/cli.ts
Normal file
133
packages/pi-ai/src/cli.ts
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
#!/usr/bin/env node
|
||||
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { createInterface } from "readline";
|
||||
import { getOAuthProvider, getOAuthProviders } from "./utils/oauth/index.js";
|
||||
import type { OAuthCredentials, OAuthProviderId } from "./utils/oauth/types.js";
|
||||
|
||||
const AUTH_FILE = "auth.json";
|
||||
const PROVIDERS = getOAuthProviders();
|
||||
|
||||
function prompt(rl: ReturnType<typeof createInterface>, question: string): Promise<string> {
|
||||
return new Promise((resolve) => rl.question(question, resolve));
|
||||
}
|
||||
|
||||
function loadAuth(): Record<string, { type: "oauth" } & OAuthCredentials> {
|
||||
if (!existsSync(AUTH_FILE)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(AUTH_FILE, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
function saveAuth(auth: Record<string, { type: "oauth" } & OAuthCredentials>): void {
|
||||
writeFileSync(AUTH_FILE, JSON.stringify(auth, null, 2), "utf-8");
|
||||
}
|
||||
|
||||
async function login(providerId: OAuthProviderId): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
console.error(`Unknown provider: ${providerId}`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
const promptFn = (msg: string) => prompt(rl, `${msg} `);
|
||||
|
||||
try {
|
||||
const credentials = await provider.login({
|
||||
onAuth: (info) => {
|
||||
console.log(`\nOpen this URL in your browser:\n${info.url}`);
|
||||
if (info.instructions) console.log(info.instructions);
|
||||
console.log();
|
||||
},
|
||||
onPrompt: async (p) => {
|
||||
return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`);
|
||||
},
|
||||
onProgress: (msg) => console.log(msg),
|
||||
});
|
||||
|
||||
const auth = loadAuth();
|
||||
auth[providerId] = { type: "oauth", ...credentials };
|
||||
saveAuth(auth);
|
||||
|
||||
console.log(`\nCredentials saved to ${AUTH_FILE}`);
|
||||
} finally {
|
||||
rl.close();
|
||||
}
|
||||
}
|
||||
|
||||
async function main(): Promise<void> {
|
||||
const args = process.argv.slice(2);
|
||||
const command = args[0];
|
||||
|
||||
if (!command || command === "help" || command === "--help" || command === "-h") {
|
||||
const providerList = PROVIDERS.map((p) => ` ${p.id.padEnd(20)} ${p.name}`).join("\n");
|
||||
console.log(`Usage: npx @gsd/pi-ai <command> [provider]
|
||||
|
||||
Commands:
|
||||
login [provider] Login to an OAuth provider
|
||||
list List available providers
|
||||
|
||||
Providers:
|
||||
${providerList}
|
||||
|
||||
Examples:
|
||||
npx @gsd/pi-ai login # interactive provider selection
|
||||
npx @gsd/pi-ai login anthropic # login to specific provider
|
||||
npx @gsd/pi-ai list # list providers
|
||||
`);
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "list") {
|
||||
console.log("Available OAuth providers:\n");
|
||||
for (const p of PROVIDERS) {
|
||||
console.log(` ${p.id.padEnd(20)} ${p.name}`);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (command === "login") {
|
||||
let provider = args[1] as OAuthProviderId | undefined;
|
||||
|
||||
if (!provider) {
|
||||
const rl = createInterface({ input: process.stdin, output: process.stdout });
|
||||
console.log("Select a provider:\n");
|
||||
for (let i = 0; i < PROVIDERS.length; i++) {
|
||||
console.log(` ${i + 1}. ${PROVIDERS[i].name}`);
|
||||
}
|
||||
console.log();
|
||||
|
||||
const choice = await prompt(rl, `Enter number (1-${PROVIDERS.length}): `);
|
||||
rl.close();
|
||||
|
||||
const index = parseInt(choice, 10) - 1;
|
||||
if (index < 0 || index >= PROVIDERS.length) {
|
||||
console.error("Invalid selection");
|
||||
process.exit(1);
|
||||
}
|
||||
provider = PROVIDERS[index].id;
|
||||
}
|
||||
|
||||
if (!PROVIDERS.some((p) => p.id === provider)) {
|
||||
console.error(`Unknown provider: ${provider}`);
|
||||
console.error(`Use 'npx @gsd/pi-ai list' to see available providers`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
console.log(`Logging in to ${provider}...`);
|
||||
await login(provider);
|
||||
return;
|
||||
}
|
||||
|
||||
console.error(`Unknown command: ${command}`);
|
||||
console.error(`Use 'npx @gsd/pi-ai --help' for usage`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
main().catch((err) => {
|
||||
console.error("Error:", err.message);
|
||||
process.exit(1);
|
||||
});
|
||||
129
packages/pi-ai/src/env-api-keys.ts
Normal file
129
packages/pi-ai/src/env-api-keys.ts
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _existsSync: typeof import("node:fs").existsSync | null = null;
|
||||
let _homedir: typeof import("node:os").homedir | null = null;
|
||||
let _join: typeof import("node:path").join | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_FS_SPECIFIER = "node:" + "fs";
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
const NODE_PATH_SPECIFIER = "node:" + "path";
|
||||
|
||||
// Eagerly load in Node.js/Bun environment only
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
dynamicImport(NODE_FS_SPECIFIER).then((m) => {
|
||||
_existsSync = (m as typeof import("node:fs")).existsSync;
|
||||
});
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_homedir = (m as typeof import("node:os")).homedir;
|
||||
});
|
||||
dynamicImport(NODE_PATH_SPECIFIER).then((m) => {
|
||||
_join = (m as typeof import("node:path")).join;
|
||||
});
|
||||
}
|
||||
|
||||
import type { KnownProvider } from "./types.js";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
if (cachedVertexAdcCredentialsExists === null) {
|
||||
// If node modules haven't loaded yet (async import race at startup),
|
||||
// return false WITHOUT caching so the next call retries once they're ready.
|
||||
// Only cache false permanently in a browser environment where fs is never available.
|
||||
if (!_existsSync || !_homedir || !_join) {
|
||||
const isNode = typeof process !== "undefined" && (process.versions?.node || process.versions?.bun);
|
||||
if (!isNode) {
|
||||
// Definitively in a browser — safe to cache false permanently
|
||||
cachedVertexAdcCredentialsExists = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
if (gacPath) {
|
||||
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
|
||||
} else {
|
||||
// Fall back to default ADC path (lazy evaluation)
|
||||
cachedVertexAdcCredentialsExists = _existsSync(
|
||||
_join(_homedir(), ".config", "gcloud", "application_default_credentials.json"),
|
||||
);
|
||||
}
|
||||
}
|
||||
return cachedVertexAdcCredentialsExists;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for provider from known environment variables, e.g. OPENAI_API_KEY.
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: any): string | undefined {
|
||||
// Fall back to environment variables
|
||||
if (provider === "github-copilot") {
|
||||
return process.env.COPILOT_GITHUB_TOKEN || process.env.GH_TOKEN || process.env.GITHUB_TOKEN;
|
||||
}
|
||||
|
||||
// ANTHROPIC_OAUTH_TOKEN takes precedence over ANTHROPIC_API_KEY
|
||||
if (provider === "anthropic") {
|
||||
return process.env.ANTHROPIC_OAUTH_TOKEN || process.env.ANTHROPIC_API_KEY;
|
||||
}
|
||||
|
||||
// Vertex AI uses Application Default Credentials, not API keys.
|
||||
// Auth is configured via `gcloud auth application-default login`.
|
||||
if (provider === "google-vertex") {
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasProject = !!(process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT);
|
||||
const hasLocation = !!process.env.GOOGLE_CLOUD_LOCATION;
|
||||
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
if (provider === "amazon-bedrock") {
|
||||
// Amazon Bedrock supports multiple credential sources:
|
||||
// 1. AWS_PROFILE - named profile from ~/.aws/credentials
|
||||
// 2. AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY - standard IAM keys
|
||||
// 3. AWS_BEARER_TOKEN_BEDROCK - Bedrock API keys (bearer token)
|
||||
// 4. AWS_CONTAINER_CREDENTIALS_RELATIVE_URI - ECS task roles
|
||||
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
|
||||
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
|
||||
if (
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
}
|
||||
|
||||
const envMap: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
"azure-openai-responses": "AZURE_OPENAI_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
groq: "GROQ_API_KEY",
|
||||
cerebras: "CEREBRAS_API_KEY",
|
||||
xai: "XAI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
"vercel-ai-gateway": "AI_GATEWAY_API_KEY",
|
||||
zai: "ZAI_API_KEY",
|
||||
mistral: "MISTRAL_API_KEY",
|
||||
minimax: "MINIMAX_API_KEY",
|
||||
"minimax-cn": "MINIMAX_CN_API_KEY",
|
||||
huggingface: "HF_TOKEN",
|
||||
opencode: "OPENCODE_API_KEY",
|
||||
"opencode-go": "OPENCODE_API_KEY",
|
||||
"kimi-coding": "KIMI_API_KEY",
|
||||
};
|
||||
|
||||
const envVar = envMap[provider];
|
||||
return envVar ? process.env[envVar] : undefined;
|
||||
}
|
||||
32
packages/pi-ai/src/index.ts
Normal file
32
packages/pi-ai/src/index.ts
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
export type { Static, TSchema } from "@sinclair/typebox";
|
||||
export { Type } from "@sinclair/typebox";
|
||||
|
||||
export * from "./api-registry.js";
|
||||
export * from "./env-api-keys.js";
|
||||
export * from "./models.js";
|
||||
export * from "./providers/anthropic.js";
|
||||
export * from "./providers/azure-openai-responses.js";
|
||||
export * from "./providers/google.js";
|
||||
export * from "./providers/google-gemini-cli.js";
|
||||
export * from "./providers/google-vertex.js";
|
||||
export * from "./providers/mistral.js";
|
||||
export * from "./providers/openai-completions.js";
|
||||
export * from "./providers/openai-responses.js";
|
||||
export * from "./providers/register-builtins.js";
|
||||
export * from "./stream.js";
|
||||
export * from "./types.js";
|
||||
export * from "./utils/event-stream.js";
|
||||
export * from "./utils/json-parse.js";
|
||||
export type {
|
||||
OAuthAuthInfo,
|
||||
OAuthCredentials,
|
||||
OAuthLoginCallbacks,
|
||||
OAuthPrompt,
|
||||
OAuthProvider,
|
||||
OAuthProviderId,
|
||||
OAuthProviderInfo,
|
||||
OAuthProviderInterface,
|
||||
} from "./utils/oauth/types.js";
|
||||
export * from "./utils/overflow.js";
|
||||
export * from "./utils/typebox-helpers.js";
|
||||
export * from "./utils/validation.js";
|
||||
13370
packages/pi-ai/src/models.generated.ts
Normal file
13370
packages/pi-ai/src/models.generated.ts
Normal file
File diff suppressed because it is too large
Load diff
77
packages/pi-ai/src/models.ts
Normal file
77
packages/pi-ai/src/models.ts
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
import { MODELS } from "./models.generated.js";
|
||||
import type { Api, KnownProvider, Model, Usage } from "./types.js";
|
||||
|
||||
const modelRegistry: Map<string, Map<string, Model<Api>>> = new Map();
|
||||
|
||||
// Initialize registry from MODELS on module load
|
||||
for (const [provider, models] of Object.entries(MODELS)) {
|
||||
const providerModels = new Map<string, Model<Api>>();
|
||||
for (const [id, model] of Object.entries(models)) {
|
||||
providerModels.set(id, model as Model<Api>);
|
||||
}
|
||||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||
|
||||
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof MODELS)[TProvider]>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>> {
|
||||
const providerModels = modelRegistry.get(provider);
|
||||
return providerModels?.get(modelId as string) as Model<ModelApi<TProvider, TModelId>>;
|
||||
}
|
||||
|
||||
export function getProviders(): KnownProvider[] {
|
||||
return Array.from(modelRegistry.keys()) as KnownProvider[];
|
||||
}
|
||||
|
||||
export function getModels<TProvider extends KnownProvider>(
|
||||
provider: TProvider,
|
||||
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
|
||||
const models = modelRegistry.get(provider);
|
||||
return models ? (Array.from(models.values()) as Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[]) : [];
|
||||
}
|
||||
|
||||
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||
usage.cost.input = (model.cost.input / 1000000) * usage.input;
|
||||
usage.cost.output = (model.cost.output / 1000000) * usage.output;
|
||||
usage.cost.cacheRead = (model.cost.cacheRead / 1000000) * usage.cacheRead;
|
||||
usage.cost.cacheWrite = (model.cost.cacheWrite / 1000000) * usage.cacheWrite;
|
||||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
return usage.cost;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model supports xhigh thinking level.
|
||||
*
|
||||
* Supported today:
|
||||
* - GPT-5.2 / GPT-5.3 / GPT-5.4 model families
|
||||
* - Anthropic Messages API Opus 4.6 models (xhigh maps to adaptive effort "max")
|
||||
*/
|
||||
export function supportsXhigh<TApi extends Api>(model: Model<TApi>): boolean {
|
||||
if (model.id.includes("gpt-5.2") || model.id.includes("gpt-5.3") || model.id.includes("gpt-5.4")) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (model.api === "anthropic-messages") {
|
||||
return model.id.includes("opus-4-6") || model.id.includes("opus-4.6");
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if two models are equal by comparing both their id and provider.
|
||||
* Returns false if either model is null or undefined.
|
||||
*/
|
||||
export function modelsAreEqual<TApi extends Api>(
|
||||
a: Model<TApi> | null | undefined,
|
||||
b: Model<TApi> | null | undefined,
|
||||
): boolean {
|
||||
if (!a || !b) return false;
|
||||
return a.id === b.id && a.provider === b.provider;
|
||||
}
|
||||
1
packages/pi-ai/src/oauth.ts
Normal file
1
packages/pi-ai/src/oauth.ts
Normal file
|
|
@ -0,0 +1 @@
|
|||
export * from "./utils/oauth/index.js";
|
||||
751
packages/pi-ai/src/providers/amazon-bedrock.ts
Normal file
751
packages/pi-ai/src/providers/amazon-bedrock.ts
Normal file
|
|
@ -0,0 +1,751 @@
|
|||
import {
|
||||
BedrockRuntimeClient,
|
||||
type BedrockRuntimeClientConfig,
|
||||
StopReason as BedrockStopReason,
|
||||
type Tool as BedrockTool,
|
||||
CachePointType,
|
||||
CacheTTL,
|
||||
type ContentBlock,
|
||||
type ContentBlockDeltaEvent,
|
||||
type ContentBlockStartEvent,
|
||||
type ContentBlockStopEvent,
|
||||
ConversationRole,
|
||||
ConverseStreamCommand,
|
||||
type ConverseStreamMetadataEvent,
|
||||
ImageFormat,
|
||||
type Message,
|
||||
type SystemContentBlock,
|
||||
type ToolChoice,
|
||||
type ToolConfiguration,
|
||||
ToolResultStatus,
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { adjustMaxTokensForThinking, buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
export interface BedrockOptions extends StreamOptions {
|
||||
region?: string;
|
||||
profile?: string;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
/* See https://docs.aws.amazon.com/bedrock/latest/userguide/inference-reasoning.html for supported models. */
|
||||
reasoning?: ThinkingLevel;
|
||||
/* Custom token budgets per thinking level. Overrides default budgets. */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
/* Only supported by Claude 4.x models, see https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html#claude-messages-extended-thinking-tool-use-interleaved */
|
||||
interleavedThinking?: boolean;
|
||||
}
|
||||
|
||||
type Block = (TextContent | ThinkingContent | ToolCall) & { index?: number; partialJson?: string };
|
||||
|
||||
export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOptions> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options: BedrockOptions = {},
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
const config: BedrockRuntimeClientConfig = {
|
||||
profile: options.profile,
|
||||
};
|
||||
|
||||
// in Node.js/Bun environment only
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
// Region resolution: explicit option > env vars > SDK default chain.
|
||||
// When AWS_PROFILE is set, we leave region undefined so the SDK can
|
||||
// resovle it from aws profile configs. Otherwise fall back to us-east-1.
|
||||
const explicitRegion = options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION;
|
||||
if (explicitRegion) {
|
||||
config.region = explicitRegion;
|
||||
} else if (!process.env.AWS_PROFILE) {
|
||||
config.region = "us-east-1";
|
||||
}
|
||||
|
||||
// Support proxies that don't need authentication
|
||||
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
|
||||
config.credentials = {
|
||||
accessKeyId: "dummy-access-key",
|
||||
secretAccessKey: "dummy-secret-key",
|
||||
};
|
||||
}
|
||||
|
||||
if (
|
||||
process.env.HTTP_PROXY ||
|
||||
process.env.HTTPS_PROXY ||
|
||||
process.env.NO_PROXY ||
|
||||
process.env.http_proxy ||
|
||||
process.env.https_proxy ||
|
||||
process.env.no_proxy
|
||||
) {
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
const proxyAgent = await import("proxy-agent");
|
||||
|
||||
const agent = new proxyAgent.ProxyAgent();
|
||||
|
||||
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
|
||||
// on `http2` module and has no support for http agent.
|
||||
// Use NodeHttpHandler to support http agent.
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler({
|
||||
httpAgent: agent,
|
||||
httpsAgent: agent,
|
||||
});
|
||||
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
|
||||
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
|
||||
const nodeHttpHandler = await import("@smithy/node-http-handler");
|
||||
config.requestHandler = new nodeHttpHandler.NodeHttpHandler();
|
||||
}
|
||||
} else {
|
||||
// Non-Node environment (browser): fall back to us-east-1 since
|
||||
// there's no config file resolution available.
|
||||
config.region = options.region || "us-east-1";
|
||||
}
|
||||
|
||||
try {
|
||||
const client = new BedrockRuntimeClient(config);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
let commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
inferenceConfig: { maxTokens: options.maxTokens, temperature: options.temperature },
|
||||
toolConfig: convertToolConfig(context.tools, options.toolChoice),
|
||||
additionalModelRequestFields: buildAdditionalModelRequestFields(model, options),
|
||||
};
|
||||
const nextCommandInput = await options?.onPayload?.(commandInput, model);
|
||||
if (nextCommandInput !== undefined) {
|
||||
commandInput = nextCommandInput as typeof commandInput;
|
||||
}
|
||||
const command = new ConverseStreamCommand(commandInput);
|
||||
|
||||
const response = await client.send(command, { abortSignal: options.signal });
|
||||
|
||||
for await (const item of response.stream!) {
|
||||
if (item.messageStart) {
|
||||
if (item.messageStart.role !== ConversationRole.ASSISTANT) {
|
||||
throw new Error("Unexpected assistant message start but got user message start instead");
|
||||
}
|
||||
stream.push({ type: "start", partial: output });
|
||||
} else if (item.contentBlockStart) {
|
||||
handleContentBlockStart(item.contentBlockStart, blocks, output, stream);
|
||||
} else if (item.contentBlockDelta) {
|
||||
handleContentBlockDelta(item.contentBlockDelta, blocks, output, stream);
|
||||
} else if (item.contentBlockStop) {
|
||||
handleContentBlockStop(item.contentBlockStop, blocks, output, stream);
|
||||
} else if (item.messageStop) {
|
||||
output.stopReason = mapStopReason(item.messageStop.stopReason);
|
||||
} else if (item.metadata) {
|
||||
handleMetadata(item.metadata, model, output);
|
||||
} else if (item.internalServerException) {
|
||||
throw new Error(`Internal server error: ${item.internalServerException.message}`);
|
||||
} else if (item.modelStreamErrorException) {
|
||||
throw new Error(`Model stream error: ${item.modelStreamErrorException.message}`);
|
||||
} else if (item.validationException) {
|
||||
throw new Error(`Validation error: ${item.validationException.message}`);
|
||||
} else if (item.throttlingException) {
|
||||
throw new Error(`Throttling error: ${item.throttlingException.message}`);
|
||||
} else if (item.serviceUnavailableException) {
|
||||
throw new Error(`Service unavailable: ${item.serviceUnavailableException.message}`);
|
||||
}
|
||||
}
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "error" || output.stopReason === "aborted") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
delete (block as Block).index;
|
||||
delete (block as Block).partialJson;
|
||||
}
|
||||
output.stopReason = options.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleBedrock: StreamFunction<"bedrock-converse-stream", SimpleStreamOptions> = (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamBedrock(model, context, { ...base, reasoning: undefined } satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: {
|
||||
...(options.thinkingBudgets || {}),
|
||||
[clampReasoning(options.reasoning)!]: adjusted.thinkingBudget,
|
||||
},
|
||||
} satisfies BedrockOptions);
|
||||
}
|
||||
|
||||
return streamBedrock(model, context, {
|
||||
...base,
|
||||
reasoning: options.reasoning,
|
||||
thinkingBudgets: options.thinkingBudgets,
|
||||
} satisfies BedrockOptions);
|
||||
};
|
||||
|
||||
function handleContentBlockStart(
|
||||
event: ContentBlockStartEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = event.contentBlockIndex!;
|
||||
const start = event.start;
|
||||
|
||||
if (start?.toolUse) {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: start.toolUse.toolUseId || "",
|
||||
name: start.toolUse.name || "",
|
||||
arguments: {},
|
||||
partialJson: "",
|
||||
index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blocks.length - 1, partial: output });
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockDelta(
|
||||
event: ContentBlockDeltaEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const contentBlockIndex = event.contentBlockIndex!;
|
||||
const delta = event.delta;
|
||||
let index = blocks.findIndex((b) => b.index === contentBlockIndex);
|
||||
let block = blocks[index];
|
||||
|
||||
if (delta?.text !== undefined) {
|
||||
// If no text block exists yet, create one, as `handleContentBlockStart` is not sent for text blocks
|
||||
if (!block) {
|
||||
const newBlock: Block = { type: "text", text: "", index: contentBlockIndex };
|
||||
output.content.push(newBlock);
|
||||
index = blocks.length - 1;
|
||||
block = blocks[index];
|
||||
stream.push({ type: "text_start", contentIndex: index, partial: output });
|
||||
}
|
||||
if (block.type === "text") {
|
||||
block.text += delta.text;
|
||||
stream.push({ type: "text_delta", contentIndex: index, delta: delta.text, partial: output });
|
||||
}
|
||||
} else if (delta?.toolUse && block?.type === "toolCall") {
|
||||
block.partialJson = (block.partialJson || "") + (delta.toolUse.input || "");
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({ type: "toolcall_delta", contentIndex: index, delta: delta.toolUse.input || "", partial: output });
|
||||
} else if (delta?.reasoningContent) {
|
||||
let thinkingBlock = block;
|
||||
let thinkingIndex = index;
|
||||
|
||||
if (!thinkingBlock) {
|
||||
const newBlock: Block = { type: "thinking", thinking: "", thinkingSignature: "", index: contentBlockIndex };
|
||||
output.content.push(newBlock);
|
||||
thinkingIndex = blocks.length - 1;
|
||||
thinkingBlock = blocks[thinkingIndex];
|
||||
stream.push({ type: "thinking_start", contentIndex: thinkingIndex, partial: output });
|
||||
}
|
||||
|
||||
if (thinkingBlock?.type === "thinking") {
|
||||
if (delta.reasoningContent.text) {
|
||||
thinkingBlock.thinking += delta.reasoningContent.text;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: thinkingIndex,
|
||||
delta: delta.reasoningContent.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
if (delta.reasoningContent.signature) {
|
||||
thinkingBlock.thinkingSignature =
|
||||
(thinkingBlock.thinkingSignature || "") + delta.reasoningContent.signature;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function handleMetadata(
|
||||
event: ConverseStreamMetadataEvent,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
if (event.usage) {
|
||||
output.usage.input = event.usage.inputTokens || 0;
|
||||
output.usage.output = event.usage.outputTokens || 0;
|
||||
output.usage.cacheRead = event.usage.cacheReadInputTokens || 0;
|
||||
output.usage.cacheWrite = event.usage.cacheWriteInputTokens || 0;
|
||||
output.usage.totalTokens = event.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
function handleContentBlockStop(
|
||||
event: ContentBlockStopEvent,
|
||||
blocks: Block[],
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
): void {
|
||||
const index = blocks.findIndex((b) => b.index === event.contentBlockIndex);
|
||||
const block = blocks[index];
|
||||
if (!block) return;
|
||||
delete (block as Block).index;
|
||||
|
||||
switch (block.type) {
|
||||
case "text":
|
||||
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: output });
|
||||
break;
|
||||
case "thinking":
|
||||
stream.push({ type: "thinking_end", contentIndex: index, content: block.thinking, partial: output });
|
||||
break;
|
||||
case "toolCall":
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as Block).partialJson;
|
||||
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: output });
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
function mapThinkingLevelToEffort(
|
||||
level: SimpleStreamOptions["reasoning"],
|
||||
modelId: string,
|
||||
): "low" | "medium" | "high" | "max" {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports prompt caching.
|
||||
* Supported: Claude 3.5 Haiku, Claude 3.7 Sonnet, Claude 4.x models
|
||||
*/
|
||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
||||
if (model.cost.cacheRead || model.cost.cacheWrite) {
|
||||
return true;
|
||||
}
|
||||
|
||||
const id = model.id.toLowerCase();
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
if (id.includes("claude") && (id.includes("-4-") || id.includes("-4."))) return true;
|
||||
// Claude 3.7 Sonnet
|
||||
if (id.includes("claude-3-7-sonnet")) return true;
|
||||
// Claude 3.5 Haiku
|
||||
if (id.includes("claude-3-5-haiku")) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model supports thinking signatures in reasoningContent.
|
||||
* Only Anthropic Claude models support the signature field.
|
||||
* Other models (OpenAI, Qwen, Minimax, Moonshot, etc.) reject it with:
|
||||
* "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
*/
|
||||
function supportsThinkingSignature(model: Model<"bedrock-converse-stream">): boolean {
|
||||
const id = model.id.toLowerCase();
|
||||
return id.includes("anthropic.claude") || id.includes("anthropic/claude");
|
||||
}
|
||||
|
||||
function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
||||
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
blocks.push({
|
||||
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
function normalizeToolCallId(id: string): string {
|
||||
const sanitized = id.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
return sanitized.length > 64 ? sanitized.slice(0, 64) : sanitized;
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const m = transformedMessages[i];
|
||||
|
||||
switch (m.role) {
|
||||
case "user":
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content:
|
||||
typeof m.content === "string"
|
||||
? [{ text: sanitizeSurrogates(m.content) }]
|
||||
: m.content.map((c) => {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
return { text: sanitizeSurrogates(c.text) };
|
||||
case "image":
|
||||
return { image: createImageBlock(c.mimeType, c.data) };
|
||||
default:
|
||||
throw new Error("Unknown user content type");
|
||||
}
|
||||
}),
|
||||
});
|
||||
break;
|
||||
case "assistant": {
|
||||
// Skip assistant messages with empty content (e.g., from aborted requests)
|
||||
// Bedrock rejects messages with empty content arrays
|
||||
if (m.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const contentBlocks: ContentBlock[] = [];
|
||||
for (const c of m.content) {
|
||||
switch (c.type) {
|
||||
case "text":
|
||||
// Skip empty text blocks
|
||||
if (c.text.trim().length === 0) continue;
|
||||
contentBlocks.push({ text: sanitizeSurrogates(c.text) });
|
||||
break;
|
||||
case "toolCall":
|
||||
contentBlocks.push({
|
||||
toolUse: { toolUseId: c.id, name: c.name, input: c.arguments },
|
||||
});
|
||||
break;
|
||||
case "thinking":
|
||||
// Skip empty thinking blocks
|
||||
if (c.thinking.trim().length === 0) continue;
|
||||
// Only Anthropic models support the signature field in reasoningText.
|
||||
// For other models, we omit the signature to avoid errors like:
|
||||
// "This model doesn't support the reasoningContent.reasoningText.signature field"
|
||||
if (supportsThinkingSignature(model)) {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: { text: sanitizeSurrogates(c.thinking), signature: c.thinkingSignature },
|
||||
},
|
||||
});
|
||||
} else {
|
||||
contentBlocks.push({
|
||||
reasoningContent: {
|
||||
reasoningText: { text: sanitizeSurrogates(c.thinking) },
|
||||
},
|
||||
});
|
||||
}
|
||||
break;
|
||||
default:
|
||||
throw new Error("Unknown assistant content type");
|
||||
}
|
||||
}
|
||||
// Skip if all content blocks were filtered out
|
||||
if (contentBlocks.length === 0) {
|
||||
continue;
|
||||
}
|
||||
result.push({
|
||||
role: ConversationRole.ASSISTANT,
|
||||
content: contentBlocks,
|
||||
});
|
||||
break;
|
||||
}
|
||||
case "toolResult": {
|
||||
// Collect all consecutive toolResult messages into a single user message
|
||||
// Bedrock requires all tool results to be in one message
|
||||
const toolResults: ContentBlock.ToolResultMember[] = [];
|
||||
|
||||
// Add current tool result with all content blocks combined
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: m.toolCallId,
|
||||
content: m.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: m.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage;
|
||||
toolResults.push({
|
||||
toolResult: {
|
||||
toolUseId: nextMsg.toolCallId,
|
||||
content: nextMsg.content.map((c) =>
|
||||
c.type === "image"
|
||||
? { image: createImageBlock(c.mimeType, c.data) }
|
||||
: { text: sanitizeSurrogates(c.text) },
|
||||
),
|
||||
status: nextMsg.isError ? ToolResultStatus.ERROR : ToolResultStatus.SUCCESS,
|
||||
},
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
result.push({
|
||||
role: ConversationRole.USER,
|
||||
content: toolResults,
|
||||
});
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw new Error("Unknown message role");
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
cachePoint: {
|
||||
type: CachePointType.DEFAULT,
|
||||
...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function convertToolConfig(
|
||||
tools: Tool[] | undefined,
|
||||
toolChoice: BedrockOptions["toolChoice"],
|
||||
): ToolConfiguration | undefined {
|
||||
if (!tools?.length || toolChoice === "none") return undefined;
|
||||
|
||||
const bedrockTools: BedrockTool[] = tools.map((tool) => ({
|
||||
toolSpec: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
inputSchema: { json: tool.parameters },
|
||||
},
|
||||
}));
|
||||
|
||||
let bedrockToolChoice: ToolChoice | undefined;
|
||||
switch (toolChoice) {
|
||||
case "auto":
|
||||
bedrockToolChoice = { auto: {} };
|
||||
break;
|
||||
case "any":
|
||||
bedrockToolChoice = { any: {} };
|
||||
break;
|
||||
default:
|
||||
if (toolChoice?.type === "tool") {
|
||||
bedrockToolChoice = { tool: { name: toolChoice.name } };
|
||||
}
|
||||
}
|
||||
|
||||
return { tools: bedrockTools, toolChoice: bedrockToolChoice };
|
||||
}
|
||||
|
||||
function mapStopReason(reason: string | undefined): StopReason {
|
||||
switch (reason) {
|
||||
case BedrockStopReason.END_TURN:
|
||||
case BedrockStopReason.STOP_SEQUENCE:
|
||||
return "stop";
|
||||
case BedrockStopReason.MAX_TOKENS:
|
||||
case BedrockStopReason.MODEL_CONTEXT_WINDOW_EXCEEDED:
|
||||
return "length";
|
||||
case BedrockStopReason.TOOL_USE:
|
||||
return "toolUse";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
|
||||
function buildAdditionalModelRequestFields(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
options: BedrockOptions,
|
||||
): Record<string, any> | undefined {
|
||||
if (!options.reasoning || !model.reasoning) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (model.id.includes("anthropic.claude") || model.id.includes("anthropic/claude")) {
|
||||
const result: Record<string, any> = supportsAdaptiveThinking(model.id)
|
||||
? {
|
||||
thinking: { type: "adaptive" },
|
||||
output_config: { effort: mapThinkingLevelToEffort(options.reasoning, model.id) },
|
||||
}
|
||||
: (() => {
|
||||
const defaultBudgets: Record<ThinkingLevel, number> = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
xhigh: 16384, // Claude doesn't support xhigh, clamp to high
|
||||
};
|
||||
|
||||
// Custom budgets override defaults (xhigh not in ThinkingBudgets, use high)
|
||||
const level = options.reasoning === "xhigh" ? "high" : options.reasoning;
|
||||
const budget = options.thinkingBudgets?.[level] ?? defaultBudgets[options.reasoning];
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
type: "enabled",
|
||||
budget_tokens: budget,
|
||||
},
|
||||
};
|
||||
})();
|
||||
|
||||
if (!supportsAdaptiveThinking(model.id) && (options.interleavedThinking ?? true)) {
|
||||
result.anthropic_beta = ["interleaved-thinking-2025-05-14"];
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function createImageBlock(mimeType: string, data: string) {
|
||||
let format: ImageFormat;
|
||||
switch (mimeType) {
|
||||
case "image/jpeg":
|
||||
case "image/jpg":
|
||||
format = ImageFormat.JPEG;
|
||||
break;
|
||||
case "image/png":
|
||||
format = ImageFormat.PNG;
|
||||
break;
|
||||
case "image/gif":
|
||||
format = ImageFormat.GIF;
|
||||
break;
|
||||
case "image/webp":
|
||||
format = ImageFormat.WEBP;
|
||||
break;
|
||||
default:
|
||||
throw new Error(`Unknown image type: ${mimeType}`);
|
||||
}
|
||||
|
||||
const binaryString = atob(data);
|
||||
const bytes = new Uint8Array(binaryString.length);
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i);
|
||||
}
|
||||
|
||||
return { source: { bytes }, format };
|
||||
}
|
||||
883
packages/pi-ai/src/providers/anthropic.ts
Normal file
883
packages/pi-ai/src/providers/anthropic.ts
Normal file
|
|
@ -0,0 +1,883 @@
|
|||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import type {
|
||||
ContentBlockParam,
|
||||
MessageCreateParamsStreaming,
|
||||
MessageParam,
|
||||
} from "@anthropic-ai/sdk/resources/messages.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { adjustMaxTokensForThinking, buildBaseOptions } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
function getCacheControl(
|
||||
baseUrl: string,
|
||||
cacheRetention?: CacheRetention,
|
||||
): { retention: CacheRetention; cacheControl?: { type: "ephemeral"; ttl?: "1h" } } {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
const ttl = retention === "long" && baseUrl.includes("api.anthropic.com") ? "1h" : undefined;
|
||||
return {
|
||||
retention,
|
||||
cacheControl: { type: "ephemeral", ...(ttl && { ttl }) },
|
||||
};
|
||||
}
|
||||
|
||||
// Stealth mode: Mimic Claude Code's tool naming exactly
|
||||
const claudeCodeVersion = "2.1.62";
|
||||
|
||||
// Claude Code 2.x tool names (canonical casing)
|
||||
// Source: https://cchistory.mariozechner.at/data/prompts-2.1.11.md
|
||||
// To update: https://github.com/badlogic/cchistory
|
||||
const claudeCodeTools = [
|
||||
"Read",
|
||||
"Write",
|
||||
"Edit",
|
||||
"Bash",
|
||||
"Grep",
|
||||
"Glob",
|
||||
"AskUserQuestion",
|
||||
"EnterPlanMode",
|
||||
"ExitPlanMode",
|
||||
"KillShell",
|
||||
"NotebookEdit",
|
||||
"Skill",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"TodoWrite",
|
||||
"WebFetch",
|
||||
"WebSearch",
|
||||
];
|
||||
|
||||
const ccToolLookup = new Map(claudeCodeTools.map((t) => [t.toLowerCase(), t]));
|
||||
|
||||
// Convert tool name to CC canonical casing if it matches (case-insensitive)
|
||||
const toClaudeCodeName = (name: string) => ccToolLookup.get(name.toLowerCase()) ?? name;
|
||||
const fromClaudeCodeName = (name: string, tools?: Tool[]) => {
|
||||
if (tools && tools.length > 0) {
|
||||
const lowerName = name.toLowerCase();
|
||||
const matchedTool = tools.find((tool) => tool.name.toLowerCase() === lowerName);
|
||||
if (matchedTool) return matchedTool.name;
|
||||
}
|
||||
return name;
|
||||
};
|
||||
|
||||
/**
|
||||
* Convert content blocks to Anthropic API format
|
||||
*/
|
||||
function convertContentBlocks(content: (TextContent | ImageContent)[]):
|
||||
| string
|
||||
| Array<
|
||||
| { type: "text"; text: string }
|
||||
| {
|
||||
type: "image";
|
||||
source: {
|
||||
type: "base64";
|
||||
media_type: "image/jpeg" | "image/png" | "image/gif" | "image/webp";
|
||||
data: string;
|
||||
};
|
||||
}
|
||||
> {
|
||||
// If only text blocks, return as concatenated string for simplicity
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (!hasImages) {
|
||||
return sanitizeSurrogates(content.map((c) => (c as TextContent).text).join("\n"));
|
||||
}
|
||||
|
||||
// If we have images, convert to content block array
|
||||
const blocks = content.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: sanitizeSurrogates(block.text),
|
||||
};
|
||||
}
|
||||
return {
|
||||
type: "image" as const,
|
||||
source: {
|
||||
type: "base64" as const,
|
||||
media_type: block.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: block.data,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// If only images (no text), add placeholder text block
|
||||
const hasText = blocks.some((b) => b.type === "text");
|
||||
if (!hasText) {
|
||||
blocks.unshift({
|
||||
type: "text" as const,
|
||||
text: "(see attached image)",
|
||||
});
|
||||
}
|
||||
|
||||
return blocks;
|
||||
}
|
||||
|
||||
export type AnthropicEffort = "low" | "medium" | "high" | "max";
|
||||
|
||||
export interface AnthropicOptions extends StreamOptions {
|
||||
/**
|
||||
* Enable extended thinking.
|
||||
* For Opus 4.6 and Sonnet 4.6: uses adaptive thinking (model decides when/how much to think).
|
||||
* For older models: uses budget-based thinking with thinkingBudgetTokens.
|
||||
*/
|
||||
thinkingEnabled?: boolean;
|
||||
/**
|
||||
* Token budget for extended thinking (older models only).
|
||||
* Ignored for Opus 4.6 and Sonnet 4.6, which use adaptive thinking.
|
||||
*/
|
||||
thinkingBudgetTokens?: number;
|
||||
/**
|
||||
* Effort level for adaptive thinking (Opus 4.6 and Sonnet 4.6).
|
||||
* Controls how much thinking Claude allocates:
|
||||
* - "max": Always thinks with no constraints (Opus 4.6 only)
|
||||
* - "high": Always thinks, deep reasoning (default)
|
||||
* - "medium": Moderate thinking, may skip for simple queries
|
||||
* - "low": Minimal thinking, skips for simple tasks
|
||||
* Ignored for older models.
|
||||
*/
|
||||
effort?: AnthropicEffort;
|
||||
interleavedThinking?: boolean;
|
||||
toolChoice?: "auto" | "any" | "none" | { type: "tool"; name: string };
|
||||
}
|
||||
|
||||
function mergeHeaders(...headerSources: (Record<string, string> | undefined)[]): Record<string, string> {
|
||||
const merged: Record<string, string> = {};
|
||||
for (const headers of headerSources) {
|
||||
if (headers) {
|
||||
Object.assign(merged, headers);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOptions> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: AnthropicOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey ?? getEnvApiKey(model.provider) ?? "";
|
||||
|
||||
let copilotDynamicHeaders: Record<string, string> | undefined;
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
copilotDynamicHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
}
|
||||
|
||||
const { client, isOAuthToken } = createClient(
|
||||
model,
|
||||
apiKey,
|
||||
options?.interleavedThinking ?? true,
|
||||
options?.headers,
|
||||
copilotDynamicHeaders,
|
||||
);
|
||||
let params = buildParams(model, context, isOAuthToken, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as MessageCreateParamsStreaming;
|
||||
}
|
||||
const anthropicStream = client.messages.stream({ ...params, stream: true }, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
type Block = (ThinkingContent | TextContent | (ToolCall & { partialJson: string })) & { index: number };
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
for await (const event of anthropicStream) {
|
||||
if (event.type === "message_start") {
|
||||
// Capture initial token usage from message_start event
|
||||
// This ensures we have input token counts even if the stream is aborted early
|
||||
output.usage.input = event.message.usage.input_tokens || 0;
|
||||
output.usage.output = event.message.usage.output_tokens || 0;
|
||||
output.usage.cacheRead = event.message.usage.cache_read_input_tokens || 0;
|
||||
output.usage.cacheWrite = event.message.usage.cache_creation_input_tokens || 0;
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
} else if (event.type === "content_block_start") {
|
||||
if (event.content_block.type === "text") {
|
||||
const block: Block = {
|
||||
type: "text",
|
||||
text: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "text_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "redacted_thinking") {
|
||||
const block: Block = {
|
||||
type: "thinking",
|
||||
thinking: "[Reasoning redacted]",
|
||||
thinkingSignature: event.content_block.data,
|
||||
redacted: true,
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "thinking_start", contentIndex: output.content.length - 1, partial: output });
|
||||
} else if (event.content_block.type === "tool_use") {
|
||||
const block: Block = {
|
||||
type: "toolCall",
|
||||
id: event.content_block.id,
|
||||
name: isOAuthToken
|
||||
? fromClaudeCodeName(event.content_block.name, context.tools)
|
||||
: event.content_block.name,
|
||||
arguments: (event.content_block.input as Record<string, any>) ?? {},
|
||||
partialJson: "",
|
||||
index: event.index,
|
||||
};
|
||||
output.content.push(block);
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
}
|
||||
} else if (event.type === "content_block_delta") {
|
||||
if (event.delta.type === "text_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "text") {
|
||||
block.text += event.delta.text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "thinking_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinking += event.delta.thinking;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "input_json_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "toolCall") {
|
||||
block.partialJson += event.delta.partial_json;
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: index,
|
||||
delta: event.delta.partial_json,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.delta.type === "signature_delta") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block && block.type === "thinking") {
|
||||
block.thinkingSignature = block.thinkingSignature || "";
|
||||
block.thinkingSignature += event.delta.signature;
|
||||
}
|
||||
}
|
||||
} else if (event.type === "content_block_stop") {
|
||||
const index = blocks.findIndex((b) => b.index === event.index);
|
||||
const block = blocks[index];
|
||||
if (block) {
|
||||
delete (block as any).index;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: index,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialJson);
|
||||
delete (block as any).partialJson;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "message_delta") {
|
||||
if (event.delta.stop_reason) {
|
||||
output.stopReason = mapStopReason(event.delta.stop_reason);
|
||||
}
|
||||
// Only update usage fields if present (not null).
|
||||
// Preserves input_tokens from message_start when proxies omit it in message_delta.
|
||||
if (event.usage.input_tokens != null) {
|
||||
output.usage.input = event.usage.input_tokens;
|
||||
}
|
||||
if (event.usage.output_tokens != null) {
|
||||
output.usage.output = event.usage.output_tokens;
|
||||
}
|
||||
if (event.usage.cache_read_input_tokens != null) {
|
||||
output.usage.cacheRead = event.usage.cache_read_input_tokens;
|
||||
}
|
||||
if (event.usage.cache_creation_input_tokens != null) {
|
||||
output.usage.cacheWrite = event.usage.cache_creation_input_tokens;
|
||||
}
|
||||
// Anthropic doesn't provide total_tokens, compute from components
|
||||
output.usage.totalTokens =
|
||||
output.usage.input + output.usage.output + output.usage.cacheRead + output.usage.cacheWrite;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Check if a model supports adaptive thinking (Opus 4.6 and Sonnet 4.6)
|
||||
*/
|
||||
function supportsAdaptiveThinking(modelId: string): boolean {
|
||||
// Opus 4.6 and Sonnet 4.6 model IDs (with or without date suffix)
|
||||
return (
|
||||
modelId.includes("opus-4-6") ||
|
||||
modelId.includes("opus-4.6") ||
|
||||
modelId.includes("sonnet-4-6") ||
|
||||
modelId.includes("sonnet-4.6")
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Map ThinkingLevel to Anthropic effort levels for adaptive thinking.
|
||||
* Note: effort "max" is only valid on Opus 4.6.
|
||||
*/
|
||||
function mapThinkingLevelToEffort(level: SimpleStreamOptions["reasoning"], modelId: string): AnthropicEffort {
|
||||
switch (level) {
|
||||
case "minimal":
|
||||
return "low";
|
||||
case "low":
|
||||
return "low";
|
||||
case "medium":
|
||||
return "medium";
|
||||
case "high":
|
||||
return "high";
|
||||
case "xhigh":
|
||||
return modelId.includes("opus-4-6") || modelId.includes("opus-4.6") ? "max" : "high";
|
||||
default:
|
||||
return "high";
|
||||
}
|
||||
}
|
||||
|
||||
export const streamSimpleAnthropic: StreamFunction<"anthropic-messages", SimpleStreamOptions> = (
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamAnthropic(model, context, { ...base, thinkingEnabled: false } satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
// For Opus 4.6 and Sonnet 4.6: use adaptive thinking with effort level
|
||||
// For older models: use budget-based thinking
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
const effort = mapThinkingLevelToEffort(options.reasoning, model.id);
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
thinkingEnabled: true,
|
||||
effort,
|
||||
} satisfies AnthropicOptions);
|
||||
}
|
||||
|
||||
const adjusted = adjustMaxTokensForThinking(
|
||||
base.maxTokens || 0,
|
||||
model.maxTokens,
|
||||
options.reasoning,
|
||||
options.thinkingBudgets,
|
||||
);
|
||||
|
||||
return streamAnthropic(model, context, {
|
||||
...base,
|
||||
maxTokens: adjusted.maxTokens,
|
||||
thinkingEnabled: true,
|
||||
thinkingBudgetTokens: adjusted.thinkingBudget,
|
||||
} satisfies AnthropicOptions);
|
||||
};
|
||||
|
||||
function isOAuthToken(apiKey: string): boolean {
|
||||
return apiKey.includes("sk-ant-oat");
|
||||
}
|
||||
|
||||
function createClient(
|
||||
model: Model<"anthropic-messages">,
|
||||
apiKey: string,
|
||||
interleavedThinking: boolean,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
dynamicHeaders?: Record<string, string>,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
// Adaptive thinking models (Opus 4.6, Sonnet 4.6) have interleaved thinking built-in.
|
||||
// The beta header is deprecated on Opus 4.6 and redundant on Sonnet 4.6, so skip it.
|
||||
const needsInterleavedBeta = interleavedThinking && !supportsAdaptiveThinking(model.id);
|
||||
|
||||
// Copilot: Bearer auth, selective betas (no fine-grained-tool-streaming)
|
||||
if (model.provider === "github-copilot") {
|
||||
const betaFeatures: string[] = [];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
...(betaFeatures.length > 0 ? { "anthropic-beta": betaFeatures.join(",") } : {}),
|
||||
},
|
||||
model.headers,
|
||||
dynamicHeaders,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
const betaFeatures = ["fine-grained-tool-streaming-2025-05-14"];
|
||||
if (needsInterleavedBeta) {
|
||||
betaFeatures.push("interleaved-thinking-2025-05-14");
|
||||
}
|
||||
|
||||
// OAuth: Bearer auth, Claude Code identity headers
|
||||
if (isOAuthToken(apiKey)) {
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": `claude-code-20250219,oauth-2025-04-20,${betaFeatures.join(",")}`,
|
||||
"user-agent": `claude-cli/${claudeCodeVersion}`,
|
||||
"x-app": "cli",
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: true };
|
||||
}
|
||||
|
||||
// API key auth
|
||||
const client = new Anthropic({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
accept: "application/json",
|
||||
"anthropic-dangerous-direct-browser-access": "true",
|
||||
"anthropic-beta": betaFeatures.join(","),
|
||||
},
|
||||
model.headers,
|
||||
optionsHeaders,
|
||||
),
|
||||
});
|
||||
|
||||
return { client, isOAuthToken: false };
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"anthropic-messages">,
|
||||
context: Context,
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(model.baseUrl, options?.cacheRetention);
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context.messages, model, isOAuthToken, cacheControl),
|
||||
max_tokens: options?.maxTokens || (model.maxTokens / 3) | 0,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// For OAuth tokens, we MUST include Claude Code identity
|
||||
if (isOAuthToken) {
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
if (context.systemPrompt) {
|
||||
params.system.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
});
|
||||
}
|
||||
} else if (context.systemPrompt) {
|
||||
// Add cache control to system prompt for non-OAuth tokens
|
||||
params.system = [
|
||||
{
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(context.systemPrompt),
|
||||
...(cacheControl ? { cache_control: cacheControl } : {}),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
// Temperature is incompatible with extended thinking (adaptive or budget-based).
|
||||
if (options?.temperature !== undefined && !options?.thinkingEnabled) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, isOAuthToken);
|
||||
}
|
||||
|
||||
// Configure thinking mode: adaptive (Opus 4.6 and Sonnet 4.6) or budget-based (older models)
|
||||
if (options?.thinkingEnabled && model.reasoning) {
|
||||
if (supportsAdaptiveThinking(model.id)) {
|
||||
// Adaptive thinking: Claude decides when and how much to think
|
||||
params.thinking = { type: "adaptive" };
|
||||
if (options.effort) {
|
||||
params.output_config = { effort: options.effort };
|
||||
}
|
||||
} else {
|
||||
// Budget-based thinking for older models
|
||||
params.thinking = {
|
||||
type: "enabled",
|
||||
budget_tokens: options.thinkingBudgetTokens || 1024,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.metadata) {
|
||||
const userId = options.metadata.user_id;
|
||||
if (typeof userId === "string") {
|
||||
params.metadata = { user_id: userId };
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
if (typeof options.toolChoice === "string") {
|
||||
params.tool_choice = { type: options.toolChoice };
|
||||
} else {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
// Normalize tool call IDs to match Anthropic's required pattern and length
|
||||
function normalizeToolCallId(id: string): string {
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
}
|
||||
|
||||
function convertMessages(
|
||||
messages: Message[],
|
||||
model: Model<"anthropic-messages">,
|
||||
isOAuthToken: boolean,
|
||||
cacheControl?: { type: "ephemeral"; ttl?: "1h" },
|
||||
): MessageParam[] {
|
||||
const params: MessageParam[] = [];
|
||||
|
||||
// Transform messages for cross-provider compatibility
|
||||
const transformedMessages = transformMessages(messages, model, normalizeToolCallId);
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
if (msg.content.trim().length > 0) {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
}
|
||||
} else {
|
||||
const blocks: ContentBlockParam[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
};
|
||||
} else {
|
||||
return {
|
||||
type: "image",
|
||||
source: {
|
||||
type: "base64",
|
||||
media_type: item.mimeType as "image/jpeg" | "image/png" | "image/gif" | "image/webp",
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
let filteredBlocks = !model?.input.includes("image") ? blocks.filter((b) => b.type !== "image") : blocks;
|
||||
filteredBlocks = filteredBlocks.filter((b) => {
|
||||
if (b.type === "text") {
|
||||
return b.text.trim().length > 0;
|
||||
}
|
||||
return true;
|
||||
});
|
||||
if (filteredBlocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredBlocks,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const blocks: ContentBlockParam[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length === 0) continue;
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.text),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Redacted thinking: pass the opaque payload back as redacted_thinking
|
||||
if (block.redacted) {
|
||||
blocks.push({
|
||||
type: "redacted_thinking",
|
||||
data: block.thinkingSignature!,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
if (block.thinking.trim().length === 0) continue;
|
||||
// If thinking signature is missing/empty (e.g., from aborted stream),
|
||||
// convert to plain text block without <thinking> tags to avoid API rejection
|
||||
// and prevent Claude from mimicking the tags in responses
|
||||
if (!block.thinkingSignature || block.thinkingSignature.trim().length === 0) {
|
||||
blocks.push({
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
} else {
|
||||
blocks.push({
|
||||
type: "thinking",
|
||||
thinking: sanitizeSurrogates(block.thinking),
|
||||
signature: block.thinkingSignature,
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
blocks.push({
|
||||
type: "tool_use",
|
||||
id: block.id,
|
||||
name: isOAuthToken ? toClaudeCodeName(block.name) : block.name,
|
||||
input: block.arguments ?? {},
|
||||
});
|
||||
}
|
||||
}
|
||||
if (blocks.length === 0) continue;
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: blocks,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Collect all consecutive toolResult messages, needed for z.ai Anthropic endpoint
|
||||
const toolResults: ContentBlockParam[] = [];
|
||||
|
||||
// Add the current tool result
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: msg.toolCallId,
|
||||
content: convertContentBlocks(msg.content),
|
||||
is_error: msg.isError,
|
||||
});
|
||||
|
||||
// Look ahead for consecutive toolResult messages
|
||||
let j = i + 1;
|
||||
while (j < transformedMessages.length && transformedMessages[j].role === "toolResult") {
|
||||
const nextMsg = transformedMessages[j] as ToolResultMessage; // We know it's a toolResult
|
||||
toolResults.push({
|
||||
type: "tool_result",
|
||||
tool_use_id: nextMsg.toolCallId,
|
||||
content: convertContentBlocks(nextMsg.content),
|
||||
is_error: nextMsg.isError,
|
||||
});
|
||||
j++;
|
||||
}
|
||||
|
||||
// Skip the messages we've already processed
|
||||
i = j - 1;
|
||||
|
||||
// Add a single user message with all tool results
|
||||
params.push({
|
||||
role: "user",
|
||||
content: toolResults,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Add cache_control to the last user message to cache conversation history
|
||||
if (cacheControl && params.length > 0) {
|
||||
const lastMessage = params[params.length - 1];
|
||||
if (lastMessage.role === "user") {
|
||||
if (Array.isArray(lastMessage.content)) {
|
||||
const lastBlock = lastMessage.content[lastMessage.content.length - 1];
|
||||
if (
|
||||
lastBlock &&
|
||||
(lastBlock.type === "text" || lastBlock.type === "image" || lastBlock.type === "tool_result")
|
||||
) {
|
||||
(lastBlock as any).cache_control = cacheControl;
|
||||
}
|
||||
} else if (typeof lastMessage.content === "string") {
|
||||
lastMessage.content = [
|
||||
{
|
||||
type: "text",
|
||||
text: lastMessage.content,
|
||||
cache_control: cacheControl,
|
||||
},
|
||||
] as any;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(tools: Tool[], isOAuthToken: boolean): Anthropic.Messages.Tool[] {
|
||||
if (!tools) return [];
|
||||
|
||||
return tools.map((tool) => {
|
||||
const jsonSchema = tool.parameters as any; // TypeBox already generates JSON Schema
|
||||
|
||||
return {
|
||||
name: isOAuthToken ? toClaudeCodeName(tool.name) : tool.name,
|
||||
description: tool.description,
|
||||
input_schema: {
|
||||
type: "object" as const,
|
||||
properties: jsonSchema.properties || {},
|
||||
required: jsonSchema.required || [],
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function mapStopReason(reason: Anthropic.Messages.StopReason | string): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
case "refusal":
|
||||
return "error";
|
||||
case "pause_turn": // Stop is good enough -> resubmit
|
||||
return "stop";
|
||||
case "stop_sequence":
|
||||
return "stop"; // We don't supply stop sequences, so this should never happen
|
||||
case "sensitive": // Content flagged by safety filters (not yet in SDK types)
|
||||
return "error";
|
||||
default:
|
||||
// Handle unknown stop reasons gracefully (API may add new values)
|
||||
throw new Error(`Unhandled stop reason: ${reason}`);
|
||||
}
|
||||
}
|
||||
259
packages/pi-ai/src/providers/azure-openai-responses.ts
Normal file
259
packages/pi-ai/src/providers/azure-openai-responses.ts
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
import { AzureOpenAI } from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const DEFAULT_AZURE_API_VERSION = "v1";
|
||||
const AZURE_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode", "azure-openai-responses"]);
|
||||
|
||||
function parseDeploymentNameMap(value: string | undefined): Map<string, string> {
|
||||
const map = new Map<string, string>();
|
||||
if (!value) return map;
|
||||
for (const entry of value.split(",")) {
|
||||
const trimmed = entry.trim();
|
||||
if (!trimmed) continue;
|
||||
const [modelId, deploymentName] = trimmed.split("=", 2);
|
||||
if (!modelId || !deploymentName) continue;
|
||||
map.set(modelId.trim(), deploymentName.trim());
|
||||
}
|
||||
return map;
|
||||
}
|
||||
|
||||
function resolveDeploymentName(model: Model<"azure-openai-responses">, options?: AzureOpenAIResponsesOptions): string {
|
||||
if (options?.azureDeploymentName) {
|
||||
return options.azureDeploymentName;
|
||||
}
|
||||
const mappedDeployment = parseDeploymentNameMap(process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP).get(model.id);
|
||||
return mappedDeployment || model.id;
|
||||
}
|
||||
|
||||
// Azure OpenAI Responses-specific options
|
||||
export interface AzureOpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
azureApiVersion?: string;
|
||||
azureResourceName?: string;
|
||||
azureBaseUrl?: string;
|
||||
azureDeploymentName?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for Azure OpenAI Responses API
|
||||
*/
|
||||
export const streamAzureOpenAIResponses: StreamFunction<"azure-openai-responses", AzureOpenAIResponsesOptions> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const deploymentName = resolveDeploymentName(model, options);
|
||||
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "azure-openai-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create Azure OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options);
|
||||
let params = buildParams(model, context, options, deploymentName);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleAzureOpenAIResponses: StreamFunction<"azure-openai-responses", SimpleStreamOptions> = (
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamAzureOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies AzureOpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function normalizeAzureBaseUrl(baseUrl: string): string {
|
||||
return baseUrl.replace(/\/+$/, "");
|
||||
}
|
||||
|
||||
function buildDefaultBaseUrl(resourceName: string): string {
|
||||
return `https://${resourceName}.openai.azure.com/openai/v1`;
|
||||
}
|
||||
|
||||
function resolveAzureConfig(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): { baseUrl: string; apiVersion: string } {
|
||||
const apiVersion = options?.azureApiVersion || process.env.AZURE_OPENAI_API_VERSION || DEFAULT_AZURE_API_VERSION;
|
||||
|
||||
const baseUrl = options?.azureBaseUrl?.trim() || process.env.AZURE_OPENAI_BASE_URL?.trim() || undefined;
|
||||
const resourceName = options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
|
||||
|
||||
let resolvedBaseUrl = baseUrl;
|
||||
|
||||
if (!resolvedBaseUrl && resourceName) {
|
||||
resolvedBaseUrl = buildDefaultBaseUrl(resourceName);
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl && model.baseUrl) {
|
||||
resolvedBaseUrl = model.baseUrl;
|
||||
}
|
||||
|
||||
if (!resolvedBaseUrl) {
|
||||
throw new Error(
|
||||
"Azure OpenAI base URL is required. Set AZURE_OPENAI_BASE_URL or AZURE_OPENAI_RESOURCE_NAME, or pass azureBaseUrl, azureResourceName, or model.baseUrl.",
|
||||
);
|
||||
}
|
||||
|
||||
return {
|
||||
baseUrl: normalizeAzureBaseUrl(resolvedBaseUrl),
|
||||
apiVersion,
|
||||
};
|
||||
}
|
||||
|
||||
function createClient(model: Model<"azure-openai-responses">, apiKey: string, options?: AzureOpenAIResponsesOptions) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.AZURE_OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"Azure OpenAI API key is required. Set AZURE_OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.AZURE_OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
|
||||
if (options?.headers) {
|
||||
Object.assign(headers, options.headers);
|
||||
}
|
||||
|
||||
const { baseUrl, apiVersion } = resolveAzureConfig(model, options);
|
||||
|
||||
return new AzureOpenAI({
|
||||
apiKey,
|
||||
apiVersion,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
baseURL: baseUrl,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"azure-openai-responses">,
|
||||
context: Context,
|
||||
options: AzureOpenAIResponsesOptions | undefined,
|
||||
deploymentName: string,
|
||||
) {
|
||||
const messages = convertResponsesMessages(model, context, AZURE_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: deploymentName,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: options?.sessionId,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
if (model.name.toLowerCase().startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
37
packages/pi-ai/src/providers/github-copilot-headers.ts
Normal file
37
packages/pi-ai/src/providers/github-copilot-headers.ts
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import type { Message } from "../types.js";
|
||||
|
||||
// Copilot expects X-Initiator to indicate whether the request is user-initiated
|
||||
// or agent-initiated (e.g. follow-up after assistant/tool messages).
|
||||
export function inferCopilotInitiator(messages: Message[]): "user" | "agent" {
|
||||
const last = messages[messages.length - 1];
|
||||
return last && last.role !== "user" ? "agent" : "user";
|
||||
}
|
||||
|
||||
// Copilot requires Copilot-Vision-Request header when sending images
|
||||
export function hasCopilotVisionInput(messages: Message[]): boolean {
|
||||
return messages.some((msg) => {
|
||||
if (msg.role === "user" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
if (msg.role === "toolResult" && Array.isArray(msg.content)) {
|
||||
return msg.content.some((c) => c.type === "image");
|
||||
}
|
||||
return false;
|
||||
});
|
||||
}
|
||||
|
||||
export function buildCopilotDynamicHeaders(params: {
|
||||
messages: Message[];
|
||||
hasImages: boolean;
|
||||
}): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"X-Initiator": inferCopilotInitiator(params.messages),
|
||||
"Openai-Intent": "conversation-edits",
|
||||
};
|
||||
|
||||
if (params.hasImages) {
|
||||
headers["Copilot-Vision-Request"] = "true";
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
967
packages/pi-ai/src/providers/google-gemini-cli.ts
Normal file
967
packages/pi-ai/src/providers/google-gemini-cli.ts
Normal file
|
|
@ -0,0 +1,967 @@
|
|||
/**
|
||||
* Google Gemini CLI / Antigravity provider.
|
||||
* Shared implementation for both google-gemini-cli and google-antigravity providers.
|
||||
* Uses the Cloud Code Assist API endpoint to access Gemini and Claude models.
|
||||
*/
|
||||
|
||||
import type { Content, ThinkingConfig } from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReasonString,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
/**
|
||||
* Thinking level for Gemini 3 models.
|
||||
* Mirrors Google's ThinkingLevel enum values.
|
||||
*/
|
||||
export type GoogleThinkingLevel = "THINKING_LEVEL_UNSPECIFIED" | "MINIMAL" | "LOW" | "MEDIUM" | "HIGH";
|
||||
|
||||
export interface GoogleGeminiCliOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
/**
|
||||
* Thinking/reasoning configuration.
|
||||
* - Gemini 2.x models: use `budgetTokens` to set the thinking budget
|
||||
* - Gemini 3 models (gemini-3-pro-*, gemini-3-flash-*): use `level` instead
|
||||
*
|
||||
* When using `streamSimple`, this is handled automatically based on the model.
|
||||
*/
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
/** Thinking budget in tokens. Use for Gemini 2.x models. */
|
||||
budgetTokens?: number;
|
||||
/** Thinking level. Use for Gemini 3 models (LOW/HIGH for Pro, MINIMAL/LOW/MEDIUM/HIGH for Flash). */
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
projectId?: string;
|
||||
}
|
||||
|
||||
const DEFAULT_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
const ANTIGRAVITY_DAILY_ENDPOINT = "https://daily-cloudcode-pa.sandbox.googleapis.com";
|
||||
const ANTIGRAVITY_AUTOPUSH_ENDPOINT = "https://autopush-cloudcode-pa.sandbox.googleapis.com";
|
||||
const ANTIGRAVITY_ENDPOINT_FALLBACKS = [
|
||||
ANTIGRAVITY_DAILY_ENDPOINT,
|
||||
ANTIGRAVITY_AUTOPUSH_ENDPOINT,
|
||||
DEFAULT_ENDPOINT,
|
||||
] as const;
|
||||
// Headers for Gemini CLI (prod endpoint)
|
||||
const GEMINI_CLI_HEADERS = {
|
||||
"User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
"Client-Metadata": JSON.stringify({
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
}),
|
||||
};
|
||||
|
||||
// Headers for Antigravity (sandbox endpoint) - requires specific User-Agent
|
||||
const DEFAULT_ANTIGRAVITY_VERSION = "1.18.4";
|
||||
|
||||
function getAntigravityHeaders() {
|
||||
const version = process.env.PI_AI_ANTIGRAVITY_VERSION || DEFAULT_ANTIGRAVITY_VERSION;
|
||||
return {
|
||||
"User-Agent": `antigravity/${version} darwin/arm64`,
|
||||
};
|
||||
}
|
||||
|
||||
// Antigravity system instruction (compact version from CLIProxyAPI).
|
||||
const ANTIGRAVITY_SYSTEM_INSTRUCTION =
|
||||
"You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding." +
|
||||
"You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question." +
|
||||
"**Absolute paths only**" +
|
||||
"**Proactiveness**";
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
// Retry configuration
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 1000;
|
||||
const MAX_EMPTY_STREAM_RETRIES = 2;
|
||||
const EMPTY_STREAM_BASE_DELAY_MS = 500;
|
||||
const CLAUDE_THINKING_BETA_HEADER = "interleaved-thinking-2025-05-14";
|
||||
|
||||
/**
|
||||
* Extract retry delay from Gemini error response (in milliseconds).
|
||||
* Checks headers first (Retry-After, x-ratelimit-reset, x-ratelimit-reset-after),
|
||||
* then parses body patterns like:
|
||||
* - "Your quota will reset after 39s"
|
||||
* - "Your quota will reset after 18h31m10s"
|
||||
* - "Please retry in Xs" or "Please retry in Xms"
|
||||
* - "retryDelay": "34.074824224s" (JSON field)
|
||||
*/
|
||||
export function extractRetryDelay(errorText: string, response?: Response | Headers): number | undefined {
|
||||
const normalizeDelay = (ms: number): number | undefined => (ms > 0 ? Math.ceil(ms + 1000) : undefined);
|
||||
|
||||
const headers = response instanceof Headers ? response : response?.headers;
|
||||
if (headers) {
|
||||
const retryAfter = headers.get("retry-after");
|
||||
if (retryAfter) {
|
||||
const retryAfterSeconds = Number(retryAfter);
|
||||
if (Number.isFinite(retryAfterSeconds)) {
|
||||
const delay = normalizeDelay(retryAfterSeconds * 1000);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
const retryAfterDate = new Date(retryAfter);
|
||||
const retryAfterMs = retryAfterDate.getTime();
|
||||
if (!Number.isNaN(retryAfterMs)) {
|
||||
const delay = normalizeDelay(retryAfterMs - Date.now());
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rateLimitReset = headers.get("x-ratelimit-reset");
|
||||
if (rateLimitReset) {
|
||||
const resetSeconds = Number.parseInt(rateLimitReset, 10);
|
||||
if (!Number.isNaN(resetSeconds)) {
|
||||
const delay = normalizeDelay(resetSeconds * 1000 - Date.now());
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const rateLimitResetAfter = headers.get("x-ratelimit-reset-after");
|
||||
if (rateLimitResetAfter) {
|
||||
const resetAfterSeconds = Number(rateLimitResetAfter);
|
||||
if (Number.isFinite(resetAfterSeconds)) {
|
||||
const delay = normalizeDelay(resetAfterSeconds * 1000);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 1: "Your quota will reset after ..." (formats: "18h31m10s", "10m15s", "6s", "39s")
|
||||
const durationMatch = errorText.match(/reset after (?:(\d+)h)?(?:(\d+)m)?(\d+(?:\.\d+)?)s/i);
|
||||
if (durationMatch) {
|
||||
const hours = durationMatch[1] ? parseInt(durationMatch[1], 10) : 0;
|
||||
const minutes = durationMatch[2] ? parseInt(durationMatch[2], 10) : 0;
|
||||
const seconds = parseFloat(durationMatch[3]);
|
||||
if (!Number.isNaN(seconds)) {
|
||||
const totalMs = ((hours * 60 + minutes) * 60 + seconds) * 1000;
|
||||
const delay = normalizeDelay(totalMs);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: "Please retry in X[ms|s]"
|
||||
const retryInMatch = errorText.match(/Please retry in ([0-9.]+)(ms|s)/i);
|
||||
if (retryInMatch?.[1]) {
|
||||
const value = parseFloat(retryInMatch[1]);
|
||||
if (!Number.isNaN(value) && value > 0) {
|
||||
const ms = retryInMatch[2].toLowerCase() === "ms" ? value : value * 1000;
|
||||
const delay = normalizeDelay(ms);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 3: "retryDelay": "34.074824224s" (JSON field in error details)
|
||||
const retryDelayMatch = errorText.match(/"retryDelay":\s*"([0-9.]+)(ms|s)"/i);
|
||||
if (retryDelayMatch?.[1]) {
|
||||
const value = parseFloat(retryDelayMatch[1]);
|
||||
if (!Number.isNaN(value) && value > 0) {
|
||||
const ms = retryDelayMatch[2].toLowerCase() === "ms" ? value : value * 1000;
|
||||
const delay = normalizeDelay(ms);
|
||||
if (delay !== undefined) {
|
||||
return delay;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function needsClaudeThinkingBetaHeader(model: Model<"google-gemini-cli">): boolean {
|
||||
return model.provider === "google-antigravity" && model.id.startsWith("claude-") && model.reasoning;
|
||||
}
|
||||
|
||||
function isGemini3ProModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-pro/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(modelId: string): boolean {
|
||||
return /gemini-3(?:\.1)?-flash/.test(modelId.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3Model(modelId: string): boolean {
|
||||
return isGemini3ProModel(modelId) || isGemini3FlashModel(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if an error is retryable (rate limit, server error, network error, etc.)
|
||||
*/
|
||||
function isRetryableError(status: number, errorText: string): boolean {
|
||||
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
|
||||
return true;
|
||||
}
|
||||
return /resource.?exhausted|rate.?limit|overloaded|service.?unavailable|other.?side.?closed/i.test(errorText);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a clean, user-friendly error message from Google API error response.
|
||||
* Parses JSON error responses and returns just the message field.
|
||||
*/
|
||||
function extractErrorMessage(errorText: string): string {
|
||||
try {
|
||||
const parsed = JSON.parse(errorText) as { error?: { message?: string } };
|
||||
if (parsed.error?.message) {
|
||||
return parsed.error.message;
|
||||
}
|
||||
} catch {
|
||||
// Not JSON, return as-is
|
||||
}
|
||||
return errorText;
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep for a given number of milliseconds, respecting abort signal.
|
||||
*/
|
||||
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Request was aborted"));
|
||||
return;
|
||||
}
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
signal?.addEventListener("abort", () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Request was aborted"));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
interface CloudCodeAssistRequest {
|
||||
project: string;
|
||||
model: string;
|
||||
request: {
|
||||
contents: Content[];
|
||||
sessionId?: string;
|
||||
systemInstruction?: { role?: string; parts: { text: string }[] };
|
||||
generationConfig?: {
|
||||
maxOutputTokens?: number;
|
||||
temperature?: number;
|
||||
thinkingConfig?: ThinkingConfig;
|
||||
};
|
||||
tools?: ReturnType<typeof convertTools>;
|
||||
toolConfig?: {
|
||||
functionCallingConfig: {
|
||||
mode: ReturnType<typeof mapToolChoice>;
|
||||
};
|
||||
};
|
||||
};
|
||||
requestType?: string;
|
||||
userAgent?: string;
|
||||
requestId?: string;
|
||||
}
|
||||
|
||||
interface CloudCodeAssistResponseChunk {
|
||||
response?: {
|
||||
candidates?: Array<{
|
||||
content?: {
|
||||
role: string;
|
||||
parts?: Array<{
|
||||
text?: string;
|
||||
thought?: boolean;
|
||||
thoughtSignature?: string;
|
||||
functionCall?: {
|
||||
name: string;
|
||||
args: Record<string, unknown>;
|
||||
id?: string;
|
||||
};
|
||||
}>;
|
||||
};
|
||||
finishReason?: string;
|
||||
}>;
|
||||
usageMetadata?: {
|
||||
promptTokenCount?: number;
|
||||
candidatesTokenCount?: number;
|
||||
thoughtsTokenCount?: number;
|
||||
totalTokenCount?: number;
|
||||
cachedContentTokenCount?: number;
|
||||
};
|
||||
modelVersion?: string;
|
||||
responseId?: string;
|
||||
};
|
||||
traceId?: string;
|
||||
}
|
||||
|
||||
export const streamGoogleGeminiCli: StreamFunction<"google-gemini-cli", GoogleGeminiCliOptions> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: GoogleGeminiCliOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-gemini-cli" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// apiKey is JSON-encoded: { token, projectId }
|
||||
const apiKeyRaw = options?.apiKey;
|
||||
if (!apiKeyRaw) {
|
||||
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
|
||||
}
|
||||
|
||||
let accessToken: string;
|
||||
let projectId: string;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(apiKeyRaw) as { token: string; projectId: string };
|
||||
accessToken = parsed.token;
|
||||
projectId = parsed.projectId;
|
||||
} catch {
|
||||
throw new Error("Invalid Google Cloud Code Assist credentials. Use /login to re-authenticate.");
|
||||
}
|
||||
|
||||
if (!accessToken || !projectId) {
|
||||
throw new Error("Missing token or projectId in Google Cloud credentials. Use /login to re-authenticate.");
|
||||
}
|
||||
|
||||
const isAntigravity = model.provider === "google-antigravity";
|
||||
const baseUrl = model.baseUrl?.trim();
|
||||
const endpoints = baseUrl ? [baseUrl] : isAntigravity ? ANTIGRAVITY_ENDPOINT_FALLBACKS : [DEFAULT_ENDPOINT];
|
||||
|
||||
let requestBody = buildRequest(model, context, projectId, options, isAntigravity);
|
||||
const nextRequestBody = await options?.onPayload?.(requestBody, model);
|
||||
if (nextRequestBody !== undefined) {
|
||||
requestBody = nextRequestBody as CloudCodeAssistRequest;
|
||||
}
|
||||
const headers = isAntigravity ? getAntigravityHeaders() : GEMINI_CLI_HEADERS;
|
||||
|
||||
const requestHeaders = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
Accept: "text/event-stream",
|
||||
...headers,
|
||||
...(needsClaudeThinkingBetaHeader(model) ? { "anthropic-beta": CLAUDE_THINKING_BETA_HEADER } : {}),
|
||||
...options?.headers,
|
||||
};
|
||||
const requestBodyJson = JSON.stringify(requestBody);
|
||||
|
||||
// Fetch with retry logic for rate limits, transient errors, and endpoint fallbacks.
|
||||
// On 403/404, immediately try the next endpoint (no delay).
|
||||
// On 429/5xx, retry with backoff on the same or next endpoint.
|
||||
let response: Response | undefined;
|
||||
let lastError: Error | undefined;
|
||||
let requestUrl: string | undefined;
|
||||
let endpointIndex = 0;
|
||||
|
||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
try {
|
||||
const endpoint = endpoints[endpointIndex];
|
||||
requestUrl = `${endpoint}/v1internal:streamGenerateContent?alt=sse`;
|
||||
response = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: requestHeaders,
|
||||
body: requestBodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
break; // Success, exit retry loop
|
||||
}
|
||||
|
||||
const errorText = await response.text();
|
||||
|
||||
// On 403/404, cascade to the next endpoint immediately (no delay)
|
||||
if ((response.status === 403 || response.status === 404) && endpointIndex < endpoints.length - 1) {
|
||||
endpointIndex++;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if retryable (429, 5xx, network patterns)
|
||||
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
|
||||
// Advance endpoint if possible
|
||||
if (endpointIndex < endpoints.length - 1) {
|
||||
endpointIndex++;
|
||||
}
|
||||
|
||||
// Use server-provided delay or exponential backoff
|
||||
const serverDelay = extractRetryDelay(errorText, response);
|
||||
const delayMs = serverDelay ?? BASE_DELAY_MS * 2 ** attempt;
|
||||
|
||||
// Check if server delay exceeds max allowed (default: 60s)
|
||||
const maxDelayMs = options?.maxRetryDelayMs ?? 60000;
|
||||
if (maxDelayMs > 0 && serverDelay && serverDelay > maxDelayMs) {
|
||||
const delaySeconds = Math.ceil(serverDelay / 1000);
|
||||
throw new Error(
|
||||
`Server requested ${delaySeconds}s retry delay (max: ${Math.ceil(maxDelayMs / 1000)}s). ${extractErrorMessage(errorText)}`,
|
||||
);
|
||||
}
|
||||
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Not retryable or max retries exceeded
|
||||
throw new Error(`Cloud Code Assist API error (${response.status}): ${extractErrorMessage(errorText)}`);
|
||||
} catch (error) {
|
||||
// Check for abort - fetch throws AbortError, our code throws "Request was aborted"
|
||||
if (error instanceof Error) {
|
||||
if (error.name === "AbortError" || error.message === "Request was aborted") {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
// Extract detailed error message from fetch errors (Node includes cause)
|
||||
lastError = error instanceof Error ? error : new Error(String(error));
|
||||
if (lastError.message === "fetch failed" && lastError.cause instanceof Error) {
|
||||
lastError = new Error(`Network error: ${lastError.cause.message}`);
|
||||
}
|
||||
// Network errors are retryable
|
||||
if (attempt < MAX_RETRIES) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response || !response.ok) {
|
||||
throw lastError ?? new Error("Failed to get response after retries");
|
||||
}
|
||||
|
||||
let started = false;
|
||||
const ensureStarted = () => {
|
||||
if (!started) {
|
||||
stream.push({ type: "start", partial: output });
|
||||
started = true;
|
||||
}
|
||||
};
|
||||
|
||||
const resetOutput = () => {
|
||||
output.content = [];
|
||||
output.usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
output.stopReason = "stop";
|
||||
output.errorMessage = undefined;
|
||||
output.timestamp = Date.now();
|
||||
started = false;
|
||||
};
|
||||
|
||||
const streamResponse = async (activeResponse: Response): Promise<boolean> => {
|
||||
if (!activeResponse.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
let hasContent = false;
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
// Read SSE stream
|
||||
const reader = activeResponse.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
// Set up abort handler to cancel reader when signal fires
|
||||
const abortHandler = () => {
|
||||
void reader.cancel().catch(() => {});
|
||||
};
|
||||
options?.signal?.addEventListener("abort", abortHandler);
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
// Check abort signal before each read
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() || "";
|
||||
|
||||
for (const line of lines) {
|
||||
if (!line.startsWith("data:")) continue;
|
||||
|
||||
const jsonStr = line.slice(5).trim();
|
||||
if (!jsonStr) continue;
|
||||
|
||||
let chunk: CloudCodeAssistResponseChunk;
|
||||
try {
|
||||
chunk = JSON.parse(jsonStr);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Unwrap the response
|
||||
const responseData = chunk.response;
|
||||
if (!responseData) continue;
|
||||
|
||||
const candidate = responseData.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
hasContent = true;
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex: blockIndex(),
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
ensureStarted();
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
hasContent = true;
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId ||
|
||||
output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, unknown>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
ensureStarted();
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReasonString(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (responseData.usageMetadata) {
|
||||
// promptTokenCount includes cachedContentTokenCount, so subtract to get fresh input
|
||||
const promptTokens = responseData.usageMetadata.promptTokenCount || 0;
|
||||
const cacheReadTokens = responseData.usageMetadata.cachedContentTokenCount || 0;
|
||||
output.usage = {
|
||||
input: promptTokens - cacheReadTokens,
|
||||
output:
|
||||
(responseData.usageMetadata.candidatesTokenCount || 0) +
|
||||
(responseData.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: cacheReadTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: responseData.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
options?.signal?.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
return hasContent;
|
||||
};
|
||||
|
||||
let receivedContent = false;
|
||||
let currentResponse = response;
|
||||
|
||||
for (let emptyAttempt = 0; emptyAttempt <= MAX_EMPTY_STREAM_RETRIES; emptyAttempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (emptyAttempt > 0) {
|
||||
const backoffMs = EMPTY_STREAM_BASE_DELAY_MS * 2 ** (emptyAttempt - 1);
|
||||
await sleep(backoffMs, options?.signal);
|
||||
|
||||
if (!requestUrl) {
|
||||
throw new Error("Missing request URL");
|
||||
}
|
||||
|
||||
currentResponse = await fetch(requestUrl, {
|
||||
method: "POST",
|
||||
headers: requestHeaders,
|
||||
body: requestBodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (!currentResponse.ok) {
|
||||
const retryErrorText = await currentResponse.text();
|
||||
throw new Error(`Cloud Code Assist API error (${currentResponse.status}): ${retryErrorText}`);
|
||||
}
|
||||
}
|
||||
|
||||
const streamed = await streamResponse(currentResponse);
|
||||
if (streamed) {
|
||||
receivedContent = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if (emptyAttempt < MAX_EMPTY_STREAM_RETRIES) {
|
||||
resetOutput();
|
||||
}
|
||||
}
|
||||
|
||||
if (!receivedContent) {
|
||||
throw new Error("Cloud Code Assist API returned an empty response");
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleGeminiCli: StreamFunction<"google-gemini-cli", SimpleStreamOptions> = (
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey;
|
||||
if (!apiKey) {
|
||||
throw new Error("Google Cloud Code Assist requires OAuth authentication. Use /login to authenticate.");
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
if (isGemini3Model(model.id)) {
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGeminiCliThinkingLevel(effort, model.id),
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
}
|
||||
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...options.thinkingBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
let thinkingBudget = budgets[effort]!;
|
||||
const maxTokens = Math.min((base.maxTokens || 0) + thinkingBudget, model.maxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return streamGoogleGeminiCli(model, context, {
|
||||
...base,
|
||||
maxTokens,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: thinkingBudget,
|
||||
},
|
||||
} satisfies GoogleGeminiCliOptions);
|
||||
};
|
||||
|
||||
export function buildRequest(
|
||||
model: Model<"google-gemini-cli">,
|
||||
context: Context,
|
||||
projectId: string,
|
||||
options: GoogleGeminiCliOptions = {},
|
||||
isAntigravity = false,
|
||||
): CloudCodeAssistRequest {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: CloudCodeAssistRequest["request"]["generationConfig"] = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
// Thinking config
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
generationConfig.thinkingConfig = {
|
||||
includeThoughts: true,
|
||||
};
|
||||
// Gemini 3 models use thinkingLevel, older models use thinkingBudget
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
generationConfig.thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
generationConfig.thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
}
|
||||
|
||||
const request: CloudCodeAssistRequest["request"] = {
|
||||
contents,
|
||||
};
|
||||
|
||||
request.sessionId = options.sessionId;
|
||||
|
||||
// System instruction must be object with parts, not plain string
|
||||
if (context.systemPrompt) {
|
||||
request.systemInstruction = {
|
||||
parts: [{ text: sanitizeSurrogates(context.systemPrompt) }],
|
||||
};
|
||||
}
|
||||
|
||||
if (Object.keys(generationConfig).length > 0) {
|
||||
request.generationConfig = generationConfig;
|
||||
}
|
||||
|
||||
if (context.tools && context.tools.length > 0) {
|
||||
// Claude models on Cloud Code Assist need the legacy `parameters` field;
|
||||
// the API translates it into Anthropic's `input_schema`.
|
||||
const useParameters = model.id.startsWith("claude-");
|
||||
request.tools = convertTools(context.tools, useParameters);
|
||||
if (options.toolChoice) {
|
||||
request.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (isAntigravity) {
|
||||
const existingParts = request.systemInstruction?.parts ?? [];
|
||||
request.systemInstruction = {
|
||||
role: "user",
|
||||
parts: [
|
||||
{ text: ANTIGRAVITY_SYSTEM_INSTRUCTION },
|
||||
{ text: `Please ignore following [ignore]${ANTIGRAVITY_SYSTEM_INSTRUCTION}[/ignore]` },
|
||||
...existingParts,
|
||||
],
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
project: projectId,
|
||||
model: model.id,
|
||||
request,
|
||||
...(isAntigravity ? { requestType: "agent" } : {}),
|
||||
userAgent: isAntigravity ? "antigravity" : "pi-coding-agent",
|
||||
requestId: `${isAntigravity ? "agent" : "pi"}-${Date.now()}-${Math.random().toString(36).slice(2, 11)}`,
|
||||
};
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function getGeminiCliThinkingLevel(effort: ClampedThinkingLevel, modelId: string): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(modelId)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
313
packages/pi-ai/src/providers/google-shared.ts
Normal file
313
packages/pi-ai/src/providers/google-shared.ts
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
/**
|
||||
* Shared utilities for Google Generative AI and Google Cloud Code Assist providers.
|
||||
*/
|
||||
|
||||
import { type Content, FinishReason, FunctionCallingConfigMode, type Part } from "@google/genai";
|
||||
import type { Context, ImageContent, Model, StopReason, TextContent, Tool } from "../types.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
type GoogleApiType = "google-generative-ai" | "google-gemini-cli" | "google-vertex";
|
||||
|
||||
/**
|
||||
* Determines whether a streamed Gemini `Part` should be treated as "thinking".
|
||||
*
|
||||
* Protocol note (Gemini / Vertex AI thought signatures):
|
||||
* - `thought: true` is the definitive marker for thinking content (thought summaries).
|
||||
* - `thoughtSignature` is an encrypted representation of the model's internal thought process
|
||||
* used to preserve reasoning context across multi-turn interactions.
|
||||
* - `thoughtSignature` can appear on ANY part type (text, functionCall, etc.) - it does NOT
|
||||
* indicate the part itself is thinking content.
|
||||
* - For non-functionCall responses, the signature appears on the last part for context replay.
|
||||
* - When persisting/replaying model outputs, signature-bearing parts must be preserved as-is;
|
||||
* do not merge/move signatures across parts.
|
||||
*
|
||||
* See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
*/
|
||||
export function isThinkingPart(part: Pick<Part, "thought" | "thoughtSignature">): boolean {
|
||||
return part.thought === true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Retain thought signatures during streaming.
|
||||
*
|
||||
* Some backends only send `thoughtSignature` on the first delta for a given part/block; later deltas may omit it.
|
||||
* This helper preserves the last non-empty signature for the current block.
|
||||
*
|
||||
* Note: this does NOT merge or move signatures across distinct response parts. It only prevents
|
||||
* a signature from being overwritten with `undefined` within the same streamed block.
|
||||
*/
|
||||
export function retainThoughtSignature(existing: string | undefined, incoming: string | undefined): string | undefined {
|
||||
if (typeof incoming === "string" && incoming.length > 0) return incoming;
|
||||
return existing;
|
||||
}
|
||||
|
||||
// Thought signatures must be base64 for Google APIs (TYPE_BYTES).
|
||||
const base64SignaturePattern = /^[A-Za-z0-9+/]+={0,2}$/;
|
||||
|
||||
// Sentinel value that tells the Gemini API to skip thought signature validation.
|
||||
// Used for unsigned function call parts (e.g. replayed from providers without thought signatures).
|
||||
// See: https://ai.google.dev/gemini-api/docs/thought-signatures
|
||||
const SKIP_THOUGHT_SIGNATURE = "skip_thought_signature_validator";
|
||||
|
||||
function isValidThoughtSignature(signature: string | undefined): boolean {
|
||||
if (!signature) return false;
|
||||
if (signature.length % 4 !== 0) return false;
|
||||
return base64SignaturePattern.test(signature);
|
||||
}
|
||||
|
||||
/**
|
||||
* Only keep signatures from the same provider/model and with valid base64.
|
||||
*/
|
||||
function resolveThoughtSignature(isSameProviderAndModel: boolean, signature: string | undefined): string | undefined {
|
||||
return isSameProviderAndModel && isValidThoughtSignature(signature) ? signature : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Models via Google APIs that require explicit tool call IDs in function calls/responses.
|
||||
*/
|
||||
export function requiresToolCallId(modelId: string): boolean {
|
||||
return modelId.startsWith("claude-") || modelId.startsWith("gpt-oss-");
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert internal messages to Gemini Content[] format.
|
||||
*/
|
||||
export function convertMessages<T extends GoogleApiType>(model: Model<T>, context: Context): Content[] {
|
||||
const contents: Content[] = [];
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!requiresToolCallId(model.id)) return id;
|
||||
return id.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 64);
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const parts: Part[] = msg.content.map((item) => {
|
||||
if (item.type === "text") {
|
||||
return { text: sanitizeSurrogates(item.text) };
|
||||
} else {
|
||||
return {
|
||||
inlineData: {
|
||||
mimeType: item.mimeType,
|
||||
data: item.data,
|
||||
},
|
||||
};
|
||||
}
|
||||
});
|
||||
const filteredParts = !model.input.includes("image") ? parts.filter((p) => p.text !== undefined) : parts;
|
||||
if (filteredParts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: filteredParts,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const parts: Part[] = [];
|
||||
// Check if message is from same provider and model - only then keep thinking blocks
|
||||
const isSameProviderAndModel = msg.provider === model.provider && msg.model === model.id;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
// Skip empty text blocks - they can cause issues with some models (e.g. Claude via Antigravity)
|
||||
if (!block.text || block.text.trim() === "") continue;
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.textSignature);
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.text),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
// Skip empty thinking blocks
|
||||
if (!block.thinking || block.thinking.trim() === "") continue;
|
||||
// Only keep as thinking block if same provider AND same model
|
||||
// Otherwise convert to plain text (no tags to avoid model mimicking them)
|
||||
if (isSameProviderAndModel) {
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thinkingSignature);
|
||||
parts.push({
|
||||
thought: true,
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
...(thoughtSignature && { thoughtSignature }),
|
||||
});
|
||||
} else {
|
||||
parts.push({
|
||||
text: sanitizeSurrogates(block.thinking),
|
||||
});
|
||||
}
|
||||
} else if (block.type === "toolCall") {
|
||||
const thoughtSignature = resolveThoughtSignature(isSameProviderAndModel, block.thoughtSignature);
|
||||
// Gemini 3 requires thoughtSignature on all function calls when thinking mode is enabled.
|
||||
// Use the skip_thought_signature_validator sentinel for unsigned function calls
|
||||
// (e.g. replayed from providers without thought signatures like Claude via Antigravity).
|
||||
const isGemini3 = model.id.toLowerCase().includes("gemini-3");
|
||||
const effectiveSignature = thoughtSignature || (isGemini3 ? SKIP_THOUGHT_SIGNATURE : undefined);
|
||||
const part: Part = {
|
||||
functionCall: {
|
||||
name: block.name,
|
||||
args: block.arguments ?? {},
|
||||
...(requiresToolCallId(model.id) ? { id: block.id } : {}),
|
||||
},
|
||||
...(effectiveSignature && { thoughtSignature: effectiveSignature }),
|
||||
};
|
||||
parts.push(part);
|
||||
}
|
||||
}
|
||||
|
||||
if (parts.length === 0) continue;
|
||||
contents.push({
|
||||
role: "model",
|
||||
parts,
|
||||
});
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textContent = msg.content.filter((c): c is TextContent => c.type === "text");
|
||||
const textResult = textContent.map((c) => c.text).join("\n");
|
||||
const imageContent = model.input.includes("image")
|
||||
? msg.content.filter((c): c is ImageContent => c.type === "image")
|
||||
: [];
|
||||
|
||||
const hasText = textResult.length > 0;
|
||||
const hasImages = imageContent.length > 0;
|
||||
|
||||
// Gemini 3 supports multimodal function responses with images nested inside functionResponse.parts
|
||||
// See: https://ai.google.dev/gemini-api/docs/function-calling#multimodal
|
||||
// Older models don't support this, so we put images in a separate user message.
|
||||
const supportsMultimodalFunctionResponse = model.id.includes("gemini-3");
|
||||
|
||||
// Use "output" key for success, "error" key for errors as per SDK documentation
|
||||
const responseValue = hasText ? sanitizeSurrogates(textResult) : hasImages ? "(see attached image)" : "";
|
||||
|
||||
const imageParts: Part[] = imageContent.map((imageBlock) => ({
|
||||
inlineData: {
|
||||
mimeType: imageBlock.mimeType,
|
||||
data: imageBlock.data,
|
||||
},
|
||||
}));
|
||||
|
||||
const includeId = requiresToolCallId(model.id);
|
||||
const functionResponsePart: Part = {
|
||||
functionResponse: {
|
||||
name: msg.toolName,
|
||||
response: msg.isError ? { error: responseValue } : { output: responseValue },
|
||||
// Nest images inside functionResponse.parts for Gemini 3
|
||||
...(hasImages && supportsMultimodalFunctionResponse && { parts: imageParts }),
|
||||
...(includeId ? { id: msg.toolCallId } : {}),
|
||||
},
|
||||
};
|
||||
|
||||
// Cloud Code Assist API requires all function responses to be in a single user turn.
|
||||
// Check if the last content is already a user turn with function responses and merge.
|
||||
const lastContent = contents[contents.length - 1];
|
||||
if (lastContent?.role === "user" && lastContent.parts?.some((p) => p.functionResponse)) {
|
||||
lastContent.parts.push(functionResponsePart);
|
||||
} else {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [functionResponsePart],
|
||||
});
|
||||
}
|
||||
|
||||
// For older models, add images in a separate user message
|
||||
if (hasImages && !supportsMultimodalFunctionResponse) {
|
||||
contents.push({
|
||||
role: "user",
|
||||
parts: [{ text: "Tool result image:" }, ...imageParts],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return contents;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert tools to Gemini function declarations format.
|
||||
*
|
||||
* By default uses `parametersJsonSchema` which supports full JSON Schema (including
|
||||
* anyOf, oneOf, const, etc.). Set `useParameters` to true to use the legacy `parameters`
|
||||
* field instead (OpenAPI 3.03 Schema). This is needed for Cloud Code Assist with Claude
|
||||
* models, where the API translates `parameters` into Anthropic's `input_schema`.
|
||||
*/
|
||||
export function convertTools(
|
||||
tools: Tool[],
|
||||
useParameters = false,
|
||||
): { functionDeclarations: Record<string, unknown>[] }[] | undefined {
|
||||
if (tools.length === 0) return undefined;
|
||||
return [
|
||||
{
|
||||
functionDeclarations: tools.map((tool) => ({
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
...(useParameters ? { parameters: tool.parameters } : { parametersJsonSchema: tool.parameters }),
|
||||
})),
|
||||
},
|
||||
];
|
||||
}
|
||||
|
||||
/**
|
||||
* Map tool choice string to Gemini FunctionCallingConfigMode.
|
||||
*/
|
||||
export function mapToolChoice(choice: string): FunctionCallingConfigMode {
|
||||
switch (choice) {
|
||||
case "auto":
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
case "none":
|
||||
return FunctionCallingConfigMode.NONE;
|
||||
case "any":
|
||||
return FunctionCallingConfigMode.ANY;
|
||||
default:
|
||||
return FunctionCallingConfigMode.AUTO;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map Gemini FinishReason to our StopReason.
|
||||
*/
|
||||
export function mapStopReason(reason: FinishReason): StopReason {
|
||||
switch (reason) {
|
||||
case FinishReason.STOP:
|
||||
return "stop";
|
||||
case FinishReason.MAX_TOKENS:
|
||||
return "length";
|
||||
case FinishReason.BLOCKLIST:
|
||||
case FinishReason.PROHIBITED_CONTENT:
|
||||
case FinishReason.SPII:
|
||||
case FinishReason.SAFETY:
|
||||
case FinishReason.IMAGE_SAFETY:
|
||||
case FinishReason.IMAGE_PROHIBITED_CONTENT:
|
||||
case FinishReason.IMAGE_RECITATION:
|
||||
case FinishReason.IMAGE_OTHER:
|
||||
case FinishReason.RECITATION:
|
||||
case FinishReason.FINISH_REASON_UNSPECIFIED:
|
||||
case FinishReason.OTHER:
|
||||
case FinishReason.LANGUAGE:
|
||||
case FinishReason.MALFORMED_FUNCTION_CALL:
|
||||
case FinishReason.UNEXPECTED_TOOL_CALL:
|
||||
case FinishReason.NO_IMAGE:
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Map string finish reason to our StopReason (for raw API responses).
|
||||
*/
|
||||
export function mapStopReasonString(reason: string): StopReason {
|
||||
switch (reason) {
|
||||
case "STOP":
|
||||
return "stop";
|
||||
case "MAX_TOKENS":
|
||||
return "length";
|
||||
default:
|
||||
return "error";
|
||||
}
|
||||
}
|
||||
485
packages/pi-ai/src/providers/google-vertex.ts
Normal file
485
packages/pi-ai/src/providers/google-vertex.ts
Normal file
|
|
@ -0,0 +1,485 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
ThinkingLevel,
|
||||
} from "@google/genai";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
ThinkingLevel as PiThinkingLevel,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleVertexOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
project?: string;
|
||||
location?: string;
|
||||
}
|
||||
|
||||
const API_VERSION = "v1";
|
||||
|
||||
const THINKING_LEVEL_MAP: Record<GoogleThinkingLevel, ThinkingLevel> = {
|
||||
THINKING_LEVEL_UNSPECIFIED: ThinkingLevel.THINKING_LEVEL_UNSPECIFIED,
|
||||
MINIMAL: ThinkingLevel.MINIMAL,
|
||||
LOW: ThinkingLevel.LOW,
|
||||
MEDIUM: ThinkingLevel.MEDIUM,
|
||||
HIGH: ThinkingLevel.HIGH,
|
||||
};
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOptions> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: GoogleVertexOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-vertex" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const project = resolveProject(options);
|
||||
const location = resolveLocation(options);
|
||||
const client = createClient(model, project, location, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogleVertex: StreamFunction<"google-vertex", SimpleStreamOptions> = (
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const base = buildBaseOptions(model, options, undefined);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: { enabled: false },
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const geminiModel = model as unknown as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(geminiModel) || isGemini3FlashModel(geminiModel)) {
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, geminiModel),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
}
|
||||
|
||||
return streamGoogleVertex(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(geminiModel, effort, options.thinkingBudgets),
|
||||
},
|
||||
} satisfies GoogleVertexOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-vertex">,
|
||||
project: string,
|
||||
location: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { headers?: Record<string, string> } = {};
|
||||
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
const hasHttpOptions = Object.values(httpOptions).some(Boolean);
|
||||
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project,
|
||||
location,
|
||||
apiVersion: API_VERSION,
|
||||
httpOptions: hasHttpOptions ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function resolveProject(options?: GoogleVertexOptions): string {
|
||||
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
|
||||
if (!project) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
||||
);
|
||||
}
|
||||
return project;
|
||||
}
|
||||
|
||||
function resolveLocation(options?: GoogleVertexOptions): string {
|
||||
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
||||
if (!location) {
|
||||
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
|
||||
}
|
||||
return location;
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-vertex">,
|
||||
context: Context,
|
||||
options: GoogleVertexOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
||||
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
thinkingConfig.thinkingLevel = THINKING_LEVEL_MAP[options.thinking.level];
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<PiThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
455
packages/pi-ai/src/providers/google.ts
Normal file
455
packages/pi-ai/src/providers/google.ts
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
import {
|
||||
type GenerateContentConfig,
|
||||
type GenerateContentParameters,
|
||||
GoogleGenAI,
|
||||
type ThinkingConfig,
|
||||
} from "@google/genai";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingBudgets,
|
||||
ThinkingContent,
|
||||
ThinkingLevel,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import type { GoogleThinkingLevel } from "./google-gemini-cli.js";
|
||||
import {
|
||||
convertMessages,
|
||||
convertTools,
|
||||
isThinkingPart,
|
||||
mapStopReason,
|
||||
mapToolChoice,
|
||||
retainThoughtSignature,
|
||||
} from "./google-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
export interface GoogleOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any";
|
||||
thinking?: {
|
||||
enabled: boolean;
|
||||
budgetTokens?: number; // -1 for dynamic, 0 to disable
|
||||
level?: GoogleThinkingLevel;
|
||||
};
|
||||
}
|
||||
|
||||
// Counter for generating unique tool call IDs
|
||||
let toolCallCounter = 0;
|
||||
|
||||
export const streamGoogle: StreamFunction<"google-generative-ai", GoogleOptions> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: GoogleOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "google-generative-ai" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as GenerateContentParameters;
|
||||
}
|
||||
const googleStream = await client.models.generateContentStream(params);
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
for await (const chunk of googleStream) {
|
||||
const candidate = chunk.candidates?.[0];
|
||||
if (candidate?.content?.parts) {
|
||||
for (const part of candidate.content.parts) {
|
||||
if (part.text !== undefined) {
|
||||
const isThinking = isThinkingPart(part);
|
||||
if (
|
||||
!currentBlock ||
|
||||
(isThinking && currentBlock.type !== "thinking") ||
|
||||
(!isThinking && currentBlock.type !== "text")
|
||||
) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blocks.length - 1,
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
if (isThinking) {
|
||||
currentBlock = { type: "thinking", thinking: "", thinkingSignature: undefined };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else {
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
}
|
||||
if (currentBlock.type === "thinking") {
|
||||
currentBlock.thinking += part.text;
|
||||
currentBlock.thinkingSignature = retainThoughtSignature(
|
||||
currentBlock.thinkingSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
currentBlock.text += part.text;
|
||||
currentBlock.textSignature = retainThoughtSignature(
|
||||
currentBlock.textSignature,
|
||||
part.thoughtSignature,
|
||||
);
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: part.text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (part.functionCall) {
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlock = null;
|
||||
}
|
||||
|
||||
// Generate unique ID if not provided or if it's a duplicate
|
||||
const providedId = part.functionCall.id;
|
||||
const needsNewId =
|
||||
!providedId || output.content.some((b) => b.type === "toolCall" && b.id === providedId);
|
||||
const toolCallId = needsNewId
|
||||
? `${part.functionCall.name}_${Date.now()}_${++toolCallCounter}`
|
||||
: providedId;
|
||||
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: toolCallId,
|
||||
name: part.functionCall.name || "",
|
||||
arguments: (part.functionCall.args as Record<string, any>) ?? {},
|
||||
...(part.thoughtSignature && { thoughtSignature: part.thoughtSignature }),
|
||||
};
|
||||
|
||||
output.content.push(toolCall);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: JSON.stringify(toolCall.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (candidate?.finishReason) {
|
||||
output.stopReason = mapStopReason(candidate.finishReason);
|
||||
if (output.content.some((b) => b.type === "toolCall")) {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.usageMetadata) {
|
||||
output.usage = {
|
||||
input: chunk.usageMetadata.promptTokenCount || 0,
|
||||
output:
|
||||
(chunk.usageMetadata.candidatesTokenCount || 0) + (chunk.usageMetadata.thoughtsTokenCount || 0),
|
||||
cacheRead: chunk.usageMetadata.cachedContentTokenCount || 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: chunk.usageMetadata.totalTokenCount || 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
}
|
||||
|
||||
if (currentBlock) {
|
||||
if (currentBlock.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
// Remove internal index property used during streaming
|
||||
for (const block of output.content) {
|
||||
if ("index" in block) {
|
||||
delete (block as { index?: number }).index;
|
||||
}
|
||||
}
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleGoogle: StreamFunction<"google-generative-ai", SimpleStreamOptions> = (
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
if (!options?.reasoning) {
|
||||
return streamGoogle(model, context, { ...base, thinking: { enabled: false } } satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
const effort = clampReasoning(options.reasoning)!;
|
||||
const googleModel = model as Model<"google-generative-ai">;
|
||||
|
||||
if (isGemini3ProModel(googleModel) || isGemini3FlashModel(googleModel)) {
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
level: getGemini3ThinkingLevel(effort, googleModel),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
}
|
||||
|
||||
return streamGoogle(model, context, {
|
||||
...base,
|
||||
thinking: {
|
||||
enabled: true,
|
||||
budgetTokens: getGoogleBudget(googleModel, effort, options.thinkingBudgets),
|
||||
},
|
||||
} satisfies GoogleOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"google-generative-ai">,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
): GoogleGenAI {
|
||||
const httpOptions: { baseUrl?: string; apiVersion?: string; headers?: Record<string, string> } = {};
|
||||
if (model.baseUrl) {
|
||||
httpOptions.baseUrl = model.baseUrl;
|
||||
httpOptions.apiVersion = ""; // baseUrl already includes version path, don't append
|
||||
}
|
||||
if (model.headers || optionsHeaders) {
|
||||
httpOptions.headers = { ...model.headers, ...optionsHeaders };
|
||||
}
|
||||
|
||||
return new GoogleGenAI({
|
||||
apiKey,
|
||||
httpOptions: Object.keys(httpOptions).length > 0 ? httpOptions : undefined,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(
|
||||
model: Model<"google-generative-ai">,
|
||||
context: Context,
|
||||
options: GoogleOptions = {},
|
||||
): GenerateContentParameters {
|
||||
const contents = convertMessages(model, context);
|
||||
|
||||
const generationConfig: GenerateContentConfig = {};
|
||||
if (options.temperature !== undefined) {
|
||||
generationConfig.temperature = options.temperature;
|
||||
}
|
||||
if (options.maxTokens !== undefined) {
|
||||
generationConfig.maxOutputTokens = options.maxTokens;
|
||||
}
|
||||
|
||||
const config: GenerateContentConfig = {
|
||||
...(Object.keys(generationConfig).length > 0 && generationConfig),
|
||||
...(context.systemPrompt && { systemInstruction: sanitizeSurrogates(context.systemPrompt) }),
|
||||
...(context.tools && context.tools.length > 0 && { tools: convertTools(context.tools) }),
|
||||
};
|
||||
|
||||
if (context.tools && context.tools.length > 0 && options.toolChoice) {
|
||||
config.toolConfig = {
|
||||
functionCallingConfig: {
|
||||
mode: mapToolChoice(options.toolChoice),
|
||||
},
|
||||
};
|
||||
} else {
|
||||
config.toolConfig = undefined;
|
||||
}
|
||||
|
||||
if (options.thinking?.enabled && model.reasoning) {
|
||||
const thinkingConfig: ThinkingConfig = { includeThoughts: true };
|
||||
if (options.thinking.level !== undefined) {
|
||||
// Cast to any since our GoogleThinkingLevel mirrors Google's ThinkingLevel enum values
|
||||
thinkingConfig.thinkingLevel = options.thinking.level as any;
|
||||
} else if (options.thinking.budgetTokens !== undefined) {
|
||||
thinkingConfig.thinkingBudget = options.thinking.budgetTokens;
|
||||
}
|
||||
config.thinkingConfig = thinkingConfig;
|
||||
}
|
||||
|
||||
if (options.signal) {
|
||||
if (options.signal.aborted) {
|
||||
throw new Error("Request aborted");
|
||||
}
|
||||
config.abortSignal = options.signal;
|
||||
}
|
||||
|
||||
const params: GenerateContentParameters = {
|
||||
model: model.id,
|
||||
contents,
|
||||
config,
|
||||
};
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
type ClampedThinkingLevel = Exclude<ThinkingLevel, "xhigh">;
|
||||
|
||||
function isGemini3ProModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-pro/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function isGemini3FlashModel(model: Model<"google-generative-ai">): boolean {
|
||||
return /gemini-3(?:\.\d+)?-flash/.test(model.id.toLowerCase());
|
||||
}
|
||||
|
||||
function getGemini3ThinkingLevel(
|
||||
effort: ClampedThinkingLevel,
|
||||
model: Model<"google-generative-ai">,
|
||||
): GoogleThinkingLevel {
|
||||
if (isGemini3ProModel(model)) {
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
switch (effort) {
|
||||
case "minimal":
|
||||
return "MINIMAL";
|
||||
case "low":
|
||||
return "LOW";
|
||||
case "medium":
|
||||
return "MEDIUM";
|
||||
case "high":
|
||||
return "HIGH";
|
||||
}
|
||||
}
|
||||
|
||||
function getGoogleBudget(
|
||||
model: Model<"google-generative-ai">,
|
||||
effort: ClampedThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): number {
|
||||
if (customBudgets?.[effort] !== undefined) {
|
||||
return customBudgets[effort]!;
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-pro")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 32768,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
if (model.id.includes("2.5-flash")) {
|
||||
const budgets: Record<ClampedThinkingLevel, number> = {
|
||||
minimal: 128,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 24576,
|
||||
};
|
||||
return budgets[effort];
|
||||
}
|
||||
|
||||
return -1;
|
||||
}
|
||||
582
packages/pi-ai/src/providers/mistral.ts
Normal file
582
packages/pi-ai/src/providers/mistral.ts
Normal file
|
|
@ -0,0 +1,582 @@
|
|||
import { Mistral } from "@mistralai/mistralai";
|
||||
import type { RequestOptions } from "@mistralai/mistralai/lib/sdks.js";
|
||||
import type {
|
||||
ChatCompletionStreamRequest,
|
||||
ChatCompletionStreamRequestMessages,
|
||||
CompletionEvent,
|
||||
ContentChunk,
|
||||
FunctionTool,
|
||||
} from "@mistralai/mistralai/models/components/index.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
const MISTRAL_TOOL_CALL_ID_LENGTH = 9;
|
||||
const MAX_MISTRAL_ERROR_BODY_CHARS = 4000;
|
||||
|
||||
/**
|
||||
* Provider-specific options for the Mistral API.
|
||||
*/
|
||||
export interface MistralOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } };
|
||||
promptMode?: "reasoning";
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream responses from Mistral using `chat.stream`.
|
||||
*/
|
||||
export const streamMistral: StreamFunction<"mistral-conversations", MistralOptions> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: MistralOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output = createOutput(model);
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Intentionally per-request: avoids shared SDK mutable state across concurrent consumers.
|
||||
const mistral = new Mistral({
|
||||
apiKey,
|
||||
serverURL: model.baseUrl,
|
||||
});
|
||||
|
||||
const normalizeMistralToolCallId = createMistralToolCallIdNormalizer();
|
||||
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeMistralToolCallId(id));
|
||||
|
||||
let payload = buildChatPayload(model, context, transformedMessages, options);
|
||||
const nextPayload = await options?.onPayload?.(payload, model);
|
||||
if (nextPayload !== undefined) {
|
||||
payload = nextPayload as ChatCompletionStreamRequest;
|
||||
}
|
||||
const mistralStream = await mistral.chat.stream(payload, buildRequestOptions(model, options));
|
||||
stream.push({ type: "start", partial: output });
|
||||
await consumeChatStream(model, output, stream, mistralStream);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = formatMistralError(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
/**
|
||||
* Maps provider-agnostic `SimpleStreamOptions` to Mistral options.
|
||||
*/
|
||||
export const streamSimpleMistral: StreamFunction<"mistral-conversations", SimpleStreamOptions> = (
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoning = clampReasoning(options?.reasoning);
|
||||
|
||||
return streamMistral(model, context, {
|
||||
...base,
|
||||
promptMode: model.reasoning && reasoning ? "reasoning" : undefined,
|
||||
} satisfies MistralOptions);
|
||||
};
|
||||
|
||||
function createOutput(model: Model<"mistral-conversations">): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createMistralToolCallIdNormalizer(): (id: string) => string {
|
||||
const idMap = new Map<string, string>();
|
||||
const reverseMap = new Map<string, string>();
|
||||
|
||||
return (id: string): string => {
|
||||
const existing = idMap.get(id);
|
||||
if (existing) return existing;
|
||||
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
const candidate = deriveMistralToolCallId(id, attempt);
|
||||
const owner = reverseMap.get(candidate);
|
||||
if (!owner || owner === id) {
|
||||
idMap.set(id, candidate);
|
||||
reverseMap.set(candidate, id);
|
||||
return candidate;
|
||||
}
|
||||
attempt++;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
function deriveMistralToolCallId(id: string, attempt: number): string {
|
||||
const normalized = id.replace(/[^a-zA-Z0-9]/g, "");
|
||||
if (attempt === 0 && normalized.length === MISTRAL_TOOL_CALL_ID_LENGTH) return normalized;
|
||||
const seedBase = normalized || id;
|
||||
const seed = attempt === 0 ? seedBase : `${seedBase}:${attempt}`;
|
||||
return shortHash(seed)
|
||||
.replace(/[^a-zA-Z0-9]/g, "")
|
||||
.slice(0, MISTRAL_TOOL_CALL_ID_LENGTH);
|
||||
}
|
||||
|
||||
function formatMistralError(error: unknown): string {
|
||||
if (error instanceof Error) {
|
||||
const sdkError = error as Error & { statusCode?: unknown; body?: unknown };
|
||||
const statusCode = typeof sdkError.statusCode === "number" ? sdkError.statusCode : undefined;
|
||||
const bodyText = typeof sdkError.body === "string" ? sdkError.body.trim() : undefined;
|
||||
if (statusCode !== undefined && bodyText) {
|
||||
return `Mistral API error (${statusCode}): ${truncateErrorText(bodyText, MAX_MISTRAL_ERROR_BODY_CHARS)}`;
|
||||
}
|
||||
if (statusCode !== undefined) return `Mistral API error (${statusCode}): ${error.message}`;
|
||||
return error.message;
|
||||
}
|
||||
return safeJsonStringify(error);
|
||||
}
|
||||
|
||||
function truncateErrorText(text: string, maxChars: number): string {
|
||||
if (text.length <= maxChars) return text;
|
||||
return `${text.slice(0, maxChars)}... [truncated ${text.length - maxChars} chars]`;
|
||||
}
|
||||
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
try {
|
||||
const serialized = JSON.stringify(value);
|
||||
return serialized === undefined ? String(value) : serialized;
|
||||
} catch {
|
||||
return String(value);
|
||||
}
|
||||
}
|
||||
|
||||
function buildRequestOptions(model: Model<"mistral-conversations">, options?: MistralOptions): RequestOptions {
|
||||
const requestOptions: RequestOptions = {};
|
||||
if (options?.signal) requestOptions.signal = options.signal;
|
||||
requestOptions.retries = { strategy: "none" };
|
||||
|
||||
const headers: Record<string, string> = {};
|
||||
if (model.headers) Object.assign(headers, model.headers);
|
||||
if (options?.headers) Object.assign(headers, options.headers);
|
||||
|
||||
// Mistral infrastructure uses `x-affinity` for KV-cache reuse (prefix caching).
|
||||
// Respect explicit caller-provided header values.
|
||||
if (options?.sessionId && !headers["x-affinity"]) {
|
||||
headers["x-affinity"] = options.sessionId;
|
||||
}
|
||||
|
||||
if (Object.keys(headers).length > 0) {
|
||||
requestOptions.headers = headers;
|
||||
}
|
||||
|
||||
return requestOptions;
|
||||
}
|
||||
|
||||
function buildChatPayload(
|
||||
model: Model<"mistral-conversations">,
|
||||
context: Context,
|
||||
messages: Message[],
|
||||
options?: MistralOptions,
|
||||
): ChatCompletionStreamRequest {
|
||||
const payload: ChatCompletionStreamRequest = {
|
||||
model: model.id,
|
||||
stream: true,
|
||||
messages: toChatMessages(messages, model.input.includes("image")),
|
||||
};
|
||||
|
||||
if (context.tools?.length) payload.tools = toFunctionTools(context.tools);
|
||||
if (options?.temperature !== undefined) payload.temperature = options.temperature;
|
||||
if (options?.maxTokens !== undefined) payload.maxTokens = options.maxTokens;
|
||||
if (options?.toolChoice) payload.toolChoice = mapToolChoice(options.toolChoice);
|
||||
if (options?.promptMode) payload.promptMode = options.promptMode as any;
|
||||
|
||||
if (context.systemPrompt) {
|
||||
payload.messages.unshift({
|
||||
role: "system",
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
return payload;
|
||||
}
|
||||
|
||||
async function consumeChatStream(
|
||||
model: Model<"mistral-conversations">,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
mistralStream: AsyncIterable<CompletionEvent>,
|
||||
): Promise<void> {
|
||||
let currentBlock: TextContent | ThinkingContent | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const toolBlocksByKey = new Map<string, number>();
|
||||
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (!block) return;
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
for await (const event of mistralStream) {
|
||||
const chunk = event.data;
|
||||
|
||||
if (chunk.usage) {
|
||||
output.usage.input = chunk.usage.promptTokens || 0;
|
||||
output.usage.output = chunk.usage.completionTokens || 0;
|
||||
output.usage.cacheRead = 0;
|
||||
output.usage.cacheWrite = 0;
|
||||
output.usage.totalTokens = chunk.usage.totalTokens || output.usage.input + output.usage.output;
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finishReason) {
|
||||
output.stopReason = mapChatStopReason(choice.finishReason);
|
||||
}
|
||||
|
||||
const delta = choice.delta;
|
||||
if (delta.content !== null && delta.content !== undefined) {
|
||||
const contentItems = typeof delta.content === "string" ? [delta.content] : delta.content;
|
||||
for (const item of contentItems) {
|
||||
if (typeof item === "string") {
|
||||
const textDelta = sanitizeSurrogates(item);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "thinking") {
|
||||
const deltaText = item.thinking
|
||||
.map((part) => ("text" in part ? part.text : ""))
|
||||
.filter((text) => text.length > 0)
|
||||
.join("");
|
||||
const thinkingDelta = sanitizeSurrogates(deltaText);
|
||||
if (!thinkingDelta) continue;
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.thinking += thinkingDelta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: thinkingDelta,
|
||||
partial: output,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (item.type === "text") {
|
||||
const textDelta = sanitizeSurrogates(item.text);
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
currentBlock.text += textDelta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: textDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = delta.toolCalls || [];
|
||||
for (const toolCall of toolCalls) {
|
||||
if (currentBlock) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = null;
|
||||
}
|
||||
const callId =
|
||||
toolCall.id && toolCall.id !== "null"
|
||||
? toolCall.id
|
||||
: deriveMistralToolCallId(`toolcall:${toolCall.index ?? 0}`, 0);
|
||||
const key = `${callId}:${toolCall.index || 0}`;
|
||||
const existingIndex = toolBlocksByKey.get(key);
|
||||
let block: (ToolCall & { partialArgs?: string }) | undefined;
|
||||
|
||||
if (existingIndex !== undefined) {
|
||||
const existing = output.content[existingIndex];
|
||||
if (existing?.type === "toolCall") {
|
||||
block = existing as ToolCall & { partialArgs?: string };
|
||||
}
|
||||
}
|
||||
|
||||
if (!block) {
|
||||
block = {
|
||||
type: "toolCall",
|
||||
id: callId,
|
||||
name: toolCall.function.name,
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(block);
|
||||
toolBlocksByKey.set(key, output.content.length - 1);
|
||||
stream.push({ type: "toolcall_start", contentIndex: output.content.length - 1, partial: output });
|
||||
}
|
||||
|
||||
const argsDelta =
|
||||
typeof toolCall.function.arguments === "string"
|
||||
? toolCall.function.arguments
|
||||
: JSON.stringify(toolCall.function.arguments || {});
|
||||
block.partialArgs = (block.partialArgs || "") + argsDelta;
|
||||
block.arguments = parseStreamingJson<Record<string, unknown>>(block.partialArgs);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: toolBlocksByKey.get(key)!,
|
||||
delta: argsDelta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
for (const index of toolBlocksByKey.values()) {
|
||||
const block = output.content[index];
|
||||
if (block.type !== "toolCall") continue;
|
||||
const toolBlock = block as ToolCall & { partialArgs?: string };
|
||||
toolBlock.arguments = parseStreamingJson<Record<string, unknown>>(toolBlock.partialArgs);
|
||||
delete toolBlock.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: index,
|
||||
toolCall: toolBlock,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function toFunctionTools(tools: Tool[]): Array<FunctionTool & { type: "function" }> {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as unknown as Record<string, unknown>,
|
||||
strict: false,
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function toChatMessages(messages: Message[], supportsImages: boolean): ChatCompletionStreamRequestMessages[] {
|
||||
const result: ChatCompletionStreamRequestMessages[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
result.push({ role: "user", content: sanitizeSurrogates(msg.content) });
|
||||
continue;
|
||||
}
|
||||
const hadImages = msg.content.some((item) => item.type === "image");
|
||||
const content: ContentChunk[] = msg.content
|
||||
.filter((item) => item.type === "text" || supportsImages)
|
||||
.map((item) => {
|
||||
if (item.type === "text") return { type: "text", text: sanitizeSurrogates(item.text) };
|
||||
return { type: "image_url", imageUrl: `data:${item.mimeType};base64,${item.data}` };
|
||||
});
|
||||
if (content.length > 0) {
|
||||
result.push({ role: "user", content });
|
||||
continue;
|
||||
}
|
||||
if (hadImages && !supportsImages) {
|
||||
result.push({ role: "user", content: "(image omitted: model does not support images)" });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
const contentParts: ContentChunk[] = [];
|
||||
const toolCalls: Array<{ id: string; type: "function"; function: { name: string; arguments: string } }> = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
if (block.text.trim().length > 0) {
|
||||
contentParts.push({ type: "text", text: sanitizeSurrogates(block.text) });
|
||||
}
|
||||
continue;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinking.trim().length > 0) {
|
||||
contentParts.push({
|
||||
type: "thinking",
|
||||
thinking: [{ type: "text", text: sanitizeSurrogates(block.thinking) }],
|
||||
});
|
||||
}
|
||||
continue;
|
||||
}
|
||||
toolCalls.push({
|
||||
id: block.id,
|
||||
type: "function",
|
||||
function: { name: block.name, arguments: JSON.stringify(block.arguments || {}) },
|
||||
});
|
||||
}
|
||||
|
||||
const assistantMessage: ChatCompletionStreamRequestMessages = { role: "assistant" };
|
||||
if (contentParts.length > 0) assistantMessage.content = contentParts;
|
||||
if (toolCalls.length > 0) assistantMessage.toolCalls = toolCalls;
|
||||
if (contentParts.length > 0 || toolCalls.length > 0) result.push(assistantMessage);
|
||||
continue;
|
||||
}
|
||||
|
||||
const toolContent: ContentChunk[] = [];
|
||||
const textResult = msg.content
|
||||
.filter((part) => part.type === "text")
|
||||
.map((part) => (part.type === "text" ? sanitizeSurrogates(part.text) : ""))
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((part) => part.type === "image");
|
||||
const toolText = buildToolResultText(textResult, hasImages, supportsImages, msg.isError);
|
||||
toolContent.push({ type: "text", text: toolText });
|
||||
for (const part of msg.content) {
|
||||
if (!supportsImages) continue;
|
||||
if (part.type !== "image") continue;
|
||||
toolContent.push({
|
||||
type: "image_url",
|
||||
imageUrl: `data:${part.mimeType};base64,${part.data}`,
|
||||
});
|
||||
}
|
||||
result.push({
|
||||
role: "tool",
|
||||
toolCallId: msg.toolCallId,
|
||||
name: msg.toolName,
|
||||
content: toolContent,
|
||||
});
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function buildToolResultText(text: string, hasImages: boolean, supportsImages: boolean, isError: boolean): string {
|
||||
const trimmed = text.trim();
|
||||
const errorPrefix = isError ? "[tool error] " : "";
|
||||
|
||||
if (trimmed.length > 0) {
|
||||
const imageSuffix = hasImages && !supportsImages ? "\n[tool image omitted: model does not support images]" : "";
|
||||
return `${errorPrefix}${trimmed}${imageSuffix}`;
|
||||
}
|
||||
|
||||
if (hasImages) {
|
||||
if (supportsImages) {
|
||||
return isError ? "[tool error] (see attached image)" : "(see attached image)";
|
||||
}
|
||||
return isError
|
||||
? "[tool error] (image omitted: model does not support images)"
|
||||
: "(image omitted: model does not support images)";
|
||||
}
|
||||
|
||||
return isError ? "[tool error] (no tool output)" : "(no tool output)";
|
||||
}
|
||||
|
||||
function mapToolChoice(
|
||||
choice: MistralOptions["toolChoice"],
|
||||
): "auto" | "none" | "any" | "required" | { type: "function"; function: { name: string } } | undefined {
|
||||
if (!choice) return undefined;
|
||||
if (choice === "auto" || choice === "none" || choice === "any" || choice === "required") {
|
||||
return choice as any;
|
||||
}
|
||||
return {
|
||||
type: "function",
|
||||
function: { name: choice.function.name },
|
||||
};
|
||||
}
|
||||
|
||||
function mapChatStopReason(reason: string | null): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
case "model_length":
|
||||
return "length";
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "error":
|
||||
return "error";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
875
packages/pi-ai/src/providers/openai-codex-responses.ts
Normal file
875
packages/pi-ai/src/providers/openai-codex-responses.ts
Normal file
|
|
@ -0,0 +1,875 @@
|
|||
import type * as NodeOs from "node:os";
|
||||
import type { Tool as OpenAITool, ResponseInput, ResponseStreamEvent } from "openai/resources/responses/responses.js";
|
||||
|
||||
// NEVER convert to top-level runtime imports - breaks browser/Vite builds (web-ui)
|
||||
let _os: typeof NodeOs | null = null;
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const NODE_OS_SPECIFIER = "node:" + "os";
|
||||
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
dynamicImport(NODE_OS_SPECIFIER).then((m) => {
|
||||
_os = m as typeof NodeOs;
|
||||
});
|
||||
}
|
||||
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
// ============================================================================
|
||||
// Configuration
|
||||
// ============================================================================
|
||||
|
||||
const DEFAULT_CODEX_BASE_URL = "https://chatgpt.com/backend-api";
|
||||
const JWT_CLAIM_PATH = "https://api.openai.com/auth" as const;
|
||||
const MAX_RETRIES = 3;
|
||||
const BASE_DELAY_MS = 1000;
|
||||
const CODEX_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||
|
||||
const CODEX_RESPONSE_STATUSES = new Set<CodexResponseStatus>([
|
||||
"completed",
|
||||
"incomplete",
|
||||
"failed",
|
||||
"cancelled",
|
||||
"queued",
|
||||
"in_progress",
|
||||
]);
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface OpenAICodexResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "none" | "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "concise" | "detailed" | "off" | "on" | null;
|
||||
textVerbosity?: "low" | "medium" | "high";
|
||||
}
|
||||
|
||||
type CodexResponseStatus = "completed" | "incomplete" | "failed" | "cancelled" | "queued" | "in_progress";
|
||||
|
||||
interface RequestBody {
|
||||
model: string;
|
||||
store?: boolean;
|
||||
stream?: boolean;
|
||||
instructions?: string;
|
||||
input?: ResponseInput;
|
||||
tools?: OpenAITool[];
|
||||
tool_choice?: "auto";
|
||||
parallel_tool_calls?: boolean;
|
||||
temperature?: number;
|
||||
reasoning?: { effort?: string; summary?: string };
|
||||
text?: { verbosity?: string };
|
||||
include?: string[];
|
||||
prompt_cache_key?: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Retry Helpers
|
||||
// ============================================================================
|
||||
|
||||
function isRetryableError(status: number, errorText: string): boolean {
|
||||
if (status === 429 || status === 500 || status === 502 || status === 503 || status === 504) {
|
||||
return true;
|
||||
}
|
||||
return /rate.?limit|overloaded|service.?unavailable|upstream.?connect|connection.?refused/i.test(errorText);
|
||||
}
|
||||
|
||||
function sleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Request was aborted"));
|
||||
return;
|
||||
}
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
signal?.addEventListener("abort", () => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Request was aborted"));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main Stream Function
|
||||
// ============================================================================
|
||||
|
||||
export const streamOpenAICodexResponses: StreamFunction<"openai-codex-responses", OpenAICodexResponsesOptions> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "openai-codex-responses" as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const accountId = extractAccountId(apiKey);
|
||||
let body = buildRequestBody(model, context, options);
|
||||
const nextBody = await options?.onPayload?.(body, model);
|
||||
if (nextBody !== undefined) {
|
||||
body = nextBody as RequestBody;
|
||||
}
|
||||
const headers = buildHeaders(model.headers, options?.headers, accountId, apiKey, options?.sessionId);
|
||||
const bodyJson = JSON.stringify(body);
|
||||
const transport = options?.transport || "sse";
|
||||
|
||||
if (transport !== "sse") {
|
||||
let websocketStarted = false;
|
||||
try {
|
||||
await processWebSocketStream(
|
||||
resolveCodexWebSocketUrl(model.baseUrl),
|
||||
body,
|
||||
headers,
|
||||
output,
|
||||
stream,
|
||||
model,
|
||||
() => {
|
||||
websocketStarted = true;
|
||||
},
|
||||
options,
|
||||
);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason as "stop" | "length" | "toolUse",
|
||||
message: output,
|
||||
});
|
||||
stream.end();
|
||||
return;
|
||||
} catch (error) {
|
||||
if (transport === "websocket" || websocketStarted) {
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch with retry logic for rate limits and transient errors
|
||||
let response: Response | undefined;
|
||||
let lastError: Error | undefined;
|
||||
|
||||
for (let attempt = 0; attempt <= MAX_RETRIES; attempt++) {
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
try {
|
||||
response = await fetch(resolveCodexUrl(model.baseUrl), {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: bodyJson,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
break;
|
||||
}
|
||||
|
||||
const errorText = await response.text();
|
||||
if (attempt < MAX_RETRIES && isRetryableError(response.status, errorText)) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Parse error for friendly message on final attempt or non-retryable error
|
||||
const fakeResponse = new Response(errorText, {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
});
|
||||
const info = await parseErrorResponse(fakeResponse);
|
||||
throw new Error(info.friendlyMessage || info.message);
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
if (error.name === "AbortError" || error.message === "Request was aborted") {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
}
|
||||
lastError = error instanceof Error ? error : new Error(String(error));
|
||||
// Network errors are retryable
|
||||
if (attempt < MAX_RETRIES && !lastError.message.includes("usage limit")) {
|
||||
const delayMs = BASE_DELAY_MS * 2 ** attempt;
|
||||
await sleep(delayMs, options?.signal);
|
||||
continue;
|
||||
}
|
||||
throw lastError;
|
||||
}
|
||||
}
|
||||
|
||||
if (!response?.ok) {
|
||||
throw lastError ?? new Error("Failed after retries");
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("No response body");
|
||||
}
|
||||
|
||||
stream.push({ type: "start", partial: output });
|
||||
await processStream(response, output, stream, model);
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason as "stop" | "length" | "toolUse", message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : String(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICodexResponses: StreamFunction<"openai-codex-responses", SimpleStreamOptions> = (
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamOpenAICodexResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAICodexResponsesOptions);
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Request Building
|
||||
// ============================================================================
|
||||
|
||||
function buildRequestBody(
|
||||
model: Model<"openai-codex-responses">,
|
||||
context: Context,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): RequestBody {
|
||||
const messages = convertResponsesMessages(model, context, CODEX_TOOL_CALL_PROVIDERS, {
|
||||
includeSystemPrompt: false,
|
||||
});
|
||||
|
||||
const body: RequestBody = {
|
||||
model: model.id,
|
||||
store: false,
|
||||
stream: true,
|
||||
instructions: context.systemPrompt,
|
||||
input: messages,
|
||||
text: { verbosity: options?.textVerbosity || "medium" },
|
||||
include: ["reasoning.encrypted_content"],
|
||||
prompt_cache_key: options?.sessionId,
|
||||
tool_choice: "auto",
|
||||
parallel_tool_calls: true,
|
||||
};
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
body.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
body.tools = convertResponsesTools(context.tools, { strict: null });
|
||||
}
|
||||
|
||||
if (options?.reasoningEffort !== undefined) {
|
||||
body.reasoning = {
|
||||
effort: clampReasoningEffort(model.id, options.reasoningEffort),
|
||||
summary: options.reasoningSummary ?? "auto",
|
||||
};
|
||||
}
|
||||
|
||||
return body;
|
||||
}
|
||||
|
||||
function clampReasoningEffort(modelId: string, effort: string): string {
|
||||
const id = modelId.includes("/") ? modelId.split("/").pop()! : modelId;
|
||||
if ((id.startsWith("gpt-5.2") || id.startsWith("gpt-5.3") || id.startsWith("gpt-5.4")) && effort === "minimal")
|
||||
return "low";
|
||||
if (id === "gpt-5.1" && effort === "xhigh") return "high";
|
||||
if (id === "gpt-5.1-codex-mini") return effort === "high" || effort === "xhigh" ? "high" : "medium";
|
||||
return effort;
|
||||
}
|
||||
|
||||
function resolveCodexUrl(baseUrl?: string): string {
|
||||
const raw = baseUrl && baseUrl.trim().length > 0 ? baseUrl : DEFAULT_CODEX_BASE_URL;
|
||||
const normalized = raw.replace(/\/+$/, "");
|
||||
if (normalized.endsWith("/codex/responses")) return normalized;
|
||||
if (normalized.endsWith("/codex")) return `${normalized}/responses`;
|
||||
return `${normalized}/codex/responses`;
|
||||
}
|
||||
|
||||
function resolveCodexWebSocketUrl(baseUrl?: string): string {
|
||||
const url = new URL(resolveCodexUrl(baseUrl));
|
||||
if (url.protocol === "https:") url.protocol = "wss:";
|
||||
if (url.protocol === "http:") url.protocol = "ws:";
|
||||
return url.toString();
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Processing
|
||||
// ============================================================================
|
||||
|
||||
async function processStream(
|
||||
response: Response,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<"openai-codex-responses">,
|
||||
): Promise<void> {
|
||||
await processResponsesStream(mapCodexEvents(parseSSE(response)), output, stream, model);
|
||||
}
|
||||
|
||||
async function* mapCodexEvents(events: AsyncIterable<Record<string, unknown>>): AsyncGenerator<ResponseStreamEvent> {
|
||||
for await (const event of events) {
|
||||
const type = typeof event.type === "string" ? event.type : undefined;
|
||||
if (!type) continue;
|
||||
|
||||
if (type === "error") {
|
||||
const code = (event as { code?: string }).code || "";
|
||||
const message = (event as { message?: string }).message || "";
|
||||
throw new Error(`Codex error: ${message || code || JSON.stringify(event)}`);
|
||||
}
|
||||
|
||||
if (type === "response.failed") {
|
||||
const msg = (event as { response?: { error?: { message?: string } } }).response?.error?.message;
|
||||
throw new Error(msg || "Codex response failed");
|
||||
}
|
||||
|
||||
if (type === "response.done" || type === "response.completed") {
|
||||
const response = (event as { response?: { status?: unknown } }).response;
|
||||
const normalizedResponse = response
|
||||
? { ...response, status: normalizeCodexStatus(response.status) }
|
||||
: response;
|
||||
yield { ...event, type: "response.completed", response: normalizedResponse } as ResponseStreamEvent;
|
||||
continue;
|
||||
}
|
||||
|
||||
yield event as unknown as ResponseStreamEvent;
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeCodexStatus(status: unknown): CodexResponseStatus | undefined {
|
||||
if (typeof status !== "string") return undefined;
|
||||
return CODEX_RESPONSE_STATUSES.has(status as CodexResponseStatus) ? (status as CodexResponseStatus) : undefined;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SSE Parsing
|
||||
// ============================================================================
|
||||
|
||||
async function* parseSSE(response: Response): AsyncGenerator<Record<string, unknown>> {
|
||||
if (!response.body) return;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
let idx = buffer.indexOf("\n\n");
|
||||
while (idx !== -1) {
|
||||
const chunk = buffer.slice(0, idx);
|
||||
buffer = buffer.slice(idx + 2);
|
||||
|
||||
const dataLines = chunk
|
||||
.split("\n")
|
||||
.filter((l) => l.startsWith("data:"))
|
||||
.map((l) => l.slice(5).trim());
|
||||
if (dataLines.length > 0) {
|
||||
const data = dataLines.join("\n").trim();
|
||||
if (data && data !== "[DONE]") {
|
||||
try {
|
||||
yield JSON.parse(data);
|
||||
} catch {}
|
||||
}
|
||||
}
|
||||
idx = buffer.indexOf("\n\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// WebSocket Parsing
|
||||
// ============================================================================
|
||||
|
||||
const OPENAI_BETA_RESPONSES_WEBSOCKETS = "responses_websockets=2026-02-06";
|
||||
const SESSION_WEBSOCKET_CACHE_TTL_MS = 5 * 60 * 1000;
|
||||
|
||||
type WebSocketEventType = "open" | "message" | "error" | "close";
|
||||
type WebSocketListener = (event: unknown) => void;
|
||||
|
||||
interface WebSocketLike {
|
||||
close(code?: number, reason?: string): void;
|
||||
send(data: string): void;
|
||||
addEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
|
||||
removeEventListener(type: WebSocketEventType, listener: WebSocketListener): void;
|
||||
}
|
||||
|
||||
interface CachedWebSocketConnection {
|
||||
socket: WebSocketLike;
|
||||
busy: boolean;
|
||||
idleTimer?: ReturnType<typeof setTimeout>;
|
||||
}
|
||||
|
||||
const websocketSessionCache = new Map<string, CachedWebSocketConnection>();
|
||||
|
||||
type WebSocketConstructor = new (
|
||||
url: string,
|
||||
protocols?: string | string[] | { headers?: Record<string, string> },
|
||||
) => WebSocketLike;
|
||||
|
||||
function getWebSocketConstructor(): WebSocketConstructor | null {
|
||||
const ctor = (globalThis as { WebSocket?: unknown }).WebSocket;
|
||||
if (typeof ctor !== "function") return null;
|
||||
return ctor as unknown as WebSocketConstructor;
|
||||
}
|
||||
|
||||
function headersToRecord(headers: Headers): Record<string, string> {
|
||||
const out: Record<string, string> = {};
|
||||
for (const [key, value] of headers.entries()) {
|
||||
out[key] = value;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
function getWebSocketReadyState(socket: WebSocketLike): number | undefined {
|
||||
const readyState = (socket as { readyState?: unknown }).readyState;
|
||||
return typeof readyState === "number" ? readyState : undefined;
|
||||
}
|
||||
|
||||
function isWebSocketReusable(socket: WebSocketLike): boolean {
|
||||
const readyState = getWebSocketReadyState(socket);
|
||||
// If readyState is unavailable, assume the runtime keeps it open/reusable.
|
||||
return readyState === undefined || readyState === 1;
|
||||
}
|
||||
|
||||
function closeWebSocketSilently(socket: WebSocketLike, code = 1000, reason = "done"): void {
|
||||
try {
|
||||
socket.close(code, reason);
|
||||
} catch {}
|
||||
}
|
||||
|
||||
function scheduleSessionWebSocketExpiry(sessionId: string, entry: CachedWebSocketConnection): void {
|
||||
if (entry.idleTimer) {
|
||||
clearTimeout(entry.idleTimer);
|
||||
}
|
||||
entry.idleTimer = setTimeout(() => {
|
||||
if (entry.busy) return;
|
||||
closeWebSocketSilently(entry.socket, 1000, "idle_timeout");
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}, SESSION_WEBSOCKET_CACHE_TTL_MS);
|
||||
}
|
||||
|
||||
async function connectWebSocket(url: string, headers: Headers, signal?: AbortSignal): Promise<WebSocketLike> {
|
||||
const WebSocketCtor = getWebSocketConstructor();
|
||||
if (!WebSocketCtor) {
|
||||
throw new Error("WebSocket transport is not available in this runtime");
|
||||
}
|
||||
|
||||
const wsHeaders = headersToRecord(headers);
|
||||
wsHeaders["OpenAI-Beta"] = OPENAI_BETA_RESPONSES_WEBSOCKETS;
|
||||
|
||||
return new Promise<WebSocketLike>((resolve, reject) => {
|
||||
let settled = false;
|
||||
let socket: WebSocketLike;
|
||||
|
||||
try {
|
||||
socket = new WebSocketCtor(url, { headers: wsHeaders });
|
||||
} catch (error) {
|
||||
reject(error instanceof Error ? error : new Error(String(error)));
|
||||
return;
|
||||
}
|
||||
|
||||
const onOpen: WebSocketListener = () => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
resolve(socket);
|
||||
};
|
||||
const onError: WebSocketListener = (event) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
reject(extractWebSocketError(event));
|
||||
};
|
||||
const onClose: WebSocketListener = (event) => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
reject(extractWebSocketCloseError(event));
|
||||
};
|
||||
const onAbort = () => {
|
||||
if (settled) return;
|
||||
settled = true;
|
||||
cleanup();
|
||||
socket.close(1000, "aborted");
|
||||
reject(new Error("Request was aborted"));
|
||||
};
|
||||
|
||||
const cleanup = () => {
|
||||
socket.removeEventListener("open", onOpen);
|
||||
socket.removeEventListener("error", onError);
|
||||
socket.removeEventListener("close", onClose);
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
};
|
||||
|
||||
socket.addEventListener("open", onOpen);
|
||||
socket.addEventListener("error", onError);
|
||||
socket.addEventListener("close", onClose);
|
||||
signal?.addEventListener("abort", onAbort);
|
||||
});
|
||||
}
|
||||
|
||||
async function acquireWebSocket(
|
||||
url: string,
|
||||
headers: Headers,
|
||||
sessionId: string | undefined,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ socket: WebSocketLike; release: (options?: { keep?: boolean }) => void }> {
|
||||
if (!sessionId) {
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
return {
|
||||
socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (keep === false) {
|
||||
closeWebSocketSilently(socket);
|
||||
return;
|
||||
}
|
||||
closeWebSocketSilently(socket);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const cached = websocketSessionCache.get(sessionId);
|
||||
if (cached) {
|
||||
if (cached.idleTimer) {
|
||||
clearTimeout(cached.idleTimer);
|
||||
cached.idleTimer = undefined;
|
||||
}
|
||||
if (!cached.busy && isWebSocketReusable(cached.socket)) {
|
||||
cached.busy = true;
|
||||
return {
|
||||
socket: cached.socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (!keep || !isWebSocketReusable(cached.socket)) {
|
||||
closeWebSocketSilently(cached.socket);
|
||||
websocketSessionCache.delete(sessionId);
|
||||
return;
|
||||
}
|
||||
cached.busy = false;
|
||||
scheduleSessionWebSocketExpiry(sessionId, cached);
|
||||
},
|
||||
};
|
||||
}
|
||||
if (cached.busy) {
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
return {
|
||||
socket,
|
||||
release: () => {
|
||||
closeWebSocketSilently(socket);
|
||||
},
|
||||
};
|
||||
}
|
||||
if (!isWebSocketReusable(cached.socket)) {
|
||||
closeWebSocketSilently(cached.socket);
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const socket = await connectWebSocket(url, headers, signal);
|
||||
const entry: CachedWebSocketConnection = { socket, busy: true };
|
||||
websocketSessionCache.set(sessionId, entry);
|
||||
return {
|
||||
socket,
|
||||
release: ({ keep } = {}) => {
|
||||
if (!keep || !isWebSocketReusable(entry.socket)) {
|
||||
closeWebSocketSilently(entry.socket);
|
||||
if (entry.idleTimer) clearTimeout(entry.idleTimer);
|
||||
if (websocketSessionCache.get(sessionId) === entry) {
|
||||
websocketSessionCache.delete(sessionId);
|
||||
}
|
||||
return;
|
||||
}
|
||||
entry.busy = false;
|
||||
scheduleSessionWebSocketExpiry(sessionId, entry);
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function extractWebSocketError(event: unknown): Error {
|
||||
if (event && typeof event === "object" && "message" in event) {
|
||||
const message = (event as { message?: unknown }).message;
|
||||
if (typeof message === "string" && message.length > 0) {
|
||||
return new Error(message);
|
||||
}
|
||||
}
|
||||
return new Error("WebSocket error");
|
||||
}
|
||||
|
||||
function extractWebSocketCloseError(event: unknown): Error {
|
||||
if (event && typeof event === "object") {
|
||||
const code = "code" in event ? (event as { code?: unknown }).code : undefined;
|
||||
const reason = "reason" in event ? (event as { reason?: unknown }).reason : undefined;
|
||||
const codeText = typeof code === "number" ? ` ${code}` : "";
|
||||
const reasonText = typeof reason === "string" && reason.length > 0 ? ` ${reason}` : "";
|
||||
return new Error(`WebSocket closed${codeText}${reasonText}`.trim());
|
||||
}
|
||||
return new Error("WebSocket closed");
|
||||
}
|
||||
|
||||
async function decodeWebSocketData(data: unknown): Promise<string | null> {
|
||||
if (typeof data === "string") return data;
|
||||
if (data instanceof ArrayBuffer) {
|
||||
return new TextDecoder().decode(new Uint8Array(data));
|
||||
}
|
||||
if (ArrayBuffer.isView(data)) {
|
||||
const view = data as ArrayBufferView;
|
||||
return new TextDecoder().decode(new Uint8Array(view.buffer, view.byteOffset, view.byteLength));
|
||||
}
|
||||
if (data && typeof data === "object" && "arrayBuffer" in data) {
|
||||
const blobLike = data as { arrayBuffer: () => Promise<ArrayBuffer> };
|
||||
const arrayBuffer = await blobLike.arrayBuffer();
|
||||
return new TextDecoder().decode(new Uint8Array(arrayBuffer));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function* parseWebSocket(socket: WebSocketLike, signal?: AbortSignal): AsyncGenerator<Record<string, unknown>> {
|
||||
const queue: Record<string, unknown>[] = [];
|
||||
let pending: (() => void) | null = null;
|
||||
let done = false;
|
||||
let failed: Error | null = null;
|
||||
let sawCompletion = false;
|
||||
|
||||
const wake = () => {
|
||||
if (!pending) return;
|
||||
const resolve = pending;
|
||||
pending = null;
|
||||
resolve();
|
||||
};
|
||||
|
||||
const onMessage: WebSocketListener = (event) => {
|
||||
void (async () => {
|
||||
if (!event || typeof event !== "object" || !("data" in event)) return;
|
||||
const text = await decodeWebSocketData((event as { data?: unknown }).data);
|
||||
if (!text) return;
|
||||
try {
|
||||
const parsed = JSON.parse(text) as Record<string, unknown>;
|
||||
const type = typeof parsed.type === "string" ? parsed.type : "";
|
||||
if (type === "response.completed" || type === "response.done") {
|
||||
sawCompletion = true;
|
||||
done = true;
|
||||
}
|
||||
queue.push(parsed);
|
||||
wake();
|
||||
} catch {}
|
||||
})();
|
||||
};
|
||||
|
||||
const onError: WebSocketListener = (event) => {
|
||||
failed = extractWebSocketError(event);
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
const onClose: WebSocketListener = (event) => {
|
||||
if (sawCompletion) {
|
||||
done = true;
|
||||
wake();
|
||||
return;
|
||||
}
|
||||
if (!failed) {
|
||||
failed = extractWebSocketCloseError(event);
|
||||
}
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
const onAbort = () => {
|
||||
failed = new Error("Request was aborted");
|
||||
done = true;
|
||||
wake();
|
||||
};
|
||||
|
||||
socket.addEventListener("message", onMessage);
|
||||
socket.addEventListener("error", onError);
|
||||
socket.addEventListener("close", onClose);
|
||||
signal?.addEventListener("abort", onAbort);
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (queue.length > 0) {
|
||||
yield queue.shift()!;
|
||||
continue;
|
||||
}
|
||||
if (done) break;
|
||||
await new Promise<void>((resolve) => {
|
||||
pending = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
if (failed) {
|
||||
throw failed;
|
||||
}
|
||||
if (!sawCompletion) {
|
||||
throw new Error("WebSocket stream closed before response.completed");
|
||||
}
|
||||
} finally {
|
||||
socket.removeEventListener("message", onMessage);
|
||||
socket.removeEventListener("error", onError);
|
||||
socket.removeEventListener("close", onClose);
|
||||
signal?.removeEventListener("abort", onAbort);
|
||||
}
|
||||
}
|
||||
|
||||
async function processWebSocketStream(
|
||||
url: string,
|
||||
body: RequestBody,
|
||||
headers: Headers,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<"openai-codex-responses">,
|
||||
onStart: () => void,
|
||||
options?: OpenAICodexResponsesOptions,
|
||||
): Promise<void> {
|
||||
const { socket, release } = await acquireWebSocket(url, headers, options?.sessionId, options?.signal);
|
||||
let keepConnection = true;
|
||||
try {
|
||||
socket.send(JSON.stringify({ type: "response.create", ...body }));
|
||||
onStart();
|
||||
stream.push({ type: "start", partial: output });
|
||||
await processResponsesStream(mapCodexEvents(parseWebSocket(socket, options?.signal)), output, stream, model);
|
||||
if (options?.signal?.aborted) {
|
||||
keepConnection = false;
|
||||
}
|
||||
} catch (error) {
|
||||
keepConnection = false;
|
||||
throw error;
|
||||
} finally {
|
||||
release({ keep: keepConnection });
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Handling
|
||||
// ============================================================================
|
||||
|
||||
async function parseErrorResponse(response: Response): Promise<{ message: string; friendlyMessage?: string }> {
|
||||
const raw = await response.text();
|
||||
let message = raw || response.statusText || "Request failed";
|
||||
let friendlyMessage: string | undefined;
|
||||
|
||||
try {
|
||||
const parsed = JSON.parse(raw) as {
|
||||
error?: { code?: string; type?: string; message?: string; plan_type?: string; resets_at?: number };
|
||||
};
|
||||
const err = parsed?.error;
|
||||
if (err) {
|
||||
const code = err.code || err.type || "";
|
||||
if (/usage_limit_reached|usage_not_included|rate_limit_exceeded/i.test(code) || response.status === 429) {
|
||||
const plan = err.plan_type ? ` (${err.plan_type.toLowerCase()} plan)` : "";
|
||||
const mins = err.resets_at
|
||||
? Math.max(0, Math.round((err.resets_at * 1000 - Date.now()) / 60000))
|
||||
: undefined;
|
||||
const when = mins !== undefined ? ` Try again in ~${mins} min.` : "";
|
||||
friendlyMessage = `You have hit your ChatGPT usage limit${plan}.${when}`.trim();
|
||||
}
|
||||
message = err.message || friendlyMessage || message;
|
||||
}
|
||||
} catch {}
|
||||
|
||||
return { message, friendlyMessage };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Auth & Headers
|
||||
// ============================================================================
|
||||
|
||||
function extractAccountId(token: string): string {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) throw new Error("Invalid token");
|
||||
const payload = JSON.parse(atob(parts[1]));
|
||||
const accountId = payload?.[JWT_CLAIM_PATH]?.chatgpt_account_id;
|
||||
if (!accountId) throw new Error("No account ID in token");
|
||||
return accountId;
|
||||
} catch {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
}
|
||||
|
||||
function buildHeaders(
|
||||
initHeaders: Record<string, string> | undefined,
|
||||
additionalHeaders: Record<string, string> | undefined,
|
||||
accountId: string,
|
||||
token: string,
|
||||
sessionId?: string,
|
||||
): Headers {
|
||||
const headers = new Headers(initHeaders);
|
||||
headers.set("Authorization", `Bearer ${token}`);
|
||||
headers.set("chatgpt-account-id", accountId);
|
||||
headers.set("OpenAI-Beta", "responses=experimental");
|
||||
headers.set("originator", "pi");
|
||||
const userAgent = _os ? `pi (${_os.platform()} ${_os.release()}; ${_os.arch()})` : "pi (browser)";
|
||||
headers.set("User-Agent", userAgent);
|
||||
headers.set("accept", "text/event-stream");
|
||||
headers.set("content-type", "application/json");
|
||||
for (const [key, value] of Object.entries(additionalHeaders || {})) {
|
||||
headers.set(key, value);
|
||||
}
|
||||
|
||||
if (sessionId) {
|
||||
headers.set("session_id", sessionId);
|
||||
}
|
||||
|
||||
return headers;
|
||||
}
|
||||
820
packages/pi-ai/src/providers/openai-completions.ts
Normal file
820
packages/pi-ai/src/providers/openai-completions.ts
Normal file
|
|
@ -0,0 +1,820 @@
|
|||
import OpenAI from "openai";
|
||||
import type {
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionContentPart,
|
||||
ChatCompletionContentPartImage,
|
||||
ChatCompletionContentPartText,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionToolMessageParam,
|
||||
} from "openai/resources/chat/completions.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { calculateCost, supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletionsCompat,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
/**
|
||||
* Check if conversation messages contain tool calls or tool results.
|
||||
* This is needed because Anthropic (via proxy) requires the tools param
|
||||
* to be present when messages include tool_calls or tool role messages.
|
||||
*/
|
||||
function hasToolHistory(messages: Message[]): boolean {
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "toolResult") {
|
||||
return true;
|
||||
}
|
||||
if (msg.role === "assistant") {
|
||||
if (msg.content.some((block) => block.type === "toolCall")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
export interface OpenAICompletionsOptions extends StreamOptions {
|
||||
toolChoice?: "auto" | "none" | "required" | { type: "function"; function: { name: string } };
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
}
|
||||
|
||||
export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenAICompletionsOptions> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.chat.completions.create(params, { signal: options?.signal });
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
let currentBlock: TextContent | ThinkingContent | (ToolCall & { partialArgs?: string }) | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
const finishCurrentBlock = (block?: typeof currentBlock) => {
|
||||
if (block) {
|
||||
if (block.type === "text") {
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "thinking") {
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
} else if (block.type === "toolCall") {
|
||||
block.arguments = parseStreamingJson(block.partialArgs);
|
||||
delete block.partialArgs;
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex: blockIndex(),
|
||||
toolCall: block,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
for await (const chunk of openaiStream) {
|
||||
if (chunk.usage) {
|
||||
const cachedTokens = chunk.usage.prompt_tokens_details?.cached_tokens || 0;
|
||||
const reasoningTokens = chunk.usage.completion_tokens_details?.reasoning_tokens || 0;
|
||||
const input = (chunk.usage.prompt_tokens || 0) - cachedTokens;
|
||||
const outputTokens = (chunk.usage.completion_tokens || 0) + reasoningTokens;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in prompt_tokens, so subtract to get non-cached input
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
// Compute totalTokens ourselves since we add reasoning_tokens to output
|
||||
// and some providers (e.g., Groq) don't include them in total_tokens
|
||||
totalTokens: input + outputTokens + cachedTokens,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
calculateCost(model, output.usage);
|
||||
}
|
||||
|
||||
const choice = chunk.choices?.[0];
|
||||
if (!choice) continue;
|
||||
|
||||
if (choice.finish_reason) {
|
||||
output.stopReason = mapStopReason(choice.finish_reason);
|
||||
}
|
||||
|
||||
if (choice.delta) {
|
||||
if (
|
||||
choice.delta.content !== null &&
|
||||
choice.delta.content !== undefined &&
|
||||
choice.delta.content.length > 0
|
||||
) {
|
||||
if (!currentBlock || currentBlock.type !== "text") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "text") {
|
||||
currentBlock.text += choice.delta.content;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: choice.delta.content,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Some endpoints return reasoning in reasoning_content (llama.cpp),
|
||||
// or reasoning (other openai compatible endpoints)
|
||||
// Use the first non-empty reasoning field to avoid duplication
|
||||
// (e.g., chutes.ai returns both reasoning_content and reasoning with same content)
|
||||
const reasoningFields = ["reasoning_content", "reasoning", "reasoning_text"];
|
||||
let foundReasoningField: string | null = null;
|
||||
for (const field of reasoningFields) {
|
||||
if (
|
||||
(choice.delta as any)[field] !== null &&
|
||||
(choice.delta as any)[field] !== undefined &&
|
||||
(choice.delta as any)[field].length > 0
|
||||
) {
|
||||
if (!foundReasoningField) {
|
||||
foundReasoningField = field;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (foundReasoningField) {
|
||||
if (!currentBlock || currentBlock.type !== "thinking") {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "thinking",
|
||||
thinking: "",
|
||||
thinkingSignature: foundReasoningField,
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "thinking") {
|
||||
const delta = (choice.delta as any)[foundReasoningField];
|
||||
currentBlock.thinking += delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (choice?.delta?.tool_calls) {
|
||||
for (const toolCall of choice.delta.tool_calls) {
|
||||
if (
|
||||
!currentBlock ||
|
||||
currentBlock.type !== "toolCall" ||
|
||||
(toolCall.id && currentBlock.id !== toolCall.id)
|
||||
) {
|
||||
finishCurrentBlock(currentBlock);
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: toolCall.id || "",
|
||||
name: toolCall.function?.name || "",
|
||||
arguments: {},
|
||||
partialArgs: "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
|
||||
if (currentBlock.type === "toolCall") {
|
||||
if (toolCall.id) currentBlock.id = toolCall.id;
|
||||
if (toolCall.function?.name) currentBlock.name = toolCall.function.name;
|
||||
let delta = "";
|
||||
if (toolCall.function?.arguments) {
|
||||
delta = toolCall.function.arguments;
|
||||
currentBlock.partialArgs += toolCall.function.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialArgs);
|
||||
}
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const reasoningDetails = (choice.delta as any).reasoning_details;
|
||||
if (reasoningDetails && Array.isArray(reasoningDetails)) {
|
||||
for (const detail of reasoningDetails) {
|
||||
if (detail.type === "reasoning.encrypted" && detail.id && detail.data) {
|
||||
const matchingToolCall = output.content.find(
|
||||
(b) => b.type === "toolCall" && b.id === detail.id,
|
||||
) as ToolCall | undefined;
|
||||
if (matchingToolCall) {
|
||||
matchingToolCall.thoughtSignature = JSON.stringify(detail);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
finishCurrentBlock(currentBlock);
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as any).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
// Some providers via OpenRouter give additional information in this field.
|
||||
const rawMetadata = (error as any)?.error?.metadata?.raw;
|
||||
if (rawMetadata) output.errorMessage += `\n${rawMetadata}`;
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAICompletions: StreamFunction<"openai-completions", SimpleStreamOptions> = (
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
const toolChoice = (options as OpenAICompletionsOptions | undefined)?.toolChoice;
|
||||
|
||||
return streamOpenAICompletions(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
toolChoice,
|
||||
} satisfies OpenAICompletionsOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-completions">, context: Context, options?: OpenAICompletionsOptions) {
|
||||
const compat = getCompat(model);
|
||||
const messages = convertMessages(model, context, compat);
|
||||
maybeAddOpenRouterAnthropicCacheControl(model, messages);
|
||||
|
||||
const params: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
messages,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
if (compat.supportsUsageInStreaming !== false) {
|
||||
(params as any).stream_options = { include_usage: true };
|
||||
}
|
||||
|
||||
if (compat.supportsStore) {
|
||||
params.store = false;
|
||||
}
|
||||
|
||||
if (options?.maxTokens) {
|
||||
if (compat.maxTokensField === "max_tokens") {
|
||||
(params as any).max_tokens = options.maxTokens;
|
||||
} else {
|
||||
params.max_completion_tokens = options.maxTokens;
|
||||
}
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options.temperature;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertTools(context.tools, compat);
|
||||
} else if (hasToolHistory(context.messages)) {
|
||||
// Anthropic (via LiteLLM/proxy) requires tools param when conversation has tool_calls/tool_results
|
||||
params.tools = [];
|
||||
}
|
||||
|
||||
if (options?.toolChoice) {
|
||||
params.tool_choice = options.toolChoice;
|
||||
}
|
||||
|
||||
if ((compat.thinkingFormat === "zai" || compat.thinkingFormat === "qwen") && model.reasoning) {
|
||||
// Both Z.ai and Qwen use enable_thinking: boolean
|
||||
(params as any).enable_thinking = !!options?.reasoningEffort;
|
||||
} else if (options?.reasoningEffort && model.reasoning && compat.supportsReasoningEffort) {
|
||||
// OpenAI-style reasoning_effort
|
||||
(params as any).reasoning_effort = mapReasoningEffort(options.reasoningEffort, compat.reasoningEffortMap);
|
||||
}
|
||||
|
||||
// OpenRouter provider routing preferences
|
||||
if (model.baseUrl.includes("openrouter.ai") && model.compat?.openRouterRouting) {
|
||||
(params as any).provider = model.compat.openRouterRouting;
|
||||
}
|
||||
|
||||
// Vercel AI Gateway provider routing preferences
|
||||
if (model.baseUrl.includes("ai-gateway.vercel.sh") && model.compat?.vercelGatewayRouting) {
|
||||
const routing = model.compat.vercelGatewayRouting;
|
||||
if (routing.only || routing.order) {
|
||||
const gatewayOptions: Record<string, string[]> = {};
|
||||
if (routing.only) gatewayOptions.only = routing.only;
|
||||
if (routing.order) gatewayOptions.order = routing.order;
|
||||
(params as any).providerOptions = { gateway: gatewayOptions };
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function mapReasoningEffort(
|
||||
effort: NonNullable<OpenAICompletionsOptions["reasoningEffort"]>,
|
||||
reasoningEffortMap: Partial<Record<NonNullable<OpenAICompletionsOptions["reasoningEffort"]>, string>>,
|
||||
): string {
|
||||
return reasoningEffortMap[effort] ?? effort;
|
||||
}
|
||||
|
||||
function maybeAddOpenRouterAnthropicCacheControl(
|
||||
model: Model<"openai-completions">,
|
||||
messages: ChatCompletionMessageParam[],
|
||||
): void {
|
||||
if (model.provider !== "openrouter" || !model.id.startsWith("anthropic/")) return;
|
||||
|
||||
// Anthropic-style caching requires cache_control on a text part. Add a breakpoint
|
||||
// on the last user/assistant message (walking backwards until we find text content).
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const msg = messages[i];
|
||||
if (msg.role !== "user" && msg.role !== "assistant") continue;
|
||||
|
||||
const content = msg.content;
|
||||
if (typeof content === "string") {
|
||||
msg.content = [
|
||||
Object.assign({ type: "text" as const, text: content }, { cache_control: { type: "ephemeral" } }),
|
||||
];
|
||||
return;
|
||||
}
|
||||
|
||||
if (!Array.isArray(content)) continue;
|
||||
|
||||
// Find last text part and add cache_control
|
||||
for (let j = content.length - 1; j >= 0; j--) {
|
||||
const part = content[j];
|
||||
if (part?.type === "text") {
|
||||
Object.assign(part, { cache_control: { type: "ephemeral" } });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function convertMessages(
|
||||
model: Model<"openai-completions">,
|
||||
context: Context,
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): ChatCompletionMessageParam[] {
|
||||
const params: ChatCompletionMessageParam[] = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
// Handle pipe-separated IDs from OpenAI Responses API
|
||||
// Format: {call_id}|{id} where {id} can be 400+ chars with special chars (+, /, =)
|
||||
// These come from providers like github-copilot, openai-codex, opencode
|
||||
// Extract just the call_id part and normalize it
|
||||
if (id.includes("|")) {
|
||||
const [callId] = id.split("|");
|
||||
// Sanitize to allowed chars and truncate to 40 chars (OpenAI limit)
|
||||
return callId.replace(/[^a-zA-Z0-9_-]/g, "_").slice(0, 40);
|
||||
}
|
||||
|
||||
if (model.provider === "openai") return id.length > 40 ? id.slice(0, 40) : id;
|
||||
return id;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, (id) => normalizeToolCallId(id));
|
||||
|
||||
if (context.systemPrompt) {
|
||||
const useDeveloperRole = model.reasoning && compat.supportsDeveloperRole;
|
||||
const role = useDeveloperRole ? "developer" : "system";
|
||||
params.push({ role: role, content: sanitizeSurrogates(context.systemPrompt) });
|
||||
}
|
||||
|
||||
let lastRole: string | null = null;
|
||||
|
||||
for (let i = 0; i < transformedMessages.length; i++) {
|
||||
const msg = transformedMessages[i];
|
||||
// Some providers don't allow user messages directly after tool results
|
||||
// Insert a synthetic assistant message to bridge the gap
|
||||
if (compat.requiresAssistantAfterToolResult && lastRole === "toolResult" && msg.role === "user") {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
params.push({
|
||||
role: "user",
|
||||
content: sanitizeSurrogates(msg.content),
|
||||
});
|
||||
} else {
|
||||
const content: ChatCompletionContentPart[] = msg.content.map((item): ChatCompletionContentPart => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ChatCompletionContentPartText;
|
||||
} else {
|
||||
return {
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${item.mimeType};base64,${item.data}`,
|
||||
},
|
||||
} satisfies ChatCompletionContentPartImage;
|
||||
}
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "image_url")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
params.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
// Some providers don't accept null content, use empty string instead
|
||||
const assistantMsg: ChatCompletionAssistantMessageParam = {
|
||||
role: "assistant",
|
||||
content: compat.requiresAssistantAfterToolResult ? "" : null,
|
||||
};
|
||||
|
||||
const textBlocks = msg.content.filter((b) => b.type === "text") as TextContent[];
|
||||
// Filter out empty text blocks to avoid API validation errors
|
||||
const nonEmptyTextBlocks = textBlocks.filter((b) => b.text && b.text.trim().length > 0);
|
||||
if (nonEmptyTextBlocks.length > 0) {
|
||||
// GitHub Copilot requires assistant content as a string, not an array.
|
||||
// Sending as array causes Claude models to re-answer all previous prompts.
|
||||
if (model.provider === "github-copilot") {
|
||||
assistantMsg.content = nonEmptyTextBlocks.map((b) => sanitizeSurrogates(b.text)).join("");
|
||||
} else {
|
||||
assistantMsg.content = nonEmptyTextBlocks.map((b) => {
|
||||
return { type: "text", text: sanitizeSurrogates(b.text) };
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Handle thinking blocks
|
||||
const thinkingBlocks = msg.content.filter((b) => b.type === "thinking") as ThinkingContent[];
|
||||
// Filter out empty thinking blocks to avoid API validation errors
|
||||
const nonEmptyThinkingBlocks = thinkingBlocks.filter((b) => b.thinking && b.thinking.trim().length > 0);
|
||||
if (nonEmptyThinkingBlocks.length > 0) {
|
||||
if (compat.requiresThinkingAsText) {
|
||||
// Convert thinking blocks to plain text (no tags to avoid model mimicking them)
|
||||
const thinkingText = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n\n");
|
||||
const textContent = assistantMsg.content as Array<{ type: "text"; text: string }> | null;
|
||||
if (textContent) {
|
||||
textContent.unshift({ type: "text", text: thinkingText });
|
||||
} else {
|
||||
assistantMsg.content = [{ type: "text", text: thinkingText }];
|
||||
}
|
||||
} else {
|
||||
// Use the signature from the first thinking block if available (for llama.cpp server + gpt-oss)
|
||||
const signature = nonEmptyThinkingBlocks[0].thinkingSignature;
|
||||
if (signature && signature.length > 0) {
|
||||
(assistantMsg as any)[signature] = nonEmptyThinkingBlocks.map((b) => b.thinking).join("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const toolCalls = msg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
assistantMsg.tool_calls = toolCalls.map((tc) => ({
|
||||
id: tc.id,
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: JSON.stringify(tc.arguments),
|
||||
},
|
||||
}));
|
||||
const reasoningDetails = toolCalls
|
||||
.filter((tc) => tc.thoughtSignature)
|
||||
.map((tc) => {
|
||||
try {
|
||||
return JSON.parse(tc.thoughtSignature!);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})
|
||||
.filter(Boolean);
|
||||
if (reasoningDetails.length > 0) {
|
||||
(assistantMsg as any).reasoning_details = reasoningDetails;
|
||||
}
|
||||
}
|
||||
// Skip assistant messages that have no content and no tool calls.
|
||||
// Some providers require "either content or tool_calls, but not none".
|
||||
// Other providers also don't accept empty assistant messages.
|
||||
// This handles aborted assistant responses that got no content.
|
||||
const content = assistantMsg.content;
|
||||
const hasContent =
|
||||
content !== null &&
|
||||
content !== undefined &&
|
||||
(typeof content === "string" ? content.length > 0 : content.length > 0);
|
||||
if (!hasContent && !assistantMsg.tool_calls) {
|
||||
continue;
|
||||
}
|
||||
params.push(assistantMsg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
const imageBlocks: Array<{ type: "image_url"; image_url: { url: string } }> = [];
|
||||
let j = i;
|
||||
|
||||
for (; j < transformedMessages.length && transformedMessages[j].role === "toolResult"; j++) {
|
||||
const toolMsg = transformedMessages[j] as ToolResultMessage;
|
||||
|
||||
// Extract text and image content
|
||||
const textResult = toolMsg.content
|
||||
.filter((c) => c.type === "text")
|
||||
.map((c) => (c as any).text)
|
||||
.join("\n");
|
||||
const hasImages = toolMsg.content.some((c) => c.type === "image");
|
||||
|
||||
// Always send tool result with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
// Some providers require the 'name' field in tool results
|
||||
const toolResultMsg: ChatCompletionToolMessageParam = {
|
||||
role: "tool",
|
||||
content: sanitizeSurrogates(hasText ? textResult : "(see attached image)"),
|
||||
tool_call_id: toolMsg.toolCallId,
|
||||
};
|
||||
if (compat.requiresToolResultName && toolMsg.toolName) {
|
||||
(toolResultMsg as any).name = toolMsg.toolName;
|
||||
}
|
||||
params.push(toolResultMsg);
|
||||
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
for (const block of toolMsg.content) {
|
||||
if (block.type === "image") {
|
||||
imageBlocks.push({
|
||||
type: "image_url",
|
||||
image_url: {
|
||||
url: `data:${(block as any).mimeType};base64,${(block as any).data}`,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i = j - 1;
|
||||
|
||||
if (imageBlocks.length > 0) {
|
||||
if (compat.requiresAssistantAfterToolResult) {
|
||||
params.push({
|
||||
role: "assistant",
|
||||
content: "I have processed the tool results.",
|
||||
});
|
||||
}
|
||||
|
||||
params.push({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
},
|
||||
...imageBlocks,
|
||||
],
|
||||
});
|
||||
lastRole = "user";
|
||||
} else {
|
||||
lastRole = "toolResult";
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
lastRole = msg.role;
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function convertTools(
|
||||
tools: Tool[],
|
||||
compat: Required<OpenAICompletionsCompat>,
|
||||
): OpenAI.Chat.Completions.ChatCompletionTool[] {
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
// Only include strict if provider supports it. Some reject unknown fields.
|
||||
...(compat.supportsStrictMode !== false && { strict: false }),
|
||||
},
|
||||
}));
|
||||
}
|
||||
|
||||
function mapStopReason(reason: ChatCompletionChunk.Choice["finish_reason"]): StopReason {
|
||||
if (reason === null) return "stop";
|
||||
switch (reason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
case "function_call":
|
||||
case "tool_calls":
|
||||
return "toolUse";
|
||||
case "content_filter":
|
||||
return "error";
|
||||
default: {
|
||||
const _exhaustive: never = reason;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Detect compatibility settings from provider and baseUrl for known providers.
|
||||
* Provider takes precedence over URL-based detection since it's explicitly configured.
|
||||
* Returns a fully resolved OpenAICompletionsCompat object with all fields set.
|
||||
*/
|
||||
function detectCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
|
||||
const provider = model.provider;
|
||||
const baseUrl = model.baseUrl;
|
||||
|
||||
const isZai = provider === "zai" || baseUrl.includes("api.z.ai");
|
||||
|
||||
const isNonStandard =
|
||||
provider === "cerebras" ||
|
||||
baseUrl.includes("cerebras.ai") ||
|
||||
provider === "xai" ||
|
||||
baseUrl.includes("api.x.ai") ||
|
||||
baseUrl.includes("chutes.ai") ||
|
||||
baseUrl.includes("deepseek.com") ||
|
||||
isZai ||
|
||||
provider === "opencode" ||
|
||||
baseUrl.includes("opencode.ai");
|
||||
|
||||
const useMaxTokens = baseUrl.includes("chutes.ai");
|
||||
|
||||
const isGrok = provider === "xai" || baseUrl.includes("api.x.ai");
|
||||
const isGroq = provider === "groq" || baseUrl.includes("groq.com");
|
||||
|
||||
const reasoningEffortMap =
|
||||
isGroq && model.id === "qwen/qwen3-32b"
|
||||
? {
|
||||
minimal: "default",
|
||||
low: "default",
|
||||
medium: "default",
|
||||
high: "default",
|
||||
xhigh: "default",
|
||||
}
|
||||
: {};
|
||||
return {
|
||||
supportsStore: !isNonStandard,
|
||||
supportsDeveloperRole: !isNonStandard,
|
||||
supportsReasoningEffort: !isGrok && !isZai,
|
||||
reasoningEffortMap,
|
||||
supportsUsageInStreaming: true,
|
||||
maxTokensField: useMaxTokens ? "max_tokens" : "max_completion_tokens",
|
||||
requiresToolResultName: false,
|
||||
requiresAssistantAfterToolResult: false,
|
||||
requiresThinkingAsText: false,
|
||||
thinkingFormat: isZai ? "zai" : "openai",
|
||||
openRouterRouting: {},
|
||||
vercelGatewayRouting: {},
|
||||
supportsStrictMode: true,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Get resolved compatibility settings for a model.
|
||||
* Uses explicit model.compat if provided, otherwise auto-detects from provider/URL.
|
||||
*/
|
||||
function getCompat(model: Model<"openai-completions">): Required<OpenAICompletionsCompat> {
|
||||
const detected = detectCompat(model);
|
||||
if (!model.compat) return detected;
|
||||
|
||||
return {
|
||||
supportsStore: model.compat.supportsStore ?? detected.supportsStore,
|
||||
supportsDeveloperRole: model.compat.supportsDeveloperRole ?? detected.supportsDeveloperRole,
|
||||
supportsReasoningEffort: model.compat.supportsReasoningEffort ?? detected.supportsReasoningEffort,
|
||||
reasoningEffortMap: model.compat.reasoningEffortMap ?? detected.reasoningEffortMap,
|
||||
supportsUsageInStreaming: model.compat.supportsUsageInStreaming ?? detected.supportsUsageInStreaming,
|
||||
maxTokensField: model.compat.maxTokensField ?? detected.maxTokensField,
|
||||
requiresToolResultName: model.compat.requiresToolResultName ?? detected.requiresToolResultName,
|
||||
requiresAssistantAfterToolResult:
|
||||
model.compat.requiresAssistantAfterToolResult ?? detected.requiresAssistantAfterToolResult,
|
||||
requiresThinkingAsText: model.compat.requiresThinkingAsText ?? detected.requiresThinkingAsText,
|
||||
thinkingFormat: model.compat.thinkingFormat ?? detected.thinkingFormat,
|
||||
openRouterRouting: model.compat.openRouterRouting ?? {},
|
||||
vercelGatewayRouting: model.compat.vercelGatewayRouting ?? detected.vercelGatewayRouting,
|
||||
supportsStrictMode: model.compat.supportsStrictMode ?? detected.supportsStrictMode,
|
||||
};
|
||||
}
|
||||
496
packages/pi-ai/src/providers/openai-responses-shared.ts
Normal file
496
packages/pi-ai/src/providers/openai-responses-shared.ts
Normal file
|
|
@ -0,0 +1,496 @@
|
|||
import type OpenAI from "openai";
|
||||
import type {
|
||||
Tool as OpenAITool,
|
||||
ResponseCreateParamsStreaming,
|
||||
ResponseFunctionToolCall,
|
||||
ResponseInput,
|
||||
ResponseInputContent,
|
||||
ResponseInputImage,
|
||||
ResponseInputText,
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
ResponseStreamEvent,
|
||||
} from "openai/resources/responses/responses.js";
|
||||
import { calculateCost } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
StopReason,
|
||||
TextContent,
|
||||
TextSignatureV1,
|
||||
ThinkingContent,
|
||||
Tool,
|
||||
ToolCall,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import type { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { shortHash } from "../utils/hash.js";
|
||||
import { parseStreamingJson } from "../utils/json-parse.js";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.js";
|
||||
import { transformMessages } from "./transform-messages.js";
|
||||
|
||||
// =============================================================================
|
||||
// Utilities
|
||||
// =============================================================================
|
||||
|
||||
function encodeTextSignatureV1(id: string, phase?: TextSignatureV1["phase"]): string {
|
||||
const payload: TextSignatureV1 = { v: 1, id };
|
||||
if (phase) payload.phase = phase;
|
||||
return JSON.stringify(payload);
|
||||
}
|
||||
|
||||
function parseTextSignature(
|
||||
signature: string | undefined,
|
||||
): { id: string; phase?: TextSignatureV1["phase"] } | undefined {
|
||||
if (!signature) return undefined;
|
||||
if (signature.startsWith("{")) {
|
||||
try {
|
||||
const parsed = JSON.parse(signature) as Partial<TextSignatureV1>;
|
||||
if (parsed.v === 1 && typeof parsed.id === "string") {
|
||||
if (parsed.phase === "commentary" || parsed.phase === "final_answer") {
|
||||
return { id: parsed.id, phase: parsed.phase };
|
||||
}
|
||||
return { id: parsed.id };
|
||||
}
|
||||
} catch {
|
||||
// Fall through to legacy plain-string handling.
|
||||
}
|
||||
}
|
||||
return { id: signature };
|
||||
}
|
||||
|
||||
export interface OpenAIResponsesStreamOptions {
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
applyServiceTierPricing?: (
|
||||
usage: Usage,
|
||||
serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined,
|
||||
) => void;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesMessagesOptions {
|
||||
includeSystemPrompt?: boolean;
|
||||
}
|
||||
|
||||
export interface ConvertResponsesToolsOptions {
|
||||
strict?: boolean | null;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Message conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesMessages<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
allowedToolCallProviders: ReadonlySet<string>,
|
||||
options?: ConvertResponsesMessagesOptions,
|
||||
): ResponseInput {
|
||||
const messages: ResponseInput = [];
|
||||
|
||||
const normalizeToolCallId = (id: string): string => {
|
||||
if (!allowedToolCallProviders.has(model.provider)) return id;
|
||||
if (!id.includes("|")) return id;
|
||||
const [callId, itemId] = id.split("|");
|
||||
const sanitizedCallId = callId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
let sanitizedItemId = itemId.replace(/[^a-zA-Z0-9_-]/g, "_");
|
||||
// OpenAI Responses API requires item id to start with "fc"
|
||||
if (!sanitizedItemId.startsWith("fc")) {
|
||||
sanitizedItemId = `fc_${sanitizedItemId}`;
|
||||
}
|
||||
// Truncate to 64 chars and strip trailing underscores (OpenAI Codex rejects them)
|
||||
let normalizedCallId = sanitizedCallId.length > 64 ? sanitizedCallId.slice(0, 64) : sanitizedCallId;
|
||||
let normalizedItemId = sanitizedItemId.length > 64 ? sanitizedItemId.slice(0, 64) : sanitizedItemId;
|
||||
normalizedCallId = normalizedCallId.replace(/_+$/, "");
|
||||
normalizedItemId = normalizedItemId.replace(/_+$/, "");
|
||||
return `${normalizedCallId}|${normalizedItemId}`;
|
||||
};
|
||||
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
|
||||
const includeSystemPrompt = options?.includeSystemPrompt ?? true;
|
||||
if (includeSystemPrompt && context.systemPrompt) {
|
||||
const role = model.reasoning ? "developer" : "system";
|
||||
messages.push({
|
||||
role,
|
||||
content: sanitizeSurrogates(context.systemPrompt),
|
||||
});
|
||||
}
|
||||
|
||||
let msgIndex = 0;
|
||||
for (const msg of transformedMessages) {
|
||||
if (msg.role === "user") {
|
||||
if (typeof msg.content === "string") {
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: [{ type: "input_text", text: sanitizeSurrogates(msg.content) }],
|
||||
});
|
||||
} else {
|
||||
const content: ResponseInputContent[] = msg.content.map((item): ResponseInputContent => {
|
||||
if (item.type === "text") {
|
||||
return {
|
||||
type: "input_text",
|
||||
text: sanitizeSurrogates(item.text),
|
||||
} satisfies ResponseInputText;
|
||||
}
|
||||
return {
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${item.mimeType};base64,${item.data}`,
|
||||
} satisfies ResponseInputImage;
|
||||
});
|
||||
const filteredContent = !model.input.includes("image")
|
||||
? content.filter((c) => c.type !== "input_image")
|
||||
: content;
|
||||
if (filteredContent.length === 0) continue;
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: filteredContent,
|
||||
});
|
||||
}
|
||||
} else if (msg.role === "assistant") {
|
||||
const output: ResponseInput = [];
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isDifferentModel =
|
||||
assistantMsg.model !== model.id &&
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api;
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "thinking") {
|
||||
if (block.thinkingSignature) {
|
||||
const reasoningItem = JSON.parse(block.thinkingSignature) as ResponseReasoningItem;
|
||||
output.push(reasoningItem);
|
||||
}
|
||||
} else if (block.type === "text") {
|
||||
const textBlock = block as TextContent;
|
||||
const parsedSignature = parseTextSignature(textBlock.textSignature);
|
||||
// OpenAI requires id to be max 64 characters
|
||||
let msgId = parsedSignature?.id;
|
||||
if (!msgId) {
|
||||
msgId = `msg_${msgIndex}`;
|
||||
} else if (msgId.length > 64) {
|
||||
msgId = `msg_${shortHash(msgId)}`;
|
||||
}
|
||||
output.push({
|
||||
type: "message",
|
||||
role: "assistant",
|
||||
content: [{ type: "output_text", text: sanitizeSurrogates(textBlock.text), annotations: [] }],
|
||||
status: "completed",
|
||||
id: msgId,
|
||||
phase: parsedSignature?.phase,
|
||||
} satisfies ResponseOutputMessage);
|
||||
} else if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
const [callId, itemIdRaw] = toolCall.id.split("|");
|
||||
let itemId: string | undefined = itemIdRaw;
|
||||
|
||||
// For different-model messages, set id to undefined to avoid pairing validation.
|
||||
// OpenAI tracks which fc_xxx IDs were paired with rs_xxx reasoning items.
|
||||
// By omitting the id, we avoid triggering that validation (like cross-provider does).
|
||||
if (isDifferentModel && itemId?.startsWith("fc_")) {
|
||||
itemId = undefined;
|
||||
}
|
||||
|
||||
output.push({
|
||||
type: "function_call",
|
||||
id: itemId,
|
||||
call_id: callId,
|
||||
name: toolCall.name,
|
||||
arguments: JSON.stringify(toolCall.arguments),
|
||||
});
|
||||
}
|
||||
}
|
||||
if (output.length === 0) continue;
|
||||
messages.push(...output);
|
||||
} else if (msg.role === "toolResult") {
|
||||
// Extract text and image content
|
||||
const textResult = msg.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
const hasImages = msg.content.some((c): c is ImageContent => c.type === "image");
|
||||
|
||||
// Always send function_call_output with text (or placeholder if only images)
|
||||
const hasText = textResult.length > 0;
|
||||
const [callId] = msg.toolCallId.split("|");
|
||||
messages.push({
|
||||
type: "function_call_output",
|
||||
call_id: callId,
|
||||
output: sanitizeSurrogates(hasText ? textResult : "(see attached image)"),
|
||||
});
|
||||
|
||||
// If there are images and model supports them, send a follow-up user message with images
|
||||
if (hasImages && model.input.includes("image")) {
|
||||
const contentParts: ResponseInputContent[] = [];
|
||||
|
||||
// Add text prefix
|
||||
contentParts.push({
|
||||
type: "input_text",
|
||||
text: "Attached image(s) from tool result:",
|
||||
} satisfies ResponseInputText);
|
||||
|
||||
// Add images
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "image") {
|
||||
contentParts.push({
|
||||
type: "input_image",
|
||||
detail: "auto",
|
||||
image_url: `data:${block.mimeType};base64,${block.data}`,
|
||||
} satisfies ResponseInputImage);
|
||||
}
|
||||
}
|
||||
|
||||
messages.push({
|
||||
role: "user",
|
||||
content: contentParts,
|
||||
});
|
||||
}
|
||||
}
|
||||
msgIndex++;
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Tool conversion
|
||||
// =============================================================================
|
||||
|
||||
export function convertResponsesTools(tools: Tool[], options?: ConvertResponsesToolsOptions): OpenAITool[] {
|
||||
const strict = options?.strict === undefined ? false : options.strict;
|
||||
return tools.map((tool) => ({
|
||||
type: "function",
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: tool.parameters as any, // TypeBox already generates JSON Schema
|
||||
strict,
|
||||
}));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Stream processing
|
||||
// =============================================================================
|
||||
|
||||
export async function processResponsesStream<TApi extends Api>(
|
||||
openaiStream: AsyncIterable<ResponseStreamEvent>,
|
||||
output: AssistantMessage,
|
||||
stream: AssistantMessageEventStream,
|
||||
model: Model<TApi>,
|
||||
options?: OpenAIResponsesStreamOptions,
|
||||
): Promise<void> {
|
||||
let currentItem: ResponseReasoningItem | ResponseOutputMessage | ResponseFunctionToolCall | null = null;
|
||||
let currentBlock: ThinkingContent | TextContent | (ToolCall & { partialJson: string }) | null = null;
|
||||
const blocks = output.content;
|
||||
const blockIndex = () => blocks.length - 1;
|
||||
|
||||
for await (const event of openaiStream) {
|
||||
if (event.type === "response.output_item.added") {
|
||||
const item = event.item;
|
||||
if (item.type === "reasoning") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "thinking", thinking: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "thinking_start", contentIndex: blockIndex(), partial: output });
|
||||
} else if (item.type === "message") {
|
||||
currentItem = item;
|
||||
currentBlock = { type: "text", text: "" };
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "text_start", contentIndex: blockIndex(), partial: output });
|
||||
} else if (item.type === "function_call") {
|
||||
currentItem = item;
|
||||
currentBlock = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: {},
|
||||
partialJson: item.arguments || "",
|
||||
};
|
||||
output.content.push(currentBlock);
|
||||
stream.push({ type: "toolcall_start", contentIndex: blockIndex(), partial: output });
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.added") {
|
||||
if (currentItem && currentItem.type === "reasoning") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
currentItem.summary.push(event.part);
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_text.delta") {
|
||||
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.reasoning_summary_part.done") {
|
||||
if (currentItem?.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentItem.summary = currentItem.summary || [];
|
||||
const lastPart = currentItem.summary[currentItem.summary.length - 1];
|
||||
if (lastPart) {
|
||||
currentBlock.thinking += "\n\n";
|
||||
lastPart.text += "\n\n";
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: "\n\n",
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.content_part.added") {
|
||||
if (currentItem?.type === "message") {
|
||||
currentItem.content = currentItem.content || [];
|
||||
// Filter out ReasoningText, only accept output_text and refusal
|
||||
if (event.part.type === "output_text" || event.part.type === "refusal") {
|
||||
currentItem.content.push(event.part);
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.output_text.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "output_text") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.text += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.refusal.delta") {
|
||||
if (currentItem?.type === "message" && currentBlock?.type === "text") {
|
||||
if (!currentItem.content || currentItem.content.length === 0) {
|
||||
continue;
|
||||
}
|
||||
const lastPart = currentItem.content[currentItem.content.length - 1];
|
||||
if (lastPart?.type === "refusal") {
|
||||
currentBlock.text += event.delta;
|
||||
lastPart.refusal += event.delta;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.delta") {
|
||||
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
|
||||
currentBlock.partialJson += event.delta;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex: blockIndex(),
|
||||
delta: event.delta,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
} else if (event.type === "response.function_call_arguments.done") {
|
||||
if (currentItem?.type === "function_call" && currentBlock?.type === "toolCall") {
|
||||
currentBlock.partialJson = event.arguments;
|
||||
currentBlock.arguments = parseStreamingJson(currentBlock.partialJson);
|
||||
}
|
||||
} else if (event.type === "response.output_item.done") {
|
||||
const item = event.item;
|
||||
|
||||
if (item.type === "reasoning" && currentBlock?.type === "thinking") {
|
||||
currentBlock.thinking = item.summary?.map((s) => s.text).join("\n\n") || "";
|
||||
currentBlock.thinkingSignature = JSON.stringify(item);
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.thinking,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "message" && currentBlock?.type === "text") {
|
||||
currentBlock.text = item.content.map((c) => (c.type === "output_text" ? c.text : c.refusal)).join("");
|
||||
currentBlock.textSignature = encodeTextSignatureV1(item.id, item.phase ?? undefined);
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex: blockIndex(),
|
||||
content: currentBlock.text,
|
||||
partial: output,
|
||||
});
|
||||
currentBlock = null;
|
||||
} else if (item.type === "function_call") {
|
||||
const args =
|
||||
currentBlock?.type === "toolCall" && currentBlock.partialJson
|
||||
? parseStreamingJson(currentBlock.partialJson)
|
||||
: parseStreamingJson(item.arguments || "{}");
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: `${item.call_id}|${item.id}`,
|
||||
name: item.name,
|
||||
arguments: args,
|
||||
};
|
||||
|
||||
currentBlock = null;
|
||||
stream.push({ type: "toolcall_end", contentIndex: blockIndex(), toolCall, partial: output });
|
||||
}
|
||||
} else if (event.type === "response.completed") {
|
||||
const response = event.response;
|
||||
if (response?.usage) {
|
||||
const cachedTokens = response.usage.input_tokens_details?.cached_tokens || 0;
|
||||
output.usage = {
|
||||
// OpenAI includes cached tokens in input_tokens, so subtract to get non-cached input
|
||||
input: (response.usage.input_tokens || 0) - cachedTokens,
|
||||
output: response.usage.output_tokens || 0,
|
||||
cacheRead: cachedTokens,
|
||||
cacheWrite: 0,
|
||||
totalTokens: response.usage.total_tokens || 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
calculateCost(model, output.usage);
|
||||
if (options?.applyServiceTierPricing) {
|
||||
const serviceTier = response?.service_tier ?? options.serviceTier;
|
||||
options.applyServiceTierPricing(output.usage, serviceTier);
|
||||
}
|
||||
// Map status to stop reason
|
||||
output.stopReason = mapStopReason(response?.status);
|
||||
if (output.content.some((b) => b.type === "toolCall") && output.stopReason === "stop") {
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
} else if (event.type === "error") {
|
||||
throw new Error(`Error Code ${event.code}: ${event.message}` || "Unknown error");
|
||||
} else if (event.type === "response.failed") {
|
||||
throw new Error("Unknown error");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function mapStopReason(status: OpenAI.Responses.ResponseStatus | undefined): StopReason {
|
||||
if (!status) return "stop";
|
||||
switch (status) {
|
||||
case "completed":
|
||||
return "stop";
|
||||
case "incomplete":
|
||||
return "length";
|
||||
case "failed":
|
||||
case "cancelled":
|
||||
return "error";
|
||||
// These two are wonky ...
|
||||
case "in_progress":
|
||||
case "queued":
|
||||
return "stop";
|
||||
default: {
|
||||
const _exhaustive: never = status;
|
||||
throw new Error(`Unhandled stop reason: ${_exhaustive}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
262
packages/pi-ai/src/providers/openai-responses.ts
Normal file
262
packages/pi-ai/src/providers/openai-responses.ts
Normal file
|
|
@ -0,0 +1,262 @@
|
|||
import OpenAI from "openai";
|
||||
import type { ResponseCreateParamsStreaming } from "openai/resources/responses/responses.js";
|
||||
import { getEnvApiKey } from "../env-api-keys.js";
|
||||
import { supportsXhigh } from "../models.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
CacheRetention,
|
||||
Context,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.js";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.js";
|
||||
import { buildBaseOptions, clampReasoning } from "./simple-options.js";
|
||||
|
||||
const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"]);
|
||||
|
||||
/**
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
}
|
||||
|
||||
/**
|
||||
* Get prompt cache retention based on cacheRetention and base URL.
|
||||
* Only applies to direct OpenAI API calls (api.openai.com).
|
||||
*/
|
||||
function getPromptCacheRetention(baseUrl: string, cacheRetention: CacheRetention): "24h" | undefined {
|
||||
if (cacheRetention !== "long") {
|
||||
return undefined;
|
||||
}
|
||||
if (baseUrl.includes("api.openai.com")) {
|
||||
return "24h";
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// OpenAI Responses-specific options
|
||||
export interface OpenAIResponsesOptions extends StreamOptions {
|
||||
reasoningEffort?: "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
reasoningSummary?: "auto" | "detailed" | "concise" | null;
|
||||
serviceTier?: ResponseCreateParamsStreaming["service_tier"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate function for OpenAI Responses API
|
||||
*/
|
||||
export const streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIResponsesOptions> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: OpenAIResponsesOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const stream = new AssistantMessageEventStream();
|
||||
|
||||
// Start async processing
|
||||
(async () => {
|
||||
const output: AssistantMessage = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
try {
|
||||
// Create OpenAI client
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider) || "";
|
||||
const client = createClient(model, context, apiKey, options?.headers);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
params = nextParams as ResponseCreateParamsStreaming;
|
||||
}
|
||||
const openaiStream = await client.responses.create(
|
||||
params,
|
||||
options?.signal ? { signal: options.signal } : undefined,
|
||||
);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
await processResponsesStream(openaiStream, output, stream, model, {
|
||||
serviceTier: options?.serviceTier,
|
||||
applyServiceTierPricing,
|
||||
});
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: output.stopReason, message: output });
|
||||
stream.end();
|
||||
} catch (error) {
|
||||
for (const block of output.content) delete (block as { index?: number }).index;
|
||||
output.stopReason = options?.signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage = error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
};
|
||||
|
||||
export const streamSimpleOpenAIResponses: StreamFunction<"openai-responses", SimpleStreamOptions> = (
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream => {
|
||||
const apiKey = options?.apiKey || getEnvApiKey(model.provider);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
const base = buildBaseOptions(model, options, apiKey);
|
||||
const reasoningEffort = supportsXhigh(model) ? options?.reasoning : clampReasoning(options?.reasoning);
|
||||
|
||||
return streamOpenAIResponses(model, context, {
|
||||
...base,
|
||||
reasoningEffort,
|
||||
} satisfies OpenAIResponsesOptions);
|
||||
};
|
||||
|
||||
function createClient(
|
||||
model: Model<"openai-responses">,
|
||||
context: Context,
|
||||
apiKey?: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
) {
|
||||
if (!apiKey) {
|
||||
if (!process.env.OPENAI_API_KEY) {
|
||||
throw new Error(
|
||||
"OpenAI API key is required. Set OPENAI_API_KEY environment variable or pass it as an argument.",
|
||||
);
|
||||
}
|
||||
apiKey = process.env.OPENAI_API_KEY;
|
||||
}
|
||||
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
const hasImages = hasCopilotVisionInput(context.messages);
|
||||
const copilotHeaders = buildCopilotDynamicHeaders({
|
||||
messages: context.messages,
|
||||
hasImages,
|
||||
});
|
||||
Object.assign(headers, copilotHeaders);
|
||||
}
|
||||
|
||||
// Merge options headers last so they can override defaults
|
||||
if (optionsHeaders) {
|
||||
Object.assign(headers, optionsHeaders);
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: headers,
|
||||
});
|
||||
}
|
||||
|
||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
input: messages,
|
||||
stream: true,
|
||||
prompt_cache_key: cacheRetention === "none" ? undefined : options?.sessionId,
|
||||
prompt_cache_retention: getPromptCacheRetention(model.baseUrl, cacheRetention),
|
||||
store: false,
|
||||
};
|
||||
|
||||
if (options?.maxTokens) {
|
||||
params.max_output_tokens = options?.maxTokens;
|
||||
}
|
||||
|
||||
if (options?.temperature !== undefined) {
|
||||
params.temperature = options?.temperature;
|
||||
}
|
||||
|
||||
if (options?.serviceTier !== undefined) {
|
||||
params.service_tier = options.serviceTier;
|
||||
}
|
||||
|
||||
if (context.tools) {
|
||||
params.tools = convertResponsesTools(context.tools);
|
||||
}
|
||||
|
||||
if (model.reasoning) {
|
||||
if (options?.reasoningEffort || options?.reasoningSummary) {
|
||||
params.reasoning = {
|
||||
effort: options?.reasoningEffort || "medium",
|
||||
summary: options?.reasoningSummary || "auto",
|
||||
};
|
||||
params.include = ["reasoning.encrypted_content"];
|
||||
} else {
|
||||
if (model.name.startsWith("gpt-5")) {
|
||||
// Jesus Christ, see https://community.openai.com/t/need-reasoning-false-option-for-gpt-5/1351588/7
|
||||
messages.push({
|
||||
role: "developer",
|
||||
content: [
|
||||
{
|
||||
type: "input_text",
|
||||
text: "# Juice: 0 !important",
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return params;
|
||||
}
|
||||
|
||||
function getServiceTierCostMultiplier(serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined): number {
|
||||
switch (serviceTier) {
|
||||
case "flex":
|
||||
return 0.5;
|
||||
case "priority":
|
||||
return 2;
|
||||
default:
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
function applyServiceTierPricing(usage: Usage, serviceTier: ResponseCreateParamsStreaming["service_tier"] | undefined) {
|
||||
const multiplier = getServiceTierCostMultiplier(serviceTier);
|
||||
if (multiplier === 1) return;
|
||||
|
||||
usage.cost.input *= multiplier;
|
||||
usage.cost.output *= multiplier;
|
||||
usage.cost.cacheRead *= multiplier;
|
||||
usage.cost.cacheWrite *= multiplier;
|
||||
usage.cost.total = usage.cost.input + usage.cost.output + usage.cost.cacheRead + usage.cost.cacheWrite;
|
||||
}
|
||||
186
packages/pi-ai/src/providers/register-builtins.ts
Normal file
186
packages/pi-ai/src/providers/register-builtins.ts
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
import { clearApiProviders, registerApiProvider } from "../api-registry.js";
|
||||
import type { AssistantMessage, AssistantMessageEvent, Context, Model, SimpleStreamOptions } from "../types.js";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
import type { BedrockOptions } from "./amazon-bedrock.js";
|
||||
import { streamAnthropic, streamSimpleAnthropic } from "./anthropic.js";
|
||||
import { streamAzureOpenAIResponses, streamSimpleAzureOpenAIResponses } from "./azure-openai-responses.js";
|
||||
import { streamGoogle, streamSimpleGoogle } from "./google.js";
|
||||
import { streamGoogleGeminiCli, streamSimpleGoogleGeminiCli } from "./google-gemini-cli.js";
|
||||
import { streamGoogleVertex, streamSimpleGoogleVertex } from "./google-vertex.js";
|
||||
import { streamMistral, streamSimpleMistral } from "./mistral.js";
|
||||
import { streamOpenAICodexResponses, streamSimpleOpenAICodexResponses } from "./openai-codex-responses.js";
|
||||
import { streamOpenAICompletions, streamSimpleOpenAICompletions } from "./openai-completions.js";
|
||||
import { streamOpenAIResponses, streamSimpleOpenAIResponses } from "./openai-responses.js";
|
||||
|
||||
interface BedrockProviderModule {
|
||||
streamBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: BedrockOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
streamSimpleBedrock: (
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
) => AsyncIterable<AssistantMessageEvent>;
|
||||
}
|
||||
|
||||
type DynamicImport = (specifier: string) => Promise<unknown>;
|
||||
|
||||
const dynamicImport: DynamicImport = (specifier) => import(specifier);
|
||||
const BEDROCK_PROVIDER_SPECIFIER = "./amazon-" + "bedrock.js";
|
||||
|
||||
let bedrockProviderModuleOverride: BedrockProviderModule | undefined;
|
||||
|
||||
export function setBedrockProviderModule(module: BedrockProviderModule): void {
|
||||
bedrockProviderModuleOverride = module;
|
||||
}
|
||||
|
||||
async function loadBedrockProviderModule(): Promise<BedrockProviderModule> {
|
||||
if (bedrockProviderModuleOverride) {
|
||||
return bedrockProviderModuleOverride;
|
||||
}
|
||||
const module = await dynamicImport(BEDROCK_PROVIDER_SPECIFIER);
|
||||
return module as BedrockProviderModule;
|
||||
}
|
||||
|
||||
function forwardStream(target: AssistantMessageEventStream, source: AsyncIterable<AssistantMessageEvent>): void {
|
||||
(async () => {
|
||||
for await (const event of source) {
|
||||
target.push(event);
|
||||
}
|
||||
target.end();
|
||||
})();
|
||||
}
|
||||
|
||||
function createLazyLoadErrorMessage(model: Model<"bedrock-converse-stream">, error: unknown): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "bedrock-converse-stream",
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function streamBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: BedrockOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
function streamSimpleBedrockLazy(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const outer = new AssistantMessageEventStream();
|
||||
|
||||
loadBedrockProviderModule()
|
||||
.then((module) => {
|
||||
const inner = module.streamSimpleBedrock(model, context, options);
|
||||
forwardStream(outer, inner);
|
||||
})
|
||||
.catch((error) => {
|
||||
const message = createLazyLoadErrorMessage(model, error);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
});
|
||||
|
||||
return outer;
|
||||
}
|
||||
|
||||
export function registerBuiltInApiProviders(): void {
|
||||
registerApiProvider({
|
||||
api: "anthropic-messages",
|
||||
stream: streamAnthropic,
|
||||
streamSimple: streamSimpleAnthropic,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-completions",
|
||||
stream: streamOpenAICompletions,
|
||||
streamSimple: streamSimpleOpenAICompletions,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "mistral-conversations",
|
||||
stream: streamMistral,
|
||||
streamSimple: streamSimpleMistral,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-responses",
|
||||
stream: streamOpenAIResponses,
|
||||
streamSimple: streamSimpleOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "azure-openai-responses",
|
||||
stream: streamAzureOpenAIResponses,
|
||||
streamSimple: streamSimpleAzureOpenAIResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "openai-codex-responses",
|
||||
stream: streamOpenAICodexResponses,
|
||||
streamSimple: streamSimpleOpenAICodexResponses,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-generative-ai",
|
||||
stream: streamGoogle,
|
||||
streamSimple: streamSimpleGoogle,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-gemini-cli",
|
||||
stream: streamGoogleGeminiCli,
|
||||
streamSimple: streamSimpleGoogleGeminiCli,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "google-vertex",
|
||||
stream: streamGoogleVertex,
|
||||
streamSimple: streamSimpleGoogleVertex,
|
||||
});
|
||||
|
||||
registerApiProvider({
|
||||
api: "bedrock-converse-stream",
|
||||
stream: streamBedrockLazy,
|
||||
streamSimple: streamSimpleBedrockLazy,
|
||||
});
|
||||
}
|
||||
|
||||
export function resetApiProviders(): void {
|
||||
clearApiProviders();
|
||||
registerBuiltInApiProviders();
|
||||
}
|
||||
|
||||
registerBuiltInApiProviders();
|
||||
46
packages/pi-ai/src/providers/simple-options.ts
Normal file
46
packages/pi-ai/src/providers/simple-options.ts
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import type { Api, Model, SimpleStreamOptions, StreamOptions, ThinkingBudgets, ThinkingLevel } from "../types.js";
|
||||
|
||||
export function buildBaseOptions(model: Model<Api>, options?: SimpleStreamOptions, apiKey?: string): StreamOptions {
|
||||
return {
|
||||
temperature: options?.temperature,
|
||||
maxTokens: options?.maxTokens || Math.min(model.maxTokens, 32000),
|
||||
signal: options?.signal,
|
||||
apiKey: apiKey || options?.apiKey,
|
||||
cacheRetention: options?.cacheRetention,
|
||||
sessionId: options?.sessionId,
|
||||
headers: options?.headers,
|
||||
onPayload: options?.onPayload,
|
||||
maxRetryDelayMs: options?.maxRetryDelayMs,
|
||||
metadata: options?.metadata,
|
||||
};
|
||||
}
|
||||
|
||||
export function clampReasoning(effort: ThinkingLevel | undefined): Exclude<ThinkingLevel, "xhigh"> | undefined {
|
||||
return effort === "xhigh" ? "high" : effort;
|
||||
}
|
||||
|
||||
export function adjustMaxTokensForThinking(
|
||||
baseMaxTokens: number,
|
||||
modelMaxTokens: number,
|
||||
reasoningLevel: ThinkingLevel,
|
||||
customBudgets?: ThinkingBudgets,
|
||||
): { maxTokens: number; thinkingBudget: number } {
|
||||
const defaultBudgets: ThinkingBudgets = {
|
||||
minimal: 1024,
|
||||
low: 2048,
|
||||
medium: 8192,
|
||||
high: 16384,
|
||||
};
|
||||
const budgets = { ...defaultBudgets, ...customBudgets };
|
||||
|
||||
const minOutputTokens = 1024;
|
||||
const level = clampReasoning(reasoningLevel)!;
|
||||
let thinkingBudget = budgets[level]!;
|
||||
const maxTokens = Math.min(baseMaxTokens + thinkingBudget, modelMaxTokens);
|
||||
|
||||
if (maxTokens <= thinkingBudget) {
|
||||
thinkingBudget = Math.max(0, maxTokens - minOutputTokens);
|
||||
}
|
||||
|
||||
return { maxTokens, thinkingBudget };
|
||||
}
|
||||
172
packages/pi-ai/src/providers/transform-messages.ts
Normal file
172
packages/pi-ai/src/providers/transform-messages.ts
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
import type { Api, AssistantMessage, Message, Model, ToolCall, ToolResultMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Normalize tool call ID for cross-provider compatibility.
|
||||
* OpenAI Responses API generates IDs that are 450+ chars with special characters like `|`.
|
||||
* Anthropic APIs require IDs matching ^[a-zA-Z0-9_-]+$ (max 64 chars).
|
||||
*/
|
||||
export function transformMessages<TApi extends Api>(
|
||||
messages: Message[],
|
||||
model: Model<TApi>,
|
||||
normalizeToolCallId?: (id: string, model: Model<TApi>, source: AssistantMessage) => string,
|
||||
): Message[] {
|
||||
// Build a map of original tool call IDs to normalized IDs
|
||||
const toolCallIdMap = new Map<string, string>();
|
||||
|
||||
// First pass: transform messages (thinking blocks, tool call ID normalization)
|
||||
const transformed = messages.map((msg) => {
|
||||
// User messages pass through unchanged
|
||||
if (msg.role === "user") {
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Handle toolResult messages - normalize toolCallId if we have a mapping
|
||||
if (msg.role === "toolResult") {
|
||||
const normalizedId = toolCallIdMap.get(msg.toolCallId);
|
||||
if (normalizedId && normalizedId !== msg.toolCallId) {
|
||||
return { ...msg, toolCallId: normalizedId };
|
||||
}
|
||||
return msg;
|
||||
}
|
||||
|
||||
// Assistant messages need transformation check
|
||||
if (msg.role === "assistant") {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
const isSameModel =
|
||||
assistantMsg.provider === model.provider &&
|
||||
assistantMsg.api === model.api &&
|
||||
assistantMsg.model === model.id;
|
||||
|
||||
const transformedContent = assistantMsg.content.flatMap((block) => {
|
||||
if (block.type === "thinking") {
|
||||
// Redacted thinking is opaque encrypted content, only valid for the same model.
|
||||
// Drop it for cross-model to avoid API errors.
|
||||
if (block.redacted) {
|
||||
return isSameModel ? block : [];
|
||||
}
|
||||
// For same model: keep thinking blocks with signatures (needed for replay)
|
||||
// even if the thinking text is empty (OpenAI encrypted reasoning)
|
||||
if (isSameModel && block.thinkingSignature) return block;
|
||||
// Skip empty thinking blocks, convert others to plain text
|
||||
if (!block.thinking || block.thinking.trim() === "") return [];
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.thinking,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "text") {
|
||||
if (isSameModel) return block;
|
||||
return {
|
||||
type: "text" as const,
|
||||
text: block.text,
|
||||
};
|
||||
}
|
||||
|
||||
if (block.type === "toolCall") {
|
||||
const toolCall = block as ToolCall;
|
||||
let normalizedToolCall: ToolCall = toolCall;
|
||||
|
||||
if (!isSameModel && toolCall.thoughtSignature) {
|
||||
normalizedToolCall = { ...toolCall };
|
||||
delete (normalizedToolCall as { thoughtSignature?: string }).thoughtSignature;
|
||||
}
|
||||
|
||||
if (!isSameModel && normalizeToolCallId) {
|
||||
const normalizedId = normalizeToolCallId(toolCall.id, model, assistantMsg);
|
||||
if (normalizedId !== toolCall.id) {
|
||||
toolCallIdMap.set(toolCall.id, normalizedId);
|
||||
normalizedToolCall = { ...normalizedToolCall, id: normalizedId };
|
||||
}
|
||||
}
|
||||
|
||||
return normalizedToolCall;
|
||||
}
|
||||
|
||||
return block;
|
||||
});
|
||||
|
||||
return {
|
||||
...assistantMsg,
|
||||
content: transformedContent,
|
||||
};
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
|
||||
// Second pass: insert synthetic empty tool results for orphaned tool calls
|
||||
// This preserves thinking signatures and satisfies API requirements
|
||||
const result: Message[] = [];
|
||||
let pendingToolCalls: ToolCall[] = [];
|
||||
let existingToolResultIds = new Set<string>();
|
||||
|
||||
for (let i = 0; i < transformed.length; i++) {
|
||||
const msg = transformed[i];
|
||||
|
||||
if (msg.role === "assistant") {
|
||||
// If we have pending orphaned tool calls from a previous assistant, insert synthetic results now
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
// Skip errored/aborted assistant messages entirely.
|
||||
// These are incomplete turns that shouldn't be replayed:
|
||||
// - May have partial content (reasoning without message, incomplete tool calls)
|
||||
// - Replaying them can cause API errors (e.g., OpenAI "reasoning without following item")
|
||||
// - The model should retry from the last valid state
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (assistantMsg.stopReason === "error" || assistantMsg.stopReason === "aborted") {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Track tool calls from this assistant message
|
||||
const toolCalls = assistantMsg.content.filter((b) => b.type === "toolCall") as ToolCall[];
|
||||
if (toolCalls.length > 0) {
|
||||
pendingToolCalls = toolCalls;
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
|
||||
result.push(msg);
|
||||
} else if (msg.role === "toolResult") {
|
||||
existingToolResultIds.add(msg.toolCallId);
|
||||
result.push(msg);
|
||||
} else if (msg.role === "user") {
|
||||
// User message interrupts tool flow - insert synthetic results for orphaned calls
|
||||
if (pendingToolCalls.length > 0) {
|
||||
for (const tc of pendingToolCalls) {
|
||||
if (!existingToolResultIds.has(tc.id)) {
|
||||
result.push({
|
||||
role: "toolResult",
|
||||
toolCallId: tc.id,
|
||||
toolName: tc.name,
|
||||
content: [{ type: "text", text: "No result provided" }],
|
||||
isError: true,
|
||||
timestamp: Date.now(),
|
||||
} as ToolResultMessage);
|
||||
}
|
||||
}
|
||||
pendingToolCalls = [];
|
||||
existingToolResultIds = new Set();
|
||||
}
|
||||
result.push(msg);
|
||||
} else {
|
||||
result.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
59
packages/pi-ai/src/stream.ts
Normal file
59
packages/pi-ai/src/stream.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import "./providers/register-builtins.js";
|
||||
|
||||
import { getApiProvider } from "./api-registry.js";
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessage,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
Model,
|
||||
ProviderStreamOptions,
|
||||
SimpleStreamOptions,
|
||||
StreamOptions,
|
||||
} from "./types.js";
|
||||
|
||||
export { getEnvApiKey } from "./env-api-keys.js";
|
||||
|
||||
function resolveApiProvider(api: Api) {
|
||||
const provider = getApiProvider(api);
|
||||
if (!provider) {
|
||||
throw new Error(`No API provider registered for api: ${api}`);
|
||||
}
|
||||
return provider;
|
||||
}
|
||||
|
||||
export function stream<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.stream(model, context, options as StreamOptions);
|
||||
}
|
||||
|
||||
export async function complete<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = stream(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
|
||||
export function streamSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const provider = resolveApiProvider(model.api);
|
||||
return provider.streamSimple(model, context, options);
|
||||
}
|
||||
|
||||
export async function completeSimple<TApi extends Api>(
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): Promise<AssistantMessage> {
|
||||
const s = streamSimple(model, context, options);
|
||||
return s.result();
|
||||
}
|
||||
321
packages/pi-ai/src/types.ts
Normal file
321
packages/pi-ai/src/types.ts
Normal file
|
|
@ -0,0 +1,321 @@
|
|||
import type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type { AssistantMessageEventStream } from "./utils/event-stream.js";
|
||||
|
||||
export type KnownApi =
|
||||
| "openai-completions"
|
||||
| "mistral-conversations"
|
||||
| "openai-responses"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex-responses"
|
||||
| "anthropic-messages"
|
||||
| "bedrock-converse-stream"
|
||||
| "google-generative-ai"
|
||||
| "google-gemini-cli"
|
||||
| "google-vertex";
|
||||
|
||||
export type Api = KnownApi | (string & {});
|
||||
|
||||
export type KnownProvider =
|
||||
| "amazon-bedrock"
|
||||
| "anthropic"
|
||||
| "google"
|
||||
| "google-gemini-cli"
|
||||
| "google-antigravity"
|
||||
| "google-vertex"
|
||||
| "openai"
|
||||
| "azure-openai-responses"
|
||||
| "openai-codex"
|
||||
| "github-copilot"
|
||||
| "xai"
|
||||
| "groq"
|
||||
| "cerebras"
|
||||
| "openrouter"
|
||||
| "vercel-ai-gateway"
|
||||
| "zai"
|
||||
| "mistral"
|
||||
| "minimax"
|
||||
| "minimax-cn"
|
||||
| "huggingface"
|
||||
| "opencode"
|
||||
| "opencode-go"
|
||||
| "kimi-coding";
|
||||
export type Provider = KnownProvider | string;
|
||||
|
||||
export type ThinkingLevel = "minimal" | "low" | "medium" | "high" | "xhigh";
|
||||
|
||||
/** Token budgets for each thinking level (token-based providers only) */
|
||||
export interface ThinkingBudgets {
|
||||
minimal?: number;
|
||||
low?: number;
|
||||
medium?: number;
|
||||
high?: number;
|
||||
}
|
||||
|
||||
// Base options all providers share
|
||||
export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export type Transport = "sse" | "websocket" | "auto";
|
||||
|
||||
export interface StreamOptions {
|
||||
temperature?: number;
|
||||
maxTokens?: number;
|
||||
signal?: AbortSignal;
|
||||
apiKey?: string;
|
||||
/**
|
||||
* Preferred transport for providers that support multiple transports.
|
||||
* Providers that do not support this option ignore it.
|
||||
*/
|
||||
transport?: Transport;
|
||||
/**
|
||||
* Prompt cache retention preference. Providers map this to their supported values.
|
||||
* Default: "short".
|
||||
*/
|
||||
cacheRetention?: CacheRetention;
|
||||
/**
|
||||
* Optional session identifier for providers that support session-based caching.
|
||||
* Providers can use this to enable prompt caching, request routing, or other
|
||||
* session-aware features. Ignored by providers that don't support it.
|
||||
*/
|
||||
sessionId?: string;
|
||||
/**
|
||||
* Optional callback for inspecting or replacing provider payloads before sending.
|
||||
* Return undefined to keep the payload unchanged.
|
||||
*/
|
||||
onPayload?: (payload: unknown, model: Model<Api>) => unknown | undefined | Promise<unknown | undefined>;
|
||||
/**
|
||||
* Optional custom HTTP headers to include in API requests.
|
||||
* Merged with provider defaults; can override default headers.
|
||||
* Not supported by all providers (e.g., AWS Bedrock uses SDK auth).
|
||||
*/
|
||||
headers?: Record<string, string>;
|
||||
/**
|
||||
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
||||
* If the server's requested delay exceeds this value, the request fails immediately
|
||||
* with an error containing the requested delay, allowing higher-level retry logic
|
||||
* to handle it with user visibility.
|
||||
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
||||
*/
|
||||
maxRetryDelayMs?: number;
|
||||
/**
|
||||
* Optional metadata to include in API requests.
|
||||
* Providers extract the fields they understand and ignore the rest.
|
||||
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
|
||||
*/
|
||||
metadata?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
|
||||
|
||||
// Unified options with reasoning passed to streamSimple() and completeSimple()
|
||||
export interface SimpleStreamOptions extends StreamOptions {
|
||||
reasoning?: ThinkingLevel;
|
||||
/** Custom token budgets for thinking levels (token-based providers only) */
|
||||
thinkingBudgets?: ThinkingBudgets;
|
||||
}
|
||||
|
||||
// Generic StreamFunction with typed options
|
||||
export type StreamFunction<TApi extends Api = Api, TOptions extends StreamOptions = StreamOptions> = (
|
||||
model: Model<TApi>,
|
||||
context: Context,
|
||||
options?: TOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface TextSignatureV1 {
|
||||
v: 1;
|
||||
id: string;
|
||||
phase?: "commentary" | "final_answer";
|
||||
}
|
||||
|
||||
export interface TextContent {
|
||||
type: "text";
|
||||
text: string;
|
||||
textSignature?: string; // e.g., for OpenAI responses, message metadata (legacy id string or TextSignatureV1 JSON)
|
||||
}
|
||||
|
||||
export interface ThinkingContent {
|
||||
type: "thinking";
|
||||
thinking: string;
|
||||
thinkingSignature?: string; // e.g., for OpenAI responses, the reasoning item ID
|
||||
/** When true, the thinking content was redacted by safety filters. The opaque
|
||||
* encrypted payload is stored in `thinkingSignature` so it can be passed back
|
||||
* to the API for multi-turn continuity. */
|
||||
redacted?: boolean;
|
||||
}
|
||||
|
||||
export interface ImageContent {
|
||||
type: "image";
|
||||
data: string; // base64 encoded image data
|
||||
mimeType: string; // e.g., "image/jpeg", "image/png"
|
||||
}
|
||||
|
||||
export interface ToolCall {
|
||||
type: "toolCall";
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, any>;
|
||||
thoughtSignature?: string; // Google-specific: opaque signature for reusing thought context
|
||||
}
|
||||
|
||||
export interface Usage {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
}
|
||||
|
||||
export type StopReason = "stop" | "length" | "toolUse" | "error" | "aborted";
|
||||
|
||||
export interface UserMessage {
|
||||
role: "user";
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface AssistantMessage {
|
||||
role: "assistant";
|
||||
content: (TextContent | ThinkingContent | ToolCall)[];
|
||||
api: Api;
|
||||
provider: Provider;
|
||||
model: string;
|
||||
usage: Usage;
|
||||
stopReason: StopReason;
|
||||
errorMessage?: string;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export interface ToolResultMessage<TDetails = any> {
|
||||
role: "toolResult";
|
||||
toolCallId: string;
|
||||
toolName: string;
|
||||
content: (TextContent | ImageContent)[]; // Supports text and images
|
||||
details?: TDetails;
|
||||
isError: boolean;
|
||||
timestamp: number; // Unix timestamp in milliseconds
|
||||
}
|
||||
|
||||
export type Message = UserMessage | AssistantMessage | ToolResultMessage;
|
||||
|
||||
import type { TSchema } from "@sinclair/typebox";
|
||||
|
||||
export interface Tool<TParameters extends TSchema = TSchema> {
|
||||
name: string;
|
||||
description: string;
|
||||
parameters: TParameters;
|
||||
}
|
||||
|
||||
export interface Context {
|
||||
systemPrompt?: string;
|
||||
messages: Message[];
|
||||
tools?: Tool[];
|
||||
}
|
||||
|
||||
export type AssistantMessageEvent =
|
||||
| { type: "start"; partial: AssistantMessage }
|
||||
| { type: "text_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "text_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "text_end"; contentIndex: number; content: string; partial: AssistantMessage }
|
||||
| { type: "thinking_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "thinking_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "thinking_end"; contentIndex: number; content: string; partial: AssistantMessage }
|
||||
| { type: "toolcall_start"; contentIndex: number; partial: AssistantMessage }
|
||||
| { type: "toolcall_delta"; contentIndex: number; delta: string; partial: AssistantMessage }
|
||||
| { type: "toolcall_end"; contentIndex: number; toolCall: ToolCall; partial: AssistantMessage }
|
||||
| { type: "done"; reason: Extract<StopReason, "stop" | "length" | "toolUse">; message: AssistantMessage }
|
||||
| { type: "error"; reason: Extract<StopReason, "aborted" | "error">; error: AssistantMessage };
|
||||
|
||||
/**
|
||||
* Compatibility settings for OpenAI-compatible completions APIs.
|
||||
* Use this to override URL-based auto-detection for custom providers.
|
||||
*/
|
||||
export interface OpenAICompletionsCompat {
|
||||
/** Whether the provider supports the `store` field. Default: auto-detected from URL. */
|
||||
supportsStore?: boolean;
|
||||
/** Whether the provider supports the `developer` role (vs `system`). Default: auto-detected from URL. */
|
||||
supportsDeveloperRole?: boolean;
|
||||
/** Whether the provider supports `reasoning_effort`. Default: auto-detected from URL. */
|
||||
supportsReasoningEffort?: boolean;
|
||||
/** Optional mapping from pi-ai reasoning levels to provider/model-specific `reasoning_effort` values. */
|
||||
reasoningEffortMap?: Partial<Record<ThinkingLevel, string>>;
|
||||
/** Whether the provider supports `stream_options: { include_usage: true }` for token usage in streaming responses. Default: true. */
|
||||
supportsUsageInStreaming?: boolean;
|
||||
/** Which field to use for max tokens. Default: auto-detected from URL. */
|
||||
maxTokensField?: "max_completion_tokens" | "max_tokens";
|
||||
/** Whether tool results require the `name` field. Default: auto-detected from URL. */
|
||||
requiresToolResultName?: boolean;
|
||||
/** Whether a user message after tool results requires an assistant message in between. Default: auto-detected from URL. */
|
||||
requiresAssistantAfterToolResult?: boolean;
|
||||
/** Whether thinking blocks must be converted to text blocks with <thinking> delimiters. Default: auto-detected from URL. */
|
||||
requiresThinkingAsText?: boolean;
|
||||
/** Format for reasoning/thinking parameter. "openai" uses reasoning_effort, "zai" uses thinking: { type: "enabled" }, "qwen" uses enable_thinking: boolean. Default: "openai". */
|
||||
thinkingFormat?: "openai" | "zai" | "qwen";
|
||||
/** OpenRouter-specific routing preferences. Only used when baseUrl points to OpenRouter. */
|
||||
openRouterRouting?: OpenRouterRouting;
|
||||
/** Vercel AI Gateway routing preferences. Only used when baseUrl points to Vercel AI Gateway. */
|
||||
vercelGatewayRouting?: VercelGatewayRouting;
|
||||
/** Whether the provider supports the `strict` field in tool definitions. Default: true. */
|
||||
supportsStrictMode?: boolean;
|
||||
}
|
||||
|
||||
/** Compatibility settings for OpenAI Responses APIs. */
|
||||
export interface OpenAIResponsesCompat {
|
||||
// Reserved for future use
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenRouter provider routing preferences.
|
||||
* Controls which upstream providers OpenRouter routes requests to.
|
||||
* @see https://openrouter.ai/docs/provider-routing
|
||||
*/
|
||||
export interface OpenRouterRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["amazon-bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Vercel AI Gateway routing preferences.
|
||||
* Controls which upstream providers the gateway routes requests to.
|
||||
* @see https://vercel.com/docs/ai-gateway/models-and-providers/provider-options
|
||||
*/
|
||||
export interface VercelGatewayRouting {
|
||||
/** List of provider slugs to exclusively use for this request (e.g., ["bedrock", "anthropic"]). */
|
||||
only?: string[];
|
||||
/** List of provider slugs to try in order (e.g., ["anthropic", "openai"]). */
|
||||
order?: string[];
|
||||
}
|
||||
|
||||
// Model interface for the unified model system
|
||||
export interface Model<TApi extends Api> {
|
||||
id: string;
|
||||
name: string;
|
||||
api: TApi;
|
||||
provider: Provider;
|
||||
baseUrl: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number; // $/million tokens
|
||||
output: number; // $/million tokens
|
||||
cacheRead: number; // $/million tokens
|
||||
cacheWrite: number; // $/million tokens
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
/** Compatibility overrides for OpenAI-compatible APIs. If not set, auto-detected from baseUrl. */
|
||||
compat?: TApi extends "openai-completions"
|
||||
? OpenAICompletionsCompat
|
||||
: TApi extends "openai-responses"
|
||||
? OpenAIResponsesCompat
|
||||
: never;
|
||||
}
|
||||
87
packages/pi-ai/src/utils/event-stream.ts
Normal file
87
packages/pi-ai/src/utils/event-stream.ts
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
import type { AssistantMessage, AssistantMessageEvent } from "../types.js";
|
||||
|
||||
// Generic event stream class for async iteration
|
||||
export class EventStream<T, R = T> implements AsyncIterable<T> {
|
||||
private queue: T[] = [];
|
||||
private waiting: ((value: IteratorResult<T>) => void)[] = [];
|
||||
private done = false;
|
||||
private finalResultPromise: Promise<R>;
|
||||
private resolveFinalResult!: (result: R) => void;
|
||||
|
||||
constructor(
|
||||
private isComplete: (event: T) => boolean,
|
||||
private extractResult: (event: T) => R,
|
||||
) {
|
||||
this.finalResultPromise = new Promise((resolve) => {
|
||||
this.resolveFinalResult = resolve;
|
||||
});
|
||||
}
|
||||
|
||||
push(event: T): void {
|
||||
if (this.done) return;
|
||||
|
||||
if (this.isComplete(event)) {
|
||||
this.done = true;
|
||||
this.resolveFinalResult(this.extractResult(event));
|
||||
}
|
||||
|
||||
// Deliver to waiting consumer or queue it
|
||||
const waiter = this.waiting.shift();
|
||||
if (waiter) {
|
||||
waiter({ value: event, done: false });
|
||||
} else {
|
||||
this.queue.push(event);
|
||||
}
|
||||
}
|
||||
|
||||
end(result?: R): void {
|
||||
this.done = true;
|
||||
if (result !== undefined) {
|
||||
this.resolveFinalResult(result);
|
||||
}
|
||||
// Notify all waiting consumers that we're done
|
||||
while (this.waiting.length > 0) {
|
||||
const waiter = this.waiting.shift()!;
|
||||
waiter({ value: undefined as any, done: true });
|
||||
}
|
||||
}
|
||||
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<T> {
|
||||
while (true) {
|
||||
if (this.queue.length > 0) {
|
||||
yield this.queue.shift()!;
|
||||
} else if (this.done) {
|
||||
return;
|
||||
} else {
|
||||
const result = await new Promise<IteratorResult<T>>((resolve) => this.waiting.push(resolve));
|
||||
if (result.done) return;
|
||||
yield result.value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result(): Promise<R> {
|
||||
return this.finalResultPromise;
|
||||
}
|
||||
}
|
||||
|
||||
export class AssistantMessageEventStream extends EventStream<AssistantMessageEvent, AssistantMessage> {
|
||||
constructor() {
|
||||
super(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") {
|
||||
return event.message;
|
||||
} else if (event.type === "error") {
|
||||
return event.error;
|
||||
}
|
||||
throw new Error("Unexpected event type for final result");
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/** Factory function for AssistantMessageEventStream (for use in extensions) */
|
||||
export function createAssistantMessageEventStream(): AssistantMessageEventStream {
|
||||
return new AssistantMessageEventStream();
|
||||
}
|
||||
13
packages/pi-ai/src/utils/hash.ts
Normal file
13
packages/pi-ai/src/utils/hash.ts
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
/** Fast deterministic hash to shorten long strings */
|
||||
export function shortHash(str: string): string {
|
||||
let h1 = 0xdeadbeef;
|
||||
let h2 = 0x41c6ce57;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
const ch = str.charCodeAt(i);
|
||||
h1 = Math.imul(h1 ^ ch, 2654435761);
|
||||
h2 = Math.imul(h2 ^ ch, 1597334677);
|
||||
}
|
||||
h1 = Math.imul(h1 ^ (h1 >>> 16), 2246822507) ^ Math.imul(h2 ^ (h2 >>> 13), 3266489909);
|
||||
h2 = Math.imul(h2 ^ (h2 >>> 16), 2246822507) ^ Math.imul(h1 ^ (h1 >>> 13), 3266489909);
|
||||
return (h2 >>> 0).toString(36) + (h1 >>> 0).toString(36);
|
||||
}
|
||||
28
packages/pi-ai/src/utils/json-parse.ts
Normal file
28
packages/pi-ai/src/utils/json-parse.ts
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
import { parse as partialParse } from "partial-json";
|
||||
|
||||
/**
|
||||
* Attempts to parse potentially incomplete JSON during streaming.
|
||||
* Always returns a valid object, even if the JSON is incomplete.
|
||||
*
|
||||
* @param partialJson The partial JSON string from streaming
|
||||
* @returns Parsed object or empty object if parsing fails
|
||||
*/
|
||||
export function parseStreamingJson<T = any>(partialJson: string | undefined): T {
|
||||
if (!partialJson || partialJson.trim() === "") {
|
||||
return {} as T;
|
||||
}
|
||||
|
||||
// Try standard parsing first (fastest for complete JSON)
|
||||
try {
|
||||
return JSON.parse(partialJson) as T;
|
||||
} catch {
|
||||
// Try partial-json for incomplete JSON
|
||||
try {
|
||||
const result = partialParse(partialJson);
|
||||
return (result ?? {}) as T;
|
||||
} catch {
|
||||
// If all parsing fails, return empty object
|
||||
return {} as T;
|
||||
}
|
||||
}
|
||||
}
|
||||
138
packages/pi-ai/src/utils/oauth/anthropic.ts
Normal file
138
packages/pi-ai/src/utils/oauth/anthropic.ts
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
/**
|
||||
* Anthropic OAuth flow (Claude Pro/Max)
|
||||
*/
|
||||
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://console.anthropic.com/v1/oauth/token";
|
||||
const REDIRECT_URI = "https://console.anthropic.com/oauth/code/callback";
|
||||
const SCOPES = "org:create_api_key user:profile user:inference";
|
||||
|
||||
/**
|
||||
* Login with Anthropic OAuth (device code flow)
|
||||
*
|
||||
* @param onAuthUrl - Callback to handle the authorization URL (e.g., open browser)
|
||||
* @param onPromptCode - Callback to prompt user for the authorization code
|
||||
*/
|
||||
export async function loginAnthropic(
|
||||
onAuthUrl: (url: string) => void,
|
||||
onPromptCode: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
code: "true",
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES,
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
});
|
||||
|
||||
const authUrl = `${AUTHORIZE_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuthUrl(authUrl);
|
||||
|
||||
// Wait for user to paste authorization code (format: code#state)
|
||||
const authCode = await onPromptCode();
|
||||
const splits = authCode.split("#");
|
||||
const code = splits[0];
|
||||
const state = splits[1];
|
||||
|
||||
// Exchange code for tokens
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code: code,
|
||||
state: state,
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
// Save credentials
|
||||
return {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Anthropic OAuth token
|
||||
*/
|
||||
export async function refreshAnthropicToken(refreshToken: string): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
grant_type: "refresh_token",
|
||||
client_id: CLIENT_ID,
|
||||
refresh_token: refreshToken,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Anthropic token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
export const anthropicOAuthProvider: OAuthProviderInterface = {
|
||||
id: "anthropic",
|
||||
name: "Anthropic (Claude Pro/Max)",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAnthropic(
|
||||
(url) => callbacks.onAuth({ url }),
|
||||
() => callbacks.onPrompt({ message: "Paste the authorization code:" }),
|
||||
);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshAnthropicToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
381
packages/pi-ai/src/utils/oauth/github-copilot.ts
Normal file
381
packages/pi-ai/src/utils/oauth/github-copilot.ts
Normal file
|
|
@ -0,0 +1,381 @@
|
|||
/**
|
||||
* GitHub Copilot OAuth flow
|
||||
*/
|
||||
|
||||
import { getModels } from "../../models.js";
|
||||
import type { Api, Model } from "../../types.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type CopilotCredentials = OAuthCredentials & {
|
||||
enterpriseUrl?: string;
|
||||
};
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("SXYxLmI1MDdhMDhjODdlY2ZlOTg=");
|
||||
|
||||
const COPILOT_HEADERS = {
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
"Editor-Version": "vscode/1.107.0",
|
||||
"Editor-Plugin-Version": "copilot-chat/0.35.0",
|
||||
"Copilot-Integration-Id": "vscode-chat",
|
||||
} as const;
|
||||
|
||||
type DeviceCodeResponse = {
|
||||
device_code: string;
|
||||
user_code: string;
|
||||
verification_uri: string;
|
||||
interval: number;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
type DeviceTokenSuccessResponse = {
|
||||
access_token: string;
|
||||
token_type?: string;
|
||||
scope?: string;
|
||||
};
|
||||
|
||||
type DeviceTokenErrorResponse = {
|
||||
error: string;
|
||||
error_description?: string;
|
||||
interval?: number;
|
||||
};
|
||||
|
||||
export function normalizeDomain(input: string): string | null {
|
||||
const trimmed = input.trim();
|
||||
if (!trimmed) return null;
|
||||
try {
|
||||
const url = trimmed.includes("://") ? new URL(trimmed) : new URL(`https://${trimmed}`);
|
||||
return url.hostname;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function getUrls(domain: string): {
|
||||
deviceCodeUrl: string;
|
||||
accessTokenUrl: string;
|
||||
copilotTokenUrl: string;
|
||||
} {
|
||||
return {
|
||||
deviceCodeUrl: `https://${domain}/login/device/code`,
|
||||
accessTokenUrl: `https://${domain}/login/oauth/access_token`,
|
||||
copilotTokenUrl: `https://api.${domain}/copilot_internal/v2/token`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the proxy-ep from a Copilot token and convert to API base URL.
|
||||
* Token format: tid=...;exp=...;proxy-ep=proxy.individual.githubcopilot.com;...
|
||||
* Returns API URL like https://api.individual.githubcopilot.com
|
||||
*/
|
||||
function getBaseUrlFromToken(token: string): string | null {
|
||||
const match = token.match(/proxy-ep=([^;]+)/);
|
||||
if (!match) return null;
|
||||
const proxyHost = match[1];
|
||||
// Convert proxy.xxx to api.xxx
|
||||
const apiHost = proxyHost.replace(/^proxy\./, "api.");
|
||||
return `https://${apiHost}`;
|
||||
}
|
||||
|
||||
export function getGitHubCopilotBaseUrl(token?: string, enterpriseDomain?: string): string {
|
||||
// If we have a token, extract the base URL from proxy-ep
|
||||
if (token) {
|
||||
const urlFromToken = getBaseUrlFromToken(token);
|
||||
if (urlFromToken) return urlFromToken;
|
||||
}
|
||||
// Fallback for enterprise or if token parsing fails
|
||||
if (enterpriseDomain) return `https://copilot-api.${enterpriseDomain}`;
|
||||
return "https://api.individual.githubcopilot.com";
|
||||
}
|
||||
|
||||
async function fetchJson(url: string, init: RequestInit): Promise<unknown> {
|
||||
const response = await fetch(url, init);
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`${response.status} ${response.statusText}: ${text}`);
|
||||
}
|
||||
return response.json();
|
||||
}
|
||||
|
||||
async function startDeviceFlow(domain: string): Promise<DeviceCodeResponse> {
|
||||
const urls = getUrls(domain);
|
||||
const data = await fetchJson(urls.deviceCodeUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
scope: "read:user",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!data || typeof data !== "object") {
|
||||
throw new Error("Invalid device code response");
|
||||
}
|
||||
|
||||
const deviceCode = (data as Record<string, unknown>).device_code;
|
||||
const userCode = (data as Record<string, unknown>).user_code;
|
||||
const verificationUri = (data as Record<string, unknown>).verification_uri;
|
||||
const interval = (data as Record<string, unknown>).interval;
|
||||
const expiresIn = (data as Record<string, unknown>).expires_in;
|
||||
|
||||
if (
|
||||
typeof deviceCode !== "string" ||
|
||||
typeof userCode !== "string" ||
|
||||
typeof verificationUri !== "string" ||
|
||||
typeof interval !== "number" ||
|
||||
typeof expiresIn !== "number"
|
||||
) {
|
||||
throw new Error("Invalid device code response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
device_code: deviceCode,
|
||||
user_code: userCode,
|
||||
verification_uri: verificationUri,
|
||||
interval,
|
||||
expires_in: expiresIn,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Sleep that can be interrupted by an AbortSignal
|
||||
*/
|
||||
function abortableSleep(ms: number, signal?: AbortSignal): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
if (signal?.aborted) {
|
||||
reject(new Error("Login cancelled"));
|
||||
return;
|
||||
}
|
||||
|
||||
const timeout = setTimeout(resolve, ms);
|
||||
|
||||
signal?.addEventListener(
|
||||
"abort",
|
||||
() => {
|
||||
clearTimeout(timeout);
|
||||
reject(new Error("Login cancelled"));
|
||||
},
|
||||
{ once: true },
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
async function pollForGitHubAccessToken(
|
||||
domain: string,
|
||||
deviceCode: string,
|
||||
intervalSeconds: number,
|
||||
expiresIn: number,
|
||||
signal?: AbortSignal,
|
||||
) {
|
||||
const urls = getUrls(domain);
|
||||
const deadline = Date.now() + expiresIn * 1000;
|
||||
let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000));
|
||||
|
||||
while (Date.now() < deadline) {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const raw = await fetchJson(urls.accessTokenUrl, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "GitHubCopilotChat/0.35.0",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
client_id: CLIENT_ID,
|
||||
device_code: deviceCode,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
}),
|
||||
});
|
||||
|
||||
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenSuccessResponse).access_token === "string") {
|
||||
return (raw as DeviceTokenSuccessResponse).access_token;
|
||||
}
|
||||
|
||||
if (raw && typeof raw === "object" && typeof (raw as DeviceTokenErrorResponse).error === "string") {
|
||||
const err = (raw as DeviceTokenErrorResponse).error;
|
||||
if (err === "authorization_pending") {
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (err === "slow_down") {
|
||||
intervalMs += 5000;
|
||||
await abortableSleep(intervalMs, signal);
|
||||
continue;
|
||||
}
|
||||
|
||||
throw new Error(`Device flow failed: ${err}`);
|
||||
}
|
||||
|
||||
await abortableSleep(intervalMs, signal);
|
||||
}
|
||||
|
||||
throw new Error("Device flow timed out");
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh GitHub Copilot token
|
||||
*/
|
||||
export async function refreshGitHubCopilotToken(
|
||||
refreshToken: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<OAuthCredentials> {
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
const urls = getUrls(domain);
|
||||
|
||||
const raw = await fetchJson(urls.copilotTokenUrl, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${refreshToken}`,
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
});
|
||||
|
||||
if (!raw || typeof raw !== "object") {
|
||||
throw new Error("Invalid Copilot token response");
|
||||
}
|
||||
|
||||
const token = (raw as Record<string, unknown>).token;
|
||||
const expiresAt = (raw as Record<string, unknown>).expires_at;
|
||||
|
||||
if (typeof token !== "string" || typeof expiresAt !== "number") {
|
||||
throw new Error("Invalid Copilot token response fields");
|
||||
}
|
||||
|
||||
return {
|
||||
refresh: refreshToken,
|
||||
access: token,
|
||||
expires: expiresAt * 1000 - 5 * 60 * 1000,
|
||||
enterpriseUrl: enterpriseDomain,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable a model for the user's GitHub Copilot account.
|
||||
* This is required for some models (like Claude, Grok) before they can be used.
|
||||
*/
|
||||
async function enableGitHubCopilotModel(token: string, modelId: string, enterpriseDomain?: string): Promise<boolean> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
const url = `${baseUrl}/models/${modelId}/policy`;
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
...COPILOT_HEADERS,
|
||||
"openai-intent": "chat-policy",
|
||||
"x-interaction-type": "chat-policy",
|
||||
},
|
||||
body: JSON.stringify({ state: "enabled" }),
|
||||
});
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Enable all known GitHub Copilot models that may require policy acceptance.
|
||||
* Called after successful login to ensure all models are available.
|
||||
*/
|
||||
async function enableAllGitHubCopilotModels(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
onProgress?: (model: string, success: boolean) => void,
|
||||
): Promise<void> {
|
||||
const models = getModels("github-copilot");
|
||||
await Promise.all(
|
||||
models.map(async (model) => {
|
||||
const success = await enableGitHubCopilotModel(token, model.id, enterpriseDomain);
|
||||
onProgress?.(model.id, success);
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with GitHub Copilot OAuth (device code flow)
|
||||
*
|
||||
* @param options.onAuth - Callback with URL and optional instructions (user code)
|
||||
* @param options.onPrompt - Callback to prompt user for input
|
||||
* @param options.onProgress - Optional progress callback
|
||||
* @param options.signal - Optional AbortSignal for cancellation
|
||||
*/
|
||||
export async function loginGitHubCopilot(options: {
|
||||
onAuth: (url: string, instructions?: string) => void;
|
||||
onPrompt: (prompt: { message: string; placeholder?: string; allowEmpty?: boolean }) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
signal?: AbortSignal;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const input = await options.onPrompt({
|
||||
message: "GitHub Enterprise URL/domain (blank for github.com)",
|
||||
placeholder: "company.ghe.com",
|
||||
allowEmpty: true,
|
||||
});
|
||||
|
||||
if (options.signal?.aborted) {
|
||||
throw new Error("Login cancelled");
|
||||
}
|
||||
|
||||
const trimmed = input.trim();
|
||||
const enterpriseDomain = normalizeDomain(input);
|
||||
if (trimmed && !enterpriseDomain) {
|
||||
throw new Error("Invalid GitHub Enterprise URL/domain");
|
||||
}
|
||||
const domain = enterpriseDomain || "github.com";
|
||||
|
||||
const device = await startDeviceFlow(domain);
|
||||
options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`);
|
||||
|
||||
const githubAccessToken = await pollForGitHubAccessToken(
|
||||
domain,
|
||||
device.device_code,
|
||||
device.interval,
|
||||
device.expires_in,
|
||||
options.signal,
|
||||
);
|
||||
const credentials = await refreshGitHubCopilotToken(githubAccessToken, enterpriseDomain ?? undefined);
|
||||
|
||||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
||||
return credentials;
|
||||
}
|
||||
|
||||
export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
||||
id: "github-copilot",
|
||||
name: "GitHub Copilot",
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGitHubCopilot({
|
||||
onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }),
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
signal: callbacks.signal,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
|
||||
modifyModels(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[] {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
},
|
||||
};
|
||||
457
packages/pi-ai/src/utils/oauth/google-antigravity.ts
Normal file
457
packages/pi-ai/src/utils/oauth/google-antigravity.ts
Normal file
|
|
@ -0,0 +1,457 @@
|
|||
/**
|
||||
* Antigravity OAuth flow (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
* Uses different OAuth credentials than google-gemini-cli for access to additional models.
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type AntigravityCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
// Antigravity OAuth credentials (different from Gemini CLI)
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"MTA3MTAwNjA2MDU5MS10bWhzc2luMmgyMWxjcmUyMzV2dG9sb2poNGc0MDNlcC5hcHBzLmdvb2dsZXVzZXJjb250ZW50LmNvbQ==",
|
||||
);
|
||||
const CLIENT_SECRET = decode("R09DU1BYLUs1OEZXUjQ4NkxkTEoxbUxCOHNYQzR6NnFEQWY=");
|
||||
const REDIRECT_URI = "http://localhost:51121/oauth-callback";
|
||||
|
||||
// Antigravity requires additional scopes
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
"https://www.googleapis.com/auth/cclog",
|
||||
"https://www.googleapis.com/auth/experimentsandconfigs",
|
||||
];
|
||||
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
|
||||
// Fallback project ID when discovery fails
|
||||
const DEFAULT_PROJECT_ID = "rising-fact-p41fc";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error("Antigravity OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let result: { code: string; state: string } | null = null;
|
||||
let cancelled = false;
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:51121`);
|
||||
|
||||
if (url.pathname === "/oauth-callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
|
||||
);
|
||||
result = { code, state };
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404);
|
||||
res.end();
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(51121, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
while (!result && !cancelled) {
|
||||
await sleep();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string | { id?: string };
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a project for the user
|
||||
*/
|
||||
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"Client-Metadata": JSON.stringify({
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
}),
|
||||
};
|
||||
|
||||
// Try endpoints in order: prod first, then sandbox
|
||||
const endpoints = ["https://cloudcode-pa.googleapis.com", "https://daily-cloudcode-pa.sandbox.googleapis.com"];
|
||||
|
||||
onProgress?.("Checking for existing project...");
|
||||
|
||||
for (const endpoint of endpoints) {
|
||||
try {
|
||||
const loadResponse = await fetch(`${endpoint}/v1internal:loadCodeAssist`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
if (loadResponse.ok) {
|
||||
const data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
|
||||
// Handle both string and object formats
|
||||
if (typeof data.cloudaicompanionProject === "string" && data.cloudaicompanionProject) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
if (
|
||||
data.cloudaicompanionProject &&
|
||||
typeof data.cloudaicompanionProject === "object" &&
|
||||
data.cloudaicompanionProject.id
|
||||
) {
|
||||
return data.cloudaicompanionProject.id;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Try next endpoint
|
||||
}
|
||||
}
|
||||
|
||||
// Use fallback project ID
|
||||
onProgress?.("Using default project...");
|
||||
return DEFAULT_PROJECT_ID;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Antigravity token
|
||||
*/
|
||||
export async function refreshAntigravityToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Antigravity token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Antigravity OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginAntigravity(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const antigravityOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-antigravity",
|
||||
name: "Antigravity (Gemini 3, Claude, GPT-OSS)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginAntigravity(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Antigravity credentials missing projectId");
|
||||
}
|
||||
return refreshAntigravityToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as AntigravityCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
599
packages/pi-ai/src/utils/oauth/google-gemini-cli.ts
Normal file
599
packages/pi-ai/src/utils/oauth/google-gemini-cli.ts
Normal file
|
|
@ -0,0 +1,599 @@
|
|||
/**
|
||||
* Gemini CLI OAuth flow (Google Cloud Code Assist)
|
||||
* Standard Gemini models only (gemini-2.0-flash, gemini-2.5-*)
|
||||
*
|
||||
* NOTE: This module uses Node.js http.createServer for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
type GeminiCredentials = OAuthCredentials & {
|
||||
projectId: string;
|
||||
};
|
||||
|
||||
let _createServer: typeof import("node:http").createServer | null = null;
|
||||
let _httpImportPromise: Promise<void> | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
_httpImportPromise = import("node:http").then((m) => {
|
||||
_createServer = m.createServer;
|
||||
});
|
||||
}
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode(
|
||||
"NjgxMjU1ODA5Mzk1LW9vOGZ0Mm9wcmRybnA5ZTNhcWY2YXYzaG1kaWIxMzVqLmFwcHMuZ29vZ2xldXNlcmNvbnRlbnQuY29t",
|
||||
);
|
||||
const CLIENT_SECRET = decode("R09DU1BYLTR1SGdNUG0tMW83U2stZ2VWNkN1NWNsWEZzeGw=");
|
||||
const REDIRECT_URI = "http://localhost:8085/oauth2callback";
|
||||
const SCOPES = [
|
||||
"https://www.googleapis.com/auth/cloud-platform",
|
||||
"https://www.googleapis.com/auth/userinfo.email",
|
||||
"https://www.googleapis.com/auth/userinfo.profile",
|
||||
];
|
||||
const AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth";
|
||||
const TOKEN_URL = "https://oauth2.googleapis.com/token";
|
||||
const CODE_ASSIST_ENDPOINT = "https://cloudcode-pa.googleapis.com";
|
||||
|
||||
type CallbackServerInfo = {
|
||||
server: Server;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string; state: string } | null>;
|
||||
};
|
||||
|
||||
/**
|
||||
* Start a local HTTP server to receive the OAuth callback
|
||||
*/
|
||||
async function getNodeCreateServer(): Promise<typeof import("node:http").createServer> {
|
||||
if (_createServer) return _createServer;
|
||||
if (_httpImportPromise) {
|
||||
await _httpImportPromise;
|
||||
}
|
||||
if (_createServer) return _createServer;
|
||||
throw new Error("Gemini CLI OAuth is only available in Node.js environments");
|
||||
}
|
||||
|
||||
async function startCallbackServer(): Promise<CallbackServerInfo> {
|
||||
const createServer = await getNodeCreateServer();
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
let result: { code: string; state: string } | null = null;
|
||||
let cancelled = false;
|
||||
|
||||
const server = createServer((req, res) => {
|
||||
const url = new URL(req.url || "", `http://localhost:8085`);
|
||||
|
||||
if (url.pathname === "/oauth2callback") {
|
||||
const code = url.searchParams.get("code");
|
||||
const state = url.searchParams.get("state");
|
||||
const error = url.searchParams.get("error");
|
||||
|
||||
if (error) {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Error: ${error}</p><p>You can close this window.</p></body></html>`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (code && state) {
|
||||
res.writeHead(200, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Successful</h1><p>You can close this window and return to the terminal.</p></body></html>`,
|
||||
);
|
||||
result = { code, state };
|
||||
} else {
|
||||
res.writeHead(400, { "Content-Type": "text/html" });
|
||||
res.end(
|
||||
`<html><body><h1>Authentication Failed</h1><p>Missing code or state parameter.</p></body></html>`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
res.writeHead(404);
|
||||
res.end();
|
||||
}
|
||||
});
|
||||
|
||||
server.on("error", (err) => {
|
||||
reject(err);
|
||||
});
|
||||
|
||||
server.listen(8085, "127.0.0.1", () => {
|
||||
resolve({
|
||||
server,
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
while (!result && !cancelled) {
|
||||
await sleep();
|
||||
}
|
||||
return result;
|
||||
},
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse redirect URL to extract code and state
|
||||
*/
|
||||
function parseRedirectUrl(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// Not a URL, return empty
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
interface LoadCodeAssistPayload {
|
||||
cloudaicompanionProject?: string;
|
||||
currentTier?: { id?: string };
|
||||
allowedTiers?: Array<{ id?: string; isDefault?: boolean }>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Long-running operation response from onboardUser
|
||||
*/
|
||||
interface LongRunningOperationResponse {
|
||||
name?: string;
|
||||
done?: boolean;
|
||||
response?: {
|
||||
cloudaicompanionProject?: { id?: string };
|
||||
};
|
||||
}
|
||||
|
||||
// Tier IDs as used by the Cloud Code API
|
||||
const TIER_FREE = "free-tier";
|
||||
const TIER_LEGACY = "legacy-tier";
|
||||
const TIER_STANDARD = "standard-tier";
|
||||
|
||||
interface GoogleRpcErrorResponse {
|
||||
error?: {
|
||||
details?: Array<{ reason?: string }>;
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait helper for onboarding retries
|
||||
*/
|
||||
function wait(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
}
|
||||
|
||||
/**
|
||||
* Get default tier from allowed tiers
|
||||
*/
|
||||
function getDefaultTier(allowedTiers?: Array<{ id?: string; isDefault?: boolean }>): { id?: string } {
|
||||
if (!allowedTiers || allowedTiers.length === 0) return { id: TIER_LEGACY };
|
||||
const defaultTier = allowedTiers.find((t) => t.isDefault);
|
||||
return defaultTier ?? { id: TIER_LEGACY };
|
||||
}
|
||||
|
||||
function isVpcScAffectedUser(payload: unknown): boolean {
|
||||
if (!payload || typeof payload !== "object") return false;
|
||||
if (!("error" in payload)) return false;
|
||||
const error = (payload as GoogleRpcErrorResponse).error;
|
||||
if (!error?.details || !Array.isArray(error.details)) return false;
|
||||
return error.details.some((detail) => detail.reason === "SECURITY_POLICY_VIOLATED");
|
||||
}
|
||||
|
||||
/**
|
||||
* Poll a long-running operation until completion
|
||||
*/
|
||||
async function pollOperation(
|
||||
operationName: string,
|
||||
headers: Record<string, string>,
|
||||
onProgress?: (message: string) => void,
|
||||
): Promise<LongRunningOperationResponse> {
|
||||
let attempt = 0;
|
||||
while (true) {
|
||||
if (attempt > 0) {
|
||||
onProgress?.(`Waiting for project provisioning (attempt ${attempt + 1})...`);
|
||||
await wait(5000);
|
||||
}
|
||||
|
||||
const response = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal/${operationName}`, {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to poll operation: ${response.status} ${response.statusText}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as LongRunningOperationResponse;
|
||||
if (data.done) {
|
||||
return data;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover or provision a Google Cloud project for the user
|
||||
*/
|
||||
async function discoverProject(accessToken: string, onProgress?: (message: string) => void): Promise<string> {
|
||||
// Check for user-provided project ID via environment variable
|
||||
const envProjectId = process.env.GOOGLE_CLOUD_PROJECT || process.env.GOOGLE_CLOUD_PROJECT_ID;
|
||||
|
||||
const headers = {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-api-nodejs-client/9.15.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
};
|
||||
|
||||
// Try to load existing project via loadCodeAssist
|
||||
onProgress?.("Checking for existing Cloud Code Assist project...");
|
||||
const loadResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:loadCodeAssist`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify({
|
||||
cloudaicompanionProject: envProjectId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
duetProject: envProjectId,
|
||||
},
|
||||
}),
|
||||
});
|
||||
|
||||
let data: LoadCodeAssistPayload;
|
||||
|
||||
if (!loadResponse.ok) {
|
||||
let errorPayload: unknown;
|
||||
try {
|
||||
errorPayload = await loadResponse.clone().json();
|
||||
} catch {
|
||||
errorPayload = undefined;
|
||||
}
|
||||
|
||||
if (isVpcScAffectedUser(errorPayload)) {
|
||||
data = { currentTier: { id: TIER_STANDARD } };
|
||||
} else {
|
||||
const errorText = await loadResponse.text();
|
||||
throw new Error(`loadCodeAssist failed: ${loadResponse.status} ${loadResponse.statusText}: ${errorText}`);
|
||||
}
|
||||
} else {
|
||||
data = (await loadResponse.json()) as LoadCodeAssistPayload;
|
||||
}
|
||||
|
||||
// If user already has a current tier and project, use it
|
||||
if (data.currentTier) {
|
||||
if (data.cloudaicompanionProject) {
|
||||
return data.cloudaicompanionProject;
|
||||
}
|
||||
// User has a tier but no managed project - they need to provide one via env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
// User needs to be onboarded - get the default tier
|
||||
const tier = getDefaultTier(data.allowedTiers);
|
||||
const tierId = tier?.id ?? TIER_FREE;
|
||||
|
||||
if (tierId !== TIER_FREE && !envProjectId) {
|
||||
throw new Error(
|
||||
"This account requires setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
onProgress?.("Provisioning Cloud Code Assist project (this may take a moment)...");
|
||||
|
||||
// Build onboard request - for free tier, don't include project ID (Google provisions one)
|
||||
// For other tiers, include the user's project ID if available
|
||||
const onboardBody: Record<string, unknown> = {
|
||||
tierId,
|
||||
metadata: {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
},
|
||||
};
|
||||
|
||||
if (tierId !== TIER_FREE && envProjectId) {
|
||||
onboardBody.cloudaicompanionProject = envProjectId;
|
||||
(onboardBody.metadata as Record<string, unknown>).duetProject = envProjectId;
|
||||
}
|
||||
|
||||
// Start onboarding - this returns a long-running operation
|
||||
const onboardResponse = await fetch(`${CODE_ASSIST_ENDPOINT}/v1internal:onboardUser`, {
|
||||
method: "POST",
|
||||
headers,
|
||||
body: JSON.stringify(onboardBody),
|
||||
});
|
||||
|
||||
if (!onboardResponse.ok) {
|
||||
const errorText = await onboardResponse.text();
|
||||
throw new Error(`onboardUser failed: ${onboardResponse.status} ${onboardResponse.statusText}: ${errorText}`);
|
||||
}
|
||||
|
||||
let lroData = (await onboardResponse.json()) as LongRunningOperationResponse;
|
||||
|
||||
// If the operation isn't done yet, poll until completion
|
||||
if (!lroData.done && lroData.name) {
|
||||
lroData = await pollOperation(lroData.name, headers, onProgress);
|
||||
}
|
||||
|
||||
// Try to get project ID from the response
|
||||
const projectId = lroData.response?.cloudaicompanionProject?.id;
|
||||
if (projectId) {
|
||||
return projectId;
|
||||
}
|
||||
|
||||
// If no project ID from onboarding, fall back to env var
|
||||
if (envProjectId) {
|
||||
return envProjectId;
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
"Could not discover or provision a Google Cloud project. " +
|
||||
"Try setting the GOOGLE_CLOUD_PROJECT or GOOGLE_CLOUD_PROJECT_ID environment variable. " +
|
||||
"See https://goo.gle/gemini-cli-auth-docs#workspace-gca",
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get user email from the access token
|
||||
*/
|
||||
async function getUserEmail(accessToken: string): Promise<string | undefined> {
|
||||
try {
|
||||
const response = await fetch("https://www.googleapis.com/oauth2/v1/userinfo?alt=json", {
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
const data = (await response.json()) as { email?: string };
|
||||
return data.email;
|
||||
}
|
||||
} catch {
|
||||
// Ignore errors, email is optional
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh Google Cloud Code Assist token
|
||||
*/
|
||||
export async function refreshGoogleCloudToken(refreshToken: string, projectId: string): Promise<OAuthCredentials> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
refresh_token: refreshToken,
|
||||
grant_type: "refresh_token",
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
throw new Error(`Google Cloud token refresh failed: ${error}`);
|
||||
}
|
||||
|
||||
const data = (await response.json()) as {
|
||||
access_token: string;
|
||||
expires_in: number;
|
||||
refresh_token?: string;
|
||||
};
|
||||
|
||||
return {
|
||||
refresh: data.refresh_token || refreshToken,
|
||||
access: data.access_token,
|
||||
expires: Date.now() + data.expires_in * 1000 - 5 * 60 * 1000,
|
||||
projectId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with Gemini CLI (Google Cloud Code Assist) OAuth
|
||||
*
|
||||
* @param onAuth - Callback with URL and optional instructions
|
||||
* @param onProgress - Optional progress callback
|
||||
* @param onManualCodeInput - Optional promise that resolves with user-pasted redirect URL.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
*/
|
||||
export async function loginGeminiCli(
|
||||
onAuth: (info: { url: string; instructions?: string }) => void,
|
||||
onProgress?: (message: string) => void,
|
||||
onManualCodeInput?: () => Promise<string>,
|
||||
): Promise<OAuthCredentials> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
|
||||
// Start local server for callback
|
||||
onProgress?.("Starting local server for OAuth callback...");
|
||||
const server = await startCallbackServer();
|
||||
|
||||
let code: string | undefined;
|
||||
|
||||
try {
|
||||
// Build authorization URL
|
||||
const authParams = new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
response_type: "code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
scope: SCOPES.join(" "),
|
||||
code_challenge: challenge,
|
||||
code_challenge_method: "S256",
|
||||
state: verifier,
|
||||
access_type: "offline",
|
||||
prompt: "consent",
|
||||
});
|
||||
|
||||
const authUrl = `${AUTH_URL}?${authParams.toString()}`;
|
||||
|
||||
// Notify caller with URL to open
|
||||
onAuth({
|
||||
url: authUrl,
|
||||
instructions: "Complete the sign-in in your browser.",
|
||||
});
|
||||
|
||||
// Wait for the callback, racing with manual input if provided
|
||||
onProgress?.("Waiting for OAuth callback...");
|
||||
|
||||
if (onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualInput: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualInput = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won - verify state
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
} else if (manualInput) {
|
||||
// Manual input won
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualInput) {
|
||||
const parsed = parseRedirectUrl(manualInput);
|
||||
if (parsed.state && parsed.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: just wait for callback
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
if (result.state !== verifier) {
|
||||
throw new Error("OAuth state mismatch - possible CSRF attack");
|
||||
}
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("No authorization code received");
|
||||
}
|
||||
|
||||
// Exchange code for tokens
|
||||
onProgress?.("Exchanging authorization code for tokens...");
|
||||
const tokenResponse = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
body: new URLSearchParams({
|
||||
client_id: CLIENT_ID,
|
||||
client_secret: CLIENT_SECRET,
|
||||
code,
|
||||
grant_type: "authorization_code",
|
||||
redirect_uri: REDIRECT_URI,
|
||||
code_verifier: verifier,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!tokenResponse.ok) {
|
||||
const error = await tokenResponse.text();
|
||||
throw new Error(`Token exchange failed: ${error}`);
|
||||
}
|
||||
|
||||
const tokenData = (await tokenResponse.json()) as {
|
||||
access_token: string;
|
||||
refresh_token: string;
|
||||
expires_in: number;
|
||||
};
|
||||
|
||||
if (!tokenData.refresh_token) {
|
||||
throw new Error("No refresh token received. Please try again.");
|
||||
}
|
||||
|
||||
// Get user email
|
||||
onProgress?.("Getting user info...");
|
||||
const email = await getUserEmail(tokenData.access_token);
|
||||
|
||||
// Discover project
|
||||
const projectId = await discoverProject(tokenData.access_token, onProgress);
|
||||
|
||||
// Calculate expiry time (current time + expires_in seconds - 5 min buffer)
|
||||
const expiresAt = Date.now() + tokenData.expires_in * 1000 - 5 * 60 * 1000;
|
||||
|
||||
const credentials: OAuthCredentials = {
|
||||
refresh: tokenData.refresh_token,
|
||||
access: tokenData.access_token,
|
||||
expires: expiresAt,
|
||||
projectId,
|
||||
email,
|
||||
};
|
||||
|
||||
return credentials;
|
||||
} finally {
|
||||
server.server.close();
|
||||
}
|
||||
}
|
||||
|
||||
export const geminiCliOAuthProvider: OAuthProviderInterface = {
|
||||
id: "google-gemini-cli",
|
||||
name: "Google Cloud Code Assist (Gemini CLI)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginGeminiCli(callbacks.onAuth, callbacks.onProgress, callbacks.onManualCodeInput);
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
if (!creds.projectId) {
|
||||
throw new Error("Google Cloud credentials missing projectId");
|
||||
}
|
||||
return refreshGoogleCloudToken(creds.refresh, creds.projectId);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
const creds = credentials as GeminiCredentials;
|
||||
return JSON.stringify({ token: creds.access, projectId: creds.projectId });
|
||||
},
|
||||
};
|
||||
162
packages/pi-ai/src/utils/oauth/index.ts
Normal file
162
packages/pi-ai/src/utils/oauth/index.ts
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
/**
|
||||
* OAuth credential management for AI providers.
|
||||
*
|
||||
* This module handles login, token refresh, and credential storage
|
||||
* for OAuth-based providers:
|
||||
* - Anthropic (Claude Pro/Max)
|
||||
* - GitHub Copilot
|
||||
* - Google Cloud Code Assist (Gemini CLI)
|
||||
* - Antigravity (Gemini 3, Claude, GPT-OSS via Google Cloud)
|
||||
*/
|
||||
|
||||
// Anthropic
|
||||
export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.js";
|
||||
// GitHub Copilot
|
||||
export {
|
||||
getGitHubCopilotBaseUrl,
|
||||
githubCopilotOAuthProvider,
|
||||
loginGitHubCopilot,
|
||||
normalizeDomain,
|
||||
refreshGitHubCopilotToken,
|
||||
} from "./github-copilot.js";
|
||||
// Google Antigravity
|
||||
export { antigravityOAuthProvider, loginAntigravity, refreshAntigravityToken } from "./google-antigravity.js";
|
||||
// Google Gemini CLI
|
||||
export { geminiCliOAuthProvider, loginGeminiCli, refreshGoogleCloudToken } from "./google-gemini-cli.js";
|
||||
// OpenAI Codex (ChatGPT OAuth)
|
||||
export { loginOpenAICodex, openaiCodexOAuthProvider, refreshOpenAICodexToken } from "./openai-codex.js";
|
||||
|
||||
export * from "./types.js";
|
||||
|
||||
// ============================================================================
|
||||
// Provider Registry
|
||||
// ============================================================================
|
||||
|
||||
import { anthropicOAuthProvider } from "./anthropic.js";
|
||||
import { githubCopilotOAuthProvider } from "./github-copilot.js";
|
||||
import { antigravityOAuthProvider } from "./google-antigravity.js";
|
||||
import { geminiCliOAuthProvider } from "./google-gemini-cli.js";
|
||||
import { openaiCodexOAuthProvider } from "./openai-codex.js";
|
||||
import type { OAuthCredentials, OAuthProviderId, OAuthProviderInfo, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const BUILT_IN_OAUTH_PROVIDERS: OAuthProviderInterface[] = [
|
||||
anthropicOAuthProvider,
|
||||
githubCopilotOAuthProvider,
|
||||
geminiCliOAuthProvider,
|
||||
antigravityOAuthProvider,
|
||||
openaiCodexOAuthProvider,
|
||||
];
|
||||
|
||||
const oauthProviderRegistry = new Map<string, OAuthProviderInterface>(
|
||||
BUILT_IN_OAUTH_PROVIDERS.map((provider) => [provider.id, provider]),
|
||||
);
|
||||
|
||||
/**
|
||||
* Get an OAuth provider by ID
|
||||
*/
|
||||
export function getOAuthProvider(id: OAuthProviderId): OAuthProviderInterface | undefined {
|
||||
return oauthProviderRegistry.get(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a custom OAuth provider
|
||||
*/
|
||||
export function registerOAuthProvider(provider: OAuthProviderInterface): void {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister an OAuth provider.
|
||||
*
|
||||
* If the provider is built-in, restores the built-in implementation.
|
||||
* Custom providers are removed completely.
|
||||
*/
|
||||
export function unregisterOAuthProvider(id: string): void {
|
||||
const builtInProvider = BUILT_IN_OAUTH_PROVIDERS.find((provider) => provider.id === id);
|
||||
if (builtInProvider) {
|
||||
oauthProviderRegistry.set(id, builtInProvider);
|
||||
return;
|
||||
}
|
||||
oauthProviderRegistry.delete(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset OAuth providers to built-ins.
|
||||
*/
|
||||
export function resetOAuthProviders(): void {
|
||||
oauthProviderRegistry.clear();
|
||||
for (const provider of BUILT_IN_OAUTH_PROVIDERS) {
|
||||
oauthProviderRegistry.set(provider.id, provider);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
export function getOAuthProviders(): OAuthProviderInterface[] {
|
||||
return Array.from(oauthProviderRegistry.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use getOAuthProviders() which returns OAuthProviderInterface[]
|
||||
*/
|
||||
export function getOAuthProviderInfoList(): OAuthProviderInfo[] {
|
||||
return getOAuthProviders().map((p) => ({
|
||||
id: p.id,
|
||||
name: p.name,
|
||||
available: true,
|
||||
}));
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// High-level API (uses provider registry)
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Refresh token for any OAuth provider.
|
||||
* @deprecated Use getOAuthProvider(id).refreshToken() instead
|
||||
*/
|
||||
export async function refreshOAuthToken(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: OAuthCredentials,
|
||||
): Promise<OAuthCredentials> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
return provider.refreshToken(credentials);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider from OAuth credentials.
|
||||
* Automatically refreshes expired tokens.
|
||||
*
|
||||
* @returns API key string and updated credentials, or null if no credentials
|
||||
* @throws Error if refresh fails
|
||||
*/
|
||||
export async function getOAuthApiKey(
|
||||
providerId: OAuthProviderId,
|
||||
credentials: Record<string, OAuthCredentials>,
|
||||
): Promise<{ newCredentials: OAuthCredentials; apiKey: string } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
let creds = credentials[providerId];
|
||||
if (!creds) {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Refresh if expired
|
||||
if (Date.now() >= creds.expires) {
|
||||
try {
|
||||
creds = await provider.refreshToken(creds);
|
||||
} catch (_error) {
|
||||
throw new Error(`Failed to refresh OAuth token for ${providerId}`);
|
||||
}
|
||||
}
|
||||
|
||||
const apiKey = provider.getApiKey(creds);
|
||||
return { newCredentials: creds, apiKey };
|
||||
}
|
||||
455
packages/pi-ai/src/utils/oauth/openai-codex.ts
Normal file
455
packages/pi-ai/src/utils/oauth/openai-codex.ts
Normal file
|
|
@ -0,0 +1,455 @@
|
|||
/**
|
||||
* OpenAI Codex (ChatGPT OAuth) flow
|
||||
*
|
||||
* NOTE: This module uses Node.js crypto and http for the OAuth callback.
|
||||
* It is only intended for CLI use, not browser environments.
|
||||
*/
|
||||
|
||||
// NEVER convert to top-level imports - breaks browser/Vite builds (web-ui)
|
||||
let _randomBytes: typeof import("node:crypto").randomBytes | null = null;
|
||||
let _http: typeof import("node:http") | null = null;
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
import("node:crypto").then((m) => {
|
||||
_randomBytes = m.randomBytes;
|
||||
});
|
||||
import("node:http").then((m) => {
|
||||
_http = m;
|
||||
});
|
||||
}
|
||||
|
||||
import { generatePKCE } from "./pkce.js";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.js";
|
||||
|
||||
const CLIENT_ID = "app_EMoamEEZ73f0CkXaXp7hrann";
|
||||
const AUTHORIZE_URL = "https://auth.openai.com/oauth/authorize";
|
||||
const TOKEN_URL = "https://auth.openai.com/oauth/token";
|
||||
const REDIRECT_URI = "http://localhost:1455/auth/callback";
|
||||
const SCOPE = "openid profile email offline_access";
|
||||
const JWT_CLAIM_PATH = "https://api.openai.com/auth";
|
||||
|
||||
const SUCCESS_HTML = `<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>Authentication successful</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>Authentication successful. Return to your terminal to continue.</p>
|
||||
</body>
|
||||
</html>`;
|
||||
|
||||
type TokenSuccess = { type: "success"; access: string; refresh: string; expires: number };
|
||||
type TokenFailure = { type: "failed" };
|
||||
type TokenResult = TokenSuccess | TokenFailure;
|
||||
|
||||
type JwtPayload = {
|
||||
[JWT_CLAIM_PATH]?: {
|
||||
chatgpt_account_id?: string;
|
||||
};
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
function createState(): string {
|
||||
if (!_randomBytes) {
|
||||
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
|
||||
}
|
||||
return _randomBytes(16).toString("hex");
|
||||
}
|
||||
|
||||
function parseAuthorizationInput(input: string): { code?: string; state?: string } {
|
||||
const value = input.trim();
|
||||
if (!value) return {};
|
||||
|
||||
try {
|
||||
const url = new URL(value);
|
||||
return {
|
||||
code: url.searchParams.get("code") ?? undefined,
|
||||
state: url.searchParams.get("state") ?? undefined,
|
||||
};
|
||||
} catch {
|
||||
// not a URL
|
||||
}
|
||||
|
||||
if (value.includes("#")) {
|
||||
const [code, state] = value.split("#", 2);
|
||||
return { code, state };
|
||||
}
|
||||
|
||||
if (value.includes("code=")) {
|
||||
const params = new URLSearchParams(value);
|
||||
return {
|
||||
code: params.get("code") ?? undefined,
|
||||
state: params.get("state") ?? undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return { code: value };
|
||||
}
|
||||
|
||||
function decodeJwt(token: string): JwtPayload | null {
|
||||
try {
|
||||
const parts = token.split(".");
|
||||
if (parts.length !== 3) return null;
|
||||
const payload = parts[1] ?? "";
|
||||
const decoded = atob(payload);
|
||||
return JSON.parse(decoded) as JwtPayload;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
async function exchangeAuthorizationCode(
|
||||
code: string,
|
||||
verifier: string,
|
||||
redirectUri: string = REDIRECT_URI,
|
||||
): Promise<TokenResult> {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "authorization_code",
|
||||
client_id: CLIENT_ID,
|
||||
code,
|
||||
code_verifier: verifier,
|
||||
redirect_uri: redirectUri,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error("[openai-codex] code->token failed:", response.status, text);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
|
||||
console.error("[openai-codex] token response missing fields:", json);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
}
|
||||
|
||||
async function refreshAccessToken(refreshToken: string): Promise<TokenResult> {
|
||||
try {
|
||||
const response = await fetch(TOKEN_URL, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/x-www-form-urlencoded" },
|
||||
body: new URLSearchParams({
|
||||
grant_type: "refresh_token",
|
||||
refresh_token: refreshToken,
|
||||
client_id: CLIENT_ID,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text().catch(() => "");
|
||||
console.error("[openai-codex] Token refresh failed:", response.status, text);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
const json = (await response.json()) as {
|
||||
access_token?: string;
|
||||
refresh_token?: string;
|
||||
expires_in?: number;
|
||||
};
|
||||
|
||||
if (!json.access_token || !json.refresh_token || typeof json.expires_in !== "number") {
|
||||
console.error("[openai-codex] Token refresh response missing fields:", json);
|
||||
return { type: "failed" };
|
||||
}
|
||||
|
||||
return {
|
||||
type: "success",
|
||||
access: json.access_token,
|
||||
refresh: json.refresh_token,
|
||||
expires: Date.now() + json.expires_in * 1000,
|
||||
};
|
||||
} catch (error) {
|
||||
console.error("[openai-codex] Token refresh error:", error);
|
||||
return { type: "failed" };
|
||||
}
|
||||
}
|
||||
|
||||
async function createAuthorizationFlow(
|
||||
originator: string = "pi",
|
||||
): Promise<{ verifier: string; state: string; url: string }> {
|
||||
const { verifier, challenge } = await generatePKCE();
|
||||
const state = createState();
|
||||
|
||||
const url = new URL(AUTHORIZE_URL);
|
||||
url.searchParams.set("response_type", "code");
|
||||
url.searchParams.set("client_id", CLIENT_ID);
|
||||
url.searchParams.set("redirect_uri", REDIRECT_URI);
|
||||
url.searchParams.set("scope", SCOPE);
|
||||
url.searchParams.set("code_challenge", challenge);
|
||||
url.searchParams.set("code_challenge_method", "S256");
|
||||
url.searchParams.set("state", state);
|
||||
url.searchParams.set("id_token_add_organizations", "true");
|
||||
url.searchParams.set("codex_cli_simplified_flow", "true");
|
||||
url.searchParams.set("originator", originator);
|
||||
|
||||
return { verifier, state, url: url.toString() };
|
||||
}
|
||||
|
||||
type OAuthServerInfo = {
|
||||
close: () => void;
|
||||
cancelWait: () => void;
|
||||
waitForCode: () => Promise<{ code: string } | null>;
|
||||
};
|
||||
|
||||
function startLocalOAuthServer(state: string): Promise<OAuthServerInfo> {
|
||||
if (!_http) {
|
||||
throw new Error("OpenAI Codex OAuth is only available in Node.js environments");
|
||||
}
|
||||
let lastCode: string | null = null;
|
||||
let cancelled = false;
|
||||
const server = _http.createServer((req, res) => {
|
||||
try {
|
||||
const url = new URL(req.url || "", "http://localhost");
|
||||
if (url.pathname !== "/auth/callback") {
|
||||
res.statusCode = 404;
|
||||
res.end("Not found");
|
||||
return;
|
||||
}
|
||||
if (url.searchParams.get("state") !== state) {
|
||||
res.statusCode = 400;
|
||||
res.end("State mismatch");
|
||||
return;
|
||||
}
|
||||
const code = url.searchParams.get("code");
|
||||
if (!code) {
|
||||
res.statusCode = 400;
|
||||
res.end("Missing authorization code");
|
||||
return;
|
||||
}
|
||||
res.statusCode = 200;
|
||||
res.setHeader("Content-Type", "text/html; charset=utf-8");
|
||||
res.end(SUCCESS_HTML);
|
||||
lastCode = code;
|
||||
} catch {
|
||||
res.statusCode = 500;
|
||||
res.end("Internal error");
|
||||
}
|
||||
});
|
||||
|
||||
return new Promise((resolve) => {
|
||||
server
|
||||
.listen(1455, "127.0.0.1", () => {
|
||||
resolve({
|
||||
close: () => server.close(),
|
||||
cancelWait: () => {
|
||||
cancelled = true;
|
||||
},
|
||||
waitForCode: async () => {
|
||||
const sleep = () => new Promise((r) => setTimeout(r, 100));
|
||||
for (let i = 0; i < 600; i += 1) {
|
||||
if (lastCode) return { code: lastCode };
|
||||
if (cancelled) return null;
|
||||
await sleep();
|
||||
}
|
||||
return null;
|
||||
},
|
||||
});
|
||||
})
|
||||
.on("error", (err: NodeJS.ErrnoException) => {
|
||||
console.error(
|
||||
"[openai-codex] Failed to bind http://127.0.0.1:1455 (",
|
||||
err.code,
|
||||
") Falling back to manual paste.",
|
||||
);
|
||||
resolve({
|
||||
close: () => {
|
||||
try {
|
||||
server.close();
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
},
|
||||
cancelWait: () => {},
|
||||
waitForCode: async () => null,
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
function getAccountId(accessToken: string): string | null {
|
||||
const payload = decodeJwt(accessToken);
|
||||
const auth = payload?.[JWT_CLAIM_PATH];
|
||||
const accountId = auth?.chatgpt_account_id;
|
||||
return typeof accountId === "string" && accountId.length > 0 ? accountId : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with OpenAI Codex OAuth
|
||||
*
|
||||
* @param options.onAuth - Called with URL and instructions when auth starts
|
||||
* @param options.onPrompt - Called to prompt user for manual code paste (fallback if no onManualCodeInput)
|
||||
* @param options.onProgress - Optional progress messages
|
||||
* @param options.onManualCodeInput - Optional promise that resolves with user-pasted code.
|
||||
* Races with browser callback - whichever completes first wins.
|
||||
* Useful for showing paste input immediately alongside browser flow.
|
||||
* @param options.originator - OAuth originator parameter (defaults to "pi")
|
||||
*/
|
||||
export async function loginOpenAICodex(options: {
|
||||
onAuth: (info: { url: string; instructions?: string }) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
originator?: string;
|
||||
}): Promise<OAuthCredentials> {
|
||||
const { verifier, state, url } = await createAuthorizationFlow(options.originator);
|
||||
const server = await startLocalOAuthServer(state);
|
||||
|
||||
options.onAuth({ url, instructions: "A browser window should open. Complete login to finish." });
|
||||
|
||||
let code: string | undefined;
|
||||
try {
|
||||
if (options.onManualCodeInput) {
|
||||
// Race between browser callback and manual input
|
||||
let manualCode: string | undefined;
|
||||
let manualError: Error | undefined;
|
||||
const manualPromise = options
|
||||
.onManualCodeInput()
|
||||
.then((input) => {
|
||||
manualCode = input;
|
||||
server.cancelWait();
|
||||
})
|
||||
.catch((err) => {
|
||||
manualError = err instanceof Error ? err : new Error(String(err));
|
||||
server.cancelWait();
|
||||
});
|
||||
|
||||
const result = await server.waitForCode();
|
||||
|
||||
// If manual input was cancelled, throw that error
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
|
||||
if (result?.code) {
|
||||
// Browser callback won
|
||||
code = result.code;
|
||||
} else if (manualCode) {
|
||||
// Manual input won (or callback timed out and user had entered code)
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
// If still no code, wait for manual promise to complete and try that
|
||||
if (!code) {
|
||||
await manualPromise;
|
||||
if (manualError) {
|
||||
throw manualError;
|
||||
}
|
||||
if (manualCode) {
|
||||
const parsed = parseAuthorizationInput(manualCode);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Original flow: wait for callback, then prompt if needed
|
||||
const result = await server.waitForCode();
|
||||
if (result?.code) {
|
||||
code = result.code;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to onPrompt if still no code
|
||||
if (!code) {
|
||||
const input = await options.onPrompt({
|
||||
message: "Paste the authorization code (or full redirect URL):",
|
||||
});
|
||||
const parsed = parseAuthorizationInput(input);
|
||||
if (parsed.state && parsed.state !== state) {
|
||||
throw new Error("State mismatch");
|
||||
}
|
||||
code = parsed.code;
|
||||
}
|
||||
|
||||
if (!code) {
|
||||
throw new Error("Missing authorization code");
|
||||
}
|
||||
|
||||
const tokenResult = await exchangeAuthorizationCode(code, verifier);
|
||||
if (tokenResult.type !== "success") {
|
||||
throw new Error("Token exchange failed");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(tokenResult.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: tokenResult.access,
|
||||
refresh: tokenResult.refresh,
|
||||
expires: tokenResult.expires,
|
||||
accountId,
|
||||
};
|
||||
} finally {
|
||||
server.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OpenAI Codex OAuth token
|
||||
*/
|
||||
export async function refreshOpenAICodexToken(refreshToken: string): Promise<OAuthCredentials> {
|
||||
const result = await refreshAccessToken(refreshToken);
|
||||
if (result.type !== "success") {
|
||||
throw new Error("Failed to refresh OpenAI Codex token");
|
||||
}
|
||||
|
||||
const accountId = getAccountId(result.access);
|
||||
if (!accountId) {
|
||||
throw new Error("Failed to extract accountId from token");
|
||||
}
|
||||
|
||||
return {
|
||||
access: result.access,
|
||||
refresh: result.refresh,
|
||||
expires: result.expires,
|
||||
accountId,
|
||||
};
|
||||
}
|
||||
|
||||
export const openaiCodexOAuthProvider: OAuthProviderInterface = {
|
||||
id: "openai-codex",
|
||||
name: "ChatGPT Plus/Pro (Codex Subscription)",
|
||||
usesCallbackServer: true,
|
||||
|
||||
async login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials> {
|
||||
return loginOpenAICodex({
|
||||
onAuth: callbacks.onAuth,
|
||||
onPrompt: callbacks.onPrompt,
|
||||
onProgress: callbacks.onProgress,
|
||||
onManualCodeInput: callbacks.onManualCodeInput,
|
||||
});
|
||||
},
|
||||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
return refreshOpenAICodexToken(credentials.refresh);
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
return credentials.access;
|
||||
},
|
||||
};
|
||||
34
packages/pi-ai/src/utils/oauth/pkce.ts
Normal file
34
packages/pi-ai/src/utils/oauth/pkce.ts
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
/**
|
||||
* PKCE utilities using Web Crypto API.
|
||||
* Works in both Node.js 20+ and browsers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Encode bytes as base64url string.
|
||||
*/
|
||||
function base64urlEncode(bytes: Uint8Array): string {
|
||||
let binary = "";
|
||||
for (const byte of bytes) {
|
||||
binary += String.fromCharCode(byte);
|
||||
}
|
||||
return btoa(binary).replace(/\+/g, "-").replace(/\//g, "_").replace(/=/g, "");
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate PKCE code verifier and challenge.
|
||||
* Uses Web Crypto API for cross-platform compatibility.
|
||||
*/
|
||||
export async function generatePKCE(): Promise<{ verifier: string; challenge: string }> {
|
||||
// Generate random verifier
|
||||
const verifierBytes = new Uint8Array(32);
|
||||
crypto.getRandomValues(verifierBytes);
|
||||
const verifier = base64urlEncode(verifierBytes);
|
||||
|
||||
// Compute SHA-256 challenge
|
||||
const encoder = new TextEncoder();
|
||||
const data = encoder.encode(verifier);
|
||||
const hashBuffer = await crypto.subtle.digest("SHA-256", data);
|
||||
const challenge = base64urlEncode(new Uint8Array(hashBuffer));
|
||||
|
||||
return { verifier, challenge };
|
||||
}
|
||||
59
packages/pi-ai/src/utils/oauth/types.ts
Normal file
59
packages/pi-ai/src/utils/oauth/types.ts
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
import type { Api, Model } from "../../types.js";
|
||||
|
||||
export type OAuthCredentials = {
|
||||
refresh: string;
|
||||
access: string;
|
||||
expires: number;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
export type OAuthProviderId = string;
|
||||
|
||||
/** @deprecated Use OAuthProviderId instead */
|
||||
export type OAuthProvider = OAuthProviderId;
|
||||
|
||||
export type OAuthPrompt = {
|
||||
message: string;
|
||||
placeholder?: string;
|
||||
allowEmpty?: boolean;
|
||||
};
|
||||
|
||||
export type OAuthAuthInfo = {
|
||||
url: string;
|
||||
instructions?: string;
|
||||
};
|
||||
|
||||
export interface OAuthLoginCallbacks {
|
||||
onAuth: (info: OAuthAuthInfo) => void;
|
||||
onPrompt: (prompt: OAuthPrompt) => Promise<string>;
|
||||
onProgress?: (message: string) => void;
|
||||
onManualCodeInput?: () => Promise<string>;
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface OAuthProviderInterface {
|
||||
readonly id: OAuthProviderId;
|
||||
readonly name: string;
|
||||
|
||||
/** Run the login flow, return credentials to persist */
|
||||
login(callbacks: OAuthLoginCallbacks): Promise<OAuthCredentials>;
|
||||
|
||||
/** Whether login uses a local callback server and supports manual code input. */
|
||||
usesCallbackServer?: boolean;
|
||||
|
||||
/** Refresh expired credentials, return updated credentials to persist */
|
||||
refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials>;
|
||||
|
||||
/** Convert credentials to API key string for the provider */
|
||||
getApiKey(credentials: OAuthCredentials): string;
|
||||
|
||||
/** Optional: modify models for this provider (e.g., update baseUrl) */
|
||||
modifyModels?(models: Model<Api>[], credentials: OAuthCredentials): Model<Api>[];
|
||||
}
|
||||
|
||||
/** @deprecated Use OAuthProviderInterface instead */
|
||||
export interface OAuthProviderInfo {
|
||||
id: OAuthProviderId;
|
||||
name: string;
|
||||
available: boolean;
|
||||
}
|
||||
123
packages/pi-ai/src/utils/overflow.ts
Normal file
123
packages/pi-ai/src/utils/overflow.ts
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import type { AssistantMessage } from "../types.js";
|
||||
|
||||
/**
|
||||
* Regex patterns to detect context overflow errors from different providers.
|
||||
*
|
||||
* These patterns match error messages returned when the input exceeds
|
||||
* the model's context window.
|
||||
*
|
||||
* Provider-specific patterns (with example error messages):
|
||||
*
|
||||
* - Anthropic: "prompt is too long: 213462 tokens > 200000 maximum"
|
||||
* - OpenAI: "Your input exceeds the context window of this model"
|
||||
* - Google: "The input token count (1196265) exceeds the maximum number of tokens allowed (1048575)"
|
||||
* - xAI: "This model's maximum prompt length is 131072 but the request contains 537812 tokens"
|
||||
* - Groq: "Please reduce the length of the messages or completion"
|
||||
* - OpenRouter: "This endpoint's maximum context length is X tokens. However, you requested about Y tokens"
|
||||
* - llama.cpp: "the request exceeds the available context size, try increasing it"
|
||||
* - LM Studio: "tokens to keep from the initial prompt is greater than the context length"
|
||||
* - GitHub Copilot: "prompt token count of X exceeds the limit of Y"
|
||||
* - MiniMax: "invalid params, context window exceeds limit"
|
||||
* - Kimi For Coding: "Your request exceeded model token limit: X (requested: Y)"
|
||||
* - Cerebras: Returns "400/413 status code (no body)" - handled separately below
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - z.ai: Does NOT error, accepts overflow silently - handled via usage.input > contextWindow
|
||||
* - Ollama: Silently truncates input - not detectable via error message
|
||||
*/
|
||||
const OVERFLOW_PATTERNS = [
|
||||
/prompt is too long/i, // Anthropic
|
||||
/input is too long for requested model/i, // Amazon Bedrock
|
||||
/exceeds the context window/i, // OpenAI (Completions & Responses API)
|
||||
/input token count.*exceeds the maximum/i, // Google (Gemini)
|
||||
/maximum prompt length is \d+/i, // xAI (Grok)
|
||||
/reduce the length of the messages/i, // Groq
|
||||
/maximum context length is \d+ tokens/i, // OpenRouter (all backends)
|
||||
/exceeds the limit of \d+/i, // GitHub Copilot
|
||||
/exceeds the available context size/i, // llama.cpp server
|
||||
/greater than the context length/i, // LM Studio
|
||||
/context window exceeds limit/i, // MiniMax
|
||||
/exceeded model token limit/i, // Kimi For Coding
|
||||
/too large for model with \d+ maximum context length/i, // Mistral
|
||||
/model_context_window_exceeded/i, // z.ai non-standard finish_reason surfaced as error text
|
||||
/context[_ ]length[_ ]exceeded/i, // Generic fallback
|
||||
/too many tokens/i, // Generic fallback
|
||||
/token limit exceeded/i, // Generic fallback
|
||||
];
|
||||
|
||||
/**
|
||||
* Check if an assistant message represents a context overflow error.
|
||||
*
|
||||
* This handles two cases:
|
||||
* 1. Error-based overflow: Most providers return stopReason "error" with a
|
||||
* specific error message pattern.
|
||||
* 2. Silent overflow: Some providers accept overflow requests and return
|
||||
* successfully. For these, we check if usage.input exceeds the context window.
|
||||
*
|
||||
* ## Reliability by Provider
|
||||
*
|
||||
* **Reliable detection (returns error with detectable message):**
|
||||
* - Anthropic: "prompt is too long: X tokens > Y maximum"
|
||||
* - OpenAI (Completions & Responses): "exceeds the context window"
|
||||
* - Google Gemini: "input token count exceeds the maximum"
|
||||
* - xAI (Grok): "maximum prompt length is X but request contains Y"
|
||||
* - Groq: "reduce the length of the messages"
|
||||
* - Cerebras: 400/413 status code (no body)
|
||||
* - Mistral: "Prompt contains X tokens ... too large for model with Y maximum context length"
|
||||
* - OpenRouter (all backends): "maximum context length is X tokens"
|
||||
* - llama.cpp: "exceeds the available context size"
|
||||
* - LM Studio: "greater than the context length"
|
||||
* - Kimi For Coding: "exceeded model token limit: X (requested: Y)"
|
||||
*
|
||||
* **Unreliable detection:**
|
||||
* - z.ai: Sometimes accepts overflow silently (detectable via usage.input > contextWindow),
|
||||
* sometimes returns rate limit errors. Pass contextWindow param to detect silent overflow.
|
||||
* - Ollama: Silently truncates input without error. Cannot be detected via this function.
|
||||
* The response will have usage.input < expected, but we don't know the expected value.
|
||||
*
|
||||
* ## Custom Providers
|
||||
*
|
||||
* If you've added custom models via settings.json, this function may not detect
|
||||
* overflow errors from those providers. To add support:
|
||||
*
|
||||
* 1. Send a request that exceeds the model's context window
|
||||
* 2. Check the errorMessage in the response
|
||||
* 3. Create a regex pattern that matches the error
|
||||
* 4. The pattern should be added to OVERFLOW_PATTERNS in this file, or
|
||||
* check the errorMessage yourself before calling this function
|
||||
*
|
||||
* @param message - The assistant message to check
|
||||
* @param contextWindow - Optional context window size for detecting silent overflow (z.ai)
|
||||
* @returns true if the message indicates a context overflow
|
||||
*/
|
||||
export function isContextOverflow(message: AssistantMessage, contextWindow?: number): boolean {
|
||||
// Case 1: Check error message patterns
|
||||
if (message.stopReason === "error" && message.errorMessage) {
|
||||
// Check known patterns
|
||||
if (OVERFLOW_PATTERNS.some((p) => p.test(message.errorMessage!))) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Cerebras returns 400/413 with no body for context overflow
|
||||
// Note: 429 is rate limiting (requests/tokens per time), NOT context overflow
|
||||
if (/^4(00|13)\s*(status code)?\s*\(no body\)/i.test(message.errorMessage)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Case 2: Silent overflow (z.ai style) - successful but usage exceeds context
|
||||
if (contextWindow && message.stopReason === "stop") {
|
||||
const inputTokens = message.usage.input + message.usage.cacheRead;
|
||||
if (inputTokens > contextWindow) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the overflow patterns for testing purposes.
|
||||
*/
|
||||
export function getOverflowPatterns(): RegExp[] {
|
||||
return [...OVERFLOW_PATTERNS];
|
||||
}
|
||||
25
packages/pi-ai/src/utils/sanitize-unicode.ts
Normal file
25
packages/pi-ai/src/utils/sanitize-unicode.ts
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Removes unpaired Unicode surrogate characters from a string.
|
||||
*
|
||||
* Unpaired surrogates (high surrogates 0xD800-0xDBFF without matching low surrogates 0xDC00-0xDFFF,
|
||||
* or vice versa) cause JSON serialization errors in many API providers.
|
||||
*
|
||||
* Valid emoji and other characters outside the Basic Multilingual Plane use properly paired
|
||||
* surrogates and will NOT be affected by this function.
|
||||
*
|
||||
* @param text - The text to sanitize
|
||||
* @returns The sanitized text with unpaired surrogates removed
|
||||
*
|
||||
* @example
|
||||
* // Valid emoji (properly paired surrogates) are preserved
|
||||
* sanitizeSurrogates("Hello 🙈 World") // => "Hello 🙈 World"
|
||||
*
|
||||
* // Unpaired high surrogate is removed
|
||||
* const unpaired = String.fromCharCode(0xD83D); // high surrogate without low
|
||||
* sanitizeSurrogates(`Text ${unpaired} here`) // => "Text here"
|
||||
*/
|
||||
export function sanitizeSurrogates(text: string): string {
|
||||
// Replace unpaired high surrogates (0xD800-0xDBFF not followed by low surrogate)
|
||||
// Replace unpaired low surrogates (0xDC00-0xDFFF not preceded by high surrogate)
|
||||
return text.replace(/[\uD800-\uDBFF](?![\uDC00-\uDFFF])|(?<![\uD800-\uDBFF])[\uDC00-\uDFFF]/g, "");
|
||||
}
|
||||
24
packages/pi-ai/src/utils/typebox-helpers.ts
Normal file
24
packages/pi-ai/src/utils/typebox-helpers.ts
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
import { type TUnsafe, Type } from "@sinclair/typebox";
|
||||
|
||||
/**
|
||||
* Creates a string enum schema compatible with Google's API and other providers
|
||||
* that don't support anyOf/const patterns.
|
||||
*
|
||||
* @example
|
||||
* const OperationSchema = StringEnum(["add", "subtract", "multiply", "divide"], {
|
||||
* description: "The operation to perform"
|
||||
* });
|
||||
*
|
||||
* type Operation = Static<typeof OperationSchema>; // "add" | "subtract" | "multiply" | "divide"
|
||||
*/
|
||||
export function StringEnum<T extends readonly string[]>(
|
||||
values: T,
|
||||
options?: { description?: string; default?: T[number] },
|
||||
): TUnsafe<T[number]> {
|
||||
return Type.Unsafe<T[number]>({
|
||||
type: "string",
|
||||
enum: values as any,
|
||||
...(options?.description && { description: options.description }),
|
||||
...(options?.default && { default: options.default }),
|
||||
});
|
||||
}
|
||||
84
packages/pi-ai/src/utils/validation.ts
Normal file
84
packages/pi-ai/src/utils/validation.ts
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
import AjvModule from "ajv";
|
||||
import addFormatsModule from "ajv-formats";
|
||||
|
||||
// Handle both default and named exports
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const addFormats = (addFormatsModule as any).default || addFormatsModule;
|
||||
|
||||
import type { Tool, ToolCall } from "../types.js";
|
||||
|
||||
// Detect if we're in a browser extension environment with strict CSP
|
||||
// Chrome extensions with Manifest V3 don't allow eval/Function constructor
|
||||
const isBrowserExtension = typeof globalThis !== "undefined" && (globalThis as any).chrome?.runtime?.id !== undefined;
|
||||
|
||||
// Create a singleton AJV instance with formats (only if not in browser extension)
|
||||
// AJV requires 'unsafe-eval' CSP which is not allowed in Manifest V3
|
||||
let ajv: any = null;
|
||||
if (!isBrowserExtension) {
|
||||
try {
|
||||
ajv = new Ajv({
|
||||
allErrors: true,
|
||||
strict: false,
|
||||
coerceTypes: true,
|
||||
});
|
||||
addFormats(ajv);
|
||||
} catch (_e) {
|
||||
// AJV initialization failed (likely CSP restriction)
|
||||
console.warn("AJV validation disabled due to CSP restrictions");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Finds a tool by name and validates the tool call arguments against its TypeBox schema
|
||||
* @param tools Array of tool definitions
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated arguments
|
||||
* @throws Error if tool is not found or validation fails
|
||||
*/
|
||||
export function validateToolCall(tools: Tool[], toolCall: ToolCall): any {
|
||||
const tool = tools.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
throw new Error(`Tool "${toolCall.name}" not found`);
|
||||
}
|
||||
return validateToolArguments(tool, toolCall);
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates tool call arguments against the tool's TypeBox schema
|
||||
* @param tool The tool definition with TypeBox schema
|
||||
* @param toolCall The tool call from the LLM
|
||||
* @returns The validated (and potentially coerced) arguments
|
||||
* @throws Error with formatted message if validation fails
|
||||
*/
|
||||
export function validateToolArguments(tool: Tool, toolCall: ToolCall): any {
|
||||
// Skip validation in browser extension environment (CSP restrictions prevent AJV from working)
|
||||
if (!ajv || isBrowserExtension) {
|
||||
// Trust the LLM's output without validation
|
||||
// Browser extensions can't use AJV due to Manifest V3 CSP restrictions
|
||||
return toolCall.arguments;
|
||||
}
|
||||
|
||||
// Compile the schema
|
||||
const validate = ajv.compile(tool.parameters);
|
||||
|
||||
// Clone arguments so AJV can safely mutate for type coercion
|
||||
const args = structuredClone(toolCall.arguments);
|
||||
|
||||
// Validate the arguments (AJV mutates args in-place for type coercion)
|
||||
if (validate(args)) {
|
||||
return args;
|
||||
}
|
||||
|
||||
// Format validation errors nicely
|
||||
const errors =
|
||||
validate.errors
|
||||
?.map((err: any) => {
|
||||
const path = err.instancePath ? err.instancePath.substring(1) : err.params.missingProperty || "root";
|
||||
return ` - ${path}: ${err.message}`;
|
||||
})
|
||||
.join("\n") || "Unknown validation error";
|
||||
|
||||
const errorMessage = `Validation failed for tool "${toolCall.name}":\n${errors}\n\nReceived arguments:\n${JSON.stringify(toolCall.arguments, null, 2)}`;
|
||||
|
||||
throw new Error(errorMessage);
|
||||
}
|
||||
27
packages/pi-ai/tsconfig.json
Normal file
27
packages/pi-ai/tsconfig.json
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2024",
|
||||
"module": "Node16",
|
||||
"lib": ["ES2024"],
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"declaration": true,
|
||||
"declarationMap": true,
|
||||
"sourceMap": true,
|
||||
"inlineSources": true,
|
||||
"inlineSourceMap": false,
|
||||
"moduleResolution": "Node16",
|
||||
"resolveJsonModule": true,
|
||||
"allowImportingTsExtensions": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true,
|
||||
"useDefineForClassFields": false,
|
||||
"types": ["node"],
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src"
|
||||
},
|
||||
"include": ["src/**/*.ts"],
|
||||
"exclude": ["node_modules", "dist", "**/*.d.ts", "src/**/*.d.ts"]
|
||||
}
|
||||
55
packages/pi-coding-agent/package.json
Normal file
55
packages/pi-coding-agent/package.json
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
{
|
||||
"name": "@gsd/pi-coding-agent",
|
||||
"version": "0.57.1",
|
||||
"description": "Coding agent CLI (vendored from pi-mono)",
|
||||
"type": "module",
|
||||
"piConfig": {
|
||||
"name": "pi",
|
||||
"configDir": ".pi"
|
||||
},
|
||||
"main": "./dist/index.js",
|
||||
"types": "./dist/index.d.ts",
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"import": "./dist/index.js"
|
||||
},
|
||||
"./hooks": {
|
||||
"types": "./dist/core/hooks/index.d.ts",
|
||||
"import": "./dist/core/hooks/index.js"
|
||||
}
|
||||
},
|
||||
"scripts": {
|
||||
"build": "tsc -p tsconfig.json && npm run copy-assets",
|
||||
"copy-assets": "node -e \"const{mkdirSync,cpSync}=require('fs');mkdirSync('dist/modes/interactive/theme',{recursive:true});cpSync('src/modes/interactive/theme','dist/modes/interactive/theme',{recursive:true,filter:(s)=>!s.endsWith('.ts')});mkdirSync('dist/core/export-html/vendor',{recursive:true});cpSync('src/core/export-html/template.html','dist/core/export-html/template.html');cpSync('src/core/export-html/template.css','dist/core/export-html/template.css');cpSync('src/core/export-html/template.js','dist/core/export-html/template.js');cpSync('src/core/export-html/vendor','dist/core/export-html/vendor',{recursive:true,filter:(s)=>!s.endsWith('.ts')})\""
|
||||
},
|
||||
"dependencies": {
|
||||
"@gsd/pi-agent-core": "*",
|
||||
"@gsd/pi-ai": "*",
|
||||
"@gsd/pi-tui": "*",
|
||||
"@mariozechner/jiti": "^2.6.2",
|
||||
"@silvia-odwyer/photon-node": "^0.3.4",
|
||||
"chalk": "^5.5.0",
|
||||
"cli-highlight": "^2.1.11",
|
||||
"diff": "^8.0.2",
|
||||
"extract-zip": "^2.0.1",
|
||||
"file-type": "^21.1.1",
|
||||
"glob": "^13.0.1",
|
||||
"hosted-git-info": "^9.0.2",
|
||||
"ignore": "^7.0.5",
|
||||
"marked": "^15.0.12",
|
||||
"minimatch": "^10.2.3",
|
||||
"proper-lockfile": "^4.1.2",
|
||||
"strip-ansi": "^7.1.0",
|
||||
"undici": "^7.19.1",
|
||||
"yaml": "^2.8.2"
|
||||
},
|
||||
"optionalDependencies": {
|
||||
"@mariozechner/clipboard": "^0.3.2"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/diff": "^7.0.2",
|
||||
"@types/hosted-git-info": "^3.0.5",
|
||||
"@types/proper-lockfile": "^4.1.4"
|
||||
}
|
||||
}
|
||||
18
packages/pi-coding-agent/src/cli.ts
Normal file
18
packages/pi-coding-agent/src/cli.ts
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env node
|
||||
/**
|
||||
* CLI entry point for the refactored coding agent.
|
||||
* Uses main.ts with AgentSession and new mode modules.
|
||||
*
|
||||
* Test with: npx tsx src/cli-new.ts [args...]
|
||||
*/
|
||||
process.title = "pi";
|
||||
|
||||
import { setBedrockProviderModule } from "@gsd/pi-ai";
|
||||
import { bedrockProviderModule } from "@gsd/pi-ai/bedrock-provider";
|
||||
import { EnvHttpProxyAgent, setGlobalDispatcher } from "undici";
|
||||
import { main } from "./main.js";
|
||||
|
||||
setGlobalDispatcher(new EnvHttpProxyAgent());
|
||||
setBedrockProviderModule(bedrockProviderModule);
|
||||
|
||||
main(process.argv.slice(2));
|
||||
316
packages/pi-coding-agent/src/cli/args.ts
Normal file
316
packages/pi-coding-agent/src/cli/args.ts
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
/**
|
||||
* CLI argument parsing and help display
|
||||
*/
|
||||
|
||||
import type { ThinkingLevel } from "@gsd/pi-agent-core";
|
||||
import chalk from "chalk";
|
||||
import { APP_NAME, CONFIG_DIR_NAME, ENV_AGENT_DIR } from "../config.js";
|
||||
import { allTools, type ToolName } from "../core/tools/index.js";
|
||||
|
||||
export type Mode = "text" | "json" | "rpc";
|
||||
|
||||
export interface Args {
|
||||
provider?: string;
|
||||
model?: string;
|
||||
apiKey?: string;
|
||||
systemPrompt?: string;
|
||||
appendSystemPrompt?: string;
|
||||
thinking?: ThinkingLevel;
|
||||
continue?: boolean;
|
||||
resume?: boolean;
|
||||
help?: boolean;
|
||||
version?: boolean;
|
||||
mode?: Mode;
|
||||
noSession?: boolean;
|
||||
session?: string;
|
||||
sessionDir?: string;
|
||||
models?: string[];
|
||||
tools?: ToolName[];
|
||||
noTools?: boolean;
|
||||
extensions?: string[];
|
||||
noExtensions?: boolean;
|
||||
print?: boolean;
|
||||
export?: string;
|
||||
noSkills?: boolean;
|
||||
skills?: string[];
|
||||
promptTemplates?: string[];
|
||||
noPromptTemplates?: boolean;
|
||||
themes?: string[];
|
||||
noThemes?: boolean;
|
||||
listModels?: string | true;
|
||||
offline?: boolean;
|
||||
verbose?: boolean;
|
||||
messages: string[];
|
||||
fileArgs: string[];
|
||||
/** Unknown flags (potentially extension flags) - map of flag name to value */
|
||||
unknownFlags: Map<string, boolean | string>;
|
||||
}
|
||||
|
||||
const VALID_THINKING_LEVELS = ["off", "minimal", "low", "medium", "high", "xhigh"] as const;
|
||||
|
||||
export function isValidThinkingLevel(level: string): level is ThinkingLevel {
|
||||
return VALID_THINKING_LEVELS.includes(level as ThinkingLevel);
|
||||
}
|
||||
|
||||
export function parseArgs(args: string[], extensionFlags?: Map<string, { type: "boolean" | "string" }>): Args {
|
||||
const result: Args = {
|
||||
messages: [],
|
||||
fileArgs: [],
|
||||
unknownFlags: new Map(),
|
||||
};
|
||||
|
||||
for (let i = 0; i < args.length; i++) {
|
||||
const arg = args[i];
|
||||
|
||||
if (arg === "--help" || arg === "-h") {
|
||||
result.help = true;
|
||||
} else if (arg === "--version" || arg === "-v") {
|
||||
result.version = true;
|
||||
} else if (arg === "--mode" && i + 1 < args.length) {
|
||||
const mode = args[++i];
|
||||
if (mode === "text" || mode === "json" || mode === "rpc") {
|
||||
result.mode = mode;
|
||||
}
|
||||
} else if (arg === "--continue" || arg === "-c") {
|
||||
result.continue = true;
|
||||
} else if (arg === "--resume" || arg === "-r") {
|
||||
result.resume = true;
|
||||
} else if (arg === "--provider" && i + 1 < args.length) {
|
||||
result.provider = args[++i];
|
||||
} else if (arg === "--model" && i + 1 < args.length) {
|
||||
result.model = args[++i];
|
||||
} else if (arg === "--api-key" && i + 1 < args.length) {
|
||||
result.apiKey = args[++i];
|
||||
} else if (arg === "--system-prompt" && i + 1 < args.length) {
|
||||
result.systemPrompt = args[++i];
|
||||
} else if (arg === "--append-system-prompt" && i + 1 < args.length) {
|
||||
result.appendSystemPrompt = args[++i];
|
||||
} else if (arg === "--no-session") {
|
||||
result.noSession = true;
|
||||
} else if (arg === "--session" && i + 1 < args.length) {
|
||||
result.session = args[++i];
|
||||
} else if (arg === "--session-dir" && i + 1 < args.length) {
|
||||
result.sessionDir = args[++i];
|
||||
} else if (arg === "--models" && i + 1 < args.length) {
|
||||
result.models = args[++i].split(",").map((s) => s.trim());
|
||||
} else if (arg === "--no-tools") {
|
||||
result.noTools = true;
|
||||
} else if (arg === "--tools" && i + 1 < args.length) {
|
||||
const toolNames = args[++i].split(",").map((s) => s.trim());
|
||||
const validTools: ToolName[] = [];
|
||||
for (const name of toolNames) {
|
||||
if (name in allTools) {
|
||||
validTools.push(name as ToolName);
|
||||
} else {
|
||||
console.error(
|
||||
chalk.yellow(`Warning: Unknown tool "${name}". Valid tools: ${Object.keys(allTools).join(", ")}`),
|
||||
);
|
||||
}
|
||||
}
|
||||
result.tools = validTools;
|
||||
} else if (arg === "--thinking" && i + 1 < args.length) {
|
||||
const level = args[++i];
|
||||
if (isValidThinkingLevel(level)) {
|
||||
result.thinking = level;
|
||||
} else {
|
||||
console.error(
|
||||
chalk.yellow(
|
||||
`Warning: Invalid thinking level "${level}". Valid values: ${VALID_THINKING_LEVELS.join(", ")}`,
|
||||
),
|
||||
);
|
||||
}
|
||||
} else if (arg === "--print" || arg === "-p") {
|
||||
result.print = true;
|
||||
} else if (arg === "--export" && i + 1 < args.length) {
|
||||
result.export = args[++i];
|
||||
} else if ((arg === "--extension" || arg === "-e") && i + 1 < args.length) {
|
||||
result.extensions = result.extensions ?? [];
|
||||
result.extensions.push(args[++i]);
|
||||
} else if (arg === "--no-extensions" || arg === "-ne") {
|
||||
result.noExtensions = true;
|
||||
} else if (arg === "--skill" && i + 1 < args.length) {
|
||||
result.skills = result.skills ?? [];
|
||||
result.skills.push(args[++i]);
|
||||
} else if (arg === "--prompt-template" && i + 1 < args.length) {
|
||||
result.promptTemplates = result.promptTemplates ?? [];
|
||||
result.promptTemplates.push(args[++i]);
|
||||
} else if (arg === "--theme" && i + 1 < args.length) {
|
||||
result.themes = result.themes ?? [];
|
||||
result.themes.push(args[++i]);
|
||||
} else if (arg === "--no-skills" || arg === "-ns") {
|
||||
result.noSkills = true;
|
||||
} else if (arg === "--no-prompt-templates" || arg === "-np") {
|
||||
result.noPromptTemplates = true;
|
||||
} else if (arg === "--no-themes") {
|
||||
result.noThemes = true;
|
||||
} else if (arg === "--list-models") {
|
||||
// Check if next arg is a search pattern (not a flag or file arg)
|
||||
if (i + 1 < args.length && !args[i + 1].startsWith("-") && !args[i + 1].startsWith("@")) {
|
||||
result.listModels = args[++i];
|
||||
} else {
|
||||
result.listModels = true;
|
||||
}
|
||||
} else if (arg === "--verbose") {
|
||||
result.verbose = true;
|
||||
} else if (arg === "--offline") {
|
||||
result.offline = true;
|
||||
} else if (arg.startsWith("@")) {
|
||||
result.fileArgs.push(arg.slice(1)); // Remove @ prefix
|
||||
} else if (arg.startsWith("--") && extensionFlags) {
|
||||
// Check if it's an extension-registered flag
|
||||
const flagName = arg.slice(2);
|
||||
const extFlag = extensionFlags.get(flagName);
|
||||
if (extFlag) {
|
||||
if (extFlag.type === "boolean") {
|
||||
result.unknownFlags.set(flagName, true);
|
||||
} else if (extFlag.type === "string" && i + 1 < args.length) {
|
||||
result.unknownFlags.set(flagName, args[++i]);
|
||||
}
|
||||
}
|
||||
// Unknown flags without extensionFlags are silently ignored (first pass)
|
||||
} else if (!arg.startsWith("-")) {
|
||||
result.messages.push(arg);
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
export function printHelp(): void {
|
||||
console.log(`${chalk.bold(APP_NAME)} - AI coding assistant with read, bash, edit, write tools
|
||||
|
||||
${chalk.bold("Usage:")}
|
||||
${APP_NAME} [options] [@files...] [messages...]
|
||||
|
||||
${chalk.bold("Commands:")}
|
||||
${APP_NAME} install <source> [-l] Install extension source and add to settings
|
||||
${APP_NAME} remove <source> [-l] Remove extension source from settings
|
||||
${APP_NAME} update [source] Update installed extensions (skips pinned sources)
|
||||
${APP_NAME} list List installed extensions from settings
|
||||
${APP_NAME} config Open TUI to enable/disable package resources
|
||||
${APP_NAME} <command> --help Show help for install/remove/update/list
|
||||
|
||||
${chalk.bold("Options:")}
|
||||
--provider <name> Provider name (default: google)
|
||||
--model <pattern> Model pattern or ID (supports "provider/id" and optional ":<thinking>")
|
||||
--api-key <key> API key (defaults to env vars)
|
||||
--system-prompt <text> System prompt (default: coding assistant prompt)
|
||||
--append-system-prompt <text> Append text or file contents to the system prompt
|
||||
--mode <mode> Output mode: text (default), json, or rpc
|
||||
--print, -p Non-interactive mode: process prompt and exit
|
||||
--continue, -c Continue previous session
|
||||
--resume, -r Select a session to resume
|
||||
--session <path> Use specific session file
|
||||
--session-dir <dir> Directory for session storage and lookup
|
||||
--no-session Don't save session (ephemeral)
|
||||
--models <patterns> Comma-separated model patterns for Ctrl+P cycling
|
||||
Supports globs (anthropic/*, *sonnet*) and fuzzy matching
|
||||
--no-tools Disable all built-in tools
|
||||
--tools <tools> Comma-separated list of tools to enable (default: read,bash,edit,write)
|
||||
Available: read, bash, edit, write, grep, find, ls
|
||||
--thinking <level> Set thinking level: off, minimal, low, medium, high, xhigh
|
||||
--extension, -e <path> Load an extension file (can be used multiple times)
|
||||
--no-extensions, -ne Disable extension discovery (explicit -e paths still work)
|
||||
--skill <path> Load a skill file or directory (can be used multiple times)
|
||||
--no-skills, -ns Disable skills discovery and loading
|
||||
--prompt-template <path> Load a prompt template file or directory (can be used multiple times)
|
||||
--no-prompt-templates, -np Disable prompt template discovery and loading
|
||||
--theme <path> Load a theme file or directory (can be used multiple times)
|
||||
--no-themes Disable theme discovery and loading
|
||||
--export <file> Export session file to HTML and exit
|
||||
--list-models [search] List available models (with optional fuzzy search)
|
||||
--verbose Force verbose startup (overrides quietStartup setting)
|
||||
--offline Disable startup network operations (same as PI_OFFLINE=1)
|
||||
--help, -h Show this help
|
||||
--version, -v Show version number
|
||||
|
||||
Extensions can register additional flags (e.g., --plan from plan-mode extension).
|
||||
|
||||
${chalk.bold("Examples:")}
|
||||
# Interactive mode
|
||||
${APP_NAME}
|
||||
|
||||
# Interactive mode with initial prompt
|
||||
${APP_NAME} "List all .ts files in src/"
|
||||
|
||||
# Include files in initial message
|
||||
${APP_NAME} @prompt.md @image.png "What color is the sky?"
|
||||
|
||||
# Non-interactive mode (process and exit)
|
||||
${APP_NAME} -p "List all .ts files in src/"
|
||||
|
||||
# Multiple messages (interactive)
|
||||
${APP_NAME} "Read package.json" "What dependencies do we have?"
|
||||
|
||||
# Continue previous session
|
||||
${APP_NAME} --continue "What did we discuss?"
|
||||
|
||||
# Use different model
|
||||
${APP_NAME} --provider openai --model gpt-4o-mini "Help me refactor this code"
|
||||
|
||||
# Use model with provider prefix (no --provider needed)
|
||||
${APP_NAME} --model openai/gpt-4o "Help me refactor this code"
|
||||
|
||||
# Use model with thinking level shorthand
|
||||
${APP_NAME} --model sonnet:high "Solve this complex problem"
|
||||
|
||||
# Limit model cycling to specific models
|
||||
${APP_NAME} --models claude-sonnet,claude-haiku,gpt-4o
|
||||
|
||||
# Limit to a specific provider with glob pattern
|
||||
${APP_NAME} --models "github-copilot/*"
|
||||
|
||||
# Cycle models with fixed thinking levels
|
||||
${APP_NAME} --models sonnet:high,haiku:low
|
||||
|
||||
# Start with a specific thinking level
|
||||
${APP_NAME} --thinking high "Solve this complex problem"
|
||||
|
||||
# Read-only mode (no file modifications possible)
|
||||
${APP_NAME} --tools read,grep,find,ls -p "Review the code in src/"
|
||||
|
||||
# Export a session file to HTML
|
||||
${APP_NAME} --export ~/${CONFIG_DIR_NAME}/agent/sessions/--path--/session.jsonl
|
||||
${APP_NAME} --export session.jsonl output.html
|
||||
|
||||
${chalk.bold("Environment Variables:")}
|
||||
ANTHROPIC_API_KEY - Anthropic Claude API key
|
||||
ANTHROPIC_OAUTH_TOKEN - Anthropic OAuth token (alternative to API key)
|
||||
OPENAI_API_KEY - OpenAI GPT API key
|
||||
AZURE_OPENAI_API_KEY - Azure OpenAI API key
|
||||
AZURE_OPENAI_BASE_URL - Azure OpenAI base URL (https://{resource}.openai.azure.com/openai/v1)
|
||||
AZURE_OPENAI_RESOURCE_NAME - Azure OpenAI resource name (alternative to base URL)
|
||||
AZURE_OPENAI_API_VERSION - Azure OpenAI API version (default: v1)
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME_MAP - Azure OpenAI model=deployment map (comma-separated)
|
||||
GEMINI_API_KEY - Google Gemini API key
|
||||
GROQ_API_KEY - Groq API key
|
||||
CEREBRAS_API_KEY - Cerebras API key
|
||||
XAI_API_KEY - xAI Grok API key
|
||||
OPENROUTER_API_KEY - OpenRouter API key
|
||||
AI_GATEWAY_API_KEY - Vercel AI Gateway API key
|
||||
ZAI_API_KEY - ZAI API key
|
||||
MISTRAL_API_KEY - Mistral API key
|
||||
MINIMAX_API_KEY - MiniMax API key
|
||||
OPENCODE_API_KEY - OpenCode Zen/OpenCode Go API key
|
||||
KIMI_API_KEY - Kimi For Coding API key
|
||||
AWS_PROFILE - AWS profile for Amazon Bedrock
|
||||
AWS_ACCESS_KEY_ID - AWS access key for Amazon Bedrock
|
||||
AWS_SECRET_ACCESS_KEY - AWS secret key for Amazon Bedrock
|
||||
AWS_BEARER_TOKEN_BEDROCK - Bedrock API key (bearer token)
|
||||
AWS_REGION - AWS region for Amazon Bedrock (e.g., us-east-1)
|
||||
${ENV_AGENT_DIR.padEnd(32)} - Session storage directory (default: ~/${CONFIG_DIR_NAME}/agent)
|
||||
PI_PACKAGE_DIR - Override package directory (for Nix/Guix store paths)
|
||||
PI_OFFLINE - Disable startup network operations when set to 1/true/yes
|
||||
PI_SHARE_VIEWER_URL - Base URL for /share command (default: https://pi.dev/session/)
|
||||
PI_AI_ANTIGRAVITY_VERSION - Override Antigravity User-Agent version (e.g., 1.23.0)
|
||||
|
||||
${chalk.bold("Available Tools (default: read, bash, edit, write):")}
|
||||
read - Read file contents
|
||||
bash - Execute bash commands
|
||||
edit - Edit files with find/replace
|
||||
write - Write files (creates/overwrites)
|
||||
grep - Search file contents (read-only, off by default)
|
||||
find - Find files by glob pattern (read-only, off by default)
|
||||
ls - List directory contents (read-only, off by default)
|
||||
`);
|
||||
}
|
||||
52
packages/pi-coding-agent/src/cli/config-selector.ts
Normal file
52
packages/pi-coding-agent/src/cli/config-selector.ts
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* TUI config selector for `pi config` command
|
||||
*/
|
||||
|
||||
import { ProcessTerminal, TUI } from "@gsd/pi-tui";
|
||||
import type { ResolvedPaths } from "../core/package-manager.js";
|
||||
import type { SettingsManager } from "../core/settings-manager.js";
|
||||
import { ConfigSelectorComponent } from "../modes/interactive/components/config-selector.js";
|
||||
import { initTheme, stopThemeWatcher } from "../modes/interactive/theme/theme.js";
|
||||
|
||||
export interface ConfigSelectorOptions {
|
||||
resolvedPaths: ResolvedPaths;
|
||||
settingsManager: SettingsManager;
|
||||
cwd: string;
|
||||
agentDir: string;
|
||||
}
|
||||
|
||||
/** Show TUI config selector and return when closed */
|
||||
export async function selectConfig(options: ConfigSelectorOptions): Promise<void> {
|
||||
// Initialize theme before showing TUI
|
||||
initTheme(options.settingsManager.getTheme(), true);
|
||||
|
||||
return new Promise((resolve) => {
|
||||
const ui = new TUI(new ProcessTerminal());
|
||||
let resolved = false;
|
||||
|
||||
const selector = new ConfigSelectorComponent(
|
||||
options.resolvedPaths,
|
||||
options.settingsManager,
|
||||
options.cwd,
|
||||
options.agentDir,
|
||||
() => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
stopThemeWatcher();
|
||||
resolve();
|
||||
}
|
||||
},
|
||||
() => {
|
||||
ui.stop();
|
||||
stopThemeWatcher();
|
||||
process.exit(0);
|
||||
},
|
||||
() => ui.requestRender(),
|
||||
);
|
||||
|
||||
ui.addChild(selector);
|
||||
ui.setFocus(selector.getResourceList());
|
||||
ui.start();
|
||||
});
|
||||
}
|
||||
96
packages/pi-coding-agent/src/cli/file-processor.ts
Normal file
96
packages/pi-coding-agent/src/cli/file-processor.ts
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Process @file CLI arguments into text content and image attachments
|
||||
*/
|
||||
|
||||
import { access, readFile, stat } from "node:fs/promises";
|
||||
import type { ImageContent } from "@gsd/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { resolve } from "path";
|
||||
import { resolveReadPath } from "../core/tools/path-utils.js";
|
||||
import { formatDimensionNote, resizeImage } from "../utils/image-resize.js";
|
||||
import { detectSupportedImageMimeTypeFromFile } from "../utils/mime.js";
|
||||
|
||||
export interface ProcessedFiles {
|
||||
text: string;
|
||||
images: ImageContent[];
|
||||
}
|
||||
|
||||
export interface ProcessFileOptions {
|
||||
/** Whether to auto-resize images to 2000x2000 max. Default: true */
|
||||
autoResizeImages?: boolean;
|
||||
}
|
||||
|
||||
/** Process @file arguments into text content and image attachments */
|
||||
export async function processFileArguments(fileArgs: string[], options?: ProcessFileOptions): Promise<ProcessedFiles> {
|
||||
const autoResizeImages = options?.autoResizeImages ?? true;
|
||||
let text = "";
|
||||
const images: ImageContent[] = [];
|
||||
|
||||
for (const fileArg of fileArgs) {
|
||||
// Expand and resolve path (handles ~ expansion and macOS screenshot Unicode spaces)
|
||||
const absolutePath = resolve(resolveReadPath(fileArg, process.cwd()));
|
||||
|
||||
// Check if file exists
|
||||
try {
|
||||
await access(absolutePath);
|
||||
} catch {
|
||||
console.error(chalk.red(`Error: File not found: ${absolutePath}`));
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
// Check if file is empty
|
||||
const stats = await stat(absolutePath);
|
||||
if (stats.size === 0) {
|
||||
// Skip empty files
|
||||
continue;
|
||||
}
|
||||
|
||||
const mimeType = await detectSupportedImageMimeTypeFromFile(absolutePath);
|
||||
|
||||
if (mimeType) {
|
||||
// Handle image file
|
||||
const content = await readFile(absolutePath);
|
||||
const base64Content = content.toString("base64");
|
||||
|
||||
let attachment: ImageContent;
|
||||
let dimensionNote: string | undefined;
|
||||
|
||||
if (autoResizeImages) {
|
||||
const resized = await resizeImage({ type: "image", data: base64Content, mimeType });
|
||||
dimensionNote = formatDimensionNote(resized);
|
||||
attachment = {
|
||||
type: "image",
|
||||
mimeType: resized.mimeType,
|
||||
data: resized.data,
|
||||
};
|
||||
} else {
|
||||
attachment = {
|
||||
type: "image",
|
||||
mimeType,
|
||||
data: base64Content,
|
||||
};
|
||||
}
|
||||
|
||||
images.push(attachment);
|
||||
|
||||
// Add text reference to image with optional dimension note
|
||||
if (dimensionNote) {
|
||||
text += `<file name="${absolutePath}">${dimensionNote}</file>\n`;
|
||||
} else {
|
||||
text += `<file name="${absolutePath}"></file>\n`;
|
||||
}
|
||||
} else {
|
||||
// Handle text file
|
||||
try {
|
||||
const content = await readFile(absolutePath, "utf-8");
|
||||
text += `<file name="${absolutePath}">\n${content}\n</file>\n`;
|
||||
} catch (error: unknown) {
|
||||
const message = error instanceof Error ? error.message : String(error);
|
||||
console.error(chalk.red(`Error: Could not read file ${absolutePath}: ${message}`));
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { text, images };
|
||||
}
|
||||
104
packages/pi-coding-agent/src/cli/list-models.ts
Normal file
104
packages/pi-coding-agent/src/cli/list-models.ts
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* List available models with optional fuzzy search
|
||||
*/
|
||||
|
||||
import type { Api, Model } from "@gsd/pi-ai";
|
||||
import { fuzzyFilter } from "@gsd/pi-tui";
|
||||
import type { ModelRegistry } from "../core/model-registry.js";
|
||||
|
||||
/**
|
||||
* Format a number as human-readable (e.g., 200000 -> "200K", 1000000 -> "1M")
|
||||
*/
|
||||
function formatTokenCount(count: number): string {
|
||||
if (count >= 1_000_000) {
|
||||
const millions = count / 1_000_000;
|
||||
return millions % 1 === 0 ? `${millions}M` : `${millions.toFixed(1)}M`;
|
||||
}
|
||||
if (count >= 1_000) {
|
||||
const thousands = count / 1_000;
|
||||
return thousands % 1 === 0 ? `${thousands}K` : `${thousands.toFixed(1)}K`;
|
||||
}
|
||||
return count.toString();
|
||||
}
|
||||
|
||||
/**
|
||||
* List available models, optionally filtered by search pattern
|
||||
*/
|
||||
export async function listModels(modelRegistry: ModelRegistry, searchPattern?: string): Promise<void> {
|
||||
const models = modelRegistry.getAvailable();
|
||||
|
||||
if (models.length === 0) {
|
||||
console.log("No models available. Set API keys in environment variables.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Apply fuzzy filter if search pattern provided
|
||||
let filteredModels: Model<Api>[] = models;
|
||||
if (searchPattern) {
|
||||
filteredModels = fuzzyFilter(models, searchPattern, (m) => `${m.provider} ${m.id}`);
|
||||
}
|
||||
|
||||
if (filteredModels.length === 0) {
|
||||
console.log(`No models matching "${searchPattern}"`);
|
||||
return;
|
||||
}
|
||||
|
||||
// Sort by provider, then by model id
|
||||
filteredModels.sort((a, b) => {
|
||||
const providerCmp = a.provider.localeCompare(b.provider);
|
||||
if (providerCmp !== 0) return providerCmp;
|
||||
return a.id.localeCompare(b.id);
|
||||
});
|
||||
|
||||
// Calculate column widths
|
||||
const rows = filteredModels.map((m) => ({
|
||||
provider: m.provider,
|
||||
model: m.id,
|
||||
context: formatTokenCount(m.contextWindow),
|
||||
maxOut: formatTokenCount(m.maxTokens),
|
||||
thinking: m.reasoning ? "yes" : "no",
|
||||
images: m.input.includes("image") ? "yes" : "no",
|
||||
}));
|
||||
|
||||
const headers = {
|
||||
provider: "provider",
|
||||
model: "model",
|
||||
context: "context",
|
||||
maxOut: "max-out",
|
||||
thinking: "thinking",
|
||||
images: "images",
|
||||
};
|
||||
|
||||
const widths = {
|
||||
provider: Math.max(headers.provider.length, ...rows.map((r) => r.provider.length)),
|
||||
model: Math.max(headers.model.length, ...rows.map((r) => r.model.length)),
|
||||
context: Math.max(headers.context.length, ...rows.map((r) => r.context.length)),
|
||||
maxOut: Math.max(headers.maxOut.length, ...rows.map((r) => r.maxOut.length)),
|
||||
thinking: Math.max(headers.thinking.length, ...rows.map((r) => r.thinking.length)),
|
||||
images: Math.max(headers.images.length, ...rows.map((r) => r.images.length)),
|
||||
};
|
||||
|
||||
// Print header
|
||||
const headerLine = [
|
||||
headers.provider.padEnd(widths.provider),
|
||||
headers.model.padEnd(widths.model),
|
||||
headers.context.padEnd(widths.context),
|
||||
headers.maxOut.padEnd(widths.maxOut),
|
||||
headers.thinking.padEnd(widths.thinking),
|
||||
headers.images.padEnd(widths.images),
|
||||
].join(" ");
|
||||
console.log(headerLine);
|
||||
|
||||
// Print rows
|
||||
for (const row of rows) {
|
||||
const line = [
|
||||
row.provider.padEnd(widths.provider),
|
||||
row.model.padEnd(widths.model),
|
||||
row.context.padEnd(widths.context),
|
||||
row.maxOut.padEnd(widths.maxOut),
|
||||
row.thinking.padEnd(widths.thinking),
|
||||
row.images.padEnd(widths.images),
|
||||
].join(" ");
|
||||
console.log(line);
|
||||
}
|
||||
}
|
||||
51
packages/pi-coding-agent/src/cli/session-picker.ts
Normal file
51
packages/pi-coding-agent/src/cli/session-picker.ts
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
/**
|
||||
* TUI session selector for --resume flag
|
||||
*/
|
||||
|
||||
import { ProcessTerminal, TUI } from "@gsd/pi-tui";
|
||||
import { KeybindingsManager } from "../core/keybindings.js";
|
||||
import type { SessionInfo, SessionListProgress } from "../core/session-manager.js";
|
||||
import { SessionSelectorComponent } from "../modes/interactive/components/session-selector.js";
|
||||
|
||||
type SessionsLoader = (onProgress?: SessionListProgress) => Promise<SessionInfo[]>;
|
||||
|
||||
/** Show TUI session selector and return selected session path or null if cancelled */
|
||||
export async function selectSession(
|
||||
currentSessionsLoader: SessionsLoader,
|
||||
allSessionsLoader: SessionsLoader,
|
||||
): Promise<string | null> {
|
||||
return new Promise((resolve) => {
|
||||
const ui = new TUI(new ProcessTerminal());
|
||||
const keybindings = KeybindingsManager.create();
|
||||
let resolved = false;
|
||||
|
||||
const selector = new SessionSelectorComponent(
|
||||
currentSessionsLoader,
|
||||
allSessionsLoader,
|
||||
(path: string) => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
resolve(path);
|
||||
}
|
||||
},
|
||||
() => {
|
||||
if (!resolved) {
|
||||
resolved = true;
|
||||
ui.stop();
|
||||
resolve(null);
|
||||
}
|
||||
},
|
||||
() => {
|
||||
ui.stop();
|
||||
process.exit(0);
|
||||
},
|
||||
() => ui.requestRender(),
|
||||
{ showRenameHint: false, keybindings },
|
||||
);
|
||||
|
||||
ui.addChild(selector);
|
||||
ui.setFocus(selector.getSessionList());
|
||||
ui.start();
|
||||
});
|
||||
}
|
||||
241
packages/pi-coding-agent/src/config.ts
Normal file
241
packages/pi-coding-agent/src/config.ts
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
import { existsSync, readFileSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { dirname, join, resolve } from "path";
|
||||
import { fileURLToPath } from "url";
|
||||
|
||||
// =============================================================================
|
||||
// Package Detection
|
||||
// =============================================================================
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
/**
|
||||
* Detect if we're running as a Bun compiled binary.
|
||||
* Bun binaries have import.meta.url containing "$bunfs", "~BUN", or "%7EBUN" (Bun's virtual filesystem path)
|
||||
*/
|
||||
export const isBunBinary =
|
||||
import.meta.url.includes("$bunfs") || import.meta.url.includes("~BUN") || import.meta.url.includes("%7EBUN");
|
||||
|
||||
/** Detect if Bun is the runtime (compiled binary or bun run) */
|
||||
export const isBunRuntime = !!process.versions.bun;
|
||||
|
||||
// =============================================================================
|
||||
// Install Method Detection
|
||||
// =============================================================================
|
||||
|
||||
export type InstallMethod = "bun-binary" | "npm" | "pnpm" | "yarn" | "bun" | "unknown";
|
||||
|
||||
export function detectInstallMethod(): InstallMethod {
|
||||
if (isBunBinary) {
|
||||
return "bun-binary";
|
||||
}
|
||||
|
||||
const resolvedPath = `${__dirname}\0${process.execPath || ""}`.toLowerCase();
|
||||
|
||||
if (resolvedPath.includes("/pnpm/") || resolvedPath.includes("/.pnpm/") || resolvedPath.includes("\\pnpm\\")) {
|
||||
return "pnpm";
|
||||
}
|
||||
if (resolvedPath.includes("/yarn/") || resolvedPath.includes("/.yarn/") || resolvedPath.includes("\\yarn\\")) {
|
||||
return "yarn";
|
||||
}
|
||||
if (isBunRuntime) {
|
||||
return "bun";
|
||||
}
|
||||
if (resolvedPath.includes("/npm/") || resolvedPath.includes("/node_modules/") || resolvedPath.includes("\\npm\\")) {
|
||||
return "npm";
|
||||
}
|
||||
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
export function getUpdateInstruction(packageName: string): string {
|
||||
const method = detectInstallMethod();
|
||||
switch (method) {
|
||||
case "bun-binary":
|
||||
return `Download from: https://github.com/badlogic/pi-mono/releases/latest`;
|
||||
case "pnpm":
|
||||
return `Run: pnpm install -g ${packageName}`;
|
||||
case "yarn":
|
||||
return `Run: yarn global add ${packageName}`;
|
||||
case "bun":
|
||||
return `Run: bun install -g ${packageName}`;
|
||||
case "npm":
|
||||
return `Run: npm install -g ${packageName}`;
|
||||
default:
|
||||
return `Run: npm install -g ${packageName}`;
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Package Asset Paths (shipped with executable)
|
||||
// =============================================================================
|
||||
|
||||
/**
|
||||
* Get the base directory for resolving package assets (themes, package.json, README.md, CHANGELOG.md).
|
||||
* - For Bun binary: returns the directory containing the executable
|
||||
* - For Node.js (dist/): returns __dirname (the dist/ directory)
|
||||
* - For tsx (src/): returns parent directory (the package root)
|
||||
*/
|
||||
export function getPackageDir(): string {
|
||||
// Allow override via environment variable (useful for Nix/Guix where store paths tokenize poorly)
|
||||
const envDir = process.env.PI_PACKAGE_DIR;
|
||||
if (envDir) {
|
||||
if (envDir === "~") return homedir();
|
||||
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
|
||||
return envDir;
|
||||
}
|
||||
|
||||
if (isBunBinary) {
|
||||
// Bun binary: process.execPath points to the compiled executable
|
||||
return dirname(process.execPath);
|
||||
}
|
||||
// Node.js: walk up from __dirname until we find package.json
|
||||
let dir = __dirname;
|
||||
while (dir !== dirname(dir)) {
|
||||
if (existsSync(join(dir, "package.json"))) {
|
||||
return dir;
|
||||
}
|
||||
dir = dirname(dir);
|
||||
}
|
||||
// Fallback (shouldn't happen)
|
||||
return __dirname;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get path to built-in themes directory (shipped with package)
|
||||
* - For Bun binary: theme/ next to executable
|
||||
* - For Node.js (dist/): dist/modes/interactive/theme/
|
||||
* - For tsx (src/): src/modes/interactive/theme/
|
||||
*/
|
||||
export function getThemesDir(): string {
|
||||
if (isBunBinary) {
|
||||
return join(dirname(process.execPath), "theme");
|
||||
}
|
||||
// Theme is in modes/interactive/theme/ relative to src/ or dist/
|
||||
const packageDir = getPackageDir();
|
||||
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
|
||||
return join(packageDir, srcOrDist, "modes", "interactive", "theme");
|
||||
}
|
||||
|
||||
/**
|
||||
* Get path to HTML export template directory (shipped with package)
|
||||
* - For Bun binary: export-html/ next to executable
|
||||
* - For Node.js (dist/): dist/core/export-html/
|
||||
* - For tsx (src/): src/core/export-html/
|
||||
*/
|
||||
export function getExportTemplateDir(): string {
|
||||
if (isBunBinary) {
|
||||
return join(dirname(process.execPath), "export-html");
|
||||
}
|
||||
const packageDir = getPackageDir();
|
||||
const srcOrDist = existsSync(join(packageDir, "src")) ? "src" : "dist";
|
||||
return join(packageDir, srcOrDist, "core", "export-html");
|
||||
}
|
||||
|
||||
/** Get path to package.json */
|
||||
export function getPackageJsonPath(): string {
|
||||
return join(getPackageDir(), "package.json");
|
||||
}
|
||||
|
||||
/** Get path to README.md */
|
||||
export function getReadmePath(): string {
|
||||
return resolve(join(getPackageDir(), "README.md"));
|
||||
}
|
||||
|
||||
/** Get path to docs directory */
|
||||
export function getDocsPath(): string {
|
||||
return resolve(join(getPackageDir(), "docs"));
|
||||
}
|
||||
|
||||
/** Get path to examples directory */
|
||||
export function getExamplesPath(): string {
|
||||
return resolve(join(getPackageDir(), "examples"));
|
||||
}
|
||||
|
||||
/** Get path to CHANGELOG.md */
|
||||
export function getChangelogPath(): string {
|
||||
return resolve(join(getPackageDir(), "CHANGELOG.md"));
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// App Config (from package.json piConfig)
|
||||
// =============================================================================
|
||||
|
||||
const pkg = JSON.parse(readFileSync(getPackageJsonPath(), "utf-8"));
|
||||
|
||||
export const APP_NAME: string = pkg.piConfig?.name || "pi";
|
||||
export const CONFIG_DIR_NAME: string = pkg.piConfig?.configDir || ".pi";
|
||||
export const VERSION: string = pkg.version;
|
||||
|
||||
// e.g., PI_CODING_AGENT_DIR or TAU_CODING_AGENT_DIR
|
||||
export const ENV_AGENT_DIR = `${APP_NAME.toUpperCase()}_CODING_AGENT_DIR`;
|
||||
|
||||
const DEFAULT_SHARE_VIEWER_URL = "https://pi.dev/session/";
|
||||
|
||||
/** Get the share viewer URL for a gist ID */
|
||||
export function getShareViewerUrl(gistId: string): string {
|
||||
const baseUrl = process.env.PI_SHARE_VIEWER_URL || DEFAULT_SHARE_VIEWER_URL;
|
||||
return `${baseUrl}#${gistId}`;
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// User Config Paths (~/.pi/agent/*)
|
||||
// =============================================================================
|
||||
|
||||
/** Get the agent config directory (e.g., ~/.pi/agent/) */
|
||||
export function getAgentDir(): string {
|
||||
const envDir = process.env[ENV_AGENT_DIR];
|
||||
if (envDir) {
|
||||
// Expand tilde to home directory
|
||||
if (envDir === "~") return homedir();
|
||||
if (envDir.startsWith("~/")) return homedir() + envDir.slice(1);
|
||||
return envDir;
|
||||
}
|
||||
return join(homedir(), CONFIG_DIR_NAME, "agent");
|
||||
}
|
||||
|
||||
/** Get path to user's custom themes directory */
|
||||
export function getCustomThemesDir(): string {
|
||||
return join(getAgentDir(), "themes");
|
||||
}
|
||||
|
||||
/** Get path to models.json */
|
||||
export function getModelsPath(): string {
|
||||
return join(getAgentDir(), "models.json");
|
||||
}
|
||||
|
||||
/** Get path to auth.json */
|
||||
export function getAuthPath(): string {
|
||||
return join(getAgentDir(), "auth.json");
|
||||
}
|
||||
|
||||
/** Get path to settings.json */
|
||||
export function getSettingsPath(): string {
|
||||
return join(getAgentDir(), "settings.json");
|
||||
}
|
||||
|
||||
/** Get path to tools directory */
|
||||
export function getToolsDir(): string {
|
||||
return join(getAgentDir(), "tools");
|
||||
}
|
||||
|
||||
/** Get path to managed binaries directory (fd, rg) */
|
||||
export function getBinDir(): string {
|
||||
return join(getAgentDir(), "bin");
|
||||
}
|
||||
|
||||
/** Get path to prompt templates directory */
|
||||
export function getPromptsDir(): string {
|
||||
return join(getAgentDir(), "prompts");
|
||||
}
|
||||
|
||||
/** Get path to sessions directory */
|
||||
export function getSessionsDir(): string {
|
||||
return join(getAgentDir(), "sessions");
|
||||
}
|
||||
|
||||
/** Get path to debug log file */
|
||||
export function getDebugLogPath(): string {
|
||||
return join(getAgentDir(), `${APP_NAME}-debug.log`);
|
||||
}
|
||||
3050
packages/pi-coding-agent/src/core/agent-session.ts
Normal file
3050
packages/pi-coding-agent/src/core/agent-session.ts
Normal file
File diff suppressed because it is too large
Load diff
489
packages/pi-coding-agent/src/core/auth-storage.ts
Normal file
489
packages/pi-coding-agent/src/core/auth-storage.ts
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
/**
|
||||
* Credential storage for API keys and OAuth tokens.
|
||||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*
|
||||
* Uses file locking to prevent race conditions when multiple pi instances
|
||||
* try to refresh tokens simultaneously.
|
||||
*/
|
||||
|
||||
import {
|
||||
getEnvApiKey,
|
||||
type OAuthCredentials,
|
||||
type OAuthLoginCallbacks,
|
||||
type OAuthProviderId,
|
||||
} from "@gsd/pi-ai";
|
||||
import { getOAuthApiKey, getOAuthProvider, getOAuthProviders } from "@gsd/pi-ai/oauth";
|
||||
import { chmodSync, existsSync, mkdirSync, readFileSync, writeFileSync } from "fs";
|
||||
import { dirname, join } from "path";
|
||||
import lockfile from "proper-lockfile";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import { resolveConfigValue } from "./resolve-config-value.js";
|
||||
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
key: string;
|
||||
};
|
||||
|
||||
export type OAuthCredential = {
|
||||
type: "oauth";
|
||||
} & OAuthCredentials;
|
||||
|
||||
export type AuthCredential = ApiKeyCredential | OAuthCredential;
|
||||
|
||||
export type AuthStorageData = Record<string, AuthCredential>;
|
||||
|
||||
type LockResult<T> = {
|
||||
result: T;
|
||||
next?: string;
|
||||
};
|
||||
|
||||
export interface AuthStorageBackend {
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T;
|
||||
withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T>;
|
||||
}
|
||||
|
||||
export class FileAuthStorageBackend implements AuthStorageBackend {
|
||||
constructor(private authPath: string = join(getAgentDir(), "auth.json")) {}
|
||||
|
||||
private ensureParentDir(): void {
|
||||
const dir = dirname(this.authPath);
|
||||
if (!existsSync(dir)) {
|
||||
mkdirSync(dir, { recursive: true, mode: 0o700 });
|
||||
}
|
||||
}
|
||||
|
||||
private ensureFileExists(): void {
|
||||
if (!existsSync(this.authPath)) {
|
||||
writeFileSync(this.authPath, "{}", "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
}
|
||||
|
||||
private acquireLockSyncWithRetry(path: string): () => void {
|
||||
const maxAttempts = 10;
|
||||
const delayMs = 20;
|
||||
let lastError: unknown;
|
||||
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
try {
|
||||
return lockfile.lockSync(path, { realpath: false });
|
||||
} catch (error) {
|
||||
const code =
|
||||
typeof error === "object" && error !== null && "code" in error
|
||||
? String((error as { code?: unknown }).code)
|
||||
: undefined;
|
||||
if (code !== "ELOCKED" || attempt === maxAttempts) {
|
||||
throw error;
|
||||
}
|
||||
lastError = error;
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < delayMs) {
|
||||
// Sleep synchronously to avoid changing callers to async.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw (lastError as Error) ?? new Error("Failed to acquire auth storage lock");
|
||||
}
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => void) | undefined;
|
||||
try {
|
||||
release = this.acquireLockSyncWithRetry(this.authPath);
|
||||
const current = existsSync(this.authPath) ? readFileSync(this.authPath, "utf-8") : undefined;
|
||||
const { result, next } = fn(current);
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
release();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T> {
|
||||
this.ensureParentDir();
|
||||
this.ensureFileExists();
|
||||
|
||||
let release: (() => Promise<void>) | undefined;
|
||||
let lockCompromised = false;
|
||||
let lockCompromisedError: Error | undefined;
|
||||
const throwIfCompromised = () => {
|
||||
if (lockCompromised) {
|
||||
throw lockCompromisedError ?? new Error("Auth storage lock was compromised");
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
release = await lockfile.lock(this.authPath, {
|
||||
retries: {
|
||||
retries: 10,
|
||||
factor: 2,
|
||||
minTimeout: 100,
|
||||
maxTimeout: 10000,
|
||||
randomize: true,
|
||||
},
|
||||
stale: 30000,
|
||||
onCompromised: (err) => {
|
||||
lockCompromised = true;
|
||||
lockCompromisedError = err;
|
||||
},
|
||||
});
|
||||
|
||||
throwIfCompromised();
|
||||
const current = existsSync(this.authPath) ? readFileSync(this.authPath, "utf-8") : undefined;
|
||||
const { result, next } = await fn(current);
|
||||
throwIfCompromised();
|
||||
if (next !== undefined) {
|
||||
writeFileSync(this.authPath, next, "utf-8");
|
||||
chmodSync(this.authPath, 0o600);
|
||||
}
|
||||
throwIfCompromised();
|
||||
return result;
|
||||
} finally {
|
||||
if (release) {
|
||||
try {
|
||||
await release();
|
||||
} catch {
|
||||
// Ignore unlock errors when lock is compromised.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export class InMemoryAuthStorageBackend implements AuthStorageBackend {
|
||||
private value: string | undefined;
|
||||
|
||||
withLock<T>(fn: (current: string | undefined) => LockResult<T>): T {
|
||||
const { result, next } = fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
async withLockAsync<T>(fn: (current: string | undefined) => Promise<LockResult<T>>): Promise<T> {
|
||||
const { result, next } = await fn(this.value);
|
||||
if (next !== undefined) {
|
||||
this.value = next;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Credential storage backed by a JSON file.
|
||||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
private runtimeOverrides: Map<string, string> = new Map();
|
||||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
private loadError: Error | null = null;
|
||||
private errors: Error[] = [];
|
||||
|
||||
private constructor(private storage: AuthStorageBackend) {
|
||||
this.reload();
|
||||
}
|
||||
|
||||
static create(authPath?: string): AuthStorage {
|
||||
return new AuthStorage(new FileAuthStorageBackend(authPath ?? join(getAgentDir(), "auth.json")));
|
||||
}
|
||||
|
||||
static fromStorage(storage: AuthStorageBackend): AuthStorage {
|
||||
return new AuthStorage(storage);
|
||||
}
|
||||
|
||||
static inMemory(data: AuthStorageData = {}): AuthStorage {
|
||||
const storage = new InMemoryAuthStorageBackend();
|
||||
storage.withLock(() => ({ result: undefined, next: JSON.stringify(data, null, 2) }));
|
||||
return AuthStorage.fromStorage(storage);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a runtime API key override (not persisted to disk).
|
||||
* Used for CLI --api-key flag.
|
||||
*/
|
||||
setRuntimeApiKey(provider: string, apiKey: string): void {
|
||||
this.runtimeOverrides.set(provider, apiKey);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a runtime API key override.
|
||||
*/
|
||||
removeRuntimeApiKey(provider: string): void {
|
||||
this.runtimeOverrides.delete(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Set a fallback resolver for API keys not found in auth.json or env vars.
|
||||
* Used for custom provider keys from models.json.
|
||||
*/
|
||||
setFallbackResolver(resolver: (provider: string) => string | undefined): void {
|
||||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
private recordError(error: unknown): void {
|
||||
const normalizedError = error instanceof Error ? error : new Error(String(error));
|
||||
this.errors.push(normalizedError);
|
||||
}
|
||||
|
||||
private parseStorageData(content: string | undefined): AuthStorageData {
|
||||
if (!content) {
|
||||
return {};
|
||||
}
|
||||
return JSON.parse(content) as AuthStorageData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from storage.
|
||||
*/
|
||||
reload(): void {
|
||||
let content: string | undefined;
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
content = current;
|
||||
return { result: undefined };
|
||||
});
|
||||
this.data = this.parseStorageData(content);
|
||||
this.loadError = null;
|
||||
} catch (error) {
|
||||
this.loadError = error as Error;
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
private persistProviderChange(provider: string, credential: AuthCredential | undefined): void {
|
||||
if (this.loadError) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
this.storage.withLock((current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
const merged: AuthStorageData = { ...currentData };
|
||||
if (credential) {
|
||||
merged[provider] = credential;
|
||||
} else {
|
||||
delete merged[provider];
|
||||
}
|
||||
return { result: undefined, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get credential for a provider.
|
||||
*/
|
||||
get(provider: string): AuthCredential | undefined {
|
||||
return this.data[provider] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set credential for a provider.
|
||||
*/
|
||||
set(provider: string, credential: AuthCredential): void {
|
||||
this.data[provider] = credential;
|
||||
this.persistProviderChange(provider, credential);
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove credential for a provider.
|
||||
*/
|
||||
remove(provider: string): void {
|
||||
delete this.data[provider];
|
||||
this.persistProviderChange(provider, undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* List all providers with credentials.
|
||||
*/
|
||||
list(): string[] {
|
||||
return Object.keys(this.data);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if credentials exist for a provider in auth.json.
|
||||
*/
|
||||
has(provider: string): boolean {
|
||||
return provider in this.data;
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if any form of auth is configured for a provider.
|
||||
* Unlike getApiKey(), this doesn't refresh OAuth tokens.
|
||||
*/
|
||||
hasAuth(provider: string): boolean {
|
||||
if (this.runtimeOverrides.has(provider)) return true;
|
||||
if (this.data[provider]) return true;
|
||||
if (getEnvApiKey(provider)) return true;
|
||||
if (this.fallbackResolver?.(provider)) return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all credentials (for passing to getOAuthApiKey).
|
||||
*/
|
||||
getAll(): AuthStorageData {
|
||||
return { ...this.data };
|
||||
}
|
||||
|
||||
drainErrors(): Error[] {
|
||||
const drained = [...this.errors];
|
||||
this.errors = [];
|
||||
return drained;
|
||||
}
|
||||
|
||||
/**
|
||||
* Login to an OAuth provider.
|
||||
*/
|
||||
async login(providerId: OAuthProviderId, callbacks: OAuthLoginCallbacks): Promise<void> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
throw new Error(`Unknown OAuth provider: ${providerId}`);
|
||||
}
|
||||
|
||||
const credentials = await provider.login(callbacks);
|
||||
this.set(providerId, { type: "oauth", ...credentials });
|
||||
}
|
||||
|
||||
/**
|
||||
* Logout from a provider.
|
||||
*/
|
||||
logout(provider: string): void {
|
||||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OAuth token with backend locking to prevent race conditions.
|
||||
* Multiple pi instances may try to refresh simultaneously when tokens expire.
|
||||
*/
|
||||
private async refreshOAuthTokenWithLock(
|
||||
providerId: OAuthProviderId,
|
||||
): Promise<{ apiKey: string; newCredentials: OAuthCredentials } | null> {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const result = await this.storage.withLockAsync(async (current) => {
|
||||
const currentData = this.parseStorageData(current);
|
||||
this.data = currentData;
|
||||
this.loadError = null;
|
||||
|
||||
const cred = currentData[providerId];
|
||||
if (cred?.type !== "oauth") {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
if (Date.now() < cred.expires) {
|
||||
return { result: { apiKey: provider.getApiKey(cred), newCredentials: cred } };
|
||||
}
|
||||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(currentData)) {
|
||||
if (value.type === "oauth") {
|
||||
oauthCreds[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
const refreshed = await getOAuthApiKey(providerId, oauthCreds);
|
||||
if (!refreshed) {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
const merged: AuthStorageData = {
|
||||
...currentData,
|
||||
[providerId]: { type: "oauth", ...refreshed.newCredentials },
|
||||
};
|
||||
this.data = merged;
|
||||
this.loadError = null;
|
||||
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed with locking)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
*/
|
||||
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[providerId];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
// Unknown OAuth provider, can't get API key
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
|
||||
if (needsRefresh) {
|
||||
// Use locked refresh to prevent race conditions
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) {
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
// Refresh failed - re-read file to check if another instance succeeded
|
||||
this.reload();
|
||||
const updatedCred = this.data[providerId];
|
||||
|
||||
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
|
||||
// Another instance refreshed successfully, use those credentials
|
||||
return provider.getApiKey(updatedCred);
|
||||
}
|
||||
|
||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||
// User can /login to re-authenticate (credentials preserved for retry)
|
||||
return undefined;
|
||||
}
|
||||
} else {
|
||||
// Token not expired, use current access token
|
||||
return provider.getApiKey(cred);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
const envKey = getEnvApiKey(providerId);
|
||||
if (envKey) return envKey;
|
||||
|
||||
// Fall back to custom resolver (e.g., models.json custom providers)
|
||||
return this.fallbackResolver?.(providerId) ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all registered OAuth providers
|
||||
*/
|
||||
getOAuthProviders() {
|
||||
return getOAuthProviders();
|
||||
}
|
||||
}
|
||||
278
packages/pi-coding-agent/src/core/bash-executor.ts
Normal file
278
packages/pi-coding-agent/src/core/bash-executor.ts
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
/**
|
||||
* Bash command execution with streaming support and cancellation.
|
||||
*
|
||||
* This module provides a unified bash execution implementation used by:
|
||||
* - AgentSession.executeBash() for interactive and RPC modes
|
||||
* - Direct calls from modes that need bash execution
|
||||
*/
|
||||
|
||||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream, type WriteStream } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { type ChildProcess, spawn } from "child_process";
|
||||
import stripAnsi from "strip-ansi";
|
||||
import { getShellConfig, getShellEnv, killProcessTree, sanitizeBinaryOutput } from "../utils/shell.js";
|
||||
import type { BashOperations } from "./tools/bash.js";
|
||||
import { DEFAULT_MAX_BYTES, truncateTail } from "./tools/truncate.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BashExecutorOptions {
|
||||
/** Callback for streaming output chunks (already sanitized) */
|
||||
onChunk?: (chunk: string) => void;
|
||||
/** AbortSignal for cancellation */
|
||||
signal?: AbortSignal;
|
||||
}
|
||||
|
||||
export interface BashResult {
|
||||
/** Combined stdout + stderr output (sanitized, possibly truncated) */
|
||||
output: string;
|
||||
/** Process exit code (undefined if killed/cancelled) */
|
||||
exitCode: number | undefined;
|
||||
/** Whether the command was cancelled via signal */
|
||||
cancelled: boolean;
|
||||
/** Whether the output was truncated */
|
||||
truncated: boolean;
|
||||
/** Path to temp file containing full output (if output exceeded truncation threshold) */
|
||||
fullOutputPath?: string;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Implementation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Execute a bash command with optional streaming and cancellation support.
|
||||
*
|
||||
* Features:
|
||||
* - Streams sanitized output via onChunk callback
|
||||
* - Writes large output to temp file for later retrieval
|
||||
* - Supports cancellation via AbortSignal
|
||||
* - Sanitizes output (strips ANSI, removes binary garbage, normalizes newlines)
|
||||
* - Truncates output if it exceeds the default max bytes
|
||||
*
|
||||
* @param command - The bash command to execute
|
||||
* @param options - Optional streaming callback and abort signal
|
||||
* @returns Promise resolving to execution result
|
||||
*/
|
||||
export function executeBash(command: string, options?: BashExecutorOptions): Promise<BashResult> {
|
||||
return new Promise((resolve, reject) => {
|
||||
const { shell, args } = getShellConfig();
|
||||
const child: ChildProcess = spawn(shell, [...args, command], {
|
||||
detached: true,
|
||||
env: getShellEnv(),
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
// Track sanitized output for truncation
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
// Temp file for large output
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
// Handle abort signal
|
||||
const abortHandler = () => {
|
||||
if (child.pid) {
|
||||
killProcessTree(child.pid);
|
||||
}
|
||||
};
|
||||
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
// Already aborted, don't even start
|
||||
child.kill();
|
||||
resolve({
|
||||
output: "",
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: false,
|
||||
});
|
||||
return;
|
||||
}
|
||||
options.signal.addEventListener("abort", abortHandler, { once: true });
|
||||
}
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const handleData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize once at the source: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(stripAnsi(decoder.decode(data, { stream: true }))).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
// Write already-buffered chunks to temp file
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer of sanitized text
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback if provided
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
child.stdout?.on("data", handleData);
|
||||
child.stderr?.on("data", handleData);
|
||||
|
||||
child.on("close", (code) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Combine buffered chunks for truncation (already sanitized)
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
|
||||
// code === null means killed (cancelled)
|
||||
const cancelled = code === null;
|
||||
|
||||
resolve({
|
||||
output: truncationResult.truncated ? truncationResult.content : fullOutput,
|
||||
exitCode: cancelled ? undefined : code,
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
});
|
||||
});
|
||||
|
||||
child.on("error", (err) => {
|
||||
// Clean up abort listener
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", abortHandler);
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
reject(err);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a bash command using custom BashOperations.
|
||||
* Used for remote execution (SSH, containers, etc.).
|
||||
*/
|
||||
export async function executeBashWithOperations(
|
||||
command: string,
|
||||
cwd: string,
|
||||
operations: BashOperations,
|
||||
options?: BashExecutorOptions,
|
||||
): Promise<BashResult> {
|
||||
const outputChunks: string[] = [];
|
||||
let outputBytes = 0;
|
||||
const maxOutputBytes = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
let tempFilePath: string | undefined;
|
||||
let tempFileStream: WriteStream | undefined;
|
||||
let totalBytes = 0;
|
||||
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const onData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
// Sanitize: strip ANSI, replace binary garbage, normalize newlines
|
||||
const text = sanitizeBinaryOutput(stripAnsi(decoder.decode(data, { stream: true }))).replace(/\r/g, "");
|
||||
|
||||
// Start writing to temp file if exceeds threshold
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !tempFilePath) {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
tempFilePath = join(tmpdir(), `pi-bash-${id}.log`);
|
||||
tempFileStream = createWriteStream(tempFilePath);
|
||||
for (const chunk of outputChunks) {
|
||||
tempFileStream.write(chunk);
|
||||
}
|
||||
}
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.write(text);
|
||||
}
|
||||
|
||||
// Keep rolling buffer
|
||||
outputChunks.push(text);
|
||||
outputBytes += text.length;
|
||||
while (outputBytes > maxOutputBytes && outputChunks.length > 1) {
|
||||
const removed = outputChunks.shift()!;
|
||||
outputBytes -= removed.length;
|
||||
}
|
||||
|
||||
// Stream to callback
|
||||
if (options?.onChunk) {
|
||||
options.onChunk(text);
|
||||
}
|
||||
};
|
||||
|
||||
try {
|
||||
const result = await operations.exec(command, cwd, {
|
||||
onData,
|
||||
signal: options?.signal,
|
||||
});
|
||||
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
const cancelled = options?.signal?.aborted ?? false;
|
||||
|
||||
return {
|
||||
output: truncationResult.truncated ? truncationResult.content : fullOutput,
|
||||
exitCode: cancelled ? undefined : (result.exitCode ?? undefined),
|
||||
cancelled,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
} catch (err) {
|
||||
if (tempFileStream) {
|
||||
tempFileStream.end();
|
||||
}
|
||||
|
||||
// Check if it was an abort
|
||||
if (options?.signal?.aborted) {
|
||||
const fullOutput = outputChunks.join("");
|
||||
const truncationResult = truncateTail(fullOutput);
|
||||
return {
|
||||
output: truncationResult.truncated ? truncationResult.content : fullOutput,
|
||||
exitCode: undefined,
|
||||
cancelled: true,
|
||||
truncated: truncationResult.truncated,
|
||||
fullOutputPath: tempFilePath,
|
||||
};
|
||||
}
|
||||
|
||||
throw err;
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,352 @@
|
|||
/**
|
||||
* Branch summarization for tree navigation.
|
||||
*
|
||||
* When navigating to a different point in the session tree, this generates
|
||||
* a summary of the branch being left so context isn't lost.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@gsd/pi-agent-core";
|
||||
import type { Model } from "@gsd/pi-ai";
|
||||
import { completeSimple } from "@gsd/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type { ReadonlySessionManager, SessionEntry } from "../session-manager.js";
|
||||
import { estimateTokens } from "./compaction.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface BranchSummaryResult {
|
||||
summary?: string;
|
||||
readFiles?: string[];
|
||||
modifiedFiles?: string[];
|
||||
aborted?: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
/** Details stored in BranchSummaryEntry.details for file tracking */
|
||||
export interface BranchSummaryDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
export type { FileOperations } from "./utils.js";
|
||||
|
||||
export interface BranchPreparation {
|
||||
/** Messages extracted for summarization, in chronological order */
|
||||
messages: AgentMessage[];
|
||||
/** File operations extracted from tool calls */
|
||||
fileOps: FileOperations;
|
||||
/** Total estimated tokens in messages */
|
||||
totalTokens: number;
|
||||
}
|
||||
|
||||
export interface CollectEntriesResult {
|
||||
/** Entries to summarize, in chronological order */
|
||||
entries: SessionEntry[];
|
||||
/** Common ancestor between old and new position, if any */
|
||||
commonAncestorId: string | null;
|
||||
}
|
||||
|
||||
export interface GenerateBranchSummaryOptions {
|
||||
/** Model to use for summarization */
|
||||
model: Model<any>;
|
||||
/** API key for the model */
|
||||
apiKey: string;
|
||||
/** Abort signal for cancellation */
|
||||
signal: AbortSignal;
|
||||
/** Optional custom instructions for summarization */
|
||||
customInstructions?: string;
|
||||
/** If true, customInstructions replaces the default prompt instead of being appended */
|
||||
replaceInstructions?: boolean;
|
||||
/** Tokens reserved for prompt + LLM response (default 16384) */
|
||||
reserveTokens?: number;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry Collection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Collect entries that should be summarized when navigating from one position to another.
|
||||
*
|
||||
* Walks from oldLeafId back to the common ancestor with targetId, collecting entries
|
||||
* along the way. Does NOT stop at compaction boundaries - those are included and their
|
||||
* summaries become context.
|
||||
*
|
||||
* @param session - Session manager (read-only access)
|
||||
* @param oldLeafId - Current position (where we're navigating from)
|
||||
* @param targetId - Target position (where we're navigating to)
|
||||
* @returns Entries to summarize and the common ancestor
|
||||
*/
|
||||
export function collectEntriesForBranchSummary(
|
||||
session: ReadonlySessionManager,
|
||||
oldLeafId: string | null,
|
||||
targetId: string,
|
||||
): CollectEntriesResult {
|
||||
// If no old position, nothing to summarize
|
||||
if (!oldLeafId) {
|
||||
return { entries: [], commonAncestorId: null };
|
||||
}
|
||||
|
||||
// Find common ancestor (deepest node that's on both paths)
|
||||
const oldPath = new Set(session.getBranch(oldLeafId).map((e) => e.id));
|
||||
const targetPath = session.getBranch(targetId);
|
||||
|
||||
// targetPath is root-first, so iterate backwards to find deepest common ancestor
|
||||
let commonAncestorId: string | null = null;
|
||||
for (let i = targetPath.length - 1; i >= 0; i--) {
|
||||
if (oldPath.has(targetPath[i].id)) {
|
||||
commonAncestorId = targetPath[i].id;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Collect entries from old leaf back to common ancestor
|
||||
const entries: SessionEntry[] = [];
|
||||
let current: string | null = oldLeafId;
|
||||
|
||||
while (current && current !== commonAncestorId) {
|
||||
const entry = session.getEntry(current);
|
||||
if (!entry) break;
|
||||
entries.push(entry);
|
||||
current = entry.parentId;
|
||||
}
|
||||
|
||||
// Reverse to get chronological order
|
||||
entries.reverse();
|
||||
|
||||
return { entries, commonAncestorId };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry to Message Conversion
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from a session entry.
|
||||
* Similar to getMessageFromEntry in compaction.ts but also handles compaction entries.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
switch (entry.type) {
|
||||
case "message":
|
||||
// Skip tool results - context is in assistant's tool call
|
||||
if (entry.message.role === "toolResult") return undefined;
|
||||
return entry.message;
|
||||
|
||||
case "custom_message":
|
||||
return createCustomMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
|
||||
|
||||
case "branch_summary":
|
||||
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
|
||||
|
||||
case "compaction":
|
||||
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
|
||||
|
||||
// These don't contribute to conversation content
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "custom":
|
||||
case "label":
|
||||
return undefined;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Prepare entries for summarization with token budget.
|
||||
*
|
||||
* Walks entries from NEWEST to OLDEST, adding messages until we hit the token budget.
|
||||
* This ensures we keep the most recent context when the branch is too long.
|
||||
*
|
||||
* Also collects file operations from:
|
||||
* - Tool calls in assistant messages
|
||||
* - Existing branch_summary entries' details (for cumulative tracking)
|
||||
*
|
||||
* @param entries - Entries in chronological order
|
||||
* @param tokenBudget - Maximum tokens to include (0 = no limit)
|
||||
*/
|
||||
export function prepareBranchEntries(entries: SessionEntry[], tokenBudget: number = 0): BranchPreparation {
|
||||
const messages: AgentMessage[] = [];
|
||||
const fileOps = createFileOps();
|
||||
let totalTokens = 0;
|
||||
|
||||
// First pass: collect file ops from ALL entries (even if they don't fit in token budget)
|
||||
// This ensures we capture cumulative file tracking from nested branch summaries
|
||||
// Only extract from pi-generated summaries (fromHook !== true), not extension-generated ones
|
||||
for (const entry of entries) {
|
||||
if (entry.type === "branch_summary" && !entry.fromHook && entry.details) {
|
||||
const details = entry.details as BranchSummaryDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
// Modified files go into both edited and written for proper deduplication
|
||||
for (const f of details.modifiedFiles) {
|
||||
fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Second pass: walk from newest to oldest, adding messages until token budget
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
const message = getMessageFromEntry(entry);
|
||||
if (!message) continue;
|
||||
|
||||
// Extract file ops from assistant messages (tool calls)
|
||||
extractFileOpsFromMessage(message, fileOps);
|
||||
|
||||
const tokens = estimateTokens(message);
|
||||
|
||||
// Check budget before adding
|
||||
if (tokenBudget > 0 && totalTokens + tokens > tokenBudget) {
|
||||
// If this is a summary entry, try to fit it anyway as it's important context
|
||||
if (entry.type === "compaction" || entry.type === "branch_summary") {
|
||||
if (totalTokens < tokenBudget * 0.9) {
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
}
|
||||
// Stop - we've hit the budget
|
||||
break;
|
||||
}
|
||||
|
||||
messages.unshift(message);
|
||||
totalTokens += tokens;
|
||||
}
|
||||
|
||||
return { messages, fileOps, totalTokens };
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summary Generation
|
||||
// ============================================================================
|
||||
|
||||
const BRANCH_SUMMARY_PREAMBLE = `The user explored a different conversation branch before returning here.
|
||||
Summary of that exploration:
|
||||
|
||||
`;
|
||||
|
||||
const BRANCH_SUMMARY_PROMPT = `Create a structured summary of this conversation branch for context when returning later.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What was the user trying to accomplish in this branch?]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Work that was started but not finished]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [What should happen next to continue this work]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of abandoned branch entries.
|
||||
*
|
||||
* @param entries - Session entries to summarize (chronological order)
|
||||
* @param options - Generation options
|
||||
*/
|
||||
export async function generateBranchSummary(
|
||||
entries: SessionEntry[],
|
||||
options: GenerateBranchSummaryOptions,
|
||||
): Promise<BranchSummaryResult> {
|
||||
const { model, apiKey, signal, customInstructions, replaceInstructions, reserveTokens = 16384 } = options;
|
||||
|
||||
// Token budget = context window minus reserved space for prompt + response
|
||||
const contextWindow = model.contextWindow || 128000;
|
||||
const tokenBudget = contextWindow - reserveTokens;
|
||||
|
||||
const { messages, fileOps } = prepareBranchEntries(entries, tokenBudget);
|
||||
|
||||
if (messages.length === 0) {
|
||||
return { summary: "No content to summarize" };
|
||||
}
|
||||
|
||||
// Transform to LLM-compatible messages, then serialize to text
|
||||
// Serialization prevents the model from treating it as a conversation to continue
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build prompt
|
||||
let instructions: string;
|
||||
if (replaceInstructions && customInstructions) {
|
||||
instructions = customInstructions;
|
||||
} else if (customInstructions) {
|
||||
instructions = `${BRANCH_SUMMARY_PROMPT}\n\nAdditional focus: ${customInstructions}`;
|
||||
} else {
|
||||
instructions = BRANCH_SUMMARY_PROMPT;
|
||||
}
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${instructions}`;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
// Call LLM for summarization
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
{ apiKey, signal, maxTokens: 2048 },
|
||||
);
|
||||
|
||||
// Check if aborted or errored
|
||||
if (response.stopReason === "aborted") {
|
||||
return { aborted: true };
|
||||
}
|
||||
if (response.stopReason === "error") {
|
||||
return { error: response.errorMessage || "Summarization failed" };
|
||||
}
|
||||
|
||||
let summary = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
// Prepend preamble to provide context about the branch summary
|
||||
summary = BRANCH_SUMMARY_PREAMBLE + summary;
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
return {
|
||||
summary: summary || "No summary generated",
|
||||
readFiles,
|
||||
modifiedFiles,
|
||||
};
|
||||
}
|
||||
813
packages/pi-coding-agent/src/core/compaction/compaction.ts
Normal file
813
packages/pi-coding-agent/src/core/compaction/compaction.ts
Normal file
|
|
@ -0,0 +1,813 @@
|
|||
/**
|
||||
* Context compaction for long sessions.
|
||||
*
|
||||
* Pure functions for compaction logic. The session manager handles I/O,
|
||||
* and after compaction the session is reloaded.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@gsd/pi-agent-core";
|
||||
import type { AssistantMessage, Model, Usage } from "@gsd/pi-ai";
|
||||
import { completeSimple } from "@gsd/pi-ai";
|
||||
import {
|
||||
convertToLlm,
|
||||
createBranchSummaryMessage,
|
||||
createCompactionSummaryMessage,
|
||||
createCustomMessage,
|
||||
} from "../messages.js";
|
||||
import type { CompactionEntry, SessionEntry } from "../session-manager.js";
|
||||
import {
|
||||
computeFileLists,
|
||||
createFileOps,
|
||||
extractFileOpsFromMessage,
|
||||
type FileOperations,
|
||||
formatFileOperations,
|
||||
SUMMARIZATION_SYSTEM_PROMPT,
|
||||
serializeConversation,
|
||||
} from "./utils.js";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
/** Details stored in CompactionEntry.details for file tracking */
|
||||
export interface CompactionDetails {
|
||||
readFiles: string[];
|
||||
modifiedFiles: string[];
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from messages and previous compaction entries.
|
||||
*/
|
||||
function extractFileOperations(
|
||||
messages: AgentMessage[],
|
||||
entries: SessionEntry[],
|
||||
prevCompactionIndex: number,
|
||||
): FileOperations {
|
||||
const fileOps = createFileOps();
|
||||
|
||||
// Collect from previous compaction's details (if pi-generated)
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = entries[prevCompactionIndex] as CompactionEntry;
|
||||
if (!prevCompaction.fromHook && prevCompaction.details) {
|
||||
// fromHook field kept for session file compatibility
|
||||
const details = prevCompaction.details as CompactionDetails;
|
||||
if (Array.isArray(details.readFiles)) {
|
||||
for (const f of details.readFiles) fileOps.read.add(f);
|
||||
}
|
||||
if (Array.isArray(details.modifiedFiles)) {
|
||||
for (const f of details.modifiedFiles) fileOps.edited.add(f);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from tool calls in messages
|
||||
for (const msg of messages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
|
||||
return fileOps;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Extraction
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Extract AgentMessage from an entry if it produces one.
|
||||
* Returns undefined for entries that don't contribute to LLM context.
|
||||
*/
|
||||
function getMessageFromEntry(entry: SessionEntry): AgentMessage | undefined {
|
||||
if (entry.type === "message") {
|
||||
return entry.message;
|
||||
}
|
||||
if (entry.type === "custom_message") {
|
||||
return createCustomMessage(entry.customType, entry.content, entry.display, entry.details, entry.timestamp);
|
||||
}
|
||||
if (entry.type === "branch_summary") {
|
||||
return createBranchSummaryMessage(entry.summary, entry.fromId, entry.timestamp);
|
||||
}
|
||||
if (entry.type === "compaction") {
|
||||
return createCompactionSummaryMessage(entry.summary, entry.tokensBefore, entry.timestamp);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Result from compact() - SessionManager adds uuid/parentUuid when saving */
|
||||
export interface CompactionResult<T = unknown> {
|
||||
summary: string;
|
||||
firstKeptEntryId: string;
|
||||
tokensBefore: number;
|
||||
/** Extension-specific data (e.g., ArtifactIndex, version markers for structured compaction) */
|
||||
details?: T;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Types
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionSettings {
|
||||
enabled: boolean;
|
||||
reserveTokens: number;
|
||||
keepRecentTokens: number;
|
||||
}
|
||||
|
||||
export const DEFAULT_COMPACTION_SETTINGS: CompactionSettings = {
|
||||
enabled: true,
|
||||
reserveTokens: 16384,
|
||||
keepRecentTokens: 20000,
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// Token calculation
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Calculate total context tokens from usage.
|
||||
* Uses the native totalTokens field when available, falls back to computing from components.
|
||||
*/
|
||||
export function calculateContextTokens(usage: Usage): number {
|
||||
return usage.totalTokens || usage.input + usage.output + usage.cacheRead + usage.cacheWrite;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get usage from an assistant message if available.
|
||||
* Skips aborted and error messages as they don't have valid usage data.
|
||||
*/
|
||||
function getAssistantUsage(msg: AgentMessage): Usage | undefined {
|
||||
if (msg.role === "assistant" && "usage" in msg) {
|
||||
const assistantMsg = msg as AssistantMessage;
|
||||
if (assistantMsg.stopReason !== "aborted" && assistantMsg.stopReason !== "error" && assistantMsg.usage) {
|
||||
return assistantMsg.usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the last non-aborted assistant message usage from session entries.
|
||||
*/
|
||||
export function getLastAssistantUsage(entries: SessionEntry[]): Usage | undefined {
|
||||
for (let i = entries.length - 1; i >= 0; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type === "message") {
|
||||
const usage = getAssistantUsage(entry.message);
|
||||
if (usage) return usage;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
export interface ContextUsageEstimate {
|
||||
tokens: number;
|
||||
usageTokens: number;
|
||||
trailingTokens: number;
|
||||
lastUsageIndex: number | null;
|
||||
}
|
||||
|
||||
function getLastAssistantUsageInfo(messages: AgentMessage[]): { usage: Usage; index: number } | undefined {
|
||||
for (let i = messages.length - 1; i >= 0; i--) {
|
||||
const usage = getAssistantUsage(messages[i]);
|
||||
if (usage) return { usage, index: i };
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate context tokens from messages, using the last assistant usage when available.
|
||||
* If there are messages after the last usage, estimate their tokens with estimateTokens.
|
||||
*/
|
||||
export function estimateContextTokens(messages: AgentMessage[]): ContextUsageEstimate {
|
||||
const usageInfo = getLastAssistantUsageInfo(messages);
|
||||
|
||||
if (!usageInfo) {
|
||||
let estimated = 0;
|
||||
for (const message of messages) {
|
||||
estimated += estimateTokens(message);
|
||||
}
|
||||
return {
|
||||
tokens: estimated,
|
||||
usageTokens: 0,
|
||||
trailingTokens: estimated,
|
||||
lastUsageIndex: null,
|
||||
};
|
||||
}
|
||||
|
||||
const usageTokens = calculateContextTokens(usageInfo.usage);
|
||||
let trailingTokens = 0;
|
||||
for (let i = usageInfo.index + 1; i < messages.length; i++) {
|
||||
trailingTokens += estimateTokens(messages[i]);
|
||||
}
|
||||
|
||||
return {
|
||||
tokens: usageTokens + trailingTokens,
|
||||
usageTokens,
|
||||
trailingTokens,
|
||||
lastUsageIndex: usageInfo.index,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if compaction should trigger based on context usage.
|
||||
*/
|
||||
export function shouldCompact(contextTokens: number, contextWindow: number, settings: CompactionSettings): boolean {
|
||||
if (!settings.enabled) return false;
|
||||
return contextTokens > contextWindow - settings.reserveTokens;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Cut point detection
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* Estimate token count for a message using chars/4 heuristic.
|
||||
* This is conservative (overestimates tokens).
|
||||
*/
|
||||
export function estimateTokens(message: AgentMessage): number {
|
||||
let chars = 0;
|
||||
|
||||
switch (message.role) {
|
||||
case "user": {
|
||||
const content = (message as { content: string | Array<{ type: string; text?: string }> }).content;
|
||||
if (typeof content === "string") {
|
||||
chars = content.length;
|
||||
} else if (Array.isArray(content)) {
|
||||
for (const block of content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "assistant": {
|
||||
const assistant = message as AssistantMessage;
|
||||
for (const block of assistant.content) {
|
||||
if (block.type === "text") {
|
||||
chars += block.text.length;
|
||||
} else if (block.type === "thinking") {
|
||||
chars += block.thinking.length;
|
||||
} else if (block.type === "toolCall") {
|
||||
chars += block.name.length + JSON.stringify(block.arguments).length;
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "custom":
|
||||
case "toolResult": {
|
||||
if (typeof message.content === "string") {
|
||||
chars = message.content.length;
|
||||
} else {
|
||||
for (const block of message.content) {
|
||||
if (block.type === "text" && block.text) {
|
||||
chars += block.text.length;
|
||||
}
|
||||
if (block.type === "image") {
|
||||
chars += 4800; // Estimate images as 4000 chars, or 1200 tokens
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "bashExecution": {
|
||||
chars = message.command.length + message.output.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
case "branchSummary":
|
||||
case "compactionSummary": {
|
||||
chars = message.summary.length;
|
||||
return Math.ceil(chars / 4);
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find valid cut points: indices of user, assistant, custom, or bashExecution messages.
|
||||
* Never cut at tool results (they must follow their tool call).
|
||||
* When we cut at an assistant message with tool calls, its tool results follow it
|
||||
* and will be kept.
|
||||
* BashExecutionMessage is treated like a user message (user-initiated context).
|
||||
*/
|
||||
function findValidCutPoints(entries: SessionEntry[], startIndex: number, endIndex: number): number[] {
|
||||
const cutPoints: number[] = [];
|
||||
for (let i = startIndex; i < endIndex; i++) {
|
||||
const entry = entries[i];
|
||||
switch (entry.type) {
|
||||
case "message": {
|
||||
const role = entry.message.role;
|
||||
switch (role) {
|
||||
case "bashExecution":
|
||||
case "custom":
|
||||
case "branchSummary":
|
||||
case "compactionSummary":
|
||||
case "user":
|
||||
case "assistant":
|
||||
cutPoints.push(i);
|
||||
break;
|
||||
case "toolResult":
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case "thinking_level_change":
|
||||
case "model_change":
|
||||
case "compaction":
|
||||
case "branch_summary":
|
||||
case "custom":
|
||||
case "custom_message":
|
||||
case "label":
|
||||
}
|
||||
// branch_summary and custom_message are user-role messages, valid cut points
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
cutPoints.push(i);
|
||||
}
|
||||
}
|
||||
return cutPoints;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the user message (or bashExecution) that starts the turn containing the given entry index.
|
||||
* Returns -1 if no turn start found before the index.
|
||||
* BashExecutionMessage is treated like a user message for turn boundaries.
|
||||
*/
|
||||
export function findTurnStartIndex(entries: SessionEntry[], entryIndex: number, startIndex: number): number {
|
||||
for (let i = entryIndex; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
// branch_summary and custom_message are user-role messages, can start a turn
|
||||
if (entry.type === "branch_summary" || entry.type === "custom_message") {
|
||||
return i;
|
||||
}
|
||||
if (entry.type === "message") {
|
||||
const role = entry.message.role;
|
||||
if (role === "user" || role === "bashExecution") {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export interface CutPointResult {
|
||||
/** Index of first entry to keep */
|
||||
firstKeptEntryIndex: number;
|
||||
/** Index of user message that starts the turn being split, or -1 if not splitting */
|
||||
turnStartIndex: number;
|
||||
/** Whether this cut splits a turn (cut point is not a user message) */
|
||||
isSplitTurn: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the cut point in session entries that keeps approximately `keepRecentTokens`.
|
||||
*
|
||||
* Algorithm: Walk backwards from newest, accumulating estimated message sizes.
|
||||
* Stop when we've accumulated >= keepRecentTokens. Cut at that point.
|
||||
*
|
||||
* Can cut at user OR assistant messages (never tool results). When cutting at an
|
||||
* assistant message with tool calls, its tool results come after and will be kept.
|
||||
*
|
||||
* Returns CutPointResult with:
|
||||
* - firstKeptEntryIndex: the entry index to start keeping from
|
||||
* - turnStartIndex: if cutting mid-turn, the user message that started that turn
|
||||
* - isSplitTurn: whether we're cutting in the middle of a turn
|
||||
*
|
||||
* Only considers entries between `startIndex` and `endIndex` (exclusive).
|
||||
*/
|
||||
export function findCutPoint(
|
||||
entries: SessionEntry[],
|
||||
startIndex: number,
|
||||
endIndex: number,
|
||||
keepRecentTokens: number,
|
||||
): CutPointResult {
|
||||
const cutPoints = findValidCutPoints(entries, startIndex, endIndex);
|
||||
|
||||
if (cutPoints.length === 0) {
|
||||
return { firstKeptEntryIndex: startIndex, turnStartIndex: -1, isSplitTurn: false };
|
||||
}
|
||||
|
||||
// Walk backwards from newest, accumulating estimated message sizes
|
||||
let accumulatedTokens = 0;
|
||||
let cutIndex = cutPoints[0]; // Default: keep from first message (not header)
|
||||
|
||||
for (let i = endIndex - 1; i >= startIndex; i--) {
|
||||
const entry = entries[i];
|
||||
if (entry.type !== "message") continue;
|
||||
|
||||
// Estimate this message's size
|
||||
const messageTokens = estimateTokens(entry.message);
|
||||
accumulatedTokens += messageTokens;
|
||||
|
||||
// Check if we've exceeded the budget
|
||||
if (accumulatedTokens >= keepRecentTokens) {
|
||||
// Find the closest valid cut point at or after this entry
|
||||
for (let c = 0; c < cutPoints.length; c++) {
|
||||
if (cutPoints[c] >= i) {
|
||||
cutIndex = cutPoints[c];
|
||||
break;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Scan backwards from cutIndex to include any non-message entries (bash, settings, etc.)
|
||||
while (cutIndex > startIndex) {
|
||||
const prevEntry = entries[cutIndex - 1];
|
||||
// Stop at session header or compaction boundaries
|
||||
if (prevEntry.type === "compaction") {
|
||||
break;
|
||||
}
|
||||
if (prevEntry.type === "message") {
|
||||
// Stop if we hit any message
|
||||
break;
|
||||
}
|
||||
// Include this non-message entry (bash, settings change, etc.)
|
||||
cutIndex--;
|
||||
}
|
||||
|
||||
// Determine if this is a split turn
|
||||
const cutEntry = entries[cutIndex];
|
||||
const isUserMessage = cutEntry.type === "message" && cutEntry.message.role === "user";
|
||||
const turnStartIndex = isUserMessage ? -1 : findTurnStartIndex(entries, cutIndex, startIndex);
|
||||
|
||||
return {
|
||||
firstKeptEntryIndex: cutIndex,
|
||||
turnStartIndex,
|
||||
isSplitTurn: !isUserMessage && turnStartIndex !== -1,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization
|
||||
// ============================================================================
|
||||
|
||||
const SUMMARIZATION_PROMPT = `The messages above are a conversation to summarize. Create a structured context checkpoint summary that another LLM will use to continue the work.
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[What is the user trying to accomplish? Can be multiple items if the session covers different tasks.]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Any constraints, preferences, or requirements mentioned by user]
|
||||
- [Or "(none)" if none were mentioned]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Completed tasks/changes]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work]
|
||||
|
||||
### Blocked
|
||||
- [Issues preventing progress, if any]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale]
|
||||
|
||||
## Next Steps
|
||||
1. [Ordered list of what should happen next]
|
||||
|
||||
## Critical Context
|
||||
- [Any data, examples, or references needed to continue]
|
||||
- [Or "(none)" if not applicable]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
const UPDATE_SUMMARIZATION_PROMPT = `The messages above are NEW conversation messages to incorporate into the existing summary provided in <previous-summary> tags.
|
||||
|
||||
Update the existing structured summary with new information. RULES:
|
||||
- PRESERVE all existing information from the previous summary
|
||||
- ADD new progress, decisions, and context from the new messages
|
||||
- UPDATE the Progress section: move items from "In Progress" to "Done" when completed
|
||||
- UPDATE "Next Steps" based on what was accomplished
|
||||
- PRESERVE exact file paths, function names, and error messages
|
||||
- If something is no longer relevant, you may remove it
|
||||
|
||||
Use this EXACT format:
|
||||
|
||||
## Goal
|
||||
[Preserve existing goals, add new ones if the task expanded]
|
||||
|
||||
## Constraints & Preferences
|
||||
- [Preserve existing, add new ones discovered]
|
||||
|
||||
## Progress
|
||||
### Done
|
||||
- [x] [Include previously done items AND newly completed items]
|
||||
|
||||
### In Progress
|
||||
- [ ] [Current work - update based on progress]
|
||||
|
||||
### Blocked
|
||||
- [Current blockers - remove if resolved]
|
||||
|
||||
## Key Decisions
|
||||
- **[Decision]**: [Brief rationale] (preserve all previous, add new)
|
||||
|
||||
## Next Steps
|
||||
1. [Update based on current state]
|
||||
|
||||
## Critical Context
|
||||
- [Preserve important context, add new if needed]
|
||||
|
||||
Keep each section concise. Preserve exact file paths, function names, and error messages.`;
|
||||
|
||||
/**
|
||||
* Generate a summary of the conversation using the LLM.
|
||||
* If previousSummary is provided, uses the update prompt to merge.
|
||||
*/
|
||||
export async function generateSummary(
|
||||
currentMessages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
customInstructions?: string,
|
||||
previousSummary?: string,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.8 * reserveTokens);
|
||||
|
||||
// Use update prompt if we have a previous summary, otherwise initial prompt
|
||||
let basePrompt = previousSummary ? UPDATE_SUMMARIZATION_PROMPT : SUMMARIZATION_PROMPT;
|
||||
if (customInstructions) {
|
||||
basePrompt = `${basePrompt}\n\nAdditional focus: ${customInstructions}`;
|
||||
}
|
||||
|
||||
// Serialize conversation to text so model doesn't try to continue it
|
||||
// Convert to LLM messages first (handles custom types like bashExecution, custom, etc.)
|
||||
const llmMessages = convertToLlm(currentMessages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
|
||||
// Build the prompt with conversation wrapped in tags
|
||||
let promptText = `<conversation>\n${conversationText}\n</conversation>\n\n`;
|
||||
if (previousSummary) {
|
||||
promptText += `<previous-summary>\n${previousSummary}\n</previous-summary>\n\n`;
|
||||
}
|
||||
promptText += basePrompt;
|
||||
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const completionOptions = model.reasoning
|
||||
? { maxTokens, signal, apiKey, reasoning: "high" as const }
|
||||
: { maxTokens, signal, apiKey };
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
completionOptions,
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(`Summarization failed: ${response.errorMessage || "Unknown error"}`);
|
||||
}
|
||||
|
||||
const textContent = response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
|
||||
return textContent;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Compaction Preparation (for extensions)
|
||||
// ============================================================================
|
||||
|
||||
export interface CompactionPreparation {
|
||||
/** UUID of first entry to keep */
|
||||
firstKeptEntryId: string;
|
||||
/** Messages that will be summarized and discarded */
|
||||
messagesToSummarize: AgentMessage[];
|
||||
/** Messages that will be turned into turn prefix summary (if splitting) */
|
||||
turnPrefixMessages: AgentMessage[];
|
||||
/** Whether this is a split turn (cut point in middle of turn) */
|
||||
isSplitTurn: boolean;
|
||||
tokensBefore: number;
|
||||
/** Summary from previous compaction, for iterative update */
|
||||
previousSummary?: string;
|
||||
/** File operations extracted from messagesToSummarize */
|
||||
fileOps: FileOperations;
|
||||
/** Compaction settions from settings.jsonl */
|
||||
settings: CompactionSettings;
|
||||
}
|
||||
|
||||
export function prepareCompaction(
|
||||
pathEntries: SessionEntry[],
|
||||
settings: CompactionSettings,
|
||||
): CompactionPreparation | undefined {
|
||||
if (pathEntries.length > 0 && pathEntries[pathEntries.length - 1].type === "compaction") {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
let prevCompactionIndex = -1;
|
||||
for (let i = pathEntries.length - 1; i >= 0; i--) {
|
||||
if (pathEntries[i].type === "compaction") {
|
||||
prevCompactionIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const boundaryStart = prevCompactionIndex + 1;
|
||||
const boundaryEnd = pathEntries.length;
|
||||
|
||||
const usageStart = prevCompactionIndex >= 0 ? prevCompactionIndex : 0;
|
||||
const usageMessages: AgentMessage[] = [];
|
||||
for (let i = usageStart; i < boundaryEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) usageMessages.push(msg);
|
||||
}
|
||||
const tokensBefore = estimateContextTokens(usageMessages).tokens;
|
||||
|
||||
const cutPoint = findCutPoint(pathEntries, boundaryStart, boundaryEnd, settings.keepRecentTokens);
|
||||
|
||||
// Get UUID of first kept entry
|
||||
const firstKeptEntry = pathEntries[cutPoint.firstKeptEntryIndex];
|
||||
if (!firstKeptEntry?.id) {
|
||||
return undefined; // Session needs migration
|
||||
}
|
||||
const firstKeptEntryId = firstKeptEntry.id;
|
||||
|
||||
const historyEnd = cutPoint.isSplitTurn ? cutPoint.turnStartIndex : cutPoint.firstKeptEntryIndex;
|
||||
|
||||
// Messages to summarize (will be discarded after summary)
|
||||
const messagesToSummarize: AgentMessage[] = [];
|
||||
for (let i = boundaryStart; i < historyEnd; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) messagesToSummarize.push(msg);
|
||||
}
|
||||
|
||||
// Messages for turn prefix summary (if splitting a turn)
|
||||
const turnPrefixMessages: AgentMessage[] = [];
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (let i = cutPoint.turnStartIndex; i < cutPoint.firstKeptEntryIndex; i++) {
|
||||
const msg = getMessageFromEntry(pathEntries[i]);
|
||||
if (msg) turnPrefixMessages.push(msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Get previous summary for iterative update
|
||||
let previousSummary: string | undefined;
|
||||
if (prevCompactionIndex >= 0) {
|
||||
const prevCompaction = pathEntries[prevCompactionIndex] as CompactionEntry;
|
||||
previousSummary = prevCompaction.summary;
|
||||
}
|
||||
|
||||
// Extract file operations from messages and previous compaction
|
||||
const fileOps = extractFileOperations(messagesToSummarize, pathEntries, prevCompactionIndex);
|
||||
|
||||
// Also extract file ops from turn prefix if splitting
|
||||
if (cutPoint.isSplitTurn) {
|
||||
for (const msg of turnPrefixMessages) {
|
||||
extractFileOpsFromMessage(msg, fileOps);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn: cutPoint.isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
};
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Main compaction function
|
||||
// ============================================================================
|
||||
|
||||
const TURN_PREFIX_SUMMARIZATION_PROMPT = `This is the PREFIX of a turn that was too large to keep. The SUFFIX (recent work) is retained.
|
||||
|
||||
Summarize the prefix to provide context for the retained suffix:
|
||||
|
||||
## Original Request
|
||||
[What did the user ask for in this turn?]
|
||||
|
||||
## Early Progress
|
||||
- [Key decisions and work done in the prefix]
|
||||
|
||||
## Context for Suffix
|
||||
- [Information needed to understand the retained recent work]
|
||||
|
||||
Be concise. Focus on what's needed to understand the kept suffix.`;
|
||||
|
||||
/**
|
||||
* Generate summaries for compaction using prepared data.
|
||||
* Returns CompactionResult - SessionManager adds uuid/parentUuid when saving.
|
||||
*
|
||||
* @param preparation - Pre-calculated preparation from prepareCompaction()
|
||||
* @param customInstructions - Optional custom focus for the summary
|
||||
*/
|
||||
export async function compact(
|
||||
preparation: CompactionPreparation,
|
||||
model: Model<any>,
|
||||
apiKey: string,
|
||||
customInstructions?: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<CompactionResult> {
|
||||
const {
|
||||
firstKeptEntryId,
|
||||
messagesToSummarize,
|
||||
turnPrefixMessages,
|
||||
isSplitTurn,
|
||||
tokensBefore,
|
||||
previousSummary,
|
||||
fileOps,
|
||||
settings,
|
||||
} = preparation;
|
||||
|
||||
// Generate summaries (can be parallel if both needed) and merge into one
|
||||
let summary: string;
|
||||
|
||||
if (isSplitTurn && turnPrefixMessages.length > 0) {
|
||||
// Generate both summaries in parallel
|
||||
const [historyResult, turnPrefixResult] = await Promise.all([
|
||||
messagesToSummarize.length > 0
|
||||
? generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(turnPrefixMessages, model, settings.reserveTokens, apiKey, signal),
|
||||
]);
|
||||
// Merge into single summary
|
||||
summary = `${historyResult}\n\n---\n\n**Turn Context (split turn):**\n\n${turnPrefixResult}`;
|
||||
} else {
|
||||
// Just generate history summary
|
||||
summary = await generateSummary(
|
||||
messagesToSummarize,
|
||||
model,
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
signal,
|
||||
customInstructions,
|
||||
previousSummary,
|
||||
);
|
||||
}
|
||||
|
||||
// Compute file lists and append to summary
|
||||
const { readFiles, modifiedFiles } = computeFileLists(fileOps);
|
||||
summary += formatFileOperations(readFiles, modifiedFiles);
|
||||
|
||||
if (!firstKeptEntryId) {
|
||||
throw new Error("First kept entry has no UUID - session may need migration");
|
||||
}
|
||||
|
||||
return {
|
||||
summary,
|
||||
firstKeptEntryId,
|
||||
tokensBefore,
|
||||
details: { readFiles, modifiedFiles } as CompactionDetails,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate a summary for a turn prefix (when splitting a turn).
|
||||
*/
|
||||
async function generateTurnPrefixSummary(
|
||||
messages: AgentMessage[],
|
||||
model: Model<any>,
|
||||
reserveTokens: number,
|
||||
apiKey: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.floor(0.5 * reserveTokens); // Smaller budget for turn prefix
|
||||
const llmMessages = convertToLlm(messages);
|
||||
const conversationText = serializeConversation(llmMessages);
|
||||
const promptText = `<conversation>\n${conversationText}\n</conversation>\n\n${TURN_PREFIX_SUMMARIZATION_PROMPT}`;
|
||||
const summarizationMessages = [
|
||||
{
|
||||
role: "user" as const,
|
||||
content: [{ type: "text" as const, text: promptText }],
|
||||
timestamp: Date.now(),
|
||||
},
|
||||
];
|
||||
|
||||
const response = await completeSimple(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
{ maxTokens, signal, apiKey },
|
||||
);
|
||||
|
||||
if (response.stopReason === "error") {
|
||||
throw new Error(`Turn prefix summarization failed: ${response.errorMessage || "Unknown error"}`);
|
||||
}
|
||||
|
||||
return response.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n");
|
||||
}
|
||||
7
packages/pi-coding-agent/src/core/compaction/index.ts
Normal file
7
packages/pi-coding-agent/src/core/compaction/index.ts
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
/**
|
||||
* Compaction and summarization utilities.
|
||||
*/
|
||||
|
||||
export * from "./branch-summarization.js";
|
||||
export * from "./compaction.js";
|
||||
export * from "./utils.js";
|
||||
170
packages/pi-coding-agent/src/core/compaction/utils.ts
Normal file
170
packages/pi-coding-agent/src/core/compaction/utils.ts
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
/**
|
||||
* Shared utilities for compaction and branch summarization.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@gsd/pi-agent-core";
|
||||
import type { Message } from "@gsd/pi-ai";
|
||||
|
||||
// ============================================================================
|
||||
// File Operation Tracking
|
||||
// ============================================================================
|
||||
|
||||
export interface FileOperations {
|
||||
read: Set<string>;
|
||||
written: Set<string>;
|
||||
edited: Set<string>;
|
||||
}
|
||||
|
||||
export function createFileOps(): FileOperations {
|
||||
return {
|
||||
read: new Set(),
|
||||
written: new Set(),
|
||||
edited: new Set(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract file operations from tool calls in an assistant message.
|
||||
*/
|
||||
export function extractFileOpsFromMessage(message: AgentMessage, fileOps: FileOperations): void {
|
||||
if (message.role !== "assistant") return;
|
||||
if (!("content" in message) || !Array.isArray(message.content)) return;
|
||||
|
||||
for (const block of message.content) {
|
||||
if (typeof block !== "object" || block === null) continue;
|
||||
if (!("type" in block) || block.type !== "toolCall") continue;
|
||||
if (!("arguments" in block) || !("name" in block)) continue;
|
||||
|
||||
const args = block.arguments as Record<string, unknown> | undefined;
|
||||
if (!args) continue;
|
||||
|
||||
const path = typeof args.path === "string" ? args.path : undefined;
|
||||
if (!path) continue;
|
||||
|
||||
switch (block.name) {
|
||||
case "read":
|
||||
fileOps.read.add(path);
|
||||
break;
|
||||
case "write":
|
||||
fileOps.written.add(path);
|
||||
break;
|
||||
case "edit":
|
||||
fileOps.edited.add(path);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute final file lists from file operations.
|
||||
* Returns readFiles (files only read, not modified) and modifiedFiles.
|
||||
*/
|
||||
export function computeFileLists(fileOps: FileOperations): { readFiles: string[]; modifiedFiles: string[] } {
|
||||
const modified = new Set([...fileOps.edited, ...fileOps.written]);
|
||||
const readOnly = [...fileOps.read].filter((f) => !modified.has(f)).sort();
|
||||
const modifiedFiles = [...modified].sort();
|
||||
return { readFiles: readOnly, modifiedFiles };
|
||||
}
|
||||
|
||||
/**
|
||||
* Format file operations as XML tags for summary.
|
||||
*/
|
||||
export function formatFileOperations(readFiles: string[], modifiedFiles: string[]): string {
|
||||
const sections: string[] = [];
|
||||
if (readFiles.length > 0) {
|
||||
sections.push(`<read-files>\n${readFiles.join("\n")}\n</read-files>`);
|
||||
}
|
||||
if (modifiedFiles.length > 0) {
|
||||
sections.push(`<modified-files>\n${modifiedFiles.join("\n")}\n</modified-files>`);
|
||||
}
|
||||
if (sections.length === 0) return "";
|
||||
return `\n\n${sections.join("\n\n")}`;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Serialization
|
||||
// ============================================================================
|
||||
|
||||
/** Maximum characters for a tool result in serialized summaries. */
|
||||
const TOOL_RESULT_MAX_CHARS = 2000;
|
||||
|
||||
/**
|
||||
* Truncate text to a maximum character length for summarization.
|
||||
* Keeps the beginning and appends a truncation marker.
|
||||
*/
|
||||
function truncateForSummary(text: string, maxChars: number): string {
|
||||
if (text.length <= maxChars) return text;
|
||||
const truncatedChars = text.length - maxChars;
|
||||
return `${text.slice(0, maxChars)}\n\n[... ${truncatedChars} more characters truncated]`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Serialize LLM messages to text for summarization.
|
||||
* This prevents the model from treating it as a conversation to continue.
|
||||
* Call convertToLlm() first to handle custom message types.
|
||||
*
|
||||
* Tool results are truncated to keep the summarization request within
|
||||
* reasonable token budgets. Full content is not needed for summarization.
|
||||
*/
|
||||
export function serializeConversation(messages: Message[]): string {
|
||||
const parts: string[] = [];
|
||||
|
||||
for (const msg of messages) {
|
||||
if (msg.role === "user") {
|
||||
const content =
|
||||
typeof msg.content === "string"
|
||||
? msg.content
|
||||
: msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) parts.push(`[User]: ${content}`);
|
||||
} else if (msg.role === "assistant") {
|
||||
const textParts: string[] = [];
|
||||
const thinkingParts: string[] = [];
|
||||
const toolCalls: string[] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "text") {
|
||||
textParts.push(block.text);
|
||||
} else if (block.type === "thinking") {
|
||||
thinkingParts.push(block.thinking);
|
||||
} else if (block.type === "toolCall") {
|
||||
const args = block.arguments as Record<string, unknown>;
|
||||
const argsStr = Object.entries(args)
|
||||
.map(([k, v]) => `${k}=${JSON.stringify(v)}`)
|
||||
.join(", ");
|
||||
toolCalls.push(`${block.name}(${argsStr})`);
|
||||
}
|
||||
}
|
||||
|
||||
if (thinkingParts.length > 0) {
|
||||
parts.push(`[Assistant thinking]: ${thinkingParts.join("\n")}`);
|
||||
}
|
||||
if (textParts.length > 0) {
|
||||
parts.push(`[Assistant]: ${textParts.join("\n")}`);
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
parts.push(`[Assistant tool calls]: ${toolCalls.join("; ")}`);
|
||||
}
|
||||
} else if (msg.role === "toolResult") {
|
||||
const content = msg.content
|
||||
.filter((c): c is { type: "text"; text: string } => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("");
|
||||
if (content) {
|
||||
parts.push(`[Tool result]: ${truncateForSummary(content, TOOL_RESULT_MAX_CHARS)}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join("\n\n");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Summarization System Prompt
|
||||
// ============================================================================
|
||||
|
||||
export const SUMMARIZATION_SYSTEM_PROMPT = `You are a context summarization assistant. Your task is to read a conversation between a user and an AI coding assistant, then produce a structured summary following the exact format specified.
|
||||
|
||||
Do NOT continue the conversation. Do NOT respond to any questions in the conversation. ONLY output the structured summary.`;
|
||||
3
packages/pi-coding-agent/src/core/defaults.ts
Normal file
3
packages/pi-coding-agent/src/core/defaults.ts
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
import type { ThinkingLevel } from "@gsd/pi-agent-core";
|
||||
|
||||
export const DEFAULT_THINKING_LEVEL: ThinkingLevel = "medium";
|
||||
15
packages/pi-coding-agent/src/core/diagnostics.ts
Normal file
15
packages/pi-coding-agent/src/core/diagnostics.ts
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
export interface ResourceCollision {
|
||||
resourceType: "extension" | "skill" | "prompt" | "theme";
|
||||
name: string; // skill name, command/tool/flag name, prompt name, theme name
|
||||
winnerPath: string;
|
||||
loserPath: string;
|
||||
winnerSource?: string; // e.g., "npm:foo", "git:...", "local"
|
||||
loserSource?: string;
|
||||
}
|
||||
|
||||
export interface ResourceDiagnostic {
|
||||
type: "warning" | "error" | "collision";
|
||||
message: string;
|
||||
path?: string;
|
||||
collision?: ResourceCollision;
|
||||
}
|
||||
33
packages/pi-coding-agent/src/core/event-bus.ts
Normal file
33
packages/pi-coding-agent/src/core/event-bus.ts
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
import { EventEmitter } from "node:events";
|
||||
|
||||
export interface EventBus {
|
||||
emit(channel: string, data: unknown): void;
|
||||
on(channel: string, handler: (data: unknown) => void): () => void;
|
||||
}
|
||||
|
||||
export interface EventBusController extends EventBus {
|
||||
clear(): void;
|
||||
}
|
||||
|
||||
export function createEventBus(): EventBusController {
|
||||
const emitter = new EventEmitter();
|
||||
return {
|
||||
emit: (channel, data) => {
|
||||
emitter.emit(channel, data);
|
||||
},
|
||||
on: (channel, handler) => {
|
||||
const safeHandler = async (data: unknown) => {
|
||||
try {
|
||||
await handler(data);
|
||||
} catch (err) {
|
||||
console.error(`Event handler error (${channel}):`, err);
|
||||
}
|
||||
};
|
||||
emitter.on(channel, safeHandler);
|
||||
return () => emitter.off(channel, safeHandler);
|
||||
},
|
||||
clear: () => {
|
||||
emitter.removeAllListeners();
|
||||
},
|
||||
};
|
||||
}
|
||||
104
packages/pi-coding-agent/src/core/exec.ts
Normal file
104
packages/pi-coding-agent/src/core/exec.ts
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
/**
|
||||
* Shared command execution utilities for extensions and custom tools.
|
||||
*/
|
||||
|
||||
import { spawn } from "node:child_process";
|
||||
|
||||
/**
|
||||
* Options for executing shell commands.
|
||||
*/
|
||||
export interface ExecOptions {
|
||||
/** AbortSignal to cancel the command */
|
||||
signal?: AbortSignal;
|
||||
/** Timeout in milliseconds */
|
||||
timeout?: number;
|
||||
/** Working directory */
|
||||
cwd?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Result of executing a shell command.
|
||||
*/
|
||||
export interface ExecResult {
|
||||
stdout: string;
|
||||
stderr: string;
|
||||
code: number;
|
||||
killed: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a shell command and return stdout/stderr/code.
|
||||
* Supports timeout and abort signal.
|
||||
*/
|
||||
export async function execCommand(
|
||||
command: string,
|
||||
args: string[],
|
||||
cwd: string,
|
||||
options?: ExecOptions,
|
||||
): Promise<ExecResult> {
|
||||
return new Promise((resolve) => {
|
||||
const proc = spawn(command, args, {
|
||||
cwd,
|
||||
shell: false,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let stdout = "";
|
||||
let stderr = "";
|
||||
let killed = false;
|
||||
let timeoutId: NodeJS.Timeout | undefined;
|
||||
|
||||
const killProcess = () => {
|
||||
if (!killed) {
|
||||
killed = true;
|
||||
proc.kill("SIGTERM");
|
||||
// Force kill after 5 seconds if SIGTERM doesn't work
|
||||
setTimeout(() => {
|
||||
if (!proc.killed) {
|
||||
proc.kill("SIGKILL");
|
||||
}
|
||||
}, 5000);
|
||||
}
|
||||
};
|
||||
|
||||
// Handle abort signal
|
||||
if (options?.signal) {
|
||||
if (options.signal.aborted) {
|
||||
killProcess();
|
||||
} else {
|
||||
options.signal.addEventListener("abort", killProcess, { once: true });
|
||||
}
|
||||
}
|
||||
|
||||
// Handle timeout
|
||||
if (options?.timeout && options.timeout > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
killProcess();
|
||||
}, options.timeout);
|
||||
}
|
||||
|
||||
proc.stdout?.on("data", (data) => {
|
||||
stdout += data.toString();
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (data) => {
|
||||
stderr += data.toString();
|
||||
});
|
||||
|
||||
proc.on("close", (code) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: code ?? 0, killed });
|
||||
});
|
||||
|
||||
proc.on("error", (_err) => {
|
||||
if (timeoutId) clearTimeout(timeoutId);
|
||||
if (options?.signal) {
|
||||
options.signal.removeEventListener("abort", killProcess);
|
||||
}
|
||||
resolve({ stdout, stderr, code: 1, killed });
|
||||
});
|
||||
});
|
||||
}
|
||||
258
packages/pi-coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
258
packages/pi-coding-agent/src/core/export-html/ansi-to-html.ts
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
/**
|
||||
* ANSI escape code to HTML converter.
|
||||
*
|
||||
* Converts terminal ANSI color/style codes to HTML with inline styles.
|
||||
* Supports:
|
||||
* - Standard foreground colors (30-37) and bright variants (90-97)
|
||||
* - Standard background colors (40-47) and bright variants (100-107)
|
||||
* - 256-color palette (38;5;N and 48;5;N)
|
||||
* - RGB true color (38;2;R;G;B and 48;2;R;G;B)
|
||||
* - Text styles: bold (1), dim (2), italic (3), underline (4)
|
||||
* - Reset (0)
|
||||
*/
|
||||
|
||||
// Standard ANSI color palette (0-15)
|
||||
const ANSI_COLORS = [
|
||||
"#000000", // 0: black
|
||||
"#800000", // 1: red
|
||||
"#008000", // 2: green
|
||||
"#808000", // 3: yellow
|
||||
"#000080", // 4: blue
|
||||
"#800080", // 5: magenta
|
||||
"#008080", // 6: cyan
|
||||
"#c0c0c0", // 7: white
|
||||
"#808080", // 8: bright black
|
||||
"#ff0000", // 9: bright red
|
||||
"#00ff00", // 10: bright green
|
||||
"#ffff00", // 11: bright yellow
|
||||
"#0000ff", // 12: bright blue
|
||||
"#ff00ff", // 13: bright magenta
|
||||
"#00ffff", // 14: bright cyan
|
||||
"#ffffff", // 15: bright white
|
||||
];
|
||||
|
||||
/**
|
||||
* Convert 256-color index to hex.
|
||||
*/
|
||||
function color256ToHex(index: number): string {
|
||||
// Standard colors (0-15)
|
||||
if (index < 16) {
|
||||
return ANSI_COLORS[index];
|
||||
}
|
||||
|
||||
// Color cube (16-231): 6x6x6 = 216 colors
|
||||
if (index < 232) {
|
||||
const cubeIndex = index - 16;
|
||||
const r = Math.floor(cubeIndex / 36);
|
||||
const g = Math.floor((cubeIndex % 36) / 6);
|
||||
const b = cubeIndex % 6;
|
||||
const toComponent = (n: number) => (n === 0 ? 0 : 55 + n * 40);
|
||||
const toHex = (n: number) => toComponent(n).toString(16).padStart(2, "0");
|
||||
return `#${toHex(r)}${toHex(g)}${toHex(b)}`;
|
||||
}
|
||||
|
||||
// Grayscale (232-255): 24 shades
|
||||
const gray = 8 + (index - 232) * 10;
|
||||
const grayHex = gray.toString(16).padStart(2, "0");
|
||||
return `#${grayHex}${grayHex}${grayHex}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Escape HTML special characters.
|
||||
*/
|
||||
function escapeHtml(text: string): string {
|
||||
return text
|
||||
.replace(/&/g, "&")
|
||||
.replace(/</g, "<")
|
||||
.replace(/>/g, ">")
|
||||
.replace(/"/g, """)
|
||||
.replace(/'/g, "'");
|
||||
}
|
||||
|
||||
interface TextStyle {
|
||||
fg: string | null;
|
||||
bg: string | null;
|
||||
bold: boolean;
|
||||
dim: boolean;
|
||||
italic: boolean;
|
||||
underline: boolean;
|
||||
}
|
||||
|
||||
function createEmptyStyle(): TextStyle {
|
||||
return {
|
||||
fg: null,
|
||||
bg: null,
|
||||
bold: false,
|
||||
dim: false,
|
||||
italic: false,
|
||||
underline: false,
|
||||
};
|
||||
}
|
||||
|
||||
function styleToInlineCSS(style: TextStyle): string {
|
||||
const parts: string[] = [];
|
||||
if (style.fg) parts.push(`color:${style.fg}`);
|
||||
if (style.bg) parts.push(`background-color:${style.bg}`);
|
||||
if (style.bold) parts.push("font-weight:bold");
|
||||
if (style.dim) parts.push("opacity:0.6");
|
||||
if (style.italic) parts.push("font-style:italic");
|
||||
if (style.underline) parts.push("text-decoration:underline");
|
||||
return parts.join(";");
|
||||
}
|
||||
|
||||
function hasStyle(style: TextStyle): boolean {
|
||||
return style.fg !== null || style.bg !== null || style.bold || style.dim || style.italic || style.underline;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse ANSI SGR (Select Graphic Rendition) codes and update style.
|
||||
*/
|
||||
function applySgrCode(params: number[], style: TextStyle): void {
|
||||
let i = 0;
|
||||
while (i < params.length) {
|
||||
const code = params[i];
|
||||
|
||||
if (code === 0) {
|
||||
// Reset all
|
||||
style.fg = null;
|
||||
style.bg = null;
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
style.italic = false;
|
||||
style.underline = false;
|
||||
} else if (code === 1) {
|
||||
style.bold = true;
|
||||
} else if (code === 2) {
|
||||
style.dim = true;
|
||||
} else if (code === 3) {
|
||||
style.italic = true;
|
||||
} else if (code === 4) {
|
||||
style.underline = true;
|
||||
} else if (code === 22) {
|
||||
// Reset bold/dim
|
||||
style.bold = false;
|
||||
style.dim = false;
|
||||
} else if (code === 23) {
|
||||
style.italic = false;
|
||||
} else if (code === 24) {
|
||||
style.underline = false;
|
||||
} else if (code >= 30 && code <= 37) {
|
||||
// Standard foreground colors
|
||||
style.fg = ANSI_COLORS[code - 30];
|
||||
} else if (code === 38) {
|
||||
// Extended foreground color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 38;5;N
|
||||
style.fg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 38;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.fg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 39) {
|
||||
// Default foreground
|
||||
style.fg = null;
|
||||
} else if (code >= 40 && code <= 47) {
|
||||
// Standard background colors
|
||||
style.bg = ANSI_COLORS[code - 40];
|
||||
} else if (code === 48) {
|
||||
// Extended background color
|
||||
if (params[i + 1] === 5 && params.length > i + 2) {
|
||||
// 256-color: 48;5;N
|
||||
style.bg = color256ToHex(params[i + 2]);
|
||||
i += 2;
|
||||
} else if (params[i + 1] === 2 && params.length > i + 4) {
|
||||
// RGB: 48;2;R;G;B
|
||||
const r = params[i + 2];
|
||||
const g = params[i + 3];
|
||||
const b = params[i + 4];
|
||||
style.bg = `rgb(${r},${g},${b})`;
|
||||
i += 4;
|
||||
}
|
||||
} else if (code === 49) {
|
||||
// Default background
|
||||
style.bg = null;
|
||||
} else if (code >= 90 && code <= 97) {
|
||||
// Bright foreground colors
|
||||
style.fg = ANSI_COLORS[code - 90 + 8];
|
||||
} else if (code >= 100 && code <= 107) {
|
||||
// Bright background colors
|
||||
style.bg = ANSI_COLORS[code - 100 + 8];
|
||||
}
|
||||
// Ignore unrecognized codes
|
||||
|
||||
i++;
|
||||
}
|
||||
}
|
||||
|
||||
// Match ANSI escape sequences: ESC[ followed by params and ending with 'm'
|
||||
const ANSI_REGEX = /\x1b\[([\d;]*)m/g;
|
||||
|
||||
/**
|
||||
* Convert ANSI-escaped text to HTML with inline styles.
|
||||
*/
|
||||
export function ansiToHtml(text: string): string {
|
||||
const style = createEmptyStyle();
|
||||
let result = "";
|
||||
let lastIndex = 0;
|
||||
let inSpan = false;
|
||||
|
||||
// Reset regex state
|
||||
ANSI_REGEX.lastIndex = 0;
|
||||
|
||||
let match = ANSI_REGEX.exec(text);
|
||||
while (match !== null) {
|
||||
// Add text before this escape sequence
|
||||
const beforeText = text.slice(lastIndex, match.index);
|
||||
if (beforeText) {
|
||||
result += escapeHtml(beforeText);
|
||||
}
|
||||
|
||||
// Parse SGR parameters
|
||||
const paramStr = match[1];
|
||||
const params = paramStr ? paramStr.split(";").map((p) => parseInt(p, 10) || 0) : [0];
|
||||
|
||||
// Close existing span if we have one
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
inSpan = false;
|
||||
}
|
||||
|
||||
// Apply the codes
|
||||
applySgrCode(params, style);
|
||||
|
||||
// Open new span if we have any styling
|
||||
if (hasStyle(style)) {
|
||||
result += `<span style="${styleToInlineCSS(style)}">`;
|
||||
inSpan = true;
|
||||
}
|
||||
|
||||
lastIndex = match.index + match[0].length;
|
||||
match = ANSI_REGEX.exec(text);
|
||||
}
|
||||
|
||||
// Add remaining text
|
||||
const remainingText = text.slice(lastIndex);
|
||||
if (remainingText) {
|
||||
result += escapeHtml(remainingText);
|
||||
}
|
||||
|
||||
// Close any open span
|
||||
if (inSpan) {
|
||||
result += "</span>";
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert array of ANSI-escaped lines to HTML.
|
||||
* Each line is wrapped in a div element.
|
||||
*/
|
||||
export function ansiLinesToHtml(lines: string[]): string {
|
||||
return lines.map((line) => `<div class="ansi-line">${ansiToHtml(line) || " "}</div>`).join("\n");
|
||||
}
|
||||
306
packages/pi-coding-agent/src/core/export-html/index.ts
Normal file
306
packages/pi-coding-agent/src/core/export-html/index.ts
Normal file
|
|
@ -0,0 +1,306 @@
|
|||
import type { AgentState } from "@gsd/pi-agent-core";
|
||||
import { existsSync, readFileSync, writeFileSync } from "fs";
|
||||
import { basename, join } from "path";
|
||||
import { APP_NAME, getExportTemplateDir } from "../../config.js";
|
||||
import { getResolvedThemeColors, getThemeExportColors } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolInfo } from "../extensions/types.js";
|
||||
import type { SessionEntry } from "../session-manager.js";
|
||||
import { SessionManager } from "../session-manager.js";
|
||||
|
||||
/**
|
||||
* Interface for rendering custom tools to HTML.
|
||||
* Used by agent-session to pre-render extension tool output.
|
||||
*/
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to HTML. Returns collapsed/expanded or undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): { collapsed?: string; expanded?: string } | undefined;
|
||||
}
|
||||
|
||||
/** Pre-rendered HTML for a custom tool call and result */
|
||||
interface RenderedToolHtml {
|
||||
callHtml?: string;
|
||||
resultHtmlCollapsed?: string;
|
||||
resultHtmlExpanded?: string;
|
||||
}
|
||||
|
||||
export interface ExportOptions {
|
||||
outputPath?: string;
|
||||
themeName?: string;
|
||||
/** Optional tool renderer for custom tools */
|
||||
toolRenderer?: ToolHtmlRenderer;
|
||||
}
|
||||
|
||||
/** Parse a color string to RGB values. Supports hex (#RRGGBB) and rgb(r,g,b) formats. */
|
||||
function parseColor(color: string): { r: number; g: number; b: number } | undefined {
|
||||
const hexMatch = color.match(/^#([0-9a-fA-F]{2})([0-9a-fA-F]{2})([0-9a-fA-F]{2})$/);
|
||||
if (hexMatch) {
|
||||
return {
|
||||
r: Number.parseInt(hexMatch[1], 16),
|
||||
g: Number.parseInt(hexMatch[2], 16),
|
||||
b: Number.parseInt(hexMatch[3], 16),
|
||||
};
|
||||
}
|
||||
const rgbMatch = color.match(/^rgb\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)\s*\)$/);
|
||||
if (rgbMatch) {
|
||||
return {
|
||||
r: Number.parseInt(rgbMatch[1], 10),
|
||||
g: Number.parseInt(rgbMatch[2], 10),
|
||||
b: Number.parseInt(rgbMatch[3], 10),
|
||||
};
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/** Calculate relative luminance of a color (0-1, higher = lighter). */
|
||||
function getLuminance(r: number, g: number, b: number): number {
|
||||
const toLinear = (c: number) => {
|
||||
const s = c / 255;
|
||||
return s <= 0.03928 ? s / 12.92 : ((s + 0.055) / 1.055) ** 2.4;
|
||||
};
|
||||
return 0.2126 * toLinear(r) + 0.7152 * toLinear(g) + 0.0722 * toLinear(b);
|
||||
}
|
||||
|
||||
/** Adjust color brightness. Factor > 1 lightens, < 1 darkens. */
|
||||
function adjustBrightness(color: string, factor: number): string {
|
||||
const parsed = parseColor(color);
|
||||
if (!parsed) return color;
|
||||
const adjust = (c: number) => Math.min(255, Math.max(0, Math.round(c * factor)));
|
||||
return `rgb(${adjust(parsed.r)}, ${adjust(parsed.g)}, ${adjust(parsed.b)})`;
|
||||
}
|
||||
|
||||
/** Derive export background colors from a base color (e.g., userMessageBg). */
|
||||
function deriveExportColors(baseColor: string): { pageBg: string; cardBg: string; infoBg: string } {
|
||||
const parsed = parseColor(baseColor);
|
||||
if (!parsed) {
|
||||
return {
|
||||
pageBg: "rgb(24, 24, 30)",
|
||||
cardBg: "rgb(30, 30, 36)",
|
||||
infoBg: "rgb(60, 55, 40)",
|
||||
};
|
||||
}
|
||||
|
||||
const luminance = getLuminance(parsed.r, parsed.g, parsed.b);
|
||||
const isLight = luminance > 0.5;
|
||||
|
||||
if (isLight) {
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.96),
|
||||
cardBg: baseColor,
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 10)}, ${Math.min(255, parsed.g + 5)}, ${Math.max(0, parsed.b - 20)})`,
|
||||
};
|
||||
}
|
||||
return {
|
||||
pageBg: adjustBrightness(baseColor, 0.7),
|
||||
cardBg: adjustBrightness(baseColor, 0.85),
|
||||
infoBg: `rgb(${Math.min(255, parsed.r + 20)}, ${Math.min(255, parsed.g + 15)}, ${parsed.b})`,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate CSS custom property declarations from theme colors.
|
||||
*/
|
||||
function generateThemeVars(themeName?: string): string {
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const lines: string[] = [];
|
||||
for (const [key, value] of Object.entries(colors)) {
|
||||
lines.push(`--${key}: ${value};`);
|
||||
}
|
||||
|
||||
// Use explicit theme export colors if available, otherwise derive from userMessageBg
|
||||
const themeExport = getThemeExportColors(themeName);
|
||||
const userMessageBg = colors.userMessageBg || "#343541";
|
||||
const derivedColors = deriveExportColors(userMessageBg);
|
||||
|
||||
lines.push(`--exportPageBg: ${themeExport.pageBg ?? derivedColors.pageBg};`);
|
||||
lines.push(`--exportCardBg: ${themeExport.cardBg ?? derivedColors.cardBg};`);
|
||||
lines.push(`--exportInfoBg: ${themeExport.infoBg ?? derivedColors.infoBg};`);
|
||||
|
||||
return lines.join("\n ");
|
||||
}
|
||||
|
||||
interface SessionData {
|
||||
header: ReturnType<SessionManager["getHeader"]>;
|
||||
entries: ReturnType<SessionManager["getEntries"]>;
|
||||
leafId: string | null;
|
||||
systemPrompt?: string;
|
||||
tools?: ToolInfo[];
|
||||
/** Pre-rendered HTML for custom tool calls/results, keyed by tool call ID */
|
||||
renderedTools?: Record<string, RenderedToolHtml>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Core HTML generation logic shared by both export functions.
|
||||
*/
|
||||
function generateHtml(sessionData: SessionData, themeName?: string): string {
|
||||
const templateDir = getExportTemplateDir();
|
||||
const template = readFileSync(join(templateDir, "template.html"), "utf-8");
|
||||
const templateCss = readFileSync(join(templateDir, "template.css"), "utf-8");
|
||||
const templateJs = readFileSync(join(templateDir, "template.js"), "utf-8");
|
||||
const markedJs = readFileSync(join(templateDir, "vendor", "marked.min.js"), "utf-8");
|
||||
const hljsJs = readFileSync(join(templateDir, "vendor", "highlight.min.js"), "utf-8");
|
||||
|
||||
const themeVars = generateThemeVars(themeName);
|
||||
const colors = getResolvedThemeColors(themeName);
|
||||
const exportColors = deriveExportColors(colors.userMessageBg || "#343541");
|
||||
const bodyBg = exportColors.pageBg;
|
||||
const containerBg = exportColors.cardBg;
|
||||
const infoBg = exportColors.infoBg;
|
||||
|
||||
// Base64 encode session data to avoid escaping issues
|
||||
const sessionDataBase64 = Buffer.from(JSON.stringify(sessionData)).toString("base64");
|
||||
|
||||
// Build the CSS with theme variables injected
|
||||
const css = templateCss
|
||||
.replace("{{THEME_VARS}}", themeVars)
|
||||
.replace("{{BODY_BG}}", bodyBg)
|
||||
.replace("{{CONTAINER_BG}}", containerBg)
|
||||
.replace("{{INFO_BG}}", infoBg);
|
||||
|
||||
return template
|
||||
.replace("{{CSS}}", css)
|
||||
.replace("{{JS}}", templateJs)
|
||||
.replace("{{SESSION_DATA}}", sessionDataBase64)
|
||||
.replace("{{MARKED_JS}}", markedJs)
|
||||
.replace("{{HIGHLIGHT_JS}}", hljsJs);
|
||||
}
|
||||
|
||||
/** Built-in tool names that have custom rendering in template.js */
|
||||
const BUILTIN_TOOLS = new Set(["bash", "read", "write", "edit", "ls", "find", "grep"]);
|
||||
|
||||
/**
|
||||
* Pre-render custom tools to HTML using their TUI renderers.
|
||||
*/
|
||||
function preRenderCustomTools(
|
||||
entries: SessionEntry[],
|
||||
toolRenderer: ToolHtmlRenderer,
|
||||
): Record<string, RenderedToolHtml> {
|
||||
const renderedTools: Record<string, RenderedToolHtml> = {};
|
||||
|
||||
for (const entry of entries) {
|
||||
if (entry.type !== "message") continue;
|
||||
const msg = entry.message;
|
||||
|
||||
// Find tool calls in assistant messages
|
||||
if (msg.role === "assistant" && Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "toolCall" && !BUILTIN_TOOLS.has(block.name)) {
|
||||
const callHtml = toolRenderer.renderCall(block.name, block.arguments);
|
||||
if (callHtml) {
|
||||
renderedTools[block.id] = { callHtml };
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find tool results
|
||||
if (msg.role === "toolResult" && msg.toolCallId) {
|
||||
const toolName = msg.toolName || "";
|
||||
// Only render if we have a pre-rendered call OR it's not a built-in tool
|
||||
const existing = renderedTools[msg.toolCallId];
|
||||
if (existing || !BUILTIN_TOOLS.has(toolName)) {
|
||||
const rendered = toolRenderer.renderResult(toolName, msg.content, msg.details, msg.isError || false);
|
||||
if (rendered) {
|
||||
renderedTools[msg.toolCallId] = {
|
||||
...existing,
|
||||
resultHtmlCollapsed: rendered.collapsed,
|
||||
resultHtmlExpanded: rendered.expanded,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return renderedTools;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session to HTML using SessionManager and AgentState.
|
||||
* Used by TUI's /export command.
|
||||
*/
|
||||
export async function exportSessionToHtml(
|
||||
sm: SessionManager,
|
||||
state?: AgentState,
|
||||
options?: ExportOptions | string,
|
||||
): Promise<string> {
|
||||
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
const sessionFile = sm.getSessionFile();
|
||||
if (!sessionFile) {
|
||||
throw new Error("Cannot export in-memory session to HTML");
|
||||
}
|
||||
if (!existsSync(sessionFile)) {
|
||||
throw new Error("Nothing to export yet - start a conversation first");
|
||||
}
|
||||
|
||||
const entries = sm.getEntries();
|
||||
|
||||
// Pre-render custom tools if a tool renderer is provided
|
||||
let renderedTools: Record<string, RenderedToolHtml> | undefined;
|
||||
if (opts.toolRenderer) {
|
||||
renderedTools = preRenderCustomTools(entries, opts.toolRenderer);
|
||||
// Only include if we actually rendered something
|
||||
if (Object.keys(renderedTools).length === 0) {
|
||||
renderedTools = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries,
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: state?.systemPrompt,
|
||||
tools: state?.tools?.map((t) => ({ name: t.name, description: t.description, parameters: t.parameters })),
|
||||
renderedTools,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const sessionBasename = basename(sessionFile, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${sessionBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
|
||||
/**
|
||||
* Export session file to HTML (standalone, without AgentState).
|
||||
* Used by CLI for exporting arbitrary session files.
|
||||
*/
|
||||
export async function exportFromFile(inputPath: string, options?: ExportOptions | string): Promise<string> {
|
||||
const opts: ExportOptions = typeof options === "string" ? { outputPath: options } : options || {};
|
||||
|
||||
if (!existsSync(inputPath)) {
|
||||
throw new Error(`File not found: ${inputPath}`);
|
||||
}
|
||||
|
||||
const sm = SessionManager.open(inputPath);
|
||||
|
||||
const sessionData: SessionData = {
|
||||
header: sm.getHeader(),
|
||||
entries: sm.getEntries(),
|
||||
leafId: sm.getLeafId(),
|
||||
systemPrompt: undefined,
|
||||
tools: undefined,
|
||||
};
|
||||
|
||||
const html = generateHtml(sessionData, opts.themeName);
|
||||
|
||||
let outputPath = opts.outputPath;
|
||||
if (!outputPath) {
|
||||
const inputBasename = basename(inputPath, ".jsonl");
|
||||
outputPath = `${APP_NAME}-session-${inputBasename}.html`;
|
||||
}
|
||||
|
||||
writeFileSync(outputPath, html, "utf8");
|
||||
return outputPath;
|
||||
}
|
||||
971
packages/pi-coding-agent/src/core/export-html/template.css
Normal file
971
packages/pi-coding-agent/src/core/export-html/template.css
Normal file
|
|
@ -0,0 +1,971 @@
|
|||
:root {
|
||||
{{THEME_VARS}}
|
||||
--body-bg: {{BODY_BG}};
|
||||
--container-bg: {{CONTAINER_BG}};
|
||||
--info-bg: {{INFO_BG}};
|
||||
}
|
||||
|
||||
* { margin: 0; padding: 0; box-sizing: border-box; }
|
||||
|
||||
:root {
|
||||
--line-height: 18px; /* 12px font * 1.5 */
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-monospace, 'Cascadia Code', 'Source Code Pro', Menlo, Consolas, 'DejaVu Sans Mono', monospace;
|
||||
font-size: 12px;
|
||||
line-height: var(--line-height);
|
||||
color: var(--text);
|
||||
background: var(--body-bg);
|
||||
}
|
||||
|
||||
#app {
|
||||
display: flex;
|
||||
min-height: 100vh;
|
||||
}
|
||||
|
||||
/* Sidebar */
|
||||
#sidebar {
|
||||
width: 400px;
|
||||
background: var(--container-bg);
|
||||
flex-shrink: 0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
height: 100vh;
|
||||
border-right: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.sidebar-header {
|
||||
padding: 8px 12px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.sidebar-controls {
|
||||
padding: 8px 8px 4px 8px;
|
||||
}
|
||||
|
||||
.sidebar-search {
|
||||
width: 100%;
|
||||
box-sizing: border-box;
|
||||
padding: 4px 8px;
|
||||
font-size: 11px;
|
||||
font-family: inherit;
|
||||
background: var(--body-bg);
|
||||
color: var(--text);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.sidebar-filters {
|
||||
display: flex;
|
||||
padding: 4px 8px 8px 8px;
|
||||
gap: 4px;
|
||||
align-items: center;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.sidebar-search:focus {
|
||||
outline: none;
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-search::placeholder {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.filter-btn {
|
||||
padding: 3px 8px;
|
||||
font-size: 10px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.filter-btn:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.filter-btn.active {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: none;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.sidebar-close:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
.tree-container {
|
||||
flex: 1;
|
||||
overflow: auto;
|
||||
padding: 4px 0;
|
||||
}
|
||||
|
||||
.tree-node {
|
||||
padding: 0 8px;
|
||||
cursor: pointer;
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
font-size: 11px;
|
||||
line-height: 13px;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.tree-node:hover {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active {
|
||||
background: var(--selectedBg);
|
||||
}
|
||||
|
||||
.tree-node.active .tree-content {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tree-node.in-path {
|
||||
background: color-mix(in srgb, var(--accent) 10%, transparent);
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path) {
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.tree-node:not(.in-path):hover {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.tree-prefix {
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
font-family: monospace;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.tree-marker {
|
||||
color: var(--accent);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.tree-content {
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tree-role-user {
|
||||
color: var(--accent);
|
||||
}
|
||||
|
||||
.tree-role-assistant {
|
||||
color: var(--success);
|
||||
}
|
||||
|
||||
.tree-role-tool {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-muted {
|
||||
color: var(--muted);
|
||||
}
|
||||
|
||||
.tree-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
.tree-compaction {
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
.tree-branch-summary {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.tree-custom-message {
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.tree-status {
|
||||
padding: 4px 12px;
|
||||
font-size: 10px;
|
||||
color: var(--muted);
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
/* Main content */
|
||||
#content {
|
||||
flex: 1;
|
||||
overflow-y: auto;
|
||||
padding: var(--line-height) calc(var(--line-height) * 2);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
width: 100%;
|
||||
max-width: 800px;
|
||||
}
|
||||
|
||||
/* Help bar */
|
||||
.help-bar {
|
||||
font-size: 11px;
|
||||
color: var(--warning);
|
||||
margin-bottom: var(--line-height);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.download-json-btn {
|
||||
font-size: 10px;
|
||||
padding: 2px 8px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 3px;
|
||||
color: var(--text);
|
||||
cursor: pointer;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.download-json-btn:hover {
|
||||
background: var(--hover);
|
||||
border-color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Header */
|
||||
.header {
|
||||
background: var(--container-bg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header h1 {
|
||||
font-size: 12px;
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.header-info {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.info-item {
|
||||
color: var(--dim);
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
}
|
||||
|
||||
.info-label {
|
||||
font-weight: 600;
|
||||
margin-right: 8px;
|
||||
min-width: 100px;
|
||||
}
|
||||
|
||||
.info-value {
|
||||
color: var(--text);
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* Messages */
|
||||
#messages {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: var(--line-height);
|
||||
}
|
||||
|
||||
.message-timestamp {
|
||||
font-size: 10px;
|
||||
color: var(--dim);
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.user-message {
|
||||
background: var(--userMessageBg);
|
||||
color: var(--userMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.assistant-message {
|
||||
padding: 0;
|
||||
position: relative;
|
||||
}
|
||||
|
||||
/* Copy link button - appears on hover */
|
||||
.copy-link-btn {
|
||||
position: absolute;
|
||||
top: 8px;
|
||||
right: 8px;
|
||||
width: 28px;
|
||||
height: 28px;
|
||||
padding: 6px;
|
||||
background: var(--container-bg);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 4px;
|
||||
color: var(--muted);
|
||||
cursor: pointer;
|
||||
opacity: 0;
|
||||
transition: opacity 0.15s, background 0.15s, color 0.15s;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
z-index: 10;
|
||||
}
|
||||
|
||||
.user-message:hover .copy-link-btn,
|
||||
.assistant-message:hover .copy-link-btn {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.copy-link-btn:hover {
|
||||
background: var(--accent);
|
||||
color: var(--body-bg);
|
||||
border-color: var(--accent);
|
||||
}
|
||||
|
||||
.copy-link-btn.copied {
|
||||
background: var(--success, #22c55e);
|
||||
color: white;
|
||||
border-color: var(--success, #22c55e);
|
||||
}
|
||||
|
||||
/* Highlight effect for deep-linked messages */
|
||||
.user-message.highlight,
|
||||
.assistant-message.highlight {
|
||||
animation: highlight-pulse 2s ease-out;
|
||||
}
|
||||
|
||||
@keyframes highlight-pulse {
|
||||
0% {
|
||||
box-shadow: 0 0 0 3px var(--accent);
|
||||
}
|
||||
100% {
|
||||
box-shadow: 0 0 0 0 transparent;
|
||||
}
|
||||
}
|
||||
|
||||
.assistant-message > .message-timestamp {
|
||||
padding-left: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text {
|
||||
padding: var(--line-height);
|
||||
padding-bottom: 0;
|
||||
}
|
||||
|
||||
.message-timestamp + .assistant-text,
|
||||
.message-timestamp + .thinking-block {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-block + .assistant-text {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-text {
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.message-timestamp + .thinking-block .thinking-text,
|
||||
.message-timestamp + .thinking-block .thinking-collapsed {
|
||||
padding-top: 0;
|
||||
}
|
||||
|
||||
.thinking-collapsed {
|
||||
display: none;
|
||||
padding: var(--line-height);
|
||||
color: var(--thinkingText);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
/* Tool execution */
|
||||
.tool-execution {
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.tool-execution + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.assistant-text + .tool-execution {
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-execution.pending { background: var(--toolPendingBg); }
|
||||
.tool-execution.success { background: var(--toolSuccessBg); }
|
||||
.tool-execution.error { background: var(--toolErrorBg); }
|
||||
|
||||
.tool-header, .tool-name {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.tool-path {
|
||||
color: var(--accent);
|
||||
word-break: break-all;
|
||||
}
|
||||
|
||||
.line-numbers {
|
||||
color: var(--warning);
|
||||
}
|
||||
|
||||
.line-count {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-command {
|
||||
font-weight: bold;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
}
|
||||
|
||||
.tool-output {
|
||||
margin-top: var(--line-height);
|
||||
color: var(--toolOutput);
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
word-break: break-word;
|
||||
font-family: inherit;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.tool-output > div,
|
||||
.output-preview,
|
||||
.output-full {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
line-height: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-output pre {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
font-family: inherit;
|
||||
color: inherit;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
overflow-wrap: break-word;
|
||||
}
|
||||
|
||||
.tool-output code {
|
||||
padding: 0;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-output.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-output.expandable:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
.tool-output.expandable .output-full {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-preview {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.tool-output.expandable.expanded .output-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.ansi-line {
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.tool-images {
|
||||
}
|
||||
|
||||
.tool-image {
|
||||
max-width: 100%;
|
||||
max-height: 500px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.expand-hint {
|
||||
color: var(--toolOutput);
|
||||
}
|
||||
|
||||
/* Diff */
|
||||
.tool-diff {
|
||||
font-size: 11px;
|
||||
overflow-x: auto;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.diff-added { color: var(--toolDiffAdded); }
|
||||
.diff-removed { color: var(--toolDiffRemoved); }
|
||||
.diff-context { color: var(--toolDiffContext); }
|
||||
|
||||
/* Model change */
|
||||
.model-change {
|
||||
padding: 0 var(--line-height);
|
||||
color: var(--dim);
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.model-name {
|
||||
color: var(--borderAccent);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Compaction / Branch Summary - matches customMessage colors from TUI */
|
||||
.compaction {
|
||||
background: var(--customMessageBg);
|
||||
border-radius: 4px;
|
||||
padding: var(--line-height);
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.compaction-label {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.compaction-collapsed {
|
||||
color: var(--customMessageText);
|
||||
}
|
||||
|
||||
.compaction-content {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-collapsed {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.compaction.expanded .compaction-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
/* System prompt */
|
||||
.system-prompt {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expandable {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.system-prompt-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
}
|
||||
|
||||
.system-prompt-preview {
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt-expand-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
.system-prompt-full {
|
||||
display: none;
|
||||
color: var(--customMessageText);
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-size: 11px;
|
||||
margin-top: var(--line-height);
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-preview,
|
||||
.system-prompt.expanded .system-prompt-expand-hint {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.system-prompt.expanded .system-prompt-full {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.system-prompt.provider-prompt {
|
||||
border-left: 3px solid var(--warning);
|
||||
}
|
||||
|
||||
.system-prompt-note {
|
||||
font-size: 10px;
|
||||
font-style: italic;
|
||||
color: var(--muted);
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
/* Tools list */
|
||||
.tools-list {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tools-header {
|
||||
font-weight: bold;
|
||||
color: var(--customMessageLabel);
|
||||
margin-bottom: var(--line-height);
|
||||
}
|
||||
|
||||
.tool-item {
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-item-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-item-desc {
|
||||
color: var(--dim);
|
||||
}
|
||||
|
||||
.tool-params-hint {
|
||||
color: var(--muted);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-item:has(.tool-params-hint) {
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.tool-params-hint::after {
|
||||
content: '[click to show parameters]';
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-hint::after {
|
||||
content: '[hide parameters]';
|
||||
}
|
||||
|
||||
.tool-params-content {
|
||||
display: none;
|
||||
margin-top: 4px;
|
||||
margin-left: 12px;
|
||||
padding-left: 8px;
|
||||
border-left: 1px solid var(--dim);
|
||||
}
|
||||
|
||||
.tool-item.params-expanded .tool-params-content {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.tool-param {
|
||||
margin-bottom: 4px;
|
||||
font-size: 11px;
|
||||
}
|
||||
|
||||
.tool-param-name {
|
||||
font-weight: bold;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.tool-param-type {
|
||||
color: var(--dim);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.tool-param-required {
|
||||
color: var(--warning, #e8a838);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-optional {
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
.tool-param-desc {
|
||||
color: var(--dim);
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
/* Hook/custom messages */
|
||||
.hook-message {
|
||||
background: var(--customMessageBg);
|
||||
color: var(--customMessageText);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.hook-type {
|
||||
color: var(--customMessageLabel);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
/* Branch summary */
|
||||
.branch-summary {
|
||||
background: var(--customMessageBg);
|
||||
padding: var(--line-height);
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
.branch-summary-header {
|
||||
font-weight: bold;
|
||||
color: var(--borderAccent);
|
||||
}
|
||||
|
||||
/* Error */
|
||||
.error-text {
|
||||
color: var(--error);
|
||||
padding: 0 var(--line-height);
|
||||
}
|
||||
.tool-error {
|
||||
color: var(--error);
|
||||
}
|
||||
|
||||
/* Images */
|
||||
.message-images {
|
||||
margin-bottom: 12px;
|
||||
}
|
||||
|
||||
.message-image {
|
||||
max-width: 100%;
|
||||
max-height: 400px;
|
||||
border-radius: 4px;
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
/* Markdown content */
|
||||
.markdown-content h1,
|
||||
.markdown-content h2,
|
||||
.markdown-content h3,
|
||||
.markdown-content h4,
|
||||
.markdown-content h5,
|
||||
.markdown-content h6 {
|
||||
color: var(--mdHeading);
|
||||
margin: var(--line-height) 0 0 0;
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content h1 { font-size: 1em; }
|
||||
.markdown-content h2 { font-size: 1em; }
|
||||
.markdown-content h3 { font-size: 1em; }
|
||||
.markdown-content h4 { font-size: 1em; }
|
||||
.markdown-content h5 { font-size: 1em; }
|
||||
.markdown-content h6 { font-size: 1em; }
|
||||
.markdown-content p { margin: 0; }
|
||||
.markdown-content p + p { margin-top: var(--line-height); }
|
||||
|
||||
.markdown-content a {
|
||||
color: var(--mdLink);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
background: rgba(128, 128, 128, 0.2);
|
||||
color: var(--mdCode);
|
||||
padding: 0 4px;
|
||||
border-radius: 3px;
|
||||
font-family: inherit;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
background: transparent;
|
||||
margin: var(--line-height) 0;
|
||||
overflow-x: auto;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
display: block;
|
||||
background: none;
|
||||
color: var(--text);
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
border-left: 3px solid var(--mdQuoteBorder);
|
||||
padding-left: var(--line-height);
|
||||
margin: var(--line-height) 0;
|
||||
color: var(--mdQuote);
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.markdown-content ul,
|
||||
.markdown-content ol {
|
||||
margin: var(--line-height) 0;
|
||||
padding-left: calc(var(--line-height) * 2);
|
||||
}
|
||||
|
||||
.markdown-content li { margin: 0; }
|
||||
.markdown-content li::marker { color: var(--mdListBullet); }
|
||||
|
||||
.markdown-content hr {
|
||||
border: none;
|
||||
border-top: 1px solid var(--mdHr);
|
||||
margin: var(--line-height) 0;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
border-collapse: collapse;
|
||||
margin: 0.5em 0;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.markdown-content th,
|
||||
.markdown-content td {
|
||||
border: 1px solid var(--mdCodeBlockBorder);
|
||||
padding: 6px 10px;
|
||||
text-align: left;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
background: rgba(128, 128, 128, 0.1);
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.markdown-content img {
|
||||
max-width: 100%;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
/* Syntax highlighting */
|
||||
.hljs { background: transparent; color: var(--text); }
|
||||
.hljs-comment, .hljs-quote { color: var(--syntaxComment); }
|
||||
.hljs-keyword, .hljs-selector-tag { color: var(--syntaxKeyword); }
|
||||
.hljs-number, .hljs-literal { color: var(--syntaxNumber); }
|
||||
.hljs-string, .hljs-doctag { color: var(--syntaxString); }
|
||||
/* Function names: hljs v11 uses .hljs-title.function_ compound class */
|
||||
.hljs-function, .hljs-title, .hljs-title.function_, .hljs-section, .hljs-name { color: var(--syntaxFunction); }
|
||||
/* Types: hljs v11 uses .hljs-title.class_ for class names */
|
||||
.hljs-type, .hljs-class, .hljs-title.class_, .hljs-built_in { color: var(--syntaxType); }
|
||||
.hljs-attr, .hljs-variable, .hljs-variable.language_, .hljs-params, .hljs-property { color: var(--syntaxVariable); }
|
||||
.hljs-meta, .hljs-meta .hljs-keyword, .hljs-meta .hljs-string { color: var(--syntaxKeyword); }
|
||||
.hljs-operator { color: var(--syntaxOperator); }
|
||||
.hljs-punctuation { color: var(--syntaxPunctuation); }
|
||||
.hljs-subst { color: var(--text); }
|
||||
|
||||
/* Footer */
|
||||
.footer {
|
||||
margin-top: 48px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
color: var(--dim);
|
||||
font-size: 10px;
|
||||
}
|
||||
|
||||
/* Mobile */
|
||||
#hamburger {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 10px;
|
||||
left: 10px;
|
||||
z-index: 100;
|
||||
padding: 3px 8px;
|
||||
font-size: 12px;
|
||||
font-family: inherit;
|
||||
background: transparent;
|
||||
color: var(--muted);
|
||||
border: 1px solid var(--dim);
|
||||
border-radius: 3px;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
#hamburger:hover {
|
||||
color: var(--text);
|
||||
border-color: var(--text);
|
||||
}
|
||||
|
||||
|
||||
|
||||
#sidebar-overlay {
|
||||
display: none;
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
bottom: 0;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
z-index: 98;
|
||||
}
|
||||
|
||||
@media (max-width: 900px) {
|
||||
#sidebar {
|
||||
position: fixed;
|
||||
left: -400px;
|
||||
width: 400px;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
height: 100vh;
|
||||
z-index: 99;
|
||||
transition: left 0.3s;
|
||||
}
|
||||
|
||||
#sidebar.open {
|
||||
left: 0;
|
||||
}
|
||||
|
||||
#sidebar-overlay.open {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#hamburger {
|
||||
display: block;
|
||||
}
|
||||
|
||||
.sidebar-close {
|
||||
display: block;
|
||||
}
|
||||
|
||||
#content {
|
||||
padding: var(--line-height) 16px;
|
||||
}
|
||||
|
||||
#content > * {
|
||||
max-width: 100%;
|
||||
}
|
||||
}
|
||||
|
||||
@media (max-width: 500px) {
|
||||
#sidebar {
|
||||
width: 100vw;
|
||||
left: -100vw;
|
||||
}
|
||||
}
|
||||
|
||||
@media print {
|
||||
#sidebar, #sidebar-toggle { display: none !important; }
|
||||
body { background: white; color: black; }
|
||||
#content { max-width: none; }
|
||||
}
|
||||
54
packages/pi-coding-agent/src/core/export-html/template.html
Normal file
54
packages/pi-coding-agent/src/core/export-html/template.html
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Session Export</title>
|
||||
<style>
|
||||
{{CSS}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<button id="hamburger" title="Open sidebar"><svg width="14" height="14" viewBox="0 0 24 24" fill="currentColor" stroke="none"><circle cx="6" cy="6" r="2.5"/><circle cx="6" cy="18" r="2.5"/><circle cx="18" cy="12" r="2.5"/><rect x="5" y="6" width="2" height="12"/><path d="M6 12h10c1 0 2 0 2-2V8"/></svg></button>
|
||||
<div id="sidebar-overlay"></div>
|
||||
<div id="app">
|
||||
<aside id="sidebar">
|
||||
<div class="sidebar-header">
|
||||
<div class="sidebar-controls">
|
||||
<input type="text" class="sidebar-search" id="tree-search" placeholder="Search...">
|
||||
</div>
|
||||
<div class="sidebar-filters">
|
||||
<button class="filter-btn active" data-filter="default" title="Hide settings entries">Default</button>
|
||||
<button class="filter-btn" data-filter="no-tools" title="Default minus tool results">No-tools</button>
|
||||
<button class="filter-btn" data-filter="user-only" title="Only user messages">User</button>
|
||||
<button class="filter-btn" data-filter="labeled-only" title="Only labeled entries">Labeled</button>
|
||||
<button class="filter-btn" data-filter="all" title="Show everything">All</button>
|
||||
<button class="sidebar-close" id="sidebar-close" title="Close">✕</button>
|
||||
</div>
|
||||
</div>
|
||||
<div class="tree-container" id="tree-container"></div>
|
||||
<div class="tree-status" id="tree-status"></div>
|
||||
</aside>
|
||||
<main id="content">
|
||||
<div id="header-container"></div>
|
||||
<div id="messages"></div>
|
||||
</main>
|
||||
<div id="image-modal" class="image-modal">
|
||||
<img id="modal-image" src="" alt="">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script id="session-data" type="application/json">{{SESSION_DATA}}</script>
|
||||
|
||||
<!-- Vendored libraries -->
|
||||
<script>{{MARKED_JS}}</script>
|
||||
|
||||
<!-- highlight.js -->
|
||||
<script>{{HIGHLIGHT_JS}}</script>
|
||||
|
||||
<!-- Main application code -->
|
||||
<script>
|
||||
{{JS}}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1583
packages/pi-coding-agent/src/core/export-html/template.js
Normal file
1583
packages/pi-coding-agent/src/core/export-html/template.js
Normal file
File diff suppressed because it is too large
Load diff
114
packages/pi-coding-agent/src/core/export-html/tool-renderer.ts
Normal file
114
packages/pi-coding-agent/src/core/export-html/tool-renderer.ts
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Tool HTML renderer for custom tools in HTML export.
|
||||
*
|
||||
* Renders custom tool calls and results to HTML by invoking their TUI renderers
|
||||
* and converting the ANSI output to HTML.
|
||||
*/
|
||||
|
||||
import type { ImageContent, TextContent } from "@gsd/pi-ai";
|
||||
import type { Theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ToolDefinition } from "../extensions/types.js";
|
||||
import { ansiLinesToHtml } from "./ansi-to-html.js";
|
||||
|
||||
export interface ToolHtmlRendererDeps {
|
||||
/** Function to look up tool definition by name */
|
||||
getToolDefinition: (name: string) => ToolDefinition | undefined;
|
||||
/** Theme for styling */
|
||||
theme: Theme;
|
||||
/** Terminal width for rendering (default: 100) */
|
||||
width?: number;
|
||||
}
|
||||
|
||||
export interface ToolHtmlRenderer {
|
||||
/** Render a tool call to HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderCall(toolName: string, args: unknown): string | undefined;
|
||||
/** Render a tool result to collapsed/expanded HTML. Returns undefined if tool has no custom renderer. */
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): { collapsed?: string; expanded?: string } | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a tool HTML renderer.
|
||||
*
|
||||
* The renderer looks up tool definitions and invokes their renderCall/renderResult
|
||||
* methods, converting the resulting TUI Component output (ANSI) to HTML.
|
||||
*/
|
||||
export function createToolHtmlRenderer(deps: ToolHtmlRendererDeps): ToolHtmlRenderer {
|
||||
const { getToolDefinition, theme, width = 100 } = deps;
|
||||
|
||||
return {
|
||||
renderCall(toolName: string, args: unknown): string | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderCall) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const component = toolDef.renderCall(args, theme);
|
||||
if (!component) {
|
||||
return undefined;
|
||||
}
|
||||
const lines = component.render(width);
|
||||
return ansiLinesToHtml(lines);
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
|
||||
renderResult(
|
||||
toolName: string,
|
||||
result: Array<{ type: string; text?: string; data?: string; mimeType?: string }>,
|
||||
details: unknown,
|
||||
isError: boolean,
|
||||
): { collapsed?: string; expanded?: string } | undefined {
|
||||
try {
|
||||
const toolDef = getToolDefinition(toolName);
|
||||
if (!toolDef?.renderResult) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Build AgentToolResult from content array
|
||||
// Cast content since session storage uses generic object types
|
||||
const agentToolResult = {
|
||||
content: result as (TextContent | ImageContent)[],
|
||||
details,
|
||||
isError,
|
||||
};
|
||||
|
||||
// Render collapsed
|
||||
const collapsedComponent = toolDef.renderResult(
|
||||
agentToolResult,
|
||||
{ expanded: false, isPartial: false },
|
||||
theme,
|
||||
);
|
||||
const collapsed = collapsedComponent ? ansiLinesToHtml(collapsedComponent.render(width)) : undefined;
|
||||
|
||||
// Render expanded
|
||||
const expandedComponent = toolDef.renderResult(
|
||||
agentToolResult,
|
||||
{ expanded: true, isPartial: false },
|
||||
theme,
|
||||
);
|
||||
const expanded = expandedComponent ? ansiLinesToHtml(expandedComponent.render(width)) : undefined;
|
||||
|
||||
// Return collapsed only if it exists and differs from expanded
|
||||
if (!expanded) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
...(collapsed && collapsed !== expanded ? { collapsed } : {}),
|
||||
expanded,
|
||||
};
|
||||
} catch {
|
||||
// On error, return undefined to trigger JSON fallback
|
||||
return undefined;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
171
packages/pi-coding-agent/src/core/extensions/index.ts
Normal file
171
packages/pi-coding-agent/src/core/extensions/index.ts
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
/**
|
||||
* Extension system for lifecycle events and custom tools.
|
||||
*/
|
||||
|
||||
export type { SlashCommandInfo, SlashCommandLocation, SlashCommandSource } from "../slash-commands.js";
|
||||
export {
|
||||
createExtensionRuntime,
|
||||
discoverAndLoadExtensions,
|
||||
loadExtensionFromFactory,
|
||||
loadExtensions,
|
||||
} from "./loader.js";
|
||||
export type {
|
||||
ExtensionErrorListener,
|
||||
ForkHandler,
|
||||
NavigateTreeHandler,
|
||||
NewSessionHandler,
|
||||
ShutdownHandler,
|
||||
SwitchSessionHandler,
|
||||
} from "./runner.js";
|
||||
export { ExtensionRunner } from "./runner.js";
|
||||
export type {
|
||||
AgentEndEvent,
|
||||
AgentStartEvent,
|
||||
// Re-exports
|
||||
AgentToolResult,
|
||||
AgentToolUpdateCallback,
|
||||
// App keybindings (for custom editors)
|
||||
AppAction,
|
||||
AppendEntryHandler,
|
||||
// Events - Tool (ToolCallEvent types)
|
||||
BashToolCallEvent,
|
||||
BashToolResultEvent,
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
BeforeProviderRequestEvent,
|
||||
BeforeProviderRequestEventResult,
|
||||
// Context
|
||||
CompactOptions,
|
||||
// Events - Agent
|
||||
ContextEvent,
|
||||
// Event Results
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
CustomToolCallEvent,
|
||||
CustomToolResultEvent,
|
||||
EditToolCallEvent,
|
||||
EditToolResultEvent,
|
||||
ExecOptions,
|
||||
ExecResult,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
// API
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
// Errors
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFactory,
|
||||
ExtensionFlag,
|
||||
ExtensionHandler,
|
||||
// Runtime
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
ExtensionUIDialogOptions,
|
||||
ExtensionWidgetOptions,
|
||||
FindToolCallEvent,
|
||||
FindToolResultEvent,
|
||||
GetActiveToolsHandler,
|
||||
GetAllToolsHandler,
|
||||
GetCommandsHandler,
|
||||
GetThinkingLevelHandler,
|
||||
GrepToolCallEvent,
|
||||
GrepToolResultEvent,
|
||||
// Events - Input
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
KeybindingsManager,
|
||||
LoadExtensionsResult,
|
||||
LsToolCallEvent,
|
||||
LsToolResultEvent,
|
||||
// Events - Message
|
||||
MessageEndEvent,
|
||||
// Message Rendering
|
||||
MessageRenderer,
|
||||
MessageRenderOptions,
|
||||
MessageStartEvent,
|
||||
MessageUpdateEvent,
|
||||
ModelSelectEvent,
|
||||
ModelSelectSource,
|
||||
// Provider Registration
|
||||
ProviderConfig,
|
||||
ProviderModelConfig,
|
||||
ReadToolCallEvent,
|
||||
ReadToolResultEvent,
|
||||
// Commands
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
// Events - Resources
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SendMessageHandler,
|
||||
SendUserMessageHandler,
|
||||
SessionBeforeCompactEvent,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkEvent,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchEvent,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeEvent,
|
||||
SessionBeforeTreeResult,
|
||||
SessionCompactEvent,
|
||||
SessionDirectoryEvent,
|
||||
SessionDirectoryHandler,
|
||||
SessionDirectoryResult,
|
||||
SessionEvent,
|
||||
SessionForkEvent,
|
||||
SessionShutdownEvent,
|
||||
// Events - Session
|
||||
SessionStartEvent,
|
||||
SessionSwitchEvent,
|
||||
SessionTreeEvent,
|
||||
SetActiveToolsHandler,
|
||||
SetLabelHandler,
|
||||
SetModelHandler,
|
||||
SetThinkingLevelHandler,
|
||||
TerminalInputHandler,
|
||||
// Events - Tool
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
// Tools
|
||||
ToolDefinition,
|
||||
// Events - Tool Execution
|
||||
ToolExecutionEndEvent,
|
||||
ToolExecutionStartEvent,
|
||||
ToolExecutionUpdateEvent,
|
||||
ToolInfo,
|
||||
ToolRenderResultOptions,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
TreePreparation,
|
||||
TurnEndEvent,
|
||||
TurnStartEvent,
|
||||
// Events - User Bash
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
WidgetPlacement,
|
||||
WriteToolCallEvent,
|
||||
WriteToolResultEvent,
|
||||
} from "./types.js";
|
||||
// Type guards
|
||||
export {
|
||||
isBashToolResult,
|
||||
isEditToolResult,
|
||||
isFindToolResult,
|
||||
isGrepToolResult,
|
||||
isLsToolResult,
|
||||
isReadToolResult,
|
||||
isToolCallEventType,
|
||||
isWriteToolResult,
|
||||
} from "./types.js";
|
||||
export {
|
||||
wrapRegisteredTool,
|
||||
wrapRegisteredTools,
|
||||
wrapToolsWithExtensions,
|
||||
wrapToolWithExtensions,
|
||||
} from "./wrapper.js";
|
||||
545
packages/pi-coding-agent/src/core/extensions/loader.ts
Normal file
545
packages/pi-coding-agent/src/core/extensions/loader.ts
Normal file
|
|
@ -0,0 +1,545 @@
|
|||
/**
|
||||
* Extension loader - loads TypeScript extension modules using jiti.
|
||||
*
|
||||
* Uses @mariozechner/jiti fork with virtualModules support for compiled Bun binaries.
|
||||
*/
|
||||
|
||||
import * as fs from "node:fs";
|
||||
import { createRequire } from "node:module";
|
||||
import * as os from "node:os";
|
||||
import * as path from "node:path";
|
||||
import { fileURLToPath } from "node:url";
|
||||
import { createJiti } from "@mariozechner/jiti";
|
||||
import * as _bundledPiAgentCore from "@gsd/pi-agent-core";
|
||||
import * as _bundledPiAi from "@gsd/pi-ai";
|
||||
import * as _bundledPiAiOauth from "@gsd/pi-ai/oauth";
|
||||
import type { KeyId } from "@gsd/pi-tui";
|
||||
import * as _bundledPiTui from "@gsd/pi-tui";
|
||||
// Static imports of packages that extensions may use.
|
||||
// These MUST be static so Bun bundles them into the compiled binary.
|
||||
// The virtualModules option then makes them available to extensions.
|
||||
import * as _bundledTypebox from "@sinclair/typebox";
|
||||
import { getAgentDir, isBunBinary } from "../../config.js";
|
||||
// NOTE: This import works because loader.ts exports are NOT re-exported from index.ts,
|
||||
// avoiding a circular dependency. Extensions can import from @gsd/pi-coding-agent.
|
||||
import * as _bundledPiCodingAgent from "../../index.js";
|
||||
import { createEventBus, type EventBus } from "../event-bus.js";
|
||||
import type { ExecOptions } from "../exec.js";
|
||||
import { execCommand } from "../exec.js";
|
||||
import type {
|
||||
Extension,
|
||||
ExtensionAPI,
|
||||
ExtensionFactory,
|
||||
ExtensionRuntime,
|
||||
LoadExtensionsResult,
|
||||
MessageRenderer,
|
||||
ProviderConfig,
|
||||
RegisteredCommand,
|
||||
ToolDefinition,
|
||||
} from "./types.js";
|
||||
|
||||
/** Modules available to extensions via virtualModules (for compiled Bun binary) */
|
||||
const VIRTUAL_MODULES: Record<string, unknown> = {
|
||||
"@sinclair/typebox": _bundledTypebox,
|
||||
"@gsd/pi-agent-core": _bundledPiAgentCore,
|
||||
"@gsd/pi-tui": _bundledPiTui,
|
||||
"@gsd/pi-ai": _bundledPiAi,
|
||||
"@gsd/pi-ai/oauth": _bundledPiAiOauth,
|
||||
"@gsd/pi-coding-agent": _bundledPiCodingAgent,
|
||||
};
|
||||
|
||||
const require = createRequire(import.meta.url);
|
||||
|
||||
/**
|
||||
* Get aliases for jiti (used in Node.js/development mode).
|
||||
* In Bun binary mode, virtualModules is used instead.
|
||||
*/
|
||||
let _aliases: Record<string, string> | null = null;
|
||||
function getAliases(): Record<string, string> {
|
||||
if (_aliases) return _aliases;
|
||||
|
||||
const __dirname = path.dirname(fileURLToPath(import.meta.url));
|
||||
const packageIndex = path.resolve(__dirname, "../..", "index.js");
|
||||
|
||||
const typeboxEntry = require.resolve("@sinclair/typebox");
|
||||
const typeboxRoot = typeboxEntry.replace(/[\\/]build[\\/]cjs[\\/]index\.js$/, "");
|
||||
|
||||
const packagesRoot = path.resolve(__dirname, "../../../../");
|
||||
const resolveWorkspaceOrImport = (workspaceRelativePath: string, specifier: string): string => {
|
||||
const workspacePath = path.join(packagesRoot, workspaceRelativePath);
|
||||
if (fs.existsSync(workspacePath)) {
|
||||
return workspacePath;
|
||||
}
|
||||
return fileURLToPath(import.meta.resolve(specifier));
|
||||
};
|
||||
|
||||
_aliases = {
|
||||
"@gsd/pi-coding-agent": packageIndex,
|
||||
"@gsd/pi-agent-core": resolveWorkspaceOrImport("agent/dist/index.js", "@gsd/pi-agent-core"),
|
||||
"@gsd/pi-tui": resolveWorkspaceOrImport("tui/dist/index.js", "@gsd/pi-tui"),
|
||||
"@gsd/pi-ai": resolveWorkspaceOrImport("ai/dist/index.js", "@gsd/pi-ai"),
|
||||
"@gsd/pi-ai/oauth": resolveWorkspaceOrImport("ai/dist/oauth.js", "@gsd/pi-ai/oauth"),
|
||||
"@sinclair/typebox": typeboxRoot,
|
||||
};
|
||||
|
||||
return _aliases;
|
||||
}
|
||||
|
||||
const UNICODE_SPACES = /[\u00A0\u2000-\u200A\u202F\u205F\u3000]/g;
|
||||
|
||||
function normalizeUnicodeSpaces(str: string): string {
|
||||
return str.replace(UNICODE_SPACES, " ");
|
||||
}
|
||||
|
||||
function expandPath(p: string): string {
|
||||
const normalized = normalizeUnicodeSpaces(p);
|
||||
if (normalized.startsWith("~/")) {
|
||||
return path.join(os.homedir(), normalized.slice(2));
|
||||
}
|
||||
if (normalized.startsWith("~")) {
|
||||
return path.join(os.homedir(), normalized.slice(1));
|
||||
}
|
||||
return normalized;
|
||||
}
|
||||
|
||||
function resolvePath(extPath: string, cwd: string): string {
|
||||
const expanded = expandPath(extPath);
|
||||
if (path.isAbsolute(expanded)) {
|
||||
return expanded;
|
||||
}
|
||||
return path.resolve(cwd, expanded);
|
||||
}
|
||||
|
||||
type HandlerFn = (...args: unknown[]) => Promise<unknown>;
|
||||
|
||||
/**
|
||||
* Create a runtime with throwing stubs for action methods.
|
||||
* Runner.bindCore() replaces these with real implementations.
|
||||
*/
|
||||
export function createExtensionRuntime(): ExtensionRuntime {
|
||||
const notInitialized = () => {
|
||||
throw new Error("Extension runtime not initialized. Action methods cannot be called during extension loading.");
|
||||
};
|
||||
|
||||
const runtime: ExtensionRuntime = {
|
||||
sendMessage: notInitialized,
|
||||
sendUserMessage: notInitialized,
|
||||
appendEntry: notInitialized,
|
||||
setSessionName: notInitialized,
|
||||
getSessionName: notInitialized,
|
||||
setLabel: notInitialized,
|
||||
getActiveTools: notInitialized,
|
||||
getAllTools: notInitialized,
|
||||
setActiveTools: notInitialized,
|
||||
// registerTool() is valid during extension load; refresh is only needed post-bind.
|
||||
refreshTools: () => {},
|
||||
getCommands: notInitialized,
|
||||
setModel: () => Promise.reject(new Error("Extension runtime not initialized")),
|
||||
getThinkingLevel: notInitialized,
|
||||
setThinkingLevel: notInitialized,
|
||||
flagValues: new Map(),
|
||||
pendingProviderRegistrations: [],
|
||||
// Pre-bind: queue registrations so bindCore() can flush them once the
|
||||
// model registry is available. bindCore() replaces both with direct calls.
|
||||
registerProvider: (name, config) => {
|
||||
runtime.pendingProviderRegistrations.push({ name, config });
|
||||
},
|
||||
unregisterProvider: (name) => {
|
||||
runtime.pendingProviderRegistrations = runtime.pendingProviderRegistrations.filter((r) => r.name !== name);
|
||||
},
|
||||
};
|
||||
|
||||
return runtime;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create the ExtensionAPI for an extension.
|
||||
* Registration methods write to the extension object.
|
||||
* Action methods delegate to the shared runtime.
|
||||
*/
|
||||
function createExtensionAPI(
|
||||
extension: Extension,
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
): ExtensionAPI {
|
||||
const api = {
|
||||
// Registration methods - write to extension
|
||||
on(event: string, handler: HandlerFn): void {
|
||||
const list = extension.handlers.get(event) ?? [];
|
||||
list.push(handler);
|
||||
extension.handlers.set(event, list);
|
||||
},
|
||||
|
||||
registerTool(tool: ToolDefinition): void {
|
||||
extension.tools.set(tool.name, {
|
||||
definition: tool,
|
||||
extensionPath: extension.path,
|
||||
});
|
||||
runtime.refreshTools();
|
||||
},
|
||||
|
||||
registerCommand(name: string, options: Omit<RegisteredCommand, "name">): void {
|
||||
extension.commands.set(name, { name, ...options });
|
||||
},
|
||||
|
||||
registerShortcut(
|
||||
shortcut: KeyId,
|
||||
options: {
|
||||
description?: string;
|
||||
handler: (ctx: import("./types.js").ExtensionContext) => Promise<void> | void;
|
||||
},
|
||||
): void {
|
||||
extension.shortcuts.set(shortcut, { shortcut, extensionPath: extension.path, ...options });
|
||||
},
|
||||
|
||||
registerFlag(
|
||||
name: string,
|
||||
options: { description?: string; type: "boolean" | "string"; default?: boolean | string },
|
||||
): void {
|
||||
extension.flags.set(name, { name, extensionPath: extension.path, ...options });
|
||||
if (options.default !== undefined && !runtime.flagValues.has(name)) {
|
||||
runtime.flagValues.set(name, options.default);
|
||||
}
|
||||
},
|
||||
|
||||
registerMessageRenderer<T>(customType: string, renderer: MessageRenderer<T>): void {
|
||||
extension.messageRenderers.set(customType, renderer as MessageRenderer);
|
||||
},
|
||||
|
||||
// Flag access - checks extension registered it, reads from runtime
|
||||
getFlag(name: string): boolean | string | undefined {
|
||||
if (!extension.flags.has(name)) return undefined;
|
||||
return runtime.flagValues.get(name);
|
||||
},
|
||||
|
||||
// Action methods - delegate to shared runtime
|
||||
sendMessage(message, options): void {
|
||||
runtime.sendMessage(message, options);
|
||||
},
|
||||
|
||||
sendUserMessage(content, options): void {
|
||||
runtime.sendUserMessage(content, options);
|
||||
},
|
||||
|
||||
appendEntry(customType: string, data?: unknown): void {
|
||||
runtime.appendEntry(customType, data);
|
||||
},
|
||||
|
||||
setSessionName(name: string): void {
|
||||
runtime.setSessionName(name);
|
||||
},
|
||||
|
||||
getSessionName(): string | undefined {
|
||||
return runtime.getSessionName();
|
||||
},
|
||||
|
||||
setLabel(entryId: string, label: string | undefined): void {
|
||||
runtime.setLabel(entryId, label);
|
||||
},
|
||||
|
||||
exec(command: string, args: string[], options?: ExecOptions) {
|
||||
return execCommand(command, args, options?.cwd ?? cwd, options);
|
||||
},
|
||||
|
||||
getActiveTools(): string[] {
|
||||
return runtime.getActiveTools();
|
||||
},
|
||||
|
||||
getAllTools() {
|
||||
return runtime.getAllTools();
|
||||
},
|
||||
|
||||
setActiveTools(toolNames: string[]): void {
|
||||
runtime.setActiveTools(toolNames);
|
||||
},
|
||||
|
||||
getCommands() {
|
||||
return runtime.getCommands();
|
||||
},
|
||||
|
||||
setModel(model) {
|
||||
return runtime.setModel(model);
|
||||
},
|
||||
|
||||
getThinkingLevel() {
|
||||
return runtime.getThinkingLevel();
|
||||
},
|
||||
|
||||
setThinkingLevel(level) {
|
||||
runtime.setThinkingLevel(level);
|
||||
},
|
||||
|
||||
registerProvider(name: string, config: ProviderConfig) {
|
||||
runtime.registerProvider(name, config);
|
||||
},
|
||||
|
||||
unregisterProvider(name: string) {
|
||||
runtime.unregisterProvider(name);
|
||||
},
|
||||
|
||||
events: eventBus,
|
||||
} as ExtensionAPI;
|
||||
|
||||
return api;
|
||||
}
|
||||
|
||||
async function loadExtensionModule(extensionPath: string) {
|
||||
const jiti = createJiti(import.meta.url, {
|
||||
moduleCache: false,
|
||||
// In Bun binary: use virtualModules for bundled packages (no filesystem resolution)
|
||||
// Also disable tryNative so jiti handles ALL imports (not just the entry point)
|
||||
// In Node.js/dev: use aliases to resolve to node_modules paths
|
||||
...(isBunBinary ? { virtualModules: VIRTUAL_MODULES, tryNative: false } : { alias: getAliases() }),
|
||||
});
|
||||
|
||||
const module = await jiti.import(extensionPath, { default: true });
|
||||
const factory = module as ExtensionFactory;
|
||||
return typeof factory !== "function" ? undefined : factory;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension object with empty collections.
|
||||
*/
|
||||
function createExtension(extensionPath: string, resolvedPath: string): Extension {
|
||||
return {
|
||||
path: extensionPath,
|
||||
resolvedPath,
|
||||
handlers: new Map(),
|
||||
tools: new Map(),
|
||||
messageRenderers: new Map(),
|
||||
commands: new Map(),
|
||||
flags: new Map(),
|
||||
shortcuts: new Map(),
|
||||
};
|
||||
}
|
||||
|
||||
async function loadExtension(
|
||||
extensionPath: string,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
): Promise<{ extension: Extension | null; error: string | null }> {
|
||||
const resolvedPath = resolvePath(extensionPath, cwd);
|
||||
|
||||
try {
|
||||
const factory = await loadExtensionModule(resolvedPath);
|
||||
if (!factory) {
|
||||
return { extension: null, error: `Extension does not export a valid factory function: ${extensionPath}` };
|
||||
}
|
||||
|
||||
const extension = createExtension(extensionPath, resolvedPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
|
||||
return { extension, error: null };
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
return { extension: null, error: `Failed to load extension: ${message}` };
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an Extension from an inline factory function.
|
||||
*/
|
||||
export async function loadExtensionFromFactory(
|
||||
factory: ExtensionFactory,
|
||||
cwd: string,
|
||||
eventBus: EventBus,
|
||||
runtime: ExtensionRuntime,
|
||||
extensionPath = "<inline>",
|
||||
): Promise<Extension> {
|
||||
const extension = createExtension(extensionPath, extensionPath);
|
||||
const api = createExtensionAPI(extension, runtime, cwd, eventBus);
|
||||
await factory(api);
|
||||
return extension;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load extensions from paths.
|
||||
*/
|
||||
export async function loadExtensions(paths: string[], cwd: string, eventBus?: EventBus): Promise<LoadExtensionsResult> {
|
||||
const extensions: Extension[] = [];
|
||||
const errors: Array<{ path: string; error: string }> = [];
|
||||
const resolvedEventBus = eventBus ?? createEventBus();
|
||||
const runtime = createExtensionRuntime();
|
||||
|
||||
for (const extPath of paths) {
|
||||
const { extension, error } = await loadExtension(extPath, cwd, resolvedEventBus, runtime);
|
||||
|
||||
if (error) {
|
||||
errors.push({ path: extPath, error });
|
||||
continue;
|
||||
}
|
||||
|
||||
if (extension) {
|
||||
extensions.push(extension);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
extensions,
|
||||
errors,
|
||||
runtime,
|
||||
};
|
||||
}
|
||||
|
||||
interface PiManifest {
|
||||
extensions?: string[];
|
||||
themes?: string[];
|
||||
skills?: string[];
|
||||
prompts?: string[];
|
||||
}
|
||||
|
||||
function readPiManifest(packageJsonPath: string): PiManifest | null {
|
||||
try {
|
||||
const content = fs.readFileSync(packageJsonPath, "utf-8");
|
||||
const pkg = JSON.parse(content);
|
||||
if (pkg.pi && typeof pkg.pi === "object") {
|
||||
return pkg.pi as PiManifest;
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function isExtensionFile(name: string): boolean {
|
||||
return name.endsWith(".ts") || name.endsWith(".js");
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve extension entry points from a directory.
|
||||
*
|
||||
* Checks for:
|
||||
* 1. package.json with "pi.extensions" field -> returns declared paths
|
||||
* 2. index.ts or index.js -> returns the index file
|
||||
*
|
||||
* Returns resolved paths or null if no entry points found.
|
||||
*/
|
||||
function resolveExtensionEntries(dir: string): string[] | null {
|
||||
// Check for package.json with "pi" field first
|
||||
const packageJsonPath = path.join(dir, "package.json");
|
||||
if (fs.existsSync(packageJsonPath)) {
|
||||
const manifest = readPiManifest(packageJsonPath);
|
||||
if (manifest?.extensions?.length) {
|
||||
const entries: string[] = [];
|
||||
for (const extPath of manifest.extensions) {
|
||||
const resolvedExtPath = path.resolve(dir, extPath);
|
||||
if (fs.existsSync(resolvedExtPath)) {
|
||||
entries.push(resolvedExtPath);
|
||||
}
|
||||
}
|
||||
if (entries.length > 0) {
|
||||
return entries;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check for index.ts or index.js
|
||||
const indexTs = path.join(dir, "index.ts");
|
||||
const indexJs = path.join(dir, "index.js");
|
||||
if (fs.existsSync(indexTs)) {
|
||||
return [indexTs];
|
||||
}
|
||||
if (fs.existsSync(indexJs)) {
|
||||
return [indexJs];
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover extensions in a directory.
|
||||
*
|
||||
* Discovery rules:
|
||||
* 1. Direct files: `extensions/*.ts` or `*.js` → load
|
||||
* 2. Subdirectory with index: `extensions/* /index.ts` or `index.js` → load
|
||||
* 3. Subdirectory with package.json: `extensions/* /package.json` with "pi" field → load what it declares
|
||||
*
|
||||
* No recursion beyond one level. Complex packages must use package.json manifest.
|
||||
*/
|
||||
function discoverExtensionsInDir(dir: string): string[] {
|
||||
if (!fs.existsSync(dir)) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const discovered: string[] = [];
|
||||
|
||||
try {
|
||||
const entries = fs.readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const entryPath = path.join(dir, entry.name);
|
||||
|
||||
// 1. Direct files: *.ts or *.js
|
||||
if ((entry.isFile() || entry.isSymbolicLink()) && isExtensionFile(entry.name)) {
|
||||
discovered.push(entryPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 2 & 3. Subdirectories
|
||||
if (entry.isDirectory() || entry.isSymbolicLink()) {
|
||||
const entries = resolveExtensionEntries(entryPath);
|
||||
if (entries) {
|
||||
discovered.push(...entries);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
|
||||
return discovered;
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover and load extensions from standard locations.
|
||||
*/
|
||||
export async function discoverAndLoadExtensions(
|
||||
configuredPaths: string[],
|
||||
cwd: string,
|
||||
agentDir: string = getAgentDir(),
|
||||
eventBus?: EventBus,
|
||||
): Promise<LoadExtensionsResult> {
|
||||
const allPaths: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
const addPaths = (paths: string[]) => {
|
||||
for (const p of paths) {
|
||||
const resolved = path.resolve(p);
|
||||
if (!seen.has(resolved)) {
|
||||
seen.add(resolved);
|
||||
allPaths.push(p);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Project-local extensions: cwd/.pi/extensions/
|
||||
const localExtDir = path.join(cwd, ".pi", "extensions");
|
||||
addPaths(discoverExtensionsInDir(localExtDir));
|
||||
|
||||
// 2. Global extensions: agentDir/extensions/
|
||||
const globalExtDir = path.join(agentDir, "extensions");
|
||||
addPaths(discoverExtensionsInDir(globalExtDir));
|
||||
|
||||
// 3. Explicitly configured paths
|
||||
for (const p of configuredPaths) {
|
||||
const resolved = resolvePath(p, cwd);
|
||||
if (fs.existsSync(resolved) && fs.statSync(resolved).isDirectory()) {
|
||||
// Check for package.json with pi manifest or index.ts
|
||||
const entries = resolveExtensionEntries(resolved);
|
||||
if (entries) {
|
||||
addPaths(entries);
|
||||
continue;
|
||||
}
|
||||
// No explicit entries - discover individual files in directory
|
||||
addPaths(discoverExtensionsInDir(resolved));
|
||||
continue;
|
||||
}
|
||||
|
||||
addPaths([resolved]);
|
||||
}
|
||||
|
||||
return loadExtensions(allPaths, cwd, eventBus);
|
||||
}
|
||||
884
packages/pi-coding-agent/src/core/extensions/runner.ts
Normal file
884
packages/pi-coding-agent/src/core/extensions/runner.ts
Normal file
|
|
@ -0,0 +1,884 @@
|
|||
/**
|
||||
* Extension runner - executes extensions and manages their lifecycle.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@gsd/pi-agent-core";
|
||||
import type { ImageContent, Model } from "@gsd/pi-ai";
|
||||
import type { KeyId } from "@gsd/pi-tui";
|
||||
import { type Theme, theme } from "../../modes/interactive/theme/theme.js";
|
||||
import type { ResourceDiagnostic } from "../diagnostics.js";
|
||||
import type { KeyAction, KeybindingsConfig } from "../keybindings.js";
|
||||
import type { ModelRegistry } from "../model-registry.js";
|
||||
import type { SessionManager } from "../session-manager.js";
|
||||
import type {
|
||||
BeforeAgentStartEvent,
|
||||
BeforeAgentStartEventResult,
|
||||
BeforeProviderRequestEvent,
|
||||
CompactOptions,
|
||||
ContextEvent,
|
||||
ContextEventResult,
|
||||
ContextUsage,
|
||||
Extension,
|
||||
ExtensionActions,
|
||||
ExtensionCommandContext,
|
||||
ExtensionCommandContextActions,
|
||||
ExtensionContext,
|
||||
ExtensionContextActions,
|
||||
ExtensionError,
|
||||
ExtensionEvent,
|
||||
ExtensionFlag,
|
||||
ExtensionRuntime,
|
||||
ExtensionShortcut,
|
||||
ExtensionUIContext,
|
||||
InputEvent,
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
MessageRenderer,
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
ResourcesDiscoverEvent,
|
||||
ResourcesDiscoverResult,
|
||||
SessionBeforeCompactResult,
|
||||
SessionBeforeForkResult,
|
||||
SessionBeforeSwitchResult,
|
||||
SessionBeforeTreeResult,
|
||||
ToolCallEvent,
|
||||
ToolCallEventResult,
|
||||
ToolResultEvent,
|
||||
ToolResultEventResult,
|
||||
UserBashEvent,
|
||||
UserBashEventResult,
|
||||
} from "./types.js";
|
||||
|
||||
// Keybindings for these actions cannot be overridden by extensions
|
||||
const RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS: ReadonlyArray<KeyAction> = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"submit",
|
||||
"selectConfirm",
|
||||
"selectCancel",
|
||||
"copy",
|
||||
"deleteToLineEnd",
|
||||
];
|
||||
|
||||
type BuiltInKeyBindings = Partial<Record<KeyId, { action: KeyAction; restrictOverride: boolean }>>;
|
||||
|
||||
const buildBuiltinKeybindings = (effectiveKeybindings: Required<KeybindingsConfig>): BuiltInKeyBindings => {
|
||||
const builtinKeybindings = {} as BuiltInKeyBindings;
|
||||
for (const [action, keys] of Object.entries(effectiveKeybindings)) {
|
||||
const keyAction = action as KeyAction;
|
||||
const keyList = Array.isArray(keys) ? keys : [keys];
|
||||
const restrictOverride = RESERVED_ACTIONS_FOR_EXTENSION_CONFLICTS.includes(keyAction);
|
||||
for (const key of keyList) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
builtinKeybindings[normalizedKey] = {
|
||||
action: keyAction,
|
||||
restrictOverride: restrictOverride,
|
||||
};
|
||||
}
|
||||
}
|
||||
return builtinKeybindings;
|
||||
};
|
||||
|
||||
/** Combined result from all before_agent_start handlers */
|
||||
interface BeforeAgentStartCombinedResult {
|
||||
messages?: NonNullable<BeforeAgentStartEventResult["message"]>[];
|
||||
systemPrompt?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Events handled by the generic emit() method.
|
||||
* Events with dedicated emitXxx() methods are excluded for stronger type safety.
|
||||
*/
|
||||
type RunnerEmitEvent = Exclude<
|
||||
ExtensionEvent,
|
||||
| ToolCallEvent
|
||||
| ToolResultEvent
|
||||
| UserBashEvent
|
||||
| ContextEvent
|
||||
| BeforeProviderRequestEvent
|
||||
| BeforeAgentStartEvent
|
||||
| ResourcesDiscoverEvent
|
||||
| InputEvent
|
||||
>;
|
||||
|
||||
type SessionBeforeEvent = Extract<
|
||||
RunnerEmitEvent,
|
||||
{ type: "session_before_switch" | "session_before_fork" | "session_before_compact" | "session_before_tree" }
|
||||
>;
|
||||
|
||||
type SessionBeforeEventResult =
|
||||
| SessionBeforeSwitchResult
|
||||
| SessionBeforeForkResult
|
||||
| SessionBeforeCompactResult
|
||||
| SessionBeforeTreeResult;
|
||||
|
||||
type RunnerEmitResult<TEvent extends RunnerEmitEvent> = TEvent extends { type: "session_before_switch" }
|
||||
? SessionBeforeSwitchResult | undefined
|
||||
: TEvent extends { type: "session_before_fork" }
|
||||
? SessionBeforeForkResult | undefined
|
||||
: TEvent extends { type: "session_before_compact" }
|
||||
? SessionBeforeCompactResult | undefined
|
||||
: TEvent extends { type: "session_before_tree" }
|
||||
? SessionBeforeTreeResult | undefined
|
||||
: undefined;
|
||||
|
||||
export type ExtensionErrorListener = (error: ExtensionError) => void;
|
||||
|
||||
export type NewSessionHandler = (options?: {
|
||||
parentSession?: string;
|
||||
setup?: (sessionManager: SessionManager) => Promise<void>;
|
||||
}) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ForkHandler = (entryId: string) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type NavigateTreeHandler = (
|
||||
targetId: string,
|
||||
options?: { summarize?: boolean; customInstructions?: string; replaceInstructions?: boolean; label?: string },
|
||||
) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type SwitchSessionHandler = (sessionPath: string) => Promise<{ cancelled: boolean }>;
|
||||
|
||||
export type ReloadHandler = () => Promise<void>;
|
||||
|
||||
export type ShutdownHandler = () => void;
|
||||
|
||||
/**
|
||||
* Helper function to emit session_shutdown event to extensions.
|
||||
* Returns true if the event was emitted, false if there were no handlers.
|
||||
*/
|
||||
export async function emitSessionShutdownEvent(extensionRunner: ExtensionRunner | undefined): Promise<boolean> {
|
||||
if (extensionRunner?.hasHandlers("session_shutdown")) {
|
||||
await extensionRunner.emit({
|
||||
type: "session_shutdown",
|
||||
});
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const noOpUIContext: ExtensionUIContext = {
|
||||
select: async () => undefined,
|
||||
confirm: async () => false,
|
||||
input: async () => undefined,
|
||||
notify: () => {},
|
||||
onTerminalInput: () => () => {},
|
||||
setStatus: () => {},
|
||||
setWorkingMessage: () => {},
|
||||
setWidget: () => {},
|
||||
setFooter: () => {},
|
||||
setHeader: () => {},
|
||||
setTitle: () => {},
|
||||
custom: async () => undefined as never,
|
||||
pasteToEditor: () => {},
|
||||
setEditorText: () => {},
|
||||
getEditorText: () => "",
|
||||
editor: async () => undefined,
|
||||
setEditorComponent: () => {},
|
||||
get theme() {
|
||||
return theme;
|
||||
},
|
||||
getAllThemes: () => [],
|
||||
getTheme: () => undefined,
|
||||
setTheme: (_theme: string | Theme) => ({ success: false, error: "UI not available" }),
|
||||
getToolsExpanded: () => false,
|
||||
setToolsExpanded: () => {},
|
||||
};
|
||||
|
||||
export class ExtensionRunner {
|
||||
private extensions: Extension[];
|
||||
private runtime: ExtensionRuntime;
|
||||
private uiContext: ExtensionUIContext;
|
||||
private cwd: string;
|
||||
private sessionManager: SessionManager;
|
||||
private modelRegistry: ModelRegistry;
|
||||
private errorListeners: Set<ExtensionErrorListener> = new Set();
|
||||
private getModel: () => Model<any> | undefined = () => undefined;
|
||||
private isIdleFn: () => boolean = () => true;
|
||||
private waitForIdleFn: () => Promise<void> = async () => {};
|
||||
private abortFn: () => void = () => {};
|
||||
private hasPendingMessagesFn: () => boolean = () => false;
|
||||
private getContextUsageFn: () => ContextUsage | undefined = () => undefined;
|
||||
private compactFn: (options?: CompactOptions) => void = () => {};
|
||||
private getSystemPromptFn: () => string = () => "";
|
||||
private newSessionHandler: NewSessionHandler = async () => ({ cancelled: false });
|
||||
private forkHandler: ForkHandler = async () => ({ cancelled: false });
|
||||
private navigateTreeHandler: NavigateTreeHandler = async () => ({ cancelled: false });
|
||||
private switchSessionHandler: SwitchSessionHandler = async () => ({ cancelled: false });
|
||||
private reloadHandler: ReloadHandler = async () => {};
|
||||
private shutdownHandler: ShutdownHandler = () => {};
|
||||
private shortcutDiagnostics: ResourceDiagnostic[] = [];
|
||||
private commandDiagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
constructor(
|
||||
extensions: Extension[],
|
||||
runtime: ExtensionRuntime,
|
||||
cwd: string,
|
||||
sessionManager: SessionManager,
|
||||
modelRegistry: ModelRegistry,
|
||||
) {
|
||||
this.extensions = extensions;
|
||||
this.runtime = runtime;
|
||||
this.uiContext = noOpUIContext;
|
||||
this.cwd = cwd;
|
||||
this.sessionManager = sessionManager;
|
||||
this.modelRegistry = modelRegistry;
|
||||
}
|
||||
|
||||
bindCore(actions: ExtensionActions, contextActions: ExtensionContextActions): void {
|
||||
// Copy actions into the shared runtime (all extension APIs reference this)
|
||||
this.runtime.sendMessage = actions.sendMessage;
|
||||
this.runtime.sendUserMessage = actions.sendUserMessage;
|
||||
this.runtime.appendEntry = actions.appendEntry;
|
||||
this.runtime.setSessionName = actions.setSessionName;
|
||||
this.runtime.getSessionName = actions.getSessionName;
|
||||
this.runtime.setLabel = actions.setLabel;
|
||||
this.runtime.getActiveTools = actions.getActiveTools;
|
||||
this.runtime.getAllTools = actions.getAllTools;
|
||||
this.runtime.setActiveTools = actions.setActiveTools;
|
||||
this.runtime.refreshTools = actions.refreshTools;
|
||||
this.runtime.getCommands = actions.getCommands;
|
||||
this.runtime.setModel = actions.setModel;
|
||||
this.runtime.getThinkingLevel = actions.getThinkingLevel;
|
||||
this.runtime.setThinkingLevel = actions.setThinkingLevel;
|
||||
|
||||
// Context actions (required)
|
||||
this.getModel = contextActions.getModel;
|
||||
this.isIdleFn = contextActions.isIdle;
|
||||
this.abortFn = contextActions.abort;
|
||||
this.hasPendingMessagesFn = contextActions.hasPendingMessages;
|
||||
this.shutdownHandler = contextActions.shutdown;
|
||||
this.getContextUsageFn = contextActions.getContextUsage;
|
||||
this.compactFn = contextActions.compact;
|
||||
this.getSystemPromptFn = contextActions.getSystemPrompt;
|
||||
|
||||
// Flush provider registrations queued during extension loading
|
||||
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
}
|
||||
this.runtime.pendingProviderRegistrations = [];
|
||||
|
||||
// From this point on, provider registration/unregistration takes effect immediately
|
||||
// without requiring a /reload.
|
||||
this.runtime.registerProvider = (name, config) => this.modelRegistry.registerProvider(name, config);
|
||||
this.runtime.unregisterProvider = (name) => this.modelRegistry.unregisterProvider(name);
|
||||
}
|
||||
|
||||
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
||||
if (actions) {
|
||||
this.waitForIdleFn = actions.waitForIdle;
|
||||
this.newSessionHandler = actions.newSession;
|
||||
this.forkHandler = actions.fork;
|
||||
this.navigateTreeHandler = actions.navigateTree;
|
||||
this.switchSessionHandler = actions.switchSession;
|
||||
this.reloadHandler = actions.reload;
|
||||
return;
|
||||
}
|
||||
|
||||
this.waitForIdleFn = async () => {};
|
||||
this.newSessionHandler = async () => ({ cancelled: false });
|
||||
this.forkHandler = async () => ({ cancelled: false });
|
||||
this.navigateTreeHandler = async () => ({ cancelled: false });
|
||||
this.switchSessionHandler = async () => ({ cancelled: false });
|
||||
this.reloadHandler = async () => {};
|
||||
}
|
||||
|
||||
setUIContext(uiContext?: ExtensionUIContext): void {
|
||||
this.uiContext = uiContext ?? noOpUIContext;
|
||||
}
|
||||
|
||||
getUIContext(): ExtensionUIContext {
|
||||
return this.uiContext;
|
||||
}
|
||||
|
||||
hasUI(): boolean {
|
||||
return this.uiContext !== noOpUIContext;
|
||||
}
|
||||
|
||||
getExtensionPaths(): string[] {
|
||||
return this.extensions.map((e) => e.path);
|
||||
}
|
||||
|
||||
/** Get all registered tools from all extensions (first registration per name wins). */
|
||||
getAllRegisteredTools(): RegisteredTool[] {
|
||||
const toolsByName = new Map<string, RegisteredTool>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const tool of ext.tools.values()) {
|
||||
if (!toolsByName.has(tool.definition.name)) {
|
||||
toolsByName.set(tool.definition.name, tool);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Array.from(toolsByName.values());
|
||||
}
|
||||
|
||||
/** Get a tool definition by name. Returns undefined if not found. */
|
||||
getToolDefinition(toolName: string): RegisteredTool["definition"] | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const tool = ext.tools.get(toolName);
|
||||
if (tool) {
|
||||
return tool.definition;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getFlags(): Map<string, ExtensionFlag> {
|
||||
const allFlags = new Map<string, ExtensionFlag>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const [name, flag] of ext.flags) {
|
||||
if (!allFlags.has(name)) {
|
||||
allFlags.set(name, flag);
|
||||
}
|
||||
}
|
||||
}
|
||||
return allFlags;
|
||||
}
|
||||
|
||||
setFlagValue(name: string, value: boolean | string): void {
|
||||
this.runtime.flagValues.set(name, value);
|
||||
}
|
||||
|
||||
getFlagValues(): Map<string, boolean | string> {
|
||||
return new Map(this.runtime.flagValues);
|
||||
}
|
||||
|
||||
getShortcuts(effectiveKeybindings: Required<KeybindingsConfig>): Map<KeyId, ExtensionShortcut> {
|
||||
this.shortcutDiagnostics = [];
|
||||
const builtinKeybindings = buildBuiltinKeybindings(effectiveKeybindings);
|
||||
const extensionShortcuts = new Map<KeyId, ExtensionShortcut>();
|
||||
|
||||
const addDiagnostic = (message: string, extensionPath: string) => {
|
||||
this.shortcutDiagnostics.push({ type: "warning", message, path: extensionPath });
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
};
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const [key, shortcut] of ext.shortcuts) {
|
||||
const normalizedKey = key.toLowerCase() as KeyId;
|
||||
|
||||
const builtInKeybinding = builtinKeybindings[normalizedKey];
|
||||
if (builtInKeybinding?.restrictOverride === true) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut '${key}' from ${shortcut.extensionPath} conflicts with built-in shortcut. Skipping.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (builtInKeybinding?.restrictOverride === false) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' is built-in shortcut for ${builtInKeybinding.action} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
|
||||
const existingExtensionShortcut = extensionShortcuts.get(normalizedKey);
|
||||
if (existingExtensionShortcut) {
|
||||
addDiagnostic(
|
||||
`Extension shortcut conflict: '${key}' registered by both ${existingExtensionShortcut.extensionPath} and ${shortcut.extensionPath}. Using ${shortcut.extensionPath}.`,
|
||||
shortcut.extensionPath,
|
||||
);
|
||||
}
|
||||
extensionShortcuts.set(normalizedKey, shortcut);
|
||||
}
|
||||
}
|
||||
return extensionShortcuts;
|
||||
}
|
||||
|
||||
getShortcutDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.shortcutDiagnostics;
|
||||
}
|
||||
|
||||
onError(listener: ExtensionErrorListener): () => void {
|
||||
this.errorListeners.add(listener);
|
||||
return () => this.errorListeners.delete(listener);
|
||||
}
|
||||
|
||||
emitError(error: ExtensionError): void {
|
||||
for (const listener of this.errorListeners) {
|
||||
listener(error);
|
||||
}
|
||||
}
|
||||
|
||||
hasHandlers(eventType: string): boolean {
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(eventType);
|
||||
if (handlers && handlers.length > 0) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
getMessageRenderer(customType: string): MessageRenderer | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const renderer = ext.messageRenderers.get(customType);
|
||||
if (renderer) {
|
||||
return renderer;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
getRegisteredCommands(reserved?: Set<string>): RegisteredCommand[] {
|
||||
this.commandDiagnostics = [];
|
||||
|
||||
const commands: RegisteredCommand[] = [];
|
||||
const commandOwners = new Map<string, string>();
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
if (reserved?.has(command.name)) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with built-in commands. Skipping.`;
|
||||
this.commandDiagnostics.push({ type: "warning", message, path: ext.path });
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const existingOwner = commandOwners.get(command.name);
|
||||
if (existingOwner) {
|
||||
const message = `Extension command '${command.name}' from ${ext.path} conflicts with ${existingOwner}. Skipping.`;
|
||||
this.commandDiagnostics.push({ type: "warning", message, path: ext.path });
|
||||
if (!this.hasUI()) {
|
||||
console.warn(message);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
commandOwners.set(command.name, ext.path);
|
||||
commands.push(command);
|
||||
}
|
||||
}
|
||||
return commands;
|
||||
}
|
||||
|
||||
getCommandDiagnostics(): ResourceDiagnostic[] {
|
||||
return this.commandDiagnostics;
|
||||
}
|
||||
|
||||
getRegisteredCommandsWithPaths(): Array<{ command: RegisteredCommand; extensionPath: string }> {
|
||||
const result: Array<{ command: RegisteredCommand; extensionPath: string }> = [];
|
||||
for (const ext of this.extensions) {
|
||||
for (const command of ext.commands.values()) {
|
||||
result.push({ command, extensionPath: ext.path });
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
getCommand(name: string): RegisteredCommand | undefined {
|
||||
for (const ext of this.extensions) {
|
||||
const command = ext.commands.get(name);
|
||||
if (command) {
|
||||
return command;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Request a graceful shutdown. Called by extension tools and event handlers.
|
||||
* The actual shutdown behavior is provided by the mode via bindExtensions().
|
||||
*/
|
||||
shutdown(): void {
|
||||
this.shutdownHandler();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ExtensionContext for use in event handlers and tool execution.
|
||||
* Context values are resolved at call time, so changes via bindCore/bindUI are reflected.
|
||||
*/
|
||||
createContext(): ExtensionContext {
|
||||
const getModel = this.getModel;
|
||||
return {
|
||||
ui: this.uiContext,
|
||||
hasUI: this.hasUI(),
|
||||
cwd: this.cwd,
|
||||
sessionManager: this.sessionManager,
|
||||
modelRegistry: this.modelRegistry,
|
||||
get model() {
|
||||
return getModel();
|
||||
},
|
||||
isIdle: () => this.isIdleFn(),
|
||||
abort: () => this.abortFn(),
|
||||
hasPendingMessages: () => this.hasPendingMessagesFn(),
|
||||
shutdown: () => this.shutdownHandler(),
|
||||
getContextUsage: () => this.getContextUsageFn(),
|
||||
compact: (options) => this.compactFn(options),
|
||||
getSystemPrompt: () => this.getSystemPromptFn(),
|
||||
};
|
||||
}
|
||||
|
||||
createCommandContext(): ExtensionCommandContext {
|
||||
return {
|
||||
...this.createContext(),
|
||||
waitForIdle: () => this.waitForIdleFn(),
|
||||
newSession: (options) => this.newSessionHandler(options),
|
||||
fork: (entryId) => this.forkHandler(entryId),
|
||||
navigateTree: (targetId, options) => this.navigateTreeHandler(targetId, options),
|
||||
switchSession: (sessionPath) => this.switchSessionHandler(sessionPath),
|
||||
reload: () => this.reloadHandler(),
|
||||
};
|
||||
}
|
||||
|
||||
private isSessionBeforeEvent(event: RunnerEmitEvent): event is SessionBeforeEvent {
|
||||
return (
|
||||
event.type === "session_before_switch" ||
|
||||
event.type === "session_before_fork" ||
|
||||
event.type === "session_before_compact" ||
|
||||
event.type === "session_before_tree"
|
||||
);
|
||||
}
|
||||
|
||||
async emit<TEvent extends RunnerEmitEvent>(event: TEvent): Promise<RunnerEmitResult<TEvent>> {
|
||||
const ctx = this.createContext();
|
||||
let result: SessionBeforeEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get(event.type);
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (this.isSessionBeforeEvent(event) && handlerResult) {
|
||||
result = handlerResult as SessionBeforeEventResult;
|
||||
if (result.cancel) {
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: event.type,
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result as RunnerEmitResult<TEvent>;
|
||||
}
|
||||
|
||||
async emitToolResult(event: ToolResultEvent): Promise<ToolResultEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const currentEvent: ToolResultEvent = { ...event };
|
||||
let modified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_result");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = (await handler(currentEvent, ctx)) as ToolResultEventResult | undefined;
|
||||
if (!handlerResult) continue;
|
||||
|
||||
if (handlerResult.content !== undefined) {
|
||||
currentEvent.content = handlerResult.content;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.details !== undefined) {
|
||||
currentEvent.details = handlerResult.details;
|
||||
modified = true;
|
||||
}
|
||||
if (handlerResult.isError !== undefined) {
|
||||
currentEvent.isError = handlerResult.isError;
|
||||
modified = true;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "tool_result",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!modified) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
content: currentEvent.content,
|
||||
details: currentEvent.details,
|
||||
isError: currentEvent.isError,
|
||||
};
|
||||
}
|
||||
|
||||
async emitToolCall(event: ToolCallEvent): Promise<ToolCallEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
let result: ToolCallEventResult | undefined;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("tool_call");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
result = handlerResult as ToolCallEventResult;
|
||||
if (result.block) {
|
||||
return result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
async emitUserBash(event: UserBashEvent): Promise<UserBashEventResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("user_bash");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const handlerResult = await handler(event, ctx);
|
||||
if (handlerResult) {
|
||||
return handlerResult as UserBashEventResult;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "user_bash",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitContext(messages: AgentMessage[]): Promise<AgentMessage[]> {
|
||||
const ctx = this.createContext();
|
||||
let currentMessages = structuredClone(messages);
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("context");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ContextEvent = { type: "context", messages: currentMessages };
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult && (handlerResult as ContextEventResult).messages) {
|
||||
currentMessages = (handlerResult as ContextEventResult).messages!;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "context",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentMessages;
|
||||
}
|
||||
|
||||
async emitBeforeProviderRequest(payload: unknown): Promise<unknown> {
|
||||
const ctx = this.createContext();
|
||||
let currentPayload = payload;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("before_provider_request");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: BeforeProviderRequestEvent = {
|
||||
type: "before_provider_request",
|
||||
payload: currentPayload,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
if (handlerResult !== undefined) {
|
||||
currentPayload = handlerResult;
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "before_provider_request",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return currentPayload;
|
||||
}
|
||||
|
||||
async emitBeforeAgentStart(
|
||||
prompt: string,
|
||||
images: ImageContent[] | undefined,
|
||||
systemPrompt: string,
|
||||
): Promise<BeforeAgentStartCombinedResult | undefined> {
|
||||
const ctx = this.createContext();
|
||||
const messages: NonNullable<BeforeAgentStartEventResult["message"]>[] = [];
|
||||
let currentSystemPrompt = systemPrompt;
|
||||
let systemPromptModified = false;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("before_agent_start");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: BeforeAgentStartEvent = {
|
||||
type: "before_agent_start",
|
||||
prompt,
|
||||
images,
|
||||
systemPrompt: currentSystemPrompt,
|
||||
};
|
||||
const handlerResult = await handler(event, ctx);
|
||||
|
||||
if (handlerResult) {
|
||||
const result = handlerResult as BeforeAgentStartEventResult;
|
||||
if (result.message) {
|
||||
messages.push(result.message);
|
||||
}
|
||||
if (result.systemPrompt !== undefined) {
|
||||
currentSystemPrompt = result.systemPrompt;
|
||||
systemPromptModified = true;
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "before_agent_start",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (messages.length > 0 || systemPromptModified) {
|
||||
return {
|
||||
messages: messages.length > 0 ? messages : undefined,
|
||||
systemPrompt: systemPromptModified ? currentSystemPrompt : undefined,
|
||||
};
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async emitResourcesDiscover(
|
||||
cwd: string,
|
||||
reason: ResourcesDiscoverEvent["reason"],
|
||||
): Promise<{
|
||||
skillPaths: Array<{ path: string; extensionPath: string }>;
|
||||
promptPaths: Array<{ path: string; extensionPath: string }>;
|
||||
themePaths: Array<{ path: string; extensionPath: string }>;
|
||||
}> {
|
||||
const ctx = this.createContext();
|
||||
const skillPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const promptPaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
const themePaths: Array<{ path: string; extensionPath: string }> = [];
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
const handlers = ext.handlers.get("resources_discover");
|
||||
if (!handlers || handlers.length === 0) continue;
|
||||
|
||||
for (const handler of handlers) {
|
||||
try {
|
||||
const event: ResourcesDiscoverEvent = { type: "resources_discover", cwd, reason };
|
||||
const handlerResult = await handler(event, ctx);
|
||||
const result = handlerResult as ResourcesDiscoverResult | undefined;
|
||||
|
||||
if (result?.skillPaths?.length) {
|
||||
skillPaths.push(...result.skillPaths.map((path) => ({ path, extensionPath: ext.path })));
|
||||
}
|
||||
if (result?.promptPaths?.length) {
|
||||
promptPaths.push(...result.promptPaths.map((path) => ({ path, extensionPath: ext.path })));
|
||||
}
|
||||
if (result?.themePaths?.length) {
|
||||
themePaths.push(...result.themePaths.map((path) => ({ path, extensionPath: ext.path })));
|
||||
}
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : String(err);
|
||||
const stack = err instanceof Error ? err.stack : undefined;
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "resources_discover",
|
||||
error: message,
|
||||
stack,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { skillPaths, promptPaths, themePaths };
|
||||
}
|
||||
|
||||
/** Emit input event. Transforms chain, "handled" short-circuits. */
|
||||
async emitInput(text: string, images: ImageContent[] | undefined, source: InputSource): Promise<InputEventResult> {
|
||||
const ctx = this.createContext();
|
||||
let currentText = text;
|
||||
let currentImages = images;
|
||||
|
||||
for (const ext of this.extensions) {
|
||||
for (const handler of ext.handlers.get("input") ?? []) {
|
||||
try {
|
||||
const event: InputEvent = { type: "input", text: currentText, images: currentImages, source };
|
||||
const result = (await handler(event, ctx)) as InputEventResult | undefined;
|
||||
if (result?.action === "handled") return result;
|
||||
if (result?.action === "transform") {
|
||||
currentText = result.text;
|
||||
currentImages = result.images ?? currentImages;
|
||||
}
|
||||
} catch (err) {
|
||||
this.emitError({
|
||||
extensionPath: ext.path,
|
||||
event: "input",
|
||||
error: err instanceof Error ? err.message : String(err),
|
||||
stack: err instanceof Error ? err.stack : undefined,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentText !== text || currentImages !== images
|
||||
? { action: "transform", text: currentText, images: currentImages }
|
||||
: { action: "continue" };
|
||||
}
|
||||
}
|
||||
1411
packages/pi-coding-agent/src/core/extensions/types.ts
Normal file
1411
packages/pi-coding-agent/src/core/extensions/types.ts
Normal file
File diff suppressed because it is too large
Load diff
118
packages/pi-coding-agent/src/core/extensions/wrapper.ts
Normal file
118
packages/pi-coding-agent/src/core/extensions/wrapper.ts
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
/**
|
||||
* Tool wrappers for extensions.
|
||||
*/
|
||||
|
||||
import type { AgentTool, AgentToolUpdateCallback } from "@gsd/pi-agent-core";
|
||||
import type { ExtensionRunner } from "./runner.js";
|
||||
import type { RegisteredTool, ToolCallEventResult } from "./types.js";
|
||||
|
||||
/**
|
||||
* Wrap a RegisteredTool into an AgentTool.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTool(registeredTool: RegisteredTool, runner: ExtensionRunner): AgentTool {
|
||||
const { definition } = registeredTool;
|
||||
return {
|
||||
name: definition.name,
|
||||
label: definition.label,
|
||||
description: definition.description,
|
||||
parameters: definition.parameters,
|
||||
execute: (toolCallId, params, signal, onUpdate) =>
|
||||
definition.execute(toolCallId, params, signal, onUpdate, runner.createContext()),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all registered tools into AgentTools.
|
||||
* Uses the runner's createContext() for consistent context across tools and event handlers.
|
||||
*/
|
||||
export function wrapRegisteredTools(registeredTools: RegisteredTool[], runner: ExtensionRunner): AgentTool[] {
|
||||
return registeredTools.map((rt) => wrapRegisteredTool(rt, runner));
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap a tool with extension callbacks for interception.
|
||||
* - Emits tool_call event before execution (can block)
|
||||
* - Emits tool_result event after execution (can modify result)
|
||||
*/
|
||||
export function wrapToolWithExtensions<T>(tool: AgentTool<any, T>, runner: ExtensionRunner): AgentTool<any, T> {
|
||||
return {
|
||||
...tool,
|
||||
execute: async (
|
||||
toolCallId: string,
|
||||
params: Record<string, unknown>,
|
||||
signal?: AbortSignal,
|
||||
onUpdate?: AgentToolUpdateCallback<T>,
|
||||
) => {
|
||||
// Emit tool_call event - extensions can block execution
|
||||
if (runner.hasHandlers("tool_call")) {
|
||||
try {
|
||||
const callResult = (await runner.emitToolCall({
|
||||
type: "tool_call",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
})) as ToolCallEventResult | undefined;
|
||||
|
||||
if (callResult?.block) {
|
||||
const reason = callResult.reason || "Tool execution was blocked by an extension";
|
||||
throw new Error(reason);
|
||||
}
|
||||
} catch (err) {
|
||||
if (err instanceof Error) {
|
||||
throw err;
|
||||
}
|
||||
throw new Error(`Extension failed, blocking execution: ${String(err)}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute the actual tool
|
||||
try {
|
||||
const result = await tool.execute(toolCallId, params, signal, onUpdate);
|
||||
|
||||
// Emit tool_result event - extensions can modify the result
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
const resultResult = await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError: false,
|
||||
});
|
||||
|
||||
if (resultResult) {
|
||||
return {
|
||||
content: resultResult.content ?? result.content,
|
||||
details: (resultResult.details ?? result.details) as T,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return result;
|
||||
} catch (err) {
|
||||
// Emit tool_result event for errors
|
||||
if (runner.hasHandlers("tool_result")) {
|
||||
await runner.emitToolResult({
|
||||
type: "tool_result",
|
||||
toolName: tool.name,
|
||||
toolCallId,
|
||||
input: params,
|
||||
content: [{ type: "text", text: err instanceof Error ? err.message : String(err) }],
|
||||
details: undefined,
|
||||
isError: true,
|
||||
});
|
||||
}
|
||||
throw err;
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrap all tools with extension callbacks.
|
||||
*/
|
||||
export function wrapToolsWithExtensions<T>(tools: AgentTool<any, T>[], runner: ExtensionRunner): AgentTool<any, T>[] {
|
||||
return tools.map((tool) => wrapToolWithExtensions(tool, runner));
|
||||
}
|
||||
144
packages/pi-coding-agent/src/core/footer-data-provider.ts
Normal file
144
packages/pi-coding-agent/src/core/footer-data-provider.ts
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
import { existsSync, type FSWatcher, readFileSync, statSync, watch } from "fs";
|
||||
import { dirname, join, resolve } from "path";
|
||||
|
||||
/**
|
||||
* Find the git HEAD path by walking up from cwd.
|
||||
* Handles both regular git repos (.git is a directory) and worktrees (.git is a file).
|
||||
*/
|
||||
function findGitHeadPath(): string | null {
|
||||
let dir = process.cwd();
|
||||
while (true) {
|
||||
const gitPath = join(dir, ".git");
|
||||
if (existsSync(gitPath)) {
|
||||
try {
|
||||
const stat = statSync(gitPath);
|
||||
if (stat.isFile()) {
|
||||
const content = readFileSync(gitPath, "utf8").trim();
|
||||
if (content.startsWith("gitdir: ")) {
|
||||
const gitDir = content.slice(8);
|
||||
const headPath = resolve(dir, gitDir, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} else if (stat.isDirectory()) {
|
||||
const headPath = join(gitPath, "HEAD");
|
||||
if (existsSync(headPath)) return headPath;
|
||||
}
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
const parent = dirname(dir);
|
||||
if (parent === dir) return null;
|
||||
dir = parent;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provides git branch and extension statuses - data not otherwise accessible to extensions.
|
||||
* Token stats, model info available via ctx.sessionManager and ctx.model.
|
||||
*/
|
||||
export class FooterDataProvider {
|
||||
private extensionStatuses = new Map<string, string>();
|
||||
private cachedBranch: string | null | undefined = undefined;
|
||||
private gitWatcher: FSWatcher | null = null;
|
||||
private branchChangeCallbacks = new Set<() => void>();
|
||||
private availableProviderCount = 0;
|
||||
|
||||
constructor() {
|
||||
this.setupGitWatcher();
|
||||
}
|
||||
|
||||
/** Current git branch, null if not in repo, "detached" if detached HEAD */
|
||||
getGitBranch(): string | null {
|
||||
if (this.cachedBranch !== undefined) return this.cachedBranch;
|
||||
|
||||
try {
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) {
|
||||
this.cachedBranch = null;
|
||||
return null;
|
||||
}
|
||||
const content = readFileSync(gitHeadPath, "utf8").trim();
|
||||
this.cachedBranch = content.startsWith("ref: refs/heads/") ? content.slice(16) : "detached";
|
||||
} catch {
|
||||
this.cachedBranch = null;
|
||||
}
|
||||
return this.cachedBranch;
|
||||
}
|
||||
|
||||
/** Extension status texts set via ctx.ui.setStatus() */
|
||||
getExtensionStatuses(): ReadonlyMap<string, string> {
|
||||
return this.extensionStatuses;
|
||||
}
|
||||
|
||||
/** Subscribe to git branch changes. Returns unsubscribe function. */
|
||||
onBranchChange(callback: () => void): () => void {
|
||||
this.branchChangeCallbacks.add(callback);
|
||||
return () => this.branchChangeCallbacks.delete(callback);
|
||||
}
|
||||
|
||||
/** Internal: set extension status */
|
||||
setExtensionStatus(key: string, text: string | undefined): void {
|
||||
if (text === undefined) {
|
||||
this.extensionStatuses.delete(key);
|
||||
} else {
|
||||
this.extensionStatuses.set(key, text);
|
||||
}
|
||||
}
|
||||
|
||||
/** Internal: clear extension statuses */
|
||||
clearExtensionStatuses(): void {
|
||||
this.extensionStatuses.clear();
|
||||
}
|
||||
|
||||
/** Number of unique providers with available models (for footer display) */
|
||||
getAvailableProviderCount(): number {
|
||||
return this.availableProviderCount;
|
||||
}
|
||||
|
||||
/** Internal: update available provider count */
|
||||
setAvailableProviderCount(count: number): void {
|
||||
this.availableProviderCount = count;
|
||||
}
|
||||
|
||||
/** Internal: cleanup */
|
||||
dispose(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
this.branchChangeCallbacks.clear();
|
||||
}
|
||||
|
||||
private setupGitWatcher(): void {
|
||||
if (this.gitWatcher) {
|
||||
this.gitWatcher.close();
|
||||
this.gitWatcher = null;
|
||||
}
|
||||
|
||||
const gitHeadPath = findGitHeadPath();
|
||||
if (!gitHeadPath) return;
|
||||
|
||||
// Watch the directory containing HEAD, not HEAD itself.
|
||||
// Git uses atomic writes (write temp, rename over HEAD), which changes the inode.
|
||||
// fs.watch on a file stops working after the inode changes.
|
||||
const gitDir = dirname(gitHeadPath);
|
||||
|
||||
try {
|
||||
this.gitWatcher = watch(gitDir, (_eventType, filename) => {
|
||||
if (filename === "HEAD") {
|
||||
this.cachedBranch = undefined;
|
||||
for (const cb of this.branchChangeCallbacks) cb();
|
||||
}
|
||||
});
|
||||
} catch {
|
||||
// Silently fail if we can't watch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Read-only view for extensions - excludes setExtensionStatus, setAvailableProviderCount and dispose */
|
||||
export type ReadonlyFooterDataProvider = Pick<
|
||||
FooterDataProvider,
|
||||
"getGitBranch" | "getExtensionStatuses" | "getAvailableProviderCount" | "onBranchChange"
|
||||
>;
|
||||
61
packages/pi-coding-agent/src/core/index.ts
Normal file
61
packages/pi-coding-agent/src/core/index.ts
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Core modules shared between all run modes.
|
||||
*/
|
||||
|
||||
export {
|
||||
AgentSession,
|
||||
type AgentSessionConfig,
|
||||
type AgentSessionEvent,
|
||||
type AgentSessionEventListener,
|
||||
type ModelCycleResult,
|
||||
type PromptOptions,
|
||||
type SessionStats,
|
||||
} from "./agent-session.js";
|
||||
export { type BashExecutorOptions, type BashResult, executeBash, executeBashWithOperations } from "./bash-executor.js";
|
||||
export type { CompactionResult } from "./compaction/index.js";
|
||||
export { createEventBus, type EventBus, type EventBusController } from "./event-bus.js";
|
||||
|
||||
// Extensions system
|
||||
export {
|
||||
type AgentEndEvent,
|
||||
type AgentStartEvent,
|
||||
type AgentToolResult,
|
||||
type AgentToolUpdateCallback,
|
||||
type BeforeAgentStartEvent,
|
||||
type ContextEvent,
|
||||
discoverAndLoadExtensions,
|
||||
type ExecOptions,
|
||||
type ExecResult,
|
||||
type Extension,
|
||||
type ExtensionAPI,
|
||||
type ExtensionCommandContext,
|
||||
type ExtensionContext,
|
||||
type ExtensionError,
|
||||
type ExtensionEvent,
|
||||
type ExtensionFactory,
|
||||
type ExtensionFlag,
|
||||
type ExtensionHandler,
|
||||
ExtensionRunner,
|
||||
type ExtensionShortcut,
|
||||
type ExtensionUIContext,
|
||||
type LoadExtensionsResult,
|
||||
type MessageRenderer,
|
||||
type RegisteredCommand,
|
||||
type SessionBeforeCompactEvent,
|
||||
type SessionBeforeForkEvent,
|
||||
type SessionBeforeSwitchEvent,
|
||||
type SessionBeforeTreeEvent,
|
||||
type SessionCompactEvent,
|
||||
type SessionForkEvent,
|
||||
type SessionShutdownEvent,
|
||||
type SessionStartEvent,
|
||||
type SessionSwitchEvent,
|
||||
type SessionTreeEvent,
|
||||
type ToolCallEvent,
|
||||
type ToolDefinition,
|
||||
type ToolRenderResultOptions,
|
||||
type ToolResultEvent,
|
||||
type TurnEndEvent,
|
||||
type TurnStartEvent,
|
||||
wrapToolsWithExtensions,
|
||||
} from "./extensions/index.js";
|
||||
211
packages/pi-coding-agent/src/core/keybindings.ts
Normal file
211
packages/pi-coding-agent/src/core/keybindings.ts
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import {
|
||||
DEFAULT_EDITOR_KEYBINDINGS,
|
||||
type EditorAction,
|
||||
type EditorKeybindingsConfig,
|
||||
EditorKeybindingsManager,
|
||||
type KeyId,
|
||||
matchesKey,
|
||||
setEditorKeybindings,
|
||||
} from "@gsd/pi-tui";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
|
||||
/**
|
||||
* Application-level actions (coding agent specific).
|
||||
*/
|
||||
export type AppAction =
|
||||
| "interrupt"
|
||||
| "clear"
|
||||
| "exit"
|
||||
| "suspend"
|
||||
| "cycleThinkingLevel"
|
||||
| "cycleModelForward"
|
||||
| "cycleModelBackward"
|
||||
| "selectModel"
|
||||
| "expandTools"
|
||||
| "toggleThinking"
|
||||
| "toggleSessionNamedFilter"
|
||||
| "externalEditor"
|
||||
| "followUp"
|
||||
| "dequeue"
|
||||
| "pasteImage"
|
||||
| "newSession"
|
||||
| "tree"
|
||||
| "fork"
|
||||
| "resume";
|
||||
|
||||
/**
|
||||
* All configurable actions.
|
||||
*/
|
||||
export type KeyAction = AppAction | EditorAction;
|
||||
|
||||
/**
|
||||
* Full keybindings configuration (app + editor actions).
|
||||
*/
|
||||
export type KeybindingsConfig = {
|
||||
[K in KeyAction]?: KeyId | KeyId[];
|
||||
};
|
||||
|
||||
/**
|
||||
* Default application keybindings.
|
||||
*/
|
||||
export const DEFAULT_APP_KEYBINDINGS: Record<AppAction, KeyId | KeyId[]> = {
|
||||
interrupt: "escape",
|
||||
clear: "ctrl+c",
|
||||
exit: "ctrl+d",
|
||||
suspend: "ctrl+z",
|
||||
cycleThinkingLevel: "shift+tab",
|
||||
cycleModelForward: "ctrl+p",
|
||||
cycleModelBackward: "shift+ctrl+p",
|
||||
selectModel: "ctrl+l",
|
||||
expandTools: "ctrl+o",
|
||||
toggleThinking: "ctrl+t",
|
||||
toggleSessionNamedFilter: "ctrl+n",
|
||||
externalEditor: "ctrl+g",
|
||||
followUp: "alt+enter",
|
||||
dequeue: "alt+up",
|
||||
pasteImage: process.platform === "win32" ? "alt+v" : "ctrl+v",
|
||||
newSession: [],
|
||||
tree: [],
|
||||
fork: [],
|
||||
resume: [],
|
||||
};
|
||||
|
||||
/**
|
||||
* All default keybindings (app + editor).
|
||||
*/
|
||||
export const DEFAULT_KEYBINDINGS: Required<KeybindingsConfig> = {
|
||||
...DEFAULT_EDITOR_KEYBINDINGS,
|
||||
...DEFAULT_APP_KEYBINDINGS,
|
||||
};
|
||||
|
||||
// App actions list for type checking
|
||||
const APP_ACTIONS: AppAction[] = [
|
||||
"interrupt",
|
||||
"clear",
|
||||
"exit",
|
||||
"suspend",
|
||||
"cycleThinkingLevel",
|
||||
"cycleModelForward",
|
||||
"cycleModelBackward",
|
||||
"selectModel",
|
||||
"expandTools",
|
||||
"toggleThinking",
|
||||
"toggleSessionNamedFilter",
|
||||
"externalEditor",
|
||||
"followUp",
|
||||
"dequeue",
|
||||
"pasteImage",
|
||||
"newSession",
|
||||
"tree",
|
||||
"fork",
|
||||
"resume",
|
||||
];
|
||||
|
||||
function isAppAction(action: string): action is AppAction {
|
||||
return APP_ACTIONS.includes(action as AppAction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Manages all keybindings (app + editor).
|
||||
*/
|
||||
export class KeybindingsManager {
|
||||
private config: KeybindingsConfig;
|
||||
private appActionToKeys: Map<AppAction, KeyId[]>;
|
||||
|
||||
private constructor(config: KeybindingsConfig) {
|
||||
this.config = config;
|
||||
this.appActionToKeys = new Map();
|
||||
this.buildMaps();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create from config file and set up editor keybindings.
|
||||
*/
|
||||
static create(agentDir: string = getAgentDir()): KeybindingsManager {
|
||||
const configPath = join(agentDir, "keybindings.json");
|
||||
const config = KeybindingsManager.loadFromFile(configPath);
|
||||
const manager = new KeybindingsManager(config);
|
||||
|
||||
// Set up editor keybindings globally
|
||||
// Include both editor actions and expandTools (shared between app and editor)
|
||||
const editorConfig: EditorKeybindingsConfig = {};
|
||||
for (const [action, keys] of Object.entries(config)) {
|
||||
if (!isAppAction(action) || action === "expandTools") {
|
||||
editorConfig[action as EditorAction] = keys;
|
||||
}
|
||||
}
|
||||
setEditorKeybindings(new EditorKeybindingsManager(editorConfig));
|
||||
|
||||
return manager;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create in-memory.
|
||||
*/
|
||||
static inMemory(config: KeybindingsConfig = {}): KeybindingsManager {
|
||||
return new KeybindingsManager(config);
|
||||
}
|
||||
|
||||
private static loadFromFile(path: string): KeybindingsConfig {
|
||||
if (!existsSync(path)) return {};
|
||||
try {
|
||||
return JSON.parse(readFileSync(path, "utf-8"));
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
private buildMaps(): void {
|
||||
this.appActionToKeys.clear();
|
||||
|
||||
// Set defaults for app actions
|
||||
for (const [action, keys] of Object.entries(DEFAULT_APP_KEYBINDINGS)) {
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action as AppAction, [...keyArray]);
|
||||
}
|
||||
|
||||
// Override with user config (app actions only)
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys === undefined || !isAppAction(action)) continue;
|
||||
const keyArray = Array.isArray(keys) ? keys : [keys];
|
||||
this.appActionToKeys.set(action, keyArray);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if input matches an app action.
|
||||
*/
|
||||
matches(data: string, action: AppAction): boolean {
|
||||
const keys = this.appActionToKeys.get(action);
|
||||
if (!keys) return false;
|
||||
for (const key of keys) {
|
||||
if (matchesKey(data, key)) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get keys bound to an app action.
|
||||
*/
|
||||
getKeys(action: AppAction): KeyId[] {
|
||||
return this.appActionToKeys.get(action) ?? [];
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the full effective config.
|
||||
*/
|
||||
getEffectiveConfig(): Required<KeybindingsConfig> {
|
||||
const result = { ...DEFAULT_KEYBINDINGS };
|
||||
for (const [action, keys] of Object.entries(this.config)) {
|
||||
if (keys !== undefined) {
|
||||
(result as KeybindingsConfig)[action as KeyAction] = keys;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export for convenience
|
||||
export type { EditorAction, KeyId };
|
||||
195
packages/pi-coding-agent/src/core/messages.ts
Normal file
195
packages/pi-coding-agent/src/core/messages.ts
Normal file
|
|
@ -0,0 +1,195 @@
|
|||
/**
|
||||
* Custom message types and transformers for the coding agent.
|
||||
*
|
||||
* Extends the base AgentMessage type with coding-agent specific message types,
|
||||
* and provides a transformer to convert them to LLM-compatible messages.
|
||||
*/
|
||||
|
||||
import type { AgentMessage } from "@gsd/pi-agent-core";
|
||||
import type { ImageContent, Message, TextContent } from "@gsd/pi-ai";
|
||||
|
||||
export const COMPACTION_SUMMARY_PREFIX = `The conversation history before this point was compacted into the following summary:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const COMPACTION_SUMMARY_SUFFIX = `
|
||||
</summary>`;
|
||||
|
||||
export const BRANCH_SUMMARY_PREFIX = `The following is a summary of a branch that this conversation came back from:
|
||||
|
||||
<summary>
|
||||
`;
|
||||
|
||||
export const BRANCH_SUMMARY_SUFFIX = `</summary>`;
|
||||
|
||||
/**
|
||||
* Message type for bash executions via the ! command.
|
||||
*/
|
||||
export interface BashExecutionMessage {
|
||||
role: "bashExecution";
|
||||
command: string;
|
||||
output: string;
|
||||
exitCode: number | undefined;
|
||||
cancelled: boolean;
|
||||
truncated: boolean;
|
||||
fullOutputPath?: string;
|
||||
timestamp: number;
|
||||
/** If true, this message is excluded from LLM context (!! prefix) */
|
||||
excludeFromContext?: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
* Message type for extension-injected messages via sendMessage().
|
||||
* These are custom messages that extensions can inject into the conversation.
|
||||
*/
|
||||
export interface CustomMessage<T = unknown> {
|
||||
role: "custom";
|
||||
customType: string;
|
||||
content: string | (TextContent | ImageContent)[];
|
||||
display: boolean;
|
||||
details?: T;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface BranchSummaryMessage {
|
||||
role: "branchSummary";
|
||||
summary: string;
|
||||
fromId: string;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
export interface CompactionSummaryMessage {
|
||||
role: "compactionSummary";
|
||||
summary: string;
|
||||
tokensBefore: number;
|
||||
timestamp: number;
|
||||
}
|
||||
|
||||
// Extend CustomAgentMessages via declaration merging
|
||||
declare module "@gsd/pi-agent-core" {
|
||||
interface CustomAgentMessages {
|
||||
bashExecution: BashExecutionMessage;
|
||||
custom: CustomMessage;
|
||||
branchSummary: BranchSummaryMessage;
|
||||
compactionSummary: CompactionSummaryMessage;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a BashExecutionMessage to user message text for LLM context.
|
||||
*/
|
||||
export function bashExecutionToText(msg: BashExecutionMessage): string {
|
||||
let text = `Ran \`${msg.command}\`\n`;
|
||||
if (msg.output) {
|
||||
text += `\`\`\`\n${msg.output}\n\`\`\``;
|
||||
} else {
|
||||
text += "(no output)";
|
||||
}
|
||||
if (msg.cancelled) {
|
||||
text += "\n\n(command cancelled)";
|
||||
} else if (msg.exitCode !== null && msg.exitCode !== undefined && msg.exitCode !== 0) {
|
||||
text += `\n\nCommand exited with code ${msg.exitCode}`;
|
||||
}
|
||||
if (msg.truncated && msg.fullOutputPath) {
|
||||
text += `\n\n[Output truncated. Full output: ${msg.fullOutputPath}]`;
|
||||
}
|
||||
return text;
|
||||
}
|
||||
|
||||
export function createBranchSummaryMessage(summary: string, fromId: string, timestamp: string): BranchSummaryMessage {
|
||||
return {
|
||||
role: "branchSummary",
|
||||
summary,
|
||||
fromId,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
export function createCompactionSummaryMessage(
|
||||
summary: string,
|
||||
tokensBefore: number,
|
||||
timestamp: string,
|
||||
): CompactionSummaryMessage {
|
||||
return {
|
||||
role: "compactionSummary",
|
||||
summary: summary,
|
||||
tokensBefore,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/** Convert CustomMessageEntry to AgentMessage format */
|
||||
export function createCustomMessage(
|
||||
customType: string,
|
||||
content: string | (TextContent | ImageContent)[],
|
||||
display: boolean,
|
||||
details: unknown | undefined,
|
||||
timestamp: string,
|
||||
): CustomMessage {
|
||||
return {
|
||||
role: "custom",
|
||||
customType,
|
||||
content,
|
||||
display,
|
||||
details,
|
||||
timestamp: new Date(timestamp).getTime(),
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Transform AgentMessages (including custom types) to LLM-compatible Messages.
|
||||
*
|
||||
* This is used by:
|
||||
* - Agent's transormToLlm option (for prompt calls and queued messages)
|
||||
* - Compaction's generateSummary (for summarization)
|
||||
* - Custom extensions and tools
|
||||
*/
|
||||
export function convertToLlm(messages: AgentMessage[]): Message[] {
|
||||
return messages
|
||||
.map((m): Message | undefined => {
|
||||
switch (m.role) {
|
||||
case "bashExecution":
|
||||
// Skip messages excluded from context (!! prefix)
|
||||
if (m.excludeFromContext) {
|
||||
return undefined;
|
||||
}
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text", text: bashExecutionToText(m) }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "custom": {
|
||||
const content = typeof m.content === "string" ? [{ type: "text" as const, text: m.content }] : m.content;
|
||||
return {
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
}
|
||||
case "branchSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [{ type: "text" as const, text: BRANCH_SUMMARY_PREFIX + m.summary + BRANCH_SUMMARY_SUFFIX }],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "compactionSummary":
|
||||
return {
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text" as const, text: COMPACTION_SUMMARY_PREFIX + m.summary + COMPACTION_SUMMARY_SUFFIX },
|
||||
],
|
||||
timestamp: m.timestamp,
|
||||
};
|
||||
case "user":
|
||||
case "assistant":
|
||||
case "toolResult":
|
||||
return m;
|
||||
default:
|
||||
// biome-ignore lint/correctness/noSwitchDeclarations: fine
|
||||
const _exhaustiveCheck: never = m;
|
||||
return undefined;
|
||||
}
|
||||
})
|
||||
.filter((m) => m !== undefined);
|
||||
}
|
||||
694
packages/pi-coding-agent/src/core/model-registry.ts
Normal file
694
packages/pi-coding-agent/src/core/model-registry.ts
Normal file
|
|
@ -0,0 +1,694 @@
|
|||
/**
|
||||
* Model registry - manages built-in and custom models, provides API key resolution.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
getModels,
|
||||
getProviders,
|
||||
type KnownProvider,
|
||||
type Model,
|
||||
type OAuthProviderInterface,
|
||||
type OpenAICompletionsCompat,
|
||||
type OpenAIResponsesCompat,
|
||||
registerApiProvider,
|
||||
resetApiProviders,
|
||||
type SimpleStreamOptions,
|
||||
} from "@gsd/pi-ai";
|
||||
import { registerOAuthProvider, resetOAuthProviders } from "@gsd/pi-ai/oauth";
|
||||
import { type Static, Type } from "@sinclair/typebox";
|
||||
import AjvModule from "ajv";
|
||||
import { existsSync, readFileSync } from "fs";
|
||||
import { join } from "path";
|
||||
import { getAgentDir } from "../config.js";
|
||||
import type { AuthStorage } from "./auth-storage.js";
|
||||
import { clearConfigValueCache, resolveConfigValue, resolveHeaders } from "./resolve-config-value.js";
|
||||
|
||||
const Ajv = (AjvModule as any).default || AjvModule;
|
||||
const ajv = new Ajv();
|
||||
|
||||
// Schema for OpenRouter routing preferences
|
||||
const OpenRouterRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for Vercel AI Gateway routing preferences
|
||||
const VercelGatewayRoutingSchema = Type.Object({
|
||||
only: Type.Optional(Type.Array(Type.String())),
|
||||
order: Type.Optional(Type.Array(Type.String())),
|
||||
});
|
||||
|
||||
// Schema for OpenAI compatibility settings
|
||||
const OpenAICompletionsCompatSchema = Type.Object({
|
||||
supportsStore: Type.Optional(Type.Boolean()),
|
||||
supportsDeveloperRole: Type.Optional(Type.Boolean()),
|
||||
supportsReasoningEffort: Type.Optional(Type.Boolean()),
|
||||
supportsUsageInStreaming: Type.Optional(Type.Boolean()),
|
||||
maxTokensField: Type.Optional(Type.Union([Type.Literal("max_completion_tokens"), Type.Literal("max_tokens")])),
|
||||
requiresToolResultName: Type.Optional(Type.Boolean()),
|
||||
requiresAssistantAfterToolResult: Type.Optional(Type.Boolean()),
|
||||
requiresThinkingAsText: Type.Optional(Type.Boolean()),
|
||||
requiresMistralToolIds: Type.Optional(Type.Boolean()),
|
||||
thinkingFormat: Type.Optional(Type.Union([Type.Literal("openai"), Type.Literal("zai"), Type.Literal("qwen")])),
|
||||
openRouterRouting: Type.Optional(OpenRouterRoutingSchema),
|
||||
vercelGatewayRouting: Type.Optional(VercelGatewayRoutingSchema),
|
||||
});
|
||||
|
||||
const OpenAIResponsesCompatSchema = Type.Object({
|
||||
// Reserved for future use
|
||||
});
|
||||
|
||||
const OpenAICompatSchema = Type.Union([OpenAICompletionsCompatSchema, OpenAIResponsesCompatSchema]);
|
||||
|
||||
// Schema for custom model definition
|
||||
// Most fields are optional with sensible defaults for local models (Ollama, LM Studio, etc.)
|
||||
const ModelDefinitionSchema = Type.Object({
|
||||
id: Type.String({ minLength: 1 }),
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Number(),
|
||||
output: Type.Number(),
|
||||
cacheRead: Type.Number(),
|
||||
cacheWrite: Type.Number(),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
// Schema for per-model overrides (all fields optional, merged with built-in model)
|
||||
const ModelOverrideSchema = Type.Object({
|
||||
name: Type.Optional(Type.String({ minLength: 1 })),
|
||||
reasoning: Type.Optional(Type.Boolean()),
|
||||
input: Type.Optional(Type.Array(Type.Union([Type.Literal("text"), Type.Literal("image")]))),
|
||||
cost: Type.Optional(
|
||||
Type.Object({
|
||||
input: Type.Optional(Type.Number()),
|
||||
output: Type.Optional(Type.Number()),
|
||||
cacheRead: Type.Optional(Type.Number()),
|
||||
cacheWrite: Type.Optional(Type.Number()),
|
||||
}),
|
||||
),
|
||||
contextWindow: Type.Optional(Type.Number()),
|
||||
maxTokens: Type.Optional(Type.Number()),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
compat: Type.Optional(OpenAICompatSchema),
|
||||
});
|
||||
|
||||
type ModelOverride = Static<typeof ModelOverrideSchema>;
|
||||
|
||||
const ProviderConfigSchema = Type.Object({
|
||||
baseUrl: Type.Optional(Type.String({ minLength: 1 })),
|
||||
apiKey: Type.Optional(Type.String({ minLength: 1 })),
|
||||
api: Type.Optional(Type.String({ minLength: 1 })),
|
||||
headers: Type.Optional(Type.Record(Type.String(), Type.String())),
|
||||
authHeader: Type.Optional(Type.Boolean()),
|
||||
models: Type.Optional(Type.Array(ModelDefinitionSchema)),
|
||||
modelOverrides: Type.Optional(Type.Record(Type.String(), ModelOverrideSchema)),
|
||||
});
|
||||
|
||||
const ModelsConfigSchema = Type.Object({
|
||||
providers: Type.Record(Type.String(), ProviderConfigSchema),
|
||||
});
|
||||
|
||||
ajv.addSchema(ModelsConfigSchema, "ModelsConfig");
|
||||
|
||||
type ModelsConfig = Static<typeof ModelsConfigSchema>;
|
||||
|
||||
/** Provider override config (baseUrl, headers, apiKey) without custom models */
|
||||
interface ProviderOverride {
|
||||
baseUrl?: string;
|
||||
headers?: Record<string, string>;
|
||||
apiKey?: string;
|
||||
}
|
||||
|
||||
/** Result of loading custom models from models.json */
|
||||
interface CustomModelsResult {
|
||||
models: Model<Api>[];
|
||||
/** Providers with baseUrl/headers/apiKey overrides for built-in models */
|
||||
overrides: Map<string, ProviderOverride>;
|
||||
/** Per-model overrides: provider -> modelId -> override */
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>;
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
function emptyCustomModelsResult(error?: string): CustomModelsResult {
|
||||
return { models: [], overrides: new Map(), modelOverrides: new Map(), error };
|
||||
}
|
||||
|
||||
function mergeCompat(
|
||||
baseCompat: Model<Api>["compat"],
|
||||
overrideCompat: ModelOverride["compat"],
|
||||
): Model<Api>["compat"] | undefined {
|
||||
if (!overrideCompat) return baseCompat;
|
||||
|
||||
const base = baseCompat as OpenAICompletionsCompat | OpenAIResponsesCompat | undefined;
|
||||
const override = overrideCompat as OpenAICompletionsCompat | OpenAIResponsesCompat;
|
||||
const merged = { ...base, ...override } as OpenAICompletionsCompat | OpenAIResponsesCompat;
|
||||
|
||||
const baseCompletions = base as OpenAICompletionsCompat | undefined;
|
||||
const overrideCompletions = override as OpenAICompletionsCompat;
|
||||
const mergedCompletions = merged as OpenAICompletionsCompat;
|
||||
|
||||
if (baseCompletions?.openRouterRouting || overrideCompletions.openRouterRouting) {
|
||||
mergedCompletions.openRouterRouting = {
|
||||
...baseCompletions?.openRouterRouting,
|
||||
...overrideCompletions.openRouterRouting,
|
||||
};
|
||||
}
|
||||
|
||||
if (baseCompletions?.vercelGatewayRouting || overrideCompletions.vercelGatewayRouting) {
|
||||
mergedCompletions.vercelGatewayRouting = {
|
||||
...baseCompletions?.vercelGatewayRouting,
|
||||
...overrideCompletions.vercelGatewayRouting,
|
||||
};
|
||||
}
|
||||
|
||||
return merged as Model<Api>["compat"];
|
||||
}
|
||||
|
||||
/**
|
||||
* Deep merge a model override into a model.
|
||||
* Handles nested objects (cost, compat) by merging rather than replacing.
|
||||
*/
|
||||
function applyModelOverride(model: Model<Api>, override: ModelOverride): Model<Api> {
|
||||
const result = { ...model };
|
||||
|
||||
// Simple field overrides
|
||||
if (override.name !== undefined) result.name = override.name;
|
||||
if (override.reasoning !== undefined) result.reasoning = override.reasoning;
|
||||
if (override.input !== undefined) result.input = override.input as ("text" | "image")[];
|
||||
if (override.contextWindow !== undefined) result.contextWindow = override.contextWindow;
|
||||
if (override.maxTokens !== undefined) result.maxTokens = override.maxTokens;
|
||||
|
||||
// Merge cost (partial override)
|
||||
if (override.cost) {
|
||||
result.cost = {
|
||||
input: override.cost.input ?? model.cost.input,
|
||||
output: override.cost.output ?? model.cost.output,
|
||||
cacheRead: override.cost.cacheRead ?? model.cost.cacheRead,
|
||||
cacheWrite: override.cost.cacheWrite ?? model.cost.cacheWrite,
|
||||
};
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
if (override.headers) {
|
||||
const resolvedHeaders = resolveHeaders(override.headers);
|
||||
result.headers = resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers;
|
||||
}
|
||||
|
||||
// Deep merge compat
|
||||
result.compat = mergeCompat(model.compat, override.compat);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export const clearApiKeyCache = clearConfigValueCache;
|
||||
|
||||
/**
|
||||
* Model registry - loads and manages models, resolves API keys via AuthStorage.
|
||||
*/
|
||||
export class ModelRegistry {
|
||||
private models: Model<Api>[] = [];
|
||||
private customProviderApiKeys: Map<string, string> = new Map();
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
||||
) {
|
||||
// Set up fallback resolver for custom provider API keys
|
||||
this.authStorage.setFallbackResolver((provider) => {
|
||||
const keyConfig = this.customProviderApiKeys.get(provider);
|
||||
if (keyConfig) {
|
||||
return resolveConfigValue(keyConfig);
|
||||
}
|
||||
return undefined;
|
||||
});
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
refresh(): void {
|
||||
this.customProviderApiKeys.clear();
|
||||
this.loadError = undefined;
|
||||
|
||||
// Ensure dynamic API/OAuth registrations are rebuilt from current provider state.
|
||||
resetApiProviders();
|
||||
resetOAuthProviders();
|
||||
|
||||
this.loadModels();
|
||||
|
||||
for (const [providerName, config] of this.registeredProviders.entries()) {
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get any error from loading models.json (undefined if no error).
|
||||
*/
|
||||
getError(): string | undefined {
|
||||
return this.loadError;
|
||||
}
|
||||
|
||||
private loadModels(): void {
|
||||
// Load custom models and overrides from models.json
|
||||
const {
|
||||
models: customModels,
|
||||
overrides,
|
||||
modelOverrides,
|
||||
error,
|
||||
} = this.modelsJsonPath ? this.loadCustomModels(this.modelsJsonPath) : emptyCustomModelsResult();
|
||||
|
||||
if (error) {
|
||||
this.loadError = error;
|
||||
// Keep built-in models even if custom models failed to load
|
||||
}
|
||||
|
||||
const builtInModels = this.loadBuiltInModels(overrides, modelOverrides);
|
||||
let combined = this.mergeCustomModels(builtInModels, customModels);
|
||||
|
||||
// Let OAuth providers modify their models (e.g., update baseUrl)
|
||||
for (const oauthProvider of this.authStorage.getOAuthProviders()) {
|
||||
const cred = this.authStorage.get(oauthProvider.id);
|
||||
if (cred?.type === "oauth" && oauthProvider.modifyModels) {
|
||||
combined = oauthProvider.modifyModels(combined, cred);
|
||||
}
|
||||
}
|
||||
|
||||
this.models = combined;
|
||||
}
|
||||
|
||||
/** Load built-in models and apply provider/model overrides */
|
||||
private loadBuiltInModels(
|
||||
overrides: Map<string, ProviderOverride>,
|
||||
modelOverrides: Map<string, Map<string, ModelOverride>>,
|
||||
): Model<Api>[] {
|
||||
return getProviders().flatMap((provider) => {
|
||||
const models = getModels(provider as KnownProvider) as Model<Api>[];
|
||||
const providerOverride = overrides.get(provider);
|
||||
const perModelOverrides = modelOverrides.get(provider);
|
||||
|
||||
return models.map((m) => {
|
||||
let model = m;
|
||||
|
||||
// Apply provider-level baseUrl/headers override
|
||||
if (providerOverride) {
|
||||
const resolvedHeaders = resolveHeaders(providerOverride.headers);
|
||||
model = {
|
||||
...model,
|
||||
baseUrl: providerOverride.baseUrl ?? model.baseUrl,
|
||||
headers: resolvedHeaders ? { ...model.headers, ...resolvedHeaders } : model.headers,
|
||||
};
|
||||
}
|
||||
|
||||
// Apply per-model override
|
||||
const modelOverride = perModelOverrides?.get(m.id);
|
||||
if (modelOverride) {
|
||||
model = applyModelOverride(model, modelOverride);
|
||||
}
|
||||
|
||||
return model;
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
/** Merge custom models into built-in list by provider+id (custom wins on conflicts). */
|
||||
private mergeCustomModels(builtInModels: Model<Api>[], customModels: Model<Api>[]): Model<Api>[] {
|
||||
const merged = [...builtInModels];
|
||||
for (const customModel of customModels) {
|
||||
const existingIndex = merged.findIndex((m) => m.provider === customModel.provider && m.id === customModel.id);
|
||||
if (existingIndex >= 0) {
|
||||
merged[existingIndex] = customModel;
|
||||
} else {
|
||||
merged.push(customModel);
|
||||
}
|
||||
}
|
||||
return merged;
|
||||
}
|
||||
|
||||
private loadCustomModels(modelsJsonPath: string): CustomModelsResult {
|
||||
if (!existsSync(modelsJsonPath)) {
|
||||
return emptyCustomModelsResult();
|
||||
}
|
||||
|
||||
try {
|
||||
const content = readFileSync(modelsJsonPath, "utf-8");
|
||||
const config: ModelsConfig = JSON.parse(content);
|
||||
|
||||
// Validate schema
|
||||
const validate = ajv.getSchema("ModelsConfig")!;
|
||||
if (!validate(config)) {
|
||||
const errors =
|
||||
validate.errors?.map((e: any) => ` - ${e.instancePath || "root"}: ${e.message}`).join("\n") ||
|
||||
"Unknown schema error";
|
||||
return emptyCustomModelsResult(`Invalid models.json schema:\n${errors}\n\nFile: ${modelsJsonPath}`);
|
||||
}
|
||||
|
||||
// Additional validation
|
||||
this.validateConfig(config);
|
||||
|
||||
const overrides = new Map<string, ProviderOverride>();
|
||||
const modelOverrides = new Map<string, Map<string, ModelOverride>>();
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
// Apply provider-level baseUrl/headers/apiKey override to built-in models when configured.
|
||||
if (providerConfig.baseUrl || providerConfig.headers || providerConfig.apiKey) {
|
||||
overrides.set(providerName, {
|
||||
baseUrl: providerConfig.baseUrl,
|
||||
headers: providerConfig.headers,
|
||||
apiKey: providerConfig.apiKey,
|
||||
});
|
||||
}
|
||||
|
||||
// Store API key for fallback resolver.
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
if (providerConfig.modelOverrides) {
|
||||
modelOverrides.set(providerName, new Map(Object.entries(providerConfig.modelOverrides)));
|
||||
}
|
||||
}
|
||||
|
||||
return { models: this.parseModels(config), overrides, modelOverrides, error: undefined };
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
return emptyCustomModelsResult(`Failed to parse models.json: ${error.message}\n\nFile: ${modelsJsonPath}`);
|
||||
}
|
||||
return emptyCustomModelsResult(
|
||||
`Failed to load models.json: ${error instanceof Error ? error.message : error}\n\nFile: ${modelsJsonPath}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
private validateConfig(config: ModelsConfig): void {
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const hasProviderApi = !!providerConfig.api;
|
||||
const models = providerConfig.models ?? [];
|
||||
const hasModelOverrides =
|
||||
providerConfig.modelOverrides && Object.keys(providerConfig.modelOverrides).length > 0;
|
||||
|
||||
if (models.length === 0) {
|
||||
// Override-only config: needs baseUrl OR modelOverrides (or both)
|
||||
if (!providerConfig.baseUrl && !hasModelOverrides) {
|
||||
throw new Error(`Provider ${providerName}: must specify "baseUrl", "modelOverrides", or "models".`);
|
||||
}
|
||||
} else {
|
||||
// Custom models are merged into provider models and require endpoint + auth.
|
||||
if (!providerConfig.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining custom models.`);
|
||||
}
|
||||
if (!providerConfig.apiKey) {
|
||||
throw new Error(`Provider ${providerName}: "apiKey" is required when defining custom models.`);
|
||||
}
|
||||
}
|
||||
|
||||
for (const modelDef of models) {
|
||||
const hasModelApi = !!modelDef.api;
|
||||
|
||||
if (!hasProviderApi && !hasModelApi) {
|
||||
throw new Error(
|
||||
`Provider ${providerName}, model ${modelDef.id}: no "api" specified. Set at provider or model level.`,
|
||||
);
|
||||
}
|
||||
|
||||
if (!modelDef.id) throw new Error(`Provider ${providerName}: model missing "id"`);
|
||||
// Validate contextWindow/maxTokens only if provided (they have defaults)
|
||||
if (modelDef.contextWindow !== undefined && modelDef.contextWindow <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid contextWindow`);
|
||||
if (modelDef.maxTokens !== undefined && modelDef.maxTokens <= 0)
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: invalid maxTokens`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private parseModels(config: ModelsConfig): Model<Api>[] {
|
||||
const models: Model<Api>[] = [];
|
||||
|
||||
for (const [providerName, providerConfig] of Object.entries(config.providers)) {
|
||||
const modelDefs = providerConfig.models ?? [];
|
||||
if (modelDefs.length === 0) continue; // Override-only, no custom models
|
||||
|
||||
// Store API key config for fallback resolver
|
||||
if (providerConfig.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, providerConfig.apiKey);
|
||||
}
|
||||
|
||||
for (const modelDef of modelDefs) {
|
||||
const api = modelDef.api || providerConfig.api;
|
||||
if (!api) continue;
|
||||
|
||||
// Merge headers: provider headers are base, model headers override
|
||||
// Resolve env vars and shell commands in header values
|
||||
const providerHeaders = resolveHeaders(providerConfig.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header with resolved API key
|
||||
if (providerConfig.authHeader && providerConfig.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(providerConfig.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
// Provider baseUrl is required when custom models are defined.
|
||||
// Individual models can override it with modelDef.baseUrl.
|
||||
const defaultCost = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
|
||||
models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name ?? modelDef.id,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: modelDef.baseUrl ?? providerConfig.baseUrl!,
|
||||
reasoning: modelDef.reasoning ?? false,
|
||||
input: (modelDef.input ?? ["text"]) as ("text" | "image")[],
|
||||
cost: modelDef.cost ?? defaultCost,
|
||||
contextWindow: modelDef.contextWindow ?? 128000,
|
||||
maxTokens: modelDef.maxTokens ?? 16384,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
}
|
||||
|
||||
return models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all models (built-in + custom).
|
||||
* If models.json had errors, returns only built-in models.
|
||||
*/
|
||||
getAll(): Model<Api>[] {
|
||||
return this.models;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get only models that have auth configured.
|
||||
* This is a fast check that doesn't refresh OAuth tokens.
|
||||
*/
|
||||
getAvailable(): Model<Api>[] {
|
||||
return this.models.filter((m) => this.authStorage.hasAuth(m.provider));
|
||||
}
|
||||
|
||||
/**
|
||||
* Find a model by provider and ID.
|
||||
*/
|
||||
find(provider: string, modelId: string): Model<Api> | undefined {
|
||||
return this.models.find((m) => m.provider === provider && m.id === modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
*/
|
||||
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a model is using OAuth credentials (subscription).
|
||||
*/
|
||||
isUsingOAuth(model: Model<Api>): boolean {
|
||||
const cred = this.authStorage.get(model.provider);
|
||||
return cred?.type === "oauth";
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a provider dynamically (from extensions).
|
||||
*
|
||||
* If provider has models: replaces all existing models for this provider.
|
||||
* If provider has only baseUrl/headers: overrides existing models' URLs.
|
||||
* If provider has oauth: registers OAuth provider for /login support.
|
||||
*/
|
||||
registerProvider(providerName: string, config: ProviderConfigInput): void {
|
||||
this.registeredProviders.set(providerName, config);
|
||||
this.applyProviderConfig(providerName, config);
|
||||
}
|
||||
|
||||
/**
|
||||
* Unregister a previously registered provider.
|
||||
*
|
||||
* Removes the provider from the registry and reloads models from disk so that
|
||||
* built-in models overridden by this provider are restored to their original state.
|
||||
* Also resets dynamic OAuth and API stream registrations before reapplying
|
||||
* remaining dynamic providers.
|
||||
* Has no effect if the provider was never registered.
|
||||
*/
|
||||
unregisterProvider(providerName: string): void {
|
||||
if (!this.registeredProviders.has(providerName)) return;
|
||||
this.registeredProviders.delete(providerName);
|
||||
this.customProviderApiKeys.delete(providerName);
|
||||
this.refresh();
|
||||
}
|
||||
|
||||
private applyProviderConfig(providerName: string, config: ProviderConfigInput): void {
|
||||
// Register OAuth provider if provided
|
||||
if (config.oauth) {
|
||||
// Ensure the OAuth provider ID matches the provider name
|
||||
const oauthProvider: OAuthProviderInterface = {
|
||||
...config.oauth,
|
||||
id: providerName,
|
||||
};
|
||||
registerOAuthProvider(oauthProvider);
|
||||
}
|
||||
|
||||
if (config.streamSimple) {
|
||||
if (!config.api) {
|
||||
throw new Error(`Provider ${providerName}: "api" is required when registering streamSimple.`);
|
||||
}
|
||||
const streamSimple = config.streamSimple;
|
||||
registerApiProvider(
|
||||
{
|
||||
api: config.api,
|
||||
stream: (model, context, options) => streamSimple(model, context, options as SimpleStreamOptions),
|
||||
streamSimple,
|
||||
},
|
||||
`provider:${providerName}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Store API key for auth resolution
|
||||
if (config.apiKey) {
|
||||
this.customProviderApiKeys.set(providerName, config.apiKey);
|
||||
}
|
||||
|
||||
if (config.models && config.models.length > 0) {
|
||||
// Full replacement: remove existing models for this provider
|
||||
this.models = this.models.filter((m) => m.provider !== providerName);
|
||||
|
||||
// Validate required fields
|
||||
if (!config.baseUrl) {
|
||||
throw new Error(`Provider ${providerName}: "baseUrl" is required when defining models.`);
|
||||
}
|
||||
if (!config.apiKey && !config.oauth) {
|
||||
throw new Error(`Provider ${providerName}: "apiKey" or "oauth" is required when defining models.`);
|
||||
}
|
||||
|
||||
// Parse and add new models
|
||||
for (const modelDef of config.models) {
|
||||
const api = modelDef.api || config.api;
|
||||
if (!api) {
|
||||
throw new Error(`Provider ${providerName}, model ${modelDef.id}: no "api" specified.`);
|
||||
}
|
||||
|
||||
// Merge headers
|
||||
const providerHeaders = resolveHeaders(config.headers);
|
||||
const modelHeaders = resolveHeaders(modelDef.headers);
|
||||
let headers = providerHeaders || modelHeaders ? { ...providerHeaders, ...modelHeaders } : undefined;
|
||||
|
||||
// If authHeader is true, add Authorization header
|
||||
if (config.authHeader && config.apiKey) {
|
||||
const resolvedKey = resolveConfigValue(config.apiKey);
|
||||
if (resolvedKey) {
|
||||
headers = { ...headers, Authorization: `Bearer ${resolvedKey}` };
|
||||
}
|
||||
}
|
||||
|
||||
this.models.push({
|
||||
id: modelDef.id,
|
||||
name: modelDef.name,
|
||||
api: api as Api,
|
||||
provider: providerName,
|
||||
baseUrl: config.baseUrl,
|
||||
reasoning: modelDef.reasoning,
|
||||
input: modelDef.input as ("text" | "image")[],
|
||||
cost: modelDef.cost,
|
||||
contextWindow: modelDef.contextWindow,
|
||||
maxTokens: modelDef.maxTokens,
|
||||
headers,
|
||||
compat: modelDef.compat,
|
||||
} as Model<Api>);
|
||||
}
|
||||
|
||||
// Apply OAuth modifyModels if credentials exist (e.g., to update baseUrl)
|
||||
if (config.oauth?.modifyModels) {
|
||||
const cred = this.authStorage.get(providerName);
|
||||
if (cred?.type === "oauth") {
|
||||
this.models = config.oauth.modifyModels(this.models, cred);
|
||||
}
|
||||
}
|
||||
} else if (config.baseUrl) {
|
||||
// Override-only: update baseUrl/headers for existing models
|
||||
const resolvedHeaders = resolveHeaders(config.headers);
|
||||
this.models = this.models.map((m) => {
|
||||
if (m.provider !== providerName) return m;
|
||||
return {
|
||||
...m,
|
||||
baseUrl: config.baseUrl ?? m.baseUrl,
|
||||
headers: resolvedHeaders ? { ...m.headers, ...resolvedHeaders } : m.headers,
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Input type for registerProvider API.
|
||||
*/
|
||||
export interface ProviderConfigInput {
|
||||
baseUrl?: string;
|
||||
apiKey?: string;
|
||||
api?: Api;
|
||||
streamSimple?: (model: Model<Api>, context: Context, options?: SimpleStreamOptions) => AssistantMessageEventStream;
|
||||
headers?: Record<string, string>;
|
||||
authHeader?: boolean;
|
||||
/** OAuth provider for /login support */
|
||||
oauth?: Omit<OAuthProviderInterface, "id">;
|
||||
models?: Array<{
|
||||
id: string;
|
||||
name: string;
|
||||
api?: Api;
|
||||
baseUrl?: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
headers?: Record<string, string>;
|
||||
compat?: Model<Api>["compat"];
|
||||
}>;
|
||||
}
|
||||
594
packages/pi-coding-agent/src/core/model-resolver.ts
Normal file
594
packages/pi-coding-agent/src/core/model-resolver.ts
Normal file
|
|
@ -0,0 +1,594 @@
|
|||
/**
|
||||
* Model resolution, scoping, and initial selection
|
||||
*/
|
||||
|
||||
import type { ThinkingLevel } from "@gsd/pi-agent-core";
|
||||
import { type Api, type KnownProvider, type Model, modelsAreEqual } from "@gsd/pi-ai";
|
||||
import chalk from "chalk";
|
||||
import { minimatch } from "minimatch";
|
||||
import { isValidThinkingLevel } from "../cli/args.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type { ModelRegistry } from "./model-registry.js";
|
||||
|
||||
/** Default model IDs for each known provider */
|
||||
export const defaultModelPerProvider: Record<KnownProvider, string> = {
|
||||
"amazon-bedrock": "us.anthropic.claude-opus-4-6-v1",
|
||||
anthropic: "claude-opus-4-6",
|
||||
openai: "gpt-5.4",
|
||||
"azure-openai-responses": "gpt-5.2",
|
||||
"openai-codex": "gpt-5.4",
|
||||
google: "gemini-2.5-pro",
|
||||
"google-gemini-cli": "gemini-2.5-pro",
|
||||
"google-antigravity": "gemini-3.1-pro-high",
|
||||
"google-vertex": "gemini-3-pro-preview",
|
||||
"github-copilot": "gpt-4o",
|
||||
openrouter: "openai/gpt-5.1-codex",
|
||||
"vercel-ai-gateway": "anthropic/claude-opus-4-6",
|
||||
xai: "grok-4-fast-non-reasoning",
|
||||
groq: "openai/gpt-oss-120b",
|
||||
cerebras: "zai-glm-4.6",
|
||||
zai: "glm-4.6",
|
||||
mistral: "devstral-medium-latest",
|
||||
minimax: "MiniMax-M2.1",
|
||||
"minimax-cn": "MiniMax-M2.1",
|
||||
huggingface: "moonshotai/Kimi-K2.5",
|
||||
opencode: "claude-opus-4-6",
|
||||
"opencode-go": "kimi-k2.5",
|
||||
"kimi-coding": "kimi-k2-thinking",
|
||||
};
|
||||
|
||||
export interface ScopedModel {
|
||||
model: Model<Api>;
|
||||
/** Thinking level if explicitly specified in pattern (e.g., "model:high"), undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to check if a model ID looks like an alias (no date suffix)
|
||||
* Dates are typically in format: -20241022 or -20250929
|
||||
*/
|
||||
function isAlias(id: string): boolean {
|
||||
// Check if ID ends with -latest
|
||||
if (id.endsWith("-latest")) return true;
|
||||
|
||||
// Check if ID ends with a date pattern (-YYYYMMDD)
|
||||
const datePattern = /-\d{8}$/;
|
||||
return !datePattern.test(id);
|
||||
}
|
||||
|
||||
/**
|
||||
* Try to match a pattern to a model from the available models list.
|
||||
* Returns the matched model or undefined if no match found.
|
||||
*/
|
||||
function tryMatchModel(modelPattern: string, availableModels: Model<Api>[]): Model<Api> | undefined {
|
||||
// Check for provider/modelId format (provider is everything before the first /)
|
||||
const slashIndex = modelPattern.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const provider = modelPattern.substring(0, slashIndex);
|
||||
const modelId = modelPattern.substring(slashIndex + 1);
|
||||
const providerMatch = availableModels.find(
|
||||
(m) => m.provider.toLowerCase() === provider.toLowerCase() && m.id.toLowerCase() === modelId.toLowerCase(),
|
||||
);
|
||||
if (providerMatch) {
|
||||
return providerMatch;
|
||||
}
|
||||
// No exact provider/model match - fall through to other matching
|
||||
}
|
||||
|
||||
// Check for exact ID match (case-insensitive)
|
||||
const exactMatch = availableModels.find((m) => m.id.toLowerCase() === modelPattern.toLowerCase());
|
||||
if (exactMatch) {
|
||||
return exactMatch;
|
||||
}
|
||||
|
||||
// No exact match - fall back to partial matching
|
||||
const matches = availableModels.filter(
|
||||
(m) =>
|
||||
m.id.toLowerCase().includes(modelPattern.toLowerCase()) ||
|
||||
m.name?.toLowerCase().includes(modelPattern.toLowerCase()),
|
||||
);
|
||||
|
||||
if (matches.length === 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Separate into aliases and dated versions
|
||||
const aliases = matches.filter((m) => isAlias(m.id));
|
||||
const datedVersions = matches.filter((m) => !isAlias(m.id));
|
||||
|
||||
if (aliases.length > 0) {
|
||||
// Prefer alias - if multiple aliases, pick the one that sorts highest
|
||||
aliases.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return aliases[0];
|
||||
} else {
|
||||
// No alias found, pick latest dated version
|
||||
datedVersions.sort((a, b) => b.id.localeCompare(a.id));
|
||||
return datedVersions[0];
|
||||
}
|
||||
}
|
||||
|
||||
export interface ParsedModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
/** Thinking level if explicitly specified in pattern, undefined otherwise */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
}
|
||||
|
||||
function buildFallbackModel(provider: string, modelId: string, availableModels: Model<Api>[]): Model<Api> | undefined {
|
||||
const providerModels = availableModels.filter((m) => m.provider === provider);
|
||||
if (providerModels.length === 0) return undefined;
|
||||
|
||||
const defaultId = defaultModelPerProvider[provider as KnownProvider];
|
||||
const baseModel = defaultId
|
||||
? (providerModels.find((m) => m.id === defaultId) ?? providerModels[0])
|
||||
: providerModels[0];
|
||||
|
||||
return {
|
||||
...baseModel,
|
||||
id: modelId,
|
||||
name: modelId,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse a pattern to extract model and thinking level.
|
||||
* Handles models with colons in their IDs (e.g., OpenRouter's :exacto suffix).
|
||||
*
|
||||
* Algorithm:
|
||||
* 1. Try to match full pattern as a model
|
||||
* 2. If found, return it with "off" thinking level
|
||||
* 3. If not found and has colons, split on last colon:
|
||||
* - If suffix is valid thinking level, use it and recurse on prefix
|
||||
* - If suffix is invalid, warn and recurse on prefix with "off"
|
||||
*
|
||||
* @internal Exported for testing
|
||||
*/
|
||||
export function parseModelPattern(
|
||||
pattern: string,
|
||||
availableModels: Model<Api>[],
|
||||
options?: { allowInvalidThinkingLevelFallback?: boolean },
|
||||
): ParsedModelResult {
|
||||
// Try exact match first
|
||||
const exactMatch = tryMatchModel(pattern, availableModels);
|
||||
if (exactMatch) {
|
||||
return { model: exactMatch, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// No match - try splitting on last colon if present
|
||||
const lastColonIndex = pattern.lastIndexOf(":");
|
||||
if (lastColonIndex === -1) {
|
||||
// No colons, pattern simply doesn't match any model
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
const prefix = pattern.substring(0, lastColonIndex);
|
||||
const suffix = pattern.substring(lastColonIndex + 1);
|
||||
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
// Valid thinking level - recurse on prefix and use this level
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
// Only use this thinking level if no warning from inner recursion
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: result.warning ? undefined : suffix,
|
||||
warning: result.warning,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
// Invalid suffix
|
||||
const allowFallback = options?.allowInvalidThinkingLevelFallback ?? true;
|
||||
if (!allowFallback) {
|
||||
// In strict mode (CLI --model parsing), treat it as part of the model id and fail.
|
||||
// This avoids accidentally resolving to a different model.
|
||||
return { model: undefined, thinkingLevel: undefined, warning: undefined };
|
||||
}
|
||||
|
||||
// Scope mode: recurse on prefix and warn
|
||||
const result = parseModelPattern(prefix, availableModels, options);
|
||||
if (result.model) {
|
||||
return {
|
||||
model: result.model,
|
||||
thinkingLevel: undefined,
|
||||
warning: `Invalid thinking level "${suffix}" in pattern "${pattern}". Using default instead.`,
|
||||
};
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve model patterns to actual Model objects with optional thinking levels
|
||||
* Format: "pattern:level" where :level is optional
|
||||
* For each pattern, finds all matching models and picks the best version:
|
||||
* 1. Prefer alias (e.g., claude-sonnet-4-5) over dated versions (claude-sonnet-4-5-20250929)
|
||||
* 2. If no alias, pick the latest dated version
|
||||
*
|
||||
* Supports models with colons in their IDs (e.g., OpenRouter's model:exacto).
|
||||
* The algorithm tries to match the full pattern first, then progressively
|
||||
* strips colon-suffixes to find a match.
|
||||
*/
|
||||
export async function resolveModelScope(patterns: string[], modelRegistry: ModelRegistry): Promise<ScopedModel[]> {
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
const scopedModels: ScopedModel[] = [];
|
||||
|
||||
for (const pattern of patterns) {
|
||||
// Check if pattern contains glob characters
|
||||
if (pattern.includes("*") || pattern.includes("?") || pattern.includes("[")) {
|
||||
// Extract optional thinking level suffix (e.g., "provider/*:high")
|
||||
const colonIdx = pattern.lastIndexOf(":");
|
||||
let globPattern = pattern;
|
||||
let thinkingLevel: ThinkingLevel | undefined;
|
||||
|
||||
if (colonIdx !== -1) {
|
||||
const suffix = pattern.substring(colonIdx + 1);
|
||||
if (isValidThinkingLevel(suffix)) {
|
||||
thinkingLevel = suffix;
|
||||
globPattern = pattern.substring(0, colonIdx);
|
||||
}
|
||||
}
|
||||
|
||||
// Match against "provider/modelId" format OR just model ID
|
||||
// This allows "*sonnet*" to match without requiring "anthropic/*sonnet*"
|
||||
const matchingModels = availableModels.filter((m) => {
|
||||
const fullId = `${m.provider}/${m.id}`;
|
||||
return minimatch(fullId, globPattern, { nocase: true }) || minimatch(m.id, globPattern, { nocase: true });
|
||||
});
|
||||
|
||||
if (matchingModels.length === 0) {
|
||||
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
|
||||
continue;
|
||||
}
|
||||
|
||||
for (const model of matchingModels) {
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(pattern, availableModels);
|
||||
|
||||
if (warning) {
|
||||
console.warn(chalk.yellow(`Warning: ${warning}`));
|
||||
}
|
||||
|
||||
if (!model) {
|
||||
console.warn(chalk.yellow(`Warning: No models match pattern "${pattern}"`));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Avoid duplicates
|
||||
if (!scopedModels.find((sm) => modelsAreEqual(sm.model, model))) {
|
||||
scopedModels.push({ model, thinkingLevel });
|
||||
}
|
||||
}
|
||||
|
||||
return scopedModels;
|
||||
}
|
||||
|
||||
export interface ResolveCliModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
warning: string | undefined;
|
||||
/**
|
||||
* Error message suitable for CLI display.
|
||||
* When set, model will be undefined.
|
||||
*/
|
||||
error: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a single model from CLI flags.
|
||||
*
|
||||
* Supports:
|
||||
* - --provider <provider> --model <pattern>
|
||||
* - --model <provider>/<pattern>
|
||||
* - Fuzzy matching (same rules as model scoping: exact id, then partial id/name)
|
||||
*
|
||||
* Note: This does not apply the thinking level by itself, but it may *parse* and
|
||||
* return a thinking level from "<pattern>:<thinking>" so the caller can apply it.
|
||||
*/
|
||||
export function resolveCliModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): ResolveCliModelResult {
|
||||
const { cliProvider, cliModel, modelRegistry } = options;
|
||||
|
||||
if (!cliModel) {
|
||||
return { model: undefined, warning: undefined, error: undefined };
|
||||
}
|
||||
|
||||
// Important: use *all* models here, not just models with pre-configured auth.
|
||||
// This allows "--api-key" to be used for first-time setup.
|
||||
const availableModels = modelRegistry.getAll();
|
||||
if (availableModels.length === 0) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: "No models available. Check your installation or add models to models.json.",
|
||||
};
|
||||
}
|
||||
|
||||
// Build canonical provider lookup (case-insensitive)
|
||||
const providerMap = new Map<string, string>();
|
||||
for (const m of availableModels) {
|
||||
providerMap.set(m.provider.toLowerCase(), m.provider);
|
||||
}
|
||||
|
||||
let provider = cliProvider ? providerMap.get(cliProvider.toLowerCase()) : undefined;
|
||||
if (cliProvider && !provider) {
|
||||
return {
|
||||
model: undefined,
|
||||
warning: undefined,
|
||||
error: `Unknown provider "${cliProvider}". Use --list-models to see available providers/models.`,
|
||||
};
|
||||
}
|
||||
|
||||
// If no explicit --provider, try to interpret "provider/model" format first.
|
||||
// When the prefix before the first slash matches a known provider, prefer that
|
||||
// interpretation over matching models whose IDs literally contain slashes
|
||||
// (e.g. "zai/glm-5" should resolve to provider=zai, model=glm-5, not to a
|
||||
// vercel-ai-gateway model with id "zai/glm-5").
|
||||
let pattern = cliModel;
|
||||
let inferredProvider = false;
|
||||
|
||||
if (!provider) {
|
||||
const slashIndex = cliModel.indexOf("/");
|
||||
if (slashIndex !== -1) {
|
||||
const maybeProvider = cliModel.substring(0, slashIndex);
|
||||
const canonical = providerMap.get(maybeProvider.toLowerCase());
|
||||
if (canonical) {
|
||||
provider = canonical;
|
||||
pattern = cliModel.substring(slashIndex + 1);
|
||||
inferredProvider = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider was inferred from the slash, try exact matches without provider inference.
|
||||
// This handles models whose IDs naturally contain slashes (e.g. OpenRouter-style IDs).
|
||||
if (!provider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) => m.id.toLowerCase() === lower || `${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return { model: exact, warning: undefined, thinkingLevel: undefined, error: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
if (cliProvider && provider) {
|
||||
// If both were provided, tolerate --model <provider>/<pattern> by stripping the provider prefix
|
||||
const prefix = `${provider}/`;
|
||||
if (cliModel.toLowerCase().startsWith(prefix.toLowerCase())) {
|
||||
pattern = cliModel.substring(prefix.length);
|
||||
}
|
||||
}
|
||||
|
||||
const candidates = provider ? availableModels.filter((m) => m.provider === provider) : availableModels;
|
||||
const { model, thinkingLevel, warning } = parseModelPattern(pattern, candidates, {
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
});
|
||||
|
||||
if (model) {
|
||||
return { model, thinkingLevel, warning, error: undefined };
|
||||
}
|
||||
|
||||
// If we inferred a provider from the slash but found no match within that provider,
|
||||
// fall back to matching the full input as a raw model id across all models.
|
||||
// This handles OpenRouter-style IDs like "openai/gpt-4o:extended" where "openai"
|
||||
// looks like a provider but the full string is actually a model id on openrouter.
|
||||
if (inferredProvider) {
|
||||
const lower = cliModel.toLowerCase();
|
||||
const exact = availableModels.find(
|
||||
(m) => m.id.toLowerCase() === lower || `${m.provider}/${m.id}`.toLowerCase() === lower,
|
||||
);
|
||||
if (exact) {
|
||||
return { model: exact, warning: undefined, thinkingLevel: undefined, error: undefined };
|
||||
}
|
||||
// Also try parseModelPattern on the full input against all models
|
||||
const fallback = parseModelPattern(cliModel, availableModels, {
|
||||
allowInvalidThinkingLevelFallback: false,
|
||||
});
|
||||
if (fallback.model) {
|
||||
return {
|
||||
model: fallback.model,
|
||||
thinkingLevel: fallback.thinkingLevel,
|
||||
warning: fallback.warning,
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
if (provider) {
|
||||
const fallbackModel = buildFallbackModel(provider, pattern, availableModels);
|
||||
if (fallbackModel) {
|
||||
const fallbackWarning = warning
|
||||
? `${warning} Model "${pattern}" not found for provider "${provider}". Using custom model id.`
|
||||
: `Model "${pattern}" not found for provider "${provider}". Using custom model id.`;
|
||||
return { model: fallbackModel, thinkingLevel: undefined, warning: fallbackWarning, error: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
const display = provider ? `${provider}/${pattern}` : cliModel;
|
||||
return {
|
||||
model: undefined,
|
||||
thinkingLevel: undefined,
|
||||
warning,
|
||||
error: `Model "${display}" not found. Use --list-models to see available models.`,
|
||||
};
|
||||
}
|
||||
|
||||
export interface InitialModelResult {
|
||||
model: Model<Api> | undefined;
|
||||
thinkingLevel: ThinkingLevel;
|
||||
fallbackMessage: string | undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Find the initial model to use based on priority:
|
||||
* 1. CLI args (provider + model)
|
||||
* 2. First model from scoped models (if not continuing/resuming)
|
||||
* 3. Restored from session (if continuing/resuming)
|
||||
* 4. Saved default from settings
|
||||
* 5. First available model with valid API key
|
||||
*/
|
||||
export async function findInitialModel(options: {
|
||||
cliProvider?: string;
|
||||
cliModel?: string;
|
||||
scopedModels: ScopedModel[];
|
||||
isContinuing: boolean;
|
||||
defaultProvider?: string;
|
||||
defaultModelId?: string;
|
||||
defaultThinkingLevel?: ThinkingLevel;
|
||||
modelRegistry: ModelRegistry;
|
||||
}): Promise<InitialModelResult> {
|
||||
const {
|
||||
cliProvider,
|
||||
cliModel,
|
||||
scopedModels,
|
||||
isContinuing,
|
||||
defaultProvider,
|
||||
defaultModelId,
|
||||
defaultThinkingLevel,
|
||||
modelRegistry,
|
||||
} = options;
|
||||
|
||||
let model: Model<Api> | undefined;
|
||||
let thinkingLevel: ThinkingLevel = DEFAULT_THINKING_LEVEL;
|
||||
|
||||
// 1. CLI args take priority
|
||||
if (cliProvider && cliModel) {
|
||||
const resolved = resolveCliModel({
|
||||
cliProvider,
|
||||
cliModel,
|
||||
modelRegistry,
|
||||
});
|
||||
if (resolved.error) {
|
||||
console.error(chalk.red(resolved.error));
|
||||
process.exit(1);
|
||||
}
|
||||
if (resolved.model) {
|
||||
return { model: resolved.model, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Use first model from scoped models (skip if continuing/resuming)
|
||||
if (scopedModels.length > 0 && !isContinuing) {
|
||||
return {
|
||||
model: scopedModels[0].model,
|
||||
thinkingLevel: scopedModels[0].thinkingLevel ?? defaultThinkingLevel ?? DEFAULT_THINKING_LEVEL,
|
||||
fallbackMessage: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// 3. Try saved default from settings
|
||||
if (defaultProvider && defaultModelId) {
|
||||
const found = modelRegistry.find(defaultProvider, defaultModelId);
|
||||
if (found) {
|
||||
model = found;
|
||||
if (defaultThinkingLevel) {
|
||||
thinkingLevel = defaultThinkingLevel;
|
||||
}
|
||||
return { model, thinkingLevel, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// 4. Try first available model with valid API key
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
||||
if (match) {
|
||||
return { model: match, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
return { model: availableModels[0], thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// 5. No model found
|
||||
return { model: undefined, thinkingLevel: DEFAULT_THINKING_LEVEL, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
/**
|
||||
* Restore model from session, with fallback to available models
|
||||
*/
|
||||
export async function restoreModelFromSession(
|
||||
savedProvider: string,
|
||||
savedModelId: string,
|
||||
currentModel: Model<Api> | undefined,
|
||||
shouldPrintMessages: boolean,
|
||||
modelRegistry: ModelRegistry,
|
||||
): Promise<{ model: Model<Api> | undefined; fallbackMessage: string | undefined }> {
|
||||
const restoredModel = modelRegistry.find(savedProvider, savedModelId);
|
||||
|
||||
// Check if restored model exists and has a valid API key
|
||||
const hasApiKey = restoredModel ? !!(await modelRegistry.getApiKey(restoredModel)) : false;
|
||||
|
||||
if (restoredModel && hasApiKey) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(chalk.dim(`Restored model: ${savedProvider}/${savedModelId}`));
|
||||
}
|
||||
return { model: restoredModel, fallbackMessage: undefined };
|
||||
}
|
||||
|
||||
// Model not found or no API key - fall back
|
||||
const reason = !restoredModel ? "model no longer exists" : "no API key available";
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.error(chalk.yellow(`Warning: Could not restore model ${savedProvider}/${savedModelId} (${reason}).`));
|
||||
}
|
||||
|
||||
// If we already have a model, use it as fallback
|
||||
if (currentModel) {
|
||||
if (shouldPrintMessages) {
|
||||
console.log(chalk.dim(`Falling back to: ${currentModel.provider}/${currentModel.id}`));
|
||||
}
|
||||
return {
|
||||
model: currentModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${currentModel.provider}/${currentModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// Try to find any available model
|
||||
const availableModels = await modelRegistry.getAvailable();
|
||||
|
||||
if (availableModels.length > 0) {
|
||||
// Try to find a default model from known providers
|
||||
let fallbackModel: Model<Api> | undefined;
|
||||
for (const provider of Object.keys(defaultModelPerProvider) as KnownProvider[]) {
|
||||
const defaultId = defaultModelPerProvider[provider];
|
||||
const match = availableModels.find((m) => m.provider === provider && m.id === defaultId);
|
||||
if (match) {
|
||||
fallbackModel = match;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// If no default found, use first available
|
||||
if (!fallbackModel) {
|
||||
fallbackModel = availableModels[0];
|
||||
}
|
||||
|
||||
if (shouldPrintMessages) {
|
||||
console.log(chalk.dim(`Falling back to: ${fallbackModel.provider}/${fallbackModel.id}`));
|
||||
}
|
||||
|
||||
return {
|
||||
model: fallbackModel,
|
||||
fallbackMessage: `Could not restore model ${savedProvider}/${savedModelId} (${reason}). Using ${fallbackModel.provider}/${fallbackModel.id}.`,
|
||||
};
|
||||
}
|
||||
|
||||
// No models available
|
||||
return { model: undefined, fallbackMessage: undefined };
|
||||
}
|
||||
1794
packages/pi-coding-agent/src/core/package-manager.ts
Normal file
1794
packages/pi-coding-agent/src/core/package-manager.ts
Normal file
File diff suppressed because it is too large
Load diff
299
packages/pi-coding-agent/src/core/prompt-templates.ts
Normal file
299
packages/pi-coding-agent/src/core/prompt-templates.ts
Normal file
|
|
@ -0,0 +1,299 @@
|
|||
import { existsSync, readdirSync, readFileSync, statSync } from "fs";
|
||||
import { homedir } from "os";
|
||||
import { basename, isAbsolute, join, resolve, sep } from "path";
|
||||
import { CONFIG_DIR_NAME, getPromptsDir } from "../config.js";
|
||||
import { parseFrontmatter } from "../utils/frontmatter.js";
|
||||
|
||||
/**
|
||||
* Represents a prompt template loaded from a markdown file
|
||||
*/
|
||||
export interface PromptTemplate {
|
||||
name: string;
|
||||
description: string;
|
||||
content: string;
|
||||
source: string; // "user", "project", or "path"
|
||||
filePath: string; // Absolute path to the template file
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse command arguments respecting quoted strings (bash-style)
|
||||
* Returns array of arguments
|
||||
*/
|
||||
export function parseCommandArgs(argsString: string): string[] {
|
||||
const args: string[] = [];
|
||||
let current = "";
|
||||
let inQuote: string | null = null;
|
||||
|
||||
for (let i = 0; i < argsString.length; i++) {
|
||||
const char = argsString[i];
|
||||
|
||||
if (inQuote) {
|
||||
if (char === inQuote) {
|
||||
inQuote = null;
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
} else if (char === '"' || char === "'") {
|
||||
inQuote = char;
|
||||
} else if (char === " " || char === "\t") {
|
||||
if (current) {
|
||||
args.push(current);
|
||||
current = "";
|
||||
}
|
||||
} else {
|
||||
current += char;
|
||||
}
|
||||
}
|
||||
|
||||
if (current) {
|
||||
args.push(current);
|
||||
}
|
||||
|
||||
return args;
|
||||
}
|
||||
|
||||
/**
|
||||
* Substitute argument placeholders in template content
|
||||
* Supports:
|
||||
* - $1, $2, ... for positional args
|
||||
* - $@ and $ARGUMENTS for all args
|
||||
* - ${@:N} for args from Nth onwards (bash-style slicing)
|
||||
* - ${@:N:L} for L args starting from Nth
|
||||
*
|
||||
* Note: Replacement happens on the template string only. Argument values
|
||||
* containing patterns like $1, $@, or $ARGUMENTS are NOT recursively substituted.
|
||||
*/
|
||||
export function substituteArgs(content: string, args: string[]): string {
|
||||
let result = content;
|
||||
|
||||
// Replace $1, $2, etc. with positional args FIRST (before wildcards)
|
||||
// This prevents wildcard replacement values containing $<digit> patterns from being re-substituted
|
||||
result = result.replace(/\$(\d+)/g, (_, num) => {
|
||||
const index = parseInt(num, 10) - 1;
|
||||
return args[index] ?? "";
|
||||
});
|
||||
|
||||
// Replace ${@:start} or ${@:start:length} with sliced args (bash-style)
|
||||
// Process BEFORE simple $@ to avoid conflicts
|
||||
result = result.replace(/\$\{@:(\d+)(?::(\d+))?\}/g, (_, startStr, lengthStr) => {
|
||||
let start = parseInt(startStr, 10) - 1; // Convert to 0-indexed (user provides 1-indexed)
|
||||
// Treat 0 as 1 (bash convention: args start at 1)
|
||||
if (start < 0) start = 0;
|
||||
|
||||
if (lengthStr) {
|
||||
const length = parseInt(lengthStr, 10);
|
||||
return args.slice(start, start + length).join(" ");
|
||||
}
|
||||
return args.slice(start).join(" ");
|
||||
});
|
||||
|
||||
// Pre-compute all args joined (optimization)
|
||||
const allArgs = args.join(" ");
|
||||
|
||||
// Replace $ARGUMENTS with all args joined (new syntax, aligns with Claude, Codex, OpenCode)
|
||||
result = result.replace(/\$ARGUMENTS/g, allArgs);
|
||||
|
||||
// Replace $@ with all args joined (existing syntax)
|
||||
result = result.replace(/\$@/g, allArgs);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
function loadTemplateFromFile(filePath: string, source: string, sourceLabel: string): PromptTemplate | null {
|
||||
try {
|
||||
const rawContent = readFileSync(filePath, "utf-8");
|
||||
const { frontmatter, body } = parseFrontmatter<Record<string, string>>(rawContent);
|
||||
|
||||
const name = basename(filePath).replace(/\.md$/, "");
|
||||
|
||||
// Get description from frontmatter or first non-empty line
|
||||
let description = frontmatter.description || "";
|
||||
if (!description) {
|
||||
const firstLine = body.split("\n").find((line) => line.trim());
|
||||
if (firstLine) {
|
||||
// Truncate if too long
|
||||
description = firstLine.slice(0, 60);
|
||||
if (firstLine.length > 60) description += "...";
|
||||
}
|
||||
}
|
||||
|
||||
// Append source to description
|
||||
description = description ? `${description} ${sourceLabel}` : sourceLabel;
|
||||
|
||||
return {
|
||||
name,
|
||||
description,
|
||||
content: body,
|
||||
source,
|
||||
filePath,
|
||||
};
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan a directory for .md files (non-recursive) and load them as prompt templates.
|
||||
*/
|
||||
function loadTemplatesFromDir(dir: string, source: string, sourceLabel: string): PromptTemplate[] {
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (!existsSync(dir)) {
|
||||
return templates;
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
|
||||
for (const entry of entries) {
|
||||
const fullPath = join(dir, entry.name);
|
||||
|
||||
// For symlinks, check if they point to a file
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
const stats = statSync(fullPath);
|
||||
isFile = stats.isFile();
|
||||
} catch {
|
||||
// Broken symlink, skip it
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (isFile && entry.name.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(fullPath, source, sourceLabel);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
return templates;
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
export interface LoadPromptTemplatesOptions {
|
||||
/** Working directory for project-local templates. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Agent config directory for global templates. Default: from getPromptsDir() */
|
||||
agentDir?: string;
|
||||
/** Explicit prompt template paths (files or directories) */
|
||||
promptPaths?: string[];
|
||||
/** Include default prompt directories. Default: true */
|
||||
includeDefaults?: boolean;
|
||||
}
|
||||
|
||||
function normalizePath(input: string): string {
|
||||
const trimmed = input.trim();
|
||||
if (trimmed === "~") return homedir();
|
||||
if (trimmed.startsWith("~/")) return join(homedir(), trimmed.slice(2));
|
||||
if (trimmed.startsWith("~")) return join(homedir(), trimmed.slice(1));
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
function resolvePromptPath(p: string, cwd: string): string {
|
||||
const normalized = normalizePath(p);
|
||||
return isAbsolute(normalized) ? normalized : resolve(cwd, normalized);
|
||||
}
|
||||
|
||||
function buildPathSourceLabel(p: string): string {
|
||||
const base = basename(p).replace(/\.md$/, "") || "path";
|
||||
return `(path:${base})`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load all prompt templates from:
|
||||
* 1. Global: agentDir/prompts/
|
||||
* 2. Project: cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
* 3. Explicit prompt paths
|
||||
*/
|
||||
export function loadPromptTemplates(options: LoadPromptTemplatesOptions = {}): PromptTemplate[] {
|
||||
const resolvedCwd = options.cwd ?? process.cwd();
|
||||
const resolvedAgentDir = options.agentDir ?? getPromptsDir();
|
||||
const promptPaths = options.promptPaths ?? [];
|
||||
const includeDefaults = options.includeDefaults ?? true;
|
||||
|
||||
const templates: PromptTemplate[] = [];
|
||||
|
||||
if (includeDefaults) {
|
||||
// 1. Load global templates from agentDir/prompts/
|
||||
// Note: if agentDir is provided, it should be the agent dir, not the prompts dir
|
||||
const globalPromptsDir = options.agentDir ? join(options.agentDir, "prompts") : resolvedAgentDir;
|
||||
templates.push(...loadTemplatesFromDir(globalPromptsDir, "user", "(user)"));
|
||||
|
||||
// 2. Load project templates from cwd/{CONFIG_DIR_NAME}/prompts/
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
templates.push(...loadTemplatesFromDir(projectPromptsDir, "project", "(project)"));
|
||||
}
|
||||
|
||||
const userPromptsDir = options.agentDir ? join(options.agentDir, "prompts") : resolvedAgentDir;
|
||||
const projectPromptsDir = resolve(resolvedCwd, CONFIG_DIR_NAME, "prompts");
|
||||
|
||||
const isUnderPath = (target: string, root: string): boolean => {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep) ? normalizedRoot : `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
};
|
||||
|
||||
const getSourceInfo = (resolvedPath: string): { source: string; label: string } => {
|
||||
if (!includeDefaults) {
|
||||
if (isUnderPath(resolvedPath, userPromptsDir)) {
|
||||
return { source: "user", label: "(user)" };
|
||||
}
|
||||
if (isUnderPath(resolvedPath, projectPromptsDir)) {
|
||||
return { source: "project", label: "(project)" };
|
||||
}
|
||||
}
|
||||
return { source: "path", label: buildPathSourceLabel(resolvedPath) };
|
||||
};
|
||||
|
||||
// 3. Load explicit prompt paths
|
||||
for (const rawPath of promptPaths) {
|
||||
const resolvedPath = resolvePromptPath(rawPath, resolvedCwd);
|
||||
if (!existsSync(resolvedPath)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolvedPath);
|
||||
const { source, label } = getSourceInfo(resolvedPath);
|
||||
if (stats.isDirectory()) {
|
||||
templates.push(...loadTemplatesFromDir(resolvedPath, source, label));
|
||||
} else if (stats.isFile() && resolvedPath.endsWith(".md")) {
|
||||
const template = loadTemplateFromFile(resolvedPath, source, label);
|
||||
if (template) {
|
||||
templates.push(template);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Ignore read failures
|
||||
}
|
||||
}
|
||||
|
||||
return templates;
|
||||
}
|
||||
|
||||
/**
|
||||
* Expand a prompt template if it matches a template name.
|
||||
* Returns the expanded content or the original text if not a template.
|
||||
*/
|
||||
export function expandPromptTemplate(text: string, templates: PromptTemplate[]): string {
|
||||
if (!text.startsWith("/")) return text;
|
||||
|
||||
const spaceIndex = text.indexOf(" ");
|
||||
const templateName = spaceIndex === -1 ? text.slice(1) : text.slice(1, spaceIndex);
|
||||
const argsString = spaceIndex === -1 ? "" : text.slice(spaceIndex + 1);
|
||||
|
||||
const template = templates.find((t) => t.name === templateName);
|
||||
if (template) {
|
||||
const args = parseCommandArgs(argsString);
|
||||
return substituteArgs(template.content, args);
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
64
packages/pi-coding-agent/src/core/resolve-config-value.ts
Normal file
64
packages/pi-coding-agent/src/core/resolve-config-value.ts
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
/**
|
||||
* Resolve configuration values that may be shell commands, environment variables, or literals.
|
||||
* Used by auth-storage.ts and model-registry.ts.
|
||||
*/
|
||||
|
||||
import { execSync } from "child_process";
|
||||
|
||||
// Cache for shell command results (persists for process lifetime)
|
||||
const commandResultCache = new Map<string, string | undefined>();
|
||||
|
||||
/**
|
||||
* Resolve a config value (API key, header value, etc.) to an actual value.
|
||||
* - If starts with "!", executes the rest as a shell command and uses stdout (cached)
|
||||
* - Otherwise checks environment variable first, then treats as literal (not cached)
|
||||
*/
|
||||
export function resolveConfigValue(config: string): string | undefined {
|
||||
if (config.startsWith("!")) {
|
||||
return executeCommand(config);
|
||||
}
|
||||
const envValue = process.env[config];
|
||||
return envValue || config;
|
||||
}
|
||||
|
||||
function executeCommand(commandConfig: string): string | undefined {
|
||||
if (commandResultCache.has(commandConfig)) {
|
||||
return commandResultCache.get(commandConfig);
|
||||
}
|
||||
|
||||
const command = commandConfig.slice(1);
|
||||
let result: string | undefined;
|
||||
try {
|
||||
const output = execSync(command, {
|
||||
encoding: "utf-8",
|
||||
timeout: 10000,
|
||||
stdio: ["ignore", "pipe", "ignore"],
|
||||
});
|
||||
result = output.trim() || undefined;
|
||||
} catch {
|
||||
result = undefined;
|
||||
}
|
||||
|
||||
commandResultCache.set(commandConfig, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve all header values using the same resolution logic as API keys.
|
||||
*/
|
||||
export function resolveHeaders(headers: Record<string, string> | undefined): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
const resolvedValue = resolveConfigValue(value);
|
||||
if (resolvedValue) {
|
||||
resolved[key] = resolvedValue;
|
||||
}
|
||||
}
|
||||
return Object.keys(resolved).length > 0 ? resolved : undefined;
|
||||
}
|
||||
|
||||
/** Clear the config value command cache. Exported for testing. */
|
||||
export function clearConfigValueCache(): void {
|
||||
commandResultCache.clear();
|
||||
}
|
||||
868
packages/pi-coding-agent/src/core/resource-loader.ts
Normal file
868
packages/pi-coding-agent/src/core/resource-loader.ts
Normal file
|
|
@ -0,0 +1,868 @@
|
|||
import { existsSync, readdirSync, readFileSync, statSync } from "node:fs";
|
||||
import { homedir } from "node:os";
|
||||
import { join, resolve, sep } from "node:path";
|
||||
import chalk from "chalk";
|
||||
import { CONFIG_DIR_NAME, getAgentDir } from "../config.js";
|
||||
import { loadThemeFromPath, type Theme } from "../modes/interactive/theme/theme.js";
|
||||
import type { ResourceDiagnostic } from "./diagnostics.js";
|
||||
|
||||
export type { ResourceCollision, ResourceDiagnostic } from "./diagnostics.js";
|
||||
|
||||
import { createEventBus, type EventBus } from "./event-bus.js";
|
||||
import { createExtensionRuntime, loadExtensionFromFactory, loadExtensions } from "./extensions/loader.js";
|
||||
import type { Extension, ExtensionFactory, ExtensionRuntime, LoadExtensionsResult } from "./extensions/types.js";
|
||||
import { DefaultPackageManager, type PathMetadata } from "./package-manager.js";
|
||||
import type { PromptTemplate } from "./prompt-templates.js";
|
||||
import { loadPromptTemplates } from "./prompt-templates.js";
|
||||
import { SettingsManager } from "./settings-manager.js";
|
||||
import type { Skill } from "./skills.js";
|
||||
import { loadSkills } from "./skills.js";
|
||||
|
||||
export interface ResourceExtensionPaths {
|
||||
skillPaths?: Array<{ path: string; metadata: PathMetadata }>;
|
||||
promptPaths?: Array<{ path: string; metadata: PathMetadata }>;
|
||||
themePaths?: Array<{ path: string; metadata: PathMetadata }>;
|
||||
}
|
||||
|
||||
export interface ResourceLoader {
|
||||
getExtensions(): LoadExtensionsResult;
|
||||
getSkills(): { skills: Skill[]; diagnostics: ResourceDiagnostic[] };
|
||||
getPrompts(): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] };
|
||||
getThemes(): { themes: Theme[]; diagnostics: ResourceDiagnostic[] };
|
||||
getAgentsFiles(): { agentsFiles: Array<{ path: string; content: string }> };
|
||||
getSystemPrompt(): string | undefined;
|
||||
getAppendSystemPrompt(): string[];
|
||||
getPathMetadata(): Map<string, PathMetadata>;
|
||||
extendResources(paths: ResourceExtensionPaths): void;
|
||||
reload(): Promise<void>;
|
||||
}
|
||||
|
||||
function resolvePromptInput(input: string | undefined, description: string): string | undefined {
|
||||
if (!input) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (existsSync(input)) {
|
||||
try {
|
||||
return readFileSync(input, "utf-8");
|
||||
} catch (error) {
|
||||
console.error(chalk.yellow(`Warning: Could not read ${description} file ${input}: ${error}`));
|
||||
return input;
|
||||
}
|
||||
}
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
function loadContextFileFromDir(dir: string): { path: string; content: string } | null {
|
||||
const candidates = ["AGENTS.md", "CLAUDE.md"];
|
||||
for (const filename of candidates) {
|
||||
const filePath = join(dir, filename);
|
||||
if (existsSync(filePath)) {
|
||||
try {
|
||||
return {
|
||||
path: filePath,
|
||||
content: readFileSync(filePath, "utf-8"),
|
||||
};
|
||||
} catch (error) {
|
||||
console.error(chalk.yellow(`Warning: Could not read ${filePath}: ${error}`));
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
function loadProjectContextFiles(
|
||||
options: { cwd?: string; agentDir?: string } = {},
|
||||
): Array<{ path: string; content: string }> {
|
||||
const resolvedCwd = options.cwd ?? process.cwd();
|
||||
const resolvedAgentDir = options.agentDir ?? getAgentDir();
|
||||
|
||||
const contextFiles: Array<{ path: string; content: string }> = [];
|
||||
const seenPaths = new Set<string>();
|
||||
|
||||
const globalContext = loadContextFileFromDir(resolvedAgentDir);
|
||||
if (globalContext) {
|
||||
contextFiles.push(globalContext);
|
||||
seenPaths.add(globalContext.path);
|
||||
}
|
||||
|
||||
const ancestorContextFiles: Array<{ path: string; content: string }> = [];
|
||||
|
||||
let currentDir = resolvedCwd;
|
||||
const root = resolve("/");
|
||||
|
||||
while (true) {
|
||||
const contextFile = loadContextFileFromDir(currentDir);
|
||||
if (contextFile && !seenPaths.has(contextFile.path)) {
|
||||
ancestorContextFiles.unshift(contextFile);
|
||||
seenPaths.add(contextFile.path);
|
||||
}
|
||||
|
||||
if (currentDir === root) break;
|
||||
|
||||
const parentDir = resolve(currentDir, "..");
|
||||
if (parentDir === currentDir) break;
|
||||
currentDir = parentDir;
|
||||
}
|
||||
|
||||
contextFiles.push(...ancestorContextFiles);
|
||||
|
||||
return contextFiles;
|
||||
}
|
||||
|
||||
export interface DefaultResourceLoaderOptions {
|
||||
cwd?: string;
|
||||
agentDir?: string;
|
||||
settingsManager?: SettingsManager;
|
||||
eventBus?: EventBus;
|
||||
additionalExtensionPaths?: string[];
|
||||
additionalSkillPaths?: string[];
|
||||
additionalPromptTemplatePaths?: string[];
|
||||
additionalThemePaths?: string[];
|
||||
extensionFactories?: ExtensionFactory[];
|
||||
noExtensions?: boolean;
|
||||
noSkills?: boolean;
|
||||
noPromptTemplates?: boolean;
|
||||
noThemes?: boolean;
|
||||
systemPrompt?: string;
|
||||
appendSystemPrompt?: string;
|
||||
extensionsOverride?: (base: LoadExtensionsResult) => LoadExtensionsResult;
|
||||
skillsOverride?: (base: { skills: Skill[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
skills: Skill[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
promptsOverride?: (base: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
prompts: PromptTemplate[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
themesOverride?: (base: { themes: Theme[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
themes: Theme[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
agentsFilesOverride?: (base: { agentsFiles: Array<{ path: string; content: string }> }) => {
|
||||
agentsFiles: Array<{ path: string; content: string }>;
|
||||
};
|
||||
systemPromptOverride?: (base: string | undefined) => string | undefined;
|
||||
appendSystemPromptOverride?: (base: string[]) => string[];
|
||||
}
|
||||
|
||||
export class DefaultResourceLoader implements ResourceLoader {
|
||||
private cwd: string;
|
||||
private agentDir: string;
|
||||
private settingsManager: SettingsManager;
|
||||
private eventBus: EventBus;
|
||||
private packageManager: DefaultPackageManager;
|
||||
private additionalExtensionPaths: string[];
|
||||
private additionalSkillPaths: string[];
|
||||
private additionalPromptTemplatePaths: string[];
|
||||
private additionalThemePaths: string[];
|
||||
private extensionFactories: ExtensionFactory[];
|
||||
private noExtensions: boolean;
|
||||
private noSkills: boolean;
|
||||
private noPromptTemplates: boolean;
|
||||
private noThemes: boolean;
|
||||
private systemPromptSource?: string;
|
||||
private appendSystemPromptSource?: string;
|
||||
private extensionsOverride?: (base: LoadExtensionsResult) => LoadExtensionsResult;
|
||||
private skillsOverride?: (base: { skills: Skill[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
skills: Skill[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
private promptsOverride?: (base: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
prompts: PromptTemplate[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
private themesOverride?: (base: { themes: Theme[]; diagnostics: ResourceDiagnostic[] }) => {
|
||||
themes: Theme[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
};
|
||||
private agentsFilesOverride?: (base: { agentsFiles: Array<{ path: string; content: string }> }) => {
|
||||
agentsFiles: Array<{ path: string; content: string }>;
|
||||
};
|
||||
private systemPromptOverride?: (base: string | undefined) => string | undefined;
|
||||
private appendSystemPromptOverride?: (base: string[]) => string[];
|
||||
|
||||
private extensionsResult: LoadExtensionsResult;
|
||||
private skills: Skill[];
|
||||
private skillDiagnostics: ResourceDiagnostic[];
|
||||
private prompts: PromptTemplate[];
|
||||
private promptDiagnostics: ResourceDiagnostic[];
|
||||
private themes: Theme[];
|
||||
private themeDiagnostics: ResourceDiagnostic[];
|
||||
private agentsFiles: Array<{ path: string; content: string }>;
|
||||
private systemPrompt?: string;
|
||||
private appendSystemPrompt: string[];
|
||||
private pathMetadata: Map<string, PathMetadata>;
|
||||
private lastSkillPaths: string[];
|
||||
private lastPromptPaths: string[];
|
||||
private lastThemePaths: string[];
|
||||
|
||||
constructor(options: DefaultResourceLoaderOptions) {
|
||||
this.cwd = options.cwd ?? process.cwd();
|
||||
this.agentDir = options.agentDir ?? getAgentDir();
|
||||
this.settingsManager = options.settingsManager ?? SettingsManager.create(this.cwd, this.agentDir);
|
||||
this.eventBus = options.eventBus ?? createEventBus();
|
||||
this.packageManager = new DefaultPackageManager({
|
||||
cwd: this.cwd,
|
||||
agentDir: this.agentDir,
|
||||
settingsManager: this.settingsManager,
|
||||
});
|
||||
this.additionalExtensionPaths = options.additionalExtensionPaths ?? [];
|
||||
this.additionalSkillPaths = options.additionalSkillPaths ?? [];
|
||||
this.additionalPromptTemplatePaths = options.additionalPromptTemplatePaths ?? [];
|
||||
this.additionalThemePaths = options.additionalThemePaths ?? [];
|
||||
this.extensionFactories = options.extensionFactories ?? [];
|
||||
this.noExtensions = options.noExtensions ?? false;
|
||||
this.noSkills = options.noSkills ?? false;
|
||||
this.noPromptTemplates = options.noPromptTemplates ?? false;
|
||||
this.noThemes = options.noThemes ?? false;
|
||||
this.systemPromptSource = options.systemPrompt;
|
||||
this.appendSystemPromptSource = options.appendSystemPrompt;
|
||||
this.extensionsOverride = options.extensionsOverride;
|
||||
this.skillsOverride = options.skillsOverride;
|
||||
this.promptsOverride = options.promptsOverride;
|
||||
this.themesOverride = options.themesOverride;
|
||||
this.agentsFilesOverride = options.agentsFilesOverride;
|
||||
this.systemPromptOverride = options.systemPromptOverride;
|
||||
this.appendSystemPromptOverride = options.appendSystemPromptOverride;
|
||||
|
||||
this.extensionsResult = { extensions: [], errors: [], runtime: createExtensionRuntime() };
|
||||
this.skills = [];
|
||||
this.skillDiagnostics = [];
|
||||
this.prompts = [];
|
||||
this.promptDiagnostics = [];
|
||||
this.themes = [];
|
||||
this.themeDiagnostics = [];
|
||||
this.agentsFiles = [];
|
||||
this.appendSystemPrompt = [];
|
||||
this.pathMetadata = new Map();
|
||||
this.lastSkillPaths = [];
|
||||
this.lastPromptPaths = [];
|
||||
this.lastThemePaths = [];
|
||||
}
|
||||
|
||||
getExtensions(): LoadExtensionsResult {
|
||||
return this.extensionsResult;
|
||||
}
|
||||
|
||||
getSkills(): { skills: Skill[]; diagnostics: ResourceDiagnostic[] } {
|
||||
return { skills: this.skills, diagnostics: this.skillDiagnostics };
|
||||
}
|
||||
|
||||
getPrompts(): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] } {
|
||||
return { prompts: this.prompts, diagnostics: this.promptDiagnostics };
|
||||
}
|
||||
|
||||
getThemes(): { themes: Theme[]; diagnostics: ResourceDiagnostic[] } {
|
||||
return { themes: this.themes, diagnostics: this.themeDiagnostics };
|
||||
}
|
||||
|
||||
getAgentsFiles(): { agentsFiles: Array<{ path: string; content: string }> } {
|
||||
return { agentsFiles: this.agentsFiles };
|
||||
}
|
||||
|
||||
getSystemPrompt(): string | undefined {
|
||||
return this.systemPrompt;
|
||||
}
|
||||
|
||||
getAppendSystemPrompt(): string[] {
|
||||
return this.appendSystemPrompt;
|
||||
}
|
||||
|
||||
getPathMetadata(): Map<string, PathMetadata> {
|
||||
return this.pathMetadata;
|
||||
}
|
||||
|
||||
extendResources(paths: ResourceExtensionPaths): void {
|
||||
const skillPaths = this.normalizeExtensionPaths(paths.skillPaths ?? []);
|
||||
const promptPaths = this.normalizeExtensionPaths(paths.promptPaths ?? []);
|
||||
const themePaths = this.normalizeExtensionPaths(paths.themePaths ?? []);
|
||||
|
||||
if (skillPaths.length > 0) {
|
||||
this.lastSkillPaths = this.mergePaths(
|
||||
this.lastSkillPaths,
|
||||
skillPaths.map((entry) => entry.path),
|
||||
);
|
||||
this.updateSkillsFromPaths(this.lastSkillPaths, skillPaths);
|
||||
}
|
||||
|
||||
if (promptPaths.length > 0) {
|
||||
this.lastPromptPaths = this.mergePaths(
|
||||
this.lastPromptPaths,
|
||||
promptPaths.map((entry) => entry.path),
|
||||
);
|
||||
this.updatePromptsFromPaths(this.lastPromptPaths, promptPaths);
|
||||
}
|
||||
|
||||
if (themePaths.length > 0) {
|
||||
this.lastThemePaths = this.mergePaths(
|
||||
this.lastThemePaths,
|
||||
themePaths.map((entry) => entry.path),
|
||||
);
|
||||
this.updateThemesFromPaths(this.lastThemePaths, themePaths);
|
||||
}
|
||||
}
|
||||
|
||||
async reload(): Promise<void> {
|
||||
const resolvedPaths = await this.packageManager.resolve();
|
||||
const cliExtensionPaths = await this.packageManager.resolveExtensionSources(this.additionalExtensionPaths, {
|
||||
temporary: true,
|
||||
});
|
||||
|
||||
// Helper to extract enabled paths and store metadata
|
||||
const getEnabledResources = (
|
||||
resources: Array<{ path: string; enabled: boolean; metadata: PathMetadata }>,
|
||||
): Array<{ path: string; enabled: boolean; metadata: PathMetadata }> => {
|
||||
for (const r of resources) {
|
||||
if (!this.pathMetadata.has(r.path)) {
|
||||
this.pathMetadata.set(r.path, r.metadata);
|
||||
}
|
||||
}
|
||||
return resources.filter((r) => r.enabled);
|
||||
};
|
||||
|
||||
const getEnabledPaths = (
|
||||
resources: Array<{ path: string; enabled: boolean; metadata: PathMetadata }>,
|
||||
): string[] => getEnabledResources(resources).map((r) => r.path);
|
||||
|
||||
// Store metadata and get enabled paths
|
||||
this.pathMetadata = new Map();
|
||||
const enabledExtensions = getEnabledPaths(resolvedPaths.extensions);
|
||||
const enabledSkillResources = getEnabledResources(resolvedPaths.skills);
|
||||
const enabledPrompts = getEnabledPaths(resolvedPaths.prompts);
|
||||
const enabledThemes = getEnabledPaths(resolvedPaths.themes);
|
||||
|
||||
const mapSkillPath = (resource: { path: string; metadata: PathMetadata }): string => {
|
||||
if (resource.metadata.source !== "auto" && resource.metadata.origin !== "package") {
|
||||
return resource.path;
|
||||
}
|
||||
try {
|
||||
const stats = statSync(resource.path);
|
||||
if (!stats.isDirectory()) {
|
||||
return resource.path;
|
||||
}
|
||||
} catch {
|
||||
return resource.path;
|
||||
}
|
||||
const skillFile = join(resource.path, "SKILL.md");
|
||||
if (existsSync(skillFile)) {
|
||||
if (!this.pathMetadata.has(skillFile)) {
|
||||
this.pathMetadata.set(skillFile, resource.metadata);
|
||||
}
|
||||
return skillFile;
|
||||
}
|
||||
return resource.path;
|
||||
};
|
||||
|
||||
const enabledSkills = enabledSkillResources.map(mapSkillPath);
|
||||
|
||||
// Add CLI paths metadata
|
||||
for (const r of cliExtensionPaths.extensions) {
|
||||
if (!this.pathMetadata.has(r.path)) {
|
||||
this.pathMetadata.set(r.path, { source: "cli", scope: "temporary", origin: "top-level" });
|
||||
}
|
||||
}
|
||||
for (const r of cliExtensionPaths.skills) {
|
||||
if (!this.pathMetadata.has(r.path)) {
|
||||
this.pathMetadata.set(r.path, { source: "cli", scope: "temporary", origin: "top-level" });
|
||||
}
|
||||
}
|
||||
|
||||
const cliEnabledExtensions = getEnabledPaths(cliExtensionPaths.extensions);
|
||||
const cliEnabledSkills = getEnabledPaths(cliExtensionPaths.skills);
|
||||
const cliEnabledPrompts = getEnabledPaths(cliExtensionPaths.prompts);
|
||||
const cliEnabledThemes = getEnabledPaths(cliExtensionPaths.themes);
|
||||
|
||||
const extensionPaths = this.noExtensions
|
||||
? cliEnabledExtensions
|
||||
: this.mergePaths(cliEnabledExtensions, enabledExtensions);
|
||||
|
||||
const extensionsResult = await loadExtensions(extensionPaths, this.cwd, this.eventBus);
|
||||
const inlineExtensions = await this.loadExtensionFactories(extensionsResult.runtime);
|
||||
extensionsResult.extensions.push(...inlineExtensions.extensions);
|
||||
extensionsResult.errors.push(...inlineExtensions.errors);
|
||||
|
||||
// Detect extension conflicts (tools, commands, flags with same names from different extensions)
|
||||
// Keep all extensions loaded. Conflicts are reported as diagnostics, and precedence is handled by load order.
|
||||
const conflicts = this.detectExtensionConflicts(extensionsResult.extensions);
|
||||
for (const conflict of conflicts) {
|
||||
extensionsResult.errors.push({ path: conflict.path, error: conflict.message });
|
||||
}
|
||||
|
||||
this.extensionsResult = this.extensionsOverride ? this.extensionsOverride(extensionsResult) : extensionsResult;
|
||||
|
||||
const skillPaths = this.noSkills
|
||||
? this.mergePaths(cliEnabledSkills, this.additionalSkillPaths)
|
||||
: this.mergePaths([...enabledSkills, ...cliEnabledSkills], this.additionalSkillPaths);
|
||||
|
||||
this.lastSkillPaths = skillPaths;
|
||||
this.updateSkillsFromPaths(skillPaths);
|
||||
|
||||
const promptPaths = this.noPromptTemplates
|
||||
? this.mergePaths(cliEnabledPrompts, this.additionalPromptTemplatePaths)
|
||||
: this.mergePaths([...enabledPrompts, ...cliEnabledPrompts], this.additionalPromptTemplatePaths);
|
||||
|
||||
this.lastPromptPaths = promptPaths;
|
||||
this.updatePromptsFromPaths(promptPaths);
|
||||
|
||||
const themePaths = this.noThemes
|
||||
? this.mergePaths(cliEnabledThemes, this.additionalThemePaths)
|
||||
: this.mergePaths([...enabledThemes, ...cliEnabledThemes], this.additionalThemePaths);
|
||||
|
||||
this.lastThemePaths = themePaths;
|
||||
this.updateThemesFromPaths(themePaths);
|
||||
|
||||
for (const extension of this.extensionsResult.extensions) {
|
||||
this.addDefaultMetadataForPath(extension.path);
|
||||
}
|
||||
|
||||
const agentsFiles = { agentsFiles: loadProjectContextFiles({ cwd: this.cwd, agentDir: this.agentDir }) };
|
||||
const resolvedAgentsFiles = this.agentsFilesOverride ? this.agentsFilesOverride(agentsFiles) : agentsFiles;
|
||||
this.agentsFiles = resolvedAgentsFiles.agentsFiles;
|
||||
|
||||
const baseSystemPrompt = resolvePromptInput(
|
||||
this.systemPromptSource ?? this.discoverSystemPromptFile(),
|
||||
"system prompt",
|
||||
);
|
||||
this.systemPrompt = this.systemPromptOverride ? this.systemPromptOverride(baseSystemPrompt) : baseSystemPrompt;
|
||||
|
||||
const appendSource = this.appendSystemPromptSource ?? this.discoverAppendSystemPromptFile();
|
||||
const resolvedAppend = resolvePromptInput(appendSource, "append system prompt");
|
||||
const baseAppend = resolvedAppend ? [resolvedAppend] : [];
|
||||
this.appendSystemPrompt = this.appendSystemPromptOverride
|
||||
? this.appendSystemPromptOverride(baseAppend)
|
||||
: baseAppend;
|
||||
}
|
||||
|
||||
private normalizeExtensionPaths(
|
||||
entries: Array<{ path: string; metadata: PathMetadata }>,
|
||||
): Array<{ path: string; metadata: PathMetadata }> {
|
||||
return entries.map((entry) => ({
|
||||
path: this.resolveResourcePath(entry.path),
|
||||
metadata: entry.metadata,
|
||||
}));
|
||||
}
|
||||
|
||||
private updateSkillsFromPaths(
|
||||
skillPaths: string[],
|
||||
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
|
||||
): void {
|
||||
let skillsResult: { skills: Skill[]; diagnostics: ResourceDiagnostic[] };
|
||||
if (this.noSkills && skillPaths.length === 0) {
|
||||
skillsResult = { skills: [], diagnostics: [] };
|
||||
} else {
|
||||
skillsResult = loadSkills({
|
||||
cwd: this.cwd,
|
||||
agentDir: this.agentDir,
|
||||
skillPaths,
|
||||
includeDefaults: false,
|
||||
});
|
||||
}
|
||||
const resolvedSkills = this.skillsOverride ? this.skillsOverride(skillsResult) : skillsResult;
|
||||
this.skills = resolvedSkills.skills;
|
||||
this.skillDiagnostics = resolvedSkills.diagnostics;
|
||||
this.applyExtensionMetadata(
|
||||
extensionPaths,
|
||||
this.skills.map((skill) => skill.filePath),
|
||||
);
|
||||
for (const skill of this.skills) {
|
||||
this.addDefaultMetadataForPath(skill.filePath);
|
||||
}
|
||||
}
|
||||
|
||||
private updatePromptsFromPaths(
|
||||
promptPaths: string[],
|
||||
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
|
||||
): void {
|
||||
let promptsResult: { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] };
|
||||
if (this.noPromptTemplates && promptPaths.length === 0) {
|
||||
promptsResult = { prompts: [], diagnostics: [] };
|
||||
} else {
|
||||
const allPrompts = loadPromptTemplates({
|
||||
cwd: this.cwd,
|
||||
agentDir: this.agentDir,
|
||||
promptPaths,
|
||||
includeDefaults: false,
|
||||
});
|
||||
promptsResult = this.dedupePrompts(allPrompts);
|
||||
}
|
||||
const resolvedPrompts = this.promptsOverride ? this.promptsOverride(promptsResult) : promptsResult;
|
||||
this.prompts = resolvedPrompts.prompts;
|
||||
this.promptDiagnostics = resolvedPrompts.diagnostics;
|
||||
this.applyExtensionMetadata(
|
||||
extensionPaths,
|
||||
this.prompts.map((prompt) => prompt.filePath),
|
||||
);
|
||||
for (const prompt of this.prompts) {
|
||||
this.addDefaultMetadataForPath(prompt.filePath);
|
||||
}
|
||||
}
|
||||
|
||||
private updateThemesFromPaths(
|
||||
themePaths: string[],
|
||||
extensionPaths: Array<{ path: string; metadata: PathMetadata }> = [],
|
||||
): void {
|
||||
let themesResult: { themes: Theme[]; diagnostics: ResourceDiagnostic[] };
|
||||
if (this.noThemes && themePaths.length === 0) {
|
||||
themesResult = { themes: [], diagnostics: [] };
|
||||
} else {
|
||||
const loaded = this.loadThemes(themePaths, false);
|
||||
const deduped = this.dedupeThemes(loaded.themes);
|
||||
themesResult = { themes: deduped.themes, diagnostics: [...loaded.diagnostics, ...deduped.diagnostics] };
|
||||
}
|
||||
const resolvedThemes = this.themesOverride ? this.themesOverride(themesResult) : themesResult;
|
||||
this.themes = resolvedThemes.themes;
|
||||
this.themeDiagnostics = resolvedThemes.diagnostics;
|
||||
const themePathsWithSource = this.themes.flatMap((theme) => (theme.sourcePath ? [theme.sourcePath] : []));
|
||||
this.applyExtensionMetadata(extensionPaths, themePathsWithSource);
|
||||
for (const theme of this.themes) {
|
||||
if (theme.sourcePath) {
|
||||
this.addDefaultMetadataForPath(theme.sourcePath);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private applyExtensionMetadata(
|
||||
extensionPaths: Array<{ path: string; metadata: PathMetadata }>,
|
||||
resourcePaths: string[],
|
||||
): void {
|
||||
if (extensionPaths.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalized = extensionPaths.map((entry) => ({
|
||||
path: resolve(entry.path),
|
||||
metadata: entry.metadata,
|
||||
}));
|
||||
|
||||
for (const entry of normalized) {
|
||||
if (!this.pathMetadata.has(entry.path)) {
|
||||
this.pathMetadata.set(entry.path, entry.metadata);
|
||||
}
|
||||
}
|
||||
|
||||
for (const resourcePath of resourcePaths) {
|
||||
const normalizedResourcePath = resolve(resourcePath);
|
||||
if (this.pathMetadata.has(normalizedResourcePath) || this.pathMetadata.has(resourcePath)) {
|
||||
continue;
|
||||
}
|
||||
const match = normalized.find(
|
||||
(entry) =>
|
||||
normalizedResourcePath === entry.path || normalizedResourcePath.startsWith(`${entry.path}${sep}`),
|
||||
);
|
||||
if (match) {
|
||||
this.pathMetadata.set(normalizedResourcePath, match.metadata);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private mergePaths(primary: string[], additional: string[]): string[] {
|
||||
const merged: string[] = [];
|
||||
const seen = new Set<string>();
|
||||
|
||||
for (const p of [...primary, ...additional]) {
|
||||
const resolved = this.resolveResourcePath(p);
|
||||
if (seen.has(resolved)) continue;
|
||||
seen.add(resolved);
|
||||
merged.push(resolved);
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
private resolveResourcePath(p: string): string {
|
||||
const trimmed = p.trim();
|
||||
let expanded = trimmed;
|
||||
if (trimmed === "~") {
|
||||
expanded = homedir();
|
||||
} else if (trimmed.startsWith("~/")) {
|
||||
expanded = join(homedir(), trimmed.slice(2));
|
||||
} else if (trimmed.startsWith("~")) {
|
||||
expanded = join(homedir(), trimmed.slice(1));
|
||||
}
|
||||
return resolve(this.cwd, expanded);
|
||||
}
|
||||
|
||||
private loadThemes(
|
||||
paths: string[],
|
||||
includeDefaults: boolean = true,
|
||||
): {
|
||||
themes: Theme[];
|
||||
diagnostics: ResourceDiagnostic[];
|
||||
} {
|
||||
const themes: Theme[] = [];
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
if (includeDefaults) {
|
||||
const defaultDirs = [join(this.agentDir, "themes"), join(this.cwd, CONFIG_DIR_NAME, "themes")];
|
||||
|
||||
for (const dir of defaultDirs) {
|
||||
this.loadThemesFromDir(dir, themes, diagnostics);
|
||||
}
|
||||
}
|
||||
|
||||
for (const p of paths) {
|
||||
const resolved = resolve(this.cwd, p);
|
||||
if (!existsSync(resolved)) {
|
||||
diagnostics.push({ type: "warning", message: "theme path does not exist", path: resolved });
|
||||
continue;
|
||||
}
|
||||
|
||||
try {
|
||||
const stats = statSync(resolved);
|
||||
if (stats.isDirectory()) {
|
||||
this.loadThemesFromDir(resolved, themes, diagnostics);
|
||||
} else if (stats.isFile() && resolved.endsWith(".json")) {
|
||||
this.loadThemeFromFile(resolved, themes, diagnostics);
|
||||
} else {
|
||||
diagnostics.push({ type: "warning", message: "theme path is not a json file", path: resolved });
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "failed to read theme path";
|
||||
diagnostics.push({ type: "warning", message, path: resolved });
|
||||
}
|
||||
}
|
||||
|
||||
return { themes, diagnostics };
|
||||
}
|
||||
|
||||
private loadThemesFromDir(dir: string, themes: Theme[], diagnostics: ResourceDiagnostic[]): void {
|
||||
if (!existsSync(dir)) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const entries = readdirSync(dir, { withFileTypes: true });
|
||||
for (const entry of entries) {
|
||||
let isFile = entry.isFile();
|
||||
if (entry.isSymbolicLink()) {
|
||||
try {
|
||||
isFile = statSync(join(dir, entry.name)).isFile();
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if (!isFile) {
|
||||
continue;
|
||||
}
|
||||
if (!entry.name.endsWith(".json")) {
|
||||
continue;
|
||||
}
|
||||
this.loadThemeFromFile(join(dir, entry.name), themes, diagnostics);
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "failed to read theme directory";
|
||||
diagnostics.push({ type: "warning", message, path: dir });
|
||||
}
|
||||
}
|
||||
|
||||
private loadThemeFromFile(filePath: string, themes: Theme[], diagnostics: ResourceDiagnostic[]): void {
|
||||
try {
|
||||
themes.push(loadThemeFromPath(filePath));
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "failed to load theme";
|
||||
diagnostics.push({ type: "warning", message, path: filePath });
|
||||
}
|
||||
}
|
||||
|
||||
private async loadExtensionFactories(runtime: ExtensionRuntime): Promise<{
|
||||
extensions: Extension[];
|
||||
errors: Array<{ path: string; error: string }>;
|
||||
}> {
|
||||
const extensions: Extension[] = [];
|
||||
const errors: Array<{ path: string; error: string }> = [];
|
||||
|
||||
for (const [index, factory] of this.extensionFactories.entries()) {
|
||||
const extensionPath = `<inline:${index + 1}>`;
|
||||
try {
|
||||
const extension = await loadExtensionFromFactory(factory, this.cwd, this.eventBus, runtime, extensionPath);
|
||||
extensions.push(extension);
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "failed to load extension";
|
||||
errors.push({ path: extensionPath, error: message });
|
||||
}
|
||||
}
|
||||
|
||||
return { extensions, errors };
|
||||
}
|
||||
|
||||
private dedupePrompts(prompts: PromptTemplate[]): { prompts: PromptTemplate[]; diagnostics: ResourceDiagnostic[] } {
|
||||
const seen = new Map<string, PromptTemplate>();
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
for (const prompt of prompts) {
|
||||
const existing = seen.get(prompt.name);
|
||||
if (existing) {
|
||||
diagnostics.push({
|
||||
type: "collision",
|
||||
message: `name "/${prompt.name}" collision`,
|
||||
path: prompt.filePath,
|
||||
collision: {
|
||||
resourceType: "prompt",
|
||||
name: prompt.name,
|
||||
winnerPath: existing.filePath,
|
||||
loserPath: prompt.filePath,
|
||||
},
|
||||
});
|
||||
} else {
|
||||
seen.set(prompt.name, prompt);
|
||||
}
|
||||
}
|
||||
|
||||
return { prompts: Array.from(seen.values()), diagnostics };
|
||||
}
|
||||
|
||||
private dedupeThemes(themes: Theme[]): { themes: Theme[]; diagnostics: ResourceDiagnostic[] } {
|
||||
const seen = new Map<string, Theme>();
|
||||
const diagnostics: ResourceDiagnostic[] = [];
|
||||
|
||||
for (const t of themes) {
|
||||
const name = t.name ?? "unnamed";
|
||||
const existing = seen.get(name);
|
||||
if (existing) {
|
||||
diagnostics.push({
|
||||
type: "collision",
|
||||
message: `name "${name}" collision`,
|
||||
path: t.sourcePath,
|
||||
collision: {
|
||||
resourceType: "theme",
|
||||
name,
|
||||
winnerPath: existing.sourcePath ?? "<builtin>",
|
||||
loserPath: t.sourcePath ?? "<builtin>",
|
||||
},
|
||||
});
|
||||
} else {
|
||||
seen.set(name, t);
|
||||
}
|
||||
}
|
||||
|
||||
return { themes: Array.from(seen.values()), diagnostics };
|
||||
}
|
||||
|
||||
private discoverSystemPromptFile(): string | undefined {
|
||||
const projectPath = join(this.cwd, CONFIG_DIR_NAME, "SYSTEM.md");
|
||||
if (existsSync(projectPath)) {
|
||||
return projectPath;
|
||||
}
|
||||
|
||||
const globalPath = join(this.agentDir, "SYSTEM.md");
|
||||
if (existsSync(globalPath)) {
|
||||
return globalPath;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private discoverAppendSystemPromptFile(): string | undefined {
|
||||
const projectPath = join(this.cwd, CONFIG_DIR_NAME, "APPEND_SYSTEM.md");
|
||||
if (existsSync(projectPath)) {
|
||||
return projectPath;
|
||||
}
|
||||
|
||||
const globalPath = join(this.agentDir, "APPEND_SYSTEM.md");
|
||||
if (existsSync(globalPath)) {
|
||||
return globalPath;
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
private addDefaultMetadataForPath(filePath: string): void {
|
||||
if (!filePath || filePath.startsWith("<")) {
|
||||
return;
|
||||
}
|
||||
|
||||
const normalizedPath = resolve(filePath);
|
||||
if (this.pathMetadata.has(normalizedPath) || this.pathMetadata.has(filePath)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const agentRoots = [
|
||||
join(this.agentDir, "skills"),
|
||||
join(this.agentDir, "prompts"),
|
||||
join(this.agentDir, "themes"),
|
||||
join(this.agentDir, "extensions"),
|
||||
];
|
||||
const projectRoots = [
|
||||
join(this.cwd, CONFIG_DIR_NAME, "skills"),
|
||||
join(this.cwd, CONFIG_DIR_NAME, "prompts"),
|
||||
join(this.cwd, CONFIG_DIR_NAME, "themes"),
|
||||
join(this.cwd, CONFIG_DIR_NAME, "extensions"),
|
||||
];
|
||||
|
||||
for (const root of agentRoots) {
|
||||
if (this.isUnderPath(normalizedPath, root)) {
|
||||
this.pathMetadata.set(normalizedPath, { source: "local", scope: "user", origin: "top-level" });
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
for (const root of projectRoots) {
|
||||
if (this.isUnderPath(normalizedPath, root)) {
|
||||
this.pathMetadata.set(normalizedPath, { source: "local", scope: "project", origin: "top-level" });
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private isUnderPath(target: string, root: string): boolean {
|
||||
const normalizedRoot = resolve(root);
|
||||
if (target === normalizedRoot) {
|
||||
return true;
|
||||
}
|
||||
const prefix = normalizedRoot.endsWith(sep) ? normalizedRoot : `${normalizedRoot}${sep}`;
|
||||
return target.startsWith(prefix);
|
||||
}
|
||||
|
||||
private detectExtensionConflicts(extensions: Extension[]): Array<{ path: string; message: string }> {
|
||||
const conflicts: Array<{ path: string; message: string }> = [];
|
||||
|
||||
// Track which extension registered each tool, command, and flag
|
||||
const toolOwners = new Map<string, string>();
|
||||
const commandOwners = new Map<string, string>();
|
||||
const flagOwners = new Map<string, string>();
|
||||
|
||||
for (const ext of extensions) {
|
||||
// Check tools
|
||||
for (const toolName of ext.tools.keys()) {
|
||||
const existingOwner = toolOwners.get(toolName);
|
||||
if (existingOwner && existingOwner !== ext.path) {
|
||||
conflicts.push({
|
||||
path: ext.path,
|
||||
message: `Tool "${toolName}" conflicts with ${existingOwner}`,
|
||||
});
|
||||
} else {
|
||||
toolOwners.set(toolName, ext.path);
|
||||
}
|
||||
}
|
||||
|
||||
// Check commands
|
||||
for (const commandName of ext.commands.keys()) {
|
||||
const existingOwner = commandOwners.get(commandName);
|
||||
if (existingOwner && existingOwner !== ext.path) {
|
||||
conflicts.push({
|
||||
path: ext.path,
|
||||
message: `Command "/${commandName}" conflicts with ${existingOwner}`,
|
||||
});
|
||||
} else {
|
||||
commandOwners.set(commandName, ext.path);
|
||||
}
|
||||
}
|
||||
|
||||
// Check flags
|
||||
for (const flagName of ext.flags.keys()) {
|
||||
const existingOwner = flagOwners.get(flagName);
|
||||
if (existingOwner && existingOwner !== ext.path) {
|
||||
conflicts.push({
|
||||
path: ext.path,
|
||||
message: `Flag "--${flagName}" conflicts with ${existingOwner}`,
|
||||
});
|
||||
} else {
|
||||
flagOwners.set(flagName, ext.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return conflicts;
|
||||
}
|
||||
}
|
||||
373
packages/pi-coding-agent/src/core/sdk.ts
Normal file
373
packages/pi-coding-agent/src/core/sdk.ts
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
import { join } from "node:path";
|
||||
import { Agent, type AgentMessage, type ThinkingLevel } from "@gsd/pi-agent-core";
|
||||
import type { Message, Model } from "@gsd/pi-ai";
|
||||
import { getAgentDir, getDocsPath } from "../config.js";
|
||||
import { AgentSession } from "./agent-session.js";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
|
||||
import type { ExtensionRunner, LoadExtensionsResult, ToolDefinition } from "./extensions/index.js";
|
||||
import { convertToLlm } from "./messages.js";
|
||||
import { ModelRegistry } from "./model-registry.js";
|
||||
import { findInitialModel } from "./model-resolver.js";
|
||||
import type { ResourceLoader } from "./resource-loader.js";
|
||||
import { DefaultResourceLoader } from "./resource-loader.js";
|
||||
import { SessionManager } from "./session-manager.js";
|
||||
import { SettingsManager } from "./settings-manager.js";
|
||||
import { time } from "./timings.js";
|
||||
import {
|
||||
allTools,
|
||||
bashTool,
|
||||
codingTools,
|
||||
createBashTool,
|
||||
createCodingTools,
|
||||
createEditTool,
|
||||
createFindTool,
|
||||
createGrepTool,
|
||||
createLsTool,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createWriteTool,
|
||||
editTool,
|
||||
findTool,
|
||||
grepTool,
|
||||
lsTool,
|
||||
readOnlyTools,
|
||||
readTool,
|
||||
type Tool,
|
||||
type ToolName,
|
||||
writeTool,
|
||||
} from "./tools/index.js";
|
||||
|
||||
export interface CreateAgentSessionOptions {
|
||||
/** Working directory for project-local discovery. Default: process.cwd() */
|
||||
cwd?: string;
|
||||
/** Global config directory. Default: ~/.pi/agent */
|
||||
agentDir?: string;
|
||||
|
||||
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
model?: Model<any>;
|
||||
/** Thinking level. Default: from settings, else 'medium' (clamped to model capabilities) */
|
||||
thinkingLevel?: ThinkingLevel;
|
||||
/** Models available for cycling (Ctrl+P in interactive mode) */
|
||||
scopedModels?: Array<{ model: Model<any>; thinkingLevel?: ThinkingLevel }>;
|
||||
|
||||
/** Built-in tools to use. Default: codingTools [read, bash, edit, write] */
|
||||
tools?: Tool[];
|
||||
/** Custom tools to register (in addition to built-in tools). */
|
||||
customTools?: ToolDefinition[];
|
||||
|
||||
/** Resource loader. When omitted, DefaultResourceLoader is used. */
|
||||
resourceLoader?: ResourceLoader;
|
||||
|
||||
/** Session manager. Default: SessionManager.create(cwd) */
|
||||
sessionManager?: SessionManager;
|
||||
|
||||
/** Settings manager. Default: SettingsManager.create(cwd, agentDir) */
|
||||
settingsManager?: SettingsManager;
|
||||
}
|
||||
|
||||
/** Result from createAgentSession */
|
||||
export interface CreateAgentSessionResult {
|
||||
/** The created session */
|
||||
session: AgentSession;
|
||||
/** Extensions result (for UI context setup in interactive mode) */
|
||||
extensionsResult: LoadExtensionsResult;
|
||||
/** Warning if session was restored with a different model than saved */
|
||||
modelFallbackMessage?: string;
|
||||
}
|
||||
|
||||
// Re-exports
|
||||
|
||||
export type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionContext,
|
||||
ExtensionFactory,
|
||||
SlashCommandInfo,
|
||||
SlashCommandLocation,
|
||||
SlashCommandSource,
|
||||
ToolDefinition,
|
||||
} from "./extensions/index.js";
|
||||
export type { PromptTemplate } from "./prompt-templates.js";
|
||||
export type { Skill } from "./skills.js";
|
||||
export type { Tool } from "./tools/index.js";
|
||||
|
||||
export {
|
||||
// Pre-built tools (use process.cwd())
|
||||
readTool,
|
||||
bashTool,
|
||||
editTool,
|
||||
writeTool,
|
||||
grepTool,
|
||||
findTool,
|
||||
lsTool,
|
||||
codingTools,
|
||||
readOnlyTools,
|
||||
allTools as allBuiltInTools,
|
||||
// Tool factories (for custom cwd)
|
||||
createCodingTools,
|
||||
createReadOnlyTools,
|
||||
createReadTool,
|
||||
createBashTool,
|
||||
createEditTool,
|
||||
createWriteTool,
|
||||
createGrepTool,
|
||||
createFindTool,
|
||||
createLsTool,
|
||||
};
|
||||
|
||||
// Helper Functions
|
||||
|
||||
function getDefaultAgentDir(): string {
|
||||
return getAgentDir();
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an AgentSession with the specified options.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // Minimal - uses defaults
|
||||
* const { session } = await createAgentSession();
|
||||
*
|
||||
* // With explicit model
|
||||
* import { getModel } from '@gsd/pi-ai';
|
||||
* const { session } = await createAgentSession({
|
||||
* model: getModel('anthropic', 'claude-opus-4-5'),
|
||||
* thinkingLevel: 'high',
|
||||
* });
|
||||
*
|
||||
* // Continue previous session
|
||||
* const { session, modelFallbackMessage } = await createAgentSession({
|
||||
* continueSession: true,
|
||||
* });
|
||||
*
|
||||
* // Full control
|
||||
* const loader = new DefaultResourceLoader({
|
||||
* cwd: process.cwd(),
|
||||
* agentDir: getAgentDir(),
|
||||
* settingsManager: SettingsManager.create(),
|
||||
* });
|
||||
* await loader.reload();
|
||||
* const { session } = await createAgentSession({
|
||||
* model: myModel,
|
||||
* tools: [readTool, bashTool],
|
||||
* resourceLoader: loader,
|
||||
* sessionManager: SessionManager.inMemory(),
|
||||
* });
|
||||
* ```
|
||||
*/
|
||||
export async function createAgentSession(options: CreateAgentSessionOptions = {}): Promise<CreateAgentSessionResult> {
|
||||
const cwd = options.cwd ?? process.cwd();
|
||||
const agentDir = options.agentDir ?? getDefaultAgentDir();
|
||||
let resourceLoader = options.resourceLoader;
|
||||
|
||||
// Use provided or create AuthStorage and ModelRegistry
|
||||
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
|
||||
const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined;
|
||||
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
|
||||
const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
|
||||
|
||||
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
|
||||
|
||||
if (!resourceLoader) {
|
||||
resourceLoader = new DefaultResourceLoader({ cwd, agentDir, settingsManager });
|
||||
await resourceLoader.reload();
|
||||
time("resourceLoader.reload");
|
||||
}
|
||||
|
||||
// Check if session has existing data to restore
|
||||
const existingSession = sessionManager.buildSessionContext();
|
||||
const hasExistingSession = existingSession.messages.length > 0;
|
||||
const hasThinkingEntry = sessionManager.getBranch().some((entry) => entry.type === "thinking_level_change");
|
||||
|
||||
let model = options.model;
|
||||
let modelFallbackMessage: string | undefined;
|
||||
|
||||
// If session has data, try to restore model from it
|
||||
if (!model && hasExistingSession && existingSession.model) {
|
||||
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
|
||||
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
|
||||
model = restoredModel;
|
||||
}
|
||||
if (!model) {
|
||||
modelFallbackMessage = `Could not restore model ${existingSession.model.provider}/${existingSession.model.modelId}`;
|
||||
}
|
||||
}
|
||||
|
||||
// If still no model, use findInitialModel (checks settings default, then provider defaults)
|
||||
if (!model) {
|
||||
const result = await findInitialModel({
|
||||
scopedModels: [],
|
||||
isContinuing: hasExistingSession,
|
||||
defaultProvider: settingsManager.getDefaultProvider(),
|
||||
defaultModelId: settingsManager.getDefaultModel(),
|
||||
defaultThinkingLevel: settingsManager.getDefaultThinkingLevel(),
|
||||
modelRegistry,
|
||||
});
|
||||
model = result.model;
|
||||
if (!model) {
|
||||
modelFallbackMessage = `No models available. Use /login or set an API key environment variable. See ${join(getDocsPath(), "providers.md")}. Then use /model to select a model.`;
|
||||
} else if (modelFallbackMessage) {
|
||||
modelFallbackMessage += `. Using ${model.provider}/${model.id}`;
|
||||
}
|
||||
}
|
||||
|
||||
let thinkingLevel = options.thinkingLevel;
|
||||
|
||||
// If session has data, restore thinking level from it
|
||||
if (thinkingLevel === undefined && hasExistingSession) {
|
||||
thinkingLevel = hasThinkingEntry
|
||||
? (existingSession.thinkingLevel as ThinkingLevel)
|
||||
: (settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL);
|
||||
}
|
||||
|
||||
// Fall back to settings default
|
||||
if (thinkingLevel === undefined) {
|
||||
thinkingLevel = settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
|
||||
}
|
||||
|
||||
// Clamp to model capabilities
|
||||
if (!model || !model.reasoning) {
|
||||
thinkingLevel = "off";
|
||||
}
|
||||
|
||||
const defaultActiveToolNames: ToolName[] = ["read", "bash", "edit", "write"];
|
||||
const initialActiveToolNames: ToolName[] = options.tools
|
||||
? options.tools.map((t) => t.name).filter((n): n is ToolName => n in allTools)
|
||||
: defaultActiveToolNames;
|
||||
|
||||
let agent: Agent;
|
||||
|
||||
// Create convertToLlm wrapper that filters images if blockImages is enabled (defense-in-depth)
|
||||
const convertToLlmWithBlockImages = (messages: AgentMessage[]): Message[] => {
|
||||
const converted = convertToLlm(messages);
|
||||
// Check setting dynamically so mid-session changes take effect
|
||||
if (!settingsManager.getBlockImages()) {
|
||||
return converted;
|
||||
}
|
||||
// Filter out ImageContent from all messages, replacing with text placeholder
|
||||
return converted.map((msg) => {
|
||||
if (msg.role === "user" || msg.role === "toolResult") {
|
||||
const content = msg.content;
|
||||
if (Array.isArray(content)) {
|
||||
const hasImages = content.some((c) => c.type === "image");
|
||||
if (hasImages) {
|
||||
const filteredContent = content
|
||||
.map((c) =>
|
||||
c.type === "image" ? { type: "text" as const, text: "Image reading is disabled." } : c,
|
||||
)
|
||||
.filter(
|
||||
(c, i, arr) =>
|
||||
// Dedupe consecutive "Image reading is disabled." texts
|
||||
!(
|
||||
c.type === "text" &&
|
||||
c.text === "Image reading is disabled." &&
|
||||
i > 0 &&
|
||||
arr[i - 1].type === "text" &&
|
||||
(arr[i - 1] as { type: "text"; text: string }).text === "Image reading is disabled."
|
||||
),
|
||||
);
|
||||
return { ...msg, content: filteredContent };
|
||||
}
|
||||
}
|
||||
}
|
||||
return msg;
|
||||
});
|
||||
};
|
||||
|
||||
const extensionRunnerRef: { current?: ExtensionRunner } = {};
|
||||
|
||||
agent = new Agent({
|
||||
initialState: {
|
||||
systemPrompt: "",
|
||||
model,
|
||||
thinkingLevel,
|
||||
tools: [],
|
||||
},
|
||||
convertToLlm: convertToLlmWithBlockImages,
|
||||
onPayload: async (payload, _model) => {
|
||||
const runner = extensionRunnerRef.current;
|
||||
if (!runner?.hasHandlers("before_provider_request")) {
|
||||
return payload;
|
||||
}
|
||||
return runner.emitBeforeProviderRequest(payload);
|
||||
},
|
||||
sessionId: sessionManager.getSessionId(),
|
||||
transformContext: async (messages) => {
|
||||
const runner = extensionRunnerRef.current;
|
||||
if (!runner) return messages;
|
||||
return runner.emitContext(messages);
|
||||
},
|
||||
steeringMode: settingsManager.getSteeringMode(),
|
||||
followUpMode: settingsManager.getFollowUpMode(),
|
||||
transport: settingsManager.getTransport(),
|
||||
thinkingBudgets: settingsManager.getThinkingBudgets(),
|
||||
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
|
||||
getApiKey: async (provider) => {
|
||||
// Use the provider argument from the in-flight request;
|
||||
// agent.state.model may already be switched mid-turn.
|
||||
const resolvedProvider = provider || agent.state.model?.provider;
|
||||
if (!resolvedProvider) {
|
||||
throw new Error("No model selected");
|
||||
}
|
||||
const key = await modelRegistry.getApiKeyForProvider(resolvedProvider);
|
||||
if (!key) {
|
||||
const model = agent.state.model;
|
||||
const isOAuth = model && modelRegistry.isUsingOAuth(model);
|
||||
if (isOAuth) {
|
||||
throw new Error(
|
||||
`Authentication failed for "${resolvedProvider}". ` +
|
||||
`Credentials may have expired or network is unavailable. ` +
|
||||
`Run '/login ${resolvedProvider}' to re-authenticate.`,
|
||||
);
|
||||
}
|
||||
throw new Error(
|
||||
`No API key found for "${resolvedProvider}". ` +
|
||||
`Set an API key environment variable or run '/login ${resolvedProvider}'.`,
|
||||
);
|
||||
}
|
||||
return key;
|
||||
},
|
||||
});
|
||||
|
||||
// Restore messages if session has existing data
|
||||
if (hasExistingSession) {
|
||||
agent.replaceMessages(existingSession.messages);
|
||||
if (!hasThinkingEntry) {
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
} else {
|
||||
// Save initial model and thinking level for new sessions so they can be restored on resume
|
||||
if (model) {
|
||||
sessionManager.appendModelChange(model.provider, model.id);
|
||||
}
|
||||
sessionManager.appendThinkingLevelChange(thinkingLevel);
|
||||
}
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd,
|
||||
scopedModels: options.scopedModels,
|
||||
resourceLoader,
|
||||
customTools: options.customTools,
|
||||
modelRegistry,
|
||||
initialActiveToolNames,
|
||||
extensionRunnerRef,
|
||||
});
|
||||
const extensionsResult = resourceLoader.getExtensions();
|
||||
|
||||
return {
|
||||
session,
|
||||
extensionsResult,
|
||||
modelFallbackMessage,
|
||||
};
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue