feat(pi-agent-core): parallel tool calling with before/after hooks (#427)

* Initial plan

* feat(pi-agent-core): add parallel tool calling support with beforeToolCall/afterToolCall hooks

Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com>
This commit is contained in:
Copilot 2026-03-14 21:02:43 -06:00 committed by GitHub
parent 3c931b2e19
commit 3fed189e00
3 changed files with 380 additions and 72 deletions

4
package-lock.json generated
View file

@ -1,12 +1,12 @@
{
"name": "gsd-pi",
"version": "2.10.12",
"version": "2.11.0",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "gsd-pi",
"version": "2.10.12",
"version": "2.11.0",
"hasInstallScript": true,
"license": "MIT",
"workspaces": [

View file

@ -17,6 +17,7 @@ import type {
AgentLoopConfig,
AgentMessage,
AgentTool,
AgentToolCall,
AgentToolResult,
StreamFn,
} from "./types.js";
@ -230,11 +231,11 @@ async function runLoop(
const toolResults: ToolResultMessage[] = [];
if (hasMoreToolCalls) {
const toolExecution = await executeToolCalls(
currentContext.tools,
currentContext,
message,
config,
signal,
stream,
config.getSteeringMessages,
);
toolResults.push(...toolExecution.toolResults);
steeringAfterTools = toolExecution.steeringMessages ?? null;
@ -367,20 +368,32 @@ async function streamAssistantResponse(
* Execute tool calls from an assistant message.
*/
async function executeToolCalls(
tools: AgentTool<any>[] | undefined,
currentContext: AgentContext,
assistantMessage: AssistantMessage,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall") as AgentToolCall[];
if (config.toolExecution === "sequential") {
return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, stream);
}
return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, stream);
}
async function executeToolCallsSequential(
currentContext: AgentContext,
assistantMessage: AssistantMessage,
toolCalls: AgentToolCall[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
const results: ToolResultMessage[] = [];
let steeringMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
const tool = tools?.find((t) => t.name === toolCall.name);
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
@ -388,56 +401,26 @@ async function executeToolCalls(
args: toolCall.arguments,
});
let result: AgentToolResult<any>;
let isError = false;
try {
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
const validatedArgs = validateToolArguments(tool, toolCall);
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
partialResult,
});
});
} catch (e) {
result = {
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
details: {},
};
isError = true;
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
if (preparation.kind === "immediate") {
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
} else {
const executed = await executePreparedToolCall(preparation, signal, stream);
results.push(
await finalizeExecutedToolCall(
currentContext,
assistantMessage,
preparation,
executed,
config,
signal,
stream,
),
);
}
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError,
});
const toolResultMessage: ToolResultMessage = {
role: "toolResult",
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: result.details,
isError,
timestamp: Date.now(),
};
results.push(toolResultMessage);
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
// Check for steering messages - skip remaining tools if user interrupted
if (getSteeringMessages) {
const steering = await getSteeringMessages();
if (config.getSteeringMessages) {
const steering = await config.getSteeringMessages();
if (steering.length > 0) {
steeringMessages = steering;
const remainingCalls = toolCalls.slice(index + 1);
@ -452,27 +435,233 @@ async function executeToolCalls(
return { toolResults: results, steeringMessages };
}
function skipToolCall(
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
async function executeToolCallsParallel(
currentContext: AgentContext,
assistantMessage: AssistantMessage,
toolCalls: AgentToolCall[],
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
const results: ToolResultMessage[] = [];
const runnableCalls: PreparedToolCall[] = [];
let steeringMessages: AgentMessage[] | undefined;
for (let index = 0; index < toolCalls.length; index++) {
const toolCall = toolCalls[index];
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
if (preparation.kind === "immediate") {
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
} else {
runnableCalls.push(preparation);
}
if (config.getSteeringMessages) {
const steering = await config.getSteeringMessages();
if (steering.length > 0) {
steeringMessages = steering;
for (const runnable of runnableCalls) {
results.push(skipToolCall(runnable.toolCall, stream, { emitStart: false }));
}
const remainingCalls = toolCalls.slice(index + 1);
for (const skipped of remainingCalls) {
results.push(skipToolCall(skipped, stream));
}
return { toolResults: results, steeringMessages };
}
}
}
const runningCalls = runnableCalls.map((prepared) => ({
prepared,
execution: executePreparedToolCall(prepared, signal, stream),
}));
for (const running of runningCalls) {
const executed = await running.execution;
results.push(
await finalizeExecutedToolCall(
currentContext,
assistantMessage,
running.prepared,
executed,
config,
signal,
stream,
),
);
}
if (!steeringMessages && config.getSteeringMessages) {
const steering = await config.getSteeringMessages();
if (steering.length > 0) {
steeringMessages = steering;
}
}
return { toolResults: results, steeringMessages };
}
type PreparedToolCall = {
kind: "prepared";
toolCall: AgentToolCall;
tool: AgentTool<any>;
args: unknown;
};
type ImmediateToolCallOutcome = {
kind: "immediate";
result: AgentToolResult<any>;
isError: boolean;
};
type ExecutedToolCallOutcome = {
result: AgentToolResult<any>;
isError: boolean;
};
async function prepareToolCall(
currentContext: AgentContext,
assistantMessage: AssistantMessage,
toolCall: AgentToolCall,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
): Promise<PreparedToolCall | ImmediateToolCallOutcome> {
const tool = currentContext.tools?.find((t) => t.name === toolCall.name);
if (!tool) {
return {
kind: "immediate",
result: createErrorToolResult(`Tool ${toolCall.name} not found`),
isError: true,
};
}
try {
const validatedArgs = validateToolArguments(tool, toolCall);
if (config.beforeToolCall) {
const beforeResult = await config.beforeToolCall(
{
assistantMessage,
toolCall,
args: validatedArgs,
context: currentContext,
},
signal,
);
if (beforeResult?.block) {
return {
kind: "immediate",
result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"),
isError: true,
};
}
}
return {
kind: "prepared",
toolCall,
tool,
args: validatedArgs,
};
} catch (error) {
return {
kind: "immediate",
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
isError: true,
};
}
}
async function executePreparedToolCall(
prepared: PreparedToolCall,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
): Promise<ExecutedToolCallOutcome> {
try {
const result = await prepared.tool.execute(
prepared.toolCall.id,
prepared.args as never,
signal,
(partialResult) => {
stream.push({
type: "tool_execution_update",
toolCallId: prepared.toolCall.id,
toolName: prepared.toolCall.name,
args: prepared.toolCall.arguments,
partialResult,
});
},
);
return { result, isError: false };
} catch (error) {
return {
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
isError: true,
};
}
}
async function finalizeExecutedToolCall(
currentContext: AgentContext,
assistantMessage: AssistantMessage,
prepared: PreparedToolCall,
executed: ExecutedToolCallOutcome,
config: AgentLoopConfig,
signal: AbortSignal | undefined,
stream: EventStream<AgentEvent, AgentMessage[]>,
): Promise<ToolResultMessage> {
let result = executed.result;
let isError = executed.isError;
if (config.afterToolCall) {
const afterResult = await config.afterToolCall(
{
assistantMessage,
toolCall: prepared.toolCall,
args: prepared.args,
result,
isError,
context: currentContext,
},
signal,
);
if (afterResult) {
result = {
content: afterResult.content !== undefined ? afterResult.content : result.content,
details: afterResult.details !== undefined ? afterResult.details : result.details,
};
isError = afterResult.isError !== undefined ? afterResult.isError : isError;
}
}
return emitToolCallOutcome(prepared.toolCall, result, isError, stream);
}
function createErrorToolResult(message: string): AgentToolResult<any> {
return {
content: [{ type: "text", text: message }],
details: {},
};
}
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
function emitToolCallOutcome(
toolCall: AgentToolCall,
result: AgentToolResult<any>,
isError: boolean,
stream: EventStream<AgentEvent, AgentMessage[]>,
): ToolResultMessage {
stream.push({
type: "tool_execution_end",
toolCallId: toolCall.id,
toolName: toolCall.name,
result,
isError: true,
isError,
});
const toolResultMessage: ToolResultMessage = {
@ -480,13 +669,34 @@ function skipToolCall(
toolCallId: toolCall.id,
toolName: toolCall.name,
content: result.content,
details: {},
isError: true,
details: result.details,
isError,
timestamp: Date.now(),
};
stream.push({ type: "message_start", message: toolResultMessage });
stream.push({ type: "message_end", message: toolResultMessage });
return toolResultMessage;
}
function skipToolCall(
toolCall: AgentToolCall,
stream: EventStream<AgentEvent, AgentMessage[]>,
options?: { emitStart?: boolean },
): ToolResultMessage {
const result: AgentToolResult<any> = {
content: [{ type: "text", text: "Skipped due to queued user message." }],
details: {},
};
if (options?.emitStart !== false) {
stream.push({
type: "tool_execution_start",
toolCallId: toolCall.id,
toolName: toolCall.name,
args: toolCall.arguments,
});
}
return emitToolCallOutcome(toolCall, result, true, stream);
}

View file

@ -1,4 +1,5 @@
import type {
AssistantMessage,
AssistantMessageEvent,
ImageContent,
Message,
@ -16,6 +17,73 @@ export type StreamFn = (
...args: Parameters<typeof streamSimple>
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
/**
* Configuration for how tool calls from a single assistant message are executed.
*
* - "sequential": each tool call is prepared, executed, and finalized before the next one starts.
* - "parallel": tool calls are prepared sequentially, then allowed tools execute concurrently.
* Final tool results are still emitted in assistant source order.
*/
export type ToolExecutionMode = "sequential" | "parallel";
/** A single tool call content block emitted by an assistant message. */
export type AgentToolCall = Extract<AssistantMessage["content"][number], { type: "toolCall" }>;
/**
* Result returned from `beforeToolCall`.
*
* Returning `{ block: true }` prevents the tool from executing. The loop emits an error tool result instead.
* `reason` becomes the text shown in that error result. If omitted, a default blocked message is used.
*/
export interface BeforeToolCallResult {
block?: boolean;
reason?: string;
}
/**
* Partial override returned from `afterToolCall`.
*
* Merge semantics are field-by-field:
* - `content`: if provided, replaces the tool result content array in full
* - `details`: if provided, replaces the tool result details value in full
* - `isError`: if provided, replaces the tool result error flag
*
* Omitted fields keep the original executed tool result values.
*/
export interface AfterToolCallResult {
content?: (TextContent | ImageContent)[];
details?: unknown;
isError?: boolean;
}
/** Context passed to `beforeToolCall`. */
export interface BeforeToolCallContext {
/** The assistant message that requested the tool call. */
assistantMessage: AssistantMessage;
/** The raw tool call block from `assistantMessage.content`. */
toolCall: AgentToolCall;
/** Validated tool arguments for the target tool schema. */
args: unknown;
/** Current agent context at the time the tool call is prepared. */
context: AgentContext;
}
/** Context passed to `afterToolCall`. */
export interface AfterToolCallContext {
/** The assistant message that requested the tool call. */
assistantMessage: AssistantMessage;
/** The raw tool call block from `assistantMessage.content`. */
toolCall: AgentToolCall;
/** Validated tool arguments for the target tool schema. */
args: unknown;
/** The executed tool result before any `afterToolCall` overrides are applied. */
result: AgentToolResult<any>;
/** Whether the executed tool result is currently treated as an error. */
isError: boolean;
/** Current agent context at the time the tool call is finalized. */
context: AgentContext;
}
/**
* Configuration for the agent loop.
*/
@ -95,6 +163,36 @@ export interface AgentLoopConfig extends SimpleStreamOptions {
* Use this for follow-up messages that should wait until the agent finishes.
*/
getFollowUpMessages?: () => Promise<AgentMessage[]>;
/**
* Tool execution mode.
* - "sequential": execute tool calls one by one
* - "parallel": preflight tool calls sequentially, then execute allowed tools concurrently
*
* Default: "parallel"
*/
toolExecution?: ToolExecutionMode;
/**
* Called before a tool is executed, after arguments have been validated.
*
* Return `{ block: true }` to prevent execution. The loop emits an error tool result instead.
* The hook receives the agent abort signal and is responsible for honoring it.
*/
beforeToolCall?: (context: BeforeToolCallContext, signal?: AbortSignal) => Promise<BeforeToolCallResult | undefined>;
/**
* Called after a tool finishes executing, before final tool events are emitted.
*
* Return an `AfterToolCallResult` to override parts of the executed tool result:
* - `content` replaces the full content array
* - `details` replaces the full details payload
* - `isError` replaces the error flag
*
* Any omitted fields keep their original values. No deep merge is performed.
* The hook receives the agent abort signal and is responsible for honoring it.
*/
afterToolCall?: (context: AfterToolCallContext, signal?: AbortSignal) => Promise<AfterToolCallResult | undefined>;
}
/**