feat(pi-agent-core): parallel tool calling with before/after hooks (#427)
* Initial plan * feat(pi-agent-core): add parallel tool calling support with beforeToolCall/afterToolCall hooks Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: glittercowboy <186001655+glittercowboy@users.noreply.github.com>
This commit is contained in:
parent
3c931b2e19
commit
3fed189e00
3 changed files with 380 additions and 72 deletions
4
package-lock.json
generated
4
package-lock.json
generated
|
|
@ -1,12 +1,12 @@
|
|||
{
|
||||
"name": "gsd-pi",
|
||||
"version": "2.10.12",
|
||||
"version": "2.11.0",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "gsd-pi",
|
||||
"version": "2.10.12",
|
||||
"version": "2.11.0",
|
||||
"hasInstallScript": true,
|
||||
"license": "MIT",
|
||||
"workspaces": [
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ import type {
|
|||
AgentLoopConfig,
|
||||
AgentMessage,
|
||||
AgentTool,
|
||||
AgentToolCall,
|
||||
AgentToolResult,
|
||||
StreamFn,
|
||||
} from "./types.js";
|
||||
|
|
@ -230,11 +231,11 @@ async function runLoop(
|
|||
const toolResults: ToolResultMessage[] = [];
|
||||
if (hasMoreToolCalls) {
|
||||
const toolExecution = await executeToolCalls(
|
||||
currentContext.tools,
|
||||
currentContext,
|
||||
message,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
config.getSteeringMessages,
|
||||
);
|
||||
toolResults.push(...toolExecution.toolResults);
|
||||
steeringAfterTools = toolExecution.steeringMessages ?? null;
|
||||
|
|
@ -367,20 +368,32 @@ async function streamAssistantResponse(
|
|||
* Execute tool calls from an assistant message.
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
tools: AgentTool<any>[] | undefined,
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall") as AgentToolCall[];
|
||||
if (config.toolExecution === "sequential") {
|
||||
return executeToolCallsSequential(currentContext, assistantMessage, toolCalls, config, signal, stream);
|
||||
}
|
||||
return executeToolCallsParallel(currentContext, assistantMessage, toolCalls, config, signal, stream);
|
||||
}
|
||||
|
||||
async function executeToolCallsSequential(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
getSteeringMessages?: AgentLoopConfig["getSteeringMessages"],
|
||||
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
|
||||
const toolCalls = assistantMessage.content.filter((c) => c.type === "toolCall");
|
||||
const results: ToolResultMessage[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
const tool = tools?.find((t) => t.name === toolCall.name);
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
|
|
@ -388,56 +401,26 @@ async function executeToolCalls(
|
|||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
let result: AgentToolResult<any>;
|
||||
let isError = false;
|
||||
|
||||
try {
|
||||
if (!tool) throw new Error(`Tool ${toolCall.name} not found`);
|
||||
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
|
||||
result = await tool.execute(toolCall.id, validatedArgs, signal, (partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
});
|
||||
} catch (e) {
|
||||
result = {
|
||||
content: [{ type: "text", text: e instanceof Error ? e.message : String(e) }],
|
||||
details: {},
|
||||
};
|
||||
isError = true;
|
||||
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
|
||||
if (preparation.kind === "immediate") {
|
||||
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
|
||||
} else {
|
||||
const executed = await executePreparedToolCall(preparation, signal, stream);
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
preparation,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
role: "toolResult",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
results.push(toolResultMessage);
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
// Check for steering messages - skip remaining tools if user interrupted
|
||||
if (getSteeringMessages) {
|
||||
const steering = await getSteeringMessages();
|
||||
if (config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
|
|
@ -452,27 +435,233 @@ async function executeToolCalls(
|
|||
return { toolResults: results, steeringMessages };
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: Extract<AssistantMessage["content"][number], { type: "toolCall" }>,
|
||||
async function executeToolCallsParallel(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCalls: AgentToolCall[],
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
): Promise<{ toolResults: ToolResultMessage[]; steeringMessages?: AgentMessage[] }> {
|
||||
const results: ToolResultMessage[] = [];
|
||||
const runnableCalls: PreparedToolCall[] = [];
|
||||
let steeringMessages: AgentMessage[] | undefined;
|
||||
|
||||
for (let index = 0; index < toolCalls.length; index++) {
|
||||
const toolCall = toolCalls[index];
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
|
||||
const preparation = await prepareToolCall(currentContext, assistantMessage, toolCall, config, signal);
|
||||
if (preparation.kind === "immediate") {
|
||||
results.push(emitToolCallOutcome(toolCall, preparation.result, preparation.isError, stream));
|
||||
} else {
|
||||
runnableCalls.push(preparation);
|
||||
}
|
||||
|
||||
if (config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
for (const runnable of runnableCalls) {
|
||||
results.push(skipToolCall(runnable.toolCall, stream, { emitStart: false }));
|
||||
}
|
||||
const remainingCalls = toolCalls.slice(index + 1);
|
||||
for (const skipped of remainingCalls) {
|
||||
results.push(skipToolCall(skipped, stream));
|
||||
}
|
||||
return { toolResults: results, steeringMessages };
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const runningCalls = runnableCalls.map((prepared) => ({
|
||||
prepared,
|
||||
execution: executePreparedToolCall(prepared, signal, stream),
|
||||
}));
|
||||
|
||||
for (const running of runningCalls) {
|
||||
const executed = await running.execution;
|
||||
results.push(
|
||||
await finalizeExecutedToolCall(
|
||||
currentContext,
|
||||
assistantMessage,
|
||||
running.prepared,
|
||||
executed,
|
||||
config,
|
||||
signal,
|
||||
stream,
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
if (!steeringMessages && config.getSteeringMessages) {
|
||||
const steering = await config.getSteeringMessages();
|
||||
if (steering.length > 0) {
|
||||
steeringMessages = steering;
|
||||
}
|
||||
}
|
||||
|
||||
return { toolResults: results, steeringMessages };
|
||||
}
|
||||
|
||||
type PreparedToolCall = {
|
||||
kind: "prepared";
|
||||
toolCall: AgentToolCall;
|
||||
tool: AgentTool<any>;
|
||||
args: unknown;
|
||||
};
|
||||
|
||||
type ImmediateToolCallOutcome = {
|
||||
kind: "immediate";
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
type ExecutedToolCallOutcome = {
|
||||
result: AgentToolResult<any>;
|
||||
isError: boolean;
|
||||
};
|
||||
|
||||
async function prepareToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
toolCall: AgentToolCall,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
): Promise<PreparedToolCall | ImmediateToolCallOutcome> {
|
||||
const tool = currentContext.tools?.find((t) => t.name === toolCall.name);
|
||||
if (!tool) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(`Tool ${toolCall.name} not found`),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
try {
|
||||
const validatedArgs = validateToolArguments(tool, toolCall);
|
||||
if (config.beforeToolCall) {
|
||||
const beforeResult = await config.beforeToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall,
|
||||
args: validatedArgs,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (beforeResult?.block) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(beforeResult.reason || "Tool execution was blocked"),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
kind: "prepared",
|
||||
toolCall,
|
||||
tool,
|
||||
args: validatedArgs,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
kind: "immediate",
|
||||
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function executePreparedToolCall(
|
||||
prepared: PreparedToolCall,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ExecutedToolCallOutcome> {
|
||||
try {
|
||||
const result = await prepared.tool.execute(
|
||||
prepared.toolCall.id,
|
||||
prepared.args as never,
|
||||
signal,
|
||||
(partialResult) => {
|
||||
stream.push({
|
||||
type: "tool_execution_update",
|
||||
toolCallId: prepared.toolCall.id,
|
||||
toolName: prepared.toolCall.name,
|
||||
args: prepared.toolCall.arguments,
|
||||
partialResult,
|
||||
});
|
||||
},
|
||||
);
|
||||
return { result, isError: false };
|
||||
} catch (error) {
|
||||
return {
|
||||
result: createErrorToolResult(error instanceof Error ? error.message : String(error)),
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
async function finalizeExecutedToolCall(
|
||||
currentContext: AgentContext,
|
||||
assistantMessage: AssistantMessage,
|
||||
prepared: PreparedToolCall,
|
||||
executed: ExecutedToolCallOutcome,
|
||||
config: AgentLoopConfig,
|
||||
signal: AbortSignal | undefined,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): Promise<ToolResultMessage> {
|
||||
let result = executed.result;
|
||||
let isError = executed.isError;
|
||||
|
||||
if (config.afterToolCall) {
|
||||
const afterResult = await config.afterToolCall(
|
||||
{
|
||||
assistantMessage,
|
||||
toolCall: prepared.toolCall,
|
||||
args: prepared.args,
|
||||
result,
|
||||
isError,
|
||||
context: currentContext,
|
||||
},
|
||||
signal,
|
||||
);
|
||||
if (afterResult) {
|
||||
result = {
|
||||
content: afterResult.content !== undefined ? afterResult.content : result.content,
|
||||
details: afterResult.details !== undefined ? afterResult.details : result.details,
|
||||
};
|
||||
isError = afterResult.isError !== undefined ? afterResult.isError : isError;
|
||||
}
|
||||
}
|
||||
|
||||
return emitToolCallOutcome(prepared.toolCall, result, isError, stream);
|
||||
}
|
||||
|
||||
function createErrorToolResult(message: string): AgentToolResult<any> {
|
||||
return {
|
||||
content: [{ type: "text", text: message }],
|
||||
details: {},
|
||||
};
|
||||
}
|
||||
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
function emitToolCallOutcome(
|
||||
toolCall: AgentToolCall,
|
||||
result: AgentToolResult<any>,
|
||||
isError: boolean,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
): ToolResultMessage {
|
||||
stream.push({
|
||||
type: "tool_execution_end",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
result,
|
||||
isError: true,
|
||||
isError,
|
||||
});
|
||||
|
||||
const toolResultMessage: ToolResultMessage = {
|
||||
|
|
@ -480,13 +669,34 @@ function skipToolCall(
|
|||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
content: result.content,
|
||||
details: {},
|
||||
isError: true,
|
||||
details: result.details,
|
||||
isError,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
|
||||
stream.push({ type: "message_start", message: toolResultMessage });
|
||||
stream.push({ type: "message_end", message: toolResultMessage });
|
||||
|
||||
return toolResultMessage;
|
||||
}
|
||||
|
||||
function skipToolCall(
|
||||
toolCall: AgentToolCall,
|
||||
stream: EventStream<AgentEvent, AgentMessage[]>,
|
||||
options?: { emitStart?: boolean },
|
||||
): ToolResultMessage {
|
||||
const result: AgentToolResult<any> = {
|
||||
content: [{ type: "text", text: "Skipped due to queued user message." }],
|
||||
details: {},
|
||||
};
|
||||
|
||||
if (options?.emitStart !== false) {
|
||||
stream.push({
|
||||
type: "tool_execution_start",
|
||||
toolCallId: toolCall.id,
|
||||
toolName: toolCall.name,
|
||||
args: toolCall.arguments,
|
||||
});
|
||||
}
|
||||
|
||||
return emitToolCallOutcome(toolCall, result, true, stream);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ImageContent,
|
||||
Message,
|
||||
|
|
@ -16,6 +17,73 @@ export type StreamFn = (
|
|||
...args: Parameters<typeof streamSimple>
|
||||
) => ReturnType<typeof streamSimple> | Promise<ReturnType<typeof streamSimple>>;
|
||||
|
||||
/**
|
||||
* Configuration for how tool calls from a single assistant message are executed.
|
||||
*
|
||||
* - "sequential": each tool call is prepared, executed, and finalized before the next one starts.
|
||||
* - "parallel": tool calls are prepared sequentially, then allowed tools execute concurrently.
|
||||
* Final tool results are still emitted in assistant source order.
|
||||
*/
|
||||
export type ToolExecutionMode = "sequential" | "parallel";
|
||||
|
||||
/** A single tool call content block emitted by an assistant message. */
|
||||
export type AgentToolCall = Extract<AssistantMessage["content"][number], { type: "toolCall" }>;
|
||||
|
||||
/**
|
||||
* Result returned from `beforeToolCall`.
|
||||
*
|
||||
* Returning `{ block: true }` prevents the tool from executing. The loop emits an error tool result instead.
|
||||
* `reason` becomes the text shown in that error result. If omitted, a default blocked message is used.
|
||||
*/
|
||||
export interface BeforeToolCallResult {
|
||||
block?: boolean;
|
||||
reason?: string;
|
||||
}
|
||||
|
||||
/**
|
||||
* Partial override returned from `afterToolCall`.
|
||||
*
|
||||
* Merge semantics are field-by-field:
|
||||
* - `content`: if provided, replaces the tool result content array in full
|
||||
* - `details`: if provided, replaces the tool result details value in full
|
||||
* - `isError`: if provided, replaces the tool result error flag
|
||||
*
|
||||
* Omitted fields keep the original executed tool result values.
|
||||
*/
|
||||
export interface AfterToolCallResult {
|
||||
content?: (TextContent | ImageContent)[];
|
||||
details?: unknown;
|
||||
isError?: boolean;
|
||||
}
|
||||
|
||||
/** Context passed to `beforeToolCall`. */
|
||||
export interface BeforeToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** Current agent context at the time the tool call is prepared. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
/** Context passed to `afterToolCall`. */
|
||||
export interface AfterToolCallContext {
|
||||
/** The assistant message that requested the tool call. */
|
||||
assistantMessage: AssistantMessage;
|
||||
/** The raw tool call block from `assistantMessage.content`. */
|
||||
toolCall: AgentToolCall;
|
||||
/** Validated tool arguments for the target tool schema. */
|
||||
args: unknown;
|
||||
/** The executed tool result before any `afterToolCall` overrides are applied. */
|
||||
result: AgentToolResult<any>;
|
||||
/** Whether the executed tool result is currently treated as an error. */
|
||||
isError: boolean;
|
||||
/** Current agent context at the time the tool call is finalized. */
|
||||
context: AgentContext;
|
||||
}
|
||||
|
||||
/**
|
||||
* Configuration for the agent loop.
|
||||
*/
|
||||
|
|
@ -95,6 +163,36 @@ export interface AgentLoopConfig extends SimpleStreamOptions {
|
|||
* Use this for follow-up messages that should wait until the agent finishes.
|
||||
*/
|
||||
getFollowUpMessages?: () => Promise<AgentMessage[]>;
|
||||
|
||||
/**
|
||||
* Tool execution mode.
|
||||
* - "sequential": execute tool calls one by one
|
||||
* - "parallel": preflight tool calls sequentially, then execute allowed tools concurrently
|
||||
*
|
||||
* Default: "parallel"
|
||||
*/
|
||||
toolExecution?: ToolExecutionMode;
|
||||
|
||||
/**
|
||||
* Called before a tool is executed, after arguments have been validated.
|
||||
*
|
||||
* Return `{ block: true }` to prevent execution. The loop emits an error tool result instead.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
beforeToolCall?: (context: BeforeToolCallContext, signal?: AbortSignal) => Promise<BeforeToolCallResult | undefined>;
|
||||
|
||||
/**
|
||||
* Called after a tool finishes executing, before final tool events are emitted.
|
||||
*
|
||||
* Return an `AfterToolCallResult` to override parts of the executed tool result:
|
||||
* - `content` replaces the full content array
|
||||
* - `details` replaces the full details payload
|
||||
* - `isError` replaces the error flag
|
||||
*
|
||||
* Any omitted fields keep their original values. No deep merge is performed.
|
||||
* The hook receives the agent abort signal and is responsible for honoring it.
|
||||
*/
|
||||
afterToolCall?: (context: AfterToolCallContext, signal?: AbortSignal) => Promise<AfterToolCallResult | undefined>;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue