diff --git a/packages/pi-agent-core/src/agent-loop.ts b/packages/pi-agent-core/src/agent-loop.ts index fa05a0eff..436f7b291 100644 --- a/packages/pi-agent-core/src/agent-loop.ts +++ b/packages/pi-agent-core/src/agent-loop.ts @@ -22,7 +22,7 @@ import type { StreamFn, } from "./types.js"; -const ZERO_USAGE = { +export const ZERO_USAGE = { input: 0, output: 0, cacheRead: 0, @@ -50,6 +50,29 @@ function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantM }; } +/** + * Emit a message_start + message_end pair for a single message. + */ +function emitMessagePair(stream: EventStream, message: AgentMessage): void { + stream.push({ type: "message_start", message }); + stream.push({ type: "message_end", message }); +} + +/** + * Emit the standard error sequence when the outer agent loop catches an error. + * Pushes message_start/end, turn_end, agent_end, then closes the stream. + */ +function emitErrorSequence( + stream: EventStream, + errMsg: AssistantMessage, + newMessages: AgentMessage[], +): void { + emitMessagePair(stream, errMsg); + stream.push({ type: "turn_end", message: errMsg, toolResults: [] }); + stream.push({ type: "agent_end", messages: [...newMessages, errMsg] }); + stream.end([...newMessages, errMsg]); +} + /** * Start an agent loop with a new prompt message. * The prompt is added to the context and events are emitted for it. @@ -73,19 +96,13 @@ export function agentLoop( stream.push({ type: "agent_start" }); stream.push({ type: "turn_start" }); for (const prompt of prompts) { - stream.push({ type: "message_start", message: prompt }); - stream.push({ type: "message_end", message: prompt }); + emitMessagePair(stream, prompt); } try { await runLoop(currentContext, newMessages, config, signal, stream, streamFn); } catch (error) { - const errMsg = createErrorMessage(error, config); - stream.push({ type: "message_start", message: errMsg }); - stream.push({ type: "message_end", message: errMsg }); - stream.push({ type: "turn_end", message: errMsg, toolResults: [] }); - stream.push({ type: "agent_end", messages: [...newMessages, errMsg] }); - stream.end([...newMessages, errMsg]); + emitErrorSequence(stream, createErrorMessage(error, config), newMessages); } })(); @@ -126,12 +143,7 @@ export function agentLoopContinue( try { await runLoop(currentContext, newMessages, config, signal, stream, streamFn); } catch (error) { - const errMsg = createErrorMessage(error, config); - stream.push({ type: "message_start", message: errMsg }); - stream.push({ type: "message_end", message: errMsg }); - stream.push({ type: "turn_end", message: errMsg, toolResults: [] }); - stream.push({ type: "agent_end", messages: [...newMessages, errMsg] }); - stream.end([...newMessages, errMsg]); + emitErrorSequence(stream, createErrorMessage(error, config), newMessages); } })(); @@ -176,8 +188,7 @@ async function runLoop( // Process pending messages (inject before next assistant response) if (pendingMessages.length > 0) { for (const message of pendingMessages) { - stream.push({ type: "message_start", message }); - stream.push({ type: "message_end", message }); + emitMessagePair(stream, message); currentContext.messages.push(message); newMessages.push(message); } @@ -199,14 +210,7 @@ async function runLoop( api: config.model.api, provider: config.model.provider, model: config.model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, + usage: ZERO_USAGE, stopReason: signal?.aborted ? "aborted" : "error", errorMessage: errorText, timestamp: Date.now(), @@ -676,8 +680,7 @@ function emitToolCallOutcome( timestamp: Date.now(), }; - stream.push({ type: "message_start", message: toolResultMessage }); - stream.push({ type: "message_end", message: toolResultMessage }); + emitMessagePair(stream, toolResultMessage); return toolResultMessage; } diff --git a/packages/pi-agent-core/src/agent.ts b/packages/pi-agent-core/src/agent.ts index 4b9711be9..14a0a33ac 100644 --- a/packages/pi-agent-core/src/agent.ts +++ b/packages/pi-agent-core/src/agent.ts @@ -14,7 +14,7 @@ import { type ThinkingBudgets, type Transport, } from "@gsd/pi-ai"; -import { agentLoop, agentLoopContinue } from "./agent-loop.js"; +import { agentLoop, agentLoopContinue, ZERO_USAGE } from "./agent-loop.js"; import type { AgentContext, AgentEvent, @@ -489,10 +489,6 @@ export class Agent { // Update internal state based on events switch (event.type) { case "message_start": - partial = event.message; - this._state.streamMessage = event.message; - break; - case "message_update": partial = event.message; this._state.streamMessage = event.message; @@ -504,19 +500,13 @@ export class Agent { this.appendMessage(event.message); break; - case "tool_execution_start": { - const s = new Set(this._state.pendingToolCalls); - s.add(event.toolCallId); - this._state.pendingToolCalls = s; + case "tool_execution_start": + this._updatePendingToolCalls("add", event.toolCallId); break; - } - case "tool_execution_end": { - const s = new Set(this._state.pendingToolCalls); - s.delete(event.toolCallId); - this._state.pendingToolCalls = s; + case "tool_execution_end": + this._updatePendingToolCalls("delete", event.toolCallId); break; - } case "turn_end": if (event.message.role === "assistant" && (event.message as any).errorMessage) { @@ -557,14 +547,7 @@ export class Agent { api: model.api, provider: model.provider, model: model.id, - usage: { - input: 0, - output: 0, - cacheRead: 0, - cacheWrite: 0, - totalTokens: 0, - cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, - }, + usage: ZERO_USAGE, stopReason: this.abortController?.signal.aborted ? "aborted" : "error", errorMessage: err?.message || String(err), timestamp: Date.now(), @@ -584,6 +567,12 @@ export class Agent { } } + private _updatePendingToolCalls(action: "add" | "delete", id: string): void { + const s = new Set(this._state.pendingToolCalls); + s[action](id); + this._state.pendingToolCalls = s; + } + private emit(e: AgentEvent) { for (const listener of this.listeners) { listener(e);