Adds `externalToolExecution` flag to AgentLoopConfig. When true, the agent loop emits tool_execution_start/end events for TUI rendering but skips local tool dispatch. Used by providers that handle tool execution internally (e.g., Claude Code CLI via Agent SDK). The flag is dynamically evaluated per-loop via a callback on AgentOptions, so model switches mid-session are handled correctly. Providers with authMode "externalCli" automatically use this mode. Also updates the Claude Code CLI stream adapter to preserve tool call blocks in the final message instead of stripping them. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
615 lines
17 KiB
TypeScript
615 lines
17 KiB
TypeScript
/**
|
|
* Agent class that uses the agent-loop directly.
|
|
* No transport abstraction - calls streamSimple via the loop.
|
|
*/
|
|
|
|
import {
|
|
getModel,
|
|
type ImageContent,
|
|
type Message,
|
|
type Model,
|
|
type SimpleStreamOptions,
|
|
streamSimple,
|
|
type TextContent,
|
|
type ThinkingBudgets,
|
|
type Transport,
|
|
} from "@gsd/pi-ai";
|
|
import { agentLoop, agentLoopContinue, ZERO_USAGE } from "./agent-loop.js";
|
|
import type {
|
|
AgentContext,
|
|
AgentEvent,
|
|
AgentLoopConfig,
|
|
AgentMessage,
|
|
AgentState,
|
|
AgentTool,
|
|
BeforeToolCallContext,
|
|
BeforeToolCallResult,
|
|
AfterToolCallContext,
|
|
AfterToolCallResult,
|
|
StreamFn,
|
|
ThinkingLevel,
|
|
} from "./types.js";
|
|
|
|
/**
|
|
* Default convertToLlm: Keep only LLM-compatible messages, convert attachments.
|
|
*/
|
|
function defaultConvertToLlm(messages: AgentMessage[]): Message[] {
|
|
return messages.filter((m) => m.role === "user" || m.role === "assistant" || m.role === "toolResult");
|
|
}
|
|
|
|
export interface AgentOptions {
|
|
initialState?: Partial<AgentState>;
|
|
|
|
/**
|
|
* Converts AgentMessage[] to LLM-compatible Message[] before each LLM call.
|
|
* Default filters to user/assistant/toolResult and converts attachments.
|
|
*/
|
|
convertToLlm?: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
|
|
|
/**
|
|
* Optional transform applied to context before convertToLlm.
|
|
* Use for context pruning, injecting external context, etc.
|
|
*/
|
|
transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
|
|
|
/**
|
|
* Steering mode: "all" = send all steering messages at once, "one-at-a-time" = one per turn
|
|
*/
|
|
steeringMode?: "all" | "one-at-a-time";
|
|
|
|
/**
|
|
* Follow-up mode: "all" = send all follow-up messages at once, "one-at-a-time" = one per turn
|
|
*/
|
|
followUpMode?: "all" | "one-at-a-time";
|
|
|
|
/**
|
|
* Custom stream function (for proxy backends, etc.). Default uses streamSimple.
|
|
*/
|
|
streamFn?: StreamFn;
|
|
|
|
/**
|
|
* Optional session identifier forwarded to LLM providers.
|
|
* Used by providers that support session-based caching (e.g., OpenAI Codex).
|
|
*/
|
|
sessionId?: string;
|
|
|
|
/**
|
|
* Resolves an API key dynamically for each LLM call.
|
|
* Useful for expiring tokens (e.g., GitHub Copilot OAuth).
|
|
*/
|
|
getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
|
|
|
/**
|
|
* Inspect or replace provider payloads before they are sent.
|
|
*/
|
|
onPayload?: SimpleStreamOptions["onPayload"];
|
|
|
|
/**
|
|
* Custom token budgets for thinking levels (token-based providers only).
|
|
*/
|
|
thinkingBudgets?: ThinkingBudgets;
|
|
|
|
/**
|
|
* Preferred transport for providers that support multiple transports.
|
|
*/
|
|
transport?: Transport;
|
|
|
|
/**
|
|
* Maximum delay in milliseconds to wait for a retry when the server requests a long wait.
|
|
* If the server's requested delay exceeds this value, the request fails immediately,
|
|
* allowing higher-level retry logic to handle it with user visibility.
|
|
* Default: 60000 (60 seconds). Set to 0 to disable the cap.
|
|
*/
|
|
maxRetryDelayMs?: number;
|
|
|
|
/**
|
|
* Determines whether a model uses external tool execution (tools handled
|
|
* by the provider, not dispatched locally). Evaluated per-loop so model
|
|
* switches mid-session are handled correctly.
|
|
*/
|
|
externalToolExecution?: (model: Model<any>) => boolean;
|
|
}
|
|
|
|
/**
|
|
* Internal wrapper that tracks message origin for origin-aware queue clearing.
|
|
* "user" = typed by human in TUI; "system" = generated by extensions/background jobs.
|
|
*/
|
|
interface QueueEntry {
|
|
message: AgentMessage;
|
|
origin: "user" | "system";
|
|
}
|
|
|
|
export class Agent {
|
|
private _state: AgentState = {
|
|
systemPrompt: "",
|
|
model: getModel("google", "gemini-2.5-flash-lite-preview-06-17"),
|
|
thinkingLevel: "off",
|
|
tools: [],
|
|
messages: [],
|
|
isStreaming: false,
|
|
streamMessage: null,
|
|
pendingToolCalls: new Set<string>(),
|
|
error: undefined,
|
|
};
|
|
|
|
private listeners = new Set<(e: AgentEvent) => void>();
|
|
private abortController?: AbortController;
|
|
private convertToLlm: (messages: AgentMessage[]) => Message[] | Promise<Message[]>;
|
|
private transformContext?: (messages: AgentMessage[], signal?: AbortSignal) => Promise<AgentMessage[]>;
|
|
private steeringQueue: QueueEntry[] = [];
|
|
private followUpQueue: QueueEntry[] = [];
|
|
private steeringMode: "all" | "one-at-a-time";
|
|
private followUpMode: "all" | "one-at-a-time";
|
|
public streamFn: StreamFn;
|
|
private _sessionId?: string;
|
|
public getApiKey?: (provider: string) => Promise<string | undefined> | string | undefined;
|
|
private _onPayload?: SimpleStreamOptions["onPayload"];
|
|
private runningPrompt?: Promise<void>;
|
|
private resolveRunningPrompt?: () => void;
|
|
private _thinkingBudgets?: ThinkingBudgets;
|
|
private _transport: Transport;
|
|
private _maxRetryDelayMs?: number;
|
|
private _beforeToolCall?: AgentLoopConfig["beforeToolCall"];
|
|
private _afterToolCall?: AgentLoopConfig["afterToolCall"];
|
|
private _externalToolExecution?: (model: Model<any>) => boolean;
|
|
|
|
constructor(opts: AgentOptions = {}) {
|
|
this._state = { ...this._state, ...opts.initialState };
|
|
this.convertToLlm = opts.convertToLlm || defaultConvertToLlm;
|
|
this.transformContext = opts.transformContext;
|
|
this.steeringMode = opts.steeringMode || "one-at-a-time";
|
|
this.followUpMode = opts.followUpMode || "one-at-a-time";
|
|
this.streamFn = opts.streamFn || streamSimple;
|
|
this._sessionId = opts.sessionId;
|
|
this.getApiKey = opts.getApiKey;
|
|
this._onPayload = opts.onPayload;
|
|
this._thinkingBudgets = opts.thinkingBudgets;
|
|
this._transport = opts.transport ?? "sse";
|
|
this._maxRetryDelayMs = opts.maxRetryDelayMs;
|
|
this._externalToolExecution = opts.externalToolExecution;
|
|
}
|
|
|
|
/**
|
|
* Get the current session ID used for provider caching.
|
|
*/
|
|
get sessionId(): string | undefined {
|
|
return this._sessionId;
|
|
}
|
|
|
|
/**
|
|
* Set the session ID for provider caching.
|
|
* Call this when switching sessions (new session, branch, resume).
|
|
*/
|
|
set sessionId(value: string | undefined) {
|
|
this._sessionId = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current thinking budgets.
|
|
*/
|
|
get thinkingBudgets(): ThinkingBudgets | undefined {
|
|
return this._thinkingBudgets;
|
|
}
|
|
|
|
/**
|
|
* Set custom thinking budgets for token-based providers.
|
|
*/
|
|
set thinkingBudgets(value: ThinkingBudgets | undefined) {
|
|
this._thinkingBudgets = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current preferred transport.
|
|
*/
|
|
get transport(): Transport {
|
|
return this._transport;
|
|
}
|
|
|
|
/**
|
|
* Set the preferred transport.
|
|
*/
|
|
setTransport(value: Transport) {
|
|
this._transport = value;
|
|
}
|
|
|
|
/**
|
|
* Get the current max retry delay in milliseconds.
|
|
*/
|
|
get maxRetryDelayMs(): number | undefined {
|
|
return this._maxRetryDelayMs;
|
|
}
|
|
|
|
/**
|
|
* Set the maximum delay to wait for server-requested retries.
|
|
* Set to 0 to disable the cap.
|
|
*/
|
|
set maxRetryDelayMs(value: number | undefined) {
|
|
this._maxRetryDelayMs = value;
|
|
}
|
|
|
|
/**
|
|
* Install a hook called before each tool executes, after argument validation.
|
|
* Return `{ block: true }` to prevent execution.
|
|
*/
|
|
setBeforeToolCall(fn: AgentLoopConfig["beforeToolCall"]): void {
|
|
this._beforeToolCall = fn;
|
|
}
|
|
|
|
/**
|
|
* Install a hook called after each tool executes, before results are emitted.
|
|
* Return field overrides for content/details/isError.
|
|
*/
|
|
setAfterToolCall(fn: AgentLoopConfig["afterToolCall"]): void {
|
|
this._afterToolCall = fn;
|
|
}
|
|
|
|
get state(): AgentState {
|
|
return this._state;
|
|
}
|
|
|
|
subscribe(fn: (e: AgentEvent) => void): () => void {
|
|
this.listeners.add(fn);
|
|
return () => this.listeners.delete(fn);
|
|
}
|
|
|
|
// State mutators
|
|
setSystemPrompt(v: string) {
|
|
this._state.systemPrompt = v;
|
|
}
|
|
|
|
setModel(m: Model<any>) {
|
|
this._state.model = m;
|
|
}
|
|
|
|
setThinkingLevel(l: ThinkingLevel) {
|
|
this._state.thinkingLevel = l;
|
|
}
|
|
|
|
setSteeringMode(mode: "all" | "one-at-a-time") {
|
|
this.steeringMode = mode;
|
|
}
|
|
|
|
getSteeringMode(): "all" | "one-at-a-time" {
|
|
return this.steeringMode;
|
|
}
|
|
|
|
setFollowUpMode(mode: "all" | "one-at-a-time") {
|
|
this.followUpMode = mode;
|
|
}
|
|
|
|
getFollowUpMode(): "all" | "one-at-a-time" {
|
|
return this.followUpMode;
|
|
}
|
|
|
|
setTools(t: AgentTool<any>[]) {
|
|
this._state.tools = t;
|
|
}
|
|
|
|
replaceMessages(ms: AgentMessage[]) {
|
|
this._state.messages = ms.slice();
|
|
}
|
|
|
|
appendMessage(m: AgentMessage) {
|
|
this._state.messages = [...this._state.messages, m];
|
|
}
|
|
|
|
/**
|
|
* Queue a steering message to interrupt the agent mid-run.
|
|
* Delivered after current tool execution, skips remaining tools.
|
|
*/
|
|
steer(m: AgentMessage, origin: "user" | "system" = "system") {
|
|
this.steeringQueue.push({ message: m, origin });
|
|
}
|
|
|
|
/**
|
|
* Queue a follow-up message to be processed after the agent finishes.
|
|
* Delivered only when agent has no more tool calls or steering messages.
|
|
*/
|
|
followUp(m: AgentMessage, origin: "user" | "system" = "system") {
|
|
this.followUpQueue.push({ message: m, origin });
|
|
}
|
|
|
|
clearSteeringQueue() {
|
|
this.steeringQueue = [];
|
|
}
|
|
|
|
clearFollowUpQueue() {
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
clearAllQueues() {
|
|
this.steeringQueue = [];
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
/**
|
|
* Drain user-origin messages from queues, leaving system messages in place.
|
|
* Used during abort to preserve messages the user explicitly typed.
|
|
*/
|
|
drainUserMessages(): { steering: AgentMessage[]; followUp: AgentMessage[] } {
|
|
const userSteering = this.steeringQueue.filter((e) => e.origin === "user").map((e) => e.message);
|
|
const userFollowUp = this.followUpQueue.filter((e) => e.origin === "user").map((e) => e.message);
|
|
this.steeringQueue = this.steeringQueue.filter((e) => e.origin !== "user");
|
|
this.followUpQueue = this.followUpQueue.filter((e) => e.origin !== "user");
|
|
return { steering: userSteering, followUp: userFollowUp };
|
|
}
|
|
|
|
hasQueuedMessages(): boolean {
|
|
return this.steeringQueue.length > 0 || this.followUpQueue.length > 0;
|
|
}
|
|
|
|
private dequeueSteeringMessages(): AgentMessage[] {
|
|
if (this.steeringMode === "one-at-a-time") {
|
|
if (this.steeringQueue.length > 0) {
|
|
const first = this.steeringQueue[0];
|
|
this.steeringQueue = this.steeringQueue.slice(1);
|
|
return [first.message];
|
|
}
|
|
return [];
|
|
}
|
|
|
|
const steering = this.steeringQueue.map((e) => e.message);
|
|
this.steeringQueue = [];
|
|
return steering;
|
|
}
|
|
|
|
private dequeueFollowUpMessages(): AgentMessage[] {
|
|
if (this.followUpMode === "one-at-a-time") {
|
|
if (this.followUpQueue.length > 0) {
|
|
const first = this.followUpQueue[0];
|
|
this.followUpQueue = this.followUpQueue.slice(1);
|
|
return [first.message];
|
|
}
|
|
return [];
|
|
}
|
|
|
|
const followUp = this.followUpQueue.map((e) => e.message);
|
|
this.followUpQueue = [];
|
|
return followUp;
|
|
}
|
|
|
|
clearMessages() {
|
|
this._state.messages = [];
|
|
}
|
|
|
|
abort() {
|
|
this.abortController?.abort();
|
|
}
|
|
|
|
waitForIdle(): Promise<void> {
|
|
return this.runningPrompt ?? Promise.resolve();
|
|
}
|
|
|
|
reset() {
|
|
this._state.messages = [];
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
this._state.pendingToolCalls = new Set<string>();
|
|
this._state.error = undefined;
|
|
this.steeringQueue = [];
|
|
this.followUpQueue = [];
|
|
}
|
|
|
|
/** Send a prompt with an AgentMessage */
|
|
async prompt(message: AgentMessage | AgentMessage[]): Promise<void>;
|
|
async prompt(input: string, images?: ImageContent[]): Promise<void>;
|
|
async prompt(input: string | AgentMessage | AgentMessage[], images?: ImageContent[]) {
|
|
if (this._state.isStreaming) {
|
|
throw new Error(
|
|
"Agent is already processing a prompt. Use steer() or followUp() to queue messages, or wait for completion.",
|
|
);
|
|
}
|
|
|
|
const model = this._state.model;
|
|
if (!model) throw new Error("No model configured");
|
|
|
|
let msgs: AgentMessage[];
|
|
|
|
if (Array.isArray(input)) {
|
|
msgs = input;
|
|
} else if (typeof input === "string") {
|
|
const content: Array<TextContent | ImageContent> = [{ type: "text", text: input }];
|
|
if (images && images.length > 0) {
|
|
content.push(...images);
|
|
}
|
|
msgs = [
|
|
{
|
|
role: "user",
|
|
content,
|
|
timestamp: Date.now(),
|
|
},
|
|
];
|
|
} else {
|
|
msgs = [input];
|
|
}
|
|
|
|
await this._runLoop(msgs);
|
|
}
|
|
|
|
/**
|
|
* Continue from current context (used for retries and resuming queued messages).
|
|
*/
|
|
async continue() {
|
|
if (this._state.isStreaming) {
|
|
throw new Error("Agent is already processing. Wait for completion before continuing.");
|
|
}
|
|
|
|
const messages = this._state.messages;
|
|
if (messages.length === 0) {
|
|
throw new Error("No messages to continue from");
|
|
}
|
|
if (messages[messages.length - 1].role === "assistant") {
|
|
const queuedSteering = this.dequeueSteeringMessages();
|
|
if (queuedSteering.length > 0) {
|
|
await this._runLoop(queuedSteering, { skipInitialSteeringPoll: true });
|
|
return;
|
|
}
|
|
|
|
const queuedFollowUp = this.dequeueFollowUpMessages();
|
|
if (queuedFollowUp.length > 0) {
|
|
await this._runLoop(queuedFollowUp);
|
|
return;
|
|
}
|
|
|
|
throw new Error("Cannot continue from message role: assistant");
|
|
}
|
|
|
|
await this._runLoop(undefined);
|
|
}
|
|
|
|
/**
|
|
* Run the agent loop.
|
|
* If messages are provided, starts a new conversation turn with those messages.
|
|
* Otherwise, continues from existing context.
|
|
*/
|
|
private async _runLoop(messages?: AgentMessage[], options?: { skipInitialSteeringPoll?: boolean }) {
|
|
const model = this._state.model;
|
|
if (!model) throw new Error("No model configured");
|
|
|
|
this._state.activeInferenceModel = model;
|
|
|
|
this.runningPrompt = new Promise<void>((resolve) => {
|
|
this.resolveRunningPrompt = resolve;
|
|
});
|
|
|
|
this.abortController = new AbortController();
|
|
this._state.isStreaming = true;
|
|
this._state.streamMessage = null;
|
|
this._state.error = undefined;
|
|
|
|
const reasoning = this._state.thinkingLevel === "off" ? undefined : this._state.thinkingLevel;
|
|
|
|
const context: AgentContext = {
|
|
systemPrompt: this._state.systemPrompt,
|
|
messages: this._state.messages.slice(),
|
|
tools: this._state.tools,
|
|
};
|
|
|
|
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
|
|
|
|
const config: AgentLoopConfig = {
|
|
model,
|
|
reasoning,
|
|
sessionId: this._sessionId,
|
|
onPayload: this._onPayload,
|
|
transport: this._transport,
|
|
thinkingBudgets: this._thinkingBudgets,
|
|
maxRetryDelayMs: this._maxRetryDelayMs,
|
|
convertToLlm: this.convertToLlm,
|
|
transformContext: this.transformContext,
|
|
getApiKey: this.getApiKey,
|
|
getSteeringMessages: async () => {
|
|
if (skipInitialSteeringPoll) {
|
|
skipInitialSteeringPoll = false;
|
|
return [];
|
|
}
|
|
return this.dequeueSteeringMessages();
|
|
},
|
|
getFollowUpMessages: async () => this.dequeueFollowUpMessages(),
|
|
beforeToolCall: this._beforeToolCall,
|
|
afterToolCall: this._afterToolCall,
|
|
externalToolExecution: this._externalToolExecution?.(model) ?? false,
|
|
};
|
|
|
|
let partial: AgentMessage | null = null;
|
|
|
|
try {
|
|
const stream = messages
|
|
? agentLoop(messages, context, config, this.abortController.signal, this.streamFn)
|
|
: agentLoopContinue(context, config, this.abortController.signal, this.streamFn);
|
|
|
|
for await (const event of stream) {
|
|
// Update internal state based on events
|
|
switch (event.type) {
|
|
case "message_start":
|
|
case "message_update":
|
|
partial = event.message;
|
|
this._state.streamMessage = event.message;
|
|
break;
|
|
|
|
case "message_end":
|
|
partial = null;
|
|
this._state.streamMessage = null;
|
|
this.appendMessage(event.message);
|
|
break;
|
|
|
|
case "tool_execution_start":
|
|
this._updatePendingToolCalls("add", event.toolCallId);
|
|
break;
|
|
|
|
case "tool_execution_end":
|
|
this._updatePendingToolCalls("delete", event.toolCallId);
|
|
break;
|
|
|
|
case "turn_end":
|
|
if (event.message.role === "assistant" && (event.message as any).errorMessage) {
|
|
this._state.error = (event.message as any).errorMessage;
|
|
}
|
|
break;
|
|
|
|
case "agent_end":
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
break;
|
|
}
|
|
|
|
// Emit to listeners
|
|
this.emit(event);
|
|
}
|
|
|
|
// Handle any remaining partial message
|
|
if (partial && partial.role === "assistant" && partial.content.length > 0) {
|
|
const onlyEmpty = !partial.content.some(
|
|
(c) =>
|
|
(c.type === "thinking" && c.thinking.trim().length > 0) ||
|
|
(c.type === "text" && c.text.trim().length > 0) ||
|
|
(c.type === "toolCall" && c.name.trim().length > 0),
|
|
);
|
|
if (!onlyEmpty) {
|
|
this.appendMessage(partial);
|
|
} else {
|
|
if (this.abortController?.signal.aborted) {
|
|
throw new Error("Request was aborted");
|
|
}
|
|
}
|
|
}
|
|
} catch (err: any) {
|
|
const errorMsg: AgentMessage = {
|
|
role: "assistant",
|
|
content: [{ type: "text", text: "" }],
|
|
api: model.api,
|
|
provider: model.provider,
|
|
model: model.id,
|
|
usage: ZERO_USAGE,
|
|
stopReason: this.abortController?.signal.aborted ? "aborted" : "error",
|
|
errorMessage: err?.message || String(err),
|
|
timestamp: Date.now(),
|
|
} as AgentMessage;
|
|
|
|
this.appendMessage(errorMsg);
|
|
this._state.error = err?.message || String(err);
|
|
this.emit({ type: "agent_end", messages: [errorMsg] });
|
|
} finally {
|
|
this._state.isStreaming = false;
|
|
this._state.streamMessage = null;
|
|
this._state.pendingToolCalls = new Set<string>();
|
|
this._state.activeInferenceModel = undefined;
|
|
this.abortController = undefined;
|
|
this.resolveRunningPrompt?.();
|
|
this.runningPrompt = undefined;
|
|
this.resolveRunningPrompt = undefined;
|
|
}
|
|
}
|
|
|
|
private _updatePendingToolCalls(action: "add" | "delete", id: string): void {
|
|
const s = new Set(this._state.pendingToolCalls);
|
|
s[action](id);
|
|
this._state.pendingToolCalls = s;
|
|
}
|
|
|
|
private emit(e: AgentEvent) {
|
|
for (const listener of this.listeners) {
|
|
listener(e);
|
|
}
|
|
}
|
|
}
|