diff --git a/packages/pi-agent-core/src/agent.test.ts b/packages/pi-agent-core/src/agent.test.ts index e0b838cd4..4ecd23af2 100644 --- a/packages/pi-agent-core/src/agent.test.ts +++ b/packages/pi-agent-core/src/agent.test.ts @@ -8,6 +8,8 @@ import assert from "node:assert/strict"; import { readFileSync } from "node:fs"; import { join, dirname } from "node:path"; import { fileURLToPath } from "node:url"; +import { Agent } from "./agent.ts"; +import { getModel, type AssistantMessageEventStream } from "@gsd/pi-ai"; const __dirname = dirname(fileURLToPath(import.meta.url)); @@ -50,4 +52,84 @@ describe("Agent — activeInferenceModel (#1844 Bug 2)", () => { assert.ok(setLine < abortLine, "activeInferenceModel must be set before streaming infrastructure is created"); }); + + it("getProviderOptions are forwarded into the provider stream call", async () => { + let capturedOptions: Record | undefined; + const agent = new Agent({ + initialState: { + model: getModel("anthropic", "claude-3-5-sonnet-20241022"), + systemPrompt: "test", + tools: [], + }, + getProviderOptions: async () => ({ customRuntimeOption: "present" }), + streamFn: (_model, _context, options): AssistantMessageEventStream => { + capturedOptions = options as Record | undefined; + return { + async *[Symbol.asyncIterator]() { + yield { + type: "start", + partial: { + role: "assistant", + content: [], + api: "anthropic-messages", + provider: "anthropic", + model: "claude-3-5-sonnet-20241022", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }, + }; + yield { + type: "done", + message: { + role: "assistant", + content: [{ type: "text", text: "ok" }], + api: "anthropic-messages", + provider: "anthropic", + model: "claude-3-5-sonnet-20241022", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }, + }; + }, + result: async () => ({ + role: "assistant", + content: [{ type: "text", text: "ok" }], + api: "anthropic-messages", + provider: "anthropic", + model: "claude-3-5-sonnet-20241022", + usage: { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + stopReason: "stop", + timestamp: Date.now(), + }), + [Symbol.asyncDispose]: async () => {}, + } as AssistantMessageEventStream; + }, + }); + + await agent.prompt("hello"); + assert.equal(capturedOptions?.customRuntimeOption, "present"); + }); }); diff --git a/packages/pi-agent-core/src/agent.ts b/packages/pi-agent-core/src/agent.ts index e65ae7a35..924dd8d39 100644 --- a/packages/pi-agent-core/src/agent.ts +++ b/packages/pi-agent-core/src/agent.ts @@ -108,6 +108,14 @@ export interface AgentOptions { * switches mid-session are handled correctly. */ externalToolExecution?: (model: Model) => boolean; + + /** + * Optional provider-specific options to merge into the next stream call. + * + * Use this for runtime-only callbacks or handles that should not live in + * shared agent state, such as UI bridges for external CLI providers. + */ + getProviderOptions?: (model: Model) => Record | undefined | Promise | undefined>; } /** @@ -152,6 +160,7 @@ export class Agent { private _beforeToolCall?: AgentLoopConfig["beforeToolCall"]; private _afterToolCall?: AgentLoopConfig["afterToolCall"]; private _externalToolExecution?: (model: Model) => boolean; + private _getProviderOptions?: AgentOptions["getProviderOptions"]; constructor(opts: AgentOptions = {}) { this._state = { ...this._state, ...opts.initialState }; @@ -167,6 +176,7 @@ export class Agent { this._transport = opts.transport ?? "sse"; this._maxRetryDelayMs = opts.maxRetryDelayMs; this._externalToolExecution = opts.externalToolExecution; + this._getProviderOptions = opts.getProviderOptions; } /** @@ -486,8 +496,10 @@ export class Agent { }; let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true; + const providerOptions = await this._getProviderOptions?.(model); const config: AgentLoopConfig = { + ...(providerOptions ?? {}), model, reasoning, sessionId: this._sessionId, diff --git a/packages/pi-coding-agent/src/core/sdk.ts b/packages/pi-coding-agent/src/core/sdk.ts index a0c2d943b..07ed24c53 100644 --- a/packages/pi-coding-agent/src/core/sdk.ts +++ b/packages/pi-coding-agent/src/core/sdk.ts @@ -341,6 +341,14 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} thinkingBudgets: settingsManager.getThinkingBudgets(), maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs, externalToolExecution: (m) => modelRegistry.getProviderAuthMode(m.provider) === "externalCli", + getProviderOptions: async (currentModel) => { + if (currentModel.provider !== "claude-code") return undefined; + const runner = extensionRunnerRef.current; + if (!runner?.hasUI()) return undefined; + return { + extensionUIContext: runner.getUIContext(), + }; + }, getApiKey: async (provider) => { // Use the provider argument from the in-flight request; // agent.state.model may already be switched mid-turn. diff --git a/src/resources/extensions/claude-code-cli/stream-adapter.ts b/src/resources/extensions/claude-code-cli/stream-adapter.ts index 465d48759..877456a20 100644 --- a/src/resources/extensions/claude-code-cli/stream-adapter.ts +++ b/src/resources/extensions/claude-code-cli/stream-adapter.ts @@ -16,10 +16,12 @@ import type { SimpleStreamOptions, ToolCall, } from "@gsd/pi-ai"; +import type { ExtensionUIContext } from "@gsd/pi-coding-agent"; import { EventStream } from "@gsd/pi-ai"; import { execSync } from "node:child_process"; import { PartialMessageBuilder, ZERO_USAGE, mapUsage } from "./partial-builder.js"; import { buildWorkflowMcpServers } from "../gsd/workflow-mcp.js"; +import { showInterviewRound, type Question, type RoundResult } from "../shared/tui.js"; import type { SDKAssistantMessage, SDKMessage, @@ -45,6 +47,46 @@ type ToolCallWithExternalResult = ToolCall & { externalResult?: ExternalToolResultPayload; }; +interface ClaudeCodeStreamOptions extends SimpleStreamOptions { + extensionUIContext?: ExtensionUIContext; +} + +interface SdkElicitationRequestOption { + const?: string; + title?: string; +} + +interface SdkElicitationFieldSchema { + type?: string; + title?: string; + description?: string; + oneOf?: SdkElicitationRequestOption[]; + items?: { + anyOf?: SdkElicitationRequestOption[]; + }; +} + +interface SdkElicitationRequest { + serverName: string; + message: string; + mode?: "form" | "url"; + requestedSchema?: { + type?: string; + properties?: Record; + }; +} + +interface SdkElicitationResult { + action: "accept" | "decline" | "cancel"; + content?: Record; +} + +interface ParsedElicitationQuestion extends Question { + noteFieldId?: string; +} + +const OTHER_OPTION_LABEL = "None of the above"; + // --------------------------------------------------------------------------- // Stream factory // --------------------------------------------------------------------------- @@ -172,6 +214,174 @@ export function makeStreamExhaustedErrorMessage(model: string, lastTextContent: return message; } +function readElicitationChoices(options: SdkElicitationRequestOption[] | undefined): string[] { + if (!Array.isArray(options)) return []; + return options + .map((option) => (typeof option?.const === "string" ? option.const : typeof option?.title === "string" ? option.title : "")) + .filter((option): option is string => option.length > 0); +} + +export function parseAskUserQuestionsElicitation( + request: Pick, +): ParsedElicitationQuestion[] | null { + if (request.mode && request.mode !== "form") return null; + const properties = request.requestedSchema?.properties; + if (!properties || typeof properties !== "object") return null; + + const questions: ParsedElicitationQuestion[] = []; + + for (const [fieldId, rawField] of Object.entries(properties)) { + if (fieldId.endsWith("__note")) continue; + if (!rawField || typeof rawField !== "object") return null; + + const header = typeof rawField.title === "string" && rawField.title.length > 0 ? rawField.title : fieldId; + const question = typeof rawField.description === "string" ? rawField.description : ""; + + if (rawField.type === "array") { + const options = readElicitationChoices(rawField.items?.anyOf).map((label) => ({ label, description: "" })); + if (options.length === 0) return null; + questions.push({ + id: fieldId, + header, + question, + options, + allowMultiple: true, + }); + continue; + } + + if (rawField.type === "string") { + const noteFieldId = Object.prototype.hasOwnProperty.call(properties, `${fieldId}__note`) + ? `${fieldId}__note` + : undefined; + const options = readElicitationChoices(rawField.oneOf) + .filter((label) => label !== OTHER_OPTION_LABEL) + .map((label) => ({ label, description: "" })); + if (options.length === 0) return null; + questions.push({ + id: fieldId, + header, + question, + options, + noteFieldId, + }); + continue; + } + + return null; + } + + return questions.length > 0 ? questions : null; +} + +export function roundResultToElicitationContent( + questions: ParsedElicitationQuestion[], + result: RoundResult, +): Record { + const content: Record = {}; + + for (const question of questions) { + const answer = result.answers[question.id]; + if (!answer) continue; + + if (question.allowMultiple) { + const selected = Array.isArray(answer.selected) ? answer.selected : [answer.selected]; + content[question.id] = selected; + continue; + } + + const selected = Array.isArray(answer.selected) ? answer.selected[0] ?? "" : answer.selected; + content[question.id] = selected; + if (question.noteFieldId && selected === OTHER_OPTION_LABEL && answer.notes.trim().length > 0) { + content[question.noteFieldId] = answer.notes.trim(); + } + } + + return content; +} + +function buildElicitationPromptTitle(request: SdkElicitationRequest, question: ParsedElicitationQuestion): string { + const parts = [ + request.serverName ? `[${request.serverName}]` : "", + question.header, + question.question, + ].filter((part) => part && part.trim().length > 0); + return parts.join("\n\n"); +} + +async function promptElicitationWithDialogs( + request: SdkElicitationRequest, + questions: ParsedElicitationQuestion[], + ui: ExtensionUIContext, + signal: AbortSignal, +): Promise { + const content: Record = {}; + + for (const question of questions) { + const title = buildElicitationPromptTitle(request, question); + + if (question.allowMultiple) { + const selected = await ui.select(title, question.options.map((option) => option.label), { + allowMultiple: true, + signal, + }); + if (Array.isArray(selected)) { + if (selected.length === 0) return { action: "cancel" }; + content[question.id] = selected; + continue; + } + if (typeof selected === "string" && selected.length > 0) { + content[question.id] = [selected]; + continue; + } + return { action: "cancel" }; + } + + const selected = await ui.select(title, [...question.options.map((option) => option.label), OTHER_OPTION_LABEL], { signal }); + if (typeof selected !== "string" || selected.length === 0) { + return { action: "cancel" }; + } + + content[question.id] = selected; + if (question.noteFieldId && selected === OTHER_OPTION_LABEL) { + const note = await ui.input(`${question.header} note`, "Explain your answer", { signal }); + if (note === undefined) return { action: "cancel" }; + if (note.trim().length > 0) { + content[question.noteFieldId] = note.trim(); + } + } + } + + return { action: "accept", content }; +} + +export function createClaudeCodeElicitationHandler( + ui: ExtensionUIContext | undefined, +): ((request: SdkElicitationRequest, options: { signal: AbortSignal }) => Promise) | undefined { + if (!ui) return undefined; + + return async (request, { signal }) => { + if (request.mode === "url") { + return { action: "decline" }; + } + + const questions = parseAskUserQuestionsElicitation(request); + if (!questions) { + return { action: "decline" }; + } + + const interviewResult = await showInterviewRound(questions, { signal }, { ui } as any).catch(() => undefined); + if (interviewResult && Object.keys(interviewResult.answers).length > 0) { + return { + action: "accept", + content: roundResultToElicitationContent(questions, interviewResult), + }; + } + + return promptElicitationWithDialogs(request, questions, ui, signal); + }; +} + // --------------------------------------------------------------------------- // SDK options builder // --------------------------------------------------------------------------- @@ -182,7 +392,11 @@ export function makeStreamExhaustedErrorMessage(model: string, lastTextContent: * Extracted for testability — callers can verify session persistence, * beta flags, and other configuration without mocking the full SDK. */ -export function buildSdkOptions(modelId: string, prompt: string): Record { +export function buildSdkOptions( + modelId: string, + prompt: string, + extraOptions: Record = {}, +): Record { const mcpServers = buildWorkflowMcpServers(); return { pathToClaudeCodeExecutable: getClaudePath(), @@ -196,6 +410,7 @@ export function buildSdkOptions(modelId: string, prompt: string): Record { process.env.GSD_CLI_PATH = prev.GSD_CLI_PATH; } }); + + test("buildSdkOptions preserves runtime callbacks such as onElicitation", () => { + const prev = { + GSD_WORKFLOW_MCP_COMMAND: process.env.GSD_WORKFLOW_MCP_COMMAND, + GSD_WORKFLOW_MCP_NAME: process.env.GSD_WORKFLOW_MCP_NAME, + GSD_WORKFLOW_MCP_ARGS: process.env.GSD_WORKFLOW_MCP_ARGS, + GSD_WORKFLOW_MCP_ENV: process.env.GSD_WORKFLOW_MCP_ENV, + GSD_WORKFLOW_MCP_CWD: process.env.GSD_WORKFLOW_MCP_CWD, + }; + const onElicitation = async () => ({ action: "decline" as const }); + try { + delete process.env.GSD_WORKFLOW_MCP_COMMAND; + delete process.env.GSD_WORKFLOW_MCP_NAME; + delete process.env.GSD_WORKFLOW_MCP_ARGS; + delete process.env.GSD_WORKFLOW_MCP_ENV; + delete process.env.GSD_WORKFLOW_MCP_CWD; + const options = buildSdkOptions("claude-sonnet-4-20250514", "test", { onElicitation }); + assert.equal(options.onElicitation, onElicitation); + } finally { + process.env.GSD_WORKFLOW_MCP_COMMAND = prev.GSD_WORKFLOW_MCP_COMMAND; + process.env.GSD_WORKFLOW_MCP_NAME = prev.GSD_WORKFLOW_MCP_NAME; + process.env.GSD_WORKFLOW_MCP_ARGS = prev.GSD_WORKFLOW_MCP_ARGS; + process.env.GSD_WORKFLOW_MCP_ENV = prev.GSD_WORKFLOW_MCP_ENV; + process.env.GSD_WORKFLOW_MCP_CWD = prev.GSD_WORKFLOW_MCP_CWD; + } + }); +}); + +describe("stream-adapter — MCP elicitation bridge", () => { + const askUserQuestionsRequest = { + serverName: "gsd-workflow", + message: "Please answer the following question(s).", + mode: "form" as const, + requestedSchema: { + type: "object" as const, + properties: { + storage_scope: { + type: "string", + title: "Storage", + description: "Does this app need to sync across devices?", + oneOf: [ + { const: "Local-only (Recommended)", title: "Local-only (Recommended)" }, + { const: "Cloud-synced", title: "Cloud-synced" }, + { const: "None of the above", title: "None of the above" }, + ], + }, + storage_scope__note: { + type: "string", + title: "Storage Note", + description: "Optional note for None of the above.", + }, + platform: { + type: "array", + title: "Platform", + description: "Where should it run?", + items: { + anyOf: [ + { const: "Web", title: "Web" }, + { const: "Desktop", title: "Desktop" }, + { const: "Mobile", title: "Mobile" }, + ], + }, + }, + }, + }, + }; + + test("parseAskUserQuestionsElicitation rebuilds interview questions from the MCP schema", () => { + const questions = parseAskUserQuestionsElicitation(askUserQuestionsRequest); + assert.deepEqual(questions, [ + { + id: "storage_scope", + header: "Storage", + question: "Does this app need to sync across devices?", + options: [ + { label: "Local-only (Recommended)", description: "" }, + { label: "Cloud-synced", description: "" }, + ], + noteFieldId: "storage_scope__note", + }, + { + id: "platform", + header: "Platform", + question: "Where should it run?", + options: [ + { label: "Web", description: "" }, + { label: "Desktop", description: "" }, + { label: "Mobile", description: "" }, + ], + allowMultiple: true, + }, + ]); + }); + + test("roundResultToElicitationContent preserves notes for None of the above", () => { + const questions = parseAskUserQuestionsElicitation(askUserQuestionsRequest); + assert.ok(questions); + + const content = roundResultToElicitationContent(questions, { + endInterview: false, + answers: { + storage_scope: { + selected: "None of the above", + notes: "Needs selective sync later", + }, + platform: { + selected: ["Web", "Desktop"], + notes: "", + }, + }, + }); + + assert.deepEqual(content, { + storage_scope: "None of the above", + storage_scope__note: "Needs selective sync later", + platform: ["Web", "Desktop"], + }); + }); + + test("createClaudeCodeElicitationHandler accepts interview-style answers from custom UI", async () => { + const handler = createClaudeCodeElicitationHandler({ + custom: async (_factory: any) => ({ + endInterview: false, + answers: { + storage_scope: { + selected: "Cloud-synced", + notes: "", + }, + platform: { + selected: ["Web", "Mobile"], + notes: "", + }, + }, + }), + } as any); + + assert.ok(handler); + const result = await handler!(askUserQuestionsRequest, { signal: new AbortController().signal }); + assert.deepEqual(result, { + action: "accept", + content: { + storage_scope: "Cloud-synced", + platform: ["Web", "Mobile"], + }, + }); + }); + + test("createClaudeCodeElicitationHandler falls back to dialog prompts when custom UI is unavailable", async () => { + const ui = { + custom: async () => undefined, + select: async (_title: string, options: string[], opts?: { allowMultiple?: boolean }) => { + if (opts?.allowMultiple) return ["Desktop", "Mobile"]; + return options.includes("None of the above") ? "None of the above" : options[0]; + }, + input: async () => "CLI-only deployment target", + }; + const handler = createClaudeCodeElicitationHandler(ui as any); + assert.ok(handler); + + const result = await handler!(askUserQuestionsRequest, { signal: new AbortController().signal }); + assert.deepEqual(result, { + action: "accept", + content: { + storage_scope: "None of the above", + storage_scope__note: "CLI-only deployment target", + platform: ["Desktop", "Mobile"], + }, + }); + }); }); describe("stream-adapter — Windows Claude path lookup (#3770)", () => {