From 7053938f7d2bcddc54bb1832d73ea9e94a91b4cd Mon Sep 17 00:00:00 2001 From: Mikael Hugo Date: Sat, 2 May 2026 13:32:05 +0200 Subject: [PATCH] fix(gemini): keep cli tools in pi harness --- packages/pi-agent-core/src/agent-loop.test.ts | 259 +++++++++++++++--- packages/pi-agent-core/src/agent-loop.ts | 187 ++++++++++--- packages/pi-coding-agent/src/core/sdk.test.ts | 44 ++- packages/pi-coding-agent/src/core/sdk.ts | 178 ++++++++---- 4 files changed, 535 insertions(+), 133 deletions(-) diff --git a/packages/pi-agent-core/src/agent-loop.test.ts b/packages/pi-agent-core/src/agent-loop.test.ts index 9ead3d2a0..54e62679b 100644 --- a/packages/pi-agent-core/src/agent-loop.test.ts +++ b/packages/pi-agent-core/src/agent-loop.test.ts @@ -1,16 +1,28 @@ // agent-loop tests // Covers: pauseTurn handling (#2869), schema overload retry cap (#2783) -import { describe, it } from 'vitest'; import assert from "node:assert/strict"; import { readFileSync } from "node:fs"; -import { join, dirname } from "node:path"; +import { dirname, join } from "node:path"; import { fileURLToPath } from "node:url"; import { Type } from "@sinclair/typebox"; -import { agentLoop, MAX_CONSECUTIVE_VALIDATION_FAILURES } from "./agent-loop.js"; -import type { AgentContext, AgentLoopConfig, AgentTool, AgentEvent, AgentMessage } from "./types.js"; -import { AssistantMessageEventStream, EventStream } from "@singularity-forge/pi-ai"; -import type { AssistantMessage, AssistantMessageEvent, Model } from "@singularity-forge/pi-ai"; +import type { AssistantMessage, Model } from "@singularity-forge/pi-ai"; +import { + AssistantMessageEventStream, + type EventStream, +} from "@singularity-forge/pi-ai"; +import { describe, it } from "vitest"; +import { + agentLoop, + MAX_CONSECUTIVE_VALIDATION_FAILURES, +} from "./agent-loop.js"; +import type { + AgentContext, + AgentEvent, + AgentLoopConfig, + AgentMessage, + AgentTool, +} from "./types.js"; const __dirname = dirname(fileURLToPath(import.meta.url)); @@ -69,7 +81,13 @@ describe("agent-loop — pauseTurn handling (#2869)", () => { const context: AgentContext = { systemPrompt: "You are a test agent.", - messages: [{ role: "user", content: [{ type: "text", text: "Run the command" }], timestamp: Date.now() }], + messages: [ + { + role: "user", + content: [{ type: "text", text: "Run the command" }], + timestamp: Date.now(), + }, + ], tools: [], }; @@ -81,7 +99,13 @@ describe("agent-loop — pauseTurn handling (#2869)", () => { }; const stream = agentLoop( - [{ role: "user", content: [{ type: "text", text: "Run the command" }], timestamp: Date.now() }], + [ + { + role: "user", + content: [{ type: "text", text: "Run the command" }], + timestamp: Date.now(), + }, + ], context, config, undefined, @@ -90,7 +114,8 @@ describe("agent-loop — pauseTurn handling (#2869)", () => { const events = await collectEvents(stream); const toolEnd = events.find( - (event): event is Extract => event.type === "tool_execution_end", + (event): event is Extract => + event.type === "tool_execution_end", ); assert.ok(toolEnd, "expected tool_execution_end event"); @@ -98,6 +123,71 @@ describe("agent-loop — pauseTurn handling (#2869)", () => { assert.deepEqual(toolEnd.result.details, { source: "claude-code" }); assert.equal(toolEnd.isError, false); }); + + it("uses a neutral provider-executed fallback when no external result is attached", async () => { + const externalMessage = makeAssistantMessage({ + content: [ + { + type: "toolCall", + id: "tc-external-fallback", + name: "read", + arguments: { filePath: ".sf/BACKLOG.md" }, + }, + ], + stopReason: "toolUse", + provider: "claude-code", + }); + + const mockStream = createMockStreamFn([externalMessage]); + + const context: AgentContext = { + systemPrompt: "You are a test agent.", + messages: [ + { + role: "user", + content: [{ type: "text", text: "Read backlog" }], + timestamp: Date.now(), + }, + ], + tools: [], + }; + + const config: AgentLoopConfig = { + model: { ...TEST_MODEL, provider: "claude-code" }, + convertToLlm: (msgs) => msgs.filter((m): m is any => m.role !== "custom"), + toolExecution: "sequential", + externalToolExecution: true, + }; + + const stream = agentLoop( + [ + { + role: "user", + content: [{ type: "text", text: "Read backlog" }], + timestamp: Date.now(), + }, + ], + context, + config, + undefined, + mockStream as any, + ); + + const events = await collectEvents(stream); + const toolEnd = events.find( + (event): event is Extract => + event.type === "tool_execution_end", + ); + + assert.ok(toolEnd, "expected tool_execution_end event"); + assert.deepEqual(toolEnd.result.content, [ + { type: "text", text: "(executed by provider)" }, + ]); + assert.equal( + JSON.stringify(toolEnd.result.content).includes("Claude Code"), + false, + ); + }); }); describe("agent-loop — steering during tool batches", () => { @@ -197,8 +287,7 @@ describe("agent-loop — steering during tool batches", () => { assert.equal(skipped.length, 0); assert.ok( events.some( - (event) => - event.type === "message_start" && event.message === steering, + (event) => event.type === "message_start" && event.message === steering, ), "system steering should still be delivered after the tool batch", ); @@ -299,8 +388,7 @@ describe("agent-loop — steering during tool batches", () => { assert.equal(skipped.length, 0); assert.ok( events.some( - (event) => - event.type === "message_start" && event.message === steering, + (event) => event.type === "message_start" && event.message === steering, ), "queued steering should still be delivered after the tool batch", ); @@ -375,21 +463,32 @@ function createMockStreamFn(responses: AssistantMessage[]) { }; } -function makeAssistantMessage(overrides: Partial = {}): AssistantMessage { +function makeAssistantMessage( + overrides: Partial = {}, +): AssistantMessage { return { role: "assistant", content: [], api: "anthropic-messages", provider: "anthropic", model: "claude-test", - usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, totalTokens: 150, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } }, + usage: { + input: 100, + output: 50, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 150, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, stopReason: "stop", timestamp: Date.now(), ...overrides, }; } -function makeToolCallMessage(toolCallArgs: Record): AssistantMessage { +function makeToolCallMessage( + toolCallArgs: Record, +): AssistantMessage { return makeAssistantMessage({ content: [ { @@ -403,26 +502,32 @@ function makeToolCallMessage(toolCallArgs: Record): AssistantMe }); } -function collectEvents(stream: EventStream): Promise { - return new Promise(async (resolve) => { +function collectEvents( + stream: EventStream, +): Promise { + return new Promise((resolve) => { const events: AgentEvent[] = []; - for await (const event of stream) { - events.push(event); - } - resolve(events); + void (async () => { + for await (const event of stream) { + events.push(event); + } + resolve(events); + })(); }); } // ─── Tests ──────────────────────────────────────────────────────────────────── describe("agent-loop — schema overload retry cap (#2783)", () => { - it("terminates after MAX_CONSECUTIVE_VALIDATION_FAILURES consecutive schema failures", async () => { const tool = makeToolWithSchema(); // LLM keeps sending tool calls with invalid args (missing required 'content' field) const badToolCall = makeToolCallMessage({ path: "/tmp/test" }); // missing 'content' - const finalStop = makeAssistantMessage({ content: [{ type: "text", text: "I give up." }], stopReason: "stop" }); + const finalStop = makeAssistantMessage({ + content: [{ type: "text", text: "I give up." }], + stopReason: "stop", + }); // Create enough bad responses to exceed the cap, plus a final stop const responses: AssistantMessage[] = []; @@ -435,7 +540,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { const context: AgentContext = { systemPrompt: "You are a test agent.", - messages: [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }], + messages: [ + { + role: "user", + content: [{ type: "text", text: "Write a file" }], + timestamp: Date.now(), + }, + ], tools: [tool], }; @@ -446,7 +557,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { }; const stream = agentLoop( - [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }], + [ + { + role: "user", + content: [{ type: "text", text: "Write a file" }], + timestamp: Date.now(), + }, + ], context, config, undefined, @@ -457,7 +574,10 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { // Must have terminated (agent_end event present) const agentEnd = events.find((e) => e.type === "agent_end"); - assert.ok(agentEnd, "agent loop must emit agent_end after hitting retry cap"); + assert.ok( + agentEnd, + "agent loop must emit agent_end after hitting retry cap", + ); // Count how many turns had validation errors (tool_execution_end with isError: true) const toolErrors = events.filter( @@ -476,15 +596,35 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { // Pattern: 2 failures, 1 success, 2 failures, 1 success, then stop const badCall = makeToolCallMessage({ path: "/tmp/test" }); // missing 'content' - const goodCall = makeToolCallMessage({ path: "/tmp/test", content: "hello" }); - const finalStop = makeAssistantMessage({ content: [{ type: "text", text: "Done." }], stopReason: "stop" }); + const goodCall = makeToolCallMessage({ + path: "/tmp/test", + content: "hello", + }); + const finalStop = makeAssistantMessage({ + content: [{ type: "text", text: "Done." }], + stopReason: "stop", + }); - const responses = [badCall, badCall, goodCall, badCall, badCall, goodCall, finalStop]; + const responses = [ + badCall, + badCall, + goodCall, + badCall, + badCall, + goodCall, + finalStop, + ]; const mockStream = createMockStreamFn(responses); const context: AgentContext = { systemPrompt: "You are a test agent.", - messages: [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }], + messages: [ + { + role: "user", + content: [{ type: "text", text: "Write a file" }], + timestamp: Date.now(), + }, + ], tools: [tool], }; @@ -495,7 +635,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { }; const stream = agentLoop( - [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }], + [ + { + role: "user", + content: [{ type: "text", text: "Write a file" }], + timestamp: Date.now(), + }, + ], context, config, undefined, @@ -506,17 +652,29 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { // Must complete successfully since failures never reached cap consecutively const agentEnd = events.find((e) => e.type === "agent_end"); - assert.ok(agentEnd, "agent loop must complete normally when failures are interspersed with successes"); + assert.ok( + agentEnd, + "agent loop must complete normally when failures are interspersed with successes", + ); // Should have processed all 6 tool-bearing turns const toolExecEnds = events.filter((e) => e.type === "tool_execution_end"); - assert.ok(toolExecEnds.length >= 4, `Expected at least 4 tool executions (2 bad + 1 good + 2 bad + 1 good), got ${toolExecEnds.length}`); + assert.ok( + toolExecEnds.length >= 4, + `Expected at least 4 tool executions (2 bad + 1 good + 2 bad + 1 good), got ${toolExecEnds.length}`, + ); }); it("exports MAX_CONSECUTIVE_VALIDATION_FAILURES as a configurable constant", () => { assert.equal(typeof MAX_CONSECUTIVE_VALIDATION_FAILURES, "number"); - assert.ok(MAX_CONSECUTIVE_VALIDATION_FAILURES >= 2, "Cap must be at least 2 to allow one retry"); - assert.ok(MAX_CONSECUTIVE_VALIDATION_FAILURES <= 10, "Cap must not be unreasonably high"); + assert.ok( + MAX_CONSECUTIVE_VALIDATION_FAILURES >= 2, + "Cap must be at least 2 to allow one retry", + ); + assert.ok( + MAX_CONSECUTIVE_VALIDATION_FAILURES <= 10, + "Cap must not be unreasonably high", + ); }); it("does NOT trip schema overload cap on tool execution errors like bash exit code 1 (#3618)", async () => { @@ -566,7 +724,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { const context: AgentContext = { systemPrompt: "You are a test agent.", - messages: [{ role: "user", content: [{ type: "text", text: "Search for references" }], timestamp: Date.now() }], + messages: [ + { + role: "user", + content: [{ type: "text", text: "Search for references" }], + timestamp: Date.now(), + }, + ], tools: [bashTool], }; @@ -577,7 +741,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { }; const stream = agentLoop( - [{ role: "user", content: [{ type: "text", text: "Search for references" }], timestamp: Date.now() }], + [ + { + role: "user", + content: [{ type: "text", text: "Search for references" }], + timestamp: Date.now(), + }, + ], context, config, undefined, @@ -604,12 +774,17 @@ describe("agent-loop — schema overload retry cap (#2783)", () => { // The stop message should NOT contain the schema overload text const allMessages = (agentEnd as any).messages as AgentMessage[]; const lastMessage = allMessages[allMessages.length - 1]; - const lastText = lastMessage.role === "assistant" - ? (lastMessage as AssistantMessage).content.find((c) => c.type === "text") - : undefined; + const lastText = + lastMessage.role === "assistant" + ? (lastMessage as AssistantMessage).content.find( + (c) => c.type === "text", + ) + : undefined; if (lastText && lastText.type === "text") { assert.ok( - !lastText.text.includes("consecutive turns with all tool calls failing"), + !lastText.text.includes( + "consecutive turns with all tool calls failing", + ), "Final message must NOT contain schema overload stop text for execution-only errors", ); } diff --git a/packages/pi-agent-core/src/agent-loop.ts b/packages/pi-agent-core/src/agent-loop.ts index 9d909c211..86cf7dd8c 100644 --- a/packages/pi-agent-core/src/agent-loop.ts +++ b/packages/pi-agent-core/src/agent-loop.ts @@ -44,7 +44,10 @@ export const ZERO_USAGE = { * Build an AssistantMessage for an unhandled error caught outside runLoop. * Uses the model from config so the message satisfies the full interface. */ -function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantMessage { +function createErrorMessage( + error: unknown, + config: AgentLoopConfig, +): AssistantMessage { const msg = error instanceof Error ? error.message : String(error); return { role: "assistant", @@ -62,7 +65,10 @@ function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantM /** * Emit a message_start + message_end pair for a single message. */ -function emitMessagePair(stream: EventStream, message: AgentMessage): void { +function emitMessagePair( + stream: EventStream, + message: AgentMessage, +): void { stream.push({ type: "message_start", message }); stream.push({ type: "message_end", message }); } @@ -109,7 +115,14 @@ export function agentLoop( } try { - await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + await runLoop( + currentContext, + newMessages, + config, + signal, + stream, + streamFn, + ); } catch (error) { emitErrorSequence(stream, createErrorMessage(error, config), newMessages); } @@ -153,7 +166,14 @@ export function agentLoopContinue( stream.push({ type: "turn_start" }); try { - await runLoop(currentContext, newMessages, config, signal, stream, streamFn); + await runLoop( + currentContext, + newMessages, + config, + signal, + stream, + streamFn, + ); } catch (error) { emitErrorSequence(stream, createErrorMessage(error, config), newMessages); } @@ -182,7 +202,8 @@ async function runLoop( ): Promise { let firstTurn = true; // Check for steering messages at start (user may have typed while waiting) - let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || []; + let pendingMessages: AgentMessage[] = + (await config.getSteeringMessages?.()) || []; // Track consecutive turns where ALL tool calls fail validation. // When the LLM repeatedly emits tool calls with schema-overloaded or malformed @@ -216,12 +237,19 @@ async function runLoop( // Stream assistant response let message: AssistantMessage; try { - message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn); + message = await streamAssistantResponse( + currentContext, + config, + signal, + stream, + streamFn, + ); } catch (error) { // Critical failure before stream started (e.g. getApiKey threw, credentials in // backoff, network unavailable). Convert to a graceful error message so the // agent loop can end cleanly instead of crashing with an unhandled rejection. - const errorText = error instanceof Error ? error.message : String(error); + const errorText = + error instanceof Error ? error.message : String(error); message = { role: "assistant", content: [], @@ -259,13 +287,20 @@ async function runLoop( // to the tool call so the UI can show the real stdout/stderr // instead of a generic placeholder. for (const tc of toolCalls as AgentToolCall[]) { - const externalResult = (tc as AgentToolCall & { - externalResult?: { - content?: Array<{ type: string; text?: string; data?: string; mimeType?: string }>; - details?: Record; - isError?: boolean; - }; - }).externalResult; + const externalResult = ( + tc as AgentToolCall & { + externalResult?: { + content?: Array<{ + type: string; + text?: string; + data?: string; + mimeType?: string; + }>; + details?: Record; + isError?: boolean; + }; + } + ).externalResult; stream.push({ type: "tool_execution_start", toolCallId: tc.id, @@ -278,11 +313,13 @@ async function runLoop( toolName: tc.name, result: externalResult ? { - content: externalResult.content ?? [{ type: "text", text: "" }], + content: externalResult.content ?? [ + { type: "text", text: "" }, + ], details: externalResult.details ?? {}, } : { - content: [{ type: "text", text: "(executed by Claude Code)" }], + content: [{ type: "text", text: "(executed by provider)" }], details: {}, }, isError: externalResult?.isError ?? false, @@ -327,7 +364,9 @@ async function runLoop( consecutiveAllToolErrorTurns = 0; } - if (consecutiveAllToolErrorTurns >= MAX_CONSECUTIVE_VALIDATION_FAILURES) { + if ( + consecutiveAllToolErrorTurns >= MAX_CONSECUTIVE_VALIDATION_FAILURES + ) { // Force-stop: the LLM is stuck retrying broken tool calls. // Emit the turn_end and terminate the agent loop cleanly. stream.push({ type: "turn_end", message, toolResults }); @@ -344,12 +383,17 @@ async function runLoop( model: config.model.id, usage: ZERO_USAGE, stopReason: "error", - errorMessage: "Schema overload: consecutive tool validation failures exceeded cap", + errorMessage: + "Schema overload: consecutive tool validation failures exceeded cap", timestamp: Date.now(), }; emitMessagePair(stream, stopMessage); newMessages.push(stopMessage); - stream.push({ type: "turn_end", message: stopMessage, toolResults: [] }); + stream.push({ + type: "turn_end", + message: stopMessage, + toolResults: [], + }); stream.push({ type: "agent_end", messages: newMessages }); stream.end(newMessages); return; @@ -414,7 +458,9 @@ async function streamAssistantResponse( // Resolve API key (important for expiring tokens) const resolvedApiKey = - (config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey; + (config.getApiKey + ? await config.getApiKey(config.model.provider) + : undefined) || config.apiKey; const response = await streamFunction(config.model, llmContext, { ...config, @@ -503,11 +549,27 @@ async function executeToolCalls( signal: AbortSignal | undefined, stream: EventStream, ): Promise { - const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall") as AgentToolCall[]; + const toolCalls = assistantMessage.content.filter( + (c) => c.type === "toolCall", + ) as AgentToolCall[]; if (config.toolExecution === "sequential") { - return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, stream); + return executeToolCallsSequential( + currentContext, + assistantMessage, + toolCalls, + config, + signal, + stream, + ); } - return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, stream); + return executeToolCallsParallel( + currentContext, + assistantMessage, + toolCalls, + config, + signal, + stream, + ); } async function executeToolCallsSequential( @@ -532,14 +594,31 @@ async function executeToolCallsSequential( args: toolCall.arguments, }); - const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal); + const preparation = await prepareToolCall( + currentContext, + assistantMessage, + toolCall, + config, + signal, + ); if (preparation.kind === "immediate") { if (preparation.isError) { preparationErrorCount++; } - results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream)); + results.push( + emitToolCallOutcome( + toolCall, + preparation.result, + preparation.isError, + stream, + ), + ); } else { - const executed = await executePreparedToolCall(preparation, signal, stream); + const executed = await executePreparedToolCall( + preparation, + signal, + stream, + ); results.push( await finalizeExecutedToolCall( currentContext, @@ -594,12 +673,25 @@ async function executeToolCallsParallel( args: toolCall.arguments, }); - const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal); + const preparation = await prepareToolCall( + currentContext, + assistantMessage, + toolCall, + config, + signal, + ); if (preparation.kind === "immediate") { if (preparation.isError) { preparationErrorCount++; } - results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream)); + results.push( + emitToolCallOutcome( + toolCall, + preparation.result, + preparation.isError, + stream, + ), + ); } else { runnableCalls.push(preparation); } @@ -610,13 +702,19 @@ async function executeToolCallsParallel( steeringMessages = [...(steeringMessages ?? []), ...steering]; if (interruptOnSteering && hasUserSteeringMessage(steering)) { for (const runnable of runnableCalls) { - results.push(skipToolCall(runnable.toolCall, stream, { emitStart: false })); + results.push( + skipToolCall(runnable.toolCall, stream, { emitStart: false }), + ); } const remainingCalls = toolCalls.slice(index + 1); for (const skipped of remainingCalls) { results.push(skipToolCall(skipped, stream)); } - return { toolResults: results, steeringMessages, preparationErrorCount }; + return { + toolResults: results, + steeringMessages, + preparationErrorCount, + }; } } } @@ -701,7 +799,9 @@ async function prepareToolCall( if (beforeResult?.block) { return { kind: "immediate", - result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"), + result: createErrorToolResult( + beforeResult.reason || "Tool execution was blocked", + ), isError: true, }; } @@ -715,7 +815,9 @@ async function prepareToolCall( } catch (error) { return { kind: "immediate", - result: createErrorToolResult(error instanceof Error ? error.message : String(error)), + result: createErrorToolResult( + error instanceof Error ? error.message : String(error), + ), isError: true, }; } @@ -744,7 +846,9 @@ async function executePreparedToolCall( return { result, isError: false }; } catch (error) { return { - result: createErrorToolResult(error instanceof Error ? error.message : String(error)), + result: createErrorToolResult( + error instanceof Error ? error.message : String(error), + ), isError: true, }; } @@ -777,13 +881,22 @@ async function finalizeExecutedToolCall( ); if (afterResult) { result = { - content: afterResult.content !== undefined ? afterResult.content : result.content, - details: afterResult.details !== undefined ? afterResult.details : result.details, + content: + afterResult.content !== undefined + ? afterResult.content + : result.content, + details: + afterResult.details !== undefined + ? afterResult.details + : result.details, }; - isError = afterResult.isError !== undefined ? afterResult.isError : isError; + isError = + afterResult.isError !== undefined ? afterResult.isError : isError; } } catch (error) { - result = createErrorToolResult(error instanceof Error ? error.message : String(error)); + result = createErrorToolResult( + error instanceof Error ? error.message : String(error), + ); isError = true; } } diff --git a/packages/pi-coding-agent/src/core/sdk.test.ts b/packages/pi-coding-agent/src/core/sdk.test.ts index 38d4336a7..5287c1238 100644 --- a/packages/pi-coding-agent/src/core/sdk.test.ts +++ b/packages/pi-coding-agent/src/core/sdk.test.ts @@ -1,9 +1,12 @@ // pi-coding-agent / CredentialCooldownError unit tests // Copyright (c) 2026 Jeremy McSpadden -import { describe, it } from 'vitest'; import assert from "node:assert/strict"; -import { CredentialCooldownError } from "./sdk.js"; +import { describe, it } from "vitest"; +import { + CredentialCooldownError, + shouldUseExternalToolExecution, +} from "./sdk.js"; // ─── CredentialCooldownError ────────────────────────────────────────────────── @@ -57,7 +60,11 @@ describe("CredentialCooldownError", () => { it("code property is readonly and always AUTH_COOLDOWN regardless of provider", () => { for (const provider of ["anthropic", "openai", "google", "openrouter"]) { const err = new CredentialCooldownError(provider); - assert.equal(err.code, "AUTH_COOLDOWN", `code should be AUTH_COOLDOWN for provider "${provider}"`); + assert.equal( + err.code, + "AUTH_COOLDOWN", + `code should be AUTH_COOLDOWN for provider "${provider}"`, + ); } }); @@ -82,8 +89,37 @@ describe("CredentialCooldownError", () => { it("code property is detectable via plain object check (cross-process pattern)", () => { const err = new CredentialCooldownError("anthropic", 15_000); // Simulate cross-process serialization: only plain properties survive JSON round-trip - const plain = { code: err.code, retryAfterMs: err.retryAfterMs, message: err.message }; + const plain = { + code: err.code, + retryAfterMs: err.retryAfterMs, + message: err.message, + }; assert.equal(plain.code, "AUTH_COOLDOWN"); assert.equal(plain.retryAfterMs, 15_000); }); }); + +// ─── External Tool Execution Ownership ─────────────────────────────────────── + +describe("shouldUseExternalToolExecution", () => { + it("returns true for claude-code because its adapter can execute tools", () => { + assert.equal( + shouldUseExternalToolExecution({ provider: "claude-code" }), + true, + ); + }); + + it("returns false for google-gemini-cli so Gemini tool calls stay in the Pi harness", () => { + assert.equal( + shouldUseExternalToolExecution({ provider: "google-gemini-cli" }), + false, + ); + }); + + it("returns false for other CLI-style providers unless their adapter owns tools", () => { + assert.equal( + shouldUseExternalToolExecution({ provider: "openai-codex" }), + false, + ); + }); +}); diff --git a/packages/pi-coding-agent/src/core/sdk.ts b/packages/pi-coding-agent/src/core/sdk.ts index 787887e49..ca85567c7 100644 --- a/packages/pi-coding-agent/src/core/sdk.ts +++ b/packages/pi-coding-agent/src/core/sdk.ts @@ -1,5 +1,6 @@ import { existsSync } from "node:fs"; import { join } from "node:path"; +import type { Model } from "@singularity-forge/pi-ai"; /** * Lightweight PATH scan for the `claude` binary — no subprocess, no network. @@ -31,13 +32,38 @@ export class CredentialCooldownError extends Error { this.retryAfterMs = retryAfterMs; } } -import { Agent, type AgentMessage, type ThinkingLevel } from "@singularity-forge/pi-agent-core"; -import type { Message, Model } from "@singularity-forge/pi-ai"; + +/** + * Returns whether a provider executes tool calls inside its own adapter rather than + * returning them for the Pi harness to dispatch locally. + * + * Purpose: keep credential/auth transport modes separate from tool execution + * ownership so CLI-auth providers such as google-gemini-cli still run tools + * through Pi's local harness. + * + * Consumer: createAgentSession when configuring the core Agent loop. + */ +export function shouldUseExternalToolExecution( + model: Pick, "provider">, +): boolean { + return model.provider === "claude-code"; +} + +import { + Agent, + type AgentMessage, + type ThinkingLevel, +} from "@singularity-forge/pi-agent-core"; +import type { Message } from "@singularity-forge/pi-ai"; import { getAgentDir, getDocsPath } from "../config.js"; import { AgentSession } from "./agent-session.js"; import { AuthStorage } from "./auth-storage.js"; import { DEFAULT_THINKING_LEVEL } from "./defaults.js"; -import type { ExtensionRunner, LoadExtensionsResult, ToolDefinition } from "./extensions/index.js"; +import type { + ExtensionRunner, + LoadExtensionsResult, + ToolDefinition, +} from "./extensions/index.js"; import { convertToLlm } from "./messages.js"; import { ModelRegistry } from "./model-registry.js"; import { findInitialModel } from "./model-resolver.js"; @@ -55,6 +81,9 @@ import { createEditTool, createFindTool, createGrepTool, + createHashlineCodingTools, + createHashlineEditTool, + createHashlineReadTool, createLsTool, createReadOnlyTools, createReadTool, @@ -65,9 +94,6 @@ import { hashlineCodingTools, hashlineEditTool, hashlineReadTool, - createHashlineCodingTools, - createHashlineEditTool, - createHashlineReadTool, lsTool, readOnlyTools, readTool, @@ -144,34 +170,34 @@ export type { Skill } from "./skills.js"; export type { Tool } from "./tools/index.js"; export { - // Pre-built tools (use process.cwd()) - readTool, - bashTool, - editTool, - writeTool, - grepTool, - findTool, - lsTool, - codingTools, - readOnlyTools, allTools as allBuiltInTools, + bashTool, + codingTools, + createBashTool, // Tool factories (for custom cwd) createCodingTools, + createEditTool, + createFindTool, + createGrepTool, + createHashlineCodingTools, + createHashlineEditTool, + createHashlineReadTool, + createLsTool, createReadOnlyTools, createReadTool, - createBashTool, - createEditTool, createWriteTool, - createGrepTool, - createFindTool, - createLsTool, + editTool, + findTool, + grepTool, // Hashline edit mode hashlineCodingTools, hashlineEditTool, hashlineReadTool, - createHashlineCodingTools, - createHashlineEditTool, - createHashlineReadTool, + lsTool, + readOnlyTools, + // Pre-built tools (use process.cwd()) + readTool, + writeTool, }; // Helper Functions @@ -215,22 +241,32 @@ function getDefaultAgentDir(): string { * }); * ``` */ -export async function createAgentSession(options: CreateAgentSessionOptions = {}): Promise { +export async function createAgentSession( + options: CreateAgentSessionOptions = {}, +): Promise { const cwd = options.cwd ?? process.cwd(); const agentDir = options.agentDir ?? getDefaultAgentDir(); let resourceLoader = options.resourceLoader; // Use provided or create AuthStorage and ModelRegistry const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined; - const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined; + const modelsPath = options.agentDir + ? join(agentDir, "models.json") + : undefined; const authStorage = options.authStorage ?? AuthStorage.create(authPath); - const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath); + const modelRegistry = + options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath); - const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir); + const settingsManager = + options.settingsManager ?? SettingsManager.create(cwd, agentDir); const sessionManager = options.sessionManager ?? SessionManager.create(cwd); if (!resourceLoader) { - resourceLoader = new DefaultResourceLoader({ cwd, agentDir, settingsManager }); + resourceLoader = new DefaultResourceLoader({ + cwd, + agentDir, + settingsManager, + }); await resourceLoader.reload(); time("resourceLoader.reload"); } @@ -240,7 +276,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // findInitialModel() runs. bindCore() repeats this flush as a safety net // for any late-arriving registrations. const { runtime: extensionRuntime } = resourceLoader.getExtensions(); - for (const { name, config } of extensionRuntime.pendingProviderRegistrations) { + for (const { + name, + config, + } of extensionRuntime.pendingProviderRegistrations) { modelRegistry.registerProvider(name, config); } extensionRuntime.pendingProviderRegistrations = []; @@ -248,14 +287,19 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // Check if session has existing data to restore const existingSession = sessionManager.buildSessionContext(); const hasExistingSession = existingSession.messages.length > 0; - const hasThinkingEntry = sessionManager.getBranch().some((entry) => entry.type === "thinking_level_change"); + const hasThinkingEntry = sessionManager + .getBranch() + .some((entry) => entry.type === "thinking_level_change"); let model = options.model; let modelFallbackMessage: string | undefined; // If session has data, try to restore model from it if (!model && hasExistingSession && existingSession.model) { - const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId); + const restoredModel = modelRegistry.find( + existingSession.model.provider, + existingSession.model.modelId, + ); if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) { model = restoredModel; } @@ -268,7 +312,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // are available in the registry before model resolution. Without this, findInitialModel() // cannot find extension models and falls back to built-in providers (#3534). const extensionsForModelResolution = resourceLoader.getExtensions(); - for (const { name, config } of extensionsForModelResolution.runtime.pendingProviderRegistrations) { + for (const { name, config } of extensionsForModelResolution.runtime + .pendingProviderRegistrations) { modelRegistry.registerProvider(name, config); } // Clear the queue so bindCore() doesn't re-register the same providers. @@ -303,7 +348,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // Fall back to settings default if (thinkingLevel === undefined) { - thinkingLevel = settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL; + thinkingLevel = + settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL; } // Clamp to model capabilities @@ -312,11 +358,23 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} } const editMode = settingsManager.getEditMode(); - const defaultActiveToolNames: ToolName[] = editMode === "hashline" - ? ["hashline_read", "grep", "find", "ls", "bash", "hashline_edit", "write", "lsp"] - : ["read", "grep", "find", "ls", "bash", "edit", "write", "lsp"]; + const defaultActiveToolNames: ToolName[] = + editMode === "hashline" + ? [ + "hashline_read", + "grep", + "find", + "ls", + "bash", + "hashline_edit", + "write", + "lsp", + ] + : ["read", "grep", "find", "ls", "bash", "edit", "write", "lsp"]; const initialActiveToolNames: ToolName[] = options.tools - ? options.tools.map((t) => t.name).filter((n): n is ToolName => n in allTools) + ? options.tools + .map((t) => t.name) + .filter((n): n is ToolName => n in allTools) : defaultActiveToolNames; let agent: Agent; @@ -337,7 +395,12 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} if (hasImages) { const filteredContent = content .map((c) => - c.type === "image" ? { type: "text" as const, text: "Image reading is disabled." } : c, + c.type === "image" + ? { + type: "text" as const, + text: "Image reading is disabled.", + } + : c, ) .filter( (c, i, arr) => @@ -347,7 +410,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} c.text === "Image reading is disabled." && i > 0 && arr[i - 1].type === "text" && - (arr[i - 1] as { type: "text"; text: string }).text === "Image reading is disabled." + (arr[i - 1] as { type: "text"; text: string }).text === + "Image reading is disabled." ), ); return { ...msg, content: filteredContent }; @@ -387,7 +451,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} transport: settingsManager.getTransport(), thinkingBudgets: settingsManager.getThinkingBudgets(), maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs, - externalToolExecution: (m) => modelRegistry.getProviderAuthMode(m.provider) === "externalCli", + externalToolExecution: shouldUseExternalToolExecution, getProviderOptions: async (currentModel) => { if (currentModel.provider !== "claude-code") return undefined; const runner = extensionRunnerRef.current; @@ -432,11 +496,12 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // If credentials are in a cooldown window, wait for the earliest // one to expire rather than using a fixed delay that's too short. - const backoffExpiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider); + const backoffExpiry = + modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider); if (backoffExpiry !== undefined) { const waitMs = backoffExpiry - Date.now() + 500; // 500ms buffer if (waitMs > 0 && waitMs <= maxCooldownWaitMs) { - await new Promise(resolve => setTimeout(resolve, waitMs)); + await new Promise((resolve) => setTimeout(resolve, waitMs)); continue; // Retry immediately after cooldown clears } if (waitMs > maxCooldownWaitMs) { @@ -445,7 +510,9 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} } // Standard exponential backoff for non-cooldown transient failures - await new Promise(resolve => setTimeout(resolve, baseDelayMs * attempt)); + await new Promise((resolve) => + setTimeout(resolve, baseDelayMs * attempt), + ); } // All retries exhausted — throw descriptive error. @@ -465,7 +532,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} // Self-heal: strip the stale oauth entry so hasAuth() stops lying // about anthropic being configured. This preserves any api_key // credentials alongside it. - const removed = modelRegistry.authStorage.removeLegacyOAuthCredential(resolvedProvider); + const removed = + modelRegistry.authStorage.removeLegacyOAuthCredential( + resolvedProvider, + ); if (removed) { console.warn( `[auth] Removed unsupported Anthropic OAuth credential from auth.json (#3952).`, @@ -484,8 +554,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} `Set ANTHROPIC_API_KEY, run '/login' and paste an API key, or switch to a different provider.`, ); } - const expiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider); - const retryAfterMs = expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined; + const expiry = + modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider); + const retryAfterMs = + expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined; throw new CredentialCooldownError(resolvedProvider, retryAfterMs); } const model = agent.state.model; @@ -493,9 +565,15 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} if (isOAuth) { // If credentials exist but are all in a backoff window (quota / rate-limit), // surface a specific message instead of the misleading "Authentication failed". - if (modelRegistry.authStorage.areAllCredentialsBackedOff(resolvedProvider)) { - const expiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider); - const retryAfterMs = expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined; + if ( + modelRegistry.authStorage.areAllCredentialsBackedOff(resolvedProvider) + ) { + const expiry = + modelRegistry.authStorage.getEarliestBackoffExpiry( + resolvedProvider, + ); + const retryAfterMs = + expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined; throw new CredentialCooldownError(resolvedProvider, retryAfterMs); } throw new Error(