fix claude code mcp elicitation bridge

This commit is contained in:
Jeremy 2026-04-10 19:24:51 -05:00
parent 5a940856c1
commit d64056f833
5 changed files with 501 additions and 2 deletions

View file

@ -8,6 +8,8 @@ import assert from "node:assert/strict";
import { readFileSync } from "node:fs";
import { join, dirname } from "node:path";
import { fileURLToPath } from "node:url";
import { Agent } from "./agent.ts";
import { getModel, type AssistantMessageEventStream } from "@gsd/pi-ai";
const __dirname = dirname(fileURLToPath(import.meta.url));
@ -50,4 +52,84 @@ describe("Agent — activeInferenceModel (#1844 Bug 2)", () => {
assert.ok(setLine < abortLine,
"activeInferenceModel must be set before streaming infrastructure is created");
});
it("getProviderOptions are forwarded into the provider stream call", async () => {
let capturedOptions: Record<string, unknown> | undefined;
const agent = new Agent({
initialState: {
model: getModel("anthropic", "claude-3-5-sonnet-20241022"),
systemPrompt: "test",
tools: [],
},
getProviderOptions: async () => ({ customRuntimeOption: "present" }),
streamFn: (_model, _context, options): AssistantMessageEventStream => {
capturedOptions = options as Record<string, unknown> | undefined;
return {
async *[Symbol.asyncIterator]() {
yield {
type: "start",
partial: {
role: "assistant",
content: [],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-3-5-sonnet-20241022",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
},
};
yield {
type: "done",
message: {
role: "assistant",
content: [{ type: "text", text: "ok" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-3-5-sonnet-20241022",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
},
};
},
result: async () => ({
role: "assistant",
content: [{ type: "text", text: "ok" }],
api: "anthropic-messages",
provider: "anthropic",
model: "claude-3-5-sonnet-20241022",
usage: {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
stopReason: "stop",
timestamp: Date.now(),
}),
[Symbol.asyncDispose]: async () => {},
} as AssistantMessageEventStream;
},
});
await agent.prompt("hello");
assert.equal(capturedOptions?.customRuntimeOption, "present");
});
});

View file

@ -108,6 +108,14 @@ export interface AgentOptions {
* switches mid-session are handled correctly.
*/
externalToolExecution?: (model: Model<any>) => boolean;
/**
* Optional provider-specific options to merge into the next stream call.
*
* Use this for runtime-only callbacks or handles that should not live in
* shared agent state, such as UI bridges for external CLI providers.
*/
getProviderOptions?: (model: Model<any>) => Record<string, unknown> | undefined | Promise<Record<string, unknown> | undefined>;
}
/**
@ -152,6 +160,7 @@ export class Agent {
private _beforeToolCall?: AgentLoopConfig["beforeToolCall"];
private _afterToolCall?: AgentLoopConfig["afterToolCall"];
private _externalToolExecution?: (model: Model<any>) => boolean;
private _getProviderOptions?: AgentOptions["getProviderOptions"];
constructor(opts: AgentOptions = {}) {
this._state = { ...this._state, ...opts.initialState };
@ -167,6 +176,7 @@ export class Agent {
this._transport = opts.transport ?? "sse";
this._maxRetryDelayMs = opts.maxRetryDelayMs;
this._externalToolExecution = opts.externalToolExecution;
this._getProviderOptions = opts.getProviderOptions;
}
/**
@ -486,8 +496,10 @@ export class Agent {
};
let skipInitialSteeringPoll = options?.skipInitialSteeringPoll === true;
const providerOptions = await this._getProviderOptions?.(model);
const config: AgentLoopConfig = {
...(providerOptions ?? {}),
model,
reasoning,
sessionId: this._sessionId,

View file

@ -341,6 +341,14 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
thinkingBudgets: settingsManager.getThinkingBudgets(),
maxRetryDelayMs: settingsManager.getRetrySettings().maxDelayMs,
externalToolExecution: (m) => modelRegistry.getProviderAuthMode(m.provider) === "externalCli",
getProviderOptions: async (currentModel) => {
if (currentModel.provider !== "claude-code") return undefined;
const runner = extensionRunnerRef.current;
if (!runner?.hasUI()) return undefined;
return {
extensionUIContext: runner.getUIContext(),
};
},
getApiKey: async (provider) => {
// Use the provider argument from the in-flight request;
// agent.state.model may already be switched mid-turn.

View file

@ -16,10 +16,12 @@ import type {
SimpleStreamOptions,
ToolCall,
} from "@gsd/pi-ai";
import type { ExtensionUIContext } from "@gsd/pi-coding-agent";
import { EventStream } from "@gsd/pi-ai";
import { execSync } from "node:child_process";
import { PartialMessageBuilder, ZERO_USAGE, mapUsage } from "./partial-builder.js";
import { buildWorkflowMcpServers } from "../gsd/workflow-mcp.js";
import { showInterviewRound, type Question, type RoundResult } from "../shared/tui.js";
import type {
SDKAssistantMessage,
SDKMessage,
@ -45,6 +47,46 @@ type ToolCallWithExternalResult = ToolCall & {
externalResult?: ExternalToolResultPayload;
};
interface ClaudeCodeStreamOptions extends SimpleStreamOptions {
extensionUIContext?: ExtensionUIContext;
}
interface SdkElicitationRequestOption {
const?: string;
title?: string;
}
interface SdkElicitationFieldSchema {
type?: string;
title?: string;
description?: string;
oneOf?: SdkElicitationRequestOption[];
items?: {
anyOf?: SdkElicitationRequestOption[];
};
}
interface SdkElicitationRequest {
serverName: string;
message: string;
mode?: "form" | "url";
requestedSchema?: {
type?: string;
properties?: Record<string, SdkElicitationFieldSchema>;
};
}
interface SdkElicitationResult {
action: "accept" | "decline" | "cancel";
content?: Record<string, string | string[]>;
}
interface ParsedElicitationQuestion extends Question {
noteFieldId?: string;
}
const OTHER_OPTION_LABEL = "None of the above";
// ---------------------------------------------------------------------------
// Stream factory
// ---------------------------------------------------------------------------
@ -172,6 +214,174 @@ export function makeStreamExhaustedErrorMessage(model: string, lastTextContent:
return message;
}
function readElicitationChoices(options: SdkElicitationRequestOption[] | undefined): string[] {
if (!Array.isArray(options)) return [];
return options
.map((option) => (typeof option?.const === "string" ? option.const : typeof option?.title === "string" ? option.title : ""))
.filter((option): option is string => option.length > 0);
}
export function parseAskUserQuestionsElicitation(
request: Pick<SdkElicitationRequest, "mode" | "requestedSchema">,
): ParsedElicitationQuestion[] | null {
if (request.mode && request.mode !== "form") return null;
const properties = request.requestedSchema?.properties;
if (!properties || typeof properties !== "object") return null;
const questions: ParsedElicitationQuestion[] = [];
for (const [fieldId, rawField] of Object.entries(properties)) {
if (fieldId.endsWith("__note")) continue;
if (!rawField || typeof rawField !== "object") return null;
const header = typeof rawField.title === "string" && rawField.title.length > 0 ? rawField.title : fieldId;
const question = typeof rawField.description === "string" ? rawField.description : "";
if (rawField.type === "array") {
const options = readElicitationChoices(rawField.items?.anyOf).map((label) => ({ label, description: "" }));
if (options.length === 0) return null;
questions.push({
id: fieldId,
header,
question,
options,
allowMultiple: true,
});
continue;
}
if (rawField.type === "string") {
const noteFieldId = Object.prototype.hasOwnProperty.call(properties, `${fieldId}__note`)
? `${fieldId}__note`
: undefined;
const options = readElicitationChoices(rawField.oneOf)
.filter((label) => label !== OTHER_OPTION_LABEL)
.map((label) => ({ label, description: "" }));
if (options.length === 0) return null;
questions.push({
id: fieldId,
header,
question,
options,
noteFieldId,
});
continue;
}
return null;
}
return questions.length > 0 ? questions : null;
}
export function roundResultToElicitationContent(
questions: ParsedElicitationQuestion[],
result: RoundResult,
): Record<string, string | string[]> {
const content: Record<string, string | string[]> = {};
for (const question of questions) {
const answer = result.answers[question.id];
if (!answer) continue;
if (question.allowMultiple) {
const selected = Array.isArray(answer.selected) ? answer.selected : [answer.selected];
content[question.id] = selected;
continue;
}
const selected = Array.isArray(answer.selected) ? answer.selected[0] ?? "" : answer.selected;
content[question.id] = selected;
if (question.noteFieldId && selected === OTHER_OPTION_LABEL && answer.notes.trim().length > 0) {
content[question.noteFieldId] = answer.notes.trim();
}
}
return content;
}
function buildElicitationPromptTitle(request: SdkElicitationRequest, question: ParsedElicitationQuestion): string {
const parts = [
request.serverName ? `[${request.serverName}]` : "",
question.header,
question.question,
].filter((part) => part && part.trim().length > 0);
return parts.join("\n\n");
}
async function promptElicitationWithDialogs(
request: SdkElicitationRequest,
questions: ParsedElicitationQuestion[],
ui: ExtensionUIContext,
signal: AbortSignal,
): Promise<SdkElicitationResult> {
const content: Record<string, string | string[]> = {};
for (const question of questions) {
const title = buildElicitationPromptTitle(request, question);
if (question.allowMultiple) {
const selected = await ui.select(title, question.options.map((option) => option.label), {
allowMultiple: true,
signal,
});
if (Array.isArray(selected)) {
if (selected.length === 0) return { action: "cancel" };
content[question.id] = selected;
continue;
}
if (typeof selected === "string" && selected.length > 0) {
content[question.id] = [selected];
continue;
}
return { action: "cancel" };
}
const selected = await ui.select(title, [...question.options.map((option) => option.label), OTHER_OPTION_LABEL], { signal });
if (typeof selected !== "string" || selected.length === 0) {
return { action: "cancel" };
}
content[question.id] = selected;
if (question.noteFieldId && selected === OTHER_OPTION_LABEL) {
const note = await ui.input(`${question.header} note`, "Explain your answer", { signal });
if (note === undefined) return { action: "cancel" };
if (note.trim().length > 0) {
content[question.noteFieldId] = note.trim();
}
}
}
return { action: "accept", content };
}
export function createClaudeCodeElicitationHandler(
ui: ExtensionUIContext | undefined,
): ((request: SdkElicitationRequest, options: { signal: AbortSignal }) => Promise<SdkElicitationResult>) | undefined {
if (!ui) return undefined;
return async (request, { signal }) => {
if (request.mode === "url") {
return { action: "decline" };
}
const questions = parseAskUserQuestionsElicitation(request);
if (!questions) {
return { action: "decline" };
}
const interviewResult = await showInterviewRound(questions, { signal }, { ui } as any).catch(() => undefined);
if (interviewResult && Object.keys(interviewResult.answers).length > 0) {
return {
action: "accept",
content: roundResultToElicitationContent(questions, interviewResult),
};
}
return promptElicitationWithDialogs(request, questions, ui, signal);
};
}
// ---------------------------------------------------------------------------
// SDK options builder
// ---------------------------------------------------------------------------
@ -182,7 +392,11 @@ export function makeStreamExhaustedErrorMessage(model: string, lastTextContent:
* Extracted for testability callers can verify session persistence,
* beta flags, and other configuration without mocking the full SDK.
*/
export function buildSdkOptions(modelId: string, prompt: string): Record<string, unknown> {
export function buildSdkOptions(
modelId: string,
prompt: string,
extraOptions: Record<string, unknown> = {},
): Record<string, unknown> {
const mcpServers = buildWorkflowMcpServers();
return {
pathToClaudeCodeExecutable: getClaudePath(),
@ -196,6 +410,7 @@ export function buildSdkOptions(modelId: string, prompt: string): Record<string,
systemPrompt: { type: "preset", preset: "claude_code" },
...(mcpServers ? { mcpServers } : {}),
betas: modelId.includes("sonnet") ? ["context-1m-2025-08-07"] : [],
...extraOptions,
};
}
@ -359,7 +574,17 @@ async function pumpSdkMessages(
}
const prompt = buildPromptFromContext(context);
const sdkOpts = buildSdkOptions(modelId, prompt);
const sdkOpts = buildSdkOptions(
modelId,
prompt,
typeof (options as ClaudeCodeStreamOptions | undefined)?.extensionUIContext === "object"
? {
onElicitation: createClaudeCodeElicitationHandler(
(options as ClaudeCodeStreamOptions | undefined)?.extensionUIContext,
),
}
: {},
);
const queryResult = sdk.query({
prompt,

View file

@ -7,9 +7,12 @@ import {
makeStreamExhaustedErrorMessage,
buildPromptFromContext,
buildSdkOptions,
createClaudeCodeElicitationHandler,
extractToolResultsFromSdkUserMessage,
getClaudeLookupCommand,
parseAskUserQuestionsElicitation,
parseClaudeLookupOutput,
roundResultToElicitationContent,
} from "../stream-adapter.ts";
import type { Context, Message } from "@gsd/pi-ai";
import type { SDKUserMessage } from "../sdk-types.ts";
@ -309,6 +312,175 @@ describe("stream-adapter — session persistence (#2859)", () => {
process.env.GSD_CLI_PATH = prev.GSD_CLI_PATH;
}
});
test("buildSdkOptions preserves runtime callbacks such as onElicitation", () => {
const prev = {
GSD_WORKFLOW_MCP_COMMAND: process.env.GSD_WORKFLOW_MCP_COMMAND,
GSD_WORKFLOW_MCP_NAME: process.env.GSD_WORKFLOW_MCP_NAME,
GSD_WORKFLOW_MCP_ARGS: process.env.GSD_WORKFLOW_MCP_ARGS,
GSD_WORKFLOW_MCP_ENV: process.env.GSD_WORKFLOW_MCP_ENV,
GSD_WORKFLOW_MCP_CWD: process.env.GSD_WORKFLOW_MCP_CWD,
};
const onElicitation = async () => ({ action: "decline" as const });
try {
delete process.env.GSD_WORKFLOW_MCP_COMMAND;
delete process.env.GSD_WORKFLOW_MCP_NAME;
delete process.env.GSD_WORKFLOW_MCP_ARGS;
delete process.env.GSD_WORKFLOW_MCP_ENV;
delete process.env.GSD_WORKFLOW_MCP_CWD;
const options = buildSdkOptions("claude-sonnet-4-20250514", "test", { onElicitation });
assert.equal(options.onElicitation, onElicitation);
} finally {
process.env.GSD_WORKFLOW_MCP_COMMAND = prev.GSD_WORKFLOW_MCP_COMMAND;
process.env.GSD_WORKFLOW_MCP_NAME = prev.GSD_WORKFLOW_MCP_NAME;
process.env.GSD_WORKFLOW_MCP_ARGS = prev.GSD_WORKFLOW_MCP_ARGS;
process.env.GSD_WORKFLOW_MCP_ENV = prev.GSD_WORKFLOW_MCP_ENV;
process.env.GSD_WORKFLOW_MCP_CWD = prev.GSD_WORKFLOW_MCP_CWD;
}
});
});
describe("stream-adapter — MCP elicitation bridge", () => {
const askUserQuestionsRequest = {
serverName: "gsd-workflow",
message: "Please answer the following question(s).",
mode: "form" as const,
requestedSchema: {
type: "object" as const,
properties: {
storage_scope: {
type: "string",
title: "Storage",
description: "Does this app need to sync across devices?",
oneOf: [
{ const: "Local-only (Recommended)", title: "Local-only (Recommended)" },
{ const: "Cloud-synced", title: "Cloud-synced" },
{ const: "None of the above", title: "None of the above" },
],
},
storage_scope__note: {
type: "string",
title: "Storage Note",
description: "Optional note for None of the above.",
},
platform: {
type: "array",
title: "Platform",
description: "Where should it run?",
items: {
anyOf: [
{ const: "Web", title: "Web" },
{ const: "Desktop", title: "Desktop" },
{ const: "Mobile", title: "Mobile" },
],
},
},
},
},
};
test("parseAskUserQuestionsElicitation rebuilds interview questions from the MCP schema", () => {
const questions = parseAskUserQuestionsElicitation(askUserQuestionsRequest);
assert.deepEqual(questions, [
{
id: "storage_scope",
header: "Storage",
question: "Does this app need to sync across devices?",
options: [
{ label: "Local-only (Recommended)", description: "" },
{ label: "Cloud-synced", description: "" },
],
noteFieldId: "storage_scope__note",
},
{
id: "platform",
header: "Platform",
question: "Where should it run?",
options: [
{ label: "Web", description: "" },
{ label: "Desktop", description: "" },
{ label: "Mobile", description: "" },
],
allowMultiple: true,
},
]);
});
test("roundResultToElicitationContent preserves notes for None of the above", () => {
const questions = parseAskUserQuestionsElicitation(askUserQuestionsRequest);
assert.ok(questions);
const content = roundResultToElicitationContent(questions, {
endInterview: false,
answers: {
storage_scope: {
selected: "None of the above",
notes: "Needs selective sync later",
},
platform: {
selected: ["Web", "Desktop"],
notes: "",
},
},
});
assert.deepEqual(content, {
storage_scope: "None of the above",
storage_scope__note: "Needs selective sync later",
platform: ["Web", "Desktop"],
});
});
test("createClaudeCodeElicitationHandler accepts interview-style answers from custom UI", async () => {
const handler = createClaudeCodeElicitationHandler({
custom: async (_factory: any) => ({
endInterview: false,
answers: {
storage_scope: {
selected: "Cloud-synced",
notes: "",
},
platform: {
selected: ["Web", "Mobile"],
notes: "",
},
},
}),
} as any);
assert.ok(handler);
const result = await handler!(askUserQuestionsRequest, { signal: new AbortController().signal });
assert.deepEqual(result, {
action: "accept",
content: {
storage_scope: "Cloud-synced",
platform: ["Web", "Mobile"],
},
});
});
test("createClaudeCodeElicitationHandler falls back to dialog prompts when custom UI is unavailable", async () => {
const ui = {
custom: async () => undefined,
select: async (_title: string, options: string[], opts?: { allowMultiple?: boolean }) => {
if (opts?.allowMultiple) return ["Desktop", "Mobile"];
return options.includes("None of the above") ? "None of the above" : options[0];
},
input: async () => "CLI-only deployment target",
};
const handler = createClaudeCodeElicitationHandler(ui as any);
assert.ok(handler);
const result = await handler!(askUserQuestionsRequest, { signal: new AbortController().signal });
assert.deepEqual(result, {
action: "accept",
content: {
storage_scope: "None of the above",
storage_scope__note: "CLI-only deployment target",
platform: ["Desktop", "Mobile"],
},
});
});
});
describe("stream-adapter — Windows Claude path lookup (#3770)", () => {