fix(gemini): keep cli tools in pi harness

This commit is contained in:
Mikael Hugo 2026-05-02 13:32:05 +02:00
parent 98fe3b605d
commit 7053938f7d
4 changed files with 535 additions and 133 deletions

View file

@ -1,16 +1,28 @@
// agent-loop tests
// Covers: pauseTurn handling (#2869), schema overload retry cap (#2783)
import { describe, it } from 'vitest';
import assert from "node:assert/strict";
import { readFileSync } from "node:fs";
import { join, dirname } from "node:path";
import { dirname, join } from "node:path";
import { fileURLToPath } from "node:url";
import { Type } from "@sinclair/typebox";
import { agentLoop, MAX_CONSECUTIVE_VALIDATION_FAILURES } from "./agent-loop.js";
import type { AgentContext, AgentLoopConfig, AgentTool, AgentEvent, AgentMessage } from "./types.js";
import { AssistantMessageEventStream, EventStream } from "@singularity-forge/pi-ai";
import type { AssistantMessage, AssistantMessageEvent, Model } from "@singularity-forge/pi-ai";
import type { AssistantMessage, Model } from "@singularity-forge/pi-ai";
import {
AssistantMessageEventStream,
type EventStream,
} from "@singularity-forge/pi-ai";
import { describe, it } from "vitest";
import {
agentLoop,
MAX_CONSECUTIVE_VALIDATION_FAILURES,
} from "./agent-loop.js";
import type {
AgentContext,
AgentEvent,
AgentLoopConfig,
AgentMessage,
AgentTool,
} from "./types.js";
const __dirname = dirname(fileURLToPath(import.meta.url));
@ -69,7 +81,13 @@ describe("agent-loop — pauseTurn handling (#2869)", () => {
const context: AgentContext = {
systemPrompt: "You are a test agent.",
messages: [{ role: "user", content: [{ type: "text", text: "Run the command" }], timestamp: Date.now() }],
messages: [
{
role: "user",
content: [{ type: "text", text: "Run the command" }],
timestamp: Date.now(),
},
],
tools: [],
};
@ -81,7 +99,13 @@ describe("agent-loop — pauseTurn handling (#2869)", () => {
};
const stream = agentLoop(
[{ role: "user", content: [{ type: "text", text: "Run the command" }], timestamp: Date.now() }],
[
{
role: "user",
content: [{ type: "text", text: "Run the command" }],
timestamp: Date.now(),
},
],
context,
config,
undefined,
@ -90,7 +114,8 @@ describe("agent-loop — pauseTurn handling (#2869)", () => {
const events = await collectEvents(stream);
const toolEnd = events.find(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> => event.type === "tool_execution_end",
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> =>
event.type === "tool_execution_end",
);
assert.ok(toolEnd, "expected tool_execution_end event");
@ -98,6 +123,71 @@ describe("agent-loop — pauseTurn handling (#2869)", () => {
assert.deepEqual(toolEnd.result.details, { source: "claude-code" });
assert.equal(toolEnd.isError, false);
});
it("uses a neutral provider-executed fallback when no external result is attached", async () => {
const externalMessage = makeAssistantMessage({
content: [
{
type: "toolCall",
id: "tc-external-fallback",
name: "read",
arguments: { filePath: ".sf/BACKLOG.md" },
},
],
stopReason: "toolUse",
provider: "claude-code",
});
const mockStream = createMockStreamFn([externalMessage]);
const context: AgentContext = {
systemPrompt: "You are a test agent.",
messages: [
{
role: "user",
content: [{ type: "text", text: "Read backlog" }],
timestamp: Date.now(),
},
],
tools: [],
};
const config: AgentLoopConfig = {
model: { ...TEST_MODEL, provider: "claude-code" },
convertToLlm: (msgs) => msgs.filter((m): m is any => m.role !== "custom"),
toolExecution: "sequential",
externalToolExecution: true,
};
const stream = agentLoop(
[
{
role: "user",
content: [{ type: "text", text: "Read backlog" }],
timestamp: Date.now(),
},
],
context,
config,
undefined,
mockStream as any,
);
const events = await collectEvents(stream);
const toolEnd = events.find(
(event): event is Extract<AgentEvent, { type: "tool_execution_end" }> =>
event.type === "tool_execution_end",
);
assert.ok(toolEnd, "expected tool_execution_end event");
assert.deepEqual(toolEnd.result.content, [
{ type: "text", text: "(executed by provider)" },
]);
assert.equal(
JSON.stringify(toolEnd.result.content).includes("Claude Code"),
false,
);
});
});
describe("agent-loop — steering during tool batches", () => {
@ -197,8 +287,7 @@ describe("agent-loop — steering during tool batches", () => {
assert.equal(skipped.length, 0);
assert.ok(
events.some(
(event) =>
event.type === "message_start" && event.message === steering,
(event) => event.type === "message_start" && event.message === steering,
),
"system steering should still be delivered after the tool batch",
);
@ -299,8 +388,7 @@ describe("agent-loop — steering during tool batches", () => {
assert.equal(skipped.length, 0);
assert.ok(
events.some(
(event) =>
event.type === "message_start" && event.message === steering,
(event) => event.type === "message_start" && event.message === steering,
),
"queued steering should still be delivered after the tool batch",
);
@ -375,21 +463,32 @@ function createMockStreamFn(responses: AssistantMessage[]) {
};
}
function makeAssistantMessage(overrides: Partial<AssistantMessage> = {}): AssistantMessage {
function makeAssistantMessage(
overrides: Partial<AssistantMessage> = {},
): AssistantMessage {
return {
role: "assistant",
content: [],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-test",
usage: { input: 100, output: 50, cacheRead: 0, cacheWrite: 0, totalTokens: 150, cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 } },
usage: {
input: 100,
output: 50,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 150,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
...overrides,
};
}
function makeToolCallMessage(toolCallArgs: Record<string, unknown>): AssistantMessage {
function makeToolCallMessage(
toolCallArgs: Record<string, unknown>,
): AssistantMessage {
return makeAssistantMessage({
content: [
{
@ -403,26 +502,32 @@ function makeToolCallMessage(toolCallArgs: Record<string, unknown>): AssistantMe
});
}
function collectEvents(stream: EventStream<AgentEvent, AgentMessage[]>): Promise<AgentEvent[]> {
return new Promise(async (resolve) => {
function collectEvents(
stream: EventStream<AgentEvent, AgentMessage[]>,
): Promise<AgentEvent[]> {
return new Promise((resolve) => {
const events: AgentEvent[] = [];
for await (const event of stream) {
events.push(event);
}
resolve(events);
void (async () => {
for await (const event of stream) {
events.push(event);
}
resolve(events);
})();
});
}
// ─── Tests ────────────────────────────────────────────────────────────────────
describe("agent-loop — schema overload retry cap (#2783)", () => {
it("terminates after MAX_CONSECUTIVE_VALIDATION_FAILURES consecutive schema failures", async () => {
const tool = makeToolWithSchema();
// LLM keeps sending tool calls with invalid args (missing required 'content' field)
const badToolCall = makeToolCallMessage({ path: "/tmp/test" }); // missing 'content'
const finalStop = makeAssistantMessage({ content: [{ type: "text", text: "I give up." }], stopReason: "stop" });
const finalStop = makeAssistantMessage({
content: [{ type: "text", text: "I give up." }],
stopReason: "stop",
});
// Create enough bad responses to exceed the cap, plus a final stop
const responses: AssistantMessage[] = [];
@ -435,7 +540,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
const context: AgentContext = {
systemPrompt: "You are a test agent.",
messages: [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }],
messages: [
{
role: "user",
content: [{ type: "text", text: "Write a file" }],
timestamp: Date.now(),
},
],
tools: [tool],
};
@ -446,7 +557,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
};
const stream = agentLoop(
[{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }],
[
{
role: "user",
content: [{ type: "text", text: "Write a file" }],
timestamp: Date.now(),
},
],
context,
config,
undefined,
@ -457,7 +574,10 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
// Must have terminated (agent_end event present)
const agentEnd = events.find((e) => e.type === "agent_end");
assert.ok(agentEnd, "agent loop must emit agent_end after hitting retry cap");
assert.ok(
agentEnd,
"agent loop must emit agent_end after hitting retry cap",
);
// Count how many turns had validation errors (tool_execution_end with isError: true)
const toolErrors = events.filter(
@ -476,15 +596,35 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
// Pattern: 2 failures, 1 success, 2 failures, 1 success, then stop
const badCall = makeToolCallMessage({ path: "/tmp/test" }); // missing 'content'
const goodCall = makeToolCallMessage({ path: "/tmp/test", content: "hello" });
const finalStop = makeAssistantMessage({ content: [{ type: "text", text: "Done." }], stopReason: "stop" });
const goodCall = makeToolCallMessage({
path: "/tmp/test",
content: "hello",
});
const finalStop = makeAssistantMessage({
content: [{ type: "text", text: "Done." }],
stopReason: "stop",
});
const responses = [badCall, badCall, goodCall, badCall, badCall, goodCall, finalStop];
const responses = [
badCall,
badCall,
goodCall,
badCall,
badCall,
goodCall,
finalStop,
];
const mockStream = createMockStreamFn(responses);
const context: AgentContext = {
systemPrompt: "You are a test agent.",
messages: [{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }],
messages: [
{
role: "user",
content: [{ type: "text", text: "Write a file" }],
timestamp: Date.now(),
},
],
tools: [tool],
};
@ -495,7 +635,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
};
const stream = agentLoop(
[{ role: "user", content: [{ type: "text", text: "Write a file" }], timestamp: Date.now() }],
[
{
role: "user",
content: [{ type: "text", text: "Write a file" }],
timestamp: Date.now(),
},
],
context,
config,
undefined,
@ -506,17 +652,29 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
// Must complete successfully since failures never reached cap consecutively
const agentEnd = events.find((e) => e.type === "agent_end");
assert.ok(agentEnd, "agent loop must complete normally when failures are interspersed with successes");
assert.ok(
agentEnd,
"agent loop must complete normally when failures are interspersed with successes",
);
// Should have processed all 6 tool-bearing turns
const toolExecEnds = events.filter((e) => e.type === "tool_execution_end");
assert.ok(toolExecEnds.length >= 4, `Expected at least 4 tool executions (2 bad + 1 good + 2 bad + 1 good), got ${toolExecEnds.length}`);
assert.ok(
toolExecEnds.length >= 4,
`Expected at least 4 tool executions (2 bad + 1 good + 2 bad + 1 good), got ${toolExecEnds.length}`,
);
});
it("exports MAX_CONSECUTIVE_VALIDATION_FAILURES as a configurable constant", () => {
assert.equal(typeof MAX_CONSECUTIVE_VALIDATION_FAILURES, "number");
assert.ok(MAX_CONSECUTIVE_VALIDATION_FAILURES >= 2, "Cap must be at least 2 to allow one retry");
assert.ok(MAX_CONSECUTIVE_VALIDATION_FAILURES <= 10, "Cap must not be unreasonably high");
assert.ok(
MAX_CONSECUTIVE_VALIDATION_FAILURES >= 2,
"Cap must be at least 2 to allow one retry",
);
assert.ok(
MAX_CONSECUTIVE_VALIDATION_FAILURES <= 10,
"Cap must not be unreasonably high",
);
});
it("does NOT trip schema overload cap on tool execution errors like bash exit code 1 (#3618)", async () => {
@ -566,7 +724,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
const context: AgentContext = {
systemPrompt: "You are a test agent.",
messages: [{ role: "user", content: [{ type: "text", text: "Search for references" }], timestamp: Date.now() }],
messages: [
{
role: "user",
content: [{ type: "text", text: "Search for references" }],
timestamp: Date.now(),
},
],
tools: [bashTool],
};
@ -577,7 +741,13 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
};
const stream = agentLoop(
[{ role: "user", content: [{ type: "text", text: "Search for references" }], timestamp: Date.now() }],
[
{
role: "user",
content: [{ type: "text", text: "Search for references" }],
timestamp: Date.now(),
},
],
context,
config,
undefined,
@ -604,12 +774,17 @@ describe("agent-loop — schema overload retry cap (#2783)", () => {
// The stop message should NOT contain the schema overload text
const allMessages = (agentEnd as any).messages as AgentMessage[];
const lastMessage = allMessages[allMessages.length - 1];
const lastText = lastMessage.role === "assistant"
? (lastMessage as AssistantMessage).content.find((c) => c.type === "text")
: undefined;
const lastText =
lastMessage.role === "assistant"
? (lastMessage as AssistantMessage).content.find(
(c) => c.type === "text",
)
: undefined;
if (lastText && lastText.type === "text") {
assert.ok(
!lastText.text.includes("consecutive turns with all tool calls failing"),
!lastText.text.includes(
"consecutive turns with all tool calls failing",
),
"Final message must NOT contain schema overload stop text for execution-only errors",
);
}

View file

@ -44,7 +44,10 @@ export const ZERO_USAGE = {
* Build an AssistantMessage for an unhandled error caught outside runLoop.
* Uses the model from config so the message satisfies the full interface.
*/
function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantMessage {
function createErrorMessage(
error: unknown,
config: AgentLoopConfig,
): AssistantMessage {
const msg = error instanceof Error ? error.message : String(error);
return {
role: "assistant",
@ -62,7 +65,10 @@ function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantM
/**
* Emit a message_start + message_end pair for a single message.
*/
function emitMessagePair(stream: EventStream<AgentEvent, AgentMessage[]>, message: AgentMessage): void {
function emitMessagePair(
stream: EventStream<AgentEvent, AgentMessage[]>,
message: AgentMessage,
): void {
stream.push({ type: "message_start", message });
stream.push({ type: "message_end", message });
}
@ -109,7 +115,14 @@ export function agentLoop(
}
try {
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
} catch (error) {
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
}
@ -153,7 +166,14 @@ export function agentLoopContinue(
stream.push({ type: "turn_start" });
try {
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
await runLoop(
currentContext,
newMessages,
config,
signal,
stream,
streamFn,
);
} catch (error) {
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
}
@ -182,7 +202,8 @@ async function runLoop(
): Promise<void> {
let firstTurn = true;
// Check for steering messages at start (user may have typed while waiting)
let pendingMessages: AgentMessage[] = (await config.getSteeringMessages?.()) || [];
let pendingMessages: AgentMessage[] =
(await config.getSteeringMessages?.()) || [];
// Track consecutive turns where ALL tool calls fail validation.
// When the LLM repeatedly emits tool calls with schema-overloaded or malformed
@ -216,12 +237,19 @@ async function runLoop(
// Stream assistant response
let message: AssistantMessage;
try {
message = await streamAssistantResponse(currentContext, config, signal, stream, streamFn);
message = await streamAssistantResponse(
currentContext,
config,
signal,
stream,
streamFn,
);
} catch (error) {
// Critical failure before stream started (e.g. getApiKey threw, credentials in
// backoff, network unavailable). Convert to a graceful error message so the
// agent loop can end cleanly instead of crashing with an unhandled rejection.
const errorText = error instanceof Error ? error.message : String(error);
const errorText =
error instanceof Error ? error.message : String(error);
message = {
role: "assistant",
content: [],
@ -259,13 +287,20 @@ async function runLoop(
// to the tool call so the UI can show the real stdout/stderr
// instead of a generic placeholder.
for (const tc of toolCalls as AgentToolCall[]) {
const externalResult = (tc as AgentToolCall & {
externalResult?: {
content?: Array<{ type: string; text?: string; data?: string; mimeType?: string }>;
details?: Record<string, unknown>;
isError?: boolean;
};
}).externalResult;
const externalResult = (
tc as AgentToolCall & {
externalResult?: {
content?: Array<{
type: string;
text?: string;
data?: string;
mimeType?: string;
}>;
details?: Record<string, unknown>;
isError?: boolean;
};
}
).externalResult;
stream.push({
type: "tool_execution_start",
toolCallId: tc.id,
@ -278,11 +313,13 @@ async function runLoop(
toolName: tc.name,
result: externalResult
? {
content: externalResult.content ?? [{ type: "text", text: "" }],
content: externalResult.content ?? [
{ type: "text", text: "" },
],
details: externalResult.details ?? {},
}
: {
content: [{ type: "text", text: "(executed by Claude Code)" }],
content: [{ type: "text", text: "(executed by provider)" }],
details: {},
},
isError: externalResult?.isError ?? false,
@ -327,7 +364,9 @@ async function runLoop(
consecutiveAllToolErrorTurns = 0;
}
if (consecutiveAllToolErrorTurns >= MAX_CONSECUTIVE_VALIDATION_FAILURES) {
if (
consecutiveAllToolErrorTurns >= MAX_CONSECUTIVE_VALIDATION_FAILURES
) {
// Force-stop: the LLM is stuck retrying broken tool calls.
// Emit the turn_end and terminate the agent loop cleanly.
stream.push({ type: "turn_end", message, toolResults });
@ -344,12 +383,17 @@ async function runLoop(
model: config.model.id,
usage: ZERO_USAGE,
stopReason: "error",
errorMessage: "Schema overload: consecutive tool validation failures exceeded cap",
errorMessage:
"Schema overload: consecutive tool validation failures exceeded cap",
timestamp: Date.now(),
};
emitMessagePair(stream, stopMessage);
newMessages.push(stopMessage);
stream.push({ type: "turn_end", message: stopMessage, toolResults: [] });
stream.push({
type: "turn_end",
message: stopMessage,
toolResults: [],
});
stream.push({ type: "agent_end", messages: newMessages });
stream.end(newMessages);
return;
@ -414,7 +458,9 @@ async function streamAssistantResponse(
// Resolve API key (important for expiring tokens)
const resolvedApiKey =
(config.getApiKey ? await config.getApiKey(config.model.provider) : undefined) || config.apiKey;
(config.getApiKey
? await config.getApiKey(config.model.provider)
: undefined) || config.apiKey;
const response = await streamFunction(config.model, llmContext, {
...config,
@ -503,11 +549,27 @@ async function executeToolCalls(
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
): Promise<ToolExecutionResult> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall") as AgentToolCall[];
const toolCalls = assistantMessage.content.filter(
(c) => c.type === "toolCall",
) as AgentToolCall[];
if (config.toolExecution === "sequential") {
return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, stream);
return executeToolCallsSequential(
currentContext,
assistantMessage,
toolCalls,
config,
signal,
stream,
);
}
return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, stream);
return executeToolCallsParallel(
currentContext,
assistantMessage,
toolCalls,
config,
signal,
stream,
);
}
async function executeToolCallsSequential(
@ -532,14 +594,31 @@ async function executeToolCallsSequential(
args: toolCall.arguments,
});
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
const preparation = await prepareToolCall(
currentContext,
assistantMessage,
toolCall,
config,
signal,
);
if (preparation.kind === "immediate") {
if (preparation.isError) {
preparationErrorCount++;
}
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
results.push(
emitToolCallOutcome(
toolCall,
preparation.result,
preparation.isError,
stream,
),
);
} else {
const executed = await executePreparedToolCall(preparation, signal, stream);
const executed = await executePreparedToolCall(
preparation,
signal,
stream,
);
results.push(
await finalizeExecutedToolCall(
currentContext,
@ -594,12 +673,25 @@ async function executeToolCallsParallel(
args: toolCall.arguments,
});
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
const preparation = await prepareToolCall(
currentContext,
assistantMessage,
toolCall,
config,
signal,
);
if (preparation.kind === "immediate") {
if (preparation.isError) {
preparationErrorCount++;
}
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
results.push(
emitToolCallOutcome(
toolCall,
preparation.result,
preparation.isError,
stream,
),
);
} else {
runnableCalls.push(preparation);
}
@ -610,13 +702,19 @@ async function executeToolCallsParallel(
steeringMessages = [...(steeringMessages ?? []), ...steering];
if (interruptOnSteering && hasUserSteeringMessage(steering)) {
for (const runnable of runnableCalls) {
results.push(skipToolCall(runnable.toolCall, stream, { emitStart: false }));
results.push(
skipToolCall(runnable.toolCall, stream, { emitStart: false }),
);
}
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
return { toolResults: results, steeringMessages, preparationErrorCount };
return {
toolResults: results,
steeringMessages,
preparationErrorCount,
};
}
}
}
@ -701,7 +799,9 @@ async function prepareToolCall(
if (beforeResult?.block) {
return {
kind: "immediate",
result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"),
result: createErrorToolResult(
beforeResult.reason || "Tool execution was blocked",
),
isError: true,
};
}
@ -715,7 +815,9 @@ async function prepareToolCall(
} catch (error) {
return {
kind: "immediate",
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
result: createErrorToolResult(
error instanceof Error ? error.message : String(error),
),
isError: true,
};
}
@ -744,7 +846,9 @@ async function executePreparedToolCall(
return { result, isError: false };
} catch (error) {
return {
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
result: createErrorToolResult(
error instanceof Error ? error.message : String(error),
),
isError: true,
};
}
@ -777,13 +881,22 @@ async function finalizeExecutedToolCall(
);
if (afterResult) {
result = {
content: afterResult.content !== undefined ? afterResult.content : result.content,
details: afterResult.details !== undefined ? afterResult.details : result.details,
content:
afterResult.content !== undefined
? afterResult.content
: result.content,
details:
afterResult.details !== undefined
? afterResult.details
: result.details,
};
isError = afterResult.isError !== undefined ? afterResult.isError : isError;
isError =
afterResult.isError !== undefined ? afterResult.isError : isError;
}
} catch (error) {
result = createErrorToolResult(error instanceof Error ? error.message : String(error));
result = createErrorToolResult(
error instanceof Error ? error.message : String(error),
);
isError = true;
}
}

View file

@ -1,9 +1,12 @@
// pi-coding-agent / CredentialCooldownError unit tests
// Copyright (c) 2026 Jeremy McSpadden <jeremy@fluxlabs.net>
import { describe, it } from 'vitest';
import assert from "node:assert/strict";
import { CredentialCooldownError } from "./sdk.js";
import { describe, it } from "vitest";
import {
CredentialCooldownError,
shouldUseExternalToolExecution,
} from "./sdk.js";
// ─── CredentialCooldownError ──────────────────────────────────────────────────
@ -57,7 +60,11 @@ describe("CredentialCooldownError", () => {
it("code property is readonly and always AUTH_COOLDOWN regardless of provider", () => {
for (const provider of ["anthropic", "openai", "google", "openrouter"]) {
const err = new CredentialCooldownError(provider);
assert.equal(err.code, "AUTH_COOLDOWN", `code should be AUTH_COOLDOWN for provider "${provider}"`);
assert.equal(
err.code,
"AUTH_COOLDOWN",
`code should be AUTH_COOLDOWN for provider "${provider}"`,
);
}
});
@ -82,8 +89,37 @@ describe("CredentialCooldownError", () => {
it("code property is detectable via plain object check (cross-process pattern)", () => {
const err = new CredentialCooldownError("anthropic", 15_000);
// Simulate cross-process serialization: only plain properties survive JSON round-trip
const plain = { code: err.code, retryAfterMs: err.retryAfterMs, message: err.message };
const plain = {
code: err.code,
retryAfterMs: err.retryAfterMs,
message: err.message,
};
assert.equal(plain.code, "AUTH_COOLDOWN");
assert.equal(plain.retryAfterMs, 15_000);
});
});
// ─── External Tool Execution Ownership ───────────────────────────────────────
describe("shouldUseExternalToolExecution", () => {
it("returns true for claude-code because its adapter can execute tools", () => {
assert.equal(
shouldUseExternalToolExecution({ provider: "claude-code" }),
true,
);
});
it("returns false for google-gemini-cli so Gemini tool calls stay in the Pi harness", () => {
assert.equal(
shouldUseExternalToolExecution({ provider: "google-gemini-cli" }),
false,
);
});
it("returns false for other CLI-style providers unless their adapter owns tools", () => {
assert.equal(
shouldUseExternalToolExecution({ provider: "openai-codex" }),
false,
);
});
});

View file

@ -1,5 +1,6 @@
import { existsSync } from "node:fs";
import { join } from "node:path";
import type { Model } from "@singularity-forge/pi-ai";
/**
* Lightweight PATH scan for the `claude` binary no subprocess, no network.
@ -31,13 +32,38 @@ export class CredentialCooldownError extends Error {
this.retryAfterMs = retryAfterMs;
}
}
import { Agent, type AgentMessage, type ThinkingLevel } from "@singularity-forge/pi-agent-core";
import type { Message, Model } from "@singularity-forge/pi-ai";
/**
* Returns whether a provider executes tool calls inside its own adapter rather than
* returning them for the Pi harness to dispatch locally.
*
* Purpose: keep credential/auth transport modes separate from tool execution
* ownership so CLI-auth providers such as google-gemini-cli still run tools
* through Pi's local harness.
*
* Consumer: createAgentSession when configuring the core Agent loop.
*/
export function shouldUseExternalToolExecution(
model: Pick<Model<any>, "provider">,
): boolean {
return model.provider === "claude-code";
}
import {
Agent,
type AgentMessage,
type ThinkingLevel,
} from "@singularity-forge/pi-agent-core";
import type { Message } from "@singularity-forge/pi-ai";
import { getAgentDir, getDocsPath } from "../config.js";
import { AgentSession } from "./agent-session.js";
import { AuthStorage } from "./auth-storage.js";
import { DEFAULT_THINKING_LEVEL } from "./defaults.js";
import type { ExtensionRunner, LoadExtensionsResult, ToolDefinition } from "./extensions/index.js";
import type {
ExtensionRunner,
LoadExtensionsResult,
ToolDefinition,
} from "./extensions/index.js";
import { convertToLlm } from "./messages.js";
import { ModelRegistry } from "./model-registry.js";
import { findInitialModel } from "./model-resolver.js";
@ -55,6 +81,9 @@ import {
createEditTool,
createFindTool,
createGrepTool,
createHashlineCodingTools,
createHashlineEditTool,
createHashlineReadTool,
createLsTool,
createReadOnlyTools,
createReadTool,
@ -65,9 +94,6 @@ import {
hashlineCodingTools,
hashlineEditTool,
hashlineReadTool,
createHashlineCodingTools,
createHashlineEditTool,
createHashlineReadTool,
lsTool,
readOnlyTools,
readTool,
@ -144,34 +170,34 @@ export type { Skill } from "./skills.js";
export type { Tool } from "./tools/index.js";
export {
// Pre-built tools (use process.cwd())
readTool,
bashTool,
editTool,
writeTool,
grepTool,
findTool,
lsTool,
codingTools,
readOnlyTools,
allTools as allBuiltInTools,
bashTool,
codingTools,
createBashTool,
// Tool factories (for custom cwd)
createCodingTools,
createEditTool,
createFindTool,
createGrepTool,
createHashlineCodingTools,
createHashlineEditTool,
createHashlineReadTool,
createLsTool,
createReadOnlyTools,
createReadTool,
createBashTool,
createEditTool,
createWriteTool,
createGrepTool,
createFindTool,
createLsTool,
editTool,
findTool,
grepTool,
// Hashline edit mode
hashlineCodingTools,
hashlineEditTool,
hashlineReadTool,
createHashlineCodingTools,
createHashlineEditTool,
createHashlineReadTool,
lsTool,
readOnlyTools,
// Pre-built tools (use process.cwd())
readTool,
writeTool,
};
// Helper Functions
@ -215,22 +241,32 @@ function getDefaultAgentDir(): string {
* });
* ```
*/
export async function createAgentSession(options: CreateAgentSessionOptions = {}): Promise<CreateAgentSessionResult> {
export async function createAgentSession(
options: CreateAgentSessionOptions = {},
): Promise<CreateAgentSessionResult> {
const cwd = options.cwd ?? process.cwd();
const agentDir = options.agentDir ?? getDefaultAgentDir();
let resourceLoader = options.resourceLoader;
// Use provided or create AuthStorage and ModelRegistry
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined;
const modelsPath = options.agentDir
? join(agentDir, "models.json")
: undefined;
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const modelRegistry =
options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const settingsManager =
options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd);
if (!resourceLoader) {
resourceLoader = new DefaultResourceLoader({ cwd, agentDir, settingsManager });
resourceLoader = new DefaultResourceLoader({
cwd,
agentDir,
settingsManager,
});
await resourceLoader.reload();
time("resourceLoader.reload");
}
@ -240,7 +276,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// findInitialModel() runs. bindCore() repeats this flush as a safety net
// for any late-arriving registrations.
const { runtime: extensionRuntime } = resourceLoader.getExtensions();
for (const { name, config } of extensionRuntime.pendingProviderRegistrations) {
for (const {
name,
config,
} of extensionRuntime.pendingProviderRegistrations) {
modelRegistry.registerProvider(name, config);
}
extensionRuntime.pendingProviderRegistrations = [];
@ -248,14 +287,19 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Check if session has existing data to restore
const existingSession = sessionManager.buildSessionContext();
const hasExistingSession = existingSession.messages.length > 0;
const hasThinkingEntry = sessionManager.getBranch().some((entry) => entry.type === "thinking_level_change");
const hasThinkingEntry = sessionManager
.getBranch()
.some((entry) => entry.type === "thinking_level_change");
let model = options.model;
let modelFallbackMessage: string | undefined;
// If session has data, try to restore model from it
if (!model && hasExistingSession && existingSession.model) {
const restoredModel = modelRegistry.find(existingSession.model.provider, existingSession.model.modelId);
const restoredModel = modelRegistry.find(
existingSession.model.provider,
existingSession.model.modelId,
);
if (restoredModel && (await modelRegistry.getApiKey(restoredModel))) {
model = restoredModel;
}
@ -268,7 +312,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// are available in the registry before model resolution. Without this, findInitialModel()
// cannot find extension models and falls back to built-in providers (#3534).
const extensionsForModelResolution = resourceLoader.getExtensions();
for (const { name, config } of extensionsForModelResolution.runtime.pendingProviderRegistrations) {
for (const { name, config } of extensionsForModelResolution.runtime
.pendingProviderRegistrations) {
modelRegistry.registerProvider(name, config);
}
// Clear the queue so bindCore() doesn't re-register the same providers.
@ -303,7 +348,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Fall back to settings default
if (thinkingLevel === undefined) {
thinkingLevel = settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
thinkingLevel =
settingsManager.getDefaultThinkingLevel() ?? DEFAULT_THINKING_LEVEL;
}
// Clamp to model capabilities
@ -312,11 +358,23 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
}
const editMode = settingsManager.getEditMode();
const defaultActiveToolNames: ToolName[] = editMode === "hashline"
? ["hashline_read", "grep", "find", "ls", "bash", "hashline_edit", "write", "lsp"]
: ["read", "grep", "find", "ls", "bash", "edit", "write", "lsp"];
const defaultActiveToolNames: ToolName[] =
editMode === "hashline"
? [
"hashline_read",
"grep",
"find",
"ls",
"bash",
"hashline_edit",
"write",
"lsp",
]
: ["read", "grep", "find", "ls", "bash", "edit", "write", "lsp"];
const initialActiveToolNames: ToolName[] = options.tools
? options.tools.map((t) => t.name).filter((n): n is ToolName => n in allTools)
? options.tools
.map((t) => t.name)
.filter((n): n is ToolName => n in allTools)
: defaultActiveToolNames;
let agent: Agent;
@ -337,7 +395,12 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (hasImages) {
const filteredContent = content
.map((c) =>
c.type === "image" ? { type: "text" as const, text: "Image reading is disabled." } : c,
c.type === "image"
? {
type: "text" as const,
text: "Image reading is disabled.",
}
: c,
)
.filter(
(c, i, arr) =>
@ -347,7 +410,8 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
c.text === "Image reading is disabled." &&
i > 0 &&
arr[i - 1].type === "text" &&
(arr[i - 1] as { type: "text"; text: string }).text === "Image reading is disabled."
(arr[i - 1] as { type: "text"; text: string }).text ===
"Image reading is disabled."
),
);
return { ...msg, content: filteredContent };
@ -387,7 +451,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
transport: settingsManager.getTransport(),
thinkingBudgets: settingsManager.getThinkingBudgets(),
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
externalToolExecution: (m) => modelRegistry.getProviderAuthMode(m.provider) === "externalCli",
externalToolExecution: shouldUseExternalToolExecution,
getProviderOptions: async (currentModel) => {
if (currentModel.provider !== "claude-code") return undefined;
const runner = extensionRunnerRef.current;
@ -432,11 +496,12 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// If credentials are in a cooldown window, wait for the earliest
// one to expire rather than using a fixed delay that's too short.
const backoffExpiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider);
const backoffExpiry =
modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider);
if (backoffExpiry !== undefined) {
const waitMs = backoffExpiry - Date.now() + 500; // 500ms buffer
if (waitMs > 0 && waitMs <= maxCooldownWaitMs) {
await new Promise(resolve => setTimeout(resolve, waitMs));
await new Promise((resolve) => setTimeout(resolve, waitMs));
continue; // Retry immediately after cooldown clears
}
if (waitMs > maxCooldownWaitMs) {
@ -445,7 +510,9 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
}
// Standard exponential backoff for non-cooldown transient failures
await new Promise(resolve => setTimeout(resolve, baseDelayMs * attempt));
await new Promise((resolve) =>
setTimeout(resolve, baseDelayMs * attempt),
);
}
// All retries exhausted — throw descriptive error.
@ -465,7 +532,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
// Self-heal: strip the stale oauth entry so hasAuth() stops lying
// about anthropic being configured. This preserves any api_key
// credentials alongside it.
const removed = modelRegistry.authStorage.removeLegacyOAuthCredential(resolvedProvider);
const removed =
modelRegistry.authStorage.removeLegacyOAuthCredential(
resolvedProvider,
);
if (removed) {
console.warn(
`[auth] Removed unsupported Anthropic OAuth credential from auth.json (#3952).`,
@ -484,8 +554,10 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
`Set ANTHROPIC_API_KEY, run '/login' and paste an API key, or switch to a different provider.`,
);
}
const expiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider);
const retryAfterMs = expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined;
const expiry =
modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider);
const retryAfterMs =
expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined;
throw new CredentialCooldownError(resolvedProvider, retryAfterMs);
}
const model = agent.state.model;
@ -493,9 +565,15 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (isOAuth) {
// If credentials exist but are all in a backoff window (quota / rate-limit),
// surface a specific message instead of the misleading "Authentication failed".
if (modelRegistry.authStorage.areAllCredentialsBackedOff(resolvedProvider)) {
const expiry = modelRegistry.authStorage.getEarliestBackoffExpiry(resolvedProvider);
const retryAfterMs = expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined;
if (
modelRegistry.authStorage.areAllCredentialsBackedOff(resolvedProvider)
) {
const expiry =
modelRegistry.authStorage.getEarliestBackoffExpiry(
resolvedProvider,
);
const retryAfterMs =
expiry !== undefined ? Math.max(0, expiry - Date.now()) : undefined;
throw new CredentialCooldownError(resolvedProvider, retryAfterMs);
}
throw new Error(