refactor: deduplicate error emission and message patterns in agent-core (#1444)
- Extract emitMessagePair() to consolidate 6 message_start/message_end push pairs in agent-loop.ts - Extract emitErrorSequence() to deduplicate identical catch blocks in agentLoop and agentLoopContinue - Export ZERO_USAGE constant and reuse it in agent.ts instead of inline object literals - Merge identical message_start/message_update switch cases in Agent._runLoop - Extract Agent._updatePendingToolCalls() to consolidate tool_execution_start/end Set mutation
This commit is contained in:
parent
54b1446fb2
commit
da2af65971
2 changed files with 42 additions and 50 deletions
|
|
@ -22,7 +22,7 @@ import type {
|
|||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
||||
const ZERO_USAGE = {
|
||||
export const ZERO_USAGE = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
|
|
@ -50,6 +50,29 @@ function createErrorMessage(error: unknown, config: AgentLoopConfig): AssistantM
|
|||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit a message_start + message_end pair for a single message.
|
||||
*/
|
||||
function emitMessagePair(stream: EventStream<AgentEvent, AgentMessage[]>, message: AgentMessage): void {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
}
|
||||
|
||||
/**
|
||||
* Emit the standard error sequence when the outer agent loop catches an error.
|
||||
* Pushes message_start/end, turn_end, agent_end, then closes the stream.
|
||||
*/
|
||||
function emitErrorSequence(
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
errMsg: AssistantMessage,
|
||||
newMessages: AgentMessage[],
|
||||
): void {
|
||||
emitMessagePair(stream, errMsg);
|
||||
stream.push({ type: "turn_end", message: errMsg, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: [...newMessages, errMsg] });
|
||||
stream.end([...newMessages, errMsg]);
|
||||
}
|
||||
|
||||
/**
|
||||
* Start an agent loop with a new prompt message.
|
||||
* The prompt is added to the context and events are emitted for it.
|
||||
|
|
@ -73,19 +96,13 @@ export function agentLoop(
|
|||
stream.push({ type: "agent_start" });
|
||||
stream.push({ type: "turn_start" });
|
||||
for (const prompt of prompts) {
|
||||
stream.push({ type: "message_start", message: prompt });
|
||||
stream.push({ type: "message_end", message: prompt });
|
||||
emitMessagePair(stream, prompt);
|
||||
}
|
||||
|
||||
try {
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
} catch (error) {
|
||||
const errMsg = createErrorMessage(error, config);
|
||||
stream.push({ type: "message_start", message: errMsg });
|
||||
stream.push({ type: "message_end", message: errMsg });
|
||||
stream.push({ type: "turn_end", message: errMsg, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: [...newMessages, errMsg] });
|
||||
stream.end([...newMessages, errMsg]);
|
||||
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
|
||||
}
|
||||
})();
|
||||
|
||||
|
|
@ -126,12 +143,7 @@ export function agentLoopContinue(
|
|||
try {
|
||||
await runLoop(currentContext, newMessages, config, signal, stream, streamFn);
|
||||
} catch (error) {
|
||||
const errMsg = createErrorMessage(error, config);
|
||||
stream.push({ type: "message_start", message: errMsg });
|
||||
stream.push({ type: "message_end", message: errMsg });
|
||||
stream.push({ type: "turn_end", message: errMsg, toolResults: [] });
|
||||
stream.push({ type: "agent_end", messages: [...newMessages, errMsg] });
|
||||
stream.end([...newMessages, errMsg]);
|
||||
emitErrorSequence(stream, createErrorMessage(error, config), newMessages);
|
||||
}
|
||||
})();
|
||||
|
||||
|
|
@ -176,8 +188,7 @@ async function runLoop(
|
|||
// Process pending messages (inject before next assistant response)
|
||||
if (pendingMessages.length > 0) {
|
||||
for (const message of pendingMessages) {
|
||||
stream.push({ type: "message_start", message });
|
||||
stream.push({ type: "message_end", message });
|
||||
emitMessagePair(stream, message);
|
||||
currentContext.messages.push(message);
|
||||
newMessages.push(message);
|
||||
}
|
||||
|
|
@ -199,14 +210,7 @@ async function runLoop(
|
|||
api: config.model.api,
|
||||
provider: config.model.provider,
|
||||
model: config.model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: signal?.aborted ? "aborted" : "error",
|
||||
errorMessage: errorText,
|
||||
timestamp: Date.now(),
|
||||
|
|
@ -676,8 +680,7 @@ function emitToolCallOutcome(
|
|||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
emitMessagePair(stream, toolResultMessage);
|
||||
return toolResultMessage;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import {
|
|||
type ThinkingBudgets,
|
||||
type Transport,
|
||||
} from "@gsd/pi-ai";
|
||||
import { agentLoop, agentLoopContinue } from "./agent-loop.js";
|
||||
import { agentLoop, agentLoopContinue, ZERO_USAGE } from "./agent-loop.js";
|
||||
import type {
|
||||
AgentContext,
|
||||
AgentEvent,
|
||||
|
|
@ -489,10 +489,6 @@ export class Agent {
|
|||
// Update internal state based on events
|
||||
switch (event.type) {
|
||||
case "message_start":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
break;
|
||||
|
||||
case "message_update":
|
||||
partial = event.message;
|
||||
this._state.streamMessage = event.message;
|
||||
|
|
@ -504,19 +500,13 @@ export class Agent {
|
|||
this.appendMessage(event.message);
|
||||
break;
|
||||
|
||||
case "tool_execution_start": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.add(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
case "tool_execution_start":
|
||||
this._updatePendingToolCalls("add", event.toolCallId);
|
||||
break;
|
||||
}
|
||||
|
||||
case "tool_execution_end": {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s.delete(event.toolCallId);
|
||||
this._state.pendingToolCalls = s;
|
||||
case "tool_execution_end":
|
||||
this._updatePendingToolCalls("delete", event.toolCallId);
|
||||
break;
|
||||
}
|
||||
|
||||
case "turn_end":
|
||||
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
||||
|
|
@ -557,14 +547,7 @@ export class Agent {
|
|||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
usage: ZERO_USAGE,
|
||||
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
||||
errorMessage: err?.message || String(err),
|
||||
timestamp: Date.now(),
|
||||
|
|
@ -584,6 +567,12 @@ export class Agent {
|
|||
}
|
||||
}
|
||||
|
||||
private _updatePendingToolCalls(action: "add" | "delete", id: string): void {
|
||||
const s = new Set(this._state.pendingToolCalls);
|
||||
s[action](id);
|
||||
this._state.pendingToolCalls = s;
|
||||
}
|
||||
|
||||
private emit(e: AgentEvent) {
|
||||
for (const listener of this.listeners) {
|
||||
listener(e);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue