From 3fed189e006edd92352daa3ab4d2150694b96c6d Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Sat, 14 Mar 2026 21:02:43 -0600 Subject: [PATCH] feat(pi-agent-core): parallel tool calling with before/after hooks (#427) * Initial plan * feat(pi-agent-core): add parallel tool calling support with beforeToolCall/afterToolCall hooks Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com> --- package-lock.json | 4 +- packages/pi-agent-core/src/agent-loop.ts | 350 ++++++++++++++++++----- packages/pi-agent-core/src/types.ts | 98 +++++++ 3 files changed, 380 insertions(+), 72 deletions(-) diff --git a/package-lock.json b/package-lock.json index c24c8743b..54d47ea10 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "gsd-pi", - "version": "2.10.12", + "version": "2.11.0", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "gsd-pi", - "version": "2.10.12", + "version": "2.11.0", "hasInstallScript": true, "license": "MIT", "workspaces": [ diff --git a/packages/pi-agent-core/src/agent-loop.ts b/packages/pi-agent-core/src/agent-loop.ts index 8dee70a08..b7ade645b 100644 --- a/packages/pi-agent-core/src/agent-loop.ts +++ b/packages/pi-agent-core/src/agent-loop.ts @@ -17,6 +17,7 @@ import type { AgentLoopConfig, AgentMessage, AgentTool, + AgentToolCall, AgentToolResult, StreamFn, } from "./types.js"; @@ -230,11 +231,11 @@ async function runLoop( const toolResults: ToolResultMessage[] = []; if (hasMoreToolCalls) { const toolExecution = await executeToolCalls( - currentContext.tools, + currentContext, message, + config, signal, stream, - config.getSteeringMessages, ); toolResults.push(...toolExecution.toolResults); steeringAfterTools = toolExecution.steeringMessages ?? null; @@ -367,20 +368,32 @@ async function streamAssistantResponse( * Execute tool calls from an assistant message. */ async function executeToolCalls( - tools: AgentTool[] | undefined, + currentContext: AgentContext, assistantMessage: AssistantMessage, + config: AgentLoopConfig, + signal: AbortSignal | undefined, + stream: EventStream, +): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> { + const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall") as AgentToolCall[]; + if (config.toolExecution === "sequential") { + return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, stream); + } + return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, stream); +} + +async function executeToolCallsSequential( + currentContext: AgentContext, + assistantMessage: AssistantMessage, + toolCalls: AgentToolCall[], + config: AgentLoopConfig, signal: AbortSignal | undefined, stream: EventStream, - getSteeringMessages?: AgentLoopConfig["getSteeringMessages"], ): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> { - const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall"); const results: ToolResultMessage[] = []; let steeringMessages: AgentMessage[] | undefined; for (let index = 0; index < toolCalls.length; index++) { const toolCall = toolCalls[index]; - const tool = tools?.find((t) => t.name === toolCall.name); - stream.push({ type: "tool_execution_start", toolCallId: toolCall.id, @@ -388,56 +401,26 @@ async function executeToolCalls( args: toolCall.arguments, }); - let result: AgentToolResult; - let isError = false; - - try { - if (!tool) throw new Error(`Tool ${toolCall.name} not found`); - - const validatedArgs = validateToolArguments(tool, toolCall); - - result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => { - stream.push({ - type: "tool_execution_update", - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - partialResult, - }); - }); - } catch (e) { - result = { - content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }], - details: {}, - }; - isError = true; + const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal); + if (preparation.kind === "immediate") { + results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream)); + } else { + const executed = await executePreparedToolCall(preparation, signal, stream); + results.push( + await finalizeExecutedToolCall( + currentContext, + assistantMessage, + preparation, + executed, + config, + signal, + stream, + ), + ); } - stream.push({ - type: "tool_execution_end", - toolCallId: toolCall.id, - toolName: toolCall.name, - result, - isError, - }); - - const toolResultMessage: ToolResultMessage = { - role: "toolResult", - toolCallId: toolCall.id, - toolName: toolCall.name, - content: result.content, - details: result.details, - isError, - timestamp: Date.now(), - }; - - results.push(toolResultMessage); - stream.push({ type: "message_start", message: toolResultMessage }); - stream.push({ type: "message_end", message: toolResultMessage }); - - // Check for steering messages - skip remaining tools if user interrupted - if (getSteeringMessages) { - const steering = await getSteeringMessages(); + if (config.getSteeringMessages) { + const steering = await config.getSteeringMessages(); if (steering.length > 0) { steeringMessages = steering; const remainingCalls = toolCalls.slice(index + 1); @@ -452,27 +435,233 @@ async function executeToolCalls( return { toolResults: results, steeringMessages }; } -function skipToolCall( - toolCall: Extract, +async function executeToolCallsParallel( + currentContext: AgentContext, + assistantMessage: AssistantMessage, + toolCalls: AgentToolCall[], + config: AgentLoopConfig, + signal: AbortSignal | undefined, stream: EventStream, -): ToolResultMessage { - const result: AgentToolResult = { - content: [{ type: "text", text: "Skipped due to queued user message." }], +): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> { + const results: ToolResultMessage[] = []; + const runnableCalls: PreparedToolCall[] = []; + let steeringMessages: AgentMessage[] | undefined; + + for (let index = 0; index < toolCalls.length; index++) { + const toolCall = toolCalls[index]; + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + + const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal); + if (preparation.kind === "immediate") { + results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream)); + } else { + runnableCalls.push(preparation); + } + + if (config.getSteeringMessages) { + const steering = await config.getSteeringMessages(); + if (steering.length > 0) { + steeringMessages = steering; + for (const runnable of runnableCalls) { + results.push(skipToolCall(runnable.toolCall, stream, { emitStart: false })); + } + const remainingCalls = toolCalls.slice(index + 1); + for (const skipped of remainingCalls) { + results.push(skipToolCall(skipped, stream)); + } + return { toolResults: results, steeringMessages }; + } + } + } + + const runningCalls = runnableCalls.map((prepared) => ({ + prepared, + execution: executePreparedToolCall(prepared, signal, stream), + })); + + for (const running of runningCalls) { + const executed = await running.execution; + results.push( + await finalizeExecutedToolCall( + currentContext, + assistantMessage, + running.prepared, + executed, + config, + signal, + stream, + ), + ); + } + + if (!steeringMessages && config.getSteeringMessages) { + const steering = await config.getSteeringMessages(); + if (steering.length > 0) { + steeringMessages = steering; + } + } + + return { toolResults: results, steeringMessages }; +} + +type PreparedToolCall = { + kind: "prepared"; + toolCall: AgentToolCall; + tool: AgentTool; + args: unknown; +}; + +type ImmediateToolCallOutcome = { + kind: "immediate"; + result: AgentToolResult; + isError: boolean; +}; + +type ExecutedToolCallOutcome = { + result: AgentToolResult; + isError: boolean; +}; + +async function prepareToolCall( + currentContext: AgentContext, + assistantMessage: AssistantMessage, + toolCall: AgentToolCall, + config: AgentLoopConfig, + signal: AbortSignal | undefined, +): Promise { + const tool = currentContext.tools?.find((t) => t.name === toolCall.name); + if (!tool) { + return { + kind: "immediate", + result: createErrorToolResult(`Tool ${toolCall.name} not found`), + isError: true, + }; + } + + try { + const validatedArgs = validateToolArguments(tool, toolCall); + if (config.beforeToolCall) { + const beforeResult = await config.beforeToolCall( + { + assistantMessage, + toolCall, + args: validatedArgs, + context: currentContext, + }, + signal, + ); + if (beforeResult?.block) { + return { + kind: "immediate", + result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"), + isError: true, + }; + } + } + return { + kind: "prepared", + toolCall, + tool, + args: validatedArgs, + }; + } catch (error) { + return { + kind: "immediate", + result: createErrorToolResult(error instanceof Error ? error.message : String(error)), + isError: true, + }; + } +} + +async function executePreparedToolCall( + prepared: PreparedToolCall, + signal: AbortSignal | undefined, + stream: EventStream, +): Promise { + try { + const result = await prepared.tool.execute( + prepared.toolCall.id, + prepared.args as never, + signal, + (partialResult) => { + stream.push({ + type: "tool_execution_update", + toolCallId: prepared.toolCall.id, + toolName: prepared.toolCall.name, + args: prepared.toolCall.arguments, + partialResult, + }); + }, + ); + return { result, isError: false }; + } catch (error) { + return { + result: createErrorToolResult(error instanceof Error ? error.message : String(error)), + isError: true, + }; + } +} + +async function finalizeExecutedToolCall( + currentContext: AgentContext, + assistantMessage: AssistantMessage, + prepared: PreparedToolCall, + executed: ExecutedToolCallOutcome, + config: AgentLoopConfig, + signal: AbortSignal | undefined, + stream: EventStream, +): Promise { + let result = executed.result; + let isError = executed.isError; + + if (config.afterToolCall) { + const afterResult = await config.afterToolCall( + { + assistantMessage, + toolCall: prepared.toolCall, + args: prepared.args, + result, + isError, + context: currentContext, + }, + signal, + ); + if (afterResult) { + result = { + content: afterResult.content !== undefined ? afterResult.content : result.content, + details: afterResult.details !== undefined ? afterResult.details : result.details, + }; + isError = afterResult.isError !== undefined ? afterResult.isError : isError; + } + } + + return emitToolCallOutcome(prepared.toolCall, result, isError, stream); +} + +function createErrorToolResult(message: string): AgentToolResult { + return { + content: [{ type: "text", text: message }], details: {}, }; +} - stream.push({ - type: "tool_execution_start", - toolCallId: toolCall.id, - toolName: toolCall.name, - args: toolCall.arguments, - }); +function emitToolCallOutcome( + toolCall: AgentToolCall, + result: AgentToolResult, + isError: boolean, + stream: EventStream, +): ToolResultMessage { stream.push({ type: "tool_execution_end", toolCallId: toolCall.id, toolName: toolCall.name, result, - isError: true, + isError, }); const toolResultMessage: ToolResultMessage = { @@ -480,13 +669,34 @@ function skipToolCall( toolCallId: toolCall.id, toolName: toolCall.name, content: result.content, - details: {}, - isError: true, + details: result.details, + isError, timestamp: Date.now(), }; stream.push({ type: "message_start", message: toolResultMessage }); stream.push({ type: "message_end", message: toolResultMessage }); - return toolResultMessage; } + +function skipToolCall( + toolCall: AgentToolCall, + stream: EventStream, + options?: { emitStart?: boolean }, +): ToolResultMessage { + const result: AgentToolResult = { + content: [{ type: "text", text: "Skipped due to queued user message." }], + details: {}, + }; + + if (options?.emitStart !== false) { + stream.push({ + type: "tool_execution_start", + toolCallId: toolCall.id, + toolName: toolCall.name, + args: toolCall.arguments, + }); + } + + return emitToolCallOutcome(toolCall, result, true, stream); +} diff --git a/packages/pi-agent-core/src/types.ts b/packages/pi-agent-core/src/types.ts index a1d5a0d4b..cfeba8895 100644 --- a/packages/pi-agent-core/src/types.ts +++ b/packages/pi-agent-core/src/types.ts @@ -1,4 +1,5 @@ import type { + AssistantMessage, AssistantMessageEvent, ImageContent, Message, @@ -16,6 +17,73 @@ export type StreamFn = ( ...args: Parameters ) => ReturnType | Promise>; +/** + * Configuration for how tool calls from a single assistant message are executed. + * + * - "sequential": each tool call is prepared, executed, and finalized before the next one starts. + * - "parallel": tool calls are prepared sequentially, then allowed tools execute concurrently. + * Final tool results are still emitted in assistant source order. + */ +export type ToolExecutionMode = "sequential" | "parallel"; + +/** A single tool call content block emitted by an assistant message. */ +export type AgentToolCall = Extract; + +/** + * Result returned from `beforeToolCall`. + * + * Returning `{ block: true }` prevents the tool from executing. The loop emits an error tool result instead. + * `reason` becomes the text shown in that error result. If omitted, a default blocked message is used. + */ +export interface BeforeToolCallResult { + block?: boolean; + reason?: string; +} + +/** + * Partial override returned from `afterToolCall`. + * + * Merge semantics are field-by-field: + * - `content`: if provided, replaces the tool result content array in full + * - `details`: if provided, replaces the tool result details value in full + * - `isError`: if provided, replaces the tool result error flag + * + * Omitted fields keep the original executed tool result values. + */ +export interface AfterToolCallResult { + content?: (TextContent | ImageContent)[]; + details?: unknown; + isError?: boolean; +} + +/** Context passed to `beforeToolCall`. */ +export interface BeforeToolCallContext { + /** The assistant message that requested the tool call. */ + assistantMessage: AssistantMessage; + /** The raw tool call block from `assistantMessage.content`. */ + toolCall: AgentToolCall; + /** Validated tool arguments for the target tool schema. */ + args: unknown; + /** Current agent context at the time the tool call is prepared. */ + context: AgentContext; +} + +/** Context passed to `afterToolCall`. */ +export interface AfterToolCallContext { + /** The assistant message that requested the tool call. */ + assistantMessage: AssistantMessage; + /** The raw tool call block from `assistantMessage.content`. */ + toolCall: AgentToolCall; + /** Validated tool arguments for the target tool schema. */ + args: unknown; + /** The executed tool result before any `afterToolCall` overrides are applied. */ + result: AgentToolResult; + /** Whether the executed tool result is currently treated as an error. */ + isError: boolean; + /** Current agent context at the time the tool call is finalized. */ + context: AgentContext; +} + /** * Configuration for the agent loop. */ @@ -95,6 +163,36 @@ export interface AgentLoopConfig extends SimpleStreamOptions { * Use this for follow-up messages that should wait until the agent finishes. */ getFollowUpMessages?: () => Promise; + + /** + * Tool execution mode. + * - "sequential": execute tool calls one by one + * - "parallel": preflight tool calls sequentially, then execute allowed tools concurrently + * + * Default: "parallel" + */ + toolExecution?: ToolExecutionMode; + + /** + * Called before a tool is executed, after arguments have been validated. + * + * Return `{ block: true }` to prevent execution. The loop emits an error tool result instead. + * The hook receives the agent abort signal and is responsible for honoring it. + */ + beforeToolCall?: (context: BeforeToolCallContext, signal?: AbortSignal) => Promise; + + /** + * Called after a tool finishes executing, before final tool events are emitted. + * + * Return an `AfterToolCallResult` to override parts of the executed tool result: + * - `content` replaces the full content array + * - `details` replaces the full details payload + * - `isError` replaces the error flag + * + * Any omitted fields keep their original values. No deep merge is performed. + * The hook receives the agent abort signal and is responsible for honoring it. + */ + afterToolCall?: (context: AfterToolCallContext, signal?: AbortSignal) => Promise; } /**