sf snapshot: pre-dispatch, uncommitted changes after 1497m inactivity
This commit is contained in:
parent
6384c5b44c
commit
f712c339b3
1252 changed files with 177 additions and 425192 deletions
|
|
@ -60,21 +60,59 @@ function copyNonTsFiles(srcDir, destDir) {
|
|||
|
||||
rmSync(distResources, { recursive: true, force: true });
|
||||
|
||||
const tscBin = require.resolve("typescript/bin/tsc");
|
||||
const compile = spawnSync(
|
||||
process.execPath,
|
||||
[tscBin, "--project", resourcesTsconfig],
|
||||
{
|
||||
cwd: root,
|
||||
stdio: "inherit",
|
||||
},
|
||||
);
|
||||
// Check if there are any .ts files to compile
|
||||
function hasTsFilesRecursive(dir) {
|
||||
for (const entry of readdirSync(dir, { withFileTypes: true })) {
|
||||
const fullPath = join(dir, entry.name);
|
||||
if (entry.isDirectory()) {
|
||||
if (hasTsFilesRecursive(fullPath)) return true;
|
||||
} else if (entry.name.endsWith(".ts") && !entry.name.endsWith(".d.ts")) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
const hasTsFiles = hasTsFilesRecursive(srcResources);
|
||||
|
||||
if (compile.status !== 0) {
|
||||
process.exit(compile.status ?? 1);
|
||||
if (hasTsFiles) {
|
||||
const tscBin = require.resolve("typescript/bin/tsc");
|
||||
const compile = spawnSync(
|
||||
process.execPath,
|
||||
[tscBin, "--project", resourcesTsconfig],
|
||||
{
|
||||
cwd: root,
|
||||
stdio: "inherit",
|
||||
},
|
||||
);
|
||||
|
||||
if (compile.status !== 0) {
|
||||
process.exit(compile.status ?? 1);
|
||||
}
|
||||
} else {
|
||||
// No .ts files — just create the dist/resources directory and copy .js files
|
||||
mkdirSync(distResources, { recursive: true });
|
||||
}
|
||||
|
||||
copyNonTsFiles(srcResources, distResources);
|
||||
|
||||
// Also copy .js files since they're not compiled from .ts
|
||||
function copyJsFiles(srcDir, destDir) {
|
||||
for (const entry of readdirSync(srcDir, { withFileTypes: true })) {
|
||||
const srcPath = join(srcDir, entry.name);
|
||||
const destPath = join(destDir, entry.name);
|
||||
|
||||
if (entry.isDirectory()) {
|
||||
copyJsFiles(srcPath, destPath);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (entry.name.endsWith(".js")) {
|
||||
mkdirSync(dirname(destPath), { recursive: true });
|
||||
copyFileSync(srcPath, destPath);
|
||||
}
|
||||
}
|
||||
}
|
||||
copyJsFiles(srcResources, distResources);
|
||||
writeFileSync(
|
||||
join(distResources, ".sf-resource-build-stamp"),
|
||||
`${new Date().toISOString()}\n`,
|
||||
|
|
|
|||
|
|
@ -1,637 +0,0 @@
|
|||
/**
|
||||
* Request User Input — LLM tool for asking the user questions
|
||||
*
|
||||
* Thin wrapper around the shared interview-ui. The LLM presents 1-3
|
||||
* questions with 2-3 options each. Each question can be single-select (default)
|
||||
* or multi-select (allowMultiple: true). A free-form "None of the above" option
|
||||
* is added automatically to single-select questions.
|
||||
*
|
||||
* Based on: https://github.com/openai/codex (codex-rs/core/src/tools/handlers/ask_user_questions.rs)
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import {
|
||||
formatRoundResultForTool,
|
||||
type RoundResult,
|
||||
} from "@singularity-forge/pi-agent-core";
|
||||
import type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
import { sanitizeError } from "./shared/sanitize.js";
|
||||
import {
|
||||
type Question,
|
||||
type QuestionOption,
|
||||
showInterviewRound,
|
||||
} from "./shared/tui.js";
|
||||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface LocalResultDetails {
|
||||
remote?: false;
|
||||
questions: Question[];
|
||||
response: RoundResult | null;
|
||||
cancelled: boolean;
|
||||
}
|
||||
|
||||
interface RemoteResultDetails {
|
||||
remote: true;
|
||||
channel: string;
|
||||
timed_out: boolean;
|
||||
promptId?: string;
|
||||
threadUrl?: string;
|
||||
status?: string;
|
||||
autoResolved?: boolean;
|
||||
autoResolveStrategy?: string;
|
||||
questions?: Question[];
|
||||
response?: RoundResult;
|
||||
error?: boolean;
|
||||
}
|
||||
|
||||
type AskUserQuestionsDetails = LocalResultDetails | RemoteResultDetails;
|
||||
|
||||
// ─── Schema ───────────────────────────────────────────────────────────────────
|
||||
|
||||
const OptionSchema = Type.Object({
|
||||
label: Type.String({ description: "User-facing label (1-5 words)" }),
|
||||
description: Type.String({
|
||||
description: "One short sentence explaining impact/tradeoff if selected",
|
||||
}),
|
||||
});
|
||||
|
||||
const QuestionSchema = Type.Object({
|
||||
id: Type.String({
|
||||
description: "Stable identifier for mapping answers (snake_case)",
|
||||
}),
|
||||
header: Type.String({
|
||||
description: "Short header label shown in the UI (12 or fewer chars)",
|
||||
}),
|
||||
question: Type.String({
|
||||
description: "Single-sentence prompt shown to the user",
|
||||
}),
|
||||
options: Type.Array(OptionSchema, {
|
||||
description:
|
||||
'Provide 2-3 mutually exclusive choices for single-select, or any number for multi-select. Put the recommended option first and suffix its label with "(Recommended)". Do not include an "Other" option for single-select; the client adds a free-form "None of the above" option automatically.',
|
||||
}),
|
||||
allowMultiple: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"If true, the user can select multiple options using SPACE to toggle and ENTER to confirm. No 'None of the above' option is added. Default: false.",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
const AskUserQuestionsParams = Type.Object({
|
||||
questions: Type.Array(QuestionSchema, {
|
||||
description: "Questions to show the user. Prefer 1 and do not exceed 3.",
|
||||
}),
|
||||
});
|
||||
|
||||
// ─── Per-turn deduplication ──────────────────────────────────────────────────
|
||||
// Prevents duplicate question dispatches (especially to remote channels like
|
||||
// Discord) when the LLM calls ask_user_questions multiple times with the same
|
||||
// questions in a single turn. Keyed by full canonicalized payload (id, header,
|
||||
// question, options, allowMultiple) — not just IDs — so that calls with the
|
||||
// same IDs but different text/options are treated as distinct.
|
||||
|
||||
import { createHash } from "node:crypto";
|
||||
|
||||
interface CachedResult {
|
||||
content: { type: "text"; text: string }[];
|
||||
details: AskUserQuestionsDetails;
|
||||
}
|
||||
|
||||
const turnCache = new Map<string, CachedResult>();
|
||||
|
||||
/** @internal Exported for testing only. */
|
||||
export function questionSignature(questions: Question[]): string {
|
||||
const canonical = questions
|
||||
.map((q) => ({
|
||||
id: q.id,
|
||||
header: q.header,
|
||||
question: q.question,
|
||||
options: (q.options || []).map((o) => ({
|
||||
label: o.label,
|
||||
description: o.description,
|
||||
})),
|
||||
allowMultiple: !!q.allowMultiple,
|
||||
}))
|
||||
.sort((a, b) => a.id.localeCompare(b.id));
|
||||
return createHash("sha256")
|
||||
.update(JSON.stringify(canonical))
|
||||
.digest("hex")
|
||||
.slice(0, 16);
|
||||
}
|
||||
|
||||
/** Reset the dedup cache. Called on session boundaries. */
|
||||
export function resetAskUserQuestionsCache(): void {
|
||||
turnCache.clear();
|
||||
}
|
||||
|
||||
// ─── Race helper ─────────────────────────────────────────────────────────────
|
||||
|
||||
interface RaceableResult {
|
||||
content: { type: "text"; text: string }[];
|
||||
details?: unknown;
|
||||
}
|
||||
|
||||
/** @internal Exported for tests. */
|
||||
export function isUsableRemoteQuestionResult(
|
||||
details: Record<string, unknown> | undefined,
|
||||
): boolean {
|
||||
if (details?.error || details?.cancelled) return false;
|
||||
if (details?.timed_out && details.autoResolved !== true) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Race a remote channel dispatch against the local TUI. The first to produce
|
||||
* a valid (non-error, non-timeout) result wins. The loser is cancelled via
|
||||
* the shared AbortController.
|
||||
*
|
||||
* If the local TUI responds first, the remote poll is aborted (the message
|
||||
* stays in Discord/Slack but polling stops). If remote responds first, the
|
||||
* local TUI prompt is cancelled.
|
||||
*
|
||||
* Returns null only when both sides fail or are cancelled.
|
||||
*/
|
||||
async function raceRemoteAndLocal(
|
||||
startRemote: () => Promise<RaceableResult | null>,
|
||||
startLocal: () => Promise<RoundResult | null | undefined>,
|
||||
controller: AbortController,
|
||||
questions: Question[],
|
||||
): Promise<RaceableResult | null> {
|
||||
// Wrap local TUI result into the same shape as remote results
|
||||
const localPromise = startLocal()
|
||||
.then((result): RaceableResult | null => {
|
||||
if (!result || Object.keys(result.answers).length === 0) return null;
|
||||
return {
|
||||
content: [{ type: "text" as const, text: formatForLLM(result) }],
|
||||
details: {
|
||||
questions,
|
||||
response: result,
|
||||
cancelled: false,
|
||||
} satisfies LocalResultDetails,
|
||||
};
|
||||
})
|
||||
.catch(() => null);
|
||||
|
||||
const remotePromise = startRemote()
|
||||
.then((result): RaceableResult | null => {
|
||||
if (!result) return null;
|
||||
const details = result.details as Record<string, unknown> | undefined;
|
||||
// Plain timeouts/errors are non-wins, but timeout auto-resolution is a
|
||||
// real answer and must win in headless/supervised flows.
|
||||
if (!isUsableRemoteQuestionResult(details)) return null;
|
||||
return result;
|
||||
})
|
||||
.catch(() => null);
|
||||
|
||||
// Race: first non-null result wins
|
||||
const winner = await Promise.race([
|
||||
localPromise.then((r) =>
|
||||
r ? { source: "local" as const, result: r } : null,
|
||||
),
|
||||
remotePromise.then((r) =>
|
||||
r ? { source: "remote" as const, result: r } : null,
|
||||
),
|
||||
]);
|
||||
|
||||
if (winner) {
|
||||
// Cancel the loser
|
||||
controller.abort();
|
||||
return winner.result;
|
||||
}
|
||||
|
||||
// First to resolve was null — wait for the other
|
||||
const [localResult, remoteResult] = await Promise.all([
|
||||
localPromise,
|
||||
remotePromise,
|
||||
]);
|
||||
controller.abort();
|
||||
return localResult ?? remoteResult;
|
||||
}
|
||||
|
||||
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
const OTHER_OPTION_LABEL = "None of the above";
|
||||
|
||||
async function askLocalQuestionRound(
|
||||
questions: Question[],
|
||||
signal: AbortSignal | undefined,
|
||||
ctx: Pick<ExtensionCommandContext, "ui">,
|
||||
): Promise<RoundResult | null | undefined> {
|
||||
const result = (await showInterviewRound(
|
||||
questions,
|
||||
{ signal },
|
||||
ctx as ExtensionCommandContext,
|
||||
)) as RoundResult | undefined;
|
||||
if (result !== undefined) return result;
|
||||
if (signal?.aborted) return null;
|
||||
|
||||
const answers: Record<
|
||||
string,
|
||||
{ selected: string | string[]; notes: string }
|
||||
> = {};
|
||||
for (const q of questions) {
|
||||
const options = q.options.map((o) => o.label);
|
||||
if (!q.allowMultiple) {
|
||||
options.push(OTHER_OPTION_LABEL);
|
||||
}
|
||||
const selected = await ctx.ui.select(
|
||||
`${q.header}: ${q.question}`,
|
||||
options,
|
||||
{ signal, ...(q.allowMultiple ? { allowMultiple: true } : {}) },
|
||||
);
|
||||
if (selected === undefined) return null;
|
||||
|
||||
let freeTextNote = "";
|
||||
const selectedStr = Array.isArray(selected) ? selected[0] : selected;
|
||||
if (!q.allowMultiple && selectedStr === OTHER_OPTION_LABEL) {
|
||||
const note = await ctx.ui.input(
|
||||
`${q.header}: Please explain in your own words`,
|
||||
"Type your answer here…",
|
||||
{ signal },
|
||||
);
|
||||
if (note) {
|
||||
freeTextNote = note;
|
||||
}
|
||||
}
|
||||
|
||||
answers[q.id] = {
|
||||
selected,
|
||||
notes: freeTextNote,
|
||||
};
|
||||
}
|
||||
|
||||
return { endInterview: false, answers };
|
||||
}
|
||||
|
||||
function errorResult(
|
||||
message: string,
|
||||
questions: Question[] = [],
|
||||
): {
|
||||
content: { type: "text"; text: string }[];
|
||||
details: AskUserQuestionsDetails;
|
||||
} {
|
||||
return {
|
||||
content: [{ type: "text", text: sanitizeError(message) }],
|
||||
details: { questions, response: null, cancelled: true },
|
||||
};
|
||||
}
|
||||
|
||||
function cleanRecommendedLabel(label: string): string {
|
||||
return label.replace(/\s*\(Recommended\)\s*/g, "").trim();
|
||||
}
|
||||
|
||||
function gateLogId(questionId: string): string {
|
||||
if (questionId.includes("depth_verification")) return "depth_verification";
|
||||
return questionId;
|
||||
}
|
||||
|
||||
function logHeadlessLocalAutoResolve(result: RaceableResult): void {
|
||||
const details = result.details as Record<string, unknown> | undefined;
|
||||
if (
|
||||
!details?.localFallback ||
|
||||
!details.response ||
|
||||
!Array.isArray(details.questions)
|
||||
)
|
||||
return;
|
||||
const questions = details.questions as Question[];
|
||||
const response = details.response as RoundResult;
|
||||
const firstQuestion = questions[0];
|
||||
if (!firstQuestion) return;
|
||||
const selected = response.answers[firstQuestion.id]?.selected;
|
||||
const firstAnswer = Array.isArray(selected) ? selected[0] : selected;
|
||||
if (!firstAnswer) return;
|
||||
process.stderr.write(
|
||||
`[gate] auto-resolved ${gateLogId(firstQuestion.id)} → "${cleanRecommendedLabel(firstAnswer)}" (timeout, headless, no telegram)\n`,
|
||||
);
|
||||
}
|
||||
|
||||
/** Convert the shared RoundResult into the JSON the LLM expects. */
|
||||
const formatForLLM = formatRoundResultForTool;
|
||||
|
||||
// ─── Extension ────────────────────────────────────────────────────────────────
|
||||
|
||||
export default function AskUserQuestions(pi: ExtensionAPI) {
|
||||
pi.registerTool({
|
||||
name: "ask_user_questions",
|
||||
label: "Request User Input",
|
||||
description:
|
||||
"Request user input for one to three short questions and wait for the response. Single-select questions have 2-3 mutually exclusive options with a free-form 'None of the above' added automatically. Multi-select questions (allowMultiple: true) let the user toggle multiple options with SPACE and confirm with ENTER.",
|
||||
promptGuidelines: [
|
||||
"Use ask_user_questions when you need the user to choose between concrete alternatives before proceeding.",
|
||||
"Keep questions to 1 when possible; never exceed 3.",
|
||||
"For single-select: each question must have 2-3 options. Put the recommended option first with '(Recommended)' suffix. Do not include an 'Other' or 'None of the above' option - the client adds one automatically.",
|
||||
"For multi-select: set allowMultiple: true. The user can pick any number of options. No 'None of the above' is added.",
|
||||
],
|
||||
parameters: AskUserQuestionsParams,
|
||||
|
||||
async execute(_toolCallId, params, signal, _onUpdate, ctx) {
|
||||
// ── Per-turn dedup: return cached result for identical question sets ──
|
||||
const sig = questionSignature(params.questions);
|
||||
const cached = turnCache.get(sig);
|
||||
if (cached) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text:
|
||||
cached.content[0].text +
|
||||
"\n(Returned cached answer — this question set was already asked this turn.)",
|
||||
},
|
||||
],
|
||||
details: cached.details,
|
||||
};
|
||||
}
|
||||
|
||||
// Validation
|
||||
if (params.questions.length === 0 || params.questions.length > 3) {
|
||||
return errorResult(
|
||||
"Error: questions must contain 1-3 items",
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
|
||||
for (const q of params.questions) {
|
||||
if (!q.options || q.options.length === 0) {
|
||||
return errorResult(
|
||||
`Error: ask_user_questions requires non-empty options for every question (question "${q.id}" has none)`,
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Routing: race remote + local, remote-only, or local-only ────────
|
||||
const {
|
||||
tryRemoteQuestions,
|
||||
isRemoteConfigured,
|
||||
tryHeadlessLocalAutoResolveQuestions,
|
||||
} = await import("./remote-questions/manager.js");
|
||||
const hasRemote = isRemoteConfigured();
|
||||
|
||||
// Case 1: Both remote and local UI available — race them.
|
||||
// The first response wins; the loser is cancelled via AbortController.
|
||||
if (hasRemote && ctx.hasUI) {
|
||||
const raceController = new AbortController();
|
||||
// Merge the parent signal so external cancellation propagates.
|
||||
const onParentAbort = () => raceController.abort();
|
||||
signal?.addEventListener("abort", onParentAbort, { once: true });
|
||||
const raceSignal = raceController.signal;
|
||||
|
||||
const raceResult = await raceRemoteAndLocal(
|
||||
() => tryRemoteQuestions(params.questions, raceSignal),
|
||||
() => askLocalQuestionRound(params.questions, raceSignal, ctx as any),
|
||||
raceController,
|
||||
params.questions,
|
||||
);
|
||||
|
||||
signal?.removeEventListener("abort", onParentAbort);
|
||||
|
||||
if (raceResult) {
|
||||
const details = raceResult.details as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (details && isUsableRemoteQuestionResult(details)) {
|
||||
turnCache.set(sig, raceResult as unknown as CachedResult);
|
||||
}
|
||||
return { ...raceResult, details: raceResult.details as unknown };
|
||||
}
|
||||
// Both sides failed/cancelled — fall through to error
|
||||
return errorResult(
|
||||
"ask_user_questions: no response received from local UI or remote channel",
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
|
||||
// Case 2: Remote configured but no local UI (headless) — remote only.
|
||||
if (hasRemote && !ctx.hasUI) {
|
||||
const remoteResult = await tryRemoteQuestions(params.questions, signal);
|
||||
let failedRemoteResult: RaceableResult | null = null;
|
||||
if (remoteResult) {
|
||||
const remoteDetails = remoteResult.details as
|
||||
| Record<string, unknown>
|
||||
| undefined;
|
||||
if (remoteDetails && isUsableRemoteQuestionResult(remoteDetails)) {
|
||||
turnCache.set(sig, remoteResult as unknown as CachedResult);
|
||||
if (remoteDetails.localFallback)
|
||||
logHeadlessLocalAutoResolve(remoteResult);
|
||||
return {
|
||||
...remoteResult,
|
||||
details: remoteResult.details as unknown,
|
||||
};
|
||||
}
|
||||
failedRemoteResult = remoteResult;
|
||||
}
|
||||
const fallbackResult = await tryHeadlessLocalAutoResolveQuestions(
|
||||
params.questions,
|
||||
{
|
||||
hasUI: ctx.hasUI,
|
||||
telegramUnavailable: true,
|
||||
unavailableReason: "telegram-poller-error",
|
||||
signal,
|
||||
},
|
||||
);
|
||||
if (fallbackResult) {
|
||||
turnCache.set(sig, fallbackResult as unknown as CachedResult);
|
||||
logHeadlessLocalAutoResolve(fallbackResult);
|
||||
return {
|
||||
...fallbackResult,
|
||||
details: fallbackResult.details as unknown,
|
||||
};
|
||||
}
|
||||
if (failedRemoteResult)
|
||||
return {
|
||||
...failedRemoteResult,
|
||||
details: failedRemoteResult.details as unknown,
|
||||
};
|
||||
return errorResult(
|
||||
"Error: remote channel configured but returned no result",
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
|
||||
// Case 3: No remote — local UI only.
|
||||
if (!ctx.hasUI) {
|
||||
const fallbackResult = await tryHeadlessLocalAutoResolveQuestions(
|
||||
params.questions,
|
||||
{
|
||||
hasUI: ctx.hasUI,
|
||||
telegramUnavailable: true,
|
||||
unavailableReason: "no-telegram",
|
||||
signal,
|
||||
},
|
||||
);
|
||||
if (fallbackResult) {
|
||||
turnCache.set(sig, fallbackResult as unknown as CachedResult);
|
||||
logHeadlessLocalAutoResolve(fallbackResult);
|
||||
return {
|
||||
...fallbackResult,
|
||||
details: fallbackResult.details as unknown,
|
||||
};
|
||||
}
|
||||
return errorResult(
|
||||
"Error: UI not available (non-interactive mode)",
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
|
||||
// Delegate to shared interview UI
|
||||
const result = await askLocalQuestionRound(
|
||||
params.questions,
|
||||
signal,
|
||||
ctx as any,
|
||||
);
|
||||
if (!result) {
|
||||
return errorResult(
|
||||
"ask_user_questions was cancelled",
|
||||
params.questions,
|
||||
);
|
||||
}
|
||||
|
||||
// Check if cancelled (empty answers = user exited)
|
||||
const hasAnswers = Object.keys(result.answers).length > 0;
|
||||
if (!hasAnswers) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "ask_user_questions was cancelled before receiving a response",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
questions: params.questions,
|
||||
response: null,
|
||||
cancelled: true,
|
||||
} satisfies LocalResultDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const successResult = {
|
||||
content: [{ type: "text" as const, text: formatForLLM(result) }],
|
||||
details: {
|
||||
questions: params.questions,
|
||||
response: result,
|
||||
cancelled: false,
|
||||
} satisfies LocalResultDetails,
|
||||
};
|
||||
turnCache.set(sig, successResult);
|
||||
return successResult;
|
||||
},
|
||||
|
||||
// ─── Rendering ────────────────────────────────────────────────────────
|
||||
|
||||
renderCall(args, theme) {
|
||||
const qs = (args.questions as Question[]) || [];
|
||||
let text = theme.fg("toolTitle", theme.bold("ask_user_questions "));
|
||||
text += theme.fg(
|
||||
"muted",
|
||||
`${qs.length} question${qs.length !== 1 ? "s" : ""}`,
|
||||
);
|
||||
if (qs.length > 0) {
|
||||
const headers = qs.map((q) => q.header).join(", ");
|
||||
text += theme.fg("dim", ` (${headers})`);
|
||||
}
|
||||
for (const q of qs) {
|
||||
const multiSel = !!q.allowMultiple;
|
||||
text += `\n ${theme.fg("text", q.question)}`;
|
||||
const optLabels = multiSel
|
||||
? (q.options || []).map((o: QuestionOption) => o.label)
|
||||
: [
|
||||
...(q.options || []).map((o: QuestionOption) => o.label),
|
||||
OTHER_OPTION_LABEL,
|
||||
];
|
||||
const prefix = multiSel ? "☐" : "";
|
||||
const numbered = optLabels
|
||||
.map((l, i) => `${prefix}${i + 1}. ${l}`)
|
||||
.join(", ");
|
||||
text += `\n ${theme.fg("dim", numbered)}`;
|
||||
}
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, _options, theme) {
|
||||
const details = result.details as AskUserQuestionsDetails | undefined;
|
||||
if (!details) {
|
||||
const text = result.content[0];
|
||||
return new Text(text?.type === "text" ? text.text : "", 0, 0);
|
||||
}
|
||||
|
||||
// Remote channel result (discriminated on details.remote === true)
|
||||
if (details.remote) {
|
||||
if (details.timed_out && !details.autoResolved) {
|
||||
return new Text(
|
||||
`${theme.fg("warning", `${details.channel} — timed out`)}${details.threadUrl ? theme.fg("dim", ` ${details.threadUrl}`) : ""}`,
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
const questions = (details.questions ?? []) as Question[];
|
||||
const lines: string[] = [];
|
||||
lines.push(
|
||||
theme.fg(
|
||||
"dim",
|
||||
details.autoResolved
|
||||
? `${details.channel} — auto-resolved on timeout`
|
||||
: details.channel,
|
||||
),
|
||||
);
|
||||
if (details.response) {
|
||||
for (const q of questions) {
|
||||
const answer = details.response.answers[q.id];
|
||||
if (!answer) {
|
||||
lines.push(
|
||||
`${theme.fg("accent", q.header)}: ${theme.fg("dim", "(no answer)")}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const selected = answer.selected;
|
||||
const answerText = Array.isArray(selected)
|
||||
? selected.join(", ")
|
||||
: selected || "(custom)";
|
||||
let line = `${theme.fg("success", "✓ ")}${theme.fg("accent", q.header)}: ${answerText}`;
|
||||
if (answer.notes) {
|
||||
line += ` ${theme.fg("muted", `[note: ${answer.notes}]`)}`;
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
}
|
||||
return new Text(lines.join("\n"), 0, 0);
|
||||
}
|
||||
|
||||
// After the remote branch, details is LocalResultDetails
|
||||
const local = details as LocalResultDetails;
|
||||
if (local.cancelled || !local.response) {
|
||||
return new Text(theme.fg("warning", "Cancelled"), 0, 0);
|
||||
}
|
||||
|
||||
const lines: string[] = [];
|
||||
for (const q of details.questions) {
|
||||
const answer = (details.response as RoundResult).answers[q.id];
|
||||
if (!answer) {
|
||||
lines.push(
|
||||
`${theme.fg("accent", q.header)}: ${theme.fg("dim", "(no answer)")}`,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
const selected = answer.selected;
|
||||
const notes = answer.notes;
|
||||
const multiSel = !!q.allowMultiple;
|
||||
const answerText =
|
||||
multiSel && Array.isArray(selected)
|
||||
? selected.join(", ")
|
||||
: ((Array.isArray(selected) ? selected[0] : selected) ??
|
||||
"(no answer)");
|
||||
let line = `${theme.fg("success", "✓ ")}${theme.fg("accent", q.header)}: ${answerText}`;
|
||||
if (notes) {
|
||||
line += ` ${theme.fg("muted", `[note: ${notes}]`)}`;
|
||||
}
|
||||
lines.push(line);
|
||||
}
|
||||
return new Text(lines.join("\n"), 0, 0);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,143 +0,0 @@
|
|||
/**
|
||||
* async-bash-timeout.test.ts — Tests for async_bash timeout behavior.
|
||||
*
|
||||
* Reproduces issue #2186: when an async bash job exceeds its timeout and
|
||||
* the child process ignores SIGTERM, the promise hangs indefinitely.
|
||||
* The fix adds a SIGKILL fallback and a hard deadline that force-resolves
|
||||
* the promise so execution can continue.
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { test } from 'vitest';
|
||||
import { createAsyncBashTool } from "./async-bash-tool.ts";
|
||||
import { AsyncJobManager } from "./job-manager.ts";
|
||||
|
||||
function getTextFromResult(result: {
|
||||
content: Array<{ type: string; text?: string }>;
|
||||
}): string {
|
||||
return result.content.map((c) => c.text ?? "").join("\n");
|
||||
}
|
||||
|
||||
const noopSignal = new AbortController().signal;
|
||||
|
||||
test("async_bash with timeout resolves even if process ignores SIGTERM", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAsyncBashTool(
|
||||
() => manager,
|
||||
() => process.cwd(),
|
||||
);
|
||||
|
||||
// Start a job that traps SIGTERM (ignores it), with a 2s timeout.
|
||||
// The process installs a SIGTERM trap and sleeps for 60s.
|
||||
// Before the fix, this would hang forever because SIGTERM is ignored
|
||||
// and the close event never fires.
|
||||
const result = await tool.execute(
|
||||
"tc-timeout",
|
||||
{
|
||||
command: "trap '' TERM; sleep 60",
|
||||
timeout: 2,
|
||||
label: "sigterm-resistant",
|
||||
},
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
|
||||
const text = getTextFromResult(result);
|
||||
assert.match(text, /sigterm-resistant/);
|
||||
|
||||
const jobId = text.match(/\*\*(bg_[a-f0-9]+)\*\*/)?.[1];
|
||||
assert.ok(jobId, "Should have returned a job ID");
|
||||
|
||||
// Now await the job — it should resolve within a reasonable time
|
||||
// (timeout 2s + SIGKILL grace 5s + buffer = well under 15s)
|
||||
const start = Date.now();
|
||||
const job = manager.getJob(jobId)!;
|
||||
assert.ok(job, "Job should exist");
|
||||
|
||||
await Promise.race([
|
||||
job.promise,
|
||||
new Promise<never>((_, reject) => {
|
||||
const t = setTimeout(
|
||||
() =>
|
||||
reject(
|
||||
new Error(
|
||||
`Job promise hung for ${Date.now() - start}ms — ` +
|
||||
`this is the bug from issue #2186: timeout hangs indefinitely`,
|
||||
),
|
||||
),
|
||||
15_000,
|
||||
);
|
||||
if (typeof t === "object" && "unref" in t) t.unref();
|
||||
}),
|
||||
]);
|
||||
|
||||
const elapsed = Date.now() - start;
|
||||
// Should have resolved well within 15s (timeout 2s + kill grace ~5s)
|
||||
assert.ok(elapsed < 15_000, `Job took ${elapsed}ms — expected <15s`);
|
||||
|
||||
// Job should have completed (resolved, not rejected) with timeout message
|
||||
assert.ok(
|
||||
job.status === "completed" || job.status === "failed",
|
||||
`Job status should be completed or failed, got: ${job.status}`,
|
||||
);
|
||||
|
||||
if (job.status === "completed") {
|
||||
assert.ok(
|
||||
job.resultText?.includes("timed out") ||
|
||||
job.resultText?.includes("Timed out"),
|
||||
`Result should mention timeout, got: ${job.resultText}`,
|
||||
);
|
||||
}
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("async_bash with timeout resolves normally when process exits on SIGTERM", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAsyncBashTool(
|
||||
() => manager,
|
||||
() => process.cwd(),
|
||||
);
|
||||
|
||||
// Start a normal sleep that will die on SIGTERM, with a 1s timeout
|
||||
const result = await tool.execute(
|
||||
"tc-normal-timeout",
|
||||
{
|
||||
command: "sleep 60",
|
||||
timeout: 1,
|
||||
label: "normal-timeout",
|
||||
},
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
|
||||
const text = getTextFromResult(result);
|
||||
const jobId = text.match(/\*\*(bg_[a-f0-9]+)\*\*/)?.[1];
|
||||
assert.ok(jobId, "Should have returned a job ID");
|
||||
|
||||
const job = manager.getJob(jobId)!;
|
||||
const start = Date.now();
|
||||
|
||||
await Promise.race([
|
||||
job.promise,
|
||||
new Promise<never>((_, reject) => {
|
||||
const t = setTimeout(() => reject(new Error("Job hung")), 10_000);
|
||||
if (typeof t === "object" && "unref" in t) t.unref();
|
||||
}),
|
||||
]);
|
||||
|
||||
const elapsed = Date.now() - start;
|
||||
assert.ok(
|
||||
elapsed < 5_000,
|
||||
`Expected quick resolution after SIGTERM, took ${elapsed}ms`,
|
||||
);
|
||||
assert.equal(job.status, "completed");
|
||||
assert.ok(
|
||||
job.resultText?.includes("timed out"),
|
||||
`Should mention timeout: ${job.resultText}`,
|
||||
);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
|
@ -1,301 +0,0 @@
|
|||
/**
|
||||
* async_bash tool — run a bash command in the background.
|
||||
*
|
||||
* Registers the command with the AsyncJobManager and returns a job ID
|
||||
* immediately. The LLM can continue working and check results later
|
||||
* with await_job.
|
||||
*/
|
||||
|
||||
import { spawn, spawnSync } from "node:child_process";
|
||||
import { randomBytes } from "node:crypto";
|
||||
import { createWriteStream } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ToolDefinition } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
getShellConfig,
|
||||
sanitizeCommand,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { rewriteCommandWithRtk } from "../shared/rtk.js";
|
||||
import type { AsyncJobManager } from "./job-manager.js";
|
||||
|
||||
const schema = Type.Object({
|
||||
command: Type.String({
|
||||
description: "Bash command to execute in the background",
|
||||
}),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({ description: "Timeout in seconds (optional)" }),
|
||||
),
|
||||
label: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Short label for the job (shown in /jobs). Defaults to a truncated version of the command.",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
function getTempFilePath(): string {
|
||||
const id = randomBytes(8).toString("hex");
|
||||
return join(tmpdir(), `pi-async-bash-${id}.log`);
|
||||
}
|
||||
|
||||
/**
|
||||
* Kill a process and its children (cross-platform).
|
||||
* Uses process group kill on Unix; taskkill /F /T on Windows.
|
||||
*/
|
||||
function killTree(pid: number): void {
|
||||
if (process.platform === "win32") {
|
||||
try {
|
||||
spawnSync("taskkill", ["/F", "/T", "/PID", String(pid)], {
|
||||
timeout: 5_000,
|
||||
stdio: "ignore",
|
||||
});
|
||||
} catch {
|
||||
try {
|
||||
process.kill(pid, "SIGTERM");
|
||||
} catch {
|
||||
/* already exited */
|
||||
}
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
process.kill(-pid, "SIGTERM");
|
||||
} catch {
|
||||
try {
|
||||
process.kill(pid, "SIGTERM");
|
||||
} catch {
|
||||
/* already exited */
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function createAsyncBashTool(
|
||||
getManager: () => AsyncJobManager,
|
||||
getCwd: () => string,
|
||||
): ToolDefinition<typeof schema> {
|
||||
return {
|
||||
name: "async_bash",
|
||||
label: "Background Bash",
|
||||
description:
|
||||
`Run a bash command in the background. Returns a job ID immediately so you can continue working. ` +
|
||||
`Use await_job to get results or cancel_job to stop. Ideal for long-running builds, tests, or installs. ` +
|
||||
`Output is truncated to the last ${DEFAULT_MAX_LINES} lines or ${DEFAULT_MAX_BYTES / 1024}KB.`,
|
||||
promptSnippet:
|
||||
"Run a bash command in the background, returning a job ID immediately.",
|
||||
promptGuidelines: [
|
||||
"Use async_bash for commands that take more than a few seconds (builds, tests, installs, large git operations).",
|
||||
"After starting async jobs, continue with other work and use await_job when you need the results.",
|
||||
"await_job has a configurable timeout (default 120s) to prevent indefinite blocking — if it times out, jobs keep running and you can check again later.",
|
||||
"For long-running processes (SSH, deploys, training) that may take minutes+, prefer async_bash with periodic await_job polling over a single long await.",
|
||||
"Use cancel_job to stop a running background job.",
|
||||
"Check /jobs to see all running and recent background jobs.",
|
||||
],
|
||||
parameters: schema,
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const manager = getManager();
|
||||
const cwd = getCwd();
|
||||
const { command, timeout, label } = params;
|
||||
const shortCmd =
|
||||
label ?? (command.length > 60 ? command.slice(0, 57) + "..." : command);
|
||||
|
||||
const jobId = manager.register("bash", shortCmd, (signal) => {
|
||||
return executeBashInBackground(command, cwd, signal, timeout);
|
||||
});
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: [
|
||||
`Background job started: **${jobId}**`,
|
||||
`Command: \`${shortCmd}\``,
|
||||
"",
|
||||
"Use `await_job` to get results when ready, or `cancel_job` to stop.",
|
||||
].join("\n"),
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a bash command, collecting output. Returns the text result.
|
||||
*/
|
||||
function executeBashInBackground(
|
||||
command: string,
|
||||
cwd: string,
|
||||
signal: AbortSignal,
|
||||
timeout?: number,
|
||||
): Promise<string> {
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
let settled = false;
|
||||
const safeResolve = (value: string) => {
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
resolve(value);
|
||||
}
|
||||
};
|
||||
const safeReject = (err: unknown) => {
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
reject(err);
|
||||
}
|
||||
};
|
||||
|
||||
const { shell, args } = getShellConfig();
|
||||
const rewrittenCommand = rewriteCommandWithRtk(command);
|
||||
const resolvedCommand = sanitizeCommand(rewrittenCommand);
|
||||
|
||||
// On Windows, detached: true sets CREATE_NEW_PROCESS_GROUP which can
|
||||
// cause EINVAL in VSCode/ConPTY terminal contexts. The bg-shell
|
||||
// extension already guards this (process-manager.ts); align here.
|
||||
// Process-tree cleanup uses taskkill /F /T on Windows regardless.
|
||||
const child = spawn(shell, [...args, resolvedCommand], {
|
||||
cwd,
|
||||
detached: process.platform !== "win32",
|
||||
env: { ...process.env },
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
});
|
||||
|
||||
let timedOut = false;
|
||||
let timeoutHandle: ReturnType<typeof setTimeout> | undefined;
|
||||
let sigkillHandle: ReturnType<typeof setTimeout> | undefined;
|
||||
let hardDeadlineHandle: ReturnType<typeof setTimeout> | undefined;
|
||||
|
||||
/** Grace period (ms) between SIGTERM and SIGKILL. */
|
||||
const SIGKILL_GRACE_MS = 5_000;
|
||||
/** Hard deadline (ms) after SIGKILL to force-resolve the promise. */
|
||||
const HARD_DEADLINE_MS = 3_000;
|
||||
|
||||
if (timeout !== undefined && timeout > 0) {
|
||||
timeoutHandle = setTimeout(() => {
|
||||
timedOut = true;
|
||||
if (child.pid) killTree(child.pid);
|
||||
|
||||
// If the process ignores SIGTERM, escalate to SIGKILL
|
||||
sigkillHandle = setTimeout(() => {
|
||||
if (child.pid) {
|
||||
// killTree already uses taskkill /F /T on Windows
|
||||
killTree(child.pid);
|
||||
}
|
||||
|
||||
// Hard deadline: if even SIGKILL doesn't trigger 'close',
|
||||
// force-resolve so the job doesn't hang forever (#2186).
|
||||
hardDeadlineHandle = setTimeout(() => {
|
||||
const output = Buffer.concat(chunks).toString("utf-8");
|
||||
safeResolve(
|
||||
output
|
||||
? `${output}\n\nCommand timed out after ${timeout} seconds (force-killed)`
|
||||
: `Command timed out after ${timeout} seconds (force-killed)`,
|
||||
);
|
||||
}, HARD_DEADLINE_MS);
|
||||
if (
|
||||
typeof hardDeadlineHandle === "object" &&
|
||||
"unref" in hardDeadlineHandle
|
||||
)
|
||||
hardDeadlineHandle.unref();
|
||||
}, SIGKILL_GRACE_MS);
|
||||
if (typeof sigkillHandle === "object" && "unref" in sigkillHandle)
|
||||
sigkillHandle.unref();
|
||||
}, timeout * 1000);
|
||||
}
|
||||
|
||||
const chunks: Buffer[] = [];
|
||||
let totalBytes = 0;
|
||||
let spillFilePath: string | undefined;
|
||||
let spillStream: ReturnType<typeof createWriteStream> | undefined;
|
||||
const MAX_BUFFER = DEFAULT_MAX_BYTES * 2;
|
||||
|
||||
const onData = (data: Buffer) => {
|
||||
totalBytes += data.length;
|
||||
|
||||
if (totalBytes > DEFAULT_MAX_BYTES && !spillFilePath) {
|
||||
spillFilePath = getTempFilePath();
|
||||
spillStream = createWriteStream(spillFilePath);
|
||||
for (const chunk of chunks) spillStream.write(chunk);
|
||||
}
|
||||
if (spillStream) spillStream.write(data);
|
||||
|
||||
chunks.push(data);
|
||||
let chunksBytes = chunks.reduce((s, c) => s + c.length, 0);
|
||||
while (chunksBytes > MAX_BUFFER && chunks.length > 1) {
|
||||
const removed = chunks.shift()!;
|
||||
chunksBytes -= removed.length;
|
||||
}
|
||||
};
|
||||
|
||||
if (child.stdout) child.stdout.on("data", onData);
|
||||
if (child.stderr) child.stderr.on("data", onData);
|
||||
|
||||
const onAbort = () => {
|
||||
if (child.pid) killTree(child.pid);
|
||||
};
|
||||
|
||||
if (signal.aborted) {
|
||||
onAbort();
|
||||
} else {
|
||||
signal.addEventListener("abort", onAbort, { once: true });
|
||||
}
|
||||
|
||||
child.on("error", (err) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (sigkillHandle) clearTimeout(sigkillHandle);
|
||||
if (hardDeadlineHandle) clearTimeout(hardDeadlineHandle);
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
safeReject(err);
|
||||
});
|
||||
|
||||
child.on("close", (code) => {
|
||||
if (timeoutHandle) clearTimeout(timeoutHandle);
|
||||
if (sigkillHandle) clearTimeout(sigkillHandle);
|
||||
if (hardDeadlineHandle) clearTimeout(hardDeadlineHandle);
|
||||
signal.removeEventListener("abort", onAbort);
|
||||
if (spillStream) spillStream.end();
|
||||
|
||||
if (signal.aborted) {
|
||||
const output = Buffer.concat(chunks).toString("utf-8");
|
||||
safeResolve(
|
||||
output ? `${output}\n\nCommand aborted` : "Command aborted",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (timedOut) {
|
||||
const output = Buffer.concat(chunks).toString("utf-8");
|
||||
safeResolve(
|
||||
output
|
||||
? `${output}\n\nCommand timed out after ${timeout} seconds`
|
||||
: `Command timed out after ${timeout} seconds`,
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const fullOutput = Buffer.concat(chunks).toString("utf-8");
|
||||
|
||||
const lines = fullOutput.split("\n");
|
||||
let text: string;
|
||||
if (lines.length > DEFAULT_MAX_LINES) {
|
||||
text = lines.slice(-DEFAULT_MAX_LINES).join("\n");
|
||||
if (spillFilePath) {
|
||||
text += `\n\n[Showing last ${DEFAULT_MAX_LINES} of ${lines.length} lines. Full output: ${spillFilePath}]`;
|
||||
} else {
|
||||
text += `\n\n[Showing last ${DEFAULT_MAX_LINES} of ${lines.length} lines]`;
|
||||
}
|
||||
} else {
|
||||
text = fullOutput || "(no output)";
|
||||
}
|
||||
|
||||
if (code !== 0 && code !== null) {
|
||||
text += `\n\nCommand exited with code ${code}`;
|
||||
}
|
||||
|
||||
safeResolve(text);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
|
@ -1,266 +0,0 @@
|
|||
/**
|
||||
* await-tool.test.ts — Tests for await_job timeout behavior.
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { test } from 'vitest';
|
||||
import { createAwaitTool } from "./await-tool.ts";
|
||||
import { AsyncJobManager } from "./job-manager.ts";
|
||||
|
||||
function getTextFromResult(result: {
|
||||
content: Array<{ type: string; text?: string }>;
|
||||
}): string {
|
||||
return result.content.map((c) => c.text ?? "").join("\n");
|
||||
}
|
||||
|
||||
const noopSignal = new AbortController().signal;
|
||||
|
||||
test("await_job returns immediately when no running jobs exist", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
const result = await tool.execute(
|
||||
"tc1",
|
||||
{},
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const text = getTextFromResult(result);
|
||||
assert.match(text, /No running background jobs/);
|
||||
});
|
||||
|
||||
test("await_job returns immediately when all watched jobs are already completed", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register a job that completes instantly
|
||||
const jobId = manager.register("bash", "fast-job", async () => "done");
|
||||
// Wait for the job to settle
|
||||
const job = manager.getJob(jobId)!;
|
||||
await job.promise;
|
||||
|
||||
const result = await tool.execute(
|
||||
"tc2",
|
||||
{ jobs: [jobId] },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const text = getTextFromResult(result);
|
||||
assert.match(text, /fast-job/);
|
||||
assert.match(text, /completed/);
|
||||
});
|
||||
|
||||
test("await_job returns on timeout when jobs are still running", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register a job that takes a long time
|
||||
const jobId = manager.register("bash", "slow-job", async (_signal) => {
|
||||
return new Promise<string>((resolve) => {
|
||||
const timer = setTimeout(() => resolve("finally done"), 60_000);
|
||||
if (typeof timer === "object" && "unref" in timer) timer.unref();
|
||||
});
|
||||
});
|
||||
|
||||
const start = Date.now();
|
||||
const result = await tool.execute(
|
||||
"tc3",
|
||||
{ jobs: [jobId], timeout: 1 },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const elapsed = Date.now() - start;
|
||||
const text = getTextFromResult(result);
|
||||
|
||||
// Should have timed out within ~1-2 seconds, not 60
|
||||
assert.ok(elapsed < 5_000, `Expected timeout in ~1s but took ${elapsed}ms`);
|
||||
assert.match(text, /Timed out/);
|
||||
assert.match(text, /Still running/);
|
||||
assert.match(text, /slow-job/);
|
||||
|
||||
// Cleanup
|
||||
manager.cancel(jobId);
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("await_job completes before timeout when job finishes quickly", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register a job that completes in 100ms
|
||||
const jobId = manager.register("bash", "quick-job", async () => {
|
||||
return new Promise<string>((resolve) =>
|
||||
setTimeout(() => resolve("quick result"), 100),
|
||||
);
|
||||
});
|
||||
|
||||
const start = Date.now();
|
||||
const result = await tool.execute(
|
||||
"tc4",
|
||||
{ jobs: [jobId], timeout: 30 },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const elapsed = Date.now() - start;
|
||||
const text = getTextFromResult(result);
|
||||
|
||||
// Should complete in ~100ms, well before the 30s timeout
|
||||
assert.ok(elapsed < 5_000, `Expected quick completion but took ${elapsed}ms`);
|
||||
assert.ok(!text.includes("Timed out"), "Should not have timed out");
|
||||
assert.match(text, /quick-job/);
|
||||
assert.match(text, /completed/);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("await_job uses default timeout of 120s when not specified", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register a job that completes immediately
|
||||
const jobId = manager.register("bash", "instant-job", async () => "instant");
|
||||
const job = manager.getJob(jobId)!;
|
||||
await job.promise;
|
||||
|
||||
// Call without timeout param — should work fine for already-done jobs
|
||||
const result = await tool.execute(
|
||||
"tc5",
|
||||
{ jobs: [jobId] },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const text = getTextFromResult(result);
|
||||
assert.match(text, /instant-job/);
|
||||
assert.match(text, /completed/);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("await_job returns not-found message for invalid job IDs", async () => {
|
||||
const manager = new AsyncJobManager();
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
const result = await tool.execute(
|
||||
"tc6",
|
||||
{ jobs: ["bg_nonexistent"] },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
const text = getTextFromResult(result);
|
||||
assert.match(text, /No jobs found/);
|
||||
assert.match(text, /bg_nonexistent/);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("await_job suppresses follow-up for jobs that complete while awaiting (#2248)", async () => {
|
||||
const followUps: string[] = [];
|
||||
const manager = new AsyncJobManager({
|
||||
onJobComplete: (job) => followUps.push(job.id),
|
||||
});
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register a job that completes in 50ms
|
||||
const jobId = manager.register("bash", "awaited-job", async () => {
|
||||
return new Promise<string>((resolve) =>
|
||||
setTimeout(() => resolve("result"), 50),
|
||||
);
|
||||
});
|
||||
|
||||
// await_job consumes the result — suppressFollowUp() should cancel delivery timer
|
||||
await tool.execute(
|
||||
"tc7",
|
||||
{ jobs: [jobId] },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
|
||||
// Give the onJobComplete callback a tick to fire (if suppression failed)
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
assert.equal(
|
||||
followUps.length,
|
||||
0,
|
||||
"onJobComplete should not fire for jobs consumed by await_job",
|
||||
);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("await_job suppresses follow-up for already-completed jobs (cross-turn case) (#3787)", async () => {
|
||||
// This is the key regression: job completes in a prior LLM turn, then
|
||||
// await_job is called in a later turn. The delivery timer must still be
|
||||
// cancellable at that point.
|
||||
const followUps: string[] = [];
|
||||
const manager = new AsyncJobManager({
|
||||
onJobComplete: (job) => followUps.push(job.id),
|
||||
});
|
||||
const tool = createAwaitTool(() => manager);
|
||||
|
||||
// Register and let the job complete fully before calling await_job
|
||||
const jobId = manager.register(
|
||||
"bash",
|
||||
"pre-completed-job",
|
||||
async () => "done",
|
||||
);
|
||||
const job = manager.getJob(jobId)!;
|
||||
await job.promise;
|
||||
|
||||
// Simulate a "later turn" by yielding to the event loop — this lets any
|
||||
// queueMicrotask callbacks run, but the setTimeout(0) delivery timer has
|
||||
// not yet fired (it's scheduled for the next macrotask).
|
||||
await new Promise((r) => setImmediate(r));
|
||||
|
||||
// Now call await_job — suppressFollowUp() should cancel the pending timer
|
||||
await tool.execute(
|
||||
"tc7b",
|
||||
{ jobs: [jobId] },
|
||||
noopSignal,
|
||||
() => {},
|
||||
undefined as never,
|
||||
);
|
||||
|
||||
// Drain the macrotask queue — the (now-cancelled) timer would have fired here
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
assert.equal(
|
||||
followUps.length,
|
||||
0,
|
||||
"onJobComplete should not fire for already-completed jobs consumed by await_job",
|
||||
);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
||||
test("unawaited jobs still get follow-up delivery (#2248)", async () => {
|
||||
const followUps: string[] = [];
|
||||
const manager = new AsyncJobManager({
|
||||
onJobComplete: (job) => {
|
||||
if (!job.awaited) followUps.push(job.id);
|
||||
},
|
||||
});
|
||||
|
||||
// Register a fire-and-forget job
|
||||
const jobId = manager.register("bash", "fire-and-forget", async () => "done");
|
||||
const job = manager.getJob(jobId)!;
|
||||
await job.promise;
|
||||
|
||||
// Give the callback a tick
|
||||
await new Promise((r) => setTimeout(r, 50));
|
||||
|
||||
assert.equal(
|
||||
followUps.length,
|
||||
1,
|
||||
"onJobComplete should deliver follow-up for unawaited jobs",
|
||||
);
|
||||
assert.equal(followUps[0], jobId);
|
||||
|
||||
manager.shutdown();
|
||||
});
|
||||
|
|
@ -1,146 +0,0 @@
|
|||
/**
|
||||
* await_job tool — wait for one or more background jobs to complete.
|
||||
*
|
||||
* If specific job IDs are provided, waits for those jobs.
|
||||
* If omitted, waits for any running job to complete.
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ToolDefinition } from "@singularity-forge/pi-coding-agent";
|
||||
import type { AsyncJobManager, Job } from "./job-manager.js";
|
||||
|
||||
const DEFAULT_TIMEOUT_SECONDS = 120;
|
||||
|
||||
const schema = Type.Object({
|
||||
jobs: Type.Optional(
|
||||
Type.Array(Type.String(), {
|
||||
description: "Job IDs to wait for. Omit to wait for any running job.",
|
||||
}),
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Maximum seconds to wait before returning control. Defaults to 120. " +
|
||||
"Jobs continue running in the background after timeout.",
|
||||
}),
|
||||
),
|
||||
});
|
||||
|
||||
export function createAwaitTool(
|
||||
getManager: () => AsyncJobManager,
|
||||
): ToolDefinition<typeof schema> {
|
||||
return {
|
||||
name: "await_job",
|
||||
label: "Await Background Job",
|
||||
description:
|
||||
"Wait for background jobs to complete. Provide specific job IDs or omit to wait for the next job that finishes. Returns results of completed jobs.",
|
||||
parameters: schema,
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const manager = getManager();
|
||||
const { jobs: jobIds, timeout } = params;
|
||||
const timeoutMs = (timeout ?? DEFAULT_TIMEOUT_SECONDS) * 1000;
|
||||
|
||||
let watched: Job[];
|
||||
if (jobIds && jobIds.length > 0) {
|
||||
watched = [];
|
||||
const notFound: string[] = [];
|
||||
for (const id of jobIds) {
|
||||
const job = manager.getJob(id);
|
||||
if (job) {
|
||||
watched.push(job);
|
||||
} else {
|
||||
notFound.push(id);
|
||||
}
|
||||
}
|
||||
if (notFound.length > 0 && watched.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `No jobs found: ${notFound.join(", ")}` },
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
} else {
|
||||
watched = manager.getRunningJobs();
|
||||
if (watched.length === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No running background jobs." }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Suppress follow-up notifications for all watched jobs upfront.
|
||||
// suppressFollowUp() cancels the pending delivery timer (if any), which
|
||||
// handles both the within-turn case (job completes while we await) and
|
||||
// the cross-turn case (job already completed before await_job was called).
|
||||
// Previously this only set j.awaited = true, which missed the cross-turn
|
||||
// case because the queueMicrotask had already fired (#3787).
|
||||
for (const j of watched) manager.suppressFollowUp(j.id);
|
||||
|
||||
// If all watched jobs are already done, return immediately
|
||||
const running = watched.filter((j) => j.status === "running");
|
||||
if (running.length === 0) {
|
||||
const result = formatResults(watched);
|
||||
return {
|
||||
content: [{ type: "text", text: result }],
|
||||
details: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// Wait for at least one to complete, or timeout
|
||||
const TIMEOUT_SENTINEL = Symbol("timeout");
|
||||
const timeoutPromise = new Promise<typeof TIMEOUT_SENTINEL>((resolve) => {
|
||||
const timer = setTimeout(() => resolve(TIMEOUT_SENTINEL), timeoutMs);
|
||||
// Allow the process to exit even if the timer is pending
|
||||
if (typeof timer === "object" && "unref" in timer) timer.unref();
|
||||
});
|
||||
|
||||
const raceResult = await Promise.race([
|
||||
Promise.race(running.map((j) => j.promise)).then(
|
||||
() => "completed" as const,
|
||||
),
|
||||
timeoutPromise,
|
||||
]);
|
||||
|
||||
const timedOut = raceResult === TIMEOUT_SENTINEL;
|
||||
|
||||
// Collect all completed results (more may have finished while waiting)
|
||||
const completed = watched.filter((j) => j.status !== "running");
|
||||
|
||||
const stillRunning = watched.filter((j) => j.status === "running");
|
||||
let result = formatResults(completed);
|
||||
if (stillRunning.length > 0) {
|
||||
result += `\n\n**Still running:** ${stillRunning.map((j) => `${j.id} (${j.label})`).join(", ")}`;
|
||||
}
|
||||
if (timedOut) {
|
||||
result +=
|
||||
`\n\n⏱ **Timed out** after ${timeout ?? DEFAULT_TIMEOUT_SECONDS}s waiting for jobs to finish. ` +
|
||||
`Jobs are still running in the background. ` +
|
||||
`Use \`await_job\` again later or \`async_bash\` + \`await_job\` for shorter polling intervals.`;
|
||||
}
|
||||
|
||||
return { content: [{ type: "text", text: result }], details: undefined };
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function formatResults(jobs: Job[]): string {
|
||||
if (jobs.length === 0) return "No completed jobs.";
|
||||
|
||||
const parts: string[] = [];
|
||||
for (const job of jobs) {
|
||||
const elapsed = ((Date.now() - job.startTime) / 1000).toFixed(1);
|
||||
const header = `### ${job.id} — ${job.label} (${job.status}, ${elapsed}s)`;
|
||||
|
||||
if (job.status === "completed") {
|
||||
parts.push(`${header}\n\n${job.resultText ?? "(no output)"}`);
|
||||
} else if (job.status === "failed") {
|
||||
parts.push(`${header}\n\nError: ${job.errorText ?? "unknown error"}`);
|
||||
} else if (job.status === "cancelled") {
|
||||
parts.push(`${header}\n\nCancelled.`);
|
||||
}
|
||||
}
|
||||
|
||||
return parts.join("\n\n---\n\n");
|
||||
}
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
/**
|
||||
* cancel_job tool — cancel a running background job.
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ToolDefinition } from "@singularity-forge/pi-coding-agent";
|
||||
import type { AsyncJobManager } from "./job-manager.js";
|
||||
|
||||
const schema = Type.Object({
|
||||
job_id: Type.String({
|
||||
description: "The background job ID to cancel (e.g. bg_a1b2c3d4)",
|
||||
}),
|
||||
});
|
||||
|
||||
export function createCancelJobTool(
|
||||
getManager: () => AsyncJobManager,
|
||||
): ToolDefinition<typeof schema> {
|
||||
return {
|
||||
name: "cancel_job",
|
||||
label: "Cancel Background Job",
|
||||
description: "Cancel a running background job by its ID.",
|
||||
parameters: schema,
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const manager = getManager();
|
||||
const result = manager.cancel(params.job_id);
|
||||
|
||||
const messages: Record<string, string> = {
|
||||
cancelled: `Job ${params.job_id} has been cancelled.`,
|
||||
not_found: `Job ${params.job_id} not found.`,
|
||||
already_completed: `Job ${params.job_id} has already completed (or failed/cancelled).`,
|
||||
};
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: messages[result] ?? `Unknown result: ${result}`,
|
||||
},
|
||||
],
|
||||
details: undefined,
|
||||
};
|
||||
},
|
||||
};
|
||||
}
|
||||
|
|
@ -1,163 +0,0 @@
|
|||
/**
|
||||
* Async Jobs Extension
|
||||
*
|
||||
* Allows bash commands to run in the background. The agent gets a job ID
|
||||
* immediately and can continue working. Results are delivered via follow-up
|
||||
* messages when jobs complete.
|
||||
*
|
||||
* Tools:
|
||||
* async_bash — run a command in the background, get a job ID
|
||||
* await_job — wait for background jobs to complete, get results
|
||||
* cancel_job — cancel a running background job
|
||||
*
|
||||
* Commands:
|
||||
* /jobs — show running and recent background jobs
|
||||
*/
|
||||
|
||||
import type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { createAsyncBashTool } from "./async-bash-tool.js";
|
||||
import { createAwaitTool } from "./await-tool.js";
|
||||
import { createCancelJobTool } from "./cancel-job-tool.js";
|
||||
import { AsyncJobManager } from "./job-manager.js";
|
||||
|
||||
export default function AsyncJobs(pi: ExtensionAPI) {
|
||||
let manager: AsyncJobManager | null = null;
|
||||
let latestCwd: string = process.cwd();
|
||||
|
||||
function getManager(): AsyncJobManager {
|
||||
if (!manager) {
|
||||
throw new Error(
|
||||
"AsyncJobManager not initialized. Wait for session_start.",
|
||||
);
|
||||
}
|
||||
return manager;
|
||||
}
|
||||
|
||||
function getCwd(): string {
|
||||
return latestCwd;
|
||||
}
|
||||
|
||||
// ── Session lifecycle ──────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
latestCwd = ctx.cwd;
|
||||
|
||||
manager = new AsyncJobManager({
|
||||
onJobComplete: (job) => {
|
||||
if (job.awaited) return;
|
||||
const statusEmoji = job.status === "completed" ? "done" : "error";
|
||||
const elapsed = ((Date.now() - job.startTime) / 1000).toFixed(1);
|
||||
const output =
|
||||
job.status === "completed"
|
||||
? (job.resultText ?? "(no output)")
|
||||
: `Error: ${job.errorText ?? "unknown error"}`;
|
||||
|
||||
// Truncate output for the follow-up message
|
||||
const maxLen = 2000;
|
||||
const truncatedOutput =
|
||||
output.length > maxLen
|
||||
? output.slice(0, maxLen) +
|
||||
"\n\n[... truncated, use await_job for full output]"
|
||||
: output;
|
||||
|
||||
// Deliver as follow-up without triggering a new LLM turn (#875).
|
||||
// When the agent is streaming: the message is queued and picked up
|
||||
// by the agent loop's getFollowUpMessages() after the current turn.
|
||||
// When the agent is idle: the message is appended to context so it's
|
||||
// visible on the next user-initiated prompt. Previously triggerTurn:true
|
||||
// caused spurious autonomous turns — the model would interpret completed
|
||||
// job output as requiring action and cascade into unbounded self-reinforcing
|
||||
// loops (running more commands, spawning more jobs, burning context).
|
||||
pi.sendMessage(
|
||||
{
|
||||
customType: "async_job_result",
|
||||
content: [
|
||||
`**Background job ${statusEmoji}: ${job.id}** (${job.label}, ${elapsed}s)`,
|
||||
"",
|
||||
truncatedOutput,
|
||||
].join("\n"),
|
||||
display: true,
|
||||
},
|
||||
{ deliverAs: "followUp" },
|
||||
);
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
pi.on("session_before_switch", async () => {
|
||||
if (manager) {
|
||||
// Cancel all running background jobs — their results are no longer
|
||||
// relevant to the new session and would produce wasteful follow-up
|
||||
// notifications that trigger empty LLM turns (#1642).
|
||||
for (const job of manager.getRunningJobs()) {
|
||||
manager.cancel(job.id);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
if (manager) {
|
||||
manager.shutdown();
|
||||
manager = null;
|
||||
}
|
||||
});
|
||||
|
||||
// ── Tools ──────────────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool(createAsyncBashTool(getManager, getCwd));
|
||||
pi.registerTool(createAwaitTool(getManager));
|
||||
pi.registerTool(createCancelJobTool(getManager));
|
||||
|
||||
// ── /jobs command ──────────────────────────────────────────────────────
|
||||
|
||||
pi.registerCommand("jobs", {
|
||||
description: "Show running and recent background jobs",
|
||||
handler: async (_args: string, _ctx: ExtensionCommandContext) => {
|
||||
if (!manager) {
|
||||
pi.sendMessage({
|
||||
customType: "async_jobs_list",
|
||||
content: "No async job manager active.",
|
||||
display: true,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
const running = manager.getRunningJobs();
|
||||
const recent = manager.getRecentJobs(10);
|
||||
const completed = recent.filter((j) => j.status !== "running");
|
||||
|
||||
const lines: string[] = ["## Background Jobs"];
|
||||
|
||||
if (running.length === 0 && completed.length === 0) {
|
||||
lines.push("", "No background jobs.");
|
||||
} else {
|
||||
if (running.length > 0) {
|
||||
lines.push("", "### Running");
|
||||
for (const job of running) {
|
||||
const elapsed = ((Date.now() - job.startTime) / 1000).toFixed(0);
|
||||
lines.push(`- **${job.id}** — ${job.label} (${elapsed}s)`);
|
||||
}
|
||||
}
|
||||
|
||||
if (completed.length > 0) {
|
||||
lines.push("", "### Recent");
|
||||
for (const job of completed) {
|
||||
const elapsed = ((Date.now() - job.startTime) / 1000).toFixed(1);
|
||||
lines.push(
|
||||
`- **${job.id}** — ${job.label} (${job.status}, ${elapsed}s)`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pi.sendMessage({
|
||||
customType: "async_jobs_list",
|
||||
content: lines.join("\n"),
|
||||
display: true,
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
/**
|
||||
* AsyncJobManager — manages background tool call jobs.
|
||||
*
|
||||
* Each job runs asynchronously and delivers its result via a callback
|
||||
* when complete. Jobs are evicted after a configurable TTL.
|
||||
*/
|
||||
|
||||
import { randomUUID } from "node:crypto";
|
||||
|
||||
// ── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
export type JobStatus = "running" | "completed" | "failed" | "cancelled";
|
||||
export type JobType = "bash";
|
||||
|
||||
export interface Job {
|
||||
id: string;
|
||||
type: JobType;
|
||||
status: JobStatus;
|
||||
startTime: number;
|
||||
label: string;
|
||||
abortController: AbortController;
|
||||
promise: Promise<void>;
|
||||
resultText?: string;
|
||||
errorText?: string;
|
||||
/** Set by await_job when results are consumed. Suppresses follow-up delivery. */
|
||||
awaited?: boolean;
|
||||
/**
|
||||
* Handle for the pending follow-up delivery timer (set by deliverResult).
|
||||
* Stored so suppressFollowUp() can cancel it before the notification fires,
|
||||
* even when await_job is called after the job has already completed (#3787).
|
||||
*/
|
||||
deliveryTimer?: ReturnType<typeof setTimeout>;
|
||||
}
|
||||
|
||||
export interface JobManagerOptions {
|
||||
maxRunning?: number; // default 15
|
||||
maxTotal?: number; // default 100
|
||||
evictionMs?: number; // default 5 minutes
|
||||
onJobComplete?: (job: Job) => void;
|
||||
}
|
||||
|
||||
// ── Manager ────────────────────────────────────────────────────────────────
|
||||
|
||||
export class AsyncJobManager {
|
||||
private jobs = new Map<string, Job>();
|
||||
private evictionTimers = new Map<string, ReturnType<typeof setTimeout>>();
|
||||
|
||||
private maxRunning: number;
|
||||
private maxTotal: number;
|
||||
private evictionMs: number;
|
||||
private onJobComplete?: (job: Job) => void;
|
||||
|
||||
constructor(options: JobManagerOptions = {}) {
|
||||
this.maxRunning = options.maxRunning ?? 15;
|
||||
this.maxTotal = options.maxTotal ?? 100;
|
||||
this.evictionMs = options.evictionMs ?? 5 * 60 * 1000;
|
||||
this.onJobComplete = options.onJobComplete;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a new background job.
|
||||
* @returns job ID (prefixed with `bg_`)
|
||||
*/
|
||||
register(
|
||||
type: JobType,
|
||||
label: string,
|
||||
runFn: (signal: AbortSignal) => Promise<string>,
|
||||
): string {
|
||||
// Enforce limits
|
||||
const running = this.getRunningJobs();
|
||||
if (running.length >= this.maxRunning) {
|
||||
throw new Error(
|
||||
`Maximum concurrent background jobs reached (${this.maxRunning}). ` +
|
||||
`Use await_job or cancel_job to free a slot.`,
|
||||
);
|
||||
}
|
||||
if (this.jobs.size >= this.maxTotal) {
|
||||
// Evict oldest completed job
|
||||
this.evictOldest();
|
||||
if (this.jobs.size >= this.maxTotal) {
|
||||
throw new Error(
|
||||
`Maximum total background jobs reached (${this.maxTotal}). ` +
|
||||
`Use cancel_job to remove jobs.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
const id = `bg_${randomUUID().slice(0, 8)}`;
|
||||
const abortController = new AbortController();
|
||||
|
||||
// Declare job first so the promise callbacks can close over it safely.
|
||||
const job: Job = {
|
||||
id,
|
||||
type,
|
||||
status: "running",
|
||||
startTime: Date.now(),
|
||||
label,
|
||||
abortController,
|
||||
// promise assigned below
|
||||
promise: undefined as unknown as Promise<void>,
|
||||
};
|
||||
|
||||
job.promise = runFn(abortController.signal)
|
||||
.then((resultText) => {
|
||||
job.status = "completed";
|
||||
job.resultText = resultText;
|
||||
this.scheduleEviction(id);
|
||||
this.deliverResult(job);
|
||||
})
|
||||
.catch((err) => {
|
||||
if (job.status === "cancelled") {
|
||||
// Already cancelled — don't overwrite
|
||||
this.scheduleEviction(id);
|
||||
return;
|
||||
}
|
||||
job.status = "failed";
|
||||
job.errorText = err instanceof Error ? err.message : String(err);
|
||||
this.scheduleEviction(id);
|
||||
this.deliverResult(job);
|
||||
});
|
||||
|
||||
this.jobs.set(id, job);
|
||||
return id;
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel a running job.
|
||||
*/
|
||||
cancel(id: string): "cancelled" | "not_found" | "already_completed" {
|
||||
const job = this.jobs.get(id);
|
||||
if (!job) return "not_found";
|
||||
if (job.status !== "running") return "already_completed";
|
||||
|
||||
job.status = "cancelled";
|
||||
job.errorText = "Cancelled by user";
|
||||
job.abortController.abort();
|
||||
this.scheduleEviction(id);
|
||||
return "cancelled";
|
||||
}
|
||||
|
||||
getJob(id: string): Job | undefined {
|
||||
return this.jobs.get(id);
|
||||
}
|
||||
|
||||
getRunningJobs(): Job[] {
|
||||
return [...this.jobs.values()].filter((j) => j.status === "running");
|
||||
}
|
||||
|
||||
getRecentJobs(limit = 10): Job[] {
|
||||
return [...this.jobs.values()]
|
||||
.sort((a, b) => b.startTime - a.startTime)
|
||||
.slice(0, limit);
|
||||
}
|
||||
|
||||
getAllJobs(): Job[] {
|
||||
return [...this.jobs.values()];
|
||||
}
|
||||
|
||||
/**
|
||||
* Cleanup all timers and resources.
|
||||
*/
|
||||
shutdown(): void {
|
||||
for (const timer of this.evictionTimers.values()) {
|
||||
clearTimeout(timer);
|
||||
}
|
||||
this.evictionTimers.clear();
|
||||
|
||||
// Abort all running jobs
|
||||
for (const job of this.jobs.values()) {
|
||||
if (job.status === "running") {
|
||||
job.status = "cancelled";
|
||||
job.abortController.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Private ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Suppress follow-up notification for a job — cancels any pending delivery
|
||||
* timer and marks the job as awaited. Safe to call at any time, including
|
||||
* before or after the job completes (#3787).
|
||||
*/
|
||||
suppressFollowUp(id: string): void {
|
||||
const job = this.jobs.get(id);
|
||||
if (!job) return;
|
||||
job.awaited = true;
|
||||
if (job.deliveryTimer !== undefined) {
|
||||
clearTimeout(job.deliveryTimer);
|
||||
job.deliveryTimer = undefined;
|
||||
}
|
||||
}
|
||||
|
||||
private deliverResult(job: Job): void {
|
||||
if (!this.onJobComplete) return;
|
||||
// Use setTimeout(0) instead of queueMicrotask so the handle is cancellable.
|
||||
// suppressFollowUp() can clear this timer even when await_job is called in
|
||||
// a later LLM turn (after the job already completed). queueMicrotask ran
|
||||
// immediately and could not be cancelled (#2762, #3787).
|
||||
const cb = this.onJobComplete;
|
||||
job.deliveryTimer = setTimeout(() => {
|
||||
job.deliveryTimer = undefined;
|
||||
if (!job.awaited) cb(job);
|
||||
}, 0);
|
||||
// Allow process to exit even if timer is pending
|
||||
if (typeof job.deliveryTimer === "object" && "unref" in job.deliveryTimer) {
|
||||
(job.deliveryTimer as NodeJS.Timeout).unref();
|
||||
}
|
||||
}
|
||||
|
||||
private scheduleEviction(id: string): void {
|
||||
const existing = this.evictionTimers.get(id);
|
||||
if (existing) clearTimeout(existing);
|
||||
|
||||
const timer = setTimeout(() => {
|
||||
this.evictionTimers.delete(id);
|
||||
this.jobs.delete(id);
|
||||
}, this.evictionMs);
|
||||
|
||||
this.evictionTimers.set(id, timer);
|
||||
}
|
||||
|
||||
private evictOldest(): void {
|
||||
let oldest: Job | undefined;
|
||||
for (const job of this.jobs.values()) {
|
||||
if (job.status !== "running") {
|
||||
if (!oldest || job.startTime < oldest.startTime) {
|
||||
oldest = job;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (oldest) {
|
||||
const timer = this.evictionTimers.get(oldest.id);
|
||||
if (timer) clearTimeout(timer);
|
||||
this.evictionTimers.delete(oldest.id);
|
||||
this.jobs.delete(oldest.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
/**
|
||||
* AWS Auth Refresh Extension
|
||||
*
|
||||
* Automatically refreshes AWS credentials when Bedrock API requests fail
|
||||
* with authentication/token errors, then retries the user's message.
|
||||
*
|
||||
* ## How it works
|
||||
*
|
||||
* Hooks into `agent_end` to check if the last assistant message failed with
|
||||
* an AWS auth error (expired SSO token, missing credentials, etc.). If so:
|
||||
*
|
||||
* 1. Runs the configured `awsAuthRefresh` command (e.g. `aws sso login`)
|
||||
* 2. Streams the SSO auth URL and verification code to the TUI so users
|
||||
* can copy/paste if the browser doesn't auto-open
|
||||
* 3. Calls `retryLastTurn()` which removes the failed assistant response
|
||||
* and re-runs the agent from the user's original message
|
||||
*
|
||||
* ## Activation
|
||||
*
|
||||
* This extension is completely inert unless BOTH conditions are met:
|
||||
* 1. A Bedrock API request fails with a recognized AWS auth error
|
||||
* 2. `awsAuthRefresh` is configured in settings.json
|
||||
*
|
||||
* Non-Bedrock users and Bedrock users without `awsAuthRefresh` configured
|
||||
* are not affected in any way.
|
||||
*
|
||||
* ## Setup
|
||||
*
|
||||
* Add to ~/.sf/agent/settings.json (or project-level .sf/settings.json):
|
||||
*
|
||||
* { "awsAuthRefresh": "aws sso login --profile my-profile" }
|
||||
*
|
||||
* ## Matched error patterns
|
||||
*
|
||||
* The extension recognizes errors from the AWS SDK, Bedrock, and SSO
|
||||
* credential providers including:
|
||||
* - ExpiredTokenException / ExpiredToken
|
||||
* - The security token included in the request is expired
|
||||
* - The SSO session associated with this profile has expired or is invalid
|
||||
* - Unable to locate credentials / Could not load credentials
|
||||
* - UnrecognizedClientException
|
||||
* - Error loading SSO Token / Token does not exist
|
||||
* - SSOTokenProviderFailure
|
||||
*/
|
||||
|
||||
import { exec } from "node:child_process";
|
||||
import { existsSync, readFileSync } from "node:fs";
|
||||
import { homedir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
|
||||
/** Matches AWS SDK / Bedrock / SSO credential and token errors. */
|
||||
const AWS_AUTH_ERROR_RE =
|
||||
/ExpiredToken|security token.*expired|unable to locate credentials|SSO.*(?:session|token).*(?:expired|not found|invalid)|UnrecognizedClient|Could not load credentials|Invalid identity token|token is expired|credentials.*(?:could not|cannot|failed to).*(?:load|resolve|find)|The.*token.*is.*not.*valid|token has expired|SSOTokenProviderFailure|Error loading SSO Token|Token.*does not exist/i;
|
||||
|
||||
/**
|
||||
* Reads the `awsAuthRefresh` command from settings.json.
|
||||
* Checks project-level first, then global (~/.sf/agent/settings.json).
|
||||
*/
|
||||
function getAwsAuthRefreshCommand(): string | undefined {
|
||||
const configDir = process.env.PI_CONFIG_DIR || ".sf";
|
||||
const paths = [
|
||||
join(process.cwd(), configDir, "settings.json"),
|
||||
join(homedir(), configDir, "agent", "settings.json"),
|
||||
];
|
||||
for (const settingsPath of paths) {
|
||||
if (!existsSync(settingsPath)) continue;
|
||||
try {
|
||||
const settings = JSON.parse(readFileSync(settingsPath, "utf-8"));
|
||||
if (settings.awsAuthRefresh) return settings.awsAuthRefresh;
|
||||
} catch {} // file missing or corrupt → skip, try next location
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the refresh command with a 2-minute timeout (for SSO browser flows).
|
||||
* Streams stdout/stderr to capture and display the SSO auth URL and
|
||||
* verification code in real-time via TUI notifications.
|
||||
*/
|
||||
async function runRefresh(
|
||||
command: string,
|
||||
notify: (msg: string, level: "info" | "warning" | "error") => void,
|
||||
): Promise<boolean> {
|
||||
notify("Refreshing AWS credentials...", "info");
|
||||
try {
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const child = exec(command, {
|
||||
timeout: 120_000,
|
||||
env: { ...process.env },
|
||||
});
|
||||
const onData = (data: Buffer | string) => {
|
||||
const text = data.toString();
|
||||
const urlMatch = text.match(/https?:\/\/\S+/);
|
||||
if (urlMatch) {
|
||||
notify(
|
||||
`Open this URL if the browser didn't launch: ${urlMatch[0]}`,
|
||||
"warning",
|
||||
);
|
||||
}
|
||||
const codeMatch = text.match(/code[:\s]+([A-Z]{4}-[A-Z]{4})/i);
|
||||
if (codeMatch) {
|
||||
notify(`Verification code: ${codeMatch[1]}`, "info");
|
||||
}
|
||||
};
|
||||
child.stdout?.on("data", onData);
|
||||
child.stderr?.on("data", onData);
|
||||
child.on("close", (code) => {
|
||||
if (code === 0) resolve();
|
||||
else reject(new Error(`Refresh command exited with code ${code}`));
|
||||
});
|
||||
child.on("error", reject);
|
||||
});
|
||||
notify("AWS credentials refreshed successfully ✓", "info");
|
||||
return true;
|
||||
} catch (error) {
|
||||
const msg = error instanceof Error ? error.message : String(error);
|
||||
const isTimeout = /timed out|ETIMEDOUT|killed/i.test(msg);
|
||||
if (isTimeout) {
|
||||
notify(
|
||||
"AWS credential refresh timed out. The SSO login may have been cancelled or the browser window was closed.",
|
||||
"error",
|
||||
);
|
||||
} else {
|
||||
notify(`AWS credential refresh failed: ${msg}`, "error");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.on("agent_end", async (event, ctx) => {
|
||||
const refreshCommand = getAwsAuthRefreshCommand();
|
||||
if (!refreshCommand) return;
|
||||
|
||||
const messages = event.messages;
|
||||
const lastAssistant = messages[messages.length - 1];
|
||||
if (
|
||||
!lastAssistant ||
|
||||
lastAssistant.role !== "assistant" ||
|
||||
!("errorMessage" in lastAssistant) ||
|
||||
!lastAssistant.errorMessage ||
|
||||
!AWS_AUTH_ERROR_RE.test(lastAssistant.errorMessage)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
const refreshed = await runRefresh(refreshCommand, (m, level) =>
|
||||
ctx.ui.notify(m, level),
|
||||
);
|
||||
if (!refreshed) return;
|
||||
|
||||
pi.retryLastTurn();
|
||||
});
|
||||
}
|
||||
|
|
@ -1,241 +0,0 @@
|
|||
/**
|
||||
* /bg slash command registration — interactive process manager overlay and CLI subcommands.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { Key } from "@singularity-forge/pi-tui";
|
||||
import { shortcutDesc } from "../shared/terminal.js";
|
||||
import type { BgShellSharedState } from "./index.js";
|
||||
import {
|
||||
formatDigestText,
|
||||
generateDigest,
|
||||
getOutput,
|
||||
} from "./output-formatter.js";
|
||||
import { BgManagerOverlay } from "./overlay.js";
|
||||
import {
|
||||
cleanupAll,
|
||||
getGroupStatus,
|
||||
killProcess,
|
||||
processes,
|
||||
} from "./process-manager.js";
|
||||
import { formatUptime } from "./utilities.js";
|
||||
|
||||
export function registerBgShellCommand(
|
||||
pi: ExtensionAPI,
|
||||
state: BgShellSharedState,
|
||||
): void {
|
||||
pi.registerCommand("bg", {
|
||||
description:
|
||||
"Manage background processes: /bg [list|output|kill|killall|groups] [id]",
|
||||
|
||||
getArgumentCompletions: (prefix: string) => {
|
||||
const subcommands = [
|
||||
"list",
|
||||
"output",
|
||||
"kill",
|
||||
"killall",
|
||||
"groups",
|
||||
"digest",
|
||||
];
|
||||
const parts = prefix.trim().split(/\s+/);
|
||||
|
||||
if (parts.length <= 1) {
|
||||
return subcommands
|
||||
.filter((cmd) => cmd.startsWith(parts[0] ?? ""))
|
||||
.map((cmd) => ({ value: cmd, label: cmd }));
|
||||
}
|
||||
|
||||
if (
|
||||
parts[0] === "output" ||
|
||||
parts[0] === "kill" ||
|
||||
parts[0] === "digest"
|
||||
) {
|
||||
const idPrefix = parts[1] ?? "";
|
||||
return Array.from(processes.values())
|
||||
.filter((p) => p.id.startsWith(idPrefix))
|
||||
.map((p) => ({
|
||||
value: `${parts[0]} ${p.id}`,
|
||||
label: `${p.id} — ${p.label}`,
|
||||
}));
|
||||
}
|
||||
|
||||
return [];
|
||||
},
|
||||
|
||||
handler: async (args, ctx) => {
|
||||
const parts = args.trim().split(/\s+/);
|
||||
const sub = parts[0] || "list";
|
||||
|
||||
if (sub === "list" || sub === "") {
|
||||
if (processes.size === 0) {
|
||||
ctx.ui.notify("No background processes.", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
const lines = Array.from(processes.values()).map((p) => {
|
||||
const statusIcon = p.alive
|
||||
? p.status === "ready"
|
||||
? "✓"
|
||||
: p.status === "error"
|
||||
? "✗"
|
||||
: "⋯"
|
||||
: "○";
|
||||
const uptime = formatUptime(Date.now() - p.startedAt);
|
||||
const portInfo = p.ports.length > 0 ? ` :${p.ports.join(",")}` : "";
|
||||
return `${p.id} ${statusIcon} ${p.status} ${uptime} ${p.label} [${p.processType}]${portInfo}`;
|
||||
});
|
||||
ctx.ui.notify(lines.join("\n"), "info");
|
||||
return;
|
||||
}
|
||||
|
||||
await ctx.ui.custom<void>(
|
||||
(tui, theme, _kb, done) => {
|
||||
return new BgManagerOverlay(tui, theme, () => {
|
||||
done();
|
||||
state.refreshWidget();
|
||||
});
|
||||
},
|
||||
{
|
||||
overlay: true,
|
||||
overlayOptions: {
|
||||
width: "60%",
|
||||
minWidth: 50,
|
||||
maxHeight: "70%",
|
||||
anchor: "center",
|
||||
},
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (sub === "output" || sub === "digest") {
|
||||
const id = parts[1];
|
||||
if (!id) {
|
||||
ctx.ui.notify(`Usage: /bg ${sub} <id>`, "error");
|
||||
return;
|
||||
}
|
||||
const bg = processes.get(id);
|
||||
if (!bg) {
|
||||
ctx.ui.notify(`No process with id '${id}'`, "error");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
if (sub === "digest") {
|
||||
const digest = generateDigest(bg);
|
||||
ctx.ui.notify(formatDigestText(bg, digest), "info");
|
||||
} else {
|
||||
const output = getOutput(bg, { stream: "both", tail: 50 });
|
||||
ctx.ui.notify(output || "(no output)", "info");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
await ctx.ui.custom<void>(
|
||||
(tui, theme, _kb, done) => {
|
||||
const overlay = new BgManagerOverlay(tui, theme, () => {
|
||||
done();
|
||||
state.refreshWidget();
|
||||
});
|
||||
const procs = Array.from(processes.values());
|
||||
const idx = procs.findIndex((p) => p.id === id);
|
||||
if (idx >= 0) overlay.selectAndView(idx);
|
||||
return overlay;
|
||||
},
|
||||
{
|
||||
overlay: true,
|
||||
overlayOptions: {
|
||||
width: "60%",
|
||||
minWidth: 50,
|
||||
maxHeight: "70%",
|
||||
anchor: "center",
|
||||
},
|
||||
},
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
if (sub === "kill") {
|
||||
const id = parts[1];
|
||||
if (!id) {
|
||||
ctx.ui.notify("Usage: /bg kill <id>", "error");
|
||||
return;
|
||||
}
|
||||
const bg = processes.get(id);
|
||||
if (!bg) {
|
||||
ctx.ui.notify(`No process with id '${id}'`, "error");
|
||||
return;
|
||||
}
|
||||
killProcess(id, "SIGTERM");
|
||||
await new Promise((r) => setTimeout(r, 300));
|
||||
if (bg.alive) {
|
||||
killProcess(id, "SIGKILL");
|
||||
await new Promise((r) => setTimeout(r, 200));
|
||||
}
|
||||
if (!bg.alive) processes.delete(id);
|
||||
ctx.ui.notify(`Killed process ${id} (${bg.label})`, "info");
|
||||
return;
|
||||
}
|
||||
|
||||
if (sub === "killall") {
|
||||
const count = processes.size;
|
||||
cleanupAll();
|
||||
ctx.ui.notify(`Killed ${count} background process(es)`, "info");
|
||||
return;
|
||||
}
|
||||
|
||||
if (sub === "groups") {
|
||||
const groups = new Set<string>();
|
||||
for (const p of processes.values()) {
|
||||
if (p.group) groups.add(p.group);
|
||||
}
|
||||
if (groups.size === 0) {
|
||||
ctx.ui.notify("No process groups defined.", "info");
|
||||
return;
|
||||
}
|
||||
const lines = Array.from(groups).map((g) => {
|
||||
const gs = getGroupStatus(g);
|
||||
const icon = gs.healthy ? "✓" : "✗";
|
||||
const procs = gs.processes
|
||||
.map((p) => `${p.id}(${p.status})`)
|
||||
.join(", ");
|
||||
return `${icon} ${g}: ${procs}`;
|
||||
});
|
||||
ctx.ui.notify(lines.join("\n"), "info");
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.ui.notify(
|
||||
"Usage: /bg [list|output|digest|kill|killall|groups] [id]",
|
||||
"info",
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// ── Ctrl+Alt+B shortcut ──────────────────────────────────────────────
|
||||
|
||||
pi.registerShortcut(Key.ctrlAlt("b"), {
|
||||
description: shortcutDesc("Open background process manager", "/bg"),
|
||||
handler: async (ctx) => {
|
||||
state.latestCtx = ctx;
|
||||
await ctx.ui.custom<void>(
|
||||
(tui, theme, _kb, done) => {
|
||||
return new BgManagerOverlay(tui, theme, () => {
|
||||
done();
|
||||
state.refreshWidget();
|
||||
});
|
||||
},
|
||||
{
|
||||
overlay: true,
|
||||
overlayOptions: {
|
||||
width: "60%",
|
||||
minWidth: 50,
|
||||
maxHeight: "70%",
|
||||
anchor: "center",
|
||||
},
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,480 +0,0 @@
|
|||
/**
|
||||
* bg_shell lifecycle hook registration — session events, compaction awareness,
|
||||
* context injection, process discovery, footer widget, and periodic maintenance.
|
||||
*/
|
||||
|
||||
import type {
|
||||
ExtensionAPI,
|
||||
ExtensionContext,
|
||||
Theme,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { truncateToWidth, visibleWidth } from "@singularity-forge/pi-tui";
|
||||
import { formatTokenCount } from "../shared/format-utils.js";
|
||||
import type { BgShellSharedState } from "./index.js";
|
||||
import {
|
||||
cleanupAll,
|
||||
cleanupSessionProcesses,
|
||||
loadManifest,
|
||||
pendingAlerts,
|
||||
persistManifest,
|
||||
processes,
|
||||
pruneDeadProcesses,
|
||||
pushAlert,
|
||||
} from "./process-manager.js";
|
||||
import {
|
||||
formatUptime,
|
||||
getBgShellLiveCwd,
|
||||
resolveBgShellPersistenceCwd,
|
||||
} from "./utilities.js";
|
||||
|
||||
export function registerBgShellLifecycle(
|
||||
pi: ExtensionAPI,
|
||||
state: BgShellSharedState,
|
||||
): void {
|
||||
function syncLatestCtxCwd(): void {
|
||||
if (!state.latestCtx) return;
|
||||
const syncedCwd = resolveBgShellPersistenceCwd(state.latestCtx.cwd);
|
||||
if (syncedCwd !== state.latestCtx.cwd) {
|
||||
state.latestCtx = { ...state.latestCtx, cwd: syncedCwd };
|
||||
}
|
||||
}
|
||||
|
||||
// Register signal handlers to clean up bg processes on unexpected exit (fixes #428)
|
||||
const signalCleanup = () => {
|
||||
cleanupAll();
|
||||
// Also kill bash-tool spawned children that bg-shell doesn't track
|
||||
try {
|
||||
const { listDescendants } =
|
||||
require("@singularity-forge/native") as typeof import("@singularity-forge/native");
|
||||
const descendants = listDescendants(process.pid);
|
||||
for (const childPid of descendants) {
|
||||
try {
|
||||
process.kill(childPid, "SIGKILL");
|
||||
} catch {} // child already dead → harmless
|
||||
}
|
||||
} catch {} // native not available → can't track descendants, continue
|
||||
};
|
||||
process.on("SIGTERM", signalCleanup);
|
||||
process.on("SIGINT", signalCleanup);
|
||||
process.on("beforeExit", signalCleanup);
|
||||
|
||||
// Clean up on session shutdown — remove signal handlers to prevent accumulation
|
||||
pi.on("session_shutdown", async () => {
|
||||
process.off("SIGTERM", signalCleanup);
|
||||
process.off("SIGINT", signalCleanup);
|
||||
process.off("beforeExit", signalCleanup);
|
||||
cleanupAll();
|
||||
});
|
||||
|
||||
// ── Compaction Awareness: Survive Context Resets ───────────────
|
||||
|
||||
/** Build a compact state summary of all alive processes for context re-injection */
|
||||
function buildProcessStateAlert(reason: string): void {
|
||||
const alive = Array.from(processes.values()).filter((p) => p.alive);
|
||||
if (alive.length === 0) return;
|
||||
|
||||
const processSummaries = alive
|
||||
.map((p) => {
|
||||
const portInfo = p.ports.length > 0 ? ` :${p.ports.join(",")}` : "";
|
||||
const urlInfo = p.urls.length > 0 ? ` ${p.urls[0]}` : "";
|
||||
const errInfo =
|
||||
p.recentErrors.length > 0 ? ` (${p.recentErrors.length} errors)` : "";
|
||||
const groupInfo = p.group ? ` [${p.group}]` : "";
|
||||
return ` - id:${p.id} "${p.label}" [${p.processType}] status:${p.status} uptime:${formatUptime(Date.now() - p.startedAt)}${portInfo}${urlInfo}${errInfo}${groupInfo}`;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
pushAlert(
|
||||
null,
|
||||
`${reason} ${alive.length} background process(es) are still running:\n${processSummaries}\nUse bg_shell digest/output/kill with these IDs.`,
|
||||
);
|
||||
}
|
||||
|
||||
// After compaction, the LLM loses all memory of running processes.
|
||||
// Queue a detailed alert so the next before_agent_start injects full state.
|
||||
pi.on("session_compact", async () => {
|
||||
buildProcessStateAlert("Context was compacted.");
|
||||
});
|
||||
|
||||
// Tree navigation also resets the agent's context.
|
||||
pi.on("session_tree", async () => {
|
||||
buildProcessStateAlert("Session tree was navigated.");
|
||||
});
|
||||
|
||||
// Session switch resets the agent's context.
|
||||
pi.on("session_switch", async (event, ctx) => {
|
||||
state.latestCtx = ctx;
|
||||
if (event.reason === "new" && event.previousSessionFile) {
|
||||
await cleanupSessionProcesses(event.previousSessionFile);
|
||||
syncLatestCtxCwd();
|
||||
if (state.latestCtx) persistManifest(state.latestCtx.cwd);
|
||||
}
|
||||
buildProcessStateAlert("Session was switched.");
|
||||
});
|
||||
|
||||
// ── Context Injection: Proactive Alerts ────────────────────────────
|
||||
|
||||
pi.on("before_agent_start", async (_event, _ctx) => {
|
||||
// Inject process status overview and any pending alerts
|
||||
const alerts = pendingAlerts.splice(0);
|
||||
const alive = Array.from(processes.values()).filter((p) => p.alive);
|
||||
|
||||
if (alerts.length === 0 && alive.length === 0) return;
|
||||
|
||||
const parts: string[] = [];
|
||||
|
||||
if (alerts.length > 0) {
|
||||
parts.push(
|
||||
`Background process alerts:\n${alerts.map((a) => ` ${a}`).join("\n")}`,
|
||||
);
|
||||
}
|
||||
|
||||
if (alive.length > 0) {
|
||||
const summary = alive
|
||||
.map((p) => {
|
||||
const status =
|
||||
p.status === "ready"
|
||||
? "✓"
|
||||
: p.status === "error"
|
||||
? "✗"
|
||||
: p.status === "starting"
|
||||
? "⋯"
|
||||
: "?";
|
||||
const portInfo = p.ports.length > 0 ? ` :${p.ports.join(",")}` : "";
|
||||
const errInfo =
|
||||
p.recentErrors.length > 0
|
||||
? ` (${p.recentErrors.length} errors)`
|
||||
: "";
|
||||
return ` ${status} ${p.id} ${p.label}${portInfo}${errInfo}`;
|
||||
})
|
||||
.join("\n");
|
||||
parts.push(`Background processes:\n${summary}`);
|
||||
}
|
||||
|
||||
return {
|
||||
message: {
|
||||
customType: "bg-shell-status",
|
||||
content: parts.join("\n\n"),
|
||||
display: false,
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
// ── Session Start: Discover Surviving Processes ────────────────────
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
state.latestCtx = ctx;
|
||||
|
||||
// Check for surviving processes from previous session
|
||||
const manifest = loadManifest(ctx.cwd);
|
||||
if (manifest.length > 0) {
|
||||
// Check which PIDs are still alive
|
||||
const surviving: typeof manifest = [];
|
||||
for (const entry of manifest) {
|
||||
if (entry.pid) {
|
||||
try {
|
||||
process.kill(entry.pid, 0); // Check if process exists
|
||||
surviving.push(entry);
|
||||
} catch {
|
||||
/* process is dead */
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (surviving.length > 0) {
|
||||
const summary = surviving
|
||||
.map(
|
||||
(s) =>
|
||||
` - ${s.id}: ${s.label} (pid ${s.pid}, type: ${s.processType}${s.group ? `, group: ${s.group}` : ""})`,
|
||||
)
|
||||
.join("\n");
|
||||
|
||||
pushAlert(
|
||||
null,
|
||||
`${surviving.length} background process(es) from previous session still running:\n${summary}\n Note: These processes are outside bg_shell's control. Kill them manually if needed.`,
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// ── Live Footer ──────────────────────────────────────────────────────
|
||||
|
||||
/** Whether we currently own the footer via setFooter */
|
||||
let footerActive = false;
|
||||
|
||||
function buildBgStatusText(th: Theme): string {
|
||||
const alive = Array.from(processes.values()).filter((p) => p.alive);
|
||||
if (alive.length === 0) return "";
|
||||
|
||||
const sep = th.fg("dim", " · ");
|
||||
const items: string[] = [];
|
||||
for (const p of alive) {
|
||||
const statusIcon =
|
||||
p.status === "ready"
|
||||
? th.fg("success", "●")
|
||||
: p.status === "error"
|
||||
? th.fg("error", "●")
|
||||
: th.fg("warning", "●");
|
||||
const name = p.label.length > 14 ? p.label.slice(0, 12) + "…" : p.label;
|
||||
const portInfo = p.ports.length > 0 ? th.fg("dim", `:${p.ports[0]}`) : "";
|
||||
const errBadge =
|
||||
p.recentErrors.length > 0
|
||||
? th.fg("error", ` err:${p.recentErrors.length}`)
|
||||
: "";
|
||||
items.push(`${statusIcon} ${th.fg("muted", name)}${portInfo}${errBadge}`);
|
||||
}
|
||||
return items.join(sep);
|
||||
}
|
||||
|
||||
/** Reference to tui for triggering re-renders when footer is active */
|
||||
let footerTui: { requestRender: () => void } | null = null;
|
||||
|
||||
function refreshWidget() {
|
||||
if (!state.latestCtx?.hasUI) return;
|
||||
const alive = Array.from(processes.values()).filter((p) => p.alive);
|
||||
|
||||
if (alive.length === 0) {
|
||||
if (footerActive) {
|
||||
state.latestCtx.ui.setFooter(undefined);
|
||||
footerActive = false;
|
||||
footerTui = null;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (footerActive) {
|
||||
// Footer already installed — just trigger a re-render
|
||||
footerTui?.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
// Install custom footer that puts bg process info right-aligned on line 1
|
||||
footerActive = true;
|
||||
state.latestCtx.ui.setFooter((tui, th, footerData) => {
|
||||
footerTui = tui;
|
||||
const branchUnsub = footerData.onBranchChange(() => tui.requestRender());
|
||||
|
||||
return {
|
||||
render(width: number): string[] {
|
||||
// ── Line 1: pwd (branch) [session] ... bg status ──
|
||||
let pwd = getBgShellLiveCwd(state.latestCtx?.cwd);
|
||||
const home = process.env.HOME || process.env.USERPROFILE;
|
||||
if (home && pwd.startsWith(home)) {
|
||||
pwd = `~${pwd.slice(home.length)}`;
|
||||
}
|
||||
const branch = footerData.getGitBranch();
|
||||
if (branch) pwd = `${pwd} (${branch})`;
|
||||
|
||||
const sessionName =
|
||||
state.latestCtx?.sessionManager?.getSessionName?.();
|
||||
if (sessionName) pwd = `${pwd} • ${sessionName}`;
|
||||
|
||||
const bgStatus = buildBgStatusText(th);
|
||||
const leftPwd = th.fg("dim", pwd);
|
||||
const leftWidth = visibleWidth(leftPwd);
|
||||
const rightWidth = visibleWidth(bgStatus);
|
||||
|
||||
let pwdLine: string;
|
||||
const minGap = 2;
|
||||
if (bgStatus && leftWidth + minGap + rightWidth <= width) {
|
||||
const pad = " ".repeat(width - leftWidth - rightWidth);
|
||||
pwdLine = leftPwd + pad + bgStatus;
|
||||
} else if (bgStatus) {
|
||||
// Truncate pwd to make room for bg status
|
||||
const availForPwd = width - rightWidth - minGap;
|
||||
if (availForPwd > 10) {
|
||||
const truncPwd = truncateToWidth(
|
||||
leftPwd,
|
||||
availForPwd,
|
||||
th.fg("dim", "…"),
|
||||
);
|
||||
const truncWidth = visibleWidth(truncPwd);
|
||||
const pad = " ".repeat(
|
||||
Math.max(0, width - truncWidth - rightWidth),
|
||||
);
|
||||
pwdLine = truncPwd + pad + bgStatus;
|
||||
} else {
|
||||
pwdLine = truncateToWidth(leftPwd, width, th.fg("dim", "…"));
|
||||
}
|
||||
} else {
|
||||
pwdLine = truncateToWidth(leftPwd, width, th.fg("dim", "…"));
|
||||
}
|
||||
|
||||
// ── Line 2: token stats (left) ... model (right) ──
|
||||
const ctx = state.latestCtx;
|
||||
const sm = ctx?.sessionManager;
|
||||
let totalInput = 0,
|
||||
totalOutput = 0;
|
||||
let totalCacheRead = 0,
|
||||
totalCacheWrite = 0,
|
||||
totalCost = 0;
|
||||
if (sm) {
|
||||
for (const entry of sm.getEntries()) {
|
||||
if (
|
||||
entry.type === "message" &&
|
||||
(entry as any).message?.role === "assistant"
|
||||
) {
|
||||
const u = (entry as any).message.usage;
|
||||
if (u) {
|
||||
totalInput += u.input || 0;
|
||||
totalOutput += u.output || 0;
|
||||
totalCacheRead += u.cacheRead || 0;
|
||||
totalCacheWrite += u.cacheWrite || 0;
|
||||
totalCost += u.cost?.total || 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const contextUsage = ctx?.getContextUsage?.();
|
||||
const contextWindow =
|
||||
contextUsage?.contextWindow ?? ctx?.model?.contextWindow ?? 0;
|
||||
const contextPercentValue = contextUsage?.percent ?? 0;
|
||||
const contextPercent =
|
||||
contextUsage?.percent !== null
|
||||
? contextPercentValue.toFixed(1)
|
||||
: "?";
|
||||
|
||||
const statsParts: string[] = [];
|
||||
if (totalInput) statsParts.push(`↑${formatTokenCount(totalInput)}`);
|
||||
if (totalOutput) statsParts.push(`↓${formatTokenCount(totalOutput)}`);
|
||||
if (totalCacheRead)
|
||||
statsParts.push(`R${formatTokenCount(totalCacheRead)}`);
|
||||
if (totalCacheWrite)
|
||||
statsParts.push(`W${formatTokenCount(totalCacheWrite)}`);
|
||||
if (totalCost) statsParts.push(`$${totalCost.toFixed(3)}`);
|
||||
|
||||
const contextDisplay =
|
||||
contextPercent === "?"
|
||||
? `?/${formatTokenCount(contextWindow)}`
|
||||
: `${contextPercent}%/${formatTokenCount(contextWindow)}`;
|
||||
let contextStr: string;
|
||||
if (contextPercentValue > 90) {
|
||||
contextStr = th.fg("error", contextDisplay);
|
||||
} else if (contextPercentValue > 70) {
|
||||
contextStr = th.fg("warning", contextDisplay);
|
||||
} else {
|
||||
contextStr = contextDisplay;
|
||||
}
|
||||
statsParts.push(contextStr);
|
||||
|
||||
let statsLeft = statsParts.join(" ");
|
||||
let statsLeftWidth = visibleWidth(statsLeft);
|
||||
if (statsLeftWidth > width) {
|
||||
statsLeft = truncateToWidth(statsLeft, width, "...");
|
||||
statsLeftWidth = visibleWidth(statsLeft);
|
||||
}
|
||||
|
||||
const modelName = ctx?.model?.id || "no-model";
|
||||
let rightSide = modelName;
|
||||
if (ctx?.model?.reasoning) {
|
||||
const thinkingLevel = (ctx as any).getThinkingLevel?.() || "off";
|
||||
rightSide =
|
||||
thinkingLevel === "off"
|
||||
? `${modelName} • thinking off`
|
||||
: `${modelName} • ${thinkingLevel}`;
|
||||
}
|
||||
if (footerData.getAvailableProviderCount() > 1 && ctx?.model) {
|
||||
const withProvider = `(${ctx.model.provider}) ${rightSide}`;
|
||||
if (statsLeftWidth + 2 + visibleWidth(withProvider) <= width) {
|
||||
rightSide = withProvider;
|
||||
}
|
||||
}
|
||||
|
||||
const rightSideWidth = visibleWidth(rightSide);
|
||||
let statsLine: string;
|
||||
if (statsLeftWidth + 2 + rightSideWidth <= width) {
|
||||
const pad = " ".repeat(width - statsLeftWidth - rightSideWidth);
|
||||
statsLine = statsLeft + pad + rightSide;
|
||||
} else {
|
||||
const avail = width - statsLeftWidth - 2;
|
||||
if (avail > 0) {
|
||||
const truncRight = truncateToWidth(rightSide, avail, "");
|
||||
const truncRightWidth = visibleWidth(truncRight);
|
||||
const pad = " ".repeat(
|
||||
Math.max(0, width - statsLeftWidth - truncRightWidth),
|
||||
);
|
||||
statsLine = statsLeft + pad + truncRight;
|
||||
} else {
|
||||
statsLine = statsLeft;
|
||||
}
|
||||
}
|
||||
|
||||
const dimStatsLeft = th.fg("dim", statsLeft);
|
||||
const remainder = statsLine.slice(statsLeft.length);
|
||||
const dimRemainder = th.fg("dim", remainder);
|
||||
|
||||
const lines = [pwdLine, dimStatsLeft + dimRemainder];
|
||||
|
||||
// ── Line 3 (optional): other extension statuses ──
|
||||
const extensionStatuses = footerData.getExtensionStatuses();
|
||||
// Filter out our own bg-shell status since it's already on line 1
|
||||
const otherStatuses = Array.from(extensionStatuses.entries())
|
||||
.filter(([key]) => key !== "bg-shell")
|
||||
.sort(([a], [b]) => a.localeCompare(b))
|
||||
.map(([, text]) =>
|
||||
text
|
||||
.replace(/[\r\n\t]/g, " ")
|
||||
.replace(/ +/g, " ")
|
||||
.trim(),
|
||||
);
|
||||
if (otherStatuses.length > 0) {
|
||||
lines.push(
|
||||
truncateToWidth(
|
||||
otherStatuses.join(" "),
|
||||
width,
|
||||
th.fg("dim", "..."),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
return lines;
|
||||
},
|
||||
invalidate() {},
|
||||
dispose() {
|
||||
branchUnsub();
|
||||
footerTui = null;
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// Expose refreshWidget via shared state so the command module can use it
|
||||
state.refreshWidget = refreshWidget;
|
||||
|
||||
// Periodic maintenance
|
||||
const maintenanceInterval = setInterval(() => {
|
||||
pruneDeadProcesses();
|
||||
refreshWidget();
|
||||
// Persist manifest periodically
|
||||
if (state.latestCtx) {
|
||||
syncLatestCtxCwd();
|
||||
persistManifest(state.latestCtx.cwd);
|
||||
}
|
||||
}, 2000);
|
||||
|
||||
// Refresh widget after agent actions and session events
|
||||
const refreshHandler = async (_event: unknown, ctx: ExtensionContext) => {
|
||||
state.latestCtx = ctx;
|
||||
refreshWidget();
|
||||
};
|
||||
pi.on("turn_end", refreshHandler as any);
|
||||
pi.on("agent_end", refreshHandler as any);
|
||||
pi.on("session_start", refreshHandler as any);
|
||||
pi.on("session_switch", refreshHandler as any);
|
||||
|
||||
pi.on("tool_execution_end", async (_event, ctx) => {
|
||||
state.latestCtx = ctx;
|
||||
refreshWidget();
|
||||
});
|
||||
|
||||
// Clean up on shutdown
|
||||
pi.on("session_shutdown", async () => {
|
||||
clearInterval(maintenanceInterval);
|
||||
if (state.latestCtx) {
|
||||
syncLatestCtxCwd();
|
||||
persistManifest(state.latestCtx.cwd);
|
||||
}
|
||||
cleanupAll();
|
||||
});
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,71 +0,0 @@
|
|||
/**
|
||||
* Background Shell Extension v2
|
||||
*
|
||||
* Command/tool registration is deferred in interactive mode so startup does not
|
||||
* block on the full background-process stack before the TUI paints.
|
||||
*/
|
||||
|
||||
import {
|
||||
type ExtensionAPI,
|
||||
type ExtensionContext,
|
||||
importExtensionModule,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { registerBgShellLifecycle } from "./bg-shell-lifecycle.js";
|
||||
|
||||
export interface BgShellSharedState {
|
||||
latestCtx: ExtensionContext | null;
|
||||
refreshWidget: () => void;
|
||||
}
|
||||
|
||||
let featuresPromise: Promise<void> | null = null;
|
||||
|
||||
async function registerBgShellFeatures(
|
||||
pi: ExtensionAPI,
|
||||
state: BgShellSharedState,
|
||||
): Promise<void> {
|
||||
if (!featuresPromise) {
|
||||
featuresPromise = (async () => {
|
||||
const [{ registerBgShellTool }, { registerBgShellCommand }] =
|
||||
await Promise.all([
|
||||
importExtensionModule<typeof import("./bg-shell-tool.js")>(
|
||||
import.meta.url,
|
||||
"./bg-shell-tool.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./bg-shell-command.js")>(
|
||||
import.meta.url,
|
||||
"./bg-shell-command.js",
|
||||
),
|
||||
]);
|
||||
registerBgShellTool(pi, state);
|
||||
registerBgShellCommand(pi, state);
|
||||
})().catch((error) => {
|
||||
featuresPromise = null;
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
return featuresPromise;
|
||||
}
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
const state: BgShellSharedState = {
|
||||
latestCtx: null,
|
||||
refreshWidget: () => {},
|
||||
};
|
||||
|
||||
registerBgShellLifecycle(pi, state);
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
if (ctx.hasUI) {
|
||||
void registerBgShellFeatures(pi, state).catch((error) => {
|
||||
ctx.ui.notify(
|
||||
`bg-shell failed to load: ${error instanceof Error ? error.message : String(error)}`,
|
||||
"warning",
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await registerBgShellFeatures(pi, state);
|
||||
});
|
||||
}
|
||||
|
|
@ -1,225 +0,0 @@
|
|||
/**
|
||||
* Expect-style interactions: send_and_wait, run on session, query shell environment.
|
||||
*/
|
||||
|
||||
import { randomUUID } from "node:crypto";
|
||||
import { rewriteCommandWithRtk } from "../shared/rtk.js";
|
||||
import type { BgProcess } from "./types.js";
|
||||
|
||||
// ── Query Shell Environment ────────────────────────────────────────────────
|
||||
|
||||
export async function queryShellEnv(
|
||||
bg: BgProcess,
|
||||
timeout: number,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ cwd: string; env: Record<string, string>; shell: string } | null> {
|
||||
const sentinel = `__SF_ENV_${randomUUID().slice(0, 8)}__`;
|
||||
const startIndex = bg.output.length;
|
||||
|
||||
const cmd = [
|
||||
`echo "${sentinel}_START"`,
|
||||
`echo "CWD=$(pwd)"`,
|
||||
`echo "SHELL=$SHELL"`,
|
||||
`echo "PATH=$PATH"`,
|
||||
`echo "VIRTUAL_ENV=$VIRTUAL_ENV"`,
|
||||
`echo "NODE_ENV=$NODE_ENV"`,
|
||||
`echo "HOME=$HOME"`,
|
||||
`echo "USER=$USER"`,
|
||||
`echo "NVM_DIR=$NVM_DIR"`,
|
||||
`echo "GOPATH=$GOPATH"`,
|
||||
`echo "CARGO_HOME=$CARGO_HOME"`,
|
||||
`echo "PYTHONPATH=$PYTHONPATH"`,
|
||||
`echo "${sentinel}_END"`,
|
||||
].join(" && ");
|
||||
|
||||
bg.proc.stdin?.write(cmd + "\n");
|
||||
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < timeout) {
|
||||
if (signal?.aborted) return null;
|
||||
if (!bg.alive) return null;
|
||||
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
const endIdx = newEntries.findIndex((e) =>
|
||||
e.line.includes(`${sentinel}_END`),
|
||||
);
|
||||
if (endIdx >= 0) {
|
||||
const startIdx = newEntries.findIndex((e) =>
|
||||
e.line.includes(`${sentinel}_START`),
|
||||
);
|
||||
if (startIdx >= 0) {
|
||||
const envLines = newEntries.slice(startIdx + 1, endIdx);
|
||||
const env: Record<string, string> = {};
|
||||
let cwd = "";
|
||||
let shell = "";
|
||||
|
||||
for (const entry of envLines) {
|
||||
const match = entry.line.match(/^([A-Z_]+)=(.*)$/);
|
||||
if (match) {
|
||||
const [, key, value] = match;
|
||||
if (key === "CWD") {
|
||||
cwd = value;
|
||||
} else if (key === "SHELL") {
|
||||
shell = value;
|
||||
} else if (value) {
|
||||
env[key] = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { cwd, env, shell };
|
||||
}
|
||||
}
|
||||
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
// ── Send and Wait ──────────────────────────────────────────────────────────
|
||||
|
||||
export async function sendAndWait(
|
||||
bg: BgProcess,
|
||||
input: string,
|
||||
waitPattern: string,
|
||||
timeout: number,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ matched: boolean; output: string }> {
|
||||
// Snapshot the current position in the unified buffer before sending
|
||||
const startIndex = bg.output.length;
|
||||
bg.proc.stdin?.write(input + "\n");
|
||||
|
||||
let re: RegExp;
|
||||
try {
|
||||
re = new RegExp(waitPattern, "i");
|
||||
} catch {
|
||||
return { matched: false, output: "Invalid wait pattern regex" };
|
||||
}
|
||||
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < timeout) {
|
||||
if (signal?.aborted) {
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
return {
|
||||
matched: false,
|
||||
output: newEntries.map((e) => e.line).join("\n") || "(cancelled)",
|
||||
};
|
||||
}
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
for (const entry of newEntries) {
|
||||
if (re.test(entry.line)) {
|
||||
return {
|
||||
matched: true,
|
||||
output: newEntries.map((e) => e.line).join("\n"),
|
||||
};
|
||||
}
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
}
|
||||
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
return {
|
||||
matched: false,
|
||||
output: newEntries.map((e) => e.line).join("\n") || "(no output)",
|
||||
};
|
||||
}
|
||||
|
||||
// ── Run on Session ─────────────────────────────────────────────────────────
|
||||
|
||||
export async function runOnSession(
|
||||
bg: BgProcess,
|
||||
command: string,
|
||||
timeout: number,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ exitCode: number; output: string; timedOut: boolean }> {
|
||||
const sentinel = randomUUID().slice(0, 8);
|
||||
const startMarker = `__SF_SENTINEL_${sentinel}_START__`;
|
||||
const endMarker = `__SF_SENTINEL_${sentinel}_END__`;
|
||||
const exitVar = `__SF_EXIT_${sentinel}__`;
|
||||
|
||||
// Snapshot current output buffer position
|
||||
const startIndex = bg.output.length;
|
||||
|
||||
// Write the sentinel-wrapped command to stdin
|
||||
const rewrittenCommand = rewriteCommandWithRtk(command);
|
||||
const wrappedCommand = [
|
||||
`echo ${startMarker}`,
|
||||
rewrittenCommand,
|
||||
`${exitVar}=$?`,
|
||||
`echo ${endMarker} $${exitVar}`,
|
||||
].join("\n");
|
||||
bg.proc.stdin?.write(wrappedCommand + "\n");
|
||||
|
||||
const start = Date.now();
|
||||
while (Date.now() - start < timeout) {
|
||||
if (signal?.aborted) {
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
return {
|
||||
exitCode: -1,
|
||||
output: newEntries.map((e) => e.line).join("\n") || "(cancelled)",
|
||||
timedOut: false,
|
||||
};
|
||||
}
|
||||
|
||||
// Process died while waiting
|
||||
if (!bg.alive) {
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
const lines = newEntries.map((e) => e.line);
|
||||
return {
|
||||
exitCode: bg.proc.exitCode ?? -1,
|
||||
output: lines.join("\n") || "(process exited)",
|
||||
timedOut: false,
|
||||
};
|
||||
}
|
||||
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
for (let i = 0; i < newEntries.length; i++) {
|
||||
if (newEntries[i].line.includes(endMarker)) {
|
||||
// Parse exit code from the END sentinel line
|
||||
const endLine = newEntries[i].line;
|
||||
const exitMatch = endLine.match(new RegExp(`${endMarker}\\s+(\\d+)`));
|
||||
const exitCode = exitMatch ? parseInt(exitMatch[1], 10) : -1;
|
||||
|
||||
// Extract output between START and END sentinels
|
||||
const outputLines: string[] = [];
|
||||
let capturing = false;
|
||||
for (let j = 0; j < newEntries.length; j++) {
|
||||
if (newEntries[j].line.includes(startMarker)) {
|
||||
capturing = true;
|
||||
continue;
|
||||
}
|
||||
if (newEntries[j].line.includes(endMarker)) {
|
||||
break;
|
||||
}
|
||||
if (capturing) {
|
||||
outputLines.push(newEntries[j].line);
|
||||
}
|
||||
}
|
||||
|
||||
return { exitCode, output: outputLines.join("\n"), timedOut: false };
|
||||
}
|
||||
}
|
||||
|
||||
await new Promise((r) => setTimeout(r, 100));
|
||||
}
|
||||
|
||||
// Timed out
|
||||
const newEntries = bg.output.slice(startIndex);
|
||||
const outputLines: string[] = [];
|
||||
let capturing = false;
|
||||
for (const entry of newEntries) {
|
||||
if (entry.line.includes(startMarker)) {
|
||||
capturing = true;
|
||||
continue;
|
||||
}
|
||||
if (capturing) {
|
||||
outputLines.push(entry.line);
|
||||
}
|
||||
}
|
||||
return {
|
||||
exitCode: -1,
|
||||
output: outputLines.join("\n") || "(no output)",
|
||||
timedOut: true,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,291 +0,0 @@
|
|||
/**
|
||||
* Output analysis, digest generation, highlights extraction, and output retrieval.
|
||||
*/
|
||||
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
truncateHead,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { addEvent, pushAlert } from "./process-manager.js";
|
||||
import { transitionToReady } from "./readiness-detector.js";
|
||||
import type {
|
||||
BgProcess,
|
||||
GetOutputOptions,
|
||||
OutputDigest,
|
||||
OutputLine,
|
||||
} from "./types.js";
|
||||
import {
|
||||
BUILD_COMPLETE_PATTERN_UNION,
|
||||
ERROR_PATTERN_UNION,
|
||||
PORT_PATTERN_SOURCE,
|
||||
READINESS_PATTERN_UNION,
|
||||
TEST_RESULT_PATTERN_UNION,
|
||||
URL_PATTERN,
|
||||
WARNING_PATTERN_UNION,
|
||||
} from "./types.js";
|
||||
import { formatTimeAgo, formatUptime } from "./utilities.js";
|
||||
|
||||
// ── Output Analysis ────────────────────────────────────────────────────────
|
||||
|
||||
export function analyzeLine(
|
||||
bg: BgProcess,
|
||||
line: string,
|
||||
_stream: "stdout" | "stderr",
|
||||
): void {
|
||||
// Error detection — single union regex instead of .some(p => p.test(line))
|
||||
if (ERROR_PATTERN_UNION.test(line)) {
|
||||
bg.recentErrors.push(line.trim().slice(0, 200)); // Cap line length
|
||||
if (bg.recentErrors.length > 50)
|
||||
bg.recentErrors.splice(0, bg.recentErrors.length - 50);
|
||||
|
||||
if (bg.status === "ready") {
|
||||
bg.status = "error";
|
||||
addEvent(bg, {
|
||||
type: "error_detected",
|
||||
detail: line.trim().slice(0, 200),
|
||||
data: { errorCount: bg.recentErrors.length },
|
||||
});
|
||||
pushAlert(bg, `error_detected: ${line.trim().slice(0, 120)}`);
|
||||
}
|
||||
}
|
||||
|
||||
// Warning detection — single union regex
|
||||
if (WARNING_PATTERN_UNION.test(line)) {
|
||||
bg.recentWarnings.push(line.trim().slice(0, 200));
|
||||
if (bg.recentWarnings.length > 50)
|
||||
bg.recentWarnings.splice(0, bg.recentWarnings.length - 50);
|
||||
}
|
||||
|
||||
// URL extraction
|
||||
const urlMatches = line.match(URL_PATTERN);
|
||||
if (urlMatches) {
|
||||
for (const url of urlMatches) {
|
||||
if (!bg.urls.includes(url)) {
|
||||
bg.urls.push(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Port extraction — PORT_PATTERN has /g flag so must be re-created per call
|
||||
// Use PORT_PATTERN_SOURCE (string) to avoid re-parsing the literal each time
|
||||
const portRe = new RegExp(PORT_PATTERN_SOURCE, "gi");
|
||||
let portMatch: RegExpExecArray | null;
|
||||
// biome-ignore lint/suspicious/noAssignInExpressions: intentional read loop
|
||||
while ((portMatch = portRe.exec(line)) !== null) {
|
||||
const port = parseInt(portMatch[1], 10);
|
||||
if (port > 0 && port <= 65535 && !bg.ports.includes(port)) {
|
||||
bg.ports.push(port);
|
||||
addEvent(bg, {
|
||||
type: "port_open",
|
||||
detail: `Port ${port} detected`,
|
||||
data: { port },
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Readiness detection — single union regex
|
||||
if (bg.status === "starting") {
|
||||
// Check custom ready pattern first
|
||||
if (bg.readyPattern) {
|
||||
try {
|
||||
if (new RegExp(bg.readyPattern, "i").test(line)) {
|
||||
transitionToReady(
|
||||
bg,
|
||||
`Custom pattern matched: ${line.trim().slice(0, 100)}`,
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
/* invalid regex, skip */
|
||||
}
|
||||
}
|
||||
|
||||
// Check built-in readiness patterns
|
||||
if (bg.status === "starting" && READINESS_PATTERN_UNION.test(line)) {
|
||||
transitionToReady(
|
||||
bg,
|
||||
`Readiness pattern matched: ${line.trim().slice(0, 100)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Recovery detection: if we were in error and see a success pattern
|
||||
if (bg.status === "error") {
|
||||
if (
|
||||
READINESS_PATTERN_UNION.test(line) ||
|
||||
BUILD_COMPLETE_PATTERN_UNION.test(line)
|
||||
) {
|
||||
bg.status = "ready";
|
||||
bg.recentErrors = [];
|
||||
addEvent(bg, {
|
||||
type: "recovered",
|
||||
detail: "Process recovered from error state",
|
||||
});
|
||||
pushAlert(bg, "recovered — errors cleared");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Digest Generation ──────────────────────────────────────────────────────
|
||||
|
||||
export function generateDigest(
|
||||
bg: BgProcess,
|
||||
mutate: boolean = false,
|
||||
): OutputDigest {
|
||||
// Change summary: what's different since last read
|
||||
const newErrors = bg.recentErrors.length - bg.lastErrorCount;
|
||||
const newWarnings = bg.recentWarnings.length - bg.lastWarningCount;
|
||||
const newLines = bg.output.length - bg.lastReadIndex;
|
||||
|
||||
let changeSummary: string;
|
||||
if (newLines === 0) {
|
||||
changeSummary = "no new output";
|
||||
} else {
|
||||
const parts: string[] = [];
|
||||
parts.push(`${newLines} new lines`);
|
||||
if (newErrors > 0) parts.push(`${newErrors} new errors`);
|
||||
if (newWarnings > 0) parts.push(`${newWarnings} new warnings`);
|
||||
changeSummary = parts.join(", ");
|
||||
}
|
||||
|
||||
// Only mutate snapshot counters when explicitly requested (e.g. from tool calls)
|
||||
if (mutate) {
|
||||
bg.lastErrorCount = bg.recentErrors.length;
|
||||
bg.lastWarningCount = bg.recentWarnings.length;
|
||||
}
|
||||
|
||||
return {
|
||||
status: bg.status,
|
||||
uptime: formatUptime(Date.now() - bg.startedAt),
|
||||
errors: bg.recentErrors.slice(-5), // Last 5 errors
|
||||
warnings: bg.recentWarnings.slice(-3), // Last 3 warnings
|
||||
urls: bg.urls,
|
||||
ports: bg.ports,
|
||||
lastActivity:
|
||||
bg.events.length > 0
|
||||
? formatTimeAgo(bg.events[bg.events.length - 1].timestamp)
|
||||
: "none",
|
||||
outputLines: bg.output.length,
|
||||
changeSummary,
|
||||
};
|
||||
}
|
||||
|
||||
// ── Highlight Extraction ───────────────────────────────────────────────────
|
||||
|
||||
export function getHighlights(bg: BgProcess, maxLines: number = 15): string[] {
|
||||
const lines: string[] = [];
|
||||
|
||||
// Collect significant lines
|
||||
const significant: { line: string; score: number; idx: number }[] = [];
|
||||
for (let i = 0; i < bg.output.length; i++) {
|
||||
const entry = bg.output[i];
|
||||
let score = 0;
|
||||
if (ERROR_PATTERN_UNION.test(entry.line)) score += 10;
|
||||
if (WARNING_PATTERN_UNION.test(entry.line)) score += 5;
|
||||
if (URL_PATTERN.test(entry.line)) score += 3;
|
||||
if (READINESS_PATTERN_UNION.test(entry.line)) score += 8;
|
||||
if (TEST_RESULT_PATTERN_UNION.test(entry.line)) score += 7;
|
||||
if (BUILD_COMPLETE_PATTERN_UNION.test(entry.line)) score += 6;
|
||||
// Boost recent lines so highlights favor fresh output over stale
|
||||
if (i >= bg.output.length - 50) score += 2;
|
||||
if (score > 0) {
|
||||
significant.push({
|
||||
line: entry.line.trim().slice(0, 300),
|
||||
score,
|
||||
idx: i,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by significance (tie-break by recency)
|
||||
significant.sort((a, b) => b.score - a.score || b.idx - a.idx);
|
||||
const top = significant.slice(0, maxLines);
|
||||
|
||||
if (top.length === 0) {
|
||||
// If nothing significant, show last few lines
|
||||
const tail = bg.output.slice(-5);
|
||||
for (const l of tail) lines.push(l.line.trim().slice(0, 300));
|
||||
} else {
|
||||
for (const entry of top) lines.push(entry.line);
|
||||
}
|
||||
|
||||
return lines;
|
||||
}
|
||||
|
||||
// ── Output Retrieval (multi-tier) ──────────────────────────────────────────
|
||||
|
||||
export function getOutput(bg: BgProcess, opts: GetOutputOptions): string {
|
||||
const { stream, tail, filter, incremental } = opts;
|
||||
|
||||
// Get the relevant slice of the unified buffer (already in chronological order)
|
||||
let entries: OutputLine[];
|
||||
if (incremental) {
|
||||
entries = bg.output.slice(bg.lastReadIndex);
|
||||
bg.lastReadIndex = bg.output.length;
|
||||
} else {
|
||||
entries = [...bg.output];
|
||||
}
|
||||
|
||||
// Filter by stream if requested
|
||||
if (stream !== "both") {
|
||||
entries = entries.filter((e) => e.stream === stream);
|
||||
}
|
||||
|
||||
// Apply regex filter
|
||||
if (filter) {
|
||||
try {
|
||||
const re = new RegExp(filter, "i");
|
||||
entries = entries.filter((e) => re.test(e.line));
|
||||
} catch {
|
||||
/* invalid regex */
|
||||
}
|
||||
}
|
||||
|
||||
// Tail
|
||||
if (tail && tail > 0 && entries.length > tail) {
|
||||
entries = entries.slice(-tail);
|
||||
}
|
||||
|
||||
const lines = entries.map((e) => e.line);
|
||||
const raw = lines.join("\n");
|
||||
const truncation = truncateHead(raw, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
|
||||
let result = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
result += `\n\n[Output truncated: showing ${truncation.outputLines}/${truncation.totalLines} lines]`;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ── Format Digest for LLM ──────────────────────────────────────────────────
|
||||
|
||||
export function formatDigestText(bg: BgProcess, digest: OutputDigest): string {
|
||||
let text = `Process ${bg.id} (${bg.label}):\n`;
|
||||
text += ` status: ${digest.status}\n`;
|
||||
text += ` type: ${bg.processType}\n`;
|
||||
text += ` uptime: ${digest.uptime}\n`;
|
||||
|
||||
if (digest.ports.length > 0) text += ` ports: ${digest.ports.join(", ")}\n`;
|
||||
if (digest.urls.length > 0) text += ` urls: ${digest.urls.join(", ")}\n`;
|
||||
|
||||
text += ` output: ${digest.outputLines} lines\n`;
|
||||
text += ` changes: ${digest.changeSummary}`;
|
||||
|
||||
if (digest.errors.length > 0) {
|
||||
text += `\n errors (${digest.errors.length}):`;
|
||||
for (const err of digest.errors) {
|
||||
text += `\n - ${err}`;
|
||||
}
|
||||
}
|
||||
if (digest.warnings.length > 0) {
|
||||
text += `\n warnings (${digest.warnings.length}):`;
|
||||
for (const w of digest.warnings) {
|
||||
text += `\n - ${w}`;
|
||||
}
|
||||
}
|
||||
|
||||
return text;
|
||||
}
|
||||
|
|
@ -1,496 +0,0 @@
|
|||
/**
|
||||
* TUI: Background Process Manager Overlay.
|
||||
*/
|
||||
|
||||
import type { Theme } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
Key,
|
||||
matchesKey,
|
||||
truncateToWidth,
|
||||
visibleWidth,
|
||||
} from "@singularity-forge/pi-tui";
|
||||
import {
|
||||
cleanupAll,
|
||||
killProcess,
|
||||
processes,
|
||||
restartProcess,
|
||||
} from "./process-manager.js";
|
||||
import type { BgProcess } from "./types.js";
|
||||
import { ERROR_PATTERNS, WARNING_PATTERNS } from "./types.js";
|
||||
import { formatTimeAgo, formatUptime } from "./utilities.js";
|
||||
|
||||
export class BgManagerOverlay {
|
||||
private tui: { requestRender: () => void };
|
||||
private theme: Theme;
|
||||
private onClose: () => void;
|
||||
private selected = 0;
|
||||
private mode: "list" | "output" | "events" = "list";
|
||||
private viewingProcess: BgProcess | null = null;
|
||||
private scrollOffset = 0;
|
||||
private cachedWidth?: number;
|
||||
private cachedLines?: string[];
|
||||
private refreshTimer: ReturnType<typeof setInterval>;
|
||||
|
||||
constructor(
|
||||
tui: { requestRender: () => void },
|
||||
theme: Theme,
|
||||
onClose: () => void,
|
||||
) {
|
||||
this.tui = tui;
|
||||
this.theme = theme;
|
||||
this.onClose = onClose;
|
||||
this.refreshTimer = setInterval(() => {
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}, 1000);
|
||||
}
|
||||
|
||||
private getProcessList(): BgProcess[] {
|
||||
return Array.from(processes.values());
|
||||
}
|
||||
|
||||
selectAndView(index: number): void {
|
||||
const procs = this.getProcessList();
|
||||
if (index >= 0 && index < procs.length) {
|
||||
this.selected = index;
|
||||
this.viewingProcess = procs[index];
|
||||
this.mode = "output";
|
||||
this.scrollOffset = Math.max(0, procs[index].output.length - 20);
|
||||
}
|
||||
}
|
||||
|
||||
handleInput(data: string): void {
|
||||
if (this.mode === "output") {
|
||||
this.handleOutputInput(data);
|
||||
return;
|
||||
}
|
||||
if (this.mode === "events") {
|
||||
this.handleEventsInput(data);
|
||||
return;
|
||||
}
|
||||
this.handleListInput(data);
|
||||
}
|
||||
|
||||
private handleListInput(data: string): void {
|
||||
const procs = this.getProcessList();
|
||||
|
||||
if (
|
||||
matchesKey(data, Key.escape) ||
|
||||
matchesKey(data, Key.ctrl("c")) ||
|
||||
matchesKey(data, Key.ctrlAlt("b"))
|
||||
) {
|
||||
clearInterval(this.refreshTimer);
|
||||
this.onClose();
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.up) || matchesKey(data, "k")) {
|
||||
if (this.selected > 0) {
|
||||
this.selected--;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.down) || matchesKey(data, "j")) {
|
||||
if (this.selected < procs.length - 1) {
|
||||
this.selected++;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.enter)) {
|
||||
const proc = procs[this.selected];
|
||||
if (proc) {
|
||||
this.viewingProcess = proc;
|
||||
this.mode = "output";
|
||||
this.scrollOffset = Math.max(0, proc.output.length - 20);
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// e = view events
|
||||
if (data === "e") {
|
||||
const proc = procs[this.selected];
|
||||
if (proc) {
|
||||
this.viewingProcess = proc;
|
||||
this.mode = "events";
|
||||
this.scrollOffset = Math.max(0, proc.events.length - 15);
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// r = restart
|
||||
if (data === "r") {
|
||||
const proc = procs[this.selected];
|
||||
if (proc) {
|
||||
restartProcess(proc.id)
|
||||
.then(() => {
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
})
|
||||
.catch((err) => {
|
||||
if (process.env.SF_DEBUG)
|
||||
console.error("[bg-shell] restart failed:", err);
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// x or d = kill selected
|
||||
if (data === "x" || data === "d") {
|
||||
const proc = procs[this.selected];
|
||||
if (proc && proc.alive) {
|
||||
killProcess(proc.id, "SIGTERM");
|
||||
setTimeout(() => {
|
||||
if (proc.alive) killProcess(proc.id, "SIGKILL");
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
}, 300);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// X or D = kill all
|
||||
if (data === "X" || data === "D") {
|
||||
cleanupAll();
|
||||
this.selected = 0;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private handleOutputInput(data: string): void {
|
||||
if (matchesKey(data, Key.escape) || matchesKey(data, "q")) {
|
||||
this.mode = "list";
|
||||
this.viewingProcess = null;
|
||||
this.scrollOffset = 0;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
// Tab to switch to events view
|
||||
if (matchesKey(data, Key.tab)) {
|
||||
this.mode = "events";
|
||||
if (this.viewingProcess) {
|
||||
this.scrollOffset = Math.max(0, this.viewingProcess.events.length - 15);
|
||||
}
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.down) || matchesKey(data, "j")) {
|
||||
if (this.viewingProcess) {
|
||||
const total = this.viewingProcess.output.length;
|
||||
this.scrollOffset = Math.min(
|
||||
this.scrollOffset + 5,
|
||||
Math.max(0, total - 20),
|
||||
);
|
||||
}
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.up) || matchesKey(data, "k")) {
|
||||
this.scrollOffset = Math.max(0, this.scrollOffset - 5);
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data === "G") {
|
||||
if (this.viewingProcess) {
|
||||
const total = this.viewingProcess.output.length;
|
||||
this.scrollOffset = Math.max(0, total - 20);
|
||||
}
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (data === "g") {
|
||||
this.scrollOffset = 0;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
private handleEventsInput(data: string): void {
|
||||
if (matchesKey(data, Key.escape) || matchesKey(data, "q")) {
|
||||
this.mode = "list";
|
||||
this.viewingProcess = null;
|
||||
this.scrollOffset = 0;
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
// Tab to switch back to output view
|
||||
if (matchesKey(data, Key.tab)) {
|
||||
this.mode = "output";
|
||||
if (this.viewingProcess) {
|
||||
this.scrollOffset = Math.max(0, this.viewingProcess.output.length - 20);
|
||||
}
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.down) || matchesKey(data, "j")) {
|
||||
if (this.viewingProcess) {
|
||||
this.scrollOffset = Math.min(
|
||||
this.scrollOffset + 3,
|
||||
Math.max(0, this.viewingProcess.events.length - 10),
|
||||
);
|
||||
}
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
|
||||
if (matchesKey(data, Key.up) || matchesKey(data, "k")) {
|
||||
this.scrollOffset = Math.max(0, this.scrollOffset - 3);
|
||||
this.invalidate();
|
||||
this.tui.requestRender();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
render(width: number): string[] {
|
||||
if (this.cachedLines && this.cachedWidth === width) {
|
||||
return this.cachedLines;
|
||||
}
|
||||
|
||||
let lines: string[];
|
||||
if (this.mode === "events") {
|
||||
lines = this.renderEvents(width);
|
||||
} else if (this.mode === "output") {
|
||||
lines = this.renderOutput(width);
|
||||
} else {
|
||||
lines = this.renderList(width);
|
||||
}
|
||||
|
||||
this.cachedWidth = width;
|
||||
this.cachedLines = lines;
|
||||
return lines;
|
||||
}
|
||||
|
||||
private box(inner: string[], width: number): string[] {
|
||||
const th = this.theme;
|
||||
const bdr = (s: string) => th.fg("borderMuted", s);
|
||||
const iw = width - 4;
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push(bdr("╭" + "─".repeat(width - 2) + "╮"));
|
||||
for (const line of inner) {
|
||||
const truncated = truncateToWidth(line, iw);
|
||||
const pad = Math.max(0, iw - visibleWidth(truncated));
|
||||
lines.push(bdr("│") + " " + truncated + " ".repeat(pad) + " " + bdr("│"));
|
||||
}
|
||||
lines.push(bdr("╰" + "─".repeat(width - 2) + "╯"));
|
||||
return lines;
|
||||
}
|
||||
|
||||
private renderList(width: number): string[] {
|
||||
const th = this.theme;
|
||||
const procs = this.getProcessList();
|
||||
const inner: string[] = [];
|
||||
|
||||
if (procs.length === 0) {
|
||||
inner.push(th.fg("dim", "No background processes."));
|
||||
inner.push("");
|
||||
inner.push(th.fg("dim", "esc close"));
|
||||
return this.box(inner, width);
|
||||
}
|
||||
|
||||
inner.push(th.fg("dim", "Background Processes"));
|
||||
inner.push("");
|
||||
|
||||
for (let i = 0; i < procs.length; i++) {
|
||||
const p = procs[i];
|
||||
const sel = i === this.selected;
|
||||
const pointer = sel ? th.fg("accent", "▸ ") : " ";
|
||||
|
||||
const statusIcon = p.alive
|
||||
? p.status === "ready"
|
||||
? th.fg("success", "●")
|
||||
: p.status === "error"
|
||||
? th.fg("error", "●")
|
||||
: th.fg("warning", "●")
|
||||
: th.fg("dim", "○");
|
||||
|
||||
const uptime = th.fg("dim", formatUptime(Date.now() - p.startedAt));
|
||||
const name = sel ? th.fg("text", p.label) : th.fg("muted", p.label);
|
||||
const typeTag = th.fg("dim", `[${p.processType}]`);
|
||||
const portInfo =
|
||||
p.ports.length > 0 ? th.fg("dim", ` :${p.ports.join(",")}`) : "";
|
||||
const errBadge =
|
||||
p.recentErrors.length > 0
|
||||
? th.fg("error", ` ⚠${p.recentErrors.length}`)
|
||||
: "";
|
||||
const groupTag = p.group ? th.fg("dim", ` {${p.group}}`) : "";
|
||||
const restartBadge =
|
||||
p.restartCount > 0 ? th.fg("warning", ` ↻${p.restartCount}`) : "";
|
||||
|
||||
const status = p.alive ? "" : " " + th.fg("dim", `exit ${p.exitCode}`);
|
||||
|
||||
inner.push(
|
||||
`${pointer}${statusIcon} ${name} ${typeTag} ${uptime}${portInfo}${errBadge}${groupTag}${restartBadge}${status}`,
|
||||
);
|
||||
}
|
||||
|
||||
inner.push("");
|
||||
inner.push(
|
||||
th.fg(
|
||||
"dim",
|
||||
"↑↓ select · enter output · e events · r restart · x kill · esc close",
|
||||
),
|
||||
);
|
||||
|
||||
return this.box(inner, width);
|
||||
}
|
||||
|
||||
private processStatusHeader(
|
||||
p: typeof this.viewingProcess,
|
||||
activeTab: "output" | "events",
|
||||
): { statusIcon: string; headerLine: string } {
|
||||
const th = this.theme;
|
||||
if (!p) return { statusIcon: "", headerLine: "" };
|
||||
const statusIcon = p.alive
|
||||
? p.status === "ready"
|
||||
? th.fg("success", "●")
|
||||
: p.status === "error"
|
||||
? th.fg("error", "●")
|
||||
: th.fg("warning", "●")
|
||||
: th.fg("dim", "○");
|
||||
const name = th.fg("muted", p.label);
|
||||
const uptime = th.fg("dim", formatUptime(Date.now() - p.startedAt));
|
||||
const typeTag = th.fg("dim", `[${p.processType}]`);
|
||||
const portInfo =
|
||||
p.ports.length > 0 ? th.fg("dim", ` :${p.ports.join(",")}`) : "";
|
||||
const tabIndicator =
|
||||
activeTab === "output"
|
||||
? th.fg("accent", "[Output]") + " " + th.fg("dim", "Events")
|
||||
: th.fg("dim", "Output") + " " + th.fg("accent", "[Events]");
|
||||
const headerLine = `${statusIcon} ${name} ${typeTag} ${uptime}${portInfo} ${tabIndicator}`;
|
||||
return { statusIcon, headerLine };
|
||||
}
|
||||
|
||||
private renderOutput(width: number): string[] {
|
||||
const th = this.theme;
|
||||
const p = this.viewingProcess;
|
||||
if (!p) return [""];
|
||||
const inner: string[] = [];
|
||||
|
||||
const { headerLine } = this.processStatusHeader(p, "output");
|
||||
inner.push(headerLine);
|
||||
inner.push("");
|
||||
|
||||
// Unified buffer is already chronologically interleaved
|
||||
const allOutput = p.output;
|
||||
|
||||
const maxVisible = 18;
|
||||
const visible = allOutput.slice(
|
||||
this.scrollOffset,
|
||||
this.scrollOffset + maxVisible,
|
||||
);
|
||||
|
||||
if (allOutput.length === 0) {
|
||||
inner.push(th.fg("dim", "(no output)"));
|
||||
} else {
|
||||
for (const entry of visible) {
|
||||
const isError = ERROR_PATTERNS.some((pat) => pat.test(entry.line));
|
||||
const isWarning =
|
||||
!isError && WARNING_PATTERNS.some((pat) => pat.test(entry.line));
|
||||
const prefix = entry.stream === "stderr" ? th.fg("error", "⚠ ") : "";
|
||||
const color = isError ? "error" : isWarning ? "warning" : "dim";
|
||||
inner.push(prefix + th.fg(color, entry.line));
|
||||
}
|
||||
|
||||
if (allOutput.length > maxVisible) {
|
||||
inner.push("");
|
||||
const pos = `${this.scrollOffset + 1}–${Math.min(this.scrollOffset + maxVisible, allOutput.length)} of ${allOutput.length}`;
|
||||
inner.push(th.fg("dim", pos));
|
||||
}
|
||||
}
|
||||
|
||||
inner.push("");
|
||||
inner.push(th.fg("dim", "↑↓ scroll · g/G top/end · tab events · q back"));
|
||||
|
||||
return this.box(inner, width);
|
||||
}
|
||||
|
||||
private renderEvents(width: number): string[] {
|
||||
const th = this.theme;
|
||||
const p = this.viewingProcess;
|
||||
if (!p) return [""];
|
||||
const inner: string[] = [];
|
||||
|
||||
const { headerLine } = this.processStatusHeader(p, "events");
|
||||
inner.push(headerLine);
|
||||
inner.push("");
|
||||
|
||||
if (p.events.length === 0) {
|
||||
inner.push(th.fg("dim", "(no events)"));
|
||||
} else {
|
||||
const maxVisible = 15;
|
||||
const visible = p.events.slice(
|
||||
this.scrollOffset,
|
||||
this.scrollOffset + maxVisible,
|
||||
);
|
||||
|
||||
for (const ev of visible) {
|
||||
const time = th.fg("dim", formatTimeAgo(ev.timestamp));
|
||||
const typeColor =
|
||||
ev.type === "crashed" || ev.type === "error_detected"
|
||||
? "error"
|
||||
: ev.type === "ready" || ev.type === "recovered"
|
||||
? "success"
|
||||
: ev.type === "port_open"
|
||||
? "accent"
|
||||
: "dim";
|
||||
const typeLabel = th.fg(typeColor, ev.type);
|
||||
inner.push(`${time} ${typeLabel}`);
|
||||
inner.push(` ${th.fg("dim", ev.detail.slice(0, 80))}`);
|
||||
}
|
||||
|
||||
if (p.events.length > maxVisible) {
|
||||
inner.push("");
|
||||
inner.push(
|
||||
th.fg(
|
||||
"dim",
|
||||
`${this.scrollOffset + 1}–${Math.min(this.scrollOffset + maxVisible, p.events.length)} of ${p.events.length} events`,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
inner.push("");
|
||||
inner.push(th.fg("dim", "↑↓ scroll · tab output · q back"));
|
||||
|
||||
return this.box(inner, width);
|
||||
}
|
||||
|
||||
dispose(): void {
|
||||
clearInterval(this.refreshTimer);
|
||||
}
|
||||
|
||||
invalidate(): void {
|
||||
this.cachedWidth = undefined;
|
||||
this.cachedLines = undefined;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,525 +0,0 @@
|
|||
/**
|
||||
* Process lifecycle management: start, stop, restart, signal, state tracking,
|
||||
* process registry, and persistence.
|
||||
*/
|
||||
|
||||
import { spawn, spawnSync } from "node:child_process";
|
||||
import { randomUUID } from "node:crypto";
|
||||
import { existsSync, mkdirSync, readFileSync, writeFileSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
import {
|
||||
getShellConfig,
|
||||
sanitizeCommand,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { rewriteCommandWithRtk } from "../shared/rtk.js";
|
||||
import { analyzeLine } from "./output-formatter.js";
|
||||
import { startPortProbing, transitionToReady } from "./readiness-detector.js";
|
||||
import type {
|
||||
BgProcess,
|
||||
BgProcessInfo,
|
||||
ProcessEvent,
|
||||
ProcessManifest,
|
||||
ProcessType,
|
||||
StartOptions,
|
||||
} from "./types.js";
|
||||
import { DEAD_PROCESS_TTL, MAX_BUFFER_LINES, MAX_EVENTS } from "./types.js";
|
||||
import { formatUptime, restoreWindowsVTInput } from "./utilities.js";
|
||||
|
||||
// ── Process Registry ───────────────────────────────────────────────────────
|
||||
|
||||
export const processes = new Map<string, BgProcess>();
|
||||
|
||||
/** Pending alerts to inject into the next agent context */
|
||||
export let pendingAlerts: string[] = [];
|
||||
|
||||
const MAX_PENDING_ALERTS = 50;
|
||||
|
||||
/** Replace the pendingAlerts array (used by the extension entry point) */
|
||||
export function setPendingAlerts(alerts: string[]): void {
|
||||
pendingAlerts = alerts;
|
||||
}
|
||||
|
||||
export function addOutputLine(
|
||||
bg: BgProcess,
|
||||
stream: "stdout" | "stderr",
|
||||
line: string,
|
||||
): void {
|
||||
bg.output.push({ stream, line, ts: Date.now() });
|
||||
if (stream === "stdout") bg.stdoutLineCount++;
|
||||
else bg.stderrLineCount++;
|
||||
if (bg.output.length > MAX_BUFFER_LINES) {
|
||||
const excess = bg.output.length - MAX_BUFFER_LINES;
|
||||
bg.output.splice(0, excess);
|
||||
// Adjust the read cursor so incremental delivery stays correct
|
||||
bg.lastReadIndex = Math.max(0, bg.lastReadIndex - excess);
|
||||
}
|
||||
}
|
||||
|
||||
export function addEvent(
|
||||
bg: BgProcess,
|
||||
event: Omit<ProcessEvent, "timestamp">,
|
||||
): void {
|
||||
const ev: ProcessEvent = { ...event, timestamp: Date.now() };
|
||||
bg.events.push(ev);
|
||||
if (bg.events.length > MAX_EVENTS) {
|
||||
bg.events.splice(0, bg.events.length - MAX_EVENTS);
|
||||
}
|
||||
}
|
||||
|
||||
export function pushAlert(bg: BgProcess | null, message: string): void {
|
||||
const prefix = bg ? `[bg:${bg.id} ${bg.label}] ` : "";
|
||||
pendingAlerts.push(`${prefix}${message}`);
|
||||
if (pendingAlerts.length > MAX_PENDING_ALERTS) {
|
||||
pendingAlerts.splice(0, pendingAlerts.length - MAX_PENDING_ALERTS);
|
||||
}
|
||||
}
|
||||
|
||||
export function getInfo(p: BgProcess): BgProcessInfo {
|
||||
return {
|
||||
id: p.id,
|
||||
label: p.label,
|
||||
command: p.command,
|
||||
cwd: p.cwd,
|
||||
ownerSessionFile: p.ownerSessionFile,
|
||||
persistAcrossSessions: p.persistAcrossSessions,
|
||||
startedAt: p.startedAt,
|
||||
alive: p.alive,
|
||||
exitCode: p.exitCode,
|
||||
signal: p.signal,
|
||||
outputLines: p.output.length,
|
||||
stdoutLines: p.stdoutLineCount,
|
||||
stderrLines: p.stderrLineCount,
|
||||
status: p.status,
|
||||
processType: p.processType,
|
||||
ports: p.ports,
|
||||
urls: p.urls,
|
||||
group: p.group,
|
||||
restartCount: p.restartCount,
|
||||
uptime: formatUptime(Date.now() - p.startedAt),
|
||||
recentErrorCount: p.recentErrors.length,
|
||||
recentWarningCount: p.recentWarnings.length,
|
||||
eventCount: p.events.length,
|
||||
};
|
||||
}
|
||||
|
||||
// ── Process Type Detection ─────────────────────────────────────────────────
|
||||
|
||||
export function detectProcessType(command: string): ProcessType {
|
||||
const cmd = command.toLowerCase();
|
||||
|
||||
// Server patterns
|
||||
if (
|
||||
/\b(serve|server|dev|start)\b/.test(cmd) &&
|
||||
/\b(npm|yarn|pnpm|bun|node|next|vite|nuxt|astro|remix|gatsby|uvicorn|flask|django|rails|cargo)\b/.test(
|
||||
cmd,
|
||||
)
|
||||
)
|
||||
return "server";
|
||||
if (
|
||||
/\b(uvicorn|gunicorn|flask\s+run|manage\.py\s+runserver|rails\s+s)\b/.test(
|
||||
cmd,
|
||||
)
|
||||
)
|
||||
return "server";
|
||||
if (/\b(http-server|live-server|serve)\b/.test(cmd)) return "server";
|
||||
|
||||
// Build patterns
|
||||
if (/\b(build|compile|make|tsc|webpack|rollup|esbuild|swc)\b/.test(cmd)) {
|
||||
if (/\b(watch|--watch|-w)\b/.test(cmd)) return "watcher";
|
||||
return "build";
|
||||
}
|
||||
|
||||
// Test patterns
|
||||
if (
|
||||
/\b(test|jest|vitest|mocha|pytest|cargo\s+test|go\s+test|rspec)\b/.test(cmd)
|
||||
)
|
||||
return "test";
|
||||
|
||||
// Watcher patterns
|
||||
if (/\b(watch|nodemon|chokidar|fswatch|inotifywait)\b/.test(cmd))
|
||||
return "watcher";
|
||||
|
||||
return "generic";
|
||||
}
|
||||
|
||||
// ── Process Start ──────────────────────────────────────────────────────────
|
||||
|
||||
export function startProcess(opts: StartOptions): BgProcess {
|
||||
const id = randomUUID().slice(0, 8);
|
||||
const processType = opts.type || detectProcessType(opts.command);
|
||||
|
||||
const env = { ...process.env, ...(opts.env || {}) };
|
||||
|
||||
const { shell, args: shellArgs } = getShellConfig();
|
||||
// Shell sessions default to the user's shell if no command specified
|
||||
const command =
|
||||
processType === "shell" && !opts.command
|
||||
? shell
|
||||
: rewriteCommandWithRtk(opts.command);
|
||||
const proc = spawn(shell, [...shellArgs, sanitizeCommand(command)], {
|
||||
cwd: opts.cwd,
|
||||
stdio: ["pipe", "pipe", "pipe"],
|
||||
env,
|
||||
detached: process.platform !== "win32",
|
||||
});
|
||||
|
||||
const bg: BgProcess = {
|
||||
id,
|
||||
label: opts.label || command.slice(0, 60),
|
||||
command,
|
||||
cwd: opts.cwd,
|
||||
ownerSessionFile: opts.ownerSessionFile ?? null,
|
||||
persistAcrossSessions: opts.persistAcrossSessions ?? false,
|
||||
startedAt: Date.now(),
|
||||
proc,
|
||||
output: [],
|
||||
exitCode: null,
|
||||
signal: null,
|
||||
alive: true,
|
||||
lastReadIndex: 0,
|
||||
processType,
|
||||
status: "starting",
|
||||
ports: [],
|
||||
urls: [],
|
||||
recentErrors: [],
|
||||
recentWarnings: [],
|
||||
events: [],
|
||||
readyPattern: opts.readyPattern || null,
|
||||
readyPort: opts.readyPort || null,
|
||||
wasReady: false,
|
||||
group: opts.group || null,
|
||||
lastErrorCount: 0,
|
||||
lastWarningCount: 0,
|
||||
stdoutLineCount: 0,
|
||||
stderrLineCount: 0,
|
||||
restartCount: 0,
|
||||
startConfig: {
|
||||
command,
|
||||
cwd: opts.cwd,
|
||||
label: opts.label || command.slice(0, 60),
|
||||
processType,
|
||||
ownerSessionFile: opts.ownerSessionFile ?? null,
|
||||
persistAcrossSessions: opts.persistAcrossSessions ?? false,
|
||||
readyPattern: opts.readyPattern || null,
|
||||
readyPort: opts.readyPort || null,
|
||||
group: opts.group || null,
|
||||
},
|
||||
};
|
||||
|
||||
addEvent(bg, {
|
||||
type: "started",
|
||||
detail: `Process started: ${command.slice(0, 100)}`,
|
||||
});
|
||||
|
||||
proc.stdout?.on("data", (chunk: Buffer) => {
|
||||
const lines = chunk.toString().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.length > 0) {
|
||||
addOutputLine(bg, "stdout", line);
|
||||
analyzeLine(bg, line, "stdout");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
proc.stderr?.on("data", (chunk: Buffer) => {
|
||||
const lines = chunk.toString().split("\n");
|
||||
for (const line of lines) {
|
||||
if (line.length > 0) {
|
||||
addOutputLine(bg, "stderr", line);
|
||||
analyzeLine(bg, line, "stderr");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
proc.on("exit", (code, sig) => {
|
||||
restoreWindowsVTInput();
|
||||
bg.alive = false;
|
||||
bg.exitCode = code;
|
||||
bg.signal = sig ?? null;
|
||||
|
||||
if (code === 0) {
|
||||
bg.status = "exited";
|
||||
addEvent(bg, { type: "exited", detail: `Exited cleanly (code 0)` });
|
||||
} else {
|
||||
bg.status = "crashed";
|
||||
const lastErrors = bg.recentErrors.slice(-3).join("; ");
|
||||
const detail = `Crashed with code ${code}${sig ? ` (signal ${sig})` : ""}${lastErrors ? ` — ${lastErrors}` : ""}`;
|
||||
addEvent(bg, {
|
||||
type: "crashed",
|
||||
detail,
|
||||
data: {
|
||||
exitCode: code,
|
||||
signal: sig,
|
||||
lastErrors: bg.recentErrors.slice(-5),
|
||||
},
|
||||
});
|
||||
pushAlert(
|
||||
bg,
|
||||
`CRASHED (code ${code})${lastErrors ? `: ${lastErrors.slice(0, 120)}` : ""}`,
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
proc.on("error", (err) => {
|
||||
bg.alive = false;
|
||||
bg.status = "crashed";
|
||||
addOutputLine(bg, "stderr", `[spawn error] ${err.message}`);
|
||||
addEvent(bg, { type: "crashed", detail: `Spawn error: ${err.message}` });
|
||||
pushAlert(bg, `spawn error: ${err.message}`);
|
||||
});
|
||||
|
||||
// Port probing for server-type processes
|
||||
if (bg.readyPort) {
|
||||
startPortProbing(bg, bg.readyPort, opts.readyTimeout);
|
||||
}
|
||||
|
||||
// Shell sessions are ready immediately after spawn
|
||||
if (bg.processType === "shell") {
|
||||
setTimeout(() => {
|
||||
if (bg.alive && bg.status === "starting") {
|
||||
transitionToReady(bg, "Shell session initialized");
|
||||
}
|
||||
}, 200);
|
||||
}
|
||||
|
||||
processes.set(id, bg);
|
||||
return bg;
|
||||
}
|
||||
|
||||
// ── Process Kill ───────────────────────────────────────────────────────────
|
||||
|
||||
export function killProcess(
|
||||
id: string,
|
||||
sig: NodeJS.Signals = "SIGTERM",
|
||||
): boolean {
|
||||
const bg = processes.get(id);
|
||||
if (!bg) return false;
|
||||
if (!bg.alive) return true;
|
||||
try {
|
||||
if (process.platform === "win32") {
|
||||
// Windows: use taskkill /F /T to force-kill the entire process tree.
|
||||
// process.kill(-pid) (Unix process groups) does not work on Windows.
|
||||
if (bg.proc.pid) {
|
||||
const result = spawnSync(
|
||||
"taskkill",
|
||||
["/F", "/T", "/PID", String(bg.proc.pid)],
|
||||
{
|
||||
timeout: 5000,
|
||||
encoding: "utf-8",
|
||||
},
|
||||
);
|
||||
if (result.status !== 0 && result.status !== 128) {
|
||||
// taskkill failed — try the direct kill as fallback
|
||||
bg.proc.kill(sig);
|
||||
}
|
||||
} else {
|
||||
bg.proc.kill(sig);
|
||||
}
|
||||
} else {
|
||||
// Unix/macOS: kill the process group via negative PID
|
||||
if (bg.proc.pid) {
|
||||
try {
|
||||
process.kill(-bg.proc.pid, sig);
|
||||
} catch {
|
||||
bg.proc.kill(sig);
|
||||
}
|
||||
} else {
|
||||
bg.proc.kill(sig);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// ── Process Restart ────────────────────────────────────────────────────────
|
||||
|
||||
export async function restartProcess(id: string): Promise<BgProcess | null> {
|
||||
const old = processes.get(id);
|
||||
if (!old) return null;
|
||||
|
||||
const config = old.startConfig;
|
||||
const restartCount = old.restartCount + 1;
|
||||
|
||||
// Kill old process
|
||||
if (old.alive) {
|
||||
killProcess(id, "SIGTERM");
|
||||
await new Promise((r) => setTimeout(r, 300));
|
||||
if (old.alive) {
|
||||
killProcess(id, "SIGKILL");
|
||||
await new Promise((r) => setTimeout(r, 200));
|
||||
}
|
||||
}
|
||||
processes.delete(id);
|
||||
|
||||
// Start new one
|
||||
const newBg = startProcess({
|
||||
command: config.command,
|
||||
cwd: config.cwd,
|
||||
label: config.label,
|
||||
type: config.processType,
|
||||
ownerSessionFile: config.ownerSessionFile,
|
||||
persistAcrossSessions: config.persistAcrossSessions,
|
||||
readyPattern: config.readyPattern || undefined,
|
||||
readyPort: config.readyPort || undefined,
|
||||
group: config.group || undefined,
|
||||
});
|
||||
newBg.restartCount = restartCount;
|
||||
|
||||
return newBg;
|
||||
}
|
||||
|
||||
// ── Group Operations ───────────────────────────────────────────────────────
|
||||
|
||||
export function getGroupProcesses(group: string): BgProcess[] {
|
||||
return Array.from(processes.values()).filter((p) => p.group === group);
|
||||
}
|
||||
|
||||
export function getGroupStatus(group: string): {
|
||||
group: string;
|
||||
healthy: boolean;
|
||||
processes: {
|
||||
id: string;
|
||||
label: string;
|
||||
status: import("./types.js").ProcessStatus;
|
||||
alive: boolean;
|
||||
}[];
|
||||
} {
|
||||
const procs = getGroupProcesses(group);
|
||||
const healthy =
|
||||
procs.length > 0 &&
|
||||
procs.every(
|
||||
(p) => p.alive && (p.status === "ready" || p.status === "starting"),
|
||||
);
|
||||
return {
|
||||
group,
|
||||
healthy,
|
||||
processes: procs.map((p) => ({
|
||||
id: p.id,
|
||||
label: p.label,
|
||||
status: p.status,
|
||||
alive: p.alive,
|
||||
})),
|
||||
};
|
||||
}
|
||||
|
||||
// ── Cleanup ────────────────────────────────────────────────────────────────
|
||||
|
||||
export function pruneDeadProcesses(): void {
|
||||
const now = Date.now();
|
||||
for (const [id, bg] of processes) {
|
||||
if (!bg.alive) {
|
||||
const ttl =
|
||||
bg.processType === "shell" ? DEAD_PROCESS_TTL * 6 : DEAD_PROCESS_TTL;
|
||||
if (now - bg.startedAt > ttl) {
|
||||
processes.delete(id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function cleanupAll(): void {
|
||||
for (const [id, bg] of processes) {
|
||||
if (bg.alive) killProcess(id, "SIGKILL");
|
||||
}
|
||||
processes.clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Kill all alive, non-persistent bg processes.
|
||||
* Called between auto-mode units to prevent orphaned servers from
|
||||
* keeping ports bound across task boundaries (#1209).
|
||||
*/
|
||||
export function killSessionProcesses(): void {
|
||||
for (const [id, bg] of processes) {
|
||||
if (bg.alive && !bg.persistAcrossSessions) {
|
||||
killProcess(id, "SIGTERM");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForProcessExit(
|
||||
bg: BgProcess,
|
||||
timeoutMs: number,
|
||||
): Promise<boolean> {
|
||||
if (!bg.alive) return true;
|
||||
await new Promise<void>((resolve) => {
|
||||
const done = () => resolve();
|
||||
const timer = setTimeout(done, timeoutMs);
|
||||
bg.proc.once("exit", () => {
|
||||
clearTimeout(timer);
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
return !bg.alive;
|
||||
}
|
||||
|
||||
export async function cleanupSessionProcesses(
|
||||
sessionFile: string,
|
||||
options?: { graceMs?: number },
|
||||
): Promise<string[]> {
|
||||
const graceMs = Math.max(0, options?.graceMs ?? 300);
|
||||
const matches = Array.from(processes.values()).filter(
|
||||
(bg) =>
|
||||
bg.alive &&
|
||||
!bg.persistAcrossSessions &&
|
||||
bg.ownerSessionFile === sessionFile,
|
||||
);
|
||||
if (matches.length === 0) return [];
|
||||
|
||||
for (const bg of matches) {
|
||||
killProcess(bg.id, "SIGTERM");
|
||||
}
|
||||
if (graceMs > 0) {
|
||||
await Promise.all(matches.map((bg) => waitForProcessExit(bg, graceMs)));
|
||||
}
|
||||
for (const bg of matches) {
|
||||
if (bg.alive) killProcess(bg.id, "SIGKILL");
|
||||
}
|
||||
return matches.map((bg) => bg.id);
|
||||
}
|
||||
|
||||
// ── Persistence ────────────────────────────────────────────────────────────
|
||||
|
||||
export function getManifestPath(cwd: string): string {
|
||||
const dir = join(cwd, ".bg-shell");
|
||||
if (!existsSync(dir)) mkdirSync(dir, { recursive: true });
|
||||
return join(dir, "manifest.json");
|
||||
}
|
||||
|
||||
export function persistManifest(cwd: string): void {
|
||||
try {
|
||||
const manifest: ProcessManifest[] = Array.from(processes.values())
|
||||
.filter((p) => p.alive)
|
||||
.map((p) => ({
|
||||
id: p.id,
|
||||
label: p.label,
|
||||
command: p.command,
|
||||
cwd: p.cwd,
|
||||
ownerSessionFile: p.ownerSessionFile,
|
||||
persistAcrossSessions: p.persistAcrossSessions,
|
||||
startedAt: p.startedAt,
|
||||
processType: p.processType,
|
||||
group: p.group,
|
||||
readyPattern: p.readyPattern,
|
||||
readyPort: p.readyPort,
|
||||
pid: p.proc.pid,
|
||||
}));
|
||||
writeFileSync(getManifestPath(cwd), JSON.stringify(manifest, null, 2));
|
||||
} catch {
|
||||
/* best effort */
|
||||
}
|
||||
}
|
||||
|
||||
export function loadManifest(cwd: string): ProcessManifest[] {
|
||||
try {
|
||||
const path = getManifestPath(cwd);
|
||||
if (existsSync(path)) {
|
||||
return JSON.parse(readFileSync(path, "utf-8"));
|
||||
}
|
||||
} catch {
|
||||
/* best effort */
|
||||
}
|
||||
return [];
|
||||
}
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
/**
|
||||
* Readiness detection: port probing, pattern matching, wait-for-ready.
|
||||
*/
|
||||
|
||||
import { createConnection } from "node:net";
|
||||
import { addEvent, pushAlert } from "./process-manager.js";
|
||||
import type { BgProcess } from "./types.js";
|
||||
import {
|
||||
DEFAULT_READY_TIMEOUT,
|
||||
PORT_PROBE_TIMEOUT,
|
||||
READY_POLL_INTERVAL,
|
||||
} from "./types.js";
|
||||
|
||||
// ── Readiness Transition ───────────────────────────────────────────────────
|
||||
|
||||
export function transitionToReady(bg: BgProcess, detail: string): void {
|
||||
bg.status = "ready";
|
||||
bg.wasReady = true;
|
||||
addEvent(bg, { type: "ready", detail });
|
||||
}
|
||||
|
||||
// ── Port Probing ───────────────────────────────────────────────────────────
|
||||
|
||||
export function probePort(
|
||||
port: number,
|
||||
host: string = "127.0.0.1",
|
||||
): Promise<boolean> {
|
||||
return new Promise((resolve) => {
|
||||
const socket = createConnection(
|
||||
{ port, host, timeout: PORT_PROBE_TIMEOUT },
|
||||
() => {
|
||||
socket.destroy();
|
||||
resolve(true);
|
||||
},
|
||||
);
|
||||
socket.on("error", () => {
|
||||
socket.destroy();
|
||||
resolve(false);
|
||||
});
|
||||
socket.on("timeout", () => {
|
||||
socket.destroy();
|
||||
resolve(false);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// ── Port Probing Loop ──────────────────────────────────────────────────────
|
||||
|
||||
export function startPortProbing(
|
||||
bg: BgProcess,
|
||||
port: number,
|
||||
customTimeout?: number,
|
||||
): void {
|
||||
const timeout = customTimeout || DEFAULT_READY_TIMEOUT;
|
||||
const interval = setInterval(async () => {
|
||||
if (!bg.alive) {
|
||||
clearInterval(interval);
|
||||
const stderrLines = bg.output
|
||||
.filter((l) => l.stream === "stderr")
|
||||
.slice(-10)
|
||||
.map((l) => l.line);
|
||||
const detail = `Process exited (code ${bg.exitCode}) before port ${port} opened${stderrLines.length > 0 ? ` — ${stderrLines.join("; ").slice(0, 200)}` : ""}`;
|
||||
addEvent(bg, {
|
||||
type: "port_timeout",
|
||||
detail,
|
||||
data: { port, exitCode: bg.exitCode },
|
||||
});
|
||||
return;
|
||||
}
|
||||
if (bg.status !== "starting") {
|
||||
clearInterval(interval);
|
||||
return;
|
||||
}
|
||||
const open = await probePort(port);
|
||||
if (open) {
|
||||
clearInterval(interval);
|
||||
if (!bg.ports.includes(port)) bg.ports.push(port);
|
||||
transitionToReady(bg, `Port ${port} is open`);
|
||||
addEvent(bg, {
|
||||
type: "port_open",
|
||||
detail: `Port ${port} is open`,
|
||||
data: { port },
|
||||
});
|
||||
}
|
||||
}, READY_POLL_INTERVAL);
|
||||
|
||||
// Stop probing after timeout — transition to error state so the process
|
||||
// doesn't stay in "starting" forever (fixes #428)
|
||||
setTimeout(() => {
|
||||
clearInterval(interval);
|
||||
if (bg.alive && bg.status === "starting") {
|
||||
const stderrLines = bg.output
|
||||
.filter((l) => l.stream === "stderr")
|
||||
.slice(-10)
|
||||
.map((l) => l.line);
|
||||
const detail = `Port ${port} not open after ${timeout}ms${stderrLines.length > 0 ? ` — ${stderrLines.join("; ").slice(0, 200)}` : ""}`;
|
||||
bg.status = "error";
|
||||
addEvent(bg, { type: "port_timeout", detail, data: { port, timeout } });
|
||||
pushAlert(bg, `Port ${port} readiness timeout after ${timeout / 1000}s`);
|
||||
}
|
||||
}, timeout);
|
||||
}
|
||||
|
||||
// ── Wait for Ready ─────────────────────────────────────────────────────────
|
||||
|
||||
export async function waitForReady(
|
||||
bg: BgProcess,
|
||||
timeout: number,
|
||||
signal?: AbortSignal,
|
||||
): Promise<{ ready: boolean; detail: string }> {
|
||||
const start = Date.now();
|
||||
|
||||
while (Date.now() - start < timeout) {
|
||||
if (signal?.aborted) {
|
||||
return { ready: false, detail: "Cancelled" };
|
||||
}
|
||||
if (!bg.alive) {
|
||||
const stderrLines = bg.output
|
||||
.filter((l) => l.stream === "stderr")
|
||||
.slice(-5)
|
||||
.map((l) => l.line);
|
||||
const stderrContext =
|
||||
stderrLines.length > 0
|
||||
? `\nstderr:\n${stderrLines.join("\n").slice(0, 500)}`
|
||||
: "";
|
||||
return {
|
||||
ready: false,
|
||||
detail: `Process exited before becoming ready (code ${bg.exitCode})${bg.recentErrors.length > 0 ? ` — ${bg.recentErrors.slice(-1)[0]}` : ""}${stderrContext}`,
|
||||
};
|
||||
}
|
||||
if (bg.status === "error") {
|
||||
const stderrLines = bg.output
|
||||
.filter((l) => l.stream === "stderr")
|
||||
.slice(-5)
|
||||
.map((l) => l.line);
|
||||
const stderrContext =
|
||||
stderrLines.length > 0
|
||||
? `\nstderr:\n${stderrLines.join("\n").slice(0, 500)}`
|
||||
: "";
|
||||
return {
|
||||
ready: false,
|
||||
detail: `Process entered error state${bg.readyPort ? ` (port ${bg.readyPort} never opened)` : ""}${stderrContext}`,
|
||||
};
|
||||
}
|
||||
if (bg.status === "ready") {
|
||||
return {
|
||||
ready: true,
|
||||
detail:
|
||||
bg.events.find((e) => e.type === "ready")?.detail ||
|
||||
"Process is ready",
|
||||
};
|
||||
}
|
||||
await new Promise((r) => setTimeout(r, READY_POLL_INTERVAL));
|
||||
}
|
||||
|
||||
// Timeout — try port probe as last resort
|
||||
if (bg.readyPort) {
|
||||
const open = await probePort(bg.readyPort);
|
||||
if (open) {
|
||||
transitionToReady(
|
||||
bg,
|
||||
`Port ${bg.readyPort} is open (detected at timeout)`,
|
||||
);
|
||||
return { ready: true, detail: `Port ${bg.readyPort} is open` };
|
||||
}
|
||||
}
|
||||
|
||||
const stderrLines = bg.output
|
||||
.filter((l) => l.stream === "stderr")
|
||||
.slice(-5)
|
||||
.map((l) => l.line);
|
||||
const stderrContext =
|
||||
stderrLines.length > 0
|
||||
? `\nstderr:\n${stderrLines.join("\n").slice(0, 500)}`
|
||||
: "";
|
||||
return {
|
||||
ready: false,
|
||||
detail: `Timed out after ${timeout}ms waiting for ready signal${stderrContext}`,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,297 +0,0 @@
|
|||
/**
|
||||
* Shared types, constants, and pattern databases for the bg-shell extension.
|
||||
*/
|
||||
|
||||
// ── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
export type ProcessStatus =
|
||||
| "starting"
|
||||
| "ready"
|
||||
| "error"
|
||||
| "exited"
|
||||
| "crashed";
|
||||
|
||||
export type ProcessType =
|
||||
| "server"
|
||||
| "build"
|
||||
| "test"
|
||||
| "watcher"
|
||||
| "generic"
|
||||
| "shell";
|
||||
|
||||
export interface ProcessEvent {
|
||||
type:
|
||||
| "started"
|
||||
| "ready"
|
||||
| "error_detected"
|
||||
| "recovered"
|
||||
| "exited"
|
||||
| "crashed"
|
||||
| "port_open"
|
||||
| "port_timeout";
|
||||
timestamp: number;
|
||||
detail: string;
|
||||
data?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface OutputDigest {
|
||||
status: ProcessStatus;
|
||||
uptime: string;
|
||||
errors: string[];
|
||||
warnings: string[];
|
||||
urls: string[];
|
||||
ports: number[];
|
||||
lastActivity: string;
|
||||
outputLines: number;
|
||||
changeSummary: string;
|
||||
}
|
||||
|
||||
export interface OutputLine {
|
||||
stream: "stdout" | "stderr";
|
||||
line: string;
|
||||
ts: number;
|
||||
}
|
||||
|
||||
export interface BgProcess {
|
||||
id: string;
|
||||
label: string;
|
||||
command: string;
|
||||
cwd: string;
|
||||
/** Session file that created this process (used for per-session cleanup) */
|
||||
ownerSessionFile: string | null;
|
||||
/** Whether this process should survive a new-session boundary */
|
||||
persistAcrossSessions: boolean;
|
||||
startedAt: number;
|
||||
proc: import("node:child_process").ChildProcess;
|
||||
/** Unified chronologically-interleaved output buffer */
|
||||
output: OutputLine[];
|
||||
exitCode: number | null;
|
||||
signal: string | null;
|
||||
alive: boolean;
|
||||
/** Tracks how many lines in the unified output buffer the LLM has already seen */
|
||||
lastReadIndex: number;
|
||||
/** Process classification */
|
||||
processType: ProcessType;
|
||||
/** Current lifecycle status */
|
||||
status: ProcessStatus;
|
||||
/** Detected ports */
|
||||
ports: number[];
|
||||
/** Detected URLs */
|
||||
urls: string[];
|
||||
/** Accumulated errors since last read */
|
||||
recentErrors: string[];
|
||||
/** Accumulated warnings since last read */
|
||||
recentWarnings: string[];
|
||||
/** Lifecycle events log */
|
||||
events: ProcessEvent[];
|
||||
/** Ready pattern (regex string) */
|
||||
readyPattern: string | null;
|
||||
/** Ready port to probe */
|
||||
readyPort: number | null;
|
||||
/** Whether readiness was ever achieved */
|
||||
wasReady: boolean;
|
||||
/** Group membership */
|
||||
group: string | null;
|
||||
/** Last error count snapshot for diff detection */
|
||||
lastErrorCount: number;
|
||||
/** Last warning count snapshot for diff detection */
|
||||
lastWarningCount: number;
|
||||
/** Tracked stdout line count (incremented in addOutputLine, avoids O(n) filter) */
|
||||
stdoutLineCount: number;
|
||||
/** Tracked stderr line count (incremented in addOutputLine, avoids O(n) filter) */
|
||||
stderrLineCount: number;
|
||||
/** Restart count */
|
||||
restartCount: number;
|
||||
/** Original start config for restart */
|
||||
startConfig: {
|
||||
command: string;
|
||||
cwd: string;
|
||||
label: string;
|
||||
processType: ProcessType;
|
||||
ownerSessionFile: string | null;
|
||||
persistAcrossSessions: boolean;
|
||||
readyPattern: string | null;
|
||||
readyPort: number | null;
|
||||
group: string | null;
|
||||
};
|
||||
}
|
||||
|
||||
export interface BgProcessInfo {
|
||||
id: string;
|
||||
label: string;
|
||||
command: string;
|
||||
cwd: string;
|
||||
ownerSessionFile: string | null;
|
||||
persistAcrossSessions: boolean;
|
||||
startedAt: number;
|
||||
alive: boolean;
|
||||
exitCode: number | null;
|
||||
signal: string | null;
|
||||
outputLines: number;
|
||||
stdoutLines: number;
|
||||
stderrLines: number;
|
||||
status: ProcessStatus;
|
||||
processType: ProcessType;
|
||||
ports: number[];
|
||||
urls: string[];
|
||||
group: string | null;
|
||||
restartCount: number;
|
||||
uptime: string;
|
||||
recentErrorCount: number;
|
||||
recentWarningCount: number;
|
||||
eventCount: number;
|
||||
}
|
||||
|
||||
export interface StartOptions {
|
||||
command: string;
|
||||
cwd: string;
|
||||
ownerSessionFile?: string | null;
|
||||
persistAcrossSessions?: boolean;
|
||||
label?: string;
|
||||
type?: ProcessType;
|
||||
readyPattern?: string;
|
||||
readyPort?: number;
|
||||
readyTimeout?: number;
|
||||
group?: string;
|
||||
env?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface GetOutputOptions {
|
||||
stream: "stdout" | "stderr" | "both";
|
||||
tail?: number;
|
||||
filter?: string;
|
||||
incremental?: boolean;
|
||||
}
|
||||
|
||||
export interface ProcessManifest {
|
||||
id: string;
|
||||
label: string;
|
||||
command: string;
|
||||
cwd: string;
|
||||
ownerSessionFile: string | null;
|
||||
persistAcrossSessions: boolean;
|
||||
startedAt: number;
|
||||
processType: ProcessType;
|
||||
group: string | null;
|
||||
readyPattern: string | null;
|
||||
readyPort: number | null;
|
||||
pid: number | undefined;
|
||||
}
|
||||
|
||||
// ── Constants ──────────────────────────────────────────────────────────────
|
||||
|
||||
export const MAX_BUFFER_LINES = 5000;
|
||||
export const MAX_EVENTS = 200;
|
||||
export const DEAD_PROCESS_TTL = 10 * 60 * 1000;
|
||||
export const PORT_PROBE_TIMEOUT = 500;
|
||||
export const READY_POLL_INTERVAL = 250;
|
||||
export const DEFAULT_READY_TIMEOUT = 30000;
|
||||
|
||||
// ── Pattern Databases ──────────────────────────────────────────────────────
|
||||
|
||||
/** Patterns that indicate a process is ready/listening */
|
||||
export const READINESS_PATTERNS: RegExp[] = [
|
||||
// Node/JS servers
|
||||
/listening\s+on\s+(?:port\s+)?(\d+)/i,
|
||||
/server\s+(?:is\s+)?(?:running|started|listening)\s+(?:at|on)\s+/i,
|
||||
/ready\s+(?:in|on|at)\s+/i,
|
||||
/started\s+(?:server\s+)?on\s+/i,
|
||||
// Next.js / Vite / etc
|
||||
/Local:\s*https?:\/\//i,
|
||||
/➜\s+Local:\s*/i,
|
||||
/compiled\s+(?:successfully|client\s+and\s+server)/i,
|
||||
// Python
|
||||
/running\s+on\s+https?:\/\//i,
|
||||
/Uvicorn\s+running/i,
|
||||
/Development\s+server\s+is\s+running/i,
|
||||
// Generic
|
||||
/press\s+ctrl[-+]c\s+to\s+(?:quit|stop)/i,
|
||||
/watching\s+for\s+(?:file\s+)?changes/i,
|
||||
/build\s+(?:completed|succeeded|finished)/i,
|
||||
];
|
||||
|
||||
/** Patterns that indicate errors */
|
||||
export const ERROR_PATTERNS: RegExp[] = [
|
||||
/\berror\b[\s:[\](]/i,
|
||||
/\bERROR\b/,
|
||||
/\bfailed\b/i,
|
||||
/\bFAILED\b/,
|
||||
/\bfatal\b/i,
|
||||
/\bFATAL\b/,
|
||||
/\bexception\b/i,
|
||||
/\bpanic\b/i,
|
||||
/\bsegmentation\s+fault\b/i,
|
||||
/\bsyntax\s*error\b/i,
|
||||
/\btype\s*error\b/i,
|
||||
/\breference\s*error\b/i,
|
||||
/Cannot\s+find\s+module/i,
|
||||
/Module\s+not\s+found/i,
|
||||
/ENOENT/,
|
||||
/EACCES/,
|
||||
/EADDRINUSE/,
|
||||
/TS\d{4,5}:/, // TypeScript errors
|
||||
/E\d{4,5}:/, // Rust errors
|
||||
/\[ERROR\]/,
|
||||
/✖|✗|❌/, // Common error symbols
|
||||
];
|
||||
|
||||
/** Patterns that indicate warnings */
|
||||
export const WARNING_PATTERNS: RegExp[] = [
|
||||
/\bwarning\b[\s:[\](]/i,
|
||||
/\bWARN(?:ING)?\b/,
|
||||
/\bdeprecated\b/i,
|
||||
/\bDEPRECATED\b/,
|
||||
/⚠️?/,
|
||||
/\[WARN\]/,
|
||||
];
|
||||
|
||||
/** Patterns to extract URLs */
|
||||
export const URL_PATTERN = /https?:\/\/[^\s"'<>)\]]+/gi;
|
||||
|
||||
/** Patterns to extract port numbers from "listening" messages */
|
||||
export const PORT_PATTERN = /(?:port|listening\s+on|:)\s*(\d{2,5})\b/gi;
|
||||
|
||||
/** Patterns indicating test results */
|
||||
export const TEST_RESULT_PATTERNS: RegExp[] = [
|
||||
/(\d+)\s+(?:tests?\s+)?passed/i,
|
||||
/(\d+)\s+(?:tests?\s+)?failed/i,
|
||||
/Tests?:\s+(\d+)\s+passed/i,
|
||||
/(\d+)\s+passing/i,
|
||||
/(\d+)\s+failing/i,
|
||||
/PASS|FAIL/,
|
||||
];
|
||||
|
||||
/** Patterns indicating build completion */
|
||||
export const BUILD_COMPLETE_PATTERNS: RegExp[] = [
|
||||
/build\s+(?:completed|succeeded|finished|done)/i,
|
||||
/compiled\s+(?:successfully|with\s+\d+\s+(?:error|warning))/i,
|
||||
/✓\s+Built/i,
|
||||
/webpack\s+\d+\.\d+/i,
|
||||
/bundle\s+(?:is\s+)?ready/i,
|
||||
];
|
||||
|
||||
// ── Compiled union regexes (single-pass alternatives to .some(p => p.test(line))) ──
|
||||
// Built once at module load — eliminates per-line RegExp construction overhead.
|
||||
|
||||
export const ERROR_PATTERN_UNION = new RegExp(
|
||||
ERROR_PATTERNS.map((p) => p.source).join("|"),
|
||||
"i",
|
||||
);
|
||||
export const WARNING_PATTERN_UNION = new RegExp(
|
||||
WARNING_PATTERNS.map((p) => p.source).join("|"),
|
||||
"i",
|
||||
);
|
||||
export const READINESS_PATTERN_UNION = new RegExp(
|
||||
READINESS_PATTERNS.map((p) => p.source).join("|"),
|
||||
"i",
|
||||
);
|
||||
export const BUILD_COMPLETE_PATTERN_UNION = new RegExp(
|
||||
BUILD_COMPLETE_PATTERNS.map((p) => p.source).join("|"),
|
||||
"i",
|
||||
);
|
||||
export const TEST_RESULT_PATTERN_UNION = new RegExp(
|
||||
TEST_RESULT_PATTERNS.map((p) => p.source).join("|"),
|
||||
"i",
|
||||
);
|
||||
/** PORT_PATTERN compiled once for reuse in analyzeLine (needs exec, so must be re-created per call with /g) */
|
||||
export const PORT_PATTERN_SOURCE = PORT_PATTERN.source;
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
/**
|
||||
* Utility functions for the bg-shell extension.
|
||||
*/
|
||||
|
||||
import { existsSync } from "node:fs";
|
||||
import { createRequire } from "node:module";
|
||||
|
||||
// ── Windows VT Input Restoration ────────────────────────────────────────────
|
||||
// Child processes (esp. Git Bash / MSYS2) can strip the ENABLE_VIRTUAL_TERMINAL_INPUT
|
||||
// flag from the shared stdin console handle. Re-enable it after each child exits.
|
||||
|
||||
let _vtHandles: {
|
||||
GetConsoleMode: (...args: unknown[]) => unknown;
|
||||
SetConsoleMode: (...args: unknown[]) => unknown;
|
||||
handle: unknown;
|
||||
} | null = null;
|
||||
export function restoreWindowsVTInput(): void {
|
||||
if (process.platform !== "win32") return;
|
||||
try {
|
||||
if (!_vtHandles) {
|
||||
const cjsRequire = createRequire(import.meta.url);
|
||||
const koffi = cjsRequire("koffi");
|
||||
const k32 = koffi.load("kernel32.dll");
|
||||
const GetStdHandle = k32.func("void* __stdcall GetStdHandle(int)");
|
||||
const GetConsoleMode = k32.func(
|
||||
"bool __stdcall GetConsoleMode(void*, _Out_ uint32_t*)",
|
||||
);
|
||||
const SetConsoleMode = k32.func(
|
||||
"bool __stdcall SetConsoleMode(void*, uint32_t)",
|
||||
);
|
||||
const handle = GetStdHandle(-10);
|
||||
_vtHandles = { GetConsoleMode, SetConsoleMode, handle };
|
||||
}
|
||||
const ENABLE_VIRTUAL_TERMINAL_INPUT = 0x0200;
|
||||
const mode = new Uint32Array(1);
|
||||
_vtHandles.GetConsoleMode(_vtHandles.handle, mode);
|
||||
if (!(mode[0] & ENABLE_VIRTUAL_TERMINAL_INPUT)) {
|
||||
_vtHandles.SetConsoleMode(
|
||||
_vtHandles.handle,
|
||||
mode[0] | ENABLE_VIRTUAL_TERMINAL_INPUT,
|
||||
);
|
||||
}
|
||||
} catch {
|
||||
/* koffi not available on non-Windows */
|
||||
}
|
||||
}
|
||||
|
||||
// ── Time Formatting ────────────────────────────────────────────────────────
|
||||
|
||||
import { formatDuration } from "../shared/mod.js";
|
||||
|
||||
export const formatUptime = formatDuration;
|
||||
|
||||
export function formatTimeAgo(timestamp: number): string {
|
||||
return formatDuration(Date.now() - timestamp) + " ago";
|
||||
}
|
||||
|
||||
function deriveProjectRootFromAutoWorktree(
|
||||
cachedCwd?: string,
|
||||
): string | undefined {
|
||||
if (!cachedCwd) return undefined;
|
||||
const match = cachedCwd.match(
|
||||
/^(.*?)[\\/]\.sf[\\/]worktrees[\\/][^\\/]+(?:[\\/].*)?$/,
|
||||
);
|
||||
return match?.[1];
|
||||
}
|
||||
|
||||
export function getBgShellLiveCwd(
|
||||
cachedCwd?: string,
|
||||
pathExists: (path: string) => boolean = existsSync,
|
||||
getCwd: () => string = () => process.cwd(),
|
||||
chdir: (path: string) => void = (path) => process.chdir(path),
|
||||
): string {
|
||||
try {
|
||||
return getCwd();
|
||||
} catch {
|
||||
const projectRoot = deriveProjectRootFromAutoWorktree(cachedCwd);
|
||||
const home = process.env.HOME || process.env.USERPROFILE;
|
||||
const fallbacks = [projectRoot, cachedCwd, home, "/"].filter(
|
||||
(candidate): candidate is string => Boolean(candidate),
|
||||
);
|
||||
|
||||
for (const candidate of fallbacks) {
|
||||
if (candidate !== "/" && !pathExists(candidate)) continue;
|
||||
try {
|
||||
chdir(candidate);
|
||||
} catch {
|
||||
// Best-effort only. Returning a known-good fallback is enough to avoid crashes.
|
||||
}
|
||||
return candidate;
|
||||
}
|
||||
|
||||
return "/";
|
||||
}
|
||||
}
|
||||
|
||||
export function resolveBgShellPersistenceCwd(
|
||||
cachedCwd: string,
|
||||
liveCwd: string | undefined = undefined,
|
||||
pathExists: (path: string) => boolean = existsSync,
|
||||
): string {
|
||||
const resolvedLiveCwd = liveCwd ?? getBgShellLiveCwd(cachedCwd, pathExists);
|
||||
const cachedIsAutoWorktree = /(?:^|[\\/])\.sf[\\/]worktrees[\\/]/.test(
|
||||
cachedCwd,
|
||||
);
|
||||
if (!cachedIsAutoWorktree) return cachedCwd;
|
||||
if (cachedCwd === resolvedLiveCwd && pathExists(cachedCwd)) return cachedCwd;
|
||||
if (!pathExists(cachedCwd)) return resolvedLiveCwd;
|
||||
if (resolvedLiveCwd !== cachedCwd) return resolvedLiveCwd;
|
||||
return cachedCwd;
|
||||
}
|
||||
|
|
@ -1,280 +0,0 @@
|
|||
/**
|
||||
* browser-tools — page state capture
|
||||
*
|
||||
* Functions for capturing compact page state, screenshots, and summaries.
|
||||
* Used by tool implementations for post-action feedback.
|
||||
*/
|
||||
|
||||
import type { Frame, Page } from "playwright";
|
||||
|
||||
// sharp is an optional native dependency. Load it lazily so that the extension
|
||||
// can still be loaded on platforms where sharp is unavailable (e.g. bunx on
|
||||
// Raspberry Pi). constrainScreenshot falls back to returning the raw buffer
|
||||
// when sharp is not installed, which means screenshots won't be resized but
|
||||
// the tool remains functional.
|
||||
let _sharp: typeof import("sharp") | null | undefined;
|
||||
async function getSharp(): Promise<typeof import("sharp") | null> {
|
||||
if (_sharp !== undefined) return _sharp;
|
||||
try {
|
||||
_sharp = (await import("sharp")).default;
|
||||
} catch {
|
||||
_sharp = null;
|
||||
}
|
||||
return _sharp;
|
||||
}
|
||||
|
||||
import type { CompactPageState } from "./state.js";
|
||||
import { formatCompactStateSummary } from "./utils.js";
|
||||
|
||||
// Anthropic vision: 1568px is the recommended optimal width. Height is capped
|
||||
// generously at 8000px so tall full-page screenshots remain readable rather
|
||||
// than being squished into a square constraint.
|
||||
//
|
||||
// Override via environment variables:
|
||||
// SCREENSHOT_MAX_WIDTH=0 → uncap width (use raw resolution)
|
||||
// SCREENSHOT_MAX_HEIGHT=0 → uncap height
|
||||
// SCREENSHOT_FORMAT=png → lossless PNG for all viewport/fullpage screenshots
|
||||
// SCREENSHOT_QUALITY=100 → max JPEG quality (1-100, default 80)
|
||||
const MAX_SCREENSHOT_WIDTH = parseScreenshotDimension(
|
||||
process.env.SCREENSHOT_MAX_WIDTH,
|
||||
1568,
|
||||
);
|
||||
const MAX_SCREENSHOT_HEIGHT = parseScreenshotDimension(
|
||||
process.env.SCREENSHOT_MAX_HEIGHT,
|
||||
8000,
|
||||
);
|
||||
|
||||
/** Parse a dimension env var: positive int = that value, 0 = Infinity (uncapped), absent/invalid = default. */
|
||||
function parseScreenshotDimension(
|
||||
value: string | undefined,
|
||||
fallback: number,
|
||||
): number {
|
||||
if (value === undefined || value === "") return fallback;
|
||||
const n = parseInt(value, 10);
|
||||
if (Number.isNaN(n) || n < 0) return fallback;
|
||||
if (n === 0) return Infinity;
|
||||
return n;
|
||||
}
|
||||
|
||||
/** Return the user-configured screenshot format override, or null for default behavior. */
|
||||
export function getScreenshotFormatOverride(): "png" | "jpeg" | null {
|
||||
const fmt = process.env.SCREENSHOT_FORMAT?.toLowerCase();
|
||||
if (fmt === "png") return "png";
|
||||
if (fmt === "jpeg" || fmt === "jpg") return "jpeg";
|
||||
return null;
|
||||
}
|
||||
|
||||
/** Return the user-configured default JPEG quality, or the provided fallback. */
|
||||
export function getScreenshotQualityDefault(fallback: number): number {
|
||||
const q = process.env.SCREENSHOT_QUALITY;
|
||||
if (q === undefined || q === "") return fallback;
|
||||
const n = parseInt(q, 10);
|
||||
if (Number.isNaN(n) || n < 1 || n > 100) return fallback;
|
||||
return n;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compact page state capture
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function captureCompactPageState(
|
||||
p: Page,
|
||||
options: {
|
||||
selectors?: string[];
|
||||
includeBodyText?: boolean;
|
||||
target?: Page | Frame;
|
||||
} = {},
|
||||
): Promise<CompactPageState> {
|
||||
const selectors = Array.from(
|
||||
new Set((options.selectors ?? []).filter(Boolean)),
|
||||
);
|
||||
const target = options.target ?? p;
|
||||
const domState = await target.evaluate(
|
||||
({ selectors, includeBodyText }) => {
|
||||
const selectorStates: Record<
|
||||
string,
|
||||
{
|
||||
exists: boolean;
|
||||
visible: boolean;
|
||||
value: string;
|
||||
checked: boolean | null;
|
||||
text: string;
|
||||
}
|
||||
> = {};
|
||||
for (const selector of selectors) {
|
||||
let el: Element | null = null;
|
||||
try {
|
||||
el = document.querySelector(selector);
|
||||
} catch {
|
||||
el = null;
|
||||
}
|
||||
if (!el) {
|
||||
selectorStates[selector] = {
|
||||
exists: false,
|
||||
visible: false,
|
||||
value: "",
|
||||
checked: null,
|
||||
text: "",
|
||||
};
|
||||
continue;
|
||||
}
|
||||
const htmlEl = el as HTMLElement;
|
||||
const style = window.getComputedStyle(htmlEl);
|
||||
const rect = htmlEl.getBoundingClientRect();
|
||||
const visible =
|
||||
style.display !== "none" &&
|
||||
style.visibility !== "hidden" &&
|
||||
rect.width > 0 &&
|
||||
rect.height > 0;
|
||||
const input = el as HTMLInputElement;
|
||||
selectorStates[selector] = {
|
||||
exists: true,
|
||||
visible,
|
||||
value:
|
||||
el instanceof HTMLInputElement ||
|
||||
el instanceof HTMLTextAreaElement ||
|
||||
el instanceof HTMLSelectElement
|
||||
? el.value
|
||||
: htmlEl.getAttribute("value") || "",
|
||||
checked:
|
||||
el instanceof HTMLInputElement &&
|
||||
["checkbox", "radio"].includes(input.type)
|
||||
? input.checked
|
||||
: null,
|
||||
text: (htmlEl.innerText || htmlEl.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 160),
|
||||
};
|
||||
}
|
||||
|
||||
const focused = document.activeElement as HTMLElement | null;
|
||||
const focusedDesc =
|
||||
focused &&
|
||||
focused !== document.body &&
|
||||
focused !== document.documentElement
|
||||
? `${focused.tagName.toLowerCase()}${focused.id ? "#" + focused.id : ""}${focused.getAttribute("aria-label") ? ' "' + focused.getAttribute("aria-label") + '"' : ""}`
|
||||
: "";
|
||||
const headings = Array.from(document.querySelectorAll("h1,h2,h3"))
|
||||
.slice(0, 5)
|
||||
.map((h) =>
|
||||
(h.textContent || "").trim().replace(/\s+/g, " ").slice(0, 80),
|
||||
);
|
||||
const dialog = document.querySelector(
|
||||
'[role="dialog"]:not([hidden]),dialog[open]',
|
||||
);
|
||||
const dialogTitle =
|
||||
dialog
|
||||
?.querySelector('[role="heading"],[aria-label]')
|
||||
?.textContent?.trim()
|
||||
.slice(0, 80) ?? "";
|
||||
const bodyText = includeBodyText
|
||||
? (document.body?.innerText || document.body?.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 4000)
|
||||
: "";
|
||||
return {
|
||||
url: window.location.href,
|
||||
title: document.title,
|
||||
focus: focusedDesc,
|
||||
headings,
|
||||
bodyText,
|
||||
counts: {
|
||||
landmarks: document.querySelectorAll(
|
||||
'[role="main"],[role="banner"],[role="navigation"],[role="contentinfo"],[role="complementary"],[role="search"],[role="form"],[role="dialog"],[role="alert"],main,header,nav,footer,aside,section,form,dialog',
|
||||
).length,
|
||||
buttons: document.querySelectorAll('button,[role="button"]').length,
|
||||
links: document.querySelectorAll("a[href]").length,
|
||||
inputs: document.querySelectorAll("input,textarea,select").length,
|
||||
},
|
||||
dialog: {
|
||||
count: document.querySelectorAll(
|
||||
'[role="dialog"]:not([hidden]),dialog[open]',
|
||||
).length,
|
||||
title: dialogTitle,
|
||||
},
|
||||
selectorStates,
|
||||
};
|
||||
},
|
||||
{ selectors, includeBodyText: options.includeBodyText === true },
|
||||
);
|
||||
// URL and title always come from the Page, not the frame
|
||||
return { ...domState, url: p.url(), title: await p.title() };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Post-action summary
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Lightweight page summary after an action. Returns ~50-150 tokens instead of full tree. */
|
||||
export async function postActionSummary(
|
||||
p: Page,
|
||||
target?: Page | Frame,
|
||||
): Promise<string> {
|
||||
try {
|
||||
const state = await captureCompactPageState(p, { target });
|
||||
return formatCompactStateSummary(state);
|
||||
} catch {
|
||||
return "[summary unavailable]";
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Screenshot helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Constrain screenshot dimensions for the Anthropic vision API.
|
||||
* Width is capped at 1568px (optimal) and height at 8000px, each
|
||||
* independently, using `fit: "inside"` so aspect ratio is preserved.
|
||||
* Small images are never upscaled.
|
||||
*
|
||||
* `page` parameter is retained for ToolDeps signature stability (D008)
|
||||
* but is no longer used — all processing is server-side via sharp.
|
||||
*/
|
||||
export async function constrainScreenshot(
|
||||
_page: Page,
|
||||
buffer: Buffer,
|
||||
mimeType: string,
|
||||
quality: number,
|
||||
): Promise<Buffer> {
|
||||
const sharp = await getSharp();
|
||||
if (!sharp) return buffer;
|
||||
|
||||
const meta = await sharp(buffer).metadata();
|
||||
const width = meta.width;
|
||||
const height = meta.height;
|
||||
|
||||
if (width === undefined || height === undefined) return buffer;
|
||||
if (width <= MAX_SCREENSHOT_WIDTH && height <= MAX_SCREENSHOT_HEIGHT)
|
||||
return buffer;
|
||||
|
||||
const resizer = sharp(buffer).resize(
|
||||
MAX_SCREENSHOT_WIDTH,
|
||||
MAX_SCREENSHOT_HEIGHT,
|
||||
{
|
||||
fit: "inside",
|
||||
withoutEnlargement: true,
|
||||
},
|
||||
);
|
||||
|
||||
if (mimeType === "image/png") {
|
||||
return Buffer.from(await resizer.png().toBuffer());
|
||||
}
|
||||
return Buffer.from(await resizer.jpeg({ quality }).toBuffer());
|
||||
}
|
||||
|
||||
/** Capture a JPEG screenshot for error debugging. Returns base64 or null. */
|
||||
export async function captureErrorScreenshot(
|
||||
p: Page | null,
|
||||
): Promise<{ data: string; mimeType: string } | null> {
|
||||
if (!p) return null;
|
||||
try {
|
||||
let buf = await p.screenshot({ type: "jpeg", quality: 60, scale: "css" });
|
||||
buf = await constrainScreenshot(p, buf, "image/jpeg", 60);
|
||||
return { data: buf.toString("base64"), mimeType: "image/jpeg" };
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,184 +0,0 @@
|
|||
/**
|
||||
* browser-tools — browser-side evaluate helpers
|
||||
*
|
||||
* Exports a single string constant `EVALUATE_HELPERS_SOURCE` containing an IIFE
|
||||
* that attaches utility functions to `window.__pi`. This is injected into every
|
||||
* new BrowserContext via `context.addInitScript()` so that `page.evaluate()`
|
||||
* callbacks can reference `window.__pi.cssPath(el)` etc. instead of redeclaring
|
||||
* the same functions inline.
|
||||
*
|
||||
* The `simpleHash` function uses the djb2 algorithm identical to
|
||||
* `computeContentHash` / `computeStructuralSignature` in `core.js`.
|
||||
*
|
||||
* Functions provided (9):
|
||||
* cssPath, simpleHash, isVisible, isEnabled, inferRole,
|
||||
* accessibleName, isInteractiveEl, domPath, selectorHints
|
||||
*/
|
||||
|
||||
export const EVALUATE_HELPERS_SOURCE = `(function() {
|
||||
var pi = window.__pi = window.__pi || {};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 1. simpleHash — djb2 hash matching core.js computeContentHash
|
||||
// -----------------------------------------------------------------------
|
||||
pi.simpleHash = function simpleHash(str) {
|
||||
if (!str) return "0";
|
||||
var h = 5381;
|
||||
for (var i = 0; i < str.length; i++) {
|
||||
h = ((h << 5) - h + str.charCodeAt(i)) | 0;
|
||||
}
|
||||
return (h >>> 0).toString(16);
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 2. isVisible
|
||||
// -----------------------------------------------------------------------
|
||||
pi.isVisible = function isVisible(el) {
|
||||
var style = window.getComputedStyle(el);
|
||||
if (style.display === "none" || style.visibility === "hidden") return false;
|
||||
var rect = el.getBoundingClientRect();
|
||||
return rect.width > 0 && rect.height > 0;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 3. isEnabled
|
||||
// -----------------------------------------------------------------------
|
||||
pi.isEnabled = function isEnabled(el) {
|
||||
var disabledAttr = el.getAttribute("disabled") !== null;
|
||||
var ariaDisabled = (el.getAttribute("aria-disabled") || "").toLowerCase() === "true";
|
||||
return !disabledAttr && !ariaDisabled;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 4. inferRole
|
||||
// -----------------------------------------------------------------------
|
||||
pi.inferRole = function inferRole(el) {
|
||||
var explicit = (el.getAttribute("role") || "").trim();
|
||||
if (explicit) return explicit;
|
||||
var tag = el.tagName.toLowerCase();
|
||||
if (tag === "a" && el.getAttribute("href")) return "link";
|
||||
if (tag === "button") return "button";
|
||||
if (tag === "select") return "combobox";
|
||||
if (tag === "textarea") return "textbox";
|
||||
if (tag === "input") {
|
||||
var type = (el.getAttribute("type") || "text").toLowerCase();
|
||||
if (["button", "submit", "reset"].indexOf(type) !== -1) return "button";
|
||||
if (type === "checkbox") return "checkbox";
|
||||
if (type === "radio") return "radio";
|
||||
if (type === "search") return "searchbox";
|
||||
return "textbox";
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 5. accessibleName
|
||||
// -----------------------------------------------------------------------
|
||||
pi.accessibleName = function accessibleName(el) {
|
||||
var ariaLabel = el.getAttribute("aria-label");
|
||||
if (ariaLabel && ariaLabel.trim()) return ariaLabel.trim();
|
||||
var labelledBy = el.getAttribute("aria-labelledby");
|
||||
if (labelledBy && labelledBy.trim()) {
|
||||
var text = labelledBy.trim().split(/\\s+/).map(function(id) {
|
||||
var ref = document.getElementById(id);
|
||||
return ref ? (ref.textContent || "").trim() : "";
|
||||
}).join(" ").trim();
|
||||
if (text) return text;
|
||||
}
|
||||
var placeholder = el.getAttribute("placeholder");
|
||||
if (placeholder && placeholder.trim()) return placeholder.trim();
|
||||
var alt = el.getAttribute("alt");
|
||||
if (alt && alt.trim()) return alt.trim();
|
||||
var value = el.value;
|
||||
if (value && typeof value === "string" && value.trim()) return value.trim().slice(0, 80);
|
||||
return (el.textContent || "").trim().replace(/\\s+/g, " ").slice(0, 80);
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 6. isInteractiveEl
|
||||
// -----------------------------------------------------------------------
|
||||
var interactiveRoles = {
|
||||
button: 1, link: 1, textbox: 1, searchbox: 1, combobox: 1,
|
||||
checkbox: 1, radio: 1, "switch": 1, menuitem: 1,
|
||||
menuitemcheckbox: 1, menuitemradio: 1, tab: 1, option: 1,
|
||||
slider: 1, spinbutton: 1
|
||||
};
|
||||
pi.isInteractiveEl = function isInteractiveEl(el) {
|
||||
var tag = el.tagName.toLowerCase();
|
||||
var role = pi.inferRole(el);
|
||||
if (["button", "input", "select", "textarea", "summary", "option"].indexOf(tag) !== -1) return true;
|
||||
if (tag === "a" && !!el.getAttribute("href")) return true;
|
||||
if (interactiveRoles[role]) return true;
|
||||
if (el.tabIndex >= 0) return true;
|
||||
if (el.isContentEditable) return true;
|
||||
return false;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 7. cssPath
|
||||
// -----------------------------------------------------------------------
|
||||
pi.cssPath = function cssPath(el) {
|
||||
if (el.id) return "#" + CSS.escape(el.id);
|
||||
var parts = [];
|
||||
var current = el;
|
||||
while (current && current.nodeType === Node.ELEMENT_NODE && current !== document.body) {
|
||||
var tag = current.tagName.toLowerCase();
|
||||
var part = tag;
|
||||
var parent = current.parentElement;
|
||||
if (parent) {
|
||||
var siblings = Array.from(parent.children).filter(function(c) {
|
||||
return c.tagName === current.tagName;
|
||||
});
|
||||
if (siblings.length > 1) {
|
||||
var idx = siblings.indexOf(current) + 1;
|
||||
part += ":nth-of-type(" + idx + ")";
|
||||
}
|
||||
}
|
||||
parts.unshift(part);
|
||||
current = current.parentElement;
|
||||
}
|
||||
return "body > " + parts.join(" > ");
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 8. domPath
|
||||
// -----------------------------------------------------------------------
|
||||
pi.domPath = function domPath(el) {
|
||||
var path = [];
|
||||
var current = el;
|
||||
while (current && current !== document.documentElement) {
|
||||
var parent = current.parentElement;
|
||||
if (!parent) break;
|
||||
var idx = Array.from(parent.children).indexOf(current);
|
||||
path.unshift(idx);
|
||||
current = parent;
|
||||
}
|
||||
return path;
|
||||
};
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// 9. selectorHints
|
||||
// -----------------------------------------------------------------------
|
||||
pi.selectorHints = function selectorHints(el) {
|
||||
var hints = [];
|
||||
if (el.id) hints.push("#" + CSS.escape(el.id));
|
||||
var nameAttr = el.getAttribute("name");
|
||||
if (nameAttr) hints.push(el.tagName.toLowerCase() + '[name="' + CSS.escape(nameAttr) + '"]');
|
||||
var aria = el.getAttribute("aria-label");
|
||||
if (aria) hints.push(el.tagName.toLowerCase() + '[aria-label="' + CSS.escape(aria) + '"]');
|
||||
var placeholder = el.getAttribute("placeholder");
|
||||
if (placeholder) hints.push(el.tagName.toLowerCase() + '[placeholder="' + CSS.escape(placeholder) + '"]');
|
||||
var cls = Array.from(el.classList).slice(0, 2);
|
||||
if (cls.length > 0) hints.push(el.tagName.toLowerCase() + "." + cls.map(function(c) { return CSS.escape(c); }).join("."));
|
||||
hints.push(pi.cssPath(el));
|
||||
var seen = {};
|
||||
var unique = [];
|
||||
for (var i = 0; i < hints.length; i++) {
|
||||
if (!seen[hints[i]]) {
|
||||
seen[hints[i]] = true;
|
||||
unique.push(hints[i]);
|
||||
}
|
||||
}
|
||||
return unique.slice(0, 6);
|
||||
};
|
||||
})();`;
|
||||
|
|
@ -1,262 +0,0 @@
|
|||
/** browser-tools — pi extension: full browser interaction via Playwright. */
|
||||
import {
|
||||
type ExtensionAPI,
|
||||
importExtensionModule,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
|
||||
let registrationPromise: Promise<void> | null = null;
|
||||
|
||||
async function registerBrowserTools(pi: ExtensionAPI): Promise<void> {
|
||||
if (!registrationPromise) {
|
||||
registrationPromise = (async () => {
|
||||
const [
|
||||
lifecycle,
|
||||
capture,
|
||||
settle,
|
||||
refs,
|
||||
utils,
|
||||
navigation,
|
||||
screenshot,
|
||||
interaction,
|
||||
inspection,
|
||||
session,
|
||||
assertions,
|
||||
refTools,
|
||||
wait,
|
||||
pages,
|
||||
forms,
|
||||
intent,
|
||||
pdf,
|
||||
statePersistence,
|
||||
networkMock,
|
||||
device,
|
||||
extract,
|
||||
visualDiff,
|
||||
zoom,
|
||||
codegen,
|
||||
actionCache,
|
||||
injectionDetection,
|
||||
verify,
|
||||
] = await Promise.all([
|
||||
importExtensionModule<typeof import("./lifecycle.js")>(
|
||||
import.meta.url,
|
||||
"./lifecycle.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./capture.js")>(
|
||||
import.meta.url,
|
||||
"./capture.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./settle.js")>(
|
||||
import.meta.url,
|
||||
"./settle.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./refs.js")>(
|
||||
import.meta.url,
|
||||
"./refs.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./utils.js")>(
|
||||
import.meta.url,
|
||||
"./utils.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/navigation.js")>(
|
||||
import.meta.url,
|
||||
"./tools/navigation.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/screenshot.js")>(
|
||||
import.meta.url,
|
||||
"./tools/screenshot.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/interaction.js")>(
|
||||
import.meta.url,
|
||||
"./tools/interaction.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/inspection.js")>(
|
||||
import.meta.url,
|
||||
"./tools/inspection.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/session.js")>(
|
||||
import.meta.url,
|
||||
"./tools/session.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/assertions.js")>(
|
||||
import.meta.url,
|
||||
"./tools/assertions.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/refs.js")>(
|
||||
import.meta.url,
|
||||
"./tools/refs.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/wait.js")>(
|
||||
import.meta.url,
|
||||
"./tools/wait.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/pages.js")>(
|
||||
import.meta.url,
|
||||
"./tools/pages.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/forms.js")>(
|
||||
import.meta.url,
|
||||
"./tools/forms.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/intent.js")>(
|
||||
import.meta.url,
|
||||
"./tools/intent.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/pdf.js")>(
|
||||
import.meta.url,
|
||||
"./tools/pdf.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/state-persistence.js")>(
|
||||
import.meta.url,
|
||||
"./tools/state-persistence.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/network-mock.js")>(
|
||||
import.meta.url,
|
||||
"./tools/network-mock.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/device.js")>(
|
||||
import.meta.url,
|
||||
"./tools/device.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/extract.js")>(
|
||||
import.meta.url,
|
||||
"./tools/extract.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/visual-diff.js")>(
|
||||
import.meta.url,
|
||||
"./tools/visual-diff.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/zoom.js")>(
|
||||
import.meta.url,
|
||||
"./tools/zoom.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/codegen.js")>(
|
||||
import.meta.url,
|
||||
"./tools/codegen.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/action-cache.js")>(
|
||||
import.meta.url,
|
||||
"./tools/action-cache.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/injection-detect.js")>(
|
||||
import.meta.url,
|
||||
"./tools/injection-detect.js",
|
||||
),
|
||||
importExtensionModule<typeof import("./tools/verify.js")>(
|
||||
import.meta.url,
|
||||
"./tools/verify.js",
|
||||
),
|
||||
]);
|
||||
|
||||
const deps = {
|
||||
ensureBrowser: lifecycle.ensureBrowser,
|
||||
closeBrowser: lifecycle.closeBrowser,
|
||||
getActivePage: lifecycle.getActivePage,
|
||||
getActiveTarget: lifecycle.getActiveTarget,
|
||||
getActivePageOrNull: lifecycle.getActivePageOrNull,
|
||||
attachPageListeners: lifecycle.attachPageListeners,
|
||||
captureCompactPageState: capture.captureCompactPageState,
|
||||
postActionSummary: capture.postActionSummary,
|
||||
constrainScreenshot: capture.constrainScreenshot,
|
||||
captureErrorScreenshot: capture.captureErrorScreenshot,
|
||||
formatCompactStateSummary: utils.formatCompactStateSummary,
|
||||
getRecentErrors: utils.getRecentErrors,
|
||||
settleAfterActionAdaptive: settle.settleAfterActionAdaptive,
|
||||
ensureMutationCounter: settle.ensureMutationCounter,
|
||||
buildRefSnapshot: refs.buildRefSnapshot,
|
||||
resolveRefTarget: refs.resolveRefTarget,
|
||||
parseRef: utils.parseRef,
|
||||
formatVersionedRef: utils.formatVersionedRef,
|
||||
staleRefGuidance: utils.staleRefGuidance,
|
||||
beginTrackedAction: utils.beginTrackedAction,
|
||||
finishTrackedAction: utils.finishTrackedAction,
|
||||
truncateText: utils.truncateText,
|
||||
verificationFromChecks: utils.verificationFromChecks,
|
||||
verificationLine: utils.verificationLine,
|
||||
collectAssertionState: (page: any, checks: any, target?: any) =>
|
||||
utils.collectAssertionState(
|
||||
page,
|
||||
checks,
|
||||
capture.captureCompactPageState,
|
||||
target,
|
||||
),
|
||||
formatAssertionText: utils.formatAssertionText,
|
||||
formatDiffText: utils.formatDiffText,
|
||||
getUrlHash: utils.getUrlHash,
|
||||
captureClickTargetState: utils.captureClickTargetState,
|
||||
readInputLikeValue: utils.readInputLikeValue,
|
||||
firstErrorLine: utils.firstErrorLine,
|
||||
captureAccessibilityMarkdown: (selector?: string) =>
|
||||
utils.captureAccessibilityMarkdown(
|
||||
lifecycle.getActiveTarget(),
|
||||
selector,
|
||||
),
|
||||
resolveAccessibilityScope: utils.resolveAccessibilityScope,
|
||||
getLivePagesSnapshot: utils.createGetLivePagesSnapshot(
|
||||
lifecycle.ensureBrowser,
|
||||
),
|
||||
getSinceTimestamp: utils.getSinceTimestamp,
|
||||
getConsoleEntriesSince: utils.getConsoleEntriesSince,
|
||||
getNetworkEntriesSince: utils.getNetworkEntriesSince,
|
||||
writeArtifactFile: utils.writeArtifactFile,
|
||||
copyArtifactFile: utils.copyArtifactFile,
|
||||
ensureSessionArtifactDir: utils.ensureSessionArtifactDir,
|
||||
buildSessionArtifactPath: utils.buildSessionArtifactPath,
|
||||
getSessionArtifactMetadata: utils.getSessionArtifactMetadata,
|
||||
sanitizeArtifactName: utils.sanitizeArtifactName,
|
||||
formatArtifactTimestamp: utils.formatArtifactTimestamp,
|
||||
};
|
||||
|
||||
navigation.registerNavigationTools(pi, deps);
|
||||
screenshot.registerScreenshotTools(pi, deps);
|
||||
interaction.registerInteractionTools(pi, deps);
|
||||
inspection.registerInspectionTools(pi, deps);
|
||||
session.registerSessionTools(pi, deps);
|
||||
assertions.registerAssertionTools(pi, deps);
|
||||
refTools.registerRefTools(pi, deps);
|
||||
wait.registerWaitTools(pi, deps);
|
||||
pages.registerPageTools(pi, deps);
|
||||
forms.registerFormTools(pi, deps);
|
||||
intent.registerIntentTools(pi, deps);
|
||||
pdf.registerPdfTools(pi, deps);
|
||||
statePersistence.registerStatePersistenceTools(pi, deps);
|
||||
networkMock.registerNetworkMockTools(pi, deps);
|
||||
device.registerDeviceTools(pi, deps);
|
||||
extract.registerExtractTools(pi, deps);
|
||||
visualDiff.registerVisualDiffTools(pi, deps);
|
||||
zoom.registerZoomTools(pi, deps);
|
||||
codegen.registerCodegenTools(pi, deps);
|
||||
actionCache.registerActionCacheTools(pi, deps);
|
||||
injectionDetection.registerInjectionDetectionTools(pi, deps);
|
||||
verify.registerVerifyTools(pi, deps);
|
||||
})().catch((error) => {
|
||||
registrationPromise = null;
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
|
||||
return registrationPromise;
|
||||
}
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
if (ctx.hasUI) {
|
||||
void registerBrowserTools(pi).catch((error) => {
|
||||
ctx.ui.notify(
|
||||
`browser-tools failed to load: ${error instanceof Error ? error.message : String(error)}`,
|
||||
"warning",
|
||||
);
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await registerBrowserTools(pi);
|
||||
});
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
const { closeBrowser } = await importExtensionModule<
|
||||
typeof import("./lifecycle.js")
|
||||
>(import.meta.url, "./lifecycle.js");
|
||||
await closeBrowser();
|
||||
});
|
||||
}
|
||||
|
|
@ -1,292 +0,0 @@
|
|||
/**
|
||||
* browser-tools — browser lifecycle management
|
||||
*
|
||||
* Manages the shared Browser + BrowserContext + Page singleton.
|
||||
* Injects EVALUATE_HELPERS_SOURCE via context.addInitScript() so that
|
||||
* page.evaluate() callbacks can reference window.__pi.* utilities.
|
||||
*/
|
||||
|
||||
import path from "node:path";
|
||||
import type { Browser, BrowserContext, Frame, Page } from "playwright";
|
||||
import {
|
||||
registryAddPage,
|
||||
registryGetActive,
|
||||
registryRemovePage,
|
||||
registrySetActive,
|
||||
} from "./core.js";
|
||||
import { EVALUATE_HELPERS_SOURCE } from "./evaluate-helpers.js";
|
||||
import {
|
||||
getActiveFrame,
|
||||
getBrowser,
|
||||
getConsoleLogs,
|
||||
getContext,
|
||||
getDialogLogs,
|
||||
getNetworkLogs,
|
||||
getPendingCriticalRequestsByPage,
|
||||
HAR_FILENAME,
|
||||
logPusher,
|
||||
type NetworkEntry,
|
||||
pageRegistry,
|
||||
resetAllState,
|
||||
setActiveFrame,
|
||||
setBrowser,
|
||||
setContext,
|
||||
setHarState,
|
||||
} from "./state.js";
|
||||
import {
|
||||
ensureSessionArtifactDir,
|
||||
ensureSessionStartedAt,
|
||||
isCriticalResourceType,
|
||||
updatePendingCriticalRequests,
|
||||
} from "./utils.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page event wiring
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Attach all event listeners to a page. Called on initial page and new tabs. */
|
||||
export function attachPageListeners(p: Page, pageId: number): void {
|
||||
const pendingMap = getPendingCriticalRequestsByPage();
|
||||
pendingMap.set(p, 0);
|
||||
|
||||
const consoleLogs = getConsoleLogs();
|
||||
const networkLogs = getNetworkLogs();
|
||||
const dialogLogs = getDialogLogs();
|
||||
|
||||
// Console messages
|
||||
p.on("console", (msg) => {
|
||||
logPusher(consoleLogs, {
|
||||
type: msg.type(),
|
||||
text: msg.text(),
|
||||
timestamp: Date.now(),
|
||||
url: p.url(),
|
||||
pageId,
|
||||
});
|
||||
});
|
||||
|
||||
// Uncaught JS errors
|
||||
p.on("pageerror", (err) => {
|
||||
logPusher(consoleLogs, {
|
||||
type: "pageerror",
|
||||
text: err.message,
|
||||
timestamp: Date.now(),
|
||||
url: p.url(),
|
||||
pageId,
|
||||
});
|
||||
});
|
||||
|
||||
// Network requests — start/completed/failed
|
||||
p.on("request", (request) => {
|
||||
if (isCriticalResourceType(request.resourceType())) {
|
||||
updatePendingCriticalRequests(p, 1);
|
||||
}
|
||||
});
|
||||
|
||||
p.on("requestfinished", async (request) => {
|
||||
if (isCriticalResourceType(request.resourceType())) {
|
||||
updatePendingCriticalRequests(p, -1);
|
||||
}
|
||||
try {
|
||||
const response = await request.response();
|
||||
const status = response?.status() ?? null;
|
||||
const entry: NetworkEntry = {
|
||||
method: request.method(),
|
||||
url: request.url(),
|
||||
status,
|
||||
resourceType: request.resourceType(),
|
||||
timestamp: Date.now(),
|
||||
failed: false,
|
||||
pageId,
|
||||
};
|
||||
if (response && status !== null && status >= 400) {
|
||||
try {
|
||||
const body = await response.text();
|
||||
entry.responseBody = body.slice(0, 2000);
|
||||
} catch {
|
||||
/* non-fatal — response body may be unavailable or already consumed */
|
||||
}
|
||||
}
|
||||
logPusher(networkLogs, entry);
|
||||
} catch {
|
||||
/* non-fatal — request may have been aborted or page closed */
|
||||
}
|
||||
});
|
||||
|
||||
p.on("requestfailed", (request) => {
|
||||
if (isCriticalResourceType(request.resourceType())) {
|
||||
updatePendingCriticalRequests(p, -1);
|
||||
}
|
||||
logPusher(networkLogs, {
|
||||
method: request.method(),
|
||||
url: request.url(),
|
||||
status: null,
|
||||
resourceType: request.resourceType(),
|
||||
timestamp: Date.now(),
|
||||
failed: true,
|
||||
failureText: request.failure()?.errorText ?? "Unknown failure",
|
||||
pageId,
|
||||
});
|
||||
});
|
||||
|
||||
// Auto-handle JS dialogs (alert, confirm, prompt, beforeunload)
|
||||
p.on("dialog", async (dialog) => {
|
||||
logPusher(dialogLogs, {
|
||||
type: dialog.type(),
|
||||
message: dialog.message(),
|
||||
timestamp: Date.now(),
|
||||
url: p.url(),
|
||||
defaultValue: dialog.defaultValue() || undefined,
|
||||
accepted: true,
|
||||
pageId,
|
||||
});
|
||||
// Auto-accept all dialogs to prevent page freezes
|
||||
await dialog.accept().catch(() => {
|
||||
/* cleanup — dialog may already be dismissed */
|
||||
});
|
||||
});
|
||||
|
||||
// Frame detach handler — clears activeFrame if the selected frame detaches
|
||||
p.on("framedetached", (frame) => {
|
||||
if (getActiveFrame() === frame) setActiveFrame(null);
|
||||
});
|
||||
|
||||
// Page close handler — removes page from registry and handles active fallback
|
||||
p.on("close", () => {
|
||||
try {
|
||||
registryRemovePage(pageRegistry, pageId);
|
||||
} catch {
|
||||
// Page already removed (e.g. during closeBrowser)
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Browser lifecycle
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function ensureBrowser(): Promise<{
|
||||
browser: Browser;
|
||||
context: BrowserContext;
|
||||
page: Page;
|
||||
}> {
|
||||
const existingBrowser = getBrowser();
|
||||
const existingContext = getContext();
|
||||
if (existingBrowser && existingContext) {
|
||||
return {
|
||||
browser: existingBrowser,
|
||||
context: existingContext,
|
||||
page: getActivePage(),
|
||||
};
|
||||
}
|
||||
|
||||
const _startedAt = ensureSessionStartedAt();
|
||||
const artifactDir = await ensureSessionArtifactDir();
|
||||
const sessionHarPath = path.join(artifactDir, HAR_FILENAME);
|
||||
setHarState({
|
||||
enabled: true,
|
||||
configuredAtContextCreation: true,
|
||||
path: sessionHarPath,
|
||||
exportCount: 0,
|
||||
lastExportedPath: null,
|
||||
lastExportedAt: null,
|
||||
});
|
||||
|
||||
// Lazy import so playwright is only loaded when actually needed
|
||||
const { chromium } = await import("playwright");
|
||||
|
||||
// Auto-detect headless environments: Linux without $DISPLAY has no GUI.
|
||||
// All browser tool operations (navigation, screenshots, DOM) work in headless mode.
|
||||
const needsHeadless = process.platform === "linux" && !process.env.DISPLAY;
|
||||
const launchOptions: Record<string, unknown> = {
|
||||
headless: needsHeadless || process.env.FORCE_HEADLESS === "true",
|
||||
};
|
||||
const customPath = process.env.BROWSER_PATH;
|
||||
if (customPath) launchOptions.executablePath = customPath;
|
||||
const browser = await chromium.launch(launchOptions);
|
||||
const context = await browser.newContext({
|
||||
deviceScaleFactor: 2,
|
||||
viewport: { width: 1280, height: 800 },
|
||||
recordHar: {
|
||||
path: sessionHarPath,
|
||||
mode: "minimal",
|
||||
content: "omit",
|
||||
},
|
||||
});
|
||||
|
||||
// Inject shared browser-side utilities into every new page/frame
|
||||
await context.addInitScript(EVALUATE_HELPERS_SOURCE);
|
||||
|
||||
setBrowser(browser);
|
||||
setContext(context);
|
||||
|
||||
const initialPage = await context.newPage();
|
||||
const pageEntry = registryAddPage(pageRegistry, {
|
||||
page: initialPage,
|
||||
title: await initialPage.title().catch(() => ""),
|
||||
url: initialPage.url(),
|
||||
opener: null,
|
||||
});
|
||||
registrySetActive(pageRegistry, pageEntry.id);
|
||||
attachPageListeners(initialPage, pageEntry.id);
|
||||
|
||||
// Register new pages (popups, target="_blank", window.open) but do NOT auto-switch
|
||||
context.on("page", (newPage) => {
|
||||
// Determine opener page ID — find which registry page opened this one
|
||||
const openerPage = newPage.opener();
|
||||
let openerId: number | null = null;
|
||||
if (openerPage) {
|
||||
const openerEntry = pageRegistry.pages.find(
|
||||
(e: any) => e.page === openerPage,
|
||||
);
|
||||
if (openerEntry) openerId = openerEntry.id;
|
||||
}
|
||||
const entry = registryAddPage(pageRegistry, {
|
||||
page: newPage,
|
||||
title: "",
|
||||
url: newPage.url(),
|
||||
opener: openerId,
|
||||
});
|
||||
attachPageListeners(newPage, entry.id);
|
||||
// Update title once loaded
|
||||
newPage
|
||||
.waitForLoadState("domcontentloaded", { timeout: 5000 })
|
||||
.then(() => newPage.title())
|
||||
.then((title) => {
|
||||
entry.title = title;
|
||||
})
|
||||
.catch(() => {
|
||||
/* best-effort title fetch — page may have closed or navigated away */
|
||||
});
|
||||
});
|
||||
|
||||
return { browser, context, page: getActivePage() };
|
||||
}
|
||||
|
||||
/** Get the currently active page from the registry. */
|
||||
export function getActivePage(): Page {
|
||||
return registryGetActive(pageRegistry).page;
|
||||
}
|
||||
|
||||
/** Get the active target — returns the selected frame if one is active, otherwise the active page. */
|
||||
export function getActiveTarget(): Page | Frame {
|
||||
return getActiveFrame() ?? getActivePage();
|
||||
}
|
||||
|
||||
/** Safe accessor for error handling — returns the active page or null if unavailable. */
|
||||
export function getActivePageOrNull(): Page | null {
|
||||
try {
|
||||
return getActivePage();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function closeBrowser(): Promise<void> {
|
||||
const browser = getBrowser();
|
||||
if (browser) {
|
||||
await browser.close().catch(() => {
|
||||
/* cleanup — browser may already be closed */
|
||||
});
|
||||
}
|
||||
resetAllState();
|
||||
}
|
||||
|
|
@ -1,29 +1,29 @@
|
|||
{
|
||||
"name": "pi-browser-tools",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"scripts": {
|
||||
"test": "node --test tests/*.test.mjs"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
},
|
||||
"peerDependencies": {
|
||||
"playwright": ">=1.40.0",
|
||||
"sharp": ">=0.33.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"playwright": {
|
||||
"optional": true
|
||||
},
|
||||
"sharp": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
"name": "pi-browser-tools",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"scripts": {
|
||||
"test": "node --test tests/*.test.mjs"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.js"
|
||||
]
|
||||
},
|
||||
"peerDependencies": {
|
||||
"playwright": ">=1.40.0",
|
||||
"sharp": ">=0.33.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"playwright": {
|
||||
"optional": true
|
||||
},
|
||||
"sharp": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,322 +0,0 @@
|
|||
/**
|
||||
* browser-tools — ref snapshot and resolution
|
||||
*
|
||||
* Builds deterministic element snapshots and resolves ref targets.
|
||||
* Uses window.__pi.* utilities injected via addInitScript (from
|
||||
* evaluate-helpers.ts) instead of redeclaring functions inline.
|
||||
*
|
||||
* Functions kept inline (not shared/duplicated):
|
||||
* - matchesMode, computeNearestHeading, computeFormOwnership
|
||||
*/
|
||||
|
||||
import type { Frame, Page } from "playwright";
|
||||
import { getSnapshotModeConfig } from "./core.js";
|
||||
import type { RefNode } from "./state.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildRefSnapshot
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function buildRefSnapshot(
|
||||
target: Page | Frame,
|
||||
options: {
|
||||
selector?: string;
|
||||
interactiveOnly: boolean;
|
||||
limit: number;
|
||||
mode?: string;
|
||||
},
|
||||
): Promise<Array<Omit<RefNode, "ref">>> {
|
||||
// Resolve mode config in Node context and serialize it as plain data for the evaluate callback
|
||||
const modeConfig = options.mode ? getSnapshotModeConfig(options.mode) : null;
|
||||
return await target.evaluate(
|
||||
({ selector, interactiveOnly, limit, modeConfig: mc }) => {
|
||||
const root = selector ? document.querySelector(selector) : document.body;
|
||||
if (!root) {
|
||||
throw new Error(`Selector scope not found: ${selector}`);
|
||||
}
|
||||
|
||||
// Use injected window.__pi utilities
|
||||
const pi = (window as any).__pi;
|
||||
const simpleHash = pi.simpleHash;
|
||||
const isVisible = pi.isVisible;
|
||||
const isEnabled = pi.isEnabled;
|
||||
const inferRole = pi.inferRole;
|
||||
const accessibleName = pi.accessibleName;
|
||||
const isInteractiveEl = pi.isInteractiveEl;
|
||||
const cssPath = pi.cssPath;
|
||||
const domPath = pi.domPath;
|
||||
const selectorHints = pi.selectorHints;
|
||||
|
||||
// Mode-based element matching — used when a snapshot mode config is provided
|
||||
const matchesMode = (
|
||||
el: Element,
|
||||
cfg: {
|
||||
tags: string[];
|
||||
roles: string[];
|
||||
selectors: string[];
|
||||
ariaAttributes: string[];
|
||||
},
|
||||
): boolean => {
|
||||
const tag = el.tagName.toLowerCase();
|
||||
if (cfg.tags.length > 0 && cfg.tags.includes(tag)) return true;
|
||||
const role = inferRole(el);
|
||||
if (cfg.roles.length > 0 && cfg.roles.includes(role)) return true;
|
||||
for (const sel of cfg.selectors) {
|
||||
try {
|
||||
if (el.matches(sel)) return true;
|
||||
} catch {
|
||||
/* invalid selector, skip */
|
||||
}
|
||||
}
|
||||
for (const attr of cfg.ariaAttributes) {
|
||||
if (el.hasAttribute(attr)) return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
let elements = Array.from(root.querySelectorAll("*"));
|
||||
|
||||
if (mc) {
|
||||
// Mode takes precedence over interactiveOnly
|
||||
if (mc.visibleOnly) {
|
||||
// visible_only mode: include all elements that are visible
|
||||
elements = elements.filter((el) => isVisible(el));
|
||||
} else if (mc.useInteractiveFilter) {
|
||||
// interactive mode: reuse existing isInteractiveEl
|
||||
elements = elements.filter((el) => isInteractiveEl(el));
|
||||
} else if (mc.containerExpand) {
|
||||
// Container-expanding modes (dialog, errors): match containers, then include
|
||||
// all interactive children of those containers, plus the containers themselves
|
||||
const containers: Element[] = [];
|
||||
const directMatches: Element[] = [];
|
||||
for (const el of elements) {
|
||||
if (matchesMode(el, mc)) {
|
||||
// Check if this is a container element (has children)
|
||||
const childEls = el.querySelectorAll("*");
|
||||
if (childEls.length > 0) {
|
||||
containers.push(el);
|
||||
} else {
|
||||
directMatches.push(el);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Collect container elements + all interactive children inside containers
|
||||
const result = new Set<Element>(directMatches);
|
||||
for (const container of containers) {
|
||||
result.add(container);
|
||||
const children = Array.from(container.querySelectorAll("*"));
|
||||
for (const child of children) {
|
||||
if (isInteractiveEl(child)) result.add(child);
|
||||
}
|
||||
}
|
||||
elements = Array.from(result);
|
||||
} else {
|
||||
// Standard mode filtering by tag/role/selector/ariaAttribute
|
||||
elements = elements.filter((el) => matchesMode(el, mc));
|
||||
}
|
||||
} else if (!interactiveOnly) {
|
||||
if (root instanceof Element) elements.unshift(root);
|
||||
} else {
|
||||
elements = elements.filter((el) => isInteractiveEl(el));
|
||||
}
|
||||
|
||||
const seen = new Set<Element>();
|
||||
const unique = elements.filter((el) => {
|
||||
if (seen.has(el)) return false;
|
||||
seen.add(el);
|
||||
return true;
|
||||
});
|
||||
|
||||
// Fingerprint helpers — computed for each element in the snapshot
|
||||
const computeNearestHeading = (el: Element): string => {
|
||||
const headingTags = new Set(["H1", "H2", "H3", "H4", "H5", "H6"]);
|
||||
// Walk up ancestors looking for heading or preceding-sibling heading
|
||||
let current: Element | null = el;
|
||||
while (current && current !== document.body) {
|
||||
// Check preceding siblings of current
|
||||
let sib: Element | null = current.previousElementSibling;
|
||||
while (sib) {
|
||||
if (
|
||||
headingTags.has(sib.tagName) ||
|
||||
sib.getAttribute("role") === "heading"
|
||||
) {
|
||||
return (sib.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 80);
|
||||
}
|
||||
sib = sib.previousElementSibling;
|
||||
}
|
||||
// Check if the parent itself is a heading (unlikely but possible)
|
||||
const parent: Element | null = current.parentElement;
|
||||
if (
|
||||
parent &&
|
||||
(headingTags.has(parent.tagName) ||
|
||||
parent.getAttribute("role") === "heading")
|
||||
) {
|
||||
return (parent.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 80);
|
||||
}
|
||||
current = parent;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
const computeFormOwnership = (el: Element): string => {
|
||||
// Check form attribute (explicit form association)
|
||||
const formAttr = el.getAttribute("form");
|
||||
if (formAttr) return formAttr;
|
||||
// Walk up ancestors looking for <form>
|
||||
let current: Element | null = el.parentElement;
|
||||
while (current && current !== document.body) {
|
||||
if (current.tagName === "FORM") {
|
||||
return (
|
||||
(current as HTMLFormElement).id ||
|
||||
(current as HTMLFormElement).name ||
|
||||
"form"
|
||||
);
|
||||
}
|
||||
current = current.parentElement;
|
||||
}
|
||||
return "";
|
||||
};
|
||||
|
||||
return unique.slice(0, limit).map((el) => {
|
||||
const tag = el.tagName.toLowerCase();
|
||||
const role = inferRole(el);
|
||||
const textContent = (el.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 200);
|
||||
const childTags = Array.from(el.children).map((c) =>
|
||||
c.tagName.toLowerCase(),
|
||||
);
|
||||
|
||||
return {
|
||||
tag,
|
||||
role,
|
||||
name: accessibleName(el),
|
||||
selectorHints: selectorHints(el),
|
||||
isVisible: isVisible(el),
|
||||
isEnabled: isEnabled(el),
|
||||
xpathOrPath: cssPath(el),
|
||||
href: el.getAttribute("href") || undefined,
|
||||
type: el.getAttribute("type") || undefined,
|
||||
path: domPath(el),
|
||||
contentHash: simpleHash(textContent),
|
||||
structuralSignature: simpleHash(
|
||||
`${tag}|${role}|${childTags.join(",")}`,
|
||||
),
|
||||
nearestHeading: computeNearestHeading(el),
|
||||
formOwnership: computeFormOwnership(el),
|
||||
};
|
||||
});
|
||||
},
|
||||
{ ...options, modeConfig },
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resolveRefTarget
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function resolveRefTarget(
|
||||
target: Page | Frame,
|
||||
node: RefNode,
|
||||
): Promise<{ ok: true; selector: string } | { ok: false; reason: string }> {
|
||||
return await target.evaluate((refNode) => {
|
||||
// Use injected window.__pi utilities
|
||||
const pi = (window as any).__pi;
|
||||
const cssPath = pi.cssPath;
|
||||
const simpleHash = pi.simpleHash;
|
||||
|
||||
const byPath = (): Element | null => {
|
||||
let current: Element | null = document.documentElement;
|
||||
for (const idx of refNode.path || []) {
|
||||
if (!current || idx < 0 || idx >= current.children.length) return null;
|
||||
current = current.children[idx] as Element;
|
||||
}
|
||||
return current;
|
||||
};
|
||||
|
||||
const nodeName = (el: Element): string => {
|
||||
return (
|
||||
el.getAttribute("aria-label")?.trim() ||
|
||||
(el as HTMLInputElement).value?.trim() ||
|
||||
el.getAttribute("placeholder")?.trim() ||
|
||||
(el.textContent || "").trim().replace(/\s+/g, " ").slice(0, 80)
|
||||
);
|
||||
};
|
||||
|
||||
// Tier 1: path-based resolution
|
||||
const pathEl = byPath();
|
||||
if (pathEl && pathEl.tagName.toLowerCase() === refNode.tag) {
|
||||
return { ok: true as const, selector: cssPath(pathEl) };
|
||||
}
|
||||
|
||||
// Tier 2: selector hints
|
||||
for (const hint of refNode.selectorHints || []) {
|
||||
try {
|
||||
const el = document.querySelector(hint);
|
||||
if (!el) continue;
|
||||
if (el.tagName.toLowerCase() !== refNode.tag) continue;
|
||||
return { ok: true as const, selector: cssPath(el) };
|
||||
} catch {
|
||||
// ignore malformed selector hint
|
||||
}
|
||||
}
|
||||
|
||||
// Tier 3: role + name match
|
||||
const candidates = Array.from(document.querySelectorAll(refNode.tag));
|
||||
const matchTarget = candidates.find((el) => {
|
||||
const role = el.getAttribute("role") || "";
|
||||
const name = nodeName(el);
|
||||
const roleMatch = !refNode.role || role === refNode.role;
|
||||
const nameMatch =
|
||||
!!refNode.name && name.toLowerCase() === refNode.name.toLowerCase();
|
||||
return roleMatch && nameMatch;
|
||||
});
|
||||
if (matchTarget) {
|
||||
return { ok: true as const, selector: cssPath(matchTarget) };
|
||||
}
|
||||
|
||||
// Tier 4: structural signature + content hash fingerprint matching
|
||||
if (refNode.contentHash && refNode.structuralSignature) {
|
||||
const fpMatches: Element[] = [];
|
||||
for (const candidate of candidates) {
|
||||
const tag = candidate.tagName.toLowerCase();
|
||||
const role = candidate.getAttribute("role") || "";
|
||||
const textContent = (candidate.textContent || "")
|
||||
.trim()
|
||||
.replace(/\s+/g, " ")
|
||||
.slice(0, 200);
|
||||
const childTags = Array.from(candidate.children).map((c) =>
|
||||
c.tagName.toLowerCase(),
|
||||
);
|
||||
const candidateContentHash = simpleHash(textContent);
|
||||
const candidateStructSig = simpleHash(
|
||||
`${tag}|${role}|${childTags.join(",")}`,
|
||||
);
|
||||
if (
|
||||
candidateContentHash === refNode.contentHash &&
|
||||
candidateStructSig === refNode.structuralSignature
|
||||
) {
|
||||
fpMatches.push(candidate);
|
||||
}
|
||||
}
|
||||
if (fpMatches.length === 1) {
|
||||
return { ok: true as const, selector: cssPath(fpMatches[0]) };
|
||||
}
|
||||
if (fpMatches.length > 1) {
|
||||
return {
|
||||
ok: false as const,
|
||||
reason: "multiple fingerprint matches — ambiguous",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return { ok: false as const, reason: "element not found in current DOM" };
|
||||
}, node);
|
||||
}
|
||||
|
|
@ -1,219 +0,0 @@
|
|||
/**
|
||||
* browser-tools — DOM settle logic
|
||||
*
|
||||
* Adaptive settling after browser actions. Polls for DOM quiet (mutation
|
||||
* counter stable, no pending critical requests, optional focus stability)
|
||||
* before returning control.
|
||||
*/
|
||||
|
||||
import type { Frame, Page } from "playwright";
|
||||
import type { AdaptiveSettleDetails, AdaptiveSettleOptions } from "./state.js";
|
||||
import { getPendingCriticalRequests } from "./utils.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mutation counter (installed in-page via evaluate)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function ensureMutationCounter(p: Page): Promise<void> {
|
||||
await p.evaluate(() => {
|
||||
const key = "__piMutationCounter" as const;
|
||||
const installedKey = "__piMutationCounterInstalled" as const;
|
||||
const w = window as unknown as Record<string, unknown>;
|
||||
if (typeof w[key] !== "number") w[key] = 0;
|
||||
if (w[installedKey]) return;
|
||||
const observer = new MutationObserver(() => {
|
||||
const current = typeof w[key] === "number" ? (w[key] as number) : 0;
|
||||
w[key] = current + 1;
|
||||
});
|
||||
observer.observe(document.documentElement || document.body, {
|
||||
subtree: true,
|
||||
childList: true,
|
||||
attributes: true,
|
||||
characterData: true,
|
||||
});
|
||||
w[installedKey] = true;
|
||||
});
|
||||
}
|
||||
|
||||
export async function readMutationCounter(p: Page): Promise<number> {
|
||||
try {
|
||||
return await p.evaluate(() => {
|
||||
const w = window as unknown as Record<string, unknown>;
|
||||
const value = w.__piMutationCounter;
|
||||
return typeof value === "number" ? value : 0;
|
||||
});
|
||||
} catch {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Focus descriptor (for focus-stability checks)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function readFocusedDescriptor(
|
||||
target: Page | Frame,
|
||||
): Promise<string> {
|
||||
try {
|
||||
return await target.evaluate(() => {
|
||||
const el = document.activeElement as HTMLElement | null;
|
||||
if (!el || el === document.body || el === document.documentElement)
|
||||
return "";
|
||||
const id = el.id ? `#${el.id}` : "";
|
||||
const role = el.getAttribute("role") || "";
|
||||
const name = (
|
||||
el.getAttribute("aria-label") ||
|
||||
el.getAttribute("name") ||
|
||||
""
|
||||
).trim();
|
||||
return `${el.tagName.toLowerCase()}${id}|${role}|${name}`;
|
||||
});
|
||||
} catch {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Combined settle-state reader (mutation counter + focus in one evaluate)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Reads the mutation counter and optionally the focused element descriptor
|
||||
* in a single `evaluate()` call, saving one round-trip per poll iteration.
|
||||
*/
|
||||
async function readSettleState(
|
||||
target: Page | Frame,
|
||||
checkFocus: boolean,
|
||||
): Promise<{ mutationCount: number; focusDescriptor: string }> {
|
||||
try {
|
||||
return await target.evaluate((wantFocus: boolean) => {
|
||||
const w = window as unknown as Record<string, unknown>;
|
||||
const mutationCount =
|
||||
typeof w.__piMutationCounter === "number"
|
||||
? (w.__piMutationCounter as number)
|
||||
: 0;
|
||||
if (!wantFocus) return { mutationCount, focusDescriptor: "" };
|
||||
const el = document.activeElement as HTMLElement | null;
|
||||
if (!el || el === document.body || el === document.documentElement) {
|
||||
return { mutationCount, focusDescriptor: "" };
|
||||
}
|
||||
const id = el.id ? `#${el.id}` : "";
|
||||
const role = el.getAttribute("role") || "";
|
||||
const name = (
|
||||
el.getAttribute("aria-label") ||
|
||||
el.getAttribute("name") ||
|
||||
""
|
||||
).trim();
|
||||
return {
|
||||
mutationCount,
|
||||
focusDescriptor: `${el.tagName.toLowerCase()}${id}|${role}|${name}`,
|
||||
};
|
||||
}, checkFocus);
|
||||
} catch {
|
||||
return { mutationCount: 0, focusDescriptor: "" };
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Adaptive settle
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Threshold (ms) after which zero mutations triggers a shortened quiet window. */
|
||||
const ZERO_MUTATION_THRESHOLD_MS = 60;
|
||||
/** Shortened quiet window when no mutations have been observed. */
|
||||
const ZERO_MUTATION_QUIET_MS = 30;
|
||||
|
||||
export async function settleAfterActionAdaptive(
|
||||
p: Page,
|
||||
opts: AdaptiveSettleOptions = {},
|
||||
): Promise<AdaptiveSettleDetails> {
|
||||
const timeoutMs = Math.max(150, opts.timeoutMs ?? 500);
|
||||
const pollMs = Math.min(100, Math.max(20, opts.pollMs ?? 40));
|
||||
const baseQuietWindowMs = Math.max(60, opts.quietWindowMs ?? 100);
|
||||
const checkFocus = opts.checkFocusStability ?? false;
|
||||
|
||||
const startedAt = Date.now();
|
||||
let polls = 0;
|
||||
let sawUrlChange = false;
|
||||
let lastActivityAt = startedAt;
|
||||
let previousUrl = p.url();
|
||||
let totalMutationsSeen = 0;
|
||||
let activeQuietWindowMs = baseQuietWindowMs;
|
||||
|
||||
// Install mutation counter + read initial state in one evaluate sequence.
|
||||
// ensureMutationCounter must run first (installs the observer), then we
|
||||
// read the baseline via the combined reader.
|
||||
await ensureMutationCounter(p).catch((e) => {
|
||||
if (process.env.SF_DEBUG)
|
||||
console.error("[browser-tools] ensureMutationCounter failed:", e.message);
|
||||
});
|
||||
const initial = await readSettleState(p, checkFocus);
|
||||
let previousMutationCount = initial.mutationCount;
|
||||
let previousFocus = initial.focusDescriptor;
|
||||
|
||||
while (Date.now() - startedAt < timeoutMs) {
|
||||
await new Promise((resolve) => setTimeout(resolve, pollMs));
|
||||
polls += 1;
|
||||
const now = Date.now();
|
||||
|
||||
const currentUrl = p.url();
|
||||
if (currentUrl !== previousUrl) {
|
||||
sawUrlChange = true;
|
||||
previousUrl = currentUrl;
|
||||
lastActivityAt = now;
|
||||
}
|
||||
|
||||
// Single combined evaluate for mutation count + focus descriptor.
|
||||
const state = await readSettleState(p, checkFocus);
|
||||
|
||||
if (state.mutationCount > previousMutationCount) {
|
||||
totalMutationsSeen += state.mutationCount - previousMutationCount;
|
||||
previousMutationCount = state.mutationCount;
|
||||
lastActivityAt = now;
|
||||
}
|
||||
|
||||
if (checkFocus && state.focusDescriptor !== previousFocus) {
|
||||
previousFocus = state.focusDescriptor;
|
||||
lastActivityAt = now;
|
||||
}
|
||||
|
||||
const pendingCritical = getPendingCriticalRequests(p);
|
||||
if (pendingCritical > 0) {
|
||||
lastActivityAt = now;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Zero-mutation short-circuit: after ZERO_MUTATION_THRESHOLD_MS with
|
||||
// no mutations observed at all, reduce the quiet window to settle faster.
|
||||
if (
|
||||
totalMutationsSeen === 0 &&
|
||||
now - startedAt >= ZERO_MUTATION_THRESHOLD_MS &&
|
||||
activeQuietWindowMs !== ZERO_MUTATION_QUIET_MS
|
||||
) {
|
||||
activeQuietWindowMs = ZERO_MUTATION_QUIET_MS;
|
||||
}
|
||||
|
||||
if (now - lastActivityAt >= activeQuietWindowMs) {
|
||||
const usedShortcut =
|
||||
activeQuietWindowMs === ZERO_MUTATION_QUIET_MS &&
|
||||
totalMutationsSeen === 0;
|
||||
return {
|
||||
settleMode: "adaptive",
|
||||
settleMs: now - startedAt,
|
||||
settleReason: usedShortcut
|
||||
? "zero_mutation_shortcut"
|
||||
: sawUrlChange
|
||||
? "url_changed_then_quiet"
|
||||
: "dom_quiet",
|
||||
settlePolls: polls,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
settleMode: "adaptive",
|
||||
settleMs: Date.now() - startedAt,
|
||||
settleReason: "timeout_fallback",
|
||||
settlePolls: polls,
|
||||
};
|
||||
}
|
||||
|
|
@ -1,535 +0,0 @@
|
|||
/**
|
||||
* browser-tools — shared mutable state
|
||||
*
|
||||
* All mutable state lives behind accessor functions (get/set) so that
|
||||
* jiti-transpiled modules see updates reliably. ES module live bindings
|
||||
* (`export let`) are not guaranteed to work under jiti's CJS shim layer.
|
||||
*
|
||||
* State is initialized to sensible defaults and can be bulk-reset via
|
||||
* `resetAllState()` (called by closeBrowser).
|
||||
*/
|
||||
|
||||
import path from "node:path";
|
||||
import type { Browser, BrowserContext, Frame, Page } from "playwright";
|
||||
import {
|
||||
createActionTimeline,
|
||||
createBoundedLogPusher,
|
||||
createPageRegistry,
|
||||
} from "./core.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Constants
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const ARTIFACT_ROOT = path.resolve(
|
||||
process.cwd(),
|
||||
".artifacts",
|
||||
"browser",
|
||||
);
|
||||
export const HAR_FILENAME = "session.har";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Type / interface definitions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export interface ConsoleEntry {
|
||||
type: string;
|
||||
text: string;
|
||||
timestamp: number;
|
||||
url: string;
|
||||
pageId: number;
|
||||
}
|
||||
|
||||
export interface NetworkEntry {
|
||||
method: string;
|
||||
url: string;
|
||||
status: number | null;
|
||||
resourceType: string;
|
||||
timestamp: number;
|
||||
failed: boolean;
|
||||
failureText?: string;
|
||||
responseBody?: string;
|
||||
pageId: number;
|
||||
}
|
||||
|
||||
export interface DialogEntry {
|
||||
type: string;
|
||||
message: string;
|
||||
timestamp: number;
|
||||
url: string;
|
||||
defaultValue?: string;
|
||||
accepted: boolean;
|
||||
pageId: number;
|
||||
}
|
||||
|
||||
export interface RefNode {
|
||||
ref: string;
|
||||
tag: string;
|
||||
role: string;
|
||||
name: string;
|
||||
selectorHints: string[];
|
||||
isVisible: boolean;
|
||||
isEnabled: boolean;
|
||||
xpathOrPath: string;
|
||||
href?: string;
|
||||
type?: string;
|
||||
path: number[];
|
||||
contentHash?: string;
|
||||
structuralSignature?: string;
|
||||
nearestHeading?: string;
|
||||
formOwnership?: string;
|
||||
}
|
||||
|
||||
export interface RefMetadata {
|
||||
url: string;
|
||||
timestamp: number;
|
||||
selectorScope?: string;
|
||||
interactiveOnly: boolean;
|
||||
limit: number;
|
||||
version: number;
|
||||
frameContext?: string;
|
||||
mode?: string;
|
||||
}
|
||||
|
||||
export interface CompactSelectorState {
|
||||
exists: boolean;
|
||||
visible: boolean;
|
||||
value: string;
|
||||
checked: boolean | null;
|
||||
text: string;
|
||||
}
|
||||
|
||||
export interface CompactPageState {
|
||||
url: string;
|
||||
title: string;
|
||||
focus: string;
|
||||
headings: string[];
|
||||
bodyText: string;
|
||||
counts: {
|
||||
landmarks: number;
|
||||
buttons: number;
|
||||
links: number;
|
||||
inputs: number;
|
||||
};
|
||||
dialog: {
|
||||
count: number;
|
||||
title: string;
|
||||
};
|
||||
selectorStates: Record<string, CompactSelectorState>;
|
||||
}
|
||||
|
||||
export interface TraceSessionState {
|
||||
startedAt: number;
|
||||
name: string;
|
||||
title?: string;
|
||||
path?: string;
|
||||
}
|
||||
|
||||
export interface HarState {
|
||||
enabled: boolean;
|
||||
configuredAtContextCreation: boolean;
|
||||
path: string | null;
|
||||
exportCount: number;
|
||||
lastExportedPath: string | null;
|
||||
lastExportedAt: number | null;
|
||||
}
|
||||
|
||||
export interface ClickTargetStateSnapshot {
|
||||
exists: boolean;
|
||||
ariaExpanded: string | null;
|
||||
ariaPressed: string | null;
|
||||
ariaSelected: string | null;
|
||||
open: boolean | null;
|
||||
}
|
||||
|
||||
export interface BrowserVerificationCheck {
|
||||
name: string;
|
||||
passed: boolean;
|
||||
value?: unknown;
|
||||
expected?: unknown;
|
||||
}
|
||||
|
||||
export interface BrowserVerificationResult {
|
||||
verified: boolean;
|
||||
checks: BrowserVerificationCheck[];
|
||||
verificationSummary: string;
|
||||
retryHint?: string;
|
||||
}
|
||||
|
||||
export interface AdaptiveSettleOptions {
|
||||
timeoutMs?: number;
|
||||
pollMs?: number;
|
||||
quietWindowMs?: number;
|
||||
checkFocusStability?: boolean;
|
||||
}
|
||||
|
||||
export interface AdaptiveSettleDetails {
|
||||
settleMode: "adaptive";
|
||||
settleMs: number;
|
||||
settleReason:
|
||||
| "dom_quiet"
|
||||
| "url_changed_then_quiet"
|
||||
| "timeout_fallback"
|
||||
| "zero_mutation_shortcut";
|
||||
settlePolls: number;
|
||||
}
|
||||
|
||||
export interface ParsedRefSpec {
|
||||
key: string;
|
||||
version: number | null;
|
||||
display: string;
|
||||
}
|
||||
|
||||
export interface BrowserAssertionCheckInput {
|
||||
kind: string;
|
||||
selector?: string;
|
||||
text?: string;
|
||||
value?: string;
|
||||
checked?: boolean;
|
||||
sinceActionId?: number;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Mutable state variables — accessed only via get/set functions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// 1. browser
|
||||
let _browser: Browser | null = null;
|
||||
export function getBrowser(): Browser | null {
|
||||
return _browser;
|
||||
}
|
||||
export function setBrowser(b: Browser | null): void {
|
||||
_browser = b;
|
||||
}
|
||||
|
||||
// 2. context
|
||||
let _context: BrowserContext | null = null;
|
||||
export function getContext(): BrowserContext | null {
|
||||
return _context;
|
||||
}
|
||||
export function setContext(c: BrowserContext | null): void {
|
||||
_context = c;
|
||||
}
|
||||
|
||||
// 3. pageRegistry (object with internal state — export the instance directly + getter)
|
||||
export const pageRegistry = createPageRegistry();
|
||||
export function getPageRegistry() {
|
||||
return pageRegistry;
|
||||
}
|
||||
|
||||
// 4. activeFrame
|
||||
let _activeFrame: Frame | null = null;
|
||||
export function getActiveFrame(): Frame | null {
|
||||
return _activeFrame;
|
||||
}
|
||||
export function setActiveFrame(f: Frame | null): void {
|
||||
_activeFrame = f;
|
||||
}
|
||||
|
||||
// 5. logPusher (bounded log push function — stateless utility, export directly)
|
||||
export const logPusher = createBoundedLogPusher(1000);
|
||||
|
||||
// 6. consoleLogs
|
||||
let _consoleLogs: ConsoleEntry[] = [];
|
||||
export function getConsoleLogs(): ConsoleEntry[] {
|
||||
return _consoleLogs;
|
||||
}
|
||||
export function setConsoleLogs(logs: ConsoleEntry[]): void {
|
||||
_consoleLogs = logs;
|
||||
}
|
||||
|
||||
// 7. networkLogs
|
||||
let _networkLogs: NetworkEntry[] = [];
|
||||
export function getNetworkLogs(): NetworkEntry[] {
|
||||
return _networkLogs;
|
||||
}
|
||||
export function setNetworkLogs(logs: NetworkEntry[]): void {
|
||||
_networkLogs = logs;
|
||||
}
|
||||
|
||||
// 8. dialogLogs
|
||||
let _dialogLogs: DialogEntry[] = [];
|
||||
export function getDialogLogs(): DialogEntry[] {
|
||||
return _dialogLogs;
|
||||
}
|
||||
export function setDialogLogs(logs: DialogEntry[]): void {
|
||||
_dialogLogs = logs;
|
||||
}
|
||||
|
||||
// 9. pendingCriticalRequestsByPage (WeakMap — can't be reassigned, just cleared by replacing)
|
||||
let _pendingCriticalRequestsByPage = new WeakMap<Page, number>();
|
||||
export function getPendingCriticalRequestsByPage(): WeakMap<Page, number> {
|
||||
return _pendingCriticalRequestsByPage;
|
||||
}
|
||||
export function resetPendingCriticalRequestsByPage(): void {
|
||||
_pendingCriticalRequestsByPage = new WeakMap();
|
||||
}
|
||||
|
||||
// 10. currentRefMap
|
||||
let _currentRefMap: Record<string, RefNode> = {};
|
||||
export function getCurrentRefMap(): Record<string, RefNode> {
|
||||
return _currentRefMap;
|
||||
}
|
||||
export function setCurrentRefMap(m: Record<string, RefNode>): void {
|
||||
_currentRefMap = m;
|
||||
}
|
||||
|
||||
// 11. refVersion
|
||||
let _refVersion = 0;
|
||||
export function getRefVersion(): number {
|
||||
return _refVersion;
|
||||
}
|
||||
export function setRefVersion(v: number): void {
|
||||
_refVersion = v;
|
||||
}
|
||||
|
||||
// 12. refMetadata
|
||||
let _refMetadata: RefMetadata | null = null;
|
||||
export function getRefMetadata(): RefMetadata | null {
|
||||
return _refMetadata;
|
||||
}
|
||||
export function setRefMetadata(m: RefMetadata | null): void {
|
||||
_refMetadata = m;
|
||||
}
|
||||
|
||||
// 13. actionTimeline (object with internal state)
|
||||
export const actionTimeline = createActionTimeline(60);
|
||||
export function getActionTimeline() {
|
||||
return actionTimeline;
|
||||
}
|
||||
|
||||
// 14. lastActionBeforeState
|
||||
let _lastActionBeforeState: CompactPageState | null = null;
|
||||
export function getLastActionBeforeState(): CompactPageState | null {
|
||||
return _lastActionBeforeState;
|
||||
}
|
||||
export function setLastActionBeforeState(s: CompactPageState | null): void {
|
||||
_lastActionBeforeState = s;
|
||||
}
|
||||
|
||||
// 15. lastActionAfterState
|
||||
let _lastActionAfterState: CompactPageState | null = null;
|
||||
export function getLastActionAfterState(): CompactPageState | null {
|
||||
return _lastActionAfterState;
|
||||
}
|
||||
export function setLastActionAfterState(s: CompactPageState | null): void {
|
||||
_lastActionAfterState = s;
|
||||
}
|
||||
|
||||
// 16. sessionStartedAt
|
||||
let _sessionStartedAt: number | null = null;
|
||||
export function getSessionStartedAt(): number | null {
|
||||
return _sessionStartedAt;
|
||||
}
|
||||
export function setSessionStartedAt(t: number | null): void {
|
||||
_sessionStartedAt = t;
|
||||
}
|
||||
|
||||
// 17. sessionArtifactDir
|
||||
let _sessionArtifactDir: string | null = null;
|
||||
export function getSessionArtifactDir(): string | null {
|
||||
return _sessionArtifactDir;
|
||||
}
|
||||
export function setSessionArtifactDir(d: string | null): void {
|
||||
_sessionArtifactDir = d;
|
||||
}
|
||||
|
||||
// 18a. activeTraceSession
|
||||
let _activeTraceSession: TraceSessionState | null = null;
|
||||
export function getActiveTraceSession(): TraceSessionState | null {
|
||||
return _activeTraceSession;
|
||||
}
|
||||
export function setActiveTraceSession(t: TraceSessionState | null): void {
|
||||
_activeTraceSession = t;
|
||||
}
|
||||
|
||||
// 18b. harState
|
||||
const DEFAULT_HAR_STATE: HarState = {
|
||||
enabled: false,
|
||||
configuredAtContextCreation: false,
|
||||
path: null,
|
||||
exportCount: 0,
|
||||
lastExportedPath: null,
|
||||
lastExportedAt: null,
|
||||
};
|
||||
let _harState: HarState = { ...DEFAULT_HAR_STATE };
|
||||
export function getHarState(): HarState {
|
||||
return _harState;
|
||||
}
|
||||
export function setHarState(h: HarState): void {
|
||||
_harState = h;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// resetAllState — mirrors closeBrowser()'s reset logic
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function resetAllState(): void {
|
||||
_browser = null;
|
||||
_context = null;
|
||||
pageRegistry.pages = [];
|
||||
pageRegistry.activePageId = null;
|
||||
pageRegistry.nextId = 1;
|
||||
_activeFrame = null;
|
||||
_consoleLogs = [];
|
||||
_networkLogs = [];
|
||||
_dialogLogs = [];
|
||||
_pendingCriticalRequestsByPage = new WeakMap();
|
||||
_currentRefMap = {};
|
||||
_refVersion = 0;
|
||||
_refMetadata = null;
|
||||
_lastActionBeforeState = null;
|
||||
_lastActionAfterState = null;
|
||||
actionTimeline.entries = [];
|
||||
actionTimeline.nextId = 1;
|
||||
_sessionStartedAt = null;
|
||||
_sessionArtifactDir = null;
|
||||
_activeTraceSession = null;
|
||||
_harState = { ...DEFAULT_HAR_STATE };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ToolDeps — interface that tool registration functions consume
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Bundles the infrastructure functions that tool registration files need.
|
||||
* Built once in the index.ts orchestrator and passed to each register* function.
|
||||
*/
|
||||
export interface ToolDeps {
|
||||
// Lifecycle
|
||||
ensureBrowser: () => Promise<{
|
||||
browser: Browser;
|
||||
context: BrowserContext;
|
||||
page: Page;
|
||||
}>;
|
||||
closeBrowser: () => Promise<void>;
|
||||
getActivePage: () => Page;
|
||||
getActiveTarget: () => Page | Frame;
|
||||
getActivePageOrNull: () => Page | null;
|
||||
|
||||
// Page event wiring
|
||||
attachPageListeners: (p: Page, pageId: number) => void;
|
||||
|
||||
// Capture & summary
|
||||
captureCompactPageState: (
|
||||
p: Page,
|
||||
options?: {
|
||||
selectors?: string[];
|
||||
includeBodyText?: boolean;
|
||||
target?: Page | Frame;
|
||||
},
|
||||
) => Promise<CompactPageState>;
|
||||
postActionSummary: (p: Page, target?: Page | Frame) => Promise<string>;
|
||||
formatCompactStateSummary: (state: CompactPageState) => string;
|
||||
constrainScreenshot: (
|
||||
page: Page,
|
||||
buffer: Buffer,
|
||||
mimeType: string,
|
||||
quality: number,
|
||||
) => Promise<Buffer>;
|
||||
captureErrorScreenshot: (
|
||||
p: Page | null,
|
||||
) => Promise<{ data: string; mimeType: string } | null>;
|
||||
getRecentErrors: (pageUrl: string) => string;
|
||||
|
||||
// Settle
|
||||
settleAfterActionAdaptive: (
|
||||
p: Page,
|
||||
opts?: AdaptiveSettleOptions,
|
||||
) => Promise<AdaptiveSettleDetails>;
|
||||
ensureMutationCounter: (p: Page) => Promise<void>;
|
||||
|
||||
// Refs
|
||||
buildRefSnapshot: (
|
||||
target: Page | Frame,
|
||||
options: {
|
||||
selector?: string;
|
||||
interactiveOnly: boolean;
|
||||
limit: number;
|
||||
mode?: string;
|
||||
},
|
||||
) => Promise<Array<Omit<RefNode, "ref">>>;
|
||||
resolveRefTarget: (
|
||||
target: Page | Frame,
|
||||
node: RefNode,
|
||||
) => Promise<{ ok: true; selector: string } | { ok: false; reason: string }>;
|
||||
parseRef: (input: string) => ParsedRefSpec;
|
||||
formatVersionedRef: (version: number, key: string) => string;
|
||||
staleRefGuidance: (refDisplay: string, reason: string) => string;
|
||||
|
||||
// Action tracking
|
||||
beginTrackedAction: (
|
||||
tool: string,
|
||||
params: unknown,
|
||||
beforeUrl: string,
|
||||
) => ReturnType<typeof import("./core.js").beginAction>;
|
||||
finishTrackedAction: (
|
||||
actionId: number,
|
||||
updates: {
|
||||
status: "success" | "error";
|
||||
afterUrl?: string;
|
||||
verificationSummary?: string;
|
||||
warningSummary?: string;
|
||||
diffSummary?: string;
|
||||
changed?: boolean;
|
||||
error?: string;
|
||||
beforeState?: CompactPageState;
|
||||
afterState?: CompactPageState;
|
||||
},
|
||||
) => ReturnType<typeof import("./core.js").finishAction>;
|
||||
|
||||
// Utilities (forwarded from utils.ts)
|
||||
truncateText: (text: string) => string;
|
||||
verificationFromChecks: (
|
||||
checks: BrowserVerificationCheck[],
|
||||
retryHint?: string,
|
||||
) => BrowserVerificationResult;
|
||||
verificationLine: (verification: BrowserVerificationResult) => string;
|
||||
collectAssertionState: (
|
||||
p: Page,
|
||||
checks: BrowserAssertionCheckInput[],
|
||||
target?: Page | Frame,
|
||||
) => Promise<Record<string, unknown>>;
|
||||
formatAssertionText: (
|
||||
result: ReturnType<typeof import("./core.js").evaluateAssertionChecks>,
|
||||
) => string;
|
||||
formatDiffText: (
|
||||
diff: ReturnType<typeof import("./core.js").diffCompactStates>,
|
||||
) => string;
|
||||
getUrlHash: (url: string) => string;
|
||||
captureClickTargetState: (
|
||||
target: Page | Frame,
|
||||
selector: string,
|
||||
) => Promise<ClickTargetStateSnapshot>;
|
||||
readInputLikeValue: (
|
||||
target: Page | Frame,
|
||||
selector?: string,
|
||||
) => Promise<string | null>;
|
||||
firstErrorLine: (err: unknown) => string;
|
||||
captureAccessibilityMarkdown: (
|
||||
selector?: string,
|
||||
) => Promise<{ snapshot: string; scope: string; source: string }>;
|
||||
resolveAccessibilityScope: (
|
||||
selector?: string,
|
||||
) => Promise<{ selector?: string; scope: string; source: string }>;
|
||||
getLivePagesSnapshot: () => Promise<
|
||||
ReturnType<typeof import("./core.js").registryListPages>
|
||||
>;
|
||||
getSinceTimestamp: (sinceActionId?: number) => number;
|
||||
getConsoleEntriesSince: (sinceActionId?: number) => ConsoleEntry[];
|
||||
getNetworkEntriesSince: (sinceActionId?: number) => NetworkEntry[];
|
||||
writeArtifactFile: (
|
||||
filePath: string,
|
||||
content: string | Uint8Array,
|
||||
) => Promise<{ path: string; bytes: number }>;
|
||||
copyArtifactFile: (
|
||||
sourcePath: string,
|
||||
destinationPath: string,
|
||||
) => Promise<{ path: string; bytes: number }>;
|
||||
ensureSessionArtifactDir: () => Promise<string>;
|
||||
buildSessionArtifactPath: (filename: string) => string;
|
||||
getSessionArtifactMetadata: () => Record<string, unknown>;
|
||||
sanitizeArtifactName: (value: string, fallback: string) => string;
|
||||
formatArtifactTimestamp: (timestamp: number) => string;
|
||||
}
|
||||
|
|
@ -1,270 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Action caching — cache semantic intent → selector mappings to skip LLM inference on repeat visits.
|
||||
* Internal optimization that hooks into browser_find_best / browser_act.
|
||||
*/
|
||||
|
||||
interface CacheEntry {
|
||||
selector: string;
|
||||
score: number;
|
||||
url: string;
|
||||
domHash: string;
|
||||
timestamp: number;
|
||||
hitCount: number;
|
||||
}
|
||||
|
||||
const cache = new Map<string, CacheEntry>();
|
||||
const MAX_CACHE_SIZE = 200;
|
||||
|
||||
export function registerActionCacheTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_action_cache
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_action_cache",
|
||||
label: "Browser Action Cache",
|
||||
description:
|
||||
"Manage the action cache that maps page structure + intent → resolved selectors. " +
|
||||
"Cache reduces token cost on repeat visits to same pages. " +
|
||||
"Actions: 'stats' (show cache metrics), 'get' (lookup cached selector), " +
|
||||
"'put' (store a selector mapping), 'clear' (flush cache).",
|
||||
parameters: Type.Object({
|
||||
action: Type.String({
|
||||
description: "Cache action: 'stats', 'get', 'put', or 'clear'.",
|
||||
}),
|
||||
intent: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Semantic intent key (for get/put). E.g., 'submit_form', 'close_dialog'.",
|
||||
}),
|
||||
),
|
||||
selector: Type.Optional(
|
||||
Type.String({ description: "CSS selector to cache (for put)." }),
|
||||
),
|
||||
score: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Confidence score 0–1 for the cached selector (for put).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const url = p.url();
|
||||
|
||||
switch (params.action) {
|
||||
case "stats": {
|
||||
const entries = [...cache.values()];
|
||||
const totalHits = entries.reduce((sum, e) => sum + e.hitCount, 0);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Action cache: ${cache.size} entries, ${totalHits} total hits\nMax size: ${MAX_CACHE_SIZE}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
size: cache.size,
|
||||
maxSize: MAX_CACHE_SIZE,
|
||||
totalHits,
|
||||
entries: entries.map((e) => ({
|
||||
url: e.url,
|
||||
selector: e.selector,
|
||||
hitCount: e.hitCount,
|
||||
score: e.score,
|
||||
})),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
case "get": {
|
||||
if (!params.intent) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Intent parameter required for 'get' action.",
|
||||
},
|
||||
],
|
||||
details: { error: "missing_intent" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const domHash = await computeDomHash(p);
|
||||
const key = buildCacheKey(url, domHash, params.intent);
|
||||
const entry = cache.get(key);
|
||||
|
||||
if (!entry) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cache miss for intent "${params.intent}" on ${url}`,
|
||||
},
|
||||
],
|
||||
details: { hit: false, intent: params.intent, url },
|
||||
};
|
||||
}
|
||||
|
||||
// Validate the cached selector still exists
|
||||
const exists = await p
|
||||
.locator(entry.selector)
|
||||
.first()
|
||||
.isVisible()
|
||||
.catch(() => false);
|
||||
if (!exists) {
|
||||
cache.delete(key);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cache entry stale (selector no longer visible): ${entry.selector}`,
|
||||
},
|
||||
],
|
||||
details: { hit: false, stale: true, selector: entry.selector },
|
||||
};
|
||||
}
|
||||
|
||||
entry.hitCount++;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cache hit: "${params.intent}" → ${entry.selector} (score: ${entry.score}, hits: ${entry.hitCount})`,
|
||||
},
|
||||
],
|
||||
details: { hit: true, ...entry },
|
||||
};
|
||||
}
|
||||
|
||||
case "put": {
|
||||
if (!params.intent || !params.selector) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Intent and selector parameters required for 'put' action.",
|
||||
},
|
||||
],
|
||||
details: { error: "missing_params" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const domHash = await computeDomHash(p);
|
||||
const key = buildCacheKey(url, domHash, params.intent);
|
||||
|
||||
// Evict oldest entries if at capacity
|
||||
if (cache.size >= MAX_CACHE_SIZE && !cache.has(key)) {
|
||||
const oldestKey = [...cache.entries()].sort(
|
||||
([, a], [, b]) => a.timestamp - b.timestamp,
|
||||
)[0]?.[0];
|
||||
if (oldestKey) cache.delete(oldestKey);
|
||||
}
|
||||
|
||||
const entry: CacheEntry = {
|
||||
selector: params.selector,
|
||||
score: params.score ?? 1.0,
|
||||
url,
|
||||
domHash,
|
||||
timestamp: Date.now(),
|
||||
hitCount: 0,
|
||||
};
|
||||
cache.set(key, entry);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cached: "${params.intent}" → ${params.selector} (cache size: ${cache.size})`,
|
||||
},
|
||||
],
|
||||
details: { stored: true, key, ...entry, cacheSize: cache.size },
|
||||
};
|
||||
}
|
||||
|
||||
case "clear": {
|
||||
const size = cache.size;
|
||||
cache.clear();
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Action cache cleared (${size} entries removed).`,
|
||||
},
|
||||
],
|
||||
details: { cleared: size },
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Unknown action: ${params.action}. Use 'stats', 'get', 'put', or 'clear'.`,
|
||||
},
|
||||
],
|
||||
details: { error: "unknown_action" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Action cache error: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function buildCacheKey(url: string, domHash: string, intent: string): string {
|
||||
// Normalize URL — strip hash and query params for broader matching
|
||||
let normalized: string;
|
||||
try {
|
||||
const u = new URL(url);
|
||||
normalized = `${u.origin}${u.pathname}`;
|
||||
} catch {
|
||||
normalized = url;
|
||||
}
|
||||
return `${normalized}|${domHash}|${intent}`;
|
||||
}
|
||||
|
||||
async function computeDomHash(page: any): Promise<string> {
|
||||
try {
|
||||
return await page.evaluate(() => {
|
||||
// Structural hash based on element count + tag distribution
|
||||
const tags = new Map<string, number>();
|
||||
const all = document.querySelectorAll("*");
|
||||
for (const el of all) {
|
||||
const tag = el.tagName;
|
||||
tags.set(tag, (tags.get(tag) ?? 0) + 1);
|
||||
}
|
||||
const entries = [...tags.entries()].sort((a, b) =>
|
||||
a[0].localeCompare(b[0]),
|
||||
);
|
||||
const str = entries.map(([t, c]) => `${t}:${c}`).join("|");
|
||||
// Simple hash
|
||||
let h = 5381;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
h = ((h << 5) - h + str.charCodeAt(i)) | 0;
|
||||
}
|
||||
return (h >>> 0).toString(16);
|
||||
});
|
||||
} catch {
|
||||
return "unknown";
|
||||
}
|
||||
}
|
||||
|
|
@ -1,548 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { StringEnum } from "@singularity-forge/pi-ai";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
createRegionStableScript,
|
||||
diffCompactStates,
|
||||
evaluateAssertionChecks,
|
||||
findAction,
|
||||
includesNeedle,
|
||||
parseThreshold,
|
||||
runBatchSteps,
|
||||
validateWaitParams,
|
||||
} from "../core.js";
|
||||
import type { CompactPageState, ToolDeps } from "../state.js";
|
||||
import {
|
||||
getActionTimeline,
|
||||
getConsoleLogs,
|
||||
getCurrentRefMap,
|
||||
getLastActionAfterState,
|
||||
getLastActionBeforeState,
|
||||
setLastActionAfterState,
|
||||
setLastActionBeforeState,
|
||||
} from "../state.js";
|
||||
|
||||
export function registerAssertionTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_assert
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_assert",
|
||||
label: "Browser Assert",
|
||||
description:
|
||||
"Run one or more explicit browser assertions and return structured PASS/FAIL results. Prefer this for verification instead of inferring success from prose summaries.",
|
||||
promptGuidelines: [
|
||||
"Prefer browser_assert for browser verification instead of inferring success from summaries.",
|
||||
"When finishing UI work, explicit browser assertions should usually be the final verification step.",
|
||||
"Use checks for URL, text, selector state, value, and browser diagnostics whenever those signals are available.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
checks: Type.Array(
|
||||
Type.Object({
|
||||
kind: Type.String({
|
||||
description:
|
||||
"Assertion kind, e.g. url_contains, text_visible, selector_visible, value_equals, no_console_errors, no_failed_requests, request_url_seen, response_status, console_message_matches, network_count, console_count, no_console_errors_since, no_failed_requests_since",
|
||||
}),
|
||||
selector: Type.Optional(Type.String()),
|
||||
text: Type.Optional(Type.String()),
|
||||
value: Type.Optional(Type.String()),
|
||||
checked: Type.Optional(Type.Boolean()),
|
||||
sinceActionId: Type.Optional(Type.Number()),
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const state = await deps.collectAssertionState(
|
||||
p,
|
||||
params.checks,
|
||||
target,
|
||||
);
|
||||
const result = evaluateAssertionChecks({
|
||||
checks: params.checks,
|
||||
state,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Browser assert\n\n${deps.formatAssertionText(result)}`,
|
||||
},
|
||||
],
|
||||
details: { ...result, url: state.url, title: state.title },
|
||||
isError: !result.verified,
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Browser assert failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_diff
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_diff",
|
||||
label: "Browser Diff",
|
||||
description:
|
||||
"Report meaningful browser-state changes. By default compares the current page to the most recent tracked action state. Use this to understand what changed after a click, submit, or navigation.",
|
||||
promptGuidelines: [
|
||||
"Use browser_diff after ambiguous or high-impact actions when you need to know what changed.",
|
||||
"Prefer browser_diff over requesting a broad new page inspection when the question is change detection.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
sinceActionId: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Optional action id to diff against. Uses that action's stored after-state when available.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const current = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
let baseline: CompactPageState | null = null;
|
||||
if (params.sinceActionId) {
|
||||
const actionTimeline = getActionTimeline();
|
||||
const action = findAction(actionTimeline, params.sinceActionId) as {
|
||||
afterState?: CompactPageState;
|
||||
} | null;
|
||||
baseline = action?.afterState ?? null;
|
||||
}
|
||||
if (!baseline) {
|
||||
baseline = getLastActionAfterState() ?? getLastActionBeforeState();
|
||||
}
|
||||
if (!baseline) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Browser diff unavailable: no prior tracked browser state exists yet.",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
changed: false,
|
||||
changes: [],
|
||||
summary: "No prior tracked state",
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const diff = diffCompactStates(baseline, current);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Browser diff\n\n${deps.formatDiffText(diff)}`,
|
||||
},
|
||||
],
|
||||
details: diff,
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Browser diff failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_batch
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_batch",
|
||||
label: "Browser Batch",
|
||||
description:
|
||||
"Execute multiple explicit browser steps in one call. Prefer this for obvious action sequences like click → type → wait → assert to reduce round trips and token usage.",
|
||||
promptGuidelines: [
|
||||
"If the next 2-5 browser actions are obvious and low-risk, prefer browser_batch over multiple tiny browser calls.",
|
||||
"Use browser_batch for explicit sequences like click → type → submit → wait → assert.",
|
||||
"Keep browser_batch steps explicit; do not use it as a speculative planner.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
steps: Type.Array(
|
||||
Type.Object({
|
||||
action: StringEnum([
|
||||
"navigate",
|
||||
"click",
|
||||
"type",
|
||||
"key_press",
|
||||
"wait_for",
|
||||
"assert",
|
||||
"click_ref",
|
||||
"fill_ref",
|
||||
] as const),
|
||||
selector: Type.Optional(Type.String()),
|
||||
text: Type.Optional(Type.String()),
|
||||
url: Type.Optional(Type.String()),
|
||||
key: Type.Optional(Type.String()),
|
||||
condition: Type.Optional(Type.String()),
|
||||
value: Type.Optional(Type.String()),
|
||||
threshold: Type.Optional(Type.String()),
|
||||
timeout: Type.Optional(Type.Number()),
|
||||
clearFirst: Type.Optional(Type.Boolean()),
|
||||
submit: Type.Optional(Type.Boolean()),
|
||||
ref: Type.Optional(Type.String()),
|
||||
checks: Type.Optional(
|
||||
Type.Array(
|
||||
Type.Object({
|
||||
kind: Type.String({
|
||||
description:
|
||||
"Assertion kind, e.g. url_contains, text_visible, selector_visible, value_equals, no_console_errors, no_failed_requests, request_url_seen, response_status, console_message_matches, network_count, console_count, no_console_errors_since, no_failed_requests_since",
|
||||
}),
|
||||
selector: Type.Optional(Type.String()),
|
||||
text: Type.Optional(Type.String()),
|
||||
value: Type.Optional(Type.String()),
|
||||
checked: Type.Optional(Type.Boolean()),
|
||||
sinceActionId: Type.Optional(Type.Number()),
|
||||
}),
|
||||
),
|
||||
),
|
||||
}),
|
||||
),
|
||||
stopOnFailure: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Stop after the first failing step (default: true).",
|
||||
}),
|
||||
),
|
||||
finalSummaryOnly: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Return only the compact final batch summary in content while keeping step results in details.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_batch",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
const executeStep = async (step: any, index: number) => {
|
||||
const stepTarget = deps.getActiveTarget();
|
||||
try {
|
||||
switch (step.action) {
|
||||
case "navigate": {
|
||||
await p.goto(step.url, {
|
||||
waitUntil: "domcontentloaded",
|
||||
timeout: 30000,
|
||||
});
|
||||
await p
|
||||
.waitForLoadState("networkidle", { timeout: 5000 })
|
||||
.catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
return { ok: true, action: step.action, url: p.url() };
|
||||
}
|
||||
case "click": {
|
||||
await stepTarget
|
||||
.locator(step.selector)
|
||||
.first()
|
||||
.click({ timeout: step.timeout ?? 8000 });
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
return {
|
||||
ok: true,
|
||||
action: step.action,
|
||||
selector: step.selector,
|
||||
url: p.url(),
|
||||
};
|
||||
}
|
||||
case "type": {
|
||||
if (step.clearFirst) {
|
||||
await stepTarget.locator(step.selector).first().fill("");
|
||||
}
|
||||
await stepTarget
|
||||
.locator(step.selector)
|
||||
.first()
|
||||
.fill(step.text ?? "", { timeout: step.timeout ?? 8000 });
|
||||
if (step.submit) await p.keyboard.press("Enter");
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
return {
|
||||
ok: true,
|
||||
action: step.action,
|
||||
selector: step.selector,
|
||||
text: step.text,
|
||||
};
|
||||
}
|
||||
case "key_press": {
|
||||
await p.keyboard.press(step.key);
|
||||
await deps.settleAfterActionAdaptive(p, {
|
||||
checkFocusStability: true,
|
||||
});
|
||||
return { ok: true, action: step.action, key: step.key };
|
||||
}
|
||||
case "wait_for": {
|
||||
const timeout = step.timeout ?? 10000;
|
||||
const waitValidation = validateWaitParams({
|
||||
condition: step.condition,
|
||||
value: step.value,
|
||||
threshold: step.threshold,
|
||||
});
|
||||
if (waitValidation) throw new Error(waitValidation.error);
|
||||
|
||||
if (step.condition === "selector_visible")
|
||||
await stepTarget.waitForSelector(step.value, {
|
||||
state: "visible",
|
||||
timeout,
|
||||
});
|
||||
else if (step.condition === "selector_hidden")
|
||||
await stepTarget.waitForSelector(step.value, {
|
||||
state: "hidden",
|
||||
timeout,
|
||||
});
|
||||
else if (step.condition === "url_contains")
|
||||
await p.waitForURL(
|
||||
(url) => url.toString().includes(step.value),
|
||||
{ timeout },
|
||||
);
|
||||
else if (step.condition === "network_idle")
|
||||
await p.waitForLoadState("networkidle", { timeout });
|
||||
else if (step.condition === "delay")
|
||||
await new Promise((resolve) =>
|
||||
setTimeout(resolve, parseInt(step.value ?? "1000", 10)),
|
||||
);
|
||||
else if (step.condition === "text_visible") {
|
||||
await stepTarget.waitForFunction(
|
||||
(needle: string) =>
|
||||
(document.body?.innerText ?? "")
|
||||
.toLowerCase()
|
||||
.includes(needle.toLowerCase()),
|
||||
step.value!,
|
||||
{ timeout },
|
||||
);
|
||||
} else if (step.condition === "text_hidden") {
|
||||
await stepTarget.waitForFunction(
|
||||
(needle: string) =>
|
||||
!(document.body?.innerText ?? "")
|
||||
.toLowerCase()
|
||||
.includes(needle.toLowerCase()),
|
||||
step.value!,
|
||||
{ timeout },
|
||||
);
|
||||
} else if (step.condition === "request_completed") {
|
||||
await deps
|
||||
.getActivePage()
|
||||
.waitForResponse(
|
||||
(resp: any) => resp.url().includes(step.value!),
|
||||
{ timeout },
|
||||
);
|
||||
} else if (step.condition === "console_message") {
|
||||
const needle = step.value!;
|
||||
const startTime = Date.now();
|
||||
let found = false;
|
||||
while (Date.now() - startTime < timeout) {
|
||||
if (
|
||||
getConsoleLogs().find((entry) =>
|
||||
includesNeedle(entry.text, needle),
|
||||
)
|
||||
) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
if (!found)
|
||||
throw new Error(
|
||||
`Timed out waiting for console message matching "${needle}" (${timeout}ms)`,
|
||||
);
|
||||
} else if (step.condition === "element_count") {
|
||||
const threshold = parseThreshold(step.threshold ?? ">=1");
|
||||
if (!threshold)
|
||||
throw new Error(
|
||||
`element_count threshold is malformed: "${step.threshold}"`,
|
||||
);
|
||||
const selector = step.value!;
|
||||
const op = threshold.op;
|
||||
const n = threshold.n;
|
||||
await stepTarget.waitForFunction(
|
||||
({
|
||||
selector,
|
||||
op,
|
||||
n,
|
||||
}: {
|
||||
selector: string;
|
||||
op: string;
|
||||
n: number;
|
||||
}) => {
|
||||
const count = document.querySelectorAll(selector).length;
|
||||
switch (op) {
|
||||
case ">=":
|
||||
return count >= n;
|
||||
case "<=":
|
||||
return count <= n;
|
||||
case "==":
|
||||
return count === n;
|
||||
case ">":
|
||||
return count > n;
|
||||
case "<":
|
||||
return count < n;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
},
|
||||
{ selector, op, n },
|
||||
{ timeout },
|
||||
);
|
||||
} else if (step.condition === "region_stable") {
|
||||
const script = createRegionStableScript(step.value!);
|
||||
await stepTarget.waitForFunction(script, undefined, {
|
||||
timeout,
|
||||
polling: 200,
|
||||
});
|
||||
} else
|
||||
throw new Error(
|
||||
`Unsupported wait condition: ${step.condition}`,
|
||||
);
|
||||
return {
|
||||
ok: true,
|
||||
action: step.action,
|
||||
condition: step.condition,
|
||||
value: step.value,
|
||||
};
|
||||
}
|
||||
case "assert": {
|
||||
const state = await deps.collectAssertionState(
|
||||
p,
|
||||
step.checks ?? [],
|
||||
stepTarget,
|
||||
);
|
||||
const assertion = evaluateAssertionChecks({
|
||||
checks: step.checks ?? [],
|
||||
state,
|
||||
});
|
||||
return {
|
||||
ok: assertion.verified,
|
||||
action: step.action,
|
||||
summary: assertion.summary,
|
||||
assertion,
|
||||
};
|
||||
}
|
||||
case "click_ref": {
|
||||
const parsedRef = deps.parseRef(step.ref);
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const node = currentRefMap[parsedRef.key];
|
||||
if (!node) throw new Error(`Unknown ref: ${step.ref}`);
|
||||
const resolved = await deps.resolveRefTarget(stepTarget, node);
|
||||
if (!resolved.ok) throw new Error(resolved.reason);
|
||||
await stepTarget
|
||||
.locator(resolved.selector)
|
||||
.first()
|
||||
.click({ timeout: step.timeout ?? 8000 });
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
return { ok: true, action: step.action, ref: step.ref };
|
||||
}
|
||||
case "fill_ref": {
|
||||
const parsedRef = deps.parseRef(step.ref);
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const node = currentRefMap[parsedRef.key];
|
||||
if (!node) throw new Error(`Unknown ref: ${step.ref}`);
|
||||
const resolved = await deps.resolveRefTarget(stepTarget, node);
|
||||
if (!resolved.ok) throw new Error(resolved.reason);
|
||||
if (step.clearFirst)
|
||||
await stepTarget.locator(resolved.selector).first().fill("");
|
||||
await stepTarget
|
||||
.locator(resolved.selector)
|
||||
.first()
|
||||
.fill(step.text ?? "", { timeout: step.timeout ?? 8000 });
|
||||
if (step.submit) await p.keyboard.press("Enter");
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
return {
|
||||
ok: true,
|
||||
action: step.action,
|
||||
ref: step.ref,
|
||||
text: step.text,
|
||||
};
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unsupported batch action: ${step.action}`);
|
||||
}
|
||||
} catch (err: any) {
|
||||
return {
|
||||
ok: false,
|
||||
action: step.action,
|
||||
index,
|
||||
message: err.message,
|
||||
};
|
||||
}
|
||||
};
|
||||
const run = await runBatchSteps({
|
||||
steps: params.steps,
|
||||
executeStep,
|
||||
stopOnFailure: params.stopOnFailure !== false,
|
||||
});
|
||||
const batchEndTarget = deps.getActiveTarget();
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target: batchEndTarget,
|
||||
});
|
||||
const diff = diffCompactStates(beforeState!, afterState);
|
||||
setLastActionBeforeState(beforeState!);
|
||||
setLastActionAfterState(afterState);
|
||||
deps.finishTrackedAction(actionId!, {
|
||||
status: run.ok ? "success" : "error",
|
||||
afterUrl: afterState.url,
|
||||
diffSummary: diff.summary,
|
||||
changed: diff.changed,
|
||||
error: run.ok ? undefined : run.summary,
|
||||
beforeState: beforeState!,
|
||||
afterState,
|
||||
});
|
||||
const summary = `${run.summary}\n${run.stepResults.map((step: any, index: number) => `- ${index + 1}. ${step.action}: ${step.ok ? "PASS" : "FAIL"}${step.message ? ` (${step.message})` : ""}`).join("\n")}`;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: params.finalSummaryOnly
|
||||
? run.summary
|
||||
: `Browser batch\nAction: ${actionId}\n\n${summary}\n\nDiff:\n${deps.formatDiffText(diff)}`,
|
||||
},
|
||||
],
|
||||
details: { actionId, diff, ...run },
|
||||
isError: !run.ok,
|
||||
};
|
||||
} catch (err: any) {
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
afterUrl: deps.getActivePageOrNull()?.url() ?? "",
|
||||
error: err.message,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Browser batch failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, actionId },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,323 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
import { getActionTimeline } from "../state.js";
|
||||
|
||||
/**
|
||||
* Test code generation — transform recorded browser session into a Playwright test script.
|
||||
*/
|
||||
|
||||
export function registerCodegenTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_generate_test",
|
||||
label: "Browser Generate Test",
|
||||
description:
|
||||
"Generate a runnable Playwright test script from the recorded action timeline. " +
|
||||
"Transforms navigation, click, type, and assertion actions into standard Playwright test syntax. " +
|
||||
"Uses stable selectors (role-based preferred). Writes the test file to a configurable path.",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Test name (used for describe/test block and filename). Default: 'recorded-session'.",
|
||||
}),
|
||||
),
|
||||
outputPath: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Output file path for the generated test. Default: writes to session artifacts directory. " +
|
||||
"Use a path ending in .spec.ts for standard Playwright test convention.",
|
||||
}),
|
||||
),
|
||||
includeAssertions: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Include assertion steps from the timeline (default: true).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const timeline = getActionTimeline();
|
||||
|
||||
if (timeline.entries.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "No actions recorded in the current session. Interact with pages first, then generate a test.",
|
||||
},
|
||||
],
|
||||
details: { error: "no_actions" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const testName = params.name ?? "recorded-session";
|
||||
const includeAssertions = params.includeAssertions ?? true;
|
||||
|
||||
// Transform timeline entries into Playwright test code
|
||||
const testLines: string[] = [];
|
||||
const imports = new Set<string>();
|
||||
imports.add("test");
|
||||
imports.add("expect");
|
||||
|
||||
testLines.push(`test.describe('${escapeString(testName)}', () => {`);
|
||||
testLines.push(` test('recorded session', async ({ page }) => {`);
|
||||
|
||||
let lastUrl = "";
|
||||
let actionCount = 0;
|
||||
|
||||
for (const entry of timeline.entries) {
|
||||
if (entry.status === "error" && entry.tool !== "browser_assert")
|
||||
continue;
|
||||
|
||||
const params = parseParamsSummary(entry.paramsSummary);
|
||||
|
||||
switch (entry.tool) {
|
||||
case "browser_navigate": {
|
||||
const url = params.url;
|
||||
if (url && url !== lastUrl) {
|
||||
testLines.push(` await page.goto(${quote(url)});`);
|
||||
lastUrl = url;
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_click": {
|
||||
const selector = params.selector;
|
||||
if (selector) {
|
||||
testLines.push(
|
||||
` await page.locator(${quote(selector)}).click();`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_click_ref": {
|
||||
// Refs are session-specific — add comment
|
||||
testLines.push(
|
||||
` // browser_click_ref: ${entry.paramsSummary} — replace with stable selector`,
|
||||
);
|
||||
actionCount++;
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_type": {
|
||||
const selector = params.selector;
|
||||
const text = params.text;
|
||||
if (selector && text) {
|
||||
testLines.push(
|
||||
` await page.locator(${quote(selector)}).fill(${quote(text)});`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_fill_ref": {
|
||||
testLines.push(
|
||||
` // browser_fill_ref: ${entry.paramsSummary} — replace with stable selector`,
|
||||
);
|
||||
actionCount++;
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_key_press": {
|
||||
const key = params.key;
|
||||
if (key) {
|
||||
testLines.push(` await page.keyboard.press(${quote(key)});`);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_select_option": {
|
||||
const selector = params.selector;
|
||||
const option = params.option;
|
||||
if (selector && option) {
|
||||
testLines.push(
|
||||
` await page.locator(${quote(selector)}).selectOption(${quote(option)});`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_set_checked": {
|
||||
const selector = params.selector;
|
||||
const checked = params.checked;
|
||||
if (selector) {
|
||||
testLines.push(
|
||||
` await page.locator(${quote(selector)}).setChecked(${checked === "true"});`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_hover": {
|
||||
const selector = params.selector;
|
||||
if (selector) {
|
||||
testLines.push(
|
||||
` await page.locator(${quote(selector)}).hover();`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_wait_for": {
|
||||
const condition = params.condition;
|
||||
const value = params.value;
|
||||
if (condition === "selector_visible" && value) {
|
||||
testLines.push(
|
||||
` await expect(page.locator(${quote(value)})).toBeVisible();`,
|
||||
);
|
||||
actionCount++;
|
||||
} else if (condition === "text_visible" && value) {
|
||||
testLines.push(
|
||||
` await expect(page.locator('body')).toContainText(${quote(value)});`,
|
||||
);
|
||||
actionCount++;
|
||||
} else if (condition === "url_contains" && value) {
|
||||
testLines.push(
|
||||
` await page.waitForURL(${quote(`**/*${value}*`)});`,
|
||||
);
|
||||
actionCount++;
|
||||
} else if (condition === "network_idle") {
|
||||
testLines.push(
|
||||
` await page.waitForLoadState('networkidle');`,
|
||||
);
|
||||
actionCount++;
|
||||
} else if (condition === "delay" && value) {
|
||||
testLines.push(` await page.waitForTimeout(${value});`);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_assert": {
|
||||
if (!includeAssertions) break;
|
||||
// The assertion details are in verificationSummary
|
||||
if (entry.verificationSummary) {
|
||||
testLines.push(
|
||||
` // Assertion: ${entry.verificationSummary}`,
|
||||
);
|
||||
}
|
||||
actionCount++;
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_scroll": {
|
||||
const direction = params.direction;
|
||||
const amount = params.amount ?? "300";
|
||||
const delta = direction === "up" ? `-${amount}` : amount;
|
||||
testLines.push(` await page.mouse.wheel(0, ${delta});`);
|
||||
actionCount++;
|
||||
break;
|
||||
}
|
||||
|
||||
case "browser_set_viewport": {
|
||||
const width = params.width;
|
||||
const height = params.height;
|
||||
if (width && height) {
|
||||
testLines.push(
|
||||
` await page.setViewportSize({ width: ${width}, height: ${height} });`,
|
||||
);
|
||||
actionCount++;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
default:
|
||||
// Skip tools that don't map to Playwright test actions
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
testLines.push(` });`);
|
||||
testLines.push(`});`);
|
||||
|
||||
const importLine = `import { ${[...imports].join(", ")} } from '@playwright/test';`;
|
||||
const fullTest = `${importLine}\n\n${testLines.join("\n")}\n`;
|
||||
|
||||
// Write to file
|
||||
let outputPath: string;
|
||||
if (params.outputPath) {
|
||||
outputPath = params.outputPath;
|
||||
} else {
|
||||
const safeName = deps.sanitizeArtifactName(
|
||||
testName,
|
||||
"recorded-session",
|
||||
);
|
||||
outputPath = deps.buildSessionArtifactPath(`${safeName}.spec.ts`);
|
||||
}
|
||||
|
||||
await deps.ensureSessionArtifactDir();
|
||||
const { path: writtenPath, bytes } = await deps.writeArtifactFile(
|
||||
outputPath,
|
||||
fullTest,
|
||||
);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Test generated: ${writtenPath}\nActions: ${actionCount}\nTimeline entries processed: ${timeline.entries.length}\n\n${fullTest}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
path: writtenPath,
|
||||
bytes,
|
||||
actionCount,
|
||||
timelineEntries: timeline.entries.length,
|
||||
testCode: fullTest,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Test generation failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
function escapeString(s: string): string {
|
||||
return s.replace(/'/g, "\\'").replace(/\\/g, "\\\\");
|
||||
}
|
||||
|
||||
function quote(s: string): string {
|
||||
// Use single quotes for simple strings, backtick for those with quotes
|
||||
if (!s.includes("'")) return `'${s}'`;
|
||||
if (!s.includes("`")) return `\`${s}\``;
|
||||
return `'${s.replace(/'/g, "\\'")}'`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Parse the paramsSummary string back into key-value pairs.
|
||||
* Format: key="value", key=value, key=[N], key={...}
|
||||
*/
|
||||
function parseParamsSummary(summary: string): Record<string, string> {
|
||||
const result: Record<string, string> = {};
|
||||
if (!summary) return result;
|
||||
|
||||
const regex = /(\w+)=(?:"([^"]*(?:\\"[^"]*)*)"|([^,\s]+))/g;
|
||||
let match: RegExpExecArray | null;
|
||||
// biome-ignore lint/suspicious/noAssignInExpressions: intentional read loop
|
||||
while ((match = regex.exec(summary)) !== null) {
|
||||
const key = match[1];
|
||||
const value = match[2] ?? match[3];
|
||||
result[key] = value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Device emulation tool — full device simulation using Playwright's built-in device descriptors.
|
||||
*/
|
||||
|
||||
export function registerDeviceTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_emulate_device",
|
||||
label: "Browser Emulate Device",
|
||||
description:
|
||||
"Simulate a specific device by setting viewport, user agent, device scale factor, touch, and mobile flag. " +
|
||||
"Uses Playwright's built-in device descriptors (~143 devices). Accepts fuzzy matching on device name. " +
|
||||
"Note: Full emulation (user agent, isMobile) requires a context restart — the current page state will be lost. " +
|
||||
"The tool recreates the context with the device profile applied.",
|
||||
parameters: Type.Object({
|
||||
device: Type.String({
|
||||
description:
|
||||
"Device name (e.g., 'iPhone 15', 'Pixel 7', 'iPad Pro 11'). " +
|
||||
"Case-insensitive fuzzy matching. Use 'list' to see all available devices.",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { chromium, devices } = await import("playwright");
|
||||
const allDeviceNames = Object.keys(devices);
|
||||
|
||||
// Handle 'list' request
|
||||
if (params.device.toLowerCase() === "list") {
|
||||
// Group by base device name (remove landscape variants for cleaner display)
|
||||
const baseNames = allDeviceNames.filter(
|
||||
(n) => !n.endsWith(" landscape"),
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Available devices (${allDeviceNames.length} total, ${baseNames.length} base):\n${baseNames.join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { devices: baseNames, total: allDeviceNames.length },
|
||||
};
|
||||
}
|
||||
|
||||
// Fuzzy match device name
|
||||
const needle = params.device.toLowerCase();
|
||||
let exactMatch = allDeviceNames.find((n) => n.toLowerCase() === needle);
|
||||
if (!exactMatch) {
|
||||
// Try contains match
|
||||
const containsMatches = allDeviceNames.filter((n) =>
|
||||
n.toLowerCase().includes(needle),
|
||||
);
|
||||
if (containsMatches.length === 1) {
|
||||
exactMatch = containsMatches[0];
|
||||
} else if (containsMatches.length > 1) {
|
||||
// Pick the shortest match (most specific)
|
||||
containsMatches.sort((a, b) => a.length - b.length);
|
||||
exactMatch = containsMatches[0];
|
||||
const _suggestions = containsMatches.slice(0, 5).join(", ");
|
||||
// Continue with best match but mention alternatives
|
||||
} else {
|
||||
// No match at all — suggest closest
|
||||
const suggestions = allDeviceNames
|
||||
.map((n) => ({
|
||||
name: n,
|
||||
score: fuzzyScore(needle, n.toLowerCase()),
|
||||
}))
|
||||
.sort((a, b) => b.score - a.score)
|
||||
.slice(0, 5)
|
||||
.map((s) => s.name);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `No device matching "${params.device}". Did you mean:\n${suggestions.map((s) => ` - ${s}`).join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { error: "no_match", suggestions },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const deviceDescriptor = devices[exactMatch!];
|
||||
if (!deviceDescriptor) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Device descriptor not found for "${exactMatch}"`,
|
||||
},
|
||||
],
|
||||
details: { error: "descriptor_not_found" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Context restart required for full emulation.
|
||||
// Save current URL to navigate back after restart.
|
||||
const { page: currentPage, context: _currentCtx } =
|
||||
await deps.ensureBrowser();
|
||||
const currentUrl = currentPage.url();
|
||||
|
||||
// Close existing browser and relaunch with device profile
|
||||
await deps.closeBrowser();
|
||||
|
||||
// Re-launch — ensureBrowser doesn't accept device params, so we do it manually.
|
||||
// This is a one-off context creation with device emulation.
|
||||
const needsHeadless =
|
||||
process.platform === "linux" && !process.env.DISPLAY;
|
||||
const launchOptions: Record<string, unknown> = {
|
||||
headless: needsHeadless || process.env.FORCE_HEADLESS === "true",
|
||||
};
|
||||
const customPath = process.env.BROWSER_PATH;
|
||||
if (customPath) launchOptions.executablePath = customPath;
|
||||
|
||||
const browser = await chromium.launch(launchOptions);
|
||||
const context = await browser.newContext({
|
||||
...deviceDescriptor,
|
||||
});
|
||||
|
||||
// Inject evaluate helpers
|
||||
const { EVALUATE_HELPERS_SOURCE } = await import(
|
||||
"../evaluate-helpers.js"
|
||||
);
|
||||
await context.addInitScript(EVALUATE_HELPERS_SOURCE);
|
||||
|
||||
// Wire up state
|
||||
const {
|
||||
setBrowser,
|
||||
setContext,
|
||||
pageRegistry,
|
||||
setSessionStartedAt,
|
||||
setSessionArtifactDir: _setSessionArtifactDir,
|
||||
resetAllState,
|
||||
} = await import("../state.js");
|
||||
const { registryAddPage, registrySetActive } = await import(
|
||||
"../core.js"
|
||||
);
|
||||
|
||||
// Reset state for new session
|
||||
resetAllState();
|
||||
setBrowser(browser);
|
||||
setContext(context);
|
||||
setSessionStartedAt(Date.now());
|
||||
|
||||
const page = await context.newPage();
|
||||
const entry = registryAddPage(pageRegistry, {
|
||||
page,
|
||||
title: "",
|
||||
url: "about:blank",
|
||||
opener: null,
|
||||
});
|
||||
registrySetActive(pageRegistry, entry.id);
|
||||
deps.attachPageListeners(page, entry.id);
|
||||
|
||||
// Navigate back to previous URL if it wasn't about:blank
|
||||
if (currentUrl && currentUrl !== "about:blank") {
|
||||
await page
|
||||
.goto(currentUrl, { waitUntil: "domcontentloaded", timeout: 15000 })
|
||||
.catch((e) => {
|
||||
if (process.env.SF_DEBUG)
|
||||
console.error(
|
||||
"[browser-tools] device goto restore failed:",
|
||||
e.message,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
const viewport = deviceDescriptor.viewport;
|
||||
const vpText = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Device emulation active: ${exactMatch}\nViewport: ${vpText}\nUser Agent: ${deviceDescriptor.userAgent?.slice(0, 80) ?? "default"}...\nMobile: ${deviceDescriptor.isMobile ?? false}\nTouch: ${deviceDescriptor.hasTouch ?? false}\nScale Factor: ${deviceDescriptor.deviceScaleFactor ?? 1}\n\nContext was restarted for full emulation. Page state was reset.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
device: exactMatch,
|
||||
viewport: vpText,
|
||||
isMobile: deviceDescriptor.isMobile ?? false,
|
||||
hasTouch: deviceDescriptor.hasTouch ?? false,
|
||||
deviceScaleFactor: deviceDescriptor.deviceScaleFactor ?? 1,
|
||||
userAgent: deviceDescriptor.userAgent,
|
||||
restoredUrl: currentUrl,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Device emulation failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple fuzzy scoring — counts matching characters in order.
|
||||
*/
|
||||
function fuzzyScore(needle: string, haystack: string): number {
|
||||
let score = 0;
|
||||
let hi = 0;
|
||||
for (let ni = 0; ni < needle.length && hi < haystack.length; ni++) {
|
||||
const idx = haystack.indexOf(needle[ni], hi);
|
||||
if (idx >= 0) {
|
||||
score++;
|
||||
hi = idx + 1;
|
||||
}
|
||||
}
|
||||
return score / Math.max(needle.length, 1);
|
||||
}
|
||||
|
|
@ -1,286 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Structured data extraction with JSON Schema validation.
|
||||
*/
|
||||
|
||||
export function registerExtractTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_extract",
|
||||
label: "Browser Extract",
|
||||
description:
|
||||
"Extract structured data from the current page using CSS selectors and validate against a JSON Schema. " +
|
||||
"Provide a schema describing the shape of data you want. The tool extracts data by evaluating " +
|
||||
"CSS selectors in the page context, then validates the result against your schema. " +
|
||||
"Supports extracting single objects or arrays of items. Waits for network idle before extraction.",
|
||||
parameters: Type.Object({
|
||||
schema: Type.Record(Type.String(), Type.Unknown(), {
|
||||
description:
|
||||
"JSON Schema describing the data shape to extract. Properties should include " +
|
||||
"'_selector' (CSS selector) and '_attribute' (attribute to read, default: 'textContent') hints. " +
|
||||
"Example: { type: 'object', properties: { title: { _selector: 'h1', _attribute: 'textContent' }, price: { _selector: '.price', _attribute: 'textContent' } } }",
|
||||
}),
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to scope extraction to a specific container element.",
|
||||
}),
|
||||
),
|
||||
multiple: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"If true, extract an array of items. The 'selector' parameter becomes the item container selector, " +
|
||||
"and schema properties are extracted relative to each matched container.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
|
||||
// Wait for network idle before extraction
|
||||
await p
|
||||
.waitForLoadState("networkidle", { timeout: 10000 })
|
||||
.catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
|
||||
const schema = params.schema as any;
|
||||
const scopeSelector = params.selector;
|
||||
const multiple = params.multiple ?? false;
|
||||
|
||||
// Build extraction plan from schema
|
||||
const extractionPlan = buildExtractionPlan(schema);
|
||||
|
||||
// Execute extraction in page context
|
||||
const rawData = await p.evaluate(
|
||||
({
|
||||
plan,
|
||||
scope,
|
||||
multi,
|
||||
}: {
|
||||
plan: ExtractionField[];
|
||||
scope: string | undefined;
|
||||
multi: boolean;
|
||||
}) => {
|
||||
function extractFromContainer(
|
||||
container: Element,
|
||||
fields: typeof plan,
|
||||
): Record<string, unknown> {
|
||||
const result: Record<string, unknown> = {};
|
||||
for (const field of fields) {
|
||||
const el = container.querySelector(field.selector);
|
||||
if (!el) {
|
||||
result[field.name] = null;
|
||||
continue;
|
||||
}
|
||||
let value: unknown;
|
||||
switch (field.attribute) {
|
||||
case "textContent":
|
||||
value = (el.textContent ?? "").trim();
|
||||
break;
|
||||
case "innerText":
|
||||
value = ((el as HTMLElement).innerText ?? "").trim();
|
||||
break;
|
||||
case "innerHTML":
|
||||
value = el.innerHTML;
|
||||
break;
|
||||
case "href":
|
||||
value =
|
||||
(el as HTMLAnchorElement).href ?? el.getAttribute("href");
|
||||
break;
|
||||
case "src":
|
||||
value =
|
||||
(el as HTMLImageElement).src ?? el.getAttribute("src");
|
||||
break;
|
||||
case "value":
|
||||
value = (el as HTMLInputElement).value;
|
||||
break;
|
||||
default:
|
||||
value =
|
||||
el.getAttribute(field.attribute) ??
|
||||
(el.textContent ?? "").trim();
|
||||
}
|
||||
// Type coercion
|
||||
if (field.type === "number" && typeof value === "string") {
|
||||
const num = parseFloat(value.replace(/[^0-9.-]/g, ""));
|
||||
value = Number.isNaN(num) ? value : num;
|
||||
} else if (
|
||||
field.type === "boolean" &&
|
||||
typeof value === "string"
|
||||
) {
|
||||
value = value.toLowerCase() === "true" || value === "1";
|
||||
}
|
||||
result[field.name] = value;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
const root = scope ? document.querySelector(scope) : document.body;
|
||||
if (!root)
|
||||
return {
|
||||
data: null,
|
||||
error: `Scope selector "${scope}" not found`,
|
||||
};
|
||||
|
||||
if (multi) {
|
||||
// For multiple items, scope is the item selector
|
||||
const containers = scope
|
||||
? document.querySelectorAll(scope)
|
||||
: [document.body];
|
||||
const items = Array.from(containers).map((container) =>
|
||||
extractFromContainer(container, plan),
|
||||
);
|
||||
return { data: items, error: null };
|
||||
} else {
|
||||
return { data: extractFromContainer(root, plan), error: null };
|
||||
}
|
||||
},
|
||||
{ plan: extractionPlan, scope: scopeSelector, multi: multiple },
|
||||
);
|
||||
|
||||
if (rawData.error) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Extraction failed: ${rawData.error}` },
|
||||
],
|
||||
details: { error: rawData.error },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Validate against schema using ajv
|
||||
const validationErrors = await validateData(
|
||||
rawData.data,
|
||||
schema,
|
||||
multiple,
|
||||
);
|
||||
|
||||
const resultText = JSON.stringify(rawData.data, null, 2);
|
||||
const truncated =
|
||||
resultText.length > 4000
|
||||
? resultText.slice(0, 4000) + "\n...(truncated)"
|
||||
: resultText;
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text:
|
||||
validationErrors.length > 0
|
||||
? `Extracted data (with ${validationErrors.length} validation warning(s)):\n${truncated}\n\nValidation warnings:\n${validationErrors.join("\n")}`
|
||||
: `Extracted data:\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
data: rawData.data,
|
||||
validationErrors:
|
||||
validationErrors.length > 0 ? validationErrors : undefined,
|
||||
fieldCount: extractionPlan.length,
|
||||
itemCount: multiple ? ((rawData.data as any[])?.length ?? 0) : 1,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Extraction failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
interface ExtractionField {
|
||||
name: string;
|
||||
selector: string;
|
||||
attribute: string;
|
||||
type: string;
|
||||
}
|
||||
|
||||
function buildExtractionPlan(schema: any): ExtractionField[] {
|
||||
const fields: ExtractionField[] = [];
|
||||
|
||||
if (!schema || typeof schema !== "object") return fields;
|
||||
|
||||
const properties = schema.properties ?? schema;
|
||||
|
||||
for (const [name, propSchema] of Object.entries(properties)) {
|
||||
const prop = propSchema as any;
|
||||
if (!prop || typeof prop !== "object") continue;
|
||||
|
||||
// Skip meta fields
|
||||
if (
|
||||
name === "type" ||
|
||||
name === "required" ||
|
||||
name === "properties" ||
|
||||
name === "$schema"
|
||||
)
|
||||
continue;
|
||||
|
||||
const selector =
|
||||
prop._selector ??
|
||||
prop.selector ??
|
||||
`[data-field="${name}"], .${name}, #${name}`;
|
||||
const attribute = prop._attribute ?? prop.attribute ?? "textContent";
|
||||
const type = prop.type ?? "string";
|
||||
|
||||
fields.push({ name, selector, attribute, type });
|
||||
}
|
||||
|
||||
return fields;
|
||||
}
|
||||
|
||||
async function validateData(
|
||||
data: unknown,
|
||||
schema: any,
|
||||
isArray: boolean,
|
||||
): Promise<string[]> {
|
||||
const errors: string[] = [];
|
||||
|
||||
try {
|
||||
const ajvModule = await import("ajv");
|
||||
const Ajv = ajvModule.default ?? ajvModule;
|
||||
const ajv = new (Ajv as any)({ allErrors: true, strict: false });
|
||||
|
||||
// Clean schema — remove our custom _selector/_attribute hints before validation
|
||||
const cleanSchema = cleanSchemaForValidation(schema);
|
||||
|
||||
// Wrap in array schema if multiple
|
||||
const validationSchema = isArray
|
||||
? { type: "array", items: cleanSchema }
|
||||
: cleanSchema;
|
||||
|
||||
const validate = ajv.compile(validationSchema);
|
||||
const valid = validate(data);
|
||||
|
||||
if (!valid && validate.errors) {
|
||||
for (const err of validate.errors) {
|
||||
errors.push(`${err.instancePath || "/"}: ${err.message}`);
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
errors.push(`Schema validation setup failed: ${err.message}`);
|
||||
}
|
||||
|
||||
return errors;
|
||||
}
|
||||
|
||||
function cleanSchemaForValidation(schema: any): any {
|
||||
if (!schema || typeof schema !== "object") return schema;
|
||||
if (Array.isArray(schema)) return schema.map(cleanSchemaForValidation);
|
||||
|
||||
const cleaned: any = {};
|
||||
for (const [key, value] of Object.entries(schema)) {
|
||||
if (key.startsWith("_")) continue; // Remove our custom hints
|
||||
if (key === "selector" && typeof value === "string") continue; // Also remove plain 'selector'
|
||||
if (key === "attribute" && typeof value === "string") continue; // Also remove plain 'attribute'
|
||||
cleaned[key] = cleanSchemaForValidation(value);
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
|
@ -1,918 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { CompactPageState, ToolDeps } from "../state.js";
|
||||
import { setLastActionAfterState, setLastActionBeforeState } from "../state.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Form analysis evaluate callback — runs in the browser context.
|
||||
// Self-contained: no external deps, no window.__pi calls.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FormFieldInfo {
|
||||
type: string;
|
||||
name: string;
|
||||
id: string;
|
||||
label: string;
|
||||
required: boolean;
|
||||
value: string;
|
||||
checked?: boolean;
|
||||
options?: Array<{ value: string; label: string; selected: boolean }>;
|
||||
validation: { valid: boolean; message: string };
|
||||
hidden: boolean;
|
||||
disabled: boolean;
|
||||
group?: string;
|
||||
}
|
||||
|
||||
interface FormSubmitButton {
|
||||
tag: string;
|
||||
type: string;
|
||||
text: string;
|
||||
name: string;
|
||||
disabled: boolean;
|
||||
}
|
||||
|
||||
interface FormAnalysisResult {
|
||||
formSelector: string;
|
||||
fields: FormFieldInfo[];
|
||||
submitButtons: FormSubmitButton[];
|
||||
fieldCount: number;
|
||||
visibleFieldCount: number;
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs inside page.evaluate(). Finds the target form, inventories all fields
|
||||
* with full label resolution, and returns a structured result.
|
||||
*/
|
||||
function buildFormAnalysisScript(selector?: string): string {
|
||||
// We return a string that will be evaluated in the page context.
|
||||
// This avoids serialization issues with passing functions.
|
||||
return `(() => {
|
||||
// --- helpers ---
|
||||
function isVisible(el) {
|
||||
if (!el) return false;
|
||||
const style = window.getComputedStyle(el);
|
||||
if (style.display === 'none' || style.visibility === 'hidden') return false;
|
||||
if (el.offsetWidth === 0 && el.offsetHeight === 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
function humanizeName(name) {
|
||||
if (!name) return '';
|
||||
return name
|
||||
.replace(/([a-z])([A-Z])/g, '$1 $2')
|
||||
.replace(/[_\\-]+/g, ' ')
|
||||
.replace(/\\bid\\b/i, 'ID')
|
||||
.trim()
|
||||
.replace(/^./, c => c.toUpperCase());
|
||||
}
|
||||
|
||||
function getTextContent(el) {
|
||||
if (!el) return '';
|
||||
return (el.textContent || '').trim().replace(/\\s+/g, ' ');
|
||||
}
|
||||
|
||||
// --- label resolution (7-level priority chain) ---
|
||||
function resolveLabel(field) {
|
||||
// 1. aria-labelledby
|
||||
const labelledBy = field.getAttribute('aria-labelledby');
|
||||
if (labelledBy) {
|
||||
const parts = labelledBy.split(/\\s+/).map(id => {
|
||||
const el = document.getElementById(id);
|
||||
return el ? getTextContent(el) : '';
|
||||
}).filter(Boolean);
|
||||
if (parts.length) return parts.join(' ');
|
||||
}
|
||||
|
||||
// 2. aria-label
|
||||
const ariaLabel = field.getAttribute('aria-label');
|
||||
if (ariaLabel && ariaLabel.trim()) return ariaLabel.trim();
|
||||
|
||||
// 3. label[for="id"]
|
||||
const fieldId = field.id;
|
||||
if (fieldId) {
|
||||
const labelFor = document.querySelector('label[for="' + CSS.escape(fieldId) + '"]');
|
||||
if (labelFor) {
|
||||
const text = getTextContent(labelFor);
|
||||
if (text) return text;
|
||||
}
|
||||
}
|
||||
|
||||
// 4. wrapping label
|
||||
const wrappingLabel = field.closest('label');
|
||||
if (wrappingLabel) {
|
||||
// Clone and remove the field itself to get just the label text
|
||||
const clone = wrappingLabel.cloneNode(true);
|
||||
const inputs = clone.querySelectorAll('input, select, textarea');
|
||||
inputs.forEach(inp => inp.remove());
|
||||
const text = (clone.textContent || '').trim().replace(/\\s+/g, ' ');
|
||||
if (text) return text;
|
||||
}
|
||||
|
||||
// 5. placeholder
|
||||
const placeholder = field.getAttribute('placeholder');
|
||||
if (placeholder && placeholder.trim()) return placeholder.trim();
|
||||
|
||||
// 6. title
|
||||
const title = field.getAttribute('title');
|
||||
if (title && title.trim()) return title.trim();
|
||||
|
||||
// 7. humanized name
|
||||
const name = field.getAttribute('name');
|
||||
if (name) return humanizeName(name);
|
||||
|
||||
return '';
|
||||
}
|
||||
|
||||
// --- form detection ---
|
||||
let form;
|
||||
const selectorArg = ${JSON.stringify(selector ?? null)};
|
||||
|
||||
if (selectorArg) {
|
||||
form = document.querySelector(selectorArg);
|
||||
if (!form) return { error: 'Form not found for selector: ' + selectorArg };
|
||||
} else {
|
||||
const forms = Array.from(document.querySelectorAll('form'));
|
||||
if (forms.length === 1) {
|
||||
form = forms[0];
|
||||
} else if (forms.length > 1) {
|
||||
// Pick form with most visible inputs
|
||||
let best = null;
|
||||
let bestCount = -1;
|
||||
for (const f of forms) {
|
||||
const inputs = f.querySelectorAll('input, select, textarea');
|
||||
let visCount = 0;
|
||||
inputs.forEach(inp => { if (isVisible(inp)) visCount++; });
|
||||
if (visCount > bestCount) {
|
||||
bestCount = visCount;
|
||||
best = f;
|
||||
}
|
||||
}
|
||||
form = best;
|
||||
} else {
|
||||
form = document.body;
|
||||
}
|
||||
}
|
||||
|
||||
// Build a useful selector for the form
|
||||
let formSelector = 'body';
|
||||
if (form !== document.body) {
|
||||
if (form.id) {
|
||||
formSelector = '#' + CSS.escape(form.id);
|
||||
} else if (form.getAttribute('name')) {
|
||||
formSelector = 'form[name="' + form.getAttribute('name') + '"]';
|
||||
} else if (form.getAttribute('action')) {
|
||||
formSelector = 'form[action="' + form.getAttribute('action') + '"]';
|
||||
} else {
|
||||
// nth-of-type fallback
|
||||
const allForms = Array.from(document.querySelectorAll('form'));
|
||||
const idx = allForms.indexOf(form);
|
||||
formSelector = idx >= 0 ? 'form:nth-of-type(' + (idx + 1) + ')' : 'form';
|
||||
}
|
||||
}
|
||||
|
||||
// --- field inventory ---
|
||||
const fieldElements = form.querySelectorAll('input, select, textarea');
|
||||
const fields = [];
|
||||
|
||||
fieldElements.forEach(field => {
|
||||
const tag = field.tagName.toLowerCase();
|
||||
const type = tag === 'select' ? 'select'
|
||||
: tag === 'textarea' ? 'textarea'
|
||||
: (field.getAttribute('type') || 'text').toLowerCase();
|
||||
|
||||
// Skip submit/button/reset/image inputs — they're not data fields
|
||||
if (tag === 'input' && ['submit', 'button', 'reset', 'image'].includes(type)) return;
|
||||
|
||||
const label = resolveLabel(field);
|
||||
const name = field.getAttribute('name') || '';
|
||||
const id = field.id || '';
|
||||
const required = field.required || field.getAttribute('aria-required') === 'true';
|
||||
const hidden = type === 'hidden' || !isVisible(field);
|
||||
const disabled = field.disabled;
|
||||
|
||||
// Value
|
||||
let value = '';
|
||||
if (tag === 'select') {
|
||||
const selected = field.querySelector('option:checked');
|
||||
value = selected ? selected.value : '';
|
||||
} else {
|
||||
value = field.value || '';
|
||||
}
|
||||
|
||||
const info = {
|
||||
type,
|
||||
name,
|
||||
id,
|
||||
label,
|
||||
required,
|
||||
value,
|
||||
hidden,
|
||||
disabled,
|
||||
validation: {
|
||||
valid: field.validity ? field.validity.valid : true,
|
||||
message: field.validationMessage || '',
|
||||
},
|
||||
};
|
||||
|
||||
// Checked state for checkboxes/radios
|
||||
if (type === 'checkbox' || type === 'radio') {
|
||||
info.checked = field.checked;
|
||||
}
|
||||
|
||||
// Options for select elements
|
||||
if (tag === 'select') {
|
||||
info.options = Array.from(field.querySelectorAll('option')).map(opt => ({
|
||||
value: opt.value,
|
||||
label: opt.textContent.trim(),
|
||||
selected: opt.selected,
|
||||
}));
|
||||
}
|
||||
|
||||
// Fieldset/legend group
|
||||
const fieldset = field.closest('fieldset');
|
||||
if (fieldset) {
|
||||
const legend = fieldset.querySelector('legend');
|
||||
if (legend) {
|
||||
info.group = getTextContent(legend);
|
||||
}
|
||||
}
|
||||
|
||||
fields.push(info);
|
||||
});
|
||||
|
||||
// --- submit buttons ---
|
||||
const submitButtons = [];
|
||||
const buttonCandidates = form.querySelectorAll('button, input[type="submit"]');
|
||||
buttonCandidates.forEach(btn => {
|
||||
const tag = btn.tagName.toLowerCase();
|
||||
const type = (btn.getAttribute('type') || (tag === 'button' ? 'submit' : '')).toLowerCase();
|
||||
// Include: explicit submit, or button without explicit type (defaults to submit)
|
||||
if (type === 'submit' || (tag === 'button' && !btn.getAttribute('type'))) {
|
||||
submitButtons.push({
|
||||
tag,
|
||||
type: type || 'submit',
|
||||
text: tag === 'input' ? (btn.value || '') : getTextContent(btn),
|
||||
name: btn.getAttribute('name') || '',
|
||||
disabled: btn.disabled,
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
const visibleFieldCount = fields.filter(f => !f.hidden).length;
|
||||
|
||||
return {
|
||||
formSelector,
|
||||
fields,
|
||||
submitButtons,
|
||||
fieldCount: fields.length,
|
||||
visibleFieldCount,
|
||||
};
|
||||
})()`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Post-fill validation collection — runs in browser context.
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function buildPostFillValidationScript(formSelector: string): string {
|
||||
return `(() => {
|
||||
const form = ${JSON.stringify(formSelector)} === 'body'
|
||||
? document.body
|
||||
: document.querySelector(${JSON.stringify(formSelector)});
|
||||
if (!form) return { valid: false, invalidCount: 0, fields: [] };
|
||||
|
||||
const fieldEls = form.querySelectorAll('input, select, textarea');
|
||||
let validCount = 0;
|
||||
let invalidCount = 0;
|
||||
const invalidFields = [];
|
||||
|
||||
fieldEls.forEach(f => {
|
||||
const tag = f.tagName.toLowerCase();
|
||||
const type = tag === 'select' ? 'select'
|
||||
: tag === 'textarea' ? 'textarea'
|
||||
: (f.getAttribute('type') || 'text').toLowerCase();
|
||||
if (['submit', 'button', 'reset', 'image', 'hidden'].includes(type)) return;
|
||||
|
||||
if (f.validity && !f.validity.valid) {
|
||||
invalidCount++;
|
||||
invalidFields.push({
|
||||
name: f.getAttribute('name') || f.id || type,
|
||||
message: f.validationMessage || 'Invalid',
|
||||
});
|
||||
} else {
|
||||
validCount++;
|
||||
}
|
||||
});
|
||||
|
||||
return {
|
||||
valid: invalidCount === 0,
|
||||
validCount,
|
||||
invalidCount,
|
||||
invalidFields,
|
||||
};
|
||||
})()`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function registerFormTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -----------------------------------------------------------------------
|
||||
// browser_analyze_form
|
||||
// -----------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_analyze_form",
|
||||
label: "Analyze Form",
|
||||
description:
|
||||
"Analyze a form on the current page and return a structured field inventory. Auto-detects the form if no selector is provided (picks the single <form>, or the form with most visible inputs, or falls back to document.body). Returns field types, labels (resolved via aria-labelledby → aria-label → label[for] → wrapping label → placeholder → title → name), values, validation state, and submit buttons.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector targeting the form element to analyze. If omitted, auto-detects the primary form on the page.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.selector ? [params.selector] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_analyze_form",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
|
||||
const script = buildFormAnalysisScript(params.selector);
|
||||
const result = (await target.evaluate(script)) as FormAnalysisResult & {
|
||||
error?: string;
|
||||
};
|
||||
|
||||
if (result.error) {
|
||||
deps.finishTrackedAction(actionId!, {
|
||||
status: "error",
|
||||
error: result.error,
|
||||
beforeState,
|
||||
});
|
||||
return {
|
||||
content: [{ type: "text" as const, text: result.error }],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.selector ? [params.selector] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
setLastActionBeforeState(beforeState);
|
||||
setLastActionAfterState(afterState);
|
||||
|
||||
deps.finishTrackedAction(actionId!, {
|
||||
status: "success",
|
||||
afterUrl: afterState.url,
|
||||
beforeState,
|
||||
afterState,
|
||||
});
|
||||
|
||||
// Format output
|
||||
const lines: string[] = [];
|
||||
lines.push(`Form: ${result.formSelector}`);
|
||||
lines.push(
|
||||
`Fields: ${result.fieldCount} total, ${result.visibleFieldCount} visible`,
|
||||
);
|
||||
lines.push(`Submit buttons: ${result.submitButtons.length}`);
|
||||
lines.push("");
|
||||
|
||||
if (result.fields.length > 0) {
|
||||
lines.push("## Fields");
|
||||
for (const f of result.fields) {
|
||||
const flags: string[] = [];
|
||||
if (f.required) flags.push("required");
|
||||
if (f.hidden) flags.push("hidden");
|
||||
if (f.disabled) flags.push("disabled");
|
||||
if (f.checked !== undefined)
|
||||
flags.push(f.checked ? "checked" : "unchecked");
|
||||
if (!f.validation.valid)
|
||||
flags.push(`invalid: ${f.validation.message}`);
|
||||
|
||||
const flagStr = flags.length ? ` [${flags.join(", ")}]` : "";
|
||||
const valueStr = f.value ? ` = "${f.value}"` : "";
|
||||
const labelStr = f.label || "(no label)";
|
||||
const selectorHint = f.id
|
||||
? `#${f.id}`
|
||||
: f.name
|
||||
? `[name="${f.name}"]`
|
||||
: f.type;
|
||||
const groupStr = f.group ? ` (group: ${f.group})` : "";
|
||||
|
||||
lines.push(
|
||||
`- **${labelStr}** \`${f.type}\` \`${selectorHint}\`${valueStr}${flagStr}${groupStr}`,
|
||||
);
|
||||
|
||||
if (f.options && f.options.length > 0) {
|
||||
for (const opt of f.options) {
|
||||
const sel = opt.selected ? " ✓" : "";
|
||||
lines.push(` - ${opt.label} (${opt.value})${sel}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (result.submitButtons.length > 0) {
|
||||
lines.push("## Submit Buttons");
|
||||
for (const btn of result.submitButtons) {
|
||||
const disStr = btn.disabled ? " [disabled]" : "";
|
||||
lines.push(
|
||||
`- "${btn.text}" \`<${btn.tag} type="${btn.type}">\`${btn.name ? ` name="${btn.name}"` : ""}${disStr}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text" as const, text: lines.join("\n") }],
|
||||
details: { formAnalysis: result },
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const screenshot = await deps.captureErrorScreenshot(
|
||||
(() => {
|
||||
try {
|
||||
return deps.getActivePage();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})(),
|
||||
);
|
||||
const errMsg = deps.firstErrorLine(err);
|
||||
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: errMsg,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "image"; data: string; mimeType: string }
|
||||
> = [{ type: "text", text: `browser_analyze_form failed: ${errMsg}` }];
|
||||
if (screenshot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: screenshot.data,
|
||||
mimeType: screenshot.mimeType,
|
||||
});
|
||||
}
|
||||
|
||||
return { content, details: {}, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// browser_fill_form
|
||||
// -----------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_fill_form",
|
||||
label: "Fill Form",
|
||||
description:
|
||||
"Fill a form on the current page using a values mapping. Keys are field identifiers (label text, name attribute, placeholder, or aria-label). Resolves fields by label → name → placeholder → aria-label (exact first, then case-insensitive). Uses fill() for text inputs, selectOption() for selects, setChecked() for checkboxes/radios. Skips file and hidden inputs. Optionally submits the form.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector targeting the form element. If omitted, auto-detects the primary form.",
|
||||
}),
|
||||
),
|
||||
values: Type.Record(Type.String(), Type.String(), {
|
||||
description:
|
||||
"Mapping of field identifiers to values. Keys can be label text, name, placeholder, or aria-label. Values are strings — for checkboxes use 'true'/'false' or 'on'/'off', for selects use the option label or value.",
|
||||
}),
|
||||
submit: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"If true, clicks the form's submit button after filling all fields.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.selector ? [params.selector] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_fill_form",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
|
||||
// --- Detect form selector ---
|
||||
// Reuse the same detection logic as analyze_form via a lightweight evaluate
|
||||
const formSelector: string =
|
||||
params.selector ??
|
||||
((await target.evaluate(`(() => {
|
||||
const forms = Array.from(document.querySelectorAll('form'));
|
||||
if (forms.length === 1) {
|
||||
const f = forms[0];
|
||||
if (f.id) return '#' + CSS.escape(f.id);
|
||||
if (f.getAttribute('name')) return 'form[name="' + f.getAttribute('name') + '"]';
|
||||
return 'form';
|
||||
} else if (forms.length > 1) {
|
||||
let best = null;
|
||||
let bestCount = -1;
|
||||
let bestIdx = 0;
|
||||
for (let i = 0; i < forms.length; i++) {
|
||||
const inputs = forms[i].querySelectorAll('input, select, textarea');
|
||||
let vis = 0;
|
||||
inputs.forEach(inp => {
|
||||
const s = window.getComputedStyle(inp);
|
||||
if (s.display !== 'none' && s.visibility !== 'hidden') vis++;
|
||||
});
|
||||
if (vis > bestCount) { bestCount = vis; best = forms[i]; bestIdx = i; }
|
||||
}
|
||||
if (best.id) return '#' + CSS.escape(best.id);
|
||||
if (best.getAttribute('name')) return 'form[name="' + best.getAttribute('name') + '"]';
|
||||
return 'form:nth-of-type(' + (bestIdx + 1) + ')';
|
||||
}
|
||||
return 'body';
|
||||
})()`)) as string);
|
||||
|
||||
const formLocator =
|
||||
formSelector === "body"
|
||||
? target.locator("body")
|
||||
: target.locator(formSelector);
|
||||
|
||||
// --- Resolve and fill each field ---
|
||||
interface MatchedField {
|
||||
key: string;
|
||||
resolvedBy: string;
|
||||
value: string;
|
||||
fieldType: string;
|
||||
}
|
||||
interface UnmatchedField {
|
||||
key: string;
|
||||
reason: string;
|
||||
}
|
||||
interface SkippedField {
|
||||
key: string;
|
||||
reason: string;
|
||||
}
|
||||
|
||||
const matched: MatchedField[] = [];
|
||||
const unmatched: UnmatchedField[] = [];
|
||||
const skipped: SkippedField[] = [];
|
||||
|
||||
for (const [key, value] of Object.entries(params.values)) {
|
||||
// Try to resolve the field in priority order
|
||||
let resolvedLocator: ReturnType<typeof formLocator.locator> | null =
|
||||
null;
|
||||
let resolvedBy = "";
|
||||
|
||||
// 1. Exact label match
|
||||
try {
|
||||
const loc = formLocator.getByLabel(key, { exact: true });
|
||||
const count = await loc.count();
|
||||
if (count === 1) {
|
||||
resolvedLocator = loc;
|
||||
resolvedBy = "label (exact)";
|
||||
} else if (count > 1) {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: `Ambiguous: ${count} fields match label "${key}"`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
/* not found, try next */
|
||||
}
|
||||
|
||||
// 2. Case-insensitive label match
|
||||
if (!resolvedLocator) {
|
||||
try {
|
||||
const loc = formLocator.getByLabel(key);
|
||||
const count = await loc.count();
|
||||
if (count === 1) {
|
||||
resolvedLocator = loc;
|
||||
resolvedBy = "label";
|
||||
} else if (count > 1) {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: `Ambiguous: ${count} fields match label "${key}" (case-insensitive)`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
/* not found, try next */
|
||||
}
|
||||
}
|
||||
|
||||
// 3. name attribute
|
||||
if (!resolvedLocator) {
|
||||
try {
|
||||
const loc = formLocator.locator(`[name="${CSS.escape(key)}"]`);
|
||||
const count = await loc.count();
|
||||
if (count === 1) {
|
||||
resolvedLocator = loc;
|
||||
resolvedBy = "name";
|
||||
} else if (count > 1) {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: `Ambiguous: ${count} fields match name="${key}"`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
/* not found, try next */
|
||||
}
|
||||
}
|
||||
|
||||
// 4. placeholder attribute (case-insensitive)
|
||||
if (!resolvedLocator) {
|
||||
try {
|
||||
const loc = formLocator.locator(`[placeholder="${key}" i]`);
|
||||
const count = await loc.count();
|
||||
if (count === 1) {
|
||||
resolvedLocator = loc;
|
||||
resolvedBy = "placeholder";
|
||||
} else if (count > 1) {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: `Ambiguous: ${count} fields match placeholder="${key}"`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
/* not found, try next */
|
||||
}
|
||||
}
|
||||
|
||||
// 5. aria-label attribute (case-insensitive)
|
||||
if (!resolvedLocator) {
|
||||
try {
|
||||
const loc = formLocator.locator(`[aria-label="${key}" i]`);
|
||||
const count = await loc.count();
|
||||
if (count === 1) {
|
||||
resolvedLocator = loc;
|
||||
resolvedBy = "aria-label";
|
||||
} else if (count > 1) {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: `Ambiguous: ${count} fields match aria-label="${key}"`,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
} catch {
|
||||
/* not found, try next */
|
||||
}
|
||||
}
|
||||
|
||||
if (!resolvedLocator) {
|
||||
unmatched.push({ key, reason: "No matching field found" });
|
||||
continue;
|
||||
}
|
||||
|
||||
// Determine field type
|
||||
const fieldInfo = await resolvedLocator
|
||||
.first()
|
||||
.evaluate((el: Element) => {
|
||||
const tag = el.tagName.toLowerCase();
|
||||
const type =
|
||||
tag === "select"
|
||||
? "select"
|
||||
: tag === "textarea"
|
||||
? "textarea"
|
||||
: ((el as HTMLInputElement).type || "text").toLowerCase();
|
||||
const hidden =
|
||||
type === "hidden" ||
|
||||
window.getComputedStyle(el).display === "none" ||
|
||||
window.getComputedStyle(el).visibility === "hidden";
|
||||
return { tag, type, hidden };
|
||||
});
|
||||
|
||||
// Skip file inputs
|
||||
if (fieldInfo.type === "file") {
|
||||
skipped.push({
|
||||
key,
|
||||
reason: "File input — use browser_upload_file instead",
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip hidden inputs
|
||||
if (fieldInfo.hidden) {
|
||||
skipped.push({ key, reason: "Hidden input" });
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fill based on type
|
||||
try {
|
||||
if (fieldInfo.type === "checkbox" || fieldInfo.type === "radio") {
|
||||
const checked = value === "true" || value === "on";
|
||||
await resolvedLocator
|
||||
.first()
|
||||
.setChecked(checked, { timeout: 5000 });
|
||||
matched.push({
|
||||
key,
|
||||
resolvedBy,
|
||||
value: checked ? "checked" : "unchecked",
|
||||
fieldType: fieldInfo.type,
|
||||
});
|
||||
} else if (fieldInfo.tag === "select") {
|
||||
// Try label first, then value
|
||||
try {
|
||||
await resolvedLocator
|
||||
.first()
|
||||
.selectOption({ label: value }, { timeout: 5000 });
|
||||
} catch {
|
||||
await resolvedLocator
|
||||
.first()
|
||||
.selectOption({ value }, { timeout: 5000 });
|
||||
}
|
||||
matched.push({ key, resolvedBy, value, fieldType: "select" });
|
||||
} else {
|
||||
// Text-like inputs and textarea
|
||||
await resolvedLocator.first().fill(value, { timeout: 5000 });
|
||||
matched.push({
|
||||
key,
|
||||
resolvedBy,
|
||||
value,
|
||||
fieldType: fieldInfo.type,
|
||||
});
|
||||
}
|
||||
} catch (fillErr: unknown) {
|
||||
const msg =
|
||||
fillErr instanceof Error ? fillErr.message : String(fillErr);
|
||||
skipped.push({ key, reason: `Fill failed: ${msg.split("\n")[0]}` });
|
||||
}
|
||||
}
|
||||
|
||||
// --- Settle after all fills ---
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
|
||||
// --- Submit if requested ---
|
||||
let submitted = false;
|
||||
if (params.submit) {
|
||||
try {
|
||||
// Find submit button in form
|
||||
const submitLoc = formLocator
|
||||
.locator('[type="submit"], button:not([type])')
|
||||
.first();
|
||||
const submitExists = await submitLoc.count();
|
||||
if (submitExists > 0) {
|
||||
await submitLoc.click({ timeout: 5000 });
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
submitted = true;
|
||||
} else {
|
||||
skipped.push({
|
||||
key: "_submit",
|
||||
reason: "No submit button found in form",
|
||||
});
|
||||
}
|
||||
} catch (submitErr: unknown) {
|
||||
const msg =
|
||||
submitErr instanceof Error
|
||||
? submitErr.message
|
||||
: String(submitErr);
|
||||
skipped.push({
|
||||
key: "_submit",
|
||||
reason: `Submit failed: ${msg.split("\n")[0]}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// --- Post-fill validation state ---
|
||||
const validationSummary = (await target.evaluate(
|
||||
buildPostFillValidationScript(formSelector),
|
||||
)) as {
|
||||
valid: boolean;
|
||||
validCount: number;
|
||||
invalidCount: number;
|
||||
invalidFields: Array<{ name: string; message: string }>;
|
||||
};
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.selector ? [params.selector] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
setLastActionBeforeState(beforeState);
|
||||
setLastActionAfterState(afterState);
|
||||
|
||||
deps.finishTrackedAction(actionId!, {
|
||||
status: "success",
|
||||
afterUrl: afterState.url,
|
||||
beforeState,
|
||||
afterState,
|
||||
});
|
||||
|
||||
// --- Format output ---
|
||||
const lines: string[] = [];
|
||||
lines.push(`Form: ${formSelector}`);
|
||||
lines.push(
|
||||
`Filled: ${matched.length} | Unmatched: ${unmatched.length} | Skipped: ${skipped.length}${submitted ? " | Submitted: yes" : ""}`,
|
||||
);
|
||||
lines.push("");
|
||||
|
||||
if (matched.length > 0) {
|
||||
lines.push("## Matched");
|
||||
for (const m of matched) {
|
||||
lines.push(
|
||||
`- ✓ **${m.key}** → "${m.value}" (${m.fieldType}, resolved by ${m.resolvedBy})`,
|
||||
);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (unmatched.length > 0) {
|
||||
lines.push("## Unmatched");
|
||||
for (const u of unmatched) {
|
||||
lines.push(`- ✗ **${u.key}** — ${u.reason}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (skipped.length > 0) {
|
||||
lines.push("## Skipped");
|
||||
for (const s of skipped) {
|
||||
lines.push(`- ⊘ **${s.key}** — ${s.reason}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (!validationSummary.valid) {
|
||||
lines.push("## Validation Issues");
|
||||
for (const inv of validationSummary.invalidFields) {
|
||||
lines.push(`- ${inv.name}: ${inv.message}`);
|
||||
}
|
||||
} else {
|
||||
lines.push("Validation: all fields valid ✓");
|
||||
}
|
||||
|
||||
const fillResult = {
|
||||
matched,
|
||||
unmatched,
|
||||
skipped,
|
||||
submitted,
|
||||
validationSummary,
|
||||
};
|
||||
|
||||
return {
|
||||
content: [{ type: "text" as const, text: lines.join("\n") }],
|
||||
details: { fillResult },
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const screenshot = await deps.captureErrorScreenshot(
|
||||
(() => {
|
||||
try {
|
||||
return deps.getActivePage();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})(),
|
||||
);
|
||||
const errMsg = deps.firstErrorLine(err);
|
||||
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: errMsg,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "image"; data: string; mimeType: string }
|
||||
> = [{ type: "text", text: `browser_fill_form failed: ${errMsg}` }];
|
||||
if (screenshot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: screenshot.data,
|
||||
mimeType: screenshot.mimeType,
|
||||
});
|
||||
}
|
||||
|
||||
return { content, details: {}, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,337 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Prompt injection detection — scan page content for text attempting to hijack the agent.
|
||||
*/
|
||||
|
||||
// Known injection patterns — regex patterns that match common prompt injection attempts
|
||||
const INJECTION_PATTERNS: Array<{
|
||||
pattern: RegExp;
|
||||
category: string;
|
||||
severity: "high" | "medium" | "low";
|
||||
}> = [
|
||||
// Direct instruction override attempts
|
||||
{
|
||||
pattern: /ignore\s+(all\s+)?previous\s+(instructions?|prompts?)/i,
|
||||
category: "instruction_override",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /disregard\s+(all\s+)?previous\s+(instructions?|prompts?)/i,
|
||||
category: "instruction_override",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /forget\s+(all\s+)?previous\s+(instructions?|prompts?)/i,
|
||||
category: "instruction_override",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /override\s+(all\s+)?previous\s+(instructions?|prompts?)/i,
|
||||
category: "instruction_override",
|
||||
severity: "high",
|
||||
},
|
||||
|
||||
// System prompt extraction
|
||||
{
|
||||
pattern:
|
||||
/(?:what|show|reveal|display|repeat|tell)\s+(?:me\s+)?(?:your|the)\s+system\s+prompt/i,
|
||||
category: "prompt_extraction",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /print\s+(?:your|the)\s+(?:system\s+)?(?:prompt|instructions)/i,
|
||||
category: "prompt_extraction",
|
||||
severity: "high",
|
||||
},
|
||||
|
||||
// Role reassignment
|
||||
{
|
||||
pattern:
|
||||
/you\s+are\s+now\s+(?:a\s+)?(?:DAN|jailbreak|unrestricted|unfiltered)/i,
|
||||
category: "role_reassignment",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /act\s+as\s+(?:a\s+)?(?:DAN|jailbreak|unrestricted|unfiltered)/i,
|
||||
category: "role_reassignment",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /entering\s+(?:a\s+)?(?:developer|admin|root|sudo)\s+mode/i,
|
||||
category: "role_reassignment",
|
||||
severity: "high",
|
||||
},
|
||||
|
||||
// Hidden instruction injection
|
||||
{
|
||||
pattern: /\[SYSTEM\]\s*:/i,
|
||||
category: "fake_system_message",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\[INST\]\s*:/i,
|
||||
category: "fake_system_message",
|
||||
severity: "medium",
|
||||
},
|
||||
{
|
||||
pattern: /<\/?system>/i,
|
||||
category: "fake_system_message",
|
||||
severity: "high",
|
||||
},
|
||||
|
||||
// Tool/action manipulation
|
||||
{
|
||||
pattern: /execute\s+(?:the\s+following\s+)?(?:command|code|script)/i,
|
||||
category: "command_injection",
|
||||
severity: "medium",
|
||||
},
|
||||
{
|
||||
pattern: /run\s+(?:this|the\s+following)\s+(?:command|code|script)/i,
|
||||
category: "command_injection",
|
||||
severity: "medium",
|
||||
},
|
||||
|
||||
// Invisible text / social engineering
|
||||
{
|
||||
pattern:
|
||||
/do\s+not\s+(?:read|process|show)\s+(?:the\s+)?(?:following|rest)/i,
|
||||
category: "social_engineering",
|
||||
severity: "low",
|
||||
},
|
||||
{
|
||||
pattern:
|
||||
/(?:this|the\s+following)\s+(?:is|are)\s+(?:your\s+)?new\s+instructions/i,
|
||||
category: "instruction_override",
|
||||
severity: "high",
|
||||
},
|
||||
|
||||
// Base64/encoded content markers
|
||||
{
|
||||
pattern: /base64\s*:\s*[A-Za-z0-9+/=]{50,}/i,
|
||||
category: "encoded_payload",
|
||||
severity: "medium",
|
||||
},
|
||||
];
|
||||
|
||||
export function registerInjectionDetectionTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
pi.registerTool({
|
||||
name: "browser_check_injection",
|
||||
label: "Browser Check Injection",
|
||||
description:
|
||||
"Scan current page content for potential prompt injection attempts. " +
|
||||
"Checks visible text and hidden elements for patterns that might hijack the agent. " +
|
||||
"Returns findings with severity levels. Use after navigating to untrusted pages.",
|
||||
parameters: Type.Object({
|
||||
includeHidden: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Also scan hidden/invisible text (default: true). " +
|
||||
"Hidden text is a common vector for injection attacks.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const includeHidden = params.includeHidden ?? true;
|
||||
|
||||
// Extract text content from the page
|
||||
const pageContent = await p.evaluate((scanHidden: boolean) => {
|
||||
const results: Array<{
|
||||
text: string;
|
||||
source: string;
|
||||
visible: boolean;
|
||||
}> = [];
|
||||
|
||||
// 1. Visible text content
|
||||
const bodyText = document.body?.innerText ?? "";
|
||||
results.push({
|
||||
text: bodyText,
|
||||
source: "body_visible_text",
|
||||
visible: true,
|
||||
});
|
||||
|
||||
// 2. Title and meta
|
||||
results.push({
|
||||
text: document.title,
|
||||
source: "page_title",
|
||||
visible: true,
|
||||
});
|
||||
|
||||
// Meta descriptions and keywords
|
||||
const metas = document.querySelectorAll("meta[name], meta[property]");
|
||||
for (const meta of metas) {
|
||||
const content = meta.getAttribute("content");
|
||||
if (content) {
|
||||
results.push({
|
||||
text: content,
|
||||
source: `meta:${meta.getAttribute("name") || meta.getAttribute("property")}`,
|
||||
visible: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if (scanHidden) {
|
||||
// 3. Hidden elements (display:none, visibility:hidden, opacity:0, off-screen, aria-hidden)
|
||||
const allElements = document.querySelectorAll("*");
|
||||
for (const el of allElements) {
|
||||
const htmlEl = el as HTMLElement;
|
||||
const style = window.getComputedStyle(htmlEl);
|
||||
const isHidden =
|
||||
style.display === "none" ||
|
||||
style.visibility === "hidden" ||
|
||||
style.opacity === "0" ||
|
||||
htmlEl.getAttribute("aria-hidden") === "true" ||
|
||||
(htmlEl.offsetWidth === 0 && htmlEl.offsetHeight === 0);
|
||||
|
||||
if (isHidden && htmlEl.textContent?.trim()) {
|
||||
const text = htmlEl.textContent.trim();
|
||||
if (text.length > 5 && text.length < 5000) {
|
||||
results.push({
|
||||
text,
|
||||
source: "hidden_element",
|
||||
visible: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 4. HTML comments
|
||||
const walker = document.createTreeWalker(
|
||||
document.documentElement,
|
||||
NodeFilter.SHOW_COMMENT,
|
||||
);
|
||||
let node: Node | null;
|
||||
// biome-ignore lint/suspicious/noAssignInExpressions: read-loop pattern
|
||||
while ((node = walker.nextNode())) {
|
||||
const text = (node as Comment).textContent?.trim() ?? "";
|
||||
if (text.length > 10) {
|
||||
results.push({ text, source: "html_comment", visible: false });
|
||||
}
|
||||
}
|
||||
|
||||
// 5. Data attributes with text content
|
||||
const dataElements = document.querySelectorAll(
|
||||
"[data-prompt], [data-instruction], [data-system]",
|
||||
);
|
||||
for (const el of dataElements) {
|
||||
for (const attr of el.attributes) {
|
||||
if (attr.name.startsWith("data-") && attr.value.length > 10) {
|
||||
results.push({
|
||||
text: attr.value,
|
||||
source: `data_attribute:${attr.name}`,
|
||||
visible: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return results;
|
||||
}, includeHidden);
|
||||
|
||||
// Scan all extracted text against injection patterns
|
||||
const findings: Array<{
|
||||
pattern: string;
|
||||
category: string;
|
||||
severity: string;
|
||||
source: string;
|
||||
visible: boolean;
|
||||
matchedText: string;
|
||||
}> = [];
|
||||
|
||||
for (const { text, source, visible } of pageContent) {
|
||||
for (const { pattern, category, severity } of INJECTION_PATTERNS) {
|
||||
const match = text.match(pattern);
|
||||
if (match) {
|
||||
findings.push({
|
||||
pattern: pattern.source.slice(0, 60),
|
||||
category,
|
||||
severity,
|
||||
source,
|
||||
visible,
|
||||
matchedText: match[0].slice(0, 100),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deduplicate findings by category + source
|
||||
const seen = new Set<string>();
|
||||
const uniqueFindings = findings.filter((f) => {
|
||||
const key = `${f.category}|${f.source}|${f.matchedText}`;
|
||||
if (seen.has(key)) return false;
|
||||
seen.add(key);
|
||||
return true;
|
||||
});
|
||||
|
||||
const highCount = uniqueFindings.filter(
|
||||
(f) => f.severity === "high",
|
||||
).length;
|
||||
const medCount = uniqueFindings.filter(
|
||||
(f) => f.severity === "medium",
|
||||
).length;
|
||||
const lowCount = uniqueFindings.filter(
|
||||
(f) => f.severity === "low",
|
||||
).length;
|
||||
|
||||
if (uniqueFindings.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `No prompt injection patterns detected.\nScanned: ${pageContent.length} text regions (hidden: ${includeHidden})`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
clean: true,
|
||||
scannedRegions: pageContent.length,
|
||||
includeHidden,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const findingLines = uniqueFindings.map(
|
||||
(f) =>
|
||||
` [${f.severity.toUpperCase()}] ${f.category} in ${f.source}${!f.visible ? " (HIDDEN)" : ""}: "${f.matchedText}"`,
|
||||
);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `⚠️ Prompt injection patterns detected: ${uniqueFindings.length} finding(s)\nHigh: ${highCount} | Medium: ${medCount} | Low: ${lowCount}\n\n${findingLines.join("\n")}\n\n⚠️ This page may be attempting to manipulate the agent. Proceed with caution.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
clean: false,
|
||||
findings: uniqueFindings,
|
||||
counts: {
|
||||
high: highCount,
|
||||
medium: medCount,
|
||||
low: lowCount,
|
||||
total: uniqueFindings.length,
|
||||
},
|
||||
scannedRegions: pageContent.length,
|
||||
includeHidden,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Injection check failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,549 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { StringEnum } from "@singularity-forge/pi-ai";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
import {
|
||||
getConsoleLogs,
|
||||
getDialogLogs,
|
||||
getNetworkLogs,
|
||||
setConsoleLogs,
|
||||
setDialogLogs,
|
||||
setNetworkLogs,
|
||||
} from "../state.js";
|
||||
|
||||
export function registerInspectionTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_console_logs
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_console_logs",
|
||||
label: "Browser Console Logs",
|
||||
description:
|
||||
"Get all buffered browser console logs and JavaScript errors captured since the last clear. Each entry includes timestamp and page URL. Note: JS errors are also auto-surfaced in interaction tool responses — use this for the full log.",
|
||||
parameters: Type.Object({
|
||||
clear: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Clear the buffer after returning logs (default: true)",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const shouldClear = params.clear !== false;
|
||||
const logs = [...getConsoleLogs()];
|
||||
|
||||
if (shouldClear) {
|
||||
setConsoleLogs([]);
|
||||
}
|
||||
|
||||
if (logs.length === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No console logs captured." }],
|
||||
details: { logs: [], count: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
const formatted = logs
|
||||
.map((entry) => {
|
||||
const time = new Date(entry.timestamp).toISOString().slice(11, 23);
|
||||
return `[${time}] [${entry.type.toUpperCase()}] ${entry.text}`;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
const truncated = deps.truncateText(formatted);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${logs.length} console log(s):\n\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: { logs, count: logs.length },
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_network_logs
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_network_logs",
|
||||
label: "Browser Network Logs",
|
||||
description:
|
||||
"Get buffered network requests and responses. Shows method, URL, status code, and resource type for all requests. Includes response body for failed requests (4xx/5xx). Use to debug API failures, CORS issues, missing resources, and auth problems.",
|
||||
parameters: Type.Object({
|
||||
clear: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Clear the buffer after returning logs (default: true)",
|
||||
}),
|
||||
),
|
||||
filter: Type.Optional(
|
||||
StringEnum(["all", "errors", "fetch-xhr"] as const),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const shouldClear = params.clear !== false;
|
||||
let logs = [...getNetworkLogs()];
|
||||
|
||||
if (shouldClear) {
|
||||
setNetworkLogs([]);
|
||||
}
|
||||
|
||||
if (params.filter === "errors") {
|
||||
logs = logs.filter(
|
||||
(e) => e.failed || (e.status !== null && e.status >= 400),
|
||||
);
|
||||
} else if (params.filter === "fetch-xhr") {
|
||||
logs = logs.filter(
|
||||
(e) => e.resourceType === "fetch" || e.resourceType === "xhr",
|
||||
);
|
||||
}
|
||||
|
||||
if (logs.length === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No network requests captured." }],
|
||||
details: { logs: [], count: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
const formatted = logs
|
||||
.map((entry) => {
|
||||
const time = new Date(entry.timestamp).toISOString().slice(11, 23);
|
||||
const status = entry.failed
|
||||
? `FAILED (${entry.failureText})`
|
||||
: `${entry.status}`;
|
||||
let line = `[${time}] ${entry.method} ${entry.url} → ${status} (${entry.resourceType})`;
|
||||
if (entry.responseBody) {
|
||||
line += `\n Response: ${entry.responseBody}`;
|
||||
}
|
||||
return line;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
const truncated = deps.truncateText(formatted);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${logs.length} network request(s):\n\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: { count: logs.length },
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_dialog_logs
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_dialog_logs",
|
||||
label: "Browser Dialog Logs",
|
||||
description:
|
||||
"Get buffered JavaScript dialog events (alert, confirm, prompt, beforeunload). Dialogs are auto-accepted to prevent page freezes. Use this to see what dialogs appeared and their messages.",
|
||||
parameters: Type.Object({
|
||||
clear: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Clear the buffer after returning logs (default: true)",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const shouldClear = params.clear !== false;
|
||||
const logs = [...getDialogLogs()];
|
||||
|
||||
if (shouldClear) {
|
||||
setDialogLogs([]);
|
||||
}
|
||||
|
||||
if (logs.length === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No dialog events captured." }],
|
||||
details: { logs: [], count: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
const formatted = logs
|
||||
.map((entry) => {
|
||||
const time = new Date(entry.timestamp).toISOString().slice(11, 23);
|
||||
let line = `[${time}] ${entry.type}: "${entry.message}"`;
|
||||
if (entry.defaultValue) {
|
||||
line += ` (default: "${entry.defaultValue}")`;
|
||||
}
|
||||
line += ` → auto-accepted`;
|
||||
return line;
|
||||
})
|
||||
.join("\n");
|
||||
|
||||
const truncated = deps.truncateText(formatted);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${logs.length} dialog(s):\n\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: { logs, count: logs.length },
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_evaluate
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_evaluate",
|
||||
label: "Browser Evaluate",
|
||||
description:
|
||||
"Execute a JavaScript expression in the browser context and return the result. Useful for reading DOM state, checking values, etc.",
|
||||
parameters: Type.Object({
|
||||
expression: Type.String({
|
||||
description: "JavaScript expression to evaluate in the page context",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const result = await target.evaluate(params.expression);
|
||||
|
||||
let serialized: string;
|
||||
if (result === undefined) {
|
||||
serialized = "undefined";
|
||||
} else {
|
||||
try {
|
||||
serialized = JSON.stringify(result, null, 2) ?? "undefined";
|
||||
} catch {
|
||||
serialized = `[non-serializable: ${typeof result}]`;
|
||||
}
|
||||
}
|
||||
|
||||
const truncated = deps.truncateText(serialized);
|
||||
return {
|
||||
content: [{ type: "text", text: truncated }],
|
||||
details: { expression: params.expression },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Evaluation failed: ${err.message}`,
|
||||
},
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_accessibility_tree
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_accessibility_tree",
|
||||
label: "Browser Accessibility Tree",
|
||||
description:
|
||||
"Get the accessibility tree of the current page as structured text. Shows roles, names, labels, values, and states of all interactive elements. Use this to understand page structure before clicking — it reveals buttons, inputs, links, and their labels without needing to guess CSS selectors or coordinates. Much more reliable than inspecting the DOM directly.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Scope the accessibility tree to a specific element by CSS selector (e.g. 'main', 'form', '#modal'). If omitted, returns the full page tree.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
|
||||
let snapshot: string;
|
||||
if (params.selector) {
|
||||
const locator = target.locator(params.selector).first();
|
||||
snapshot = await locator.ariaSnapshot();
|
||||
} else {
|
||||
snapshot = await target.locator("body").ariaSnapshot();
|
||||
}
|
||||
|
||||
const truncated = deps.truncateText(snapshot);
|
||||
const scope = params.selector
|
||||
? `element "${params.selector}"`
|
||||
: "full page";
|
||||
const viewport = p.viewportSize();
|
||||
const vpText = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Accessibility tree for ${scope} (viewport: ${vpText}):\n\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: { scope, snapshot, viewport: vpText },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Accessibility tree failed: ${err.message}`,
|
||||
},
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_find
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_find",
|
||||
label: "Browser Find",
|
||||
description:
|
||||
"Find elements on the page by text content, ARIA role, or CSS selector. Returns only the matched nodes as a compact accessibility snapshot — far cheaper than browser_get_accessibility_tree. Use this after any action to locate a specific button, input, heading, or link before clicking it.",
|
||||
promptGuidelines: [
|
||||
"Use browser_find for cheap targeted discovery before requesting the full accessibility tree.",
|
||||
"Prefer browser_find when you need one button, input, heading, dialog, or alert rather than a full-page structure dump.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
text: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Find elements whose visible text contains this string (case-insensitive).",
|
||||
}),
|
||||
),
|
||||
role: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"ARIA role to filter by, e.g. 'button', 'link', 'heading', 'textbox', 'dialog', 'alert'.",
|
||||
}),
|
||||
),
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to scope the search. If omitted, searches the full page.",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of results to return (default: 20).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const limit = params.limit ?? 20;
|
||||
|
||||
const results = await target.evaluate(
|
||||
({ text, role, selector, limit }) => {
|
||||
const root = selector
|
||||
? document.querySelector(selector)
|
||||
: document.body;
|
||||
if (!root) return [];
|
||||
|
||||
let candidates: Element[];
|
||||
if (role) {
|
||||
const roleMap: Record<string, string> = {
|
||||
button: 'button,[role="button"]',
|
||||
link: 'a[href],[role="link"]',
|
||||
heading: 'h1,h2,h3,h4,h5,h6,[role="heading"]',
|
||||
textbox:
|
||||
'input:not([type="hidden"]):not([type="checkbox"]):not([type="radio"]):not([type="submit"]):not([type="button"]),textarea,[role="textbox"]',
|
||||
checkbox: 'input[type="checkbox"],[role="checkbox"]',
|
||||
radio: 'input[type="radio"],[role="radio"]',
|
||||
combobox: 'select,[role="combobox"]',
|
||||
dialog: 'dialog,[role="dialog"]',
|
||||
alert: '[role="alert"]',
|
||||
navigation: 'nav,[role="navigation"]',
|
||||
listitem: 'li,[role="listitem"]',
|
||||
};
|
||||
const cssForRole =
|
||||
roleMap[role.toLowerCase()] ?? `[role="${role}"]`;
|
||||
candidates = Array.from(root.querySelectorAll(cssForRole));
|
||||
} else {
|
||||
candidates = Array.from(root.querySelectorAll("*"));
|
||||
}
|
||||
|
||||
if (text) {
|
||||
const lower = text.toLowerCase();
|
||||
candidates = candidates.filter(
|
||||
(el) =>
|
||||
(el.textContent ?? "").toLowerCase().includes(lower) ||
|
||||
(el.getAttribute("aria-label") ?? "")
|
||||
.toLowerCase()
|
||||
.includes(lower) ||
|
||||
(el.getAttribute("placeholder") ?? "")
|
||||
.toLowerCase()
|
||||
.includes(lower) ||
|
||||
(el.getAttribute("value") ?? "")
|
||||
.toLowerCase()
|
||||
.includes(lower),
|
||||
);
|
||||
}
|
||||
|
||||
return candidates.slice(0, limit).map((el) => {
|
||||
const tag = el.tagName.toLowerCase();
|
||||
const id = el.id ? `#${el.id}` : "";
|
||||
const classes = Array.from(el.classList)
|
||||
.slice(0, 2)
|
||||
.map((c) => `.${c}`)
|
||||
.join("");
|
||||
const ariaLabel = el.getAttribute("aria-label") ?? "";
|
||||
const placeholder = el.getAttribute("placeholder") ?? "";
|
||||
const textContent = (el.textContent ?? "").trim().slice(0, 80);
|
||||
const role = el.getAttribute("role") ?? "";
|
||||
const type = el.getAttribute("type") ?? "";
|
||||
const href = el.getAttribute("href") ?? "";
|
||||
const value = (el as HTMLInputElement).value ?? "";
|
||||
|
||||
return {
|
||||
tag,
|
||||
id,
|
||||
classes,
|
||||
ariaLabel,
|
||||
placeholder,
|
||||
textContent,
|
||||
role,
|
||||
type,
|
||||
href,
|
||||
value,
|
||||
};
|
||||
});
|
||||
},
|
||||
{
|
||||
text: params.text,
|
||||
role: params.role,
|
||||
selector: params.selector,
|
||||
limit,
|
||||
},
|
||||
);
|
||||
|
||||
if (results.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "No elements found matching the criteria.",
|
||||
},
|
||||
],
|
||||
details: { count: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
const lines = results.map((r: any) => {
|
||||
const parts: string[] = [`${r.tag}${r.id}${r.classes}`];
|
||||
if (r.role) parts.push(`role="${r.role}"`);
|
||||
if (r.type) parts.push(`type="${r.type}"`);
|
||||
if (r.ariaLabel) parts.push(`aria-label="${r.ariaLabel}"`);
|
||||
if (r.placeholder) parts.push(`placeholder="${r.placeholder}"`);
|
||||
if (r.href) parts.push(`href="${r.href.slice(0, 60)}"`);
|
||||
if (r.value) parts.push(`value="${r.value.slice(0, 40)}"`);
|
||||
if (r.textContent && !r.ariaLabel) parts.push(`"${r.textContent}"`);
|
||||
return " " + parts.join(" ");
|
||||
});
|
||||
|
||||
const criteria: string[] = [];
|
||||
if (params.role) criteria.push(`role="${params.role}"`);
|
||||
if (params.text) criteria.push(`text="${params.text}"`);
|
||||
if (params.selector) criteria.push(`within="${params.selector}"`);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Found ${results.length} element(s) [${criteria.join(", ")}]:\n${lines.join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { count: results.length, results },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Find failed: ${err.message}` }],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_page_source
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_page_source",
|
||||
label: "Browser Page Source",
|
||||
description:
|
||||
"Get the current HTML source of the page (or a specific element). Use when you need to inspect the actual DOM structure — verify semantic HTML, check that elements rendered correctly, debug why a selector isn't matching, or audit accessibility markup. Output is truncated for large pages.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to scope the output to a specific element (e.g. 'main', 'form', '#app'). If omitted, returns the full page HTML.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
|
||||
let html: string;
|
||||
if (params.selector) {
|
||||
html = await target
|
||||
.locator(params.selector)
|
||||
.first()
|
||||
.evaluate((el: Element) => el.outerHTML);
|
||||
} else {
|
||||
html = await target.content();
|
||||
}
|
||||
|
||||
const truncated = deps.truncateText(html);
|
||||
const scope = params.selector
|
||||
? `element "${params.selector}"`
|
||||
: "full page";
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `HTML source of ${scope}:\n\n${truncated}`,
|
||||
},
|
||||
],
|
||||
details: { scope },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Get page source failed: ${err.message}`,
|
||||
},
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,671 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { StringEnum } from "@singularity-forge/pi-ai";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { diffCompactStates } from "../core.js";
|
||||
import type { CompactPageState, ToolDeps } from "../state.js";
|
||||
import { setLastActionAfterState, setLastActionBeforeState } from "../state.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Intent definitions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const INTENTS = [
|
||||
"submit_form",
|
||||
"close_dialog",
|
||||
"primary_cta",
|
||||
"search_field",
|
||||
"next_step",
|
||||
"dismiss",
|
||||
"auth_action",
|
||||
"back_navigation",
|
||||
] as const;
|
||||
|
||||
type _Intent = (typeof INTENTS)[number];
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Scoring evaluate script — runs entirely in-browser via page.evaluate()
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Builds a self-contained IIFE string that scores candidate elements for a
|
||||
* given intent. Returns top 5 candidates sorted by score descending, each
|
||||
* with { score, selector, tag, role, name, text, reason }.
|
||||
*
|
||||
* Uses window.__pi utilities (injected via addInitScript) for element
|
||||
* metadata — no inline redeclarations.
|
||||
*/
|
||||
function buildIntentScoringScript(intent: string, scope?: string): string {
|
||||
const scopeSelector = JSON.stringify(scope ?? null);
|
||||
|
||||
return `(() => {
|
||||
var pi = window.__pi;
|
||||
if (!pi) return { error: "window.__pi not available — browser helpers not injected" };
|
||||
|
||||
var intentRaw = ${JSON.stringify(intent)};
|
||||
var normalized = intentRaw.toLowerCase().replace(/[\\s_\\-]+/g, "");
|
||||
var scopeSel = ${scopeSelector};
|
||||
var root = scopeSel ? document.querySelector(scopeSel) : document.body;
|
||||
if (!root) return { error: "Scope selector not found: " + scopeSel };
|
||||
|
||||
// --- Shared helpers ---
|
||||
function textOf(el) {
|
||||
return (el.textContent || "").trim().replace(/\\s+/g, " ").slice(0, 120).toLowerCase();
|
||||
}
|
||||
|
||||
function clamp01(v) { return Math.max(0, Math.min(1, v)); }
|
||||
|
||||
function makeCandidate(el, score, reason) {
|
||||
return {
|
||||
score: Math.round(clamp01(score) * 100) / 100,
|
||||
selector: pi.cssPath(el),
|
||||
tag: el.tagName.toLowerCase(),
|
||||
role: pi.inferRole(el) || "",
|
||||
name: pi.accessibleName(el) || "",
|
||||
text: textOf(el).slice(0, 80),
|
||||
reason: reason,
|
||||
};
|
||||
}
|
||||
|
||||
function qsa(sel) { return Array.from(root.querySelectorAll(sel)); }
|
||||
|
||||
function visibleEnabled(el) {
|
||||
return pi.isVisible(el) && pi.isEnabled(el);
|
||||
}
|
||||
|
||||
function textMatches(el, patterns) {
|
||||
var t = textOf(el);
|
||||
var n = (pi.accessibleName(el) || "").toLowerCase();
|
||||
var combined = t + " " + n;
|
||||
for (var i = 0; i < patterns.length; i++) {
|
||||
if (combined.indexOf(patterns[i]) !== -1) return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
function textMatchStrength(el, patterns) {
|
||||
var t = textOf(el);
|
||||
var n = (pi.accessibleName(el) || "").toLowerCase();
|
||||
var combined = t + " " + n;
|
||||
var count = 0;
|
||||
for (var i = 0; i < patterns.length; i++) {
|
||||
if (combined.indexOf(patterns[i]) !== -1) count++;
|
||||
}
|
||||
return Math.min(count / Math.max(patterns.length, 1), 1);
|
||||
}
|
||||
|
||||
// --- Intent-specific scoring ---
|
||||
var candidates = [];
|
||||
|
||||
if (normalized === "submitform") {
|
||||
var els = qsa('button[type="submit"], input[type="submit"], button:not([type]), button[type="button"]');
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = el.type === "submit" || el.getAttribute("type") === "submit" ? 0.35 : 0;
|
||||
var d2 = el.closest("form") ? 0.3 : 0;
|
||||
var d3 = textMatches(el, ["submit", "send", "save", "create", "add", "post", "confirm", "ok", "done", "register", "sign up", "log in"]) ? 0.2 : 0;
|
||||
var d4 = 0.15;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
if (d1 > 0) reasons.push("submit-type");
|
||||
if (d2 > 0) reasons.push("inside-form");
|
||||
if (d3 > 0) reasons.push("text-suggests-submit");
|
||||
reasons.push("visible+enabled");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "closedialog") {
|
||||
var containers = qsa('[role="dialog"], dialog, [aria-modal="true"], [role="alertdialog"]');
|
||||
for (var ci = 0; ci < containers.length; ci++) {
|
||||
var btns = containers[ci].querySelectorAll("button, a, [role='button']");
|
||||
for (var bi = 0; bi < btns.length; bi++) {
|
||||
var el = btns[bi];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = textMatches(el, ["close", "cancel", "dismiss", "×", "✕", "x", "got it", "ok", "done"]) ? 0.35 : 0;
|
||||
var ariaLbl = (el.getAttribute("aria-label") || "").toLowerCase();
|
||||
var d2 = (ariaLbl.indexOf("close") !== -1 || ariaLbl.indexOf("dismiss") !== -1) ? 0.25 : 0;
|
||||
var d3 = 0.2;
|
||||
var rect = el.getBoundingClientRect();
|
||||
var parentRect = containers[ci].getBoundingClientRect();
|
||||
var isTopRight = rect.top - parentRect.top < 60 && parentRect.right - rect.right < 60;
|
||||
var d4 = isTopRight ? 0.2 : 0;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
if (d1 > 0) reasons.push("text-matches-close");
|
||||
if (d2 > 0) reasons.push("aria-label-close");
|
||||
reasons.push("inside-dialog");
|
||||
if (d4 > 0) reasons.push("top-right-position");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "primarycta") {
|
||||
var els = qsa("button, a, [role='button'], input[type='submit'], input[type='button']");
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var rect = el.getBoundingClientRect();
|
||||
var area = rect.width * rect.height;
|
||||
var d1 = clamp01(area / 12000);
|
||||
var role = pi.inferRole(el);
|
||||
var d2 = role === "button" ? 0.25 : (role === "link" ? 0.1 : 0.15);
|
||||
var isNegative = textMatches(el, ["cancel", "dismiss", "close", "skip", "no thanks", "no, thanks", "maybe later"]);
|
||||
var d3 = isNegative ? 0 : 0.2;
|
||||
var inMain = !!el.closest("main, [role='main'], article, section, .hero, .content");
|
||||
var d4 = inMain ? 0.15 : 0;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
reasons.push("size:" + Math.round(area));
|
||||
if (d2 >= 0.25) reasons.push("button-role");
|
||||
if (d3 > 0) reasons.push("non-dismissive");
|
||||
if (d4 > 0) reasons.push("in-main-content");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "searchfield") {
|
||||
var els = qsa("input, textarea, [role='searchbox'], [role='combobox'], [contenteditable='true']");
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!pi.isVisible(el)) continue;
|
||||
var type = (el.getAttribute("type") || "text").toLowerCase();
|
||||
if (["hidden", "submit", "button", "reset", "image", "checkbox", "radio", "file"].indexOf(type) !== -1 && el.tagName.toLowerCase() === "input") continue;
|
||||
var d1 = type === "search" || pi.inferRole(el) === "searchbox" ? 0.4 : 0;
|
||||
var ph = (el.getAttribute("placeholder") || "").toLowerCase();
|
||||
var nm = (el.getAttribute("name") || "").toLowerCase();
|
||||
var ariaLbl = (el.getAttribute("aria-label") || "").toLowerCase();
|
||||
var combined = ph + " " + nm + " " + ariaLbl;
|
||||
var d2 = combined.indexOf("search") !== -1 || combined.indexOf("query") !== -1 || combined.indexOf("find") !== -1 ? 0.3 : 0;
|
||||
var d3 = pi.isEnabled(el) ? 0.15 : 0;
|
||||
var inHeader = !!el.closest("header, nav, [role='banner'], [role='navigation'], [role='search']");
|
||||
var d4 = inHeader ? 0.15 : 0;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
if (score < 0.1) continue;
|
||||
var reasons = [];
|
||||
if (d1 > 0) reasons.push("search-type/role");
|
||||
if (d2 > 0) reasons.push("name/placeholder-match");
|
||||
if (d3 > 0) reasons.push("enabled");
|
||||
if (d4 > 0) reasons.push("in-header/nav");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "nextstep") {
|
||||
var els = qsa("button, a, [role='button'], input[type='submit'], input[type='button']");
|
||||
var patterns = ["next", "continue", "proceed", "forward", "go", "step"];
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = textMatchStrength(el, patterns) * 0.4;
|
||||
if (d1 === 0) continue;
|
||||
var role = pi.inferRole(el);
|
||||
var d2 = role === "button" ? 0.25 : 0.1;
|
||||
var d3 = 0.2;
|
||||
var isDisabled = !pi.isEnabled(el);
|
||||
var d4 = isDisabled ? 0 : 0.15;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
reasons.push("text-match");
|
||||
if (d2 >= 0.25) reasons.push("button-role");
|
||||
reasons.push("visible");
|
||||
if (d4 > 0) reasons.push("enabled");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "dismiss") {
|
||||
var els = qsa("button, a, [role='button'], [role='link']");
|
||||
var patterns = ["close", "cancel", "dismiss", "skip", "no thanks", "no, thanks", "maybe later", "not now", "×", "✕"];
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = textMatchStrength(el, patterns) * 0.35;
|
||||
if (d1 === 0) continue;
|
||||
var inOverlay = !!el.closest('[role="dialog"], dialog, [aria-modal="true"], [role="alertdialog"], .modal, .overlay, .popup, .popover, .toast, .banner');
|
||||
var d2 = inOverlay ? 0.3 : 0.05;
|
||||
var rect = el.getBoundingClientRect();
|
||||
var isEdge = rect.top < 80 || rect.right > window.innerWidth - 80;
|
||||
var d3 = isEdge ? 0.15 : 0;
|
||||
var d4 = 0.15;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
reasons.push("text-match");
|
||||
if (d2 >= 0.3) reasons.push("inside-overlay");
|
||||
if (d3 > 0) reasons.push("edge-position");
|
||||
reasons.push("visible+enabled");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "authaction") {
|
||||
var els = qsa("button, a, [role='button'], [role='link'], input[type='submit']");
|
||||
var patterns = ["log in", "login", "sign in", "signin", "sign up", "signup", "register", "create account", "join", "get started"];
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = textMatchStrength(el, patterns) * 0.4;
|
||||
if (d1 === 0) continue;
|
||||
var role = pi.inferRole(el);
|
||||
var d2 = (role === "button" || role === "link") ? 0.25 : 0.1;
|
||||
var rect = el.getBoundingClientRect();
|
||||
var inHeader = !!el.closest("header, nav, [role='banner'], [role='navigation']");
|
||||
var isProminent = inHeader || rect.top < 200;
|
||||
var d3 = isProminent ? 0.2 : 0.05;
|
||||
var d4 = 0.15;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
reasons.push("text-match");
|
||||
if (d2 >= 0.25) reasons.push("button-or-link");
|
||||
if (d3 >= 0.2) reasons.push("prominent-position");
|
||||
reasons.push("visible+enabled");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else if (normalized === "backnavigation") {
|
||||
var els = qsa("button, a, [role='button'], [role='link']");
|
||||
var patterns = ["back", "previous", "prev", "return", "go back"];
|
||||
for (var i = 0; i < els.length; i++) {
|
||||
var el = els[i];
|
||||
if (!visibleEnabled(el)) continue;
|
||||
var d1 = textMatchStrength(el, patterns) * 0.35;
|
||||
if (d1 === 0) continue;
|
||||
var innerHtml = el.innerHTML.toLowerCase();
|
||||
var hasArrow = innerHtml.indexOf("←") !== -1 || innerHtml.indexOf("&larr") !== -1 || innerHtml.indexOf("arrow") !== -1 || innerHtml.indexOf("chevron-left") !== -1 || innerHtml.indexOf("back") !== -1;
|
||||
var d2 = hasArrow ? 0.25 : 0;
|
||||
var inNav = !!el.closest("header, nav, [role='banner'], [role='navigation'], .breadcrumb, .toolbar");
|
||||
var d3 = inNav ? 0.25 : 0.05;
|
||||
var d4 = 0.15;
|
||||
var score = d1 + d2 + d3 + d4;
|
||||
var reasons = [];
|
||||
reasons.push("text-match");
|
||||
if (d2 > 0) reasons.push("has-back-arrow/icon");
|
||||
if (d3 >= 0.25) reasons.push("in-nav/header");
|
||||
reasons.push("visible+enabled");
|
||||
candidates.push(makeCandidate(el, score, reasons.join(", ")));
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
return { error: "Unknown intent: " + intentRaw + ". Valid: submit_form, close_dialog, primary_cta, search_field, next_step, dismiss, auth_action, back_navigation" };
|
||||
}
|
||||
|
||||
// Sort by score descending, cap at 5
|
||||
candidates.sort(function(a, b) { return b.score - a.score; });
|
||||
candidates = candidates.slice(0, 5);
|
||||
|
||||
return { intent: intentRaw, normalized: normalized, count: candidates.length, candidates: candidates };
|
||||
})()`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Result types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface IntentCandidate {
|
||||
score: number;
|
||||
selector: string;
|
||||
tag: string;
|
||||
role: string;
|
||||
name: string;
|
||||
text: string;
|
||||
reason: string;
|
||||
}
|
||||
|
||||
interface IntentScoringResult {
|
||||
intent: string;
|
||||
normalized: string;
|
||||
count: number;
|
||||
candidates: IntentCandidate[];
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Registration
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function registerIntentTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -----------------------------------------------------------------------
|
||||
// browser_find_best
|
||||
// -----------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_find_best",
|
||||
label: "Find Best",
|
||||
description:
|
||||
'Find the best-matching element for a semantic intent. Returns up to 5 scored candidates (0-1) ranked by structural position, role, text signals, and visibility. Use this to discover which element the agent should interact with for a given goal — e.g. intent="submit_form" finds submit buttons, intent="close_dialog" finds close/dismiss buttons inside dialogs. Each candidate includes a CSS selector usable with browser_click.',
|
||||
parameters: Type.Object({
|
||||
intent: StringEnum(INTENTS, {
|
||||
description:
|
||||
"Semantic intent: submit_form, close_dialog, primary_cta, search_field, next_step, dismiss, auth_action, back_navigation",
|
||||
}),
|
||||
scope: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to narrow the search area. If omitted, searches the full page.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.scope ? [params.scope] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_find_best",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
|
||||
const script = buildIntentScoringScript(params.intent, params.scope);
|
||||
const result = (await target.evaluate(script)) as IntentScoringResult;
|
||||
|
||||
if (result.error) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: result.error,
|
||||
beforeState,
|
||||
});
|
||||
return {
|
||||
content: [{ type: "text" as const, text: result.error }],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.scope ? [params.scope] : [],
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
setLastActionBeforeState(beforeState);
|
||||
setLastActionAfterState(afterState);
|
||||
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "success",
|
||||
afterUrl: afterState.url,
|
||||
beforeState,
|
||||
afterState,
|
||||
});
|
||||
|
||||
// Format output
|
||||
const lines: string[] = [];
|
||||
lines.push(`Intent: ${params.intent} → ${result.count} candidate(s)`);
|
||||
if (params.scope) lines.push(`Scope: ${params.scope}`);
|
||||
lines.push("");
|
||||
|
||||
if (result.candidates.length === 0) {
|
||||
lines.push(
|
||||
"No candidates found for this intent on the current page.",
|
||||
);
|
||||
} else {
|
||||
for (let i = 0; i < result.candidates.length; i++) {
|
||||
const c = result.candidates[i];
|
||||
lines.push(`${i + 1}. **${c.score}** \`${c.selector}\``);
|
||||
lines.push(
|
||||
` ${c.tag}${c.role ? ` [${c.role}]` : ""} — "${c.name || c.text}"`,
|
||||
);
|
||||
lines.push(` Reason: ${c.reason}`);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text" as const, text: lines.join("\n") }],
|
||||
details: { intentResult: result },
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const screenshot = await deps.captureErrorScreenshot(
|
||||
(() => {
|
||||
try {
|
||||
return deps.getActivePage();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})(),
|
||||
);
|
||||
const errMsg = deps.firstErrorLine(err);
|
||||
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: errMsg,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "image"; data: string; mimeType: string }
|
||||
> = [{ type: "text", text: `browser_find_best failed: ${errMsg}` }];
|
||||
if (screenshot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: screenshot.data,
|
||||
mimeType: screenshot.mimeType,
|
||||
});
|
||||
}
|
||||
return { content, details: {}, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
// browser_act
|
||||
// -----------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_act",
|
||||
label: "Browser Act",
|
||||
description:
|
||||
'Execute a semantic action in one call. Resolves the top candidate for the given intent (same scoring as browser_find_best), performs the action (click for buttons/links, focus for search fields), settles the page, and returns a before/after diff. Use when you know what you want to accomplish semantically — e.g. intent="submit_form" finds and clicks the submit button, intent="close_dialog" dismisses the dialog.',
|
||||
parameters: Type.Object({
|
||||
intent: StringEnum(INTENTS, {
|
||||
description:
|
||||
"Semantic intent: submit_form, close_dialog, primary_cta, search_field, next_step, dismiss, auth_action, back_navigation",
|
||||
}),
|
||||
scope: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to narrow the search area. If omitted, searches the full page.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.scope ? [params.scope] : [],
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_act",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
|
||||
// Score candidates
|
||||
const script = buildIntentScoringScript(params.intent, params.scope);
|
||||
const result = (await target.evaluate(script)) as IntentScoringResult;
|
||||
|
||||
if (result.error) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: result.error,
|
||||
beforeState,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: `browser_act failed: ${result.error}`,
|
||||
},
|
||||
],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
if (result.candidates.length === 0) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: `No candidates found for intent "${params.intent}"`,
|
||||
beforeState,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: `browser_act: No candidates found for intent "${params.intent}" on the current page. The page may not have the expected elements (e.g. no dialog for close_dialog, no form for submit_form).`,
|
||||
},
|
||||
],
|
||||
details: { intentResult: result },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Take top candidate and execute action
|
||||
const top = result.candidates[0];
|
||||
const normalizedIntent = params.intent
|
||||
.toLowerCase()
|
||||
.replace(/[\s_-]+/g, "");
|
||||
|
||||
if (normalizedIntent === "searchfield") {
|
||||
// Focus instead of click for search fields
|
||||
try {
|
||||
await target.locator(top.selector).first().focus({ timeout: 5000 });
|
||||
} catch {
|
||||
// Fallback: click to focus
|
||||
await target.locator(top.selector).first().click({ timeout: 5000 });
|
||||
}
|
||||
} else {
|
||||
// Click via Playwright locator (D021)
|
||||
try {
|
||||
await target.locator(top.selector).first().click({ timeout: 5000 });
|
||||
} catch {
|
||||
// getByRole fallback from interaction.ts pattern
|
||||
const nameMatch = top.selector.match(
|
||||
/\[(?:aria-label|name|placeholder)="([^"]+)"\]/i,
|
||||
);
|
||||
const roleName = nameMatch?.[1];
|
||||
let clicked = false;
|
||||
for (const role of [
|
||||
"button",
|
||||
"link",
|
||||
"combobox",
|
||||
"textbox",
|
||||
] as const) {
|
||||
try {
|
||||
const loc = roleName
|
||||
? target.getByRole(role, { name: new RegExp(roleName, "i") })
|
||||
: target.getByRole(role, {
|
||||
name: new RegExp(
|
||||
top.name.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"),
|
||||
"i",
|
||||
),
|
||||
});
|
||||
await loc.first().click({ timeout: 3000 });
|
||||
clicked = true;
|
||||
break;
|
||||
} catch {
|
||||
/* try next role */
|
||||
}
|
||||
}
|
||||
if (!clicked) {
|
||||
throw new Error(
|
||||
`Could not click top candidate "${top.selector}" for intent "${params.intent}"`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Settle after action
|
||||
await deps.settleAfterActionAdaptive(p);
|
||||
|
||||
// Capture after state and diff
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
selectors: params.scope ? [params.scope] : [],
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
const diff = diffCompactStates(beforeState, afterState);
|
||||
const summary = deps.formatCompactStateSummary(afterState);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
|
||||
setLastActionBeforeState(beforeState);
|
||||
setLastActionAfterState(afterState);
|
||||
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "success",
|
||||
afterUrl: afterState.url,
|
||||
diffSummary: diff.summary,
|
||||
beforeState,
|
||||
afterState,
|
||||
});
|
||||
|
||||
// Format output
|
||||
const lines: string[] = [];
|
||||
lines.push(`Intent: ${params.intent}`);
|
||||
lines.push(
|
||||
`Action: ${normalizedIntent === "searchfield" ? "focused" : "clicked"} top candidate (score: ${top.score})`,
|
||||
);
|
||||
lines.push(`Target: \`${top.selector}\` — "${top.name || top.text}"`);
|
||||
lines.push(`Reason: ${top.reason}`);
|
||||
lines.push("");
|
||||
lines.push(`Diff:\n${deps.formatDiffText(diff)}`);
|
||||
if (jsErrors.trim()) {
|
||||
lines.push(`\nJS Errors:\n${jsErrors}`);
|
||||
}
|
||||
lines.push(`\nPage summary:\n${summary}`);
|
||||
|
||||
return {
|
||||
content: [{ type: "text" as const, text: lines.join("\n") }],
|
||||
details: { intentResult: result, topCandidate: top, diff },
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const screenshot = await deps.captureErrorScreenshot(
|
||||
(() => {
|
||||
try {
|
||||
return deps.getActivePage();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
})(),
|
||||
);
|
||||
const errMsg = deps.firstErrorLine(err);
|
||||
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
error: errMsg,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
|
||||
const content: Array<
|
||||
| { type: "text"; text: string }
|
||||
| { type: "image"; data: string; mimeType: string }
|
||||
> = [{ type: "text", text: `browser_act failed: ${errMsg}` }];
|
||||
if (screenshot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: screenshot.data,
|
||||
mimeType: screenshot.mimeType,
|
||||
});
|
||||
}
|
||||
return { content, details: {}, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,346 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { diffCompactStates } from "../core.js";
|
||||
import type { CompactPageState, ToolDeps } from "../state.js";
|
||||
import { setLastActionAfterState, setLastActionBeforeState } from "../state.js";
|
||||
|
||||
export function registerNavigationTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_navigate
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_navigate",
|
||||
label: "Browser Navigate",
|
||||
description:
|
||||
"Open the browser (if not already open) and navigate to a URL. Waits for network idle. Returns page title and current URL. Use ONLY for visually verifying locally-running web apps (e.g. http://localhost:3000). Do NOT use for documentation sites, GitHub, search results, or any external URL — use web_search instead. Screenshots are only captured when the `screenshot` parameter is set to true.",
|
||||
parameters: Type.Object({
|
||||
url: Type.String({
|
||||
description: "URL to navigate to, e.g. http://localhost:3000",
|
||||
}),
|
||||
screenshot: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Capture and return a screenshot (default: false)",
|
||||
default: false,
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
let actionId: number | null = null;
|
||||
let beforeState: CompactPageState | null = null;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
beforeState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
});
|
||||
actionId = deps.beginTrackedAction(
|
||||
"browser_navigate",
|
||||
params,
|
||||
beforeState.url,
|
||||
).id;
|
||||
await p.goto(params.url, {
|
||||
waitUntil: "domcontentloaded",
|
||||
timeout: 30000,
|
||||
});
|
||||
await p.waitForLoadState("networkidle", { timeout: 5000 }).catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
await new Promise((resolve) => setTimeout(resolve, 300));
|
||||
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
const viewport = p.viewportSize();
|
||||
const vpText = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
});
|
||||
const summary = deps.formatCompactStateSummary(afterState);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
const diff = diffCompactStates(beforeState, afterState);
|
||||
setLastActionBeforeState(beforeState);
|
||||
setLastActionAfterState(afterState);
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "success",
|
||||
afterUrl: afterState.url,
|
||||
warningSummary: jsErrors.trim() || undefined,
|
||||
diffSummary: diff.summary,
|
||||
changed: diff.changed,
|
||||
beforeState,
|
||||
afterState,
|
||||
});
|
||||
|
||||
let screenshotContent: any[] = [];
|
||||
if (params.screenshot) {
|
||||
try {
|
||||
let buf = await p.screenshot({
|
||||
type: "jpeg",
|
||||
quality: 80,
|
||||
scale: "css",
|
||||
});
|
||||
buf = await deps.constrainScreenshot(p, buf, "image/jpeg", 80);
|
||||
screenshotContent = [
|
||||
{
|
||||
type: "image",
|
||||
data: buf.toString("base64"),
|
||||
mimeType: "image/jpeg",
|
||||
},
|
||||
];
|
||||
} catch {
|
||||
/* non-fatal — screenshot is optional, navigation result is still valid */
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Navigated to: ${url}\nTitle: ${title}\nViewport: ${vpText}\nAction: ${actionId}${jsErrors}\n\nDiff:\n${deps.formatDiffText(diff)}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
...screenshotContent,
|
||||
],
|
||||
details: {
|
||||
title,
|
||||
url,
|
||||
status: "loaded",
|
||||
viewport: vpText,
|
||||
actionId,
|
||||
diff,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
if (actionId !== null) {
|
||||
deps.finishTrackedAction(actionId, {
|
||||
status: "error",
|
||||
afterUrl: deps.getActivePageOrNull()?.url() ?? "",
|
||||
error: err.message,
|
||||
beforeState: beforeState ?? undefined,
|
||||
});
|
||||
}
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const content: any[] = [
|
||||
{ type: "text", text: `Navigation failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return {
|
||||
content,
|
||||
details: { status: "error", error: err.message, actionId },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_go_back
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_go_back",
|
||||
label: "Browser Go Back",
|
||||
description:
|
||||
"Navigate back in browser history. Returns a compact page summary after navigation.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const response = await p.goBack({
|
||||
waitUntil: "domcontentloaded",
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
if (!response) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No previous page in history." }],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
await p.waitForLoadState("networkidle", { timeout: 5000 }).catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
const summary = await deps.postActionSummary(p);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Navigated back to: ${url}\nTitle: ${title}${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: { title, url },
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const content: any[] = [
|
||||
{ type: "text", text: `Go back failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return { content, details: { error: err.message }, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_go_forward
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_go_forward",
|
||||
label: "Browser Go Forward",
|
||||
description:
|
||||
"Navigate forward in browser history. Returns a compact page summary after navigation.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const response = await p.goForward({
|
||||
waitUntil: "domcontentloaded",
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
if (!response) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No forward page in history." }],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
await p.waitForLoadState("networkidle", { timeout: 5000 }).catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
const summary = await deps.postActionSummary(p);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Navigated forward to: ${url}\nTitle: ${title}${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: { title, url },
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const content: any[] = [
|
||||
{ type: "text", text: `Go forward failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return { content, details: { error: err.message }, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_reload
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_reload",
|
||||
label: "Browser Reload",
|
||||
description:
|
||||
"Reload the current page. Returns a screenshot, compact page summary, and page metadata (same shape as browser_navigate).",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
await p.reload({ waitUntil: "domcontentloaded", timeout: 30000 });
|
||||
await p.waitForLoadState("networkidle", { timeout: 5000 }).catch(() => {
|
||||
/* networkidle timeout — non-fatal, page may still be usable */
|
||||
});
|
||||
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
const viewport = p.viewportSize();
|
||||
const vpText = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
const summary = await deps.postActionSummary(p);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
|
||||
let screenshotContent: any[] = [];
|
||||
try {
|
||||
let buf = await p.screenshot({
|
||||
type: "jpeg",
|
||||
quality: 80,
|
||||
scale: "css",
|
||||
});
|
||||
buf = await deps.constrainScreenshot(p, buf, "image/jpeg", 80);
|
||||
screenshotContent = [
|
||||
{
|
||||
type: "image",
|
||||
data: buf.toString("base64"),
|
||||
mimeType: "image/jpeg",
|
||||
},
|
||||
];
|
||||
} catch {
|
||||
/* non-fatal — screenshot is optional, reload result is still valid */
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Reloaded: ${url}\nTitle: ${title}\nViewport: ${vpText}${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
...screenshotContent,
|
||||
],
|
||||
details: { title, url, viewport: vpText },
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const content: any[] = [
|
||||
{ type: "text", text: `Reload failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return { content, details: { error: err.message }, isError: true };
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,278 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Network interception & mocking tools — mock API responses, block URLs, simulate errors.
|
||||
*/
|
||||
|
||||
interface ActiveRoute {
|
||||
id: number;
|
||||
pattern: string;
|
||||
type: "mock" | "block";
|
||||
status?: number;
|
||||
delay?: number;
|
||||
description: string;
|
||||
}
|
||||
|
||||
let nextRouteId = 1;
|
||||
const activeRoutes: ActiveRoute[] = [];
|
||||
const routeCleanups: Map<number, () => Promise<void>> = new Map();
|
||||
|
||||
export function registerNetworkMockTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_mock_route
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_mock_route",
|
||||
label: "Browser Mock Route",
|
||||
description:
|
||||
"Intercept network requests matching a URL pattern and respond with custom status, body, and headers. " +
|
||||
"Supports simulating slow responses via delay parameter. " +
|
||||
"Routes survive page navigation within the same context. Use browser_clear_routes to remove all mocks.",
|
||||
parameters: Type.Object({
|
||||
url: Type.String({
|
||||
description:
|
||||
"URL pattern to intercept. Supports glob patterns (e.g., '**/api/users*') or exact URLs.",
|
||||
}),
|
||||
status: Type.Optional(
|
||||
Type.Number({
|
||||
description: "HTTP status code for the mock response (default: 200).",
|
||||
}),
|
||||
),
|
||||
body: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Response body string. For JSON responses, pass a JSON string.",
|
||||
}),
|
||||
),
|
||||
contentType: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Content-Type header (default: 'application/json' if body looks like JSON, else 'text/plain').",
|
||||
}),
|
||||
),
|
||||
headers: Type.Optional(
|
||||
Type.Record(Type.String(), Type.String(), {
|
||||
description: "Additional response headers as key-value pairs.",
|
||||
}),
|
||||
),
|
||||
delay: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Delay in milliseconds before sending the response. Simulates slow responses.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const routeId = nextRouteId++;
|
||||
|
||||
const status = params.status ?? 200;
|
||||
const body = params.body ?? "";
|
||||
const delay = params.delay ?? 0;
|
||||
|
||||
// Auto-detect content type
|
||||
let contentType = params.contentType;
|
||||
if (!contentType) {
|
||||
try {
|
||||
JSON.parse(body);
|
||||
contentType = "application/json";
|
||||
} catch {
|
||||
contentType = "text/plain";
|
||||
}
|
||||
}
|
||||
|
||||
const headers: Record<string, string> = {
|
||||
"content-type": contentType,
|
||||
"access-control-allow-origin": "*",
|
||||
...(params.headers ?? {}),
|
||||
};
|
||||
|
||||
const handler = async (route: any) => {
|
||||
if (delay > 0) {
|
||||
await new Promise((resolve) => setTimeout(resolve, delay));
|
||||
}
|
||||
await route.fulfill({
|
||||
status,
|
||||
body,
|
||||
headers,
|
||||
});
|
||||
};
|
||||
|
||||
await p.route(params.url, handler);
|
||||
|
||||
const cleanup = async () => {
|
||||
try {
|
||||
await p.unroute(params.url, handler);
|
||||
} catch {
|
||||
// Page may be closed
|
||||
}
|
||||
};
|
||||
|
||||
const routeInfo: ActiveRoute = {
|
||||
id: routeId,
|
||||
pattern: params.url,
|
||||
type: "mock",
|
||||
status,
|
||||
delay: delay > 0 ? delay : undefined,
|
||||
description: `Mock ${params.url} → ${status}${delay > 0 ? ` (${delay}ms delay)` : ""}`,
|
||||
};
|
||||
|
||||
activeRoutes.push(routeInfo);
|
||||
routeCleanups.set(routeId, cleanup);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Route mocked: ${routeInfo.description}\nRoute ID: ${routeId}\nActive routes: ${activeRoutes.length}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
routeId,
|
||||
...routeInfo,
|
||||
activeRouteCount: activeRoutes.length,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Mock route failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_block_urls
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_block_urls",
|
||||
label: "Browser Block URLs",
|
||||
description:
|
||||
"Block network requests matching URL patterns. Useful for blocking analytics, ads, or third-party scripts. " +
|
||||
"Accepts glob patterns. Routes survive page navigation.",
|
||||
parameters: Type.Object({
|
||||
patterns: Type.Array(Type.String(), {
|
||||
description:
|
||||
"URL patterns to block (glob syntax, e.g., ['**/analytics*', '**/ads*']).",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const results: ActiveRoute[] = [];
|
||||
|
||||
for (const pattern of params.patterns) {
|
||||
const routeId = nextRouteId++;
|
||||
|
||||
const handler = async (route: any) => {
|
||||
await route.abort("blockedbyclient");
|
||||
};
|
||||
|
||||
await p.route(pattern, handler);
|
||||
|
||||
const cleanup = async () => {
|
||||
try {
|
||||
await p.unroute(pattern, handler);
|
||||
} catch {
|
||||
/* cleanup — route may already be removed or page closed */
|
||||
}
|
||||
};
|
||||
|
||||
const routeInfo: ActiveRoute = {
|
||||
id: routeId,
|
||||
pattern,
|
||||
type: "block",
|
||||
description: `Block ${pattern}`,
|
||||
};
|
||||
|
||||
activeRoutes.push(routeInfo);
|
||||
routeCleanups.set(routeId, cleanup);
|
||||
results.push(routeInfo);
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Blocked ${results.length} URL pattern(s):\n${results.map((r) => ` - ${r.description} (ID: ${r.id})`).join("\n")}\nActive routes: ${activeRoutes.length}`,
|
||||
},
|
||||
],
|
||||
details: { blocked: results, activeRouteCount: activeRoutes.length },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Block URLs failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_clear_routes
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_clear_routes",
|
||||
label: "Browser Clear Routes",
|
||||
description:
|
||||
"Remove all active route mocks and URL blocks. Also lists currently active routes if called with no routes active.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const count = activeRoutes.length;
|
||||
|
||||
if (count === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No active routes to clear." }],
|
||||
details: { cleared: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
const routeDescriptions = activeRoutes.map((r) => r.description);
|
||||
|
||||
// Clean up all routes
|
||||
for (const [_id, cleanup] of routeCleanups) {
|
||||
await cleanup();
|
||||
}
|
||||
|
||||
activeRoutes.length = 0;
|
||||
routeCleanups.clear();
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cleared ${count} route(s):\n${routeDescriptions.map((d) => ` - ${d}`).join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { cleared: count, routes: routeDescriptions },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Clear routes failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,421 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
registryGetActive,
|
||||
registryListPages,
|
||||
registrySetActive,
|
||||
} from "../core.js";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
import { getActiveFrame, getPageRegistry, setActiveFrame } from "../state.js";
|
||||
|
||||
export function registerPageTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_list_pages
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_list_pages",
|
||||
label: "Browser List Pages",
|
||||
description:
|
||||
"List all open browser pages/tabs with their IDs, titles, URLs, and active status. Use to see what pages are available before switching.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const pageRegistry = getPageRegistry();
|
||||
for (const entry of pageRegistry.pages) {
|
||||
try {
|
||||
entry.title = await entry.page.title();
|
||||
entry.url = entry.page.url();
|
||||
} catch {
|
||||
// Page may have been closed
|
||||
}
|
||||
}
|
||||
const pages = registryListPages(pageRegistry);
|
||||
if (pages.length === 0) {
|
||||
return {
|
||||
content: [{ type: "text", text: "No pages open." }],
|
||||
details: { pages: [], count: 0 },
|
||||
};
|
||||
}
|
||||
const lines = pages.map((p: any) => {
|
||||
const active = p.isActive ? " ← active" : "";
|
||||
const opener = p.opener !== null ? ` (opener: ${p.opener})` : "";
|
||||
return ` [${p.id}] ${p.title || "(untitled)"} — ${p.url}${opener}${active}`;
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${pages.length} page(s):\n${lines.join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { pages, count: pages.length },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `List pages failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_switch_page
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_switch_page",
|
||||
label: "Browser Switch Page",
|
||||
description:
|
||||
"Switch the active browser page/tab by page ID. Use browser_list_pages to see available IDs. Clears any active frame selection.",
|
||||
parameters: Type.Object({
|
||||
id: Type.Number({
|
||||
description: "Page ID to switch to (from browser_list_pages)",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const pageRegistry = getPageRegistry();
|
||||
registrySetActive(pageRegistry, params.id);
|
||||
setActiveFrame(null);
|
||||
const entry = registryGetActive(pageRegistry);
|
||||
await entry.page.bringToFront();
|
||||
const title = await entry.page.title().catch(() => "");
|
||||
const url = entry.page.url();
|
||||
entry.title = title;
|
||||
entry.url = url;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Switched to page ${params.id}: ${title || "(untitled)"} — ${url}`,
|
||||
},
|
||||
],
|
||||
details: { id: params.id, title, url },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Switch page failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_close_page
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_close_page",
|
||||
label: "Browser Close Page",
|
||||
description:
|
||||
"Close a specific browser page/tab by ID. Cannot close the last remaining page. The page's close event triggers automatic registry cleanup and active-page fallback.",
|
||||
parameters: Type.Object({
|
||||
id: Type.Number({
|
||||
description: "Page ID to close (from browser_list_pages)",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const pageRegistry = getPageRegistry();
|
||||
if (pageRegistry.pages.length <= 1) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Cannot close the last remaining page. Use browser_close to close the entire browser.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "last_page",
|
||||
pageCount: pageRegistry.pages.length,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const entry = pageRegistry.pages.find((e: any) => e.id === params.id);
|
||||
if (!entry) {
|
||||
const available = pageRegistry.pages.map((e: any) => e.id);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Page ${params.id} not found. Available page IDs: [${available.join(", ")}].`,
|
||||
},
|
||||
],
|
||||
details: { error: "not_found", available },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
await entry.page.close();
|
||||
setActiveFrame(null);
|
||||
for (const remaining of pageRegistry.pages) {
|
||||
try {
|
||||
remaining.title = await remaining.page.title();
|
||||
remaining.url = remaining.page.url();
|
||||
} catch {
|
||||
/* non-fatal — page may have been closed or navigated away */
|
||||
}
|
||||
}
|
||||
const pages = registryListPages(pageRegistry);
|
||||
const lines = pages.map((p: any) => {
|
||||
const active = p.isActive ? " ← active" : "";
|
||||
return ` [${p.id}] ${p.title || "(untitled)"} — ${p.url}${active}`;
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Closed page ${params.id}. ${pages.length} page(s) remaining:\n${lines.join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: { closedId: params.id, pages, count: pages.length },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Close page failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_list_frames
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_list_frames",
|
||||
label: "Browser List Frames",
|
||||
description:
|
||||
"List all frames in the active page, including the main frame and any iframes. Shows frame name, URL, and parent frame name. Use before browser_select_frame to identify available frames.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const p = deps.getActivePage();
|
||||
const frames = p.frames();
|
||||
const mainFrame = p.mainFrame();
|
||||
const activeFrame = getActiveFrame();
|
||||
const frameList = frames.map((f, index) => {
|
||||
const isMain = f === mainFrame;
|
||||
const parentName =
|
||||
f.parentFrame()?.name() ||
|
||||
(f.parentFrame() === mainFrame ? "main" : "");
|
||||
return {
|
||||
index,
|
||||
name: f.name() || (isMain ? "main" : `(unnamed-${index})`),
|
||||
url: f.url(),
|
||||
isMain,
|
||||
parentName: isMain ? null : parentName || "main",
|
||||
isActive: f === activeFrame,
|
||||
};
|
||||
});
|
||||
const lines = frameList.map((f) => {
|
||||
const main = f.isMain ? " [main]" : "";
|
||||
const active = f.isActive ? " ← selected" : "";
|
||||
const parent = f.parentName ? ` (parent: ${f.parentName})` : "";
|
||||
return ` [${f.index}] "${f.name}" — ${f.url}${main}${parent}${active}`;
|
||||
});
|
||||
const activeInfo = activeFrame
|
||||
? `Active frame: "${activeFrame.name() || "(unnamed)"}"`
|
||||
: "No frame selected (operating on main page)";
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${frameList.length} frame(s) in active page:\n${lines.join("\n")}\n\n${activeInfo}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
frames: frameList,
|
||||
count: frameList.length,
|
||||
activeFrame: activeFrame?.name() ?? null,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `List frames failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_select_frame
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_select_frame",
|
||||
label: "Browser Select Frame",
|
||||
description:
|
||||
'Select a frame within the active page to operate on. Find frames by name, URL pattern, or index. Pass null or "main" to reset back to the main page frame. Once a frame is selected, tools like browser_evaluate, browser_find, and browser_click will operate within that frame (after T03 migration).',
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Frame name to select. Use 'main' or 'null' to reset to main frame.",
|
||||
}),
|
||||
),
|
||||
urlPattern: Type.Optional(
|
||||
Type.String({
|
||||
description: "URL substring to match against frame URLs.",
|
||||
}),
|
||||
),
|
||||
index: Type.Optional(
|
||||
Type.Number({ description: "Frame index from browser_list_frames." }),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const p = deps.getActivePage();
|
||||
const frames = p.frames();
|
||||
|
||||
if (
|
||||
params.name === "main" ||
|
||||
params.name === "null" ||
|
||||
params.name === null
|
||||
) {
|
||||
setActiveFrame(null);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Reset to main page frame. Tools will operate on the main page.",
|
||||
},
|
||||
],
|
||||
details: { activeFrame: null },
|
||||
};
|
||||
}
|
||||
|
||||
if (params.name) {
|
||||
const frame = frames.find((f) => f.name() === params.name);
|
||||
if (!frame) {
|
||||
const available = frames.map(
|
||||
(f, i) => `[${i}] "${f.name() || "(unnamed)"}" — ${f.url()}`,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Frame with name "${params.name}" not found.\nAvailable frames:\n ${available.join("\n ")}`,
|
||||
},
|
||||
],
|
||||
details: { error: "frame_not_found", available },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
setActiveFrame(frame);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Selected frame "${frame.name()}" — ${frame.url()}`,
|
||||
},
|
||||
],
|
||||
details: { name: frame.name(), url: frame.url() },
|
||||
};
|
||||
}
|
||||
|
||||
if (params.urlPattern) {
|
||||
const frame = frames.find((f) =>
|
||||
f.url().includes(params.urlPattern!),
|
||||
);
|
||||
if (!frame) {
|
||||
const available = frames.map(
|
||||
(f, i) => `[${i}] "${f.name() || "(unnamed)"}" — ${f.url()}`,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `No frame URL matches "${params.urlPattern}".\nAvailable frames:\n ${available.join("\n ")}`,
|
||||
},
|
||||
],
|
||||
details: { error: "frame_not_found", available },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
setActiveFrame(frame);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Selected frame "${frame.name() || "(unnamed)"}" — ${frame.url()}`,
|
||||
},
|
||||
],
|
||||
details: { name: frame.name(), url: frame.url() },
|
||||
};
|
||||
}
|
||||
|
||||
if (params.index !== undefined) {
|
||||
if (params.index < 0 || params.index >= frames.length) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Frame index ${params.index} out of range. ${frames.length} frame(s) available (0-${frames.length - 1}).`,
|
||||
},
|
||||
],
|
||||
details: { error: "index_out_of_range", count: frames.length },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const frame = frames[params.index];
|
||||
setActiveFrame(frame);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Selected frame [${params.index}] "${frame.name() || "(unnamed)"}" — ${frame.url()}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
index: params.index,
|
||||
name: frame.name(),
|
||||
url: frame.url(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Provide name, urlPattern, or index to select a frame. Use name='main' to reset to main frame.",
|
||||
},
|
||||
],
|
||||
details: { error: "no_criteria" },
|
||||
isError: true,
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Select frame failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
export function registerPdfTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_save_pdf",
|
||||
label: "Browser Save PDF",
|
||||
description:
|
||||
"Render current page as PDF artifact via Playwright's page.pdf(). " +
|
||||
"Supports A4/Letter/custom page formats and optional background graphics. " +
|
||||
"Writes to session artifacts directory. Chromium only.",
|
||||
parameters: Type.Object({
|
||||
filename: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Output filename (default: auto-generated from page title + timestamp).",
|
||||
}),
|
||||
),
|
||||
format: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Page format: 'A4' (default), 'Letter', 'Legal', 'Tabloid', or custom like '8.5in x 11in'. " +
|
||||
"Custom format uses CSS dimension syntax for width x height.",
|
||||
}),
|
||||
),
|
||||
printBackground: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Include background graphics (default: true).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
|
||||
const url = p.url();
|
||||
const title = await p.title().catch(() => "untitled");
|
||||
|
||||
// Resolve filename
|
||||
const timestamp = deps.formatArtifactTimestamp(Date.now());
|
||||
const safeName = deps.sanitizeArtifactName(
|
||||
params.filename || `${title}-${timestamp}`,
|
||||
`pdf-${timestamp}`,
|
||||
);
|
||||
const filename = safeName.endsWith(".pdf")
|
||||
? safeName
|
||||
: `${safeName}.pdf`;
|
||||
|
||||
// Resolve format
|
||||
const knownFormats = new Set([
|
||||
"A4",
|
||||
"Letter",
|
||||
"Legal",
|
||||
"Tabloid",
|
||||
"Ledger",
|
||||
"A0",
|
||||
"A1",
|
||||
"A2",
|
||||
"A3",
|
||||
"A5",
|
||||
"A6",
|
||||
]);
|
||||
const formatInput = params.format ?? "A4";
|
||||
const pdfOptions: Record<string, unknown> = {};
|
||||
|
||||
if (knownFormats.has(formatInput)) {
|
||||
pdfOptions.format = formatInput;
|
||||
} else {
|
||||
// Custom format: parse "WIDTHin x HEIGHTin" or "WIDTHcm x HEIGHTcm" etc.
|
||||
const customMatch = formatInput.match(/^(.+?)\s*[xX×]\s*(.+)$/);
|
||||
if (customMatch) {
|
||||
pdfOptions.width = customMatch[1]!.trim();
|
||||
pdfOptions.height = customMatch[2]!.trim();
|
||||
} else {
|
||||
pdfOptions.format = "A4"; // fallback
|
||||
}
|
||||
}
|
||||
|
||||
pdfOptions.printBackground = params.printBackground ?? true;
|
||||
|
||||
// Generate PDF
|
||||
await deps.ensureSessionArtifactDir();
|
||||
const outputPath = deps.buildSessionArtifactPath(filename);
|
||||
pdfOptions.path = outputPath;
|
||||
|
||||
await p.pdf(pdfOptions as any);
|
||||
|
||||
// Read file size
|
||||
const { stat } = await import("node:fs/promises");
|
||||
const fileStat = await stat(outputPath);
|
||||
const sizeBytes = fileStat.size;
|
||||
const sizeKB = (sizeBytes / 1024).toFixed(1);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `PDF saved: ${outputPath}\nSize: ${sizeKB} KB\nFormat: ${formatInput}\nPage: ${title}\nURL: ${url}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
path: outputPath,
|
||||
sizeBytes,
|
||||
format: formatInput,
|
||||
pageUrl: url,
|
||||
pageTitle: title,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `PDF generation failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,900 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { getSnapshotModeConfig, SNAPSHOT_MODES } from "../core.js";
|
||||
import type { RefNode, ToolDeps } from "../state.js";
|
||||
import {
|
||||
getActiveFrame,
|
||||
getCurrentRefMap,
|
||||
getRefMetadata,
|
||||
getRefVersion,
|
||||
setCurrentRefMap,
|
||||
setRefMetadata,
|
||||
setRefVersion,
|
||||
} from "../state.js";
|
||||
|
||||
export function registerRefTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_snapshot_refs
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_snapshot_refs",
|
||||
label: "Browser Snapshot Refs",
|
||||
description:
|
||||
"Capture a compact inventory of interactive elements and assign deterministic versioned refs (@vN:e1, @vN:e2, ...). Use these refs with browser_click_ref, browser_fill_ref, and browser_hover_ref.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional CSS selector scope for the snapshot (e.g. 'main', 'form', '#modal').",
|
||||
}),
|
||||
),
|
||||
interactiveOnly: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Include only interactive elements (default: true).",
|
||||
}),
|
||||
),
|
||||
limit: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Maximum number of elements to include (default: 40).",
|
||||
}),
|
||||
),
|
||||
mode: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Semantic snapshot mode that pre-filters elements by category. When set, overrides interactiveOnly. Modes: interactive, form, dialog, navigation, errors, headings, visible_only.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
|
||||
const mode = params.mode;
|
||||
if (mode !== undefined) {
|
||||
const modeConfig = getSnapshotModeConfig(mode);
|
||||
if (!modeConfig) {
|
||||
const validModes = Object.keys(SNAPSHOT_MODES).join(", ");
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Unknown snapshot mode: "${mode}". Valid modes: ${validModes}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: `Unknown mode: ${mode}`,
|
||||
validModes: Object.keys(SNAPSHOT_MODES),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
const interactiveOnly = params.interactiveOnly !== false;
|
||||
const limit = Math.max(
|
||||
1,
|
||||
Math.min(200, Math.floor(params.limit ?? 40)),
|
||||
);
|
||||
const rawNodes = await deps.buildRefSnapshot(target, {
|
||||
selector: params.selector,
|
||||
interactiveOnly,
|
||||
limit,
|
||||
mode,
|
||||
});
|
||||
|
||||
const newVersion = getRefVersion() + 1;
|
||||
setRefVersion(newVersion);
|
||||
const nextMap: Record<string, RefNode> = {};
|
||||
for (let i = 0; i < rawNodes.length; i += 1) {
|
||||
const ref = `e${i + 1}`;
|
||||
nextMap[ref] = { ref, ...rawNodes[i] };
|
||||
}
|
||||
setCurrentRefMap(nextMap);
|
||||
const activeFrame = getActiveFrame();
|
||||
const frameCtx = activeFrame
|
||||
? activeFrame.name() || activeFrame.url()
|
||||
: undefined;
|
||||
setRefMetadata({
|
||||
url: p.url(),
|
||||
timestamp: Date.now(),
|
||||
selectorScope: params.selector,
|
||||
interactiveOnly,
|
||||
limit,
|
||||
version: newVersion,
|
||||
frameContext: frameCtx,
|
||||
mode,
|
||||
});
|
||||
|
||||
if (rawNodes.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "No elements found for ref snapshot (try interactiveOnly=false or a wider selector scope).",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
count: 0,
|
||||
version: newVersion,
|
||||
metadata: getRefMetadata(),
|
||||
refs: {},
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const versionedRefs: Record<string, RefNode> = {};
|
||||
const lines = Object.values(nextMap).map((node) => {
|
||||
const versionedRef = deps.formatVersionedRef(newVersion, node.ref);
|
||||
versionedRefs[versionedRef] = node;
|
||||
const parts: string[] = [versionedRef, node.role || node.tag];
|
||||
if (node.name) parts.push(`"${node.name}"`);
|
||||
if (node.href) parts.push(`href="${node.href.slice(0, 80)}"`);
|
||||
if (!node.isVisible) parts.push("(hidden)");
|
||||
if (!node.isEnabled) parts.push("(disabled)");
|
||||
return parts.join(" ");
|
||||
});
|
||||
|
||||
const modeLabel = mode ? `Mode: ${mode}\n` : "";
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text:
|
||||
`Ref snapshot v${newVersion} (${rawNodes.length} element(s))\n` +
|
||||
`URL: ${p.url()}\n` +
|
||||
`Scope: ${params.selector ?? "body"}\n` +
|
||||
modeLabel +
|
||||
`Use versioned refs exactly as shown (e.g. @v${newVersion}:e1).\n\n` +
|
||||
lines.join("\n"),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
count: rawNodes.length,
|
||||
version: newVersion,
|
||||
metadata: getRefMetadata(),
|
||||
refs: nextMap,
|
||||
versionedRefs,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Snapshot refs failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_get_ref
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_get_ref",
|
||||
label: "Browser Get Ref",
|
||||
description:
|
||||
"Inspect stored metadata for one deterministic element ref (prefer versioned format, e.g. @v3:e1).",
|
||||
parameters: Type.Object({
|
||||
ref: Type.String({
|
||||
description: "Reference id, preferably versioned (e.g. '@v3:e1').",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const parsedRef = deps.parseRef(params.ref);
|
||||
const refMetadata = getRefMetadata();
|
||||
const refVersion = getRefVersion();
|
||||
if (
|
||||
parsedRef.version !== null &&
|
||||
refMetadata &&
|
||||
parsedRef.version !== refMetadata.version
|
||||
) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
parsedRef.display,
|
||||
`snapshot version mismatch (have v${refMetadata.version})`,
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: parsedRef.display,
|
||||
expectedVersion: refMetadata.version,
|
||||
receivedVersion: parsedRef.version,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const node = currentRefMap[parsedRef.key];
|
||||
if (!node) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(parsedRef.display, "ref not found"),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_not_found",
|
||||
ref: parsedRef.display,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const versionedRef = deps.formatVersionedRef(
|
||||
refMetadata?.version ?? refVersion,
|
||||
node.ref,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${versionedRef}: ${node.role || node.tag}${node.name ? ` "${node.name}"` : ""}\nVisible: ${node.isVisible}\nEnabled: ${node.isEnabled}\nPath: ${node.xpathOrPath}`,
|
||||
},
|
||||
],
|
||||
details: { ref: versionedRef, node, metadata: refMetadata },
|
||||
};
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_click_ref
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_click_ref",
|
||||
label: "Browser Click Ref",
|
||||
description:
|
||||
"Click a previously snapshotted element by deterministic versioned ref (e.g. @v3:e2).",
|
||||
parameters: Type.Object({
|
||||
ref: Type.String({
|
||||
description: "Reference id in versioned format, e.g. '@v3:e2'.",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const parsedRef = deps.parseRef(params.ref);
|
||||
const requestedRef = parsedRef.display;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const refMetadata = getRefMetadata();
|
||||
const refVersion = getRefVersion();
|
||||
if (parsedRef.version === null) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Unversioned ref ${requestedRef} is ambiguous. Use a versioned ref (e.g. @v${refMetadata?.version ?? refVersion}:e1) from browser_snapshot_refs.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_unversioned",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata && parsedRef.version !== refMetadata.version) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`snapshot version mismatch (have v${refMetadata.version})`,
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
expectedVersion: refMetadata.version,
|
||||
receivedVersion: parsedRef.version,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const ref = parsedRef.key;
|
||||
const node = currentRefMap[ref];
|
||||
if (!node) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, "ref not found"),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_not_found",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata?.url && refMetadata.url !== p.url()) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
"URL changed since snapshot",
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
snapshotUrl: refMetadata.url,
|
||||
currentUrl: p.url(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const resolved = await deps.resolveRefTarget(target, node);
|
||||
if (!resolved.ok) {
|
||||
const reason = (resolved as { ok: false; reason: string }).reason;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, reason),
|
||||
},
|
||||
],
|
||||
details: { error: "ref_stale", ref: requestedRef, reason },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const beforeState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
const beforeUrl = beforeState.url;
|
||||
const beforeHash = deps.getUrlHash(beforeUrl);
|
||||
const beforeTargetState = await deps.captureClickTargetState(
|
||||
target,
|
||||
resolved.selector,
|
||||
);
|
||||
await target
|
||||
.locator(resolved.selector)
|
||||
.first()
|
||||
.click({ timeout: 8000 });
|
||||
const settle = await deps.settleAfterActionAdaptive(p);
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
const afterUrl = afterState.url;
|
||||
const afterHash = deps.getUrlHash(afterUrl);
|
||||
const afterTargetState = await deps.captureClickTargetState(
|
||||
target,
|
||||
resolved.selector,
|
||||
);
|
||||
const targetStateChanged =
|
||||
beforeTargetState.exists !== afterTargetState.exists ||
|
||||
beforeTargetState.ariaExpanded !== afterTargetState.ariaExpanded ||
|
||||
beforeTargetState.ariaPressed !== afterTargetState.ariaPressed ||
|
||||
beforeTargetState.ariaSelected !== afterTargetState.ariaSelected ||
|
||||
beforeTargetState.open !== afterTargetState.open;
|
||||
const verification = deps.verificationFromChecks(
|
||||
[
|
||||
{
|
||||
name: "url_changed",
|
||||
passed: afterUrl !== beforeUrl,
|
||||
value: afterUrl,
|
||||
expected: `!= ${beforeUrl}`,
|
||||
},
|
||||
{
|
||||
name: "hash_changed",
|
||||
passed: afterHash !== beforeHash,
|
||||
value: afterHash,
|
||||
expected: `!= ${beforeHash}`,
|
||||
},
|
||||
{
|
||||
name: "target_state_changed",
|
||||
passed: targetStateChanged,
|
||||
value: afterTargetState,
|
||||
expected: beforeTargetState,
|
||||
},
|
||||
{
|
||||
name: "dialog_open",
|
||||
passed: afterState.dialog.count > beforeState.dialog.count,
|
||||
value: afterState.dialog.count,
|
||||
expected: `> ${beforeState.dialog.count}`,
|
||||
},
|
||||
],
|
||||
"Ref may now point to an inert element. Refresh refs with browser_snapshot_refs and retry.",
|
||||
);
|
||||
|
||||
const summary = deps.formatCompactStateSummary(afterState);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
const versionedRef = deps.formatVersionedRef(
|
||||
refMetadata?.version ?? refVersion,
|
||||
node.ref,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Clicked ${versionedRef} (${node.role || node.tag}${node.name ? ` "${node.name}"` : ""})\n${deps.verificationLine(verification)}${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
ref: versionedRef,
|
||||
selector: resolved.selector,
|
||||
url: p.url(),
|
||||
...settle,
|
||||
...verification,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const reason = deps.firstErrorLine(err);
|
||||
const content: any[] = [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`action failed: ${reason}`,
|
||||
),
|
||||
},
|
||||
{ type: "text", text: `Click ref failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return {
|
||||
content,
|
||||
details: {
|
||||
error: err.message,
|
||||
ref: requestedRef,
|
||||
hint: "Run browser_snapshot_refs to refresh refs.",
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_hover_ref
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_hover_ref",
|
||||
label: "Browser Hover Ref",
|
||||
description:
|
||||
"Hover a previously snapshotted element by deterministic versioned ref (e.g. @v3:e4).",
|
||||
parameters: Type.Object({
|
||||
ref: Type.String({
|
||||
description: "Reference id in versioned format, e.g. '@v3:e4'.",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const parsedRef = deps.parseRef(params.ref);
|
||||
const requestedRef = parsedRef.display;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const refMetadata = getRefMetadata();
|
||||
const refVersion = getRefVersion();
|
||||
if (parsedRef.version === null) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Unversioned ref ${requestedRef} is ambiguous. Use a versioned ref (e.g. @v${refMetadata?.version ?? refVersion}:e1) from browser_snapshot_refs.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_unversioned",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata && parsedRef.version !== refMetadata.version) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`snapshot version mismatch (have v${refMetadata.version})`,
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
expectedVersion: refMetadata.version,
|
||||
receivedVersion: parsedRef.version,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const ref = parsedRef.key;
|
||||
const node = currentRefMap[ref];
|
||||
if (!node) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, "ref not found"),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_not_found",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata?.url && refMetadata.url !== p.url()) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
"URL changed since snapshot",
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
snapshotUrl: refMetadata.url,
|
||||
currentUrl: p.url(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const resolved = await deps.resolveRefTarget(target, node);
|
||||
if (!resolved.ok) {
|
||||
const reason = (resolved as { ok: false; reason: string }).reason;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, reason),
|
||||
},
|
||||
],
|
||||
details: { error: "ref_stale", ref: requestedRef, reason },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
await target
|
||||
.locator(resolved.selector)
|
||||
.first()
|
||||
.hover({ timeout: 8000 });
|
||||
const settle = await deps.settleAfterActionAdaptive(p);
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: false,
|
||||
target,
|
||||
});
|
||||
const summary = deps.formatCompactStateSummary(afterState);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
const versionedRef = deps.formatVersionedRef(
|
||||
refMetadata?.version ?? refVersion,
|
||||
node.ref,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Hovered ${versionedRef} (${node.role || node.tag}${node.name ? ` "${node.name}"` : ""})${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
ref: versionedRef,
|
||||
selector: resolved.selector,
|
||||
url: p.url(),
|
||||
...settle,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const reason = deps.firstErrorLine(err);
|
||||
const content: any[] = [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`action failed: ${reason}`,
|
||||
),
|
||||
},
|
||||
{ type: "text", text: `Hover ref failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return {
|
||||
content,
|
||||
details: {
|
||||
error: err.message,
|
||||
ref: requestedRef,
|
||||
hint: "Run browser_snapshot_refs to refresh refs.",
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_fill_ref
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_fill_ref",
|
||||
label: "Browser Fill Ref",
|
||||
description:
|
||||
"Fill/type text into an input-like element by deterministic versioned ref (e.g. @v3:e1).",
|
||||
parameters: Type.Object({
|
||||
ref: Type.String({
|
||||
description: "Reference id in versioned format, e.g. '@v3:e1'.",
|
||||
}),
|
||||
text: Type.String({ description: "Text to enter." }),
|
||||
clearFirst: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Clear existing value first (default: false).",
|
||||
}),
|
||||
),
|
||||
submit: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Press Enter after typing (default: false).",
|
||||
}),
|
||||
),
|
||||
slowly: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Type character-by-character (default: false).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const parsedRef = deps.parseRef(params.ref);
|
||||
const requestedRef = parsedRef.display;
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const refMetadata = getRefMetadata();
|
||||
const refVersion = getRefVersion();
|
||||
if (parsedRef.version === null) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Unversioned ref ${requestedRef} is ambiguous. Use a versioned ref (e.g. @v${refMetadata?.version ?? refVersion}:e1) from browser_snapshot_refs.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_unversioned",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata && parsedRef.version !== refMetadata.version) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`snapshot version mismatch (have v${refMetadata.version})`,
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
expectedVersion: refMetadata.version,
|
||||
receivedVersion: parsedRef.version,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const currentRefMap = getCurrentRefMap();
|
||||
const ref = parsedRef.key;
|
||||
const node = currentRefMap[ref];
|
||||
if (!node) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, "ref not found"),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_not_found",
|
||||
ref: requestedRef,
|
||||
metadata: refMetadata,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
if (refMetadata?.url && refMetadata.url !== p.url()) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
"URL changed since snapshot",
|
||||
),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "ref_stale",
|
||||
ref: requestedRef,
|
||||
snapshotUrl: refMetadata.url,
|
||||
currentUrl: p.url(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const resolved = await deps.resolveRefTarget(target, node);
|
||||
if (!resolved.ok) {
|
||||
const reason = (resolved as { ok: false; reason: string }).reason;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(requestedRef, reason),
|
||||
},
|
||||
],
|
||||
details: { error: "ref_stale", ref: requestedRef, reason },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const locator = target.locator(resolved.selector).first();
|
||||
const beforeUrl = p.url();
|
||||
if (params.slowly) {
|
||||
await locator.click({ timeout: 8000 });
|
||||
if (params.clearFirst) {
|
||||
await p.keyboard.press("Control+A");
|
||||
await p.keyboard.press("Delete");
|
||||
}
|
||||
await p.keyboard.type(params.text);
|
||||
} else {
|
||||
if (params.clearFirst) {
|
||||
await locator.fill("");
|
||||
}
|
||||
await locator.fill(params.text, { timeout: 8000 });
|
||||
}
|
||||
if (params.submit) {
|
||||
await p.keyboard.press("Enter");
|
||||
}
|
||||
const settle = await deps.settleAfterActionAdaptive(p);
|
||||
|
||||
const filledValue = await deps.readInputLikeValue(
|
||||
target,
|
||||
resolved.selector,
|
||||
);
|
||||
const afterUrl = p.url();
|
||||
const verification = deps.verificationFromChecks(
|
||||
[
|
||||
{
|
||||
name: "value_equals_expected",
|
||||
passed: filledValue === params.text,
|
||||
value: filledValue,
|
||||
expected: params.text,
|
||||
},
|
||||
{
|
||||
name: "value_contains_expected",
|
||||
passed:
|
||||
typeof filledValue === "string" &&
|
||||
filledValue.includes(params.text),
|
||||
value: filledValue,
|
||||
expected: params.text,
|
||||
},
|
||||
{
|
||||
name: "url_changed_after_submit",
|
||||
passed: !!params.submit && afterUrl !== beforeUrl,
|
||||
value: afterUrl,
|
||||
expected: `!= ${beforeUrl}`,
|
||||
},
|
||||
],
|
||||
"Try refreshing refs and confirm this ref still targets an input-like element.",
|
||||
);
|
||||
|
||||
const afterState = await deps.captureCompactPageState(p, {
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
const summary = deps.formatCompactStateSummary(afterState);
|
||||
const jsErrors = deps.getRecentErrors(p.url());
|
||||
const versionedRef = deps.formatVersionedRef(
|
||||
refMetadata?.version ?? refVersion,
|
||||
node.ref,
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Filled ${versionedRef} (${node.role || node.tag}${node.name ? ` "${node.name}"` : ""}) with "${params.text}"\n${deps.verificationLine(verification)}${jsErrors}\n\nPage summary:\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
ref: versionedRef,
|
||||
selector: resolved.selector,
|
||||
url: p.url(),
|
||||
filledValue,
|
||||
...settle,
|
||||
...verification,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
const errorShot = await deps.captureErrorScreenshot(
|
||||
deps.getActivePageOrNull(),
|
||||
);
|
||||
const reason = deps.firstErrorLine(err);
|
||||
const content: any[] = [
|
||||
{
|
||||
type: "text",
|
||||
text: deps.staleRefGuidance(
|
||||
requestedRef,
|
||||
`action failed: ${reason}`,
|
||||
),
|
||||
},
|
||||
{ type: "text", text: `Fill ref failed: ${err.message}` },
|
||||
];
|
||||
if (errorShot) {
|
||||
content.push({
|
||||
type: "image",
|
||||
data: errorShot.data,
|
||||
mimeType: errorShot.mimeType,
|
||||
});
|
||||
}
|
||||
return {
|
||||
content,
|
||||
details: {
|
||||
error: err.message,
|
||||
ref: requestedRef,
|
||||
hint: "Run browser_snapshot_refs to refresh refs.",
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,129 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
getScreenshotFormatOverride,
|
||||
getScreenshotQualityDefault,
|
||||
} from "../capture.js";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
export function registerScreenshotTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
pi.registerTool({
|
||||
name: "browser_screenshot",
|
||||
label: "Browser Screenshot",
|
||||
description:
|
||||
"Take a screenshot of the current browser page and return it as an inline image. Uses JPEG for viewport/fullpage (smaller, configurable quality) and PNG for element crops (preserves transparency). Optionally crop to a specific element by CSS selector.",
|
||||
parameters: Type.Object({
|
||||
fullPage: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Capture the full scrollable page (default: false)",
|
||||
}),
|
||||
),
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector of a specific element to screenshot (crops to that element's bounding box). If omitted, screenshots the entire viewport.",
|
||||
}),
|
||||
),
|
||||
quality: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"JPEG quality 1-100 (default: 80). Only applies to viewport/fullpage screenshots, not element crops. Lower = smaller image.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
|
||||
let screenshotBuffer: Buffer;
|
||||
let mimeType: string;
|
||||
const formatOverride = getScreenshotFormatOverride();
|
||||
const quality = params.quality ?? getScreenshotQualityDefault(80);
|
||||
|
||||
if (params.selector) {
|
||||
const fmt = formatOverride ?? "png";
|
||||
const locator = p.locator(params.selector).first();
|
||||
if (fmt === "jpeg") {
|
||||
screenshotBuffer = await locator.screenshot({
|
||||
type: "jpeg",
|
||||
quality,
|
||||
scale: "css",
|
||||
});
|
||||
mimeType = "image/jpeg";
|
||||
} else {
|
||||
screenshotBuffer = await locator.screenshot({
|
||||
type: "png",
|
||||
scale: "css",
|
||||
});
|
||||
mimeType = "image/png";
|
||||
}
|
||||
} else {
|
||||
const fmt = formatOverride ?? "jpeg";
|
||||
if (fmt === "png") {
|
||||
screenshotBuffer = await p.screenshot({
|
||||
fullPage: params.fullPage ?? false,
|
||||
type: "png",
|
||||
scale: "css",
|
||||
});
|
||||
mimeType = "image/png";
|
||||
} else {
|
||||
screenshotBuffer = await p.screenshot({
|
||||
fullPage: params.fullPage ?? false,
|
||||
type: "jpeg",
|
||||
quality,
|
||||
scale: "css",
|
||||
});
|
||||
mimeType = "image/jpeg";
|
||||
}
|
||||
}
|
||||
|
||||
screenshotBuffer = await deps.constrainScreenshot(
|
||||
p,
|
||||
screenshotBuffer,
|
||||
mimeType,
|
||||
quality,
|
||||
);
|
||||
|
||||
const base64Data = screenshotBuffer.toString("base64");
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
const viewport = p.viewportSize();
|
||||
const vpText = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
const scope = params.selector
|
||||
? `element "${params.selector}"`
|
||||
: params.fullPage
|
||||
? "full page"
|
||||
: "viewport";
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Screenshot of ${scope}.\nPage: ${title}\nURL: ${url}\nViewport: ${vpText}`,
|
||||
},
|
||||
{
|
||||
type: "image",
|
||||
data: base64Data,
|
||||
mimeType,
|
||||
},
|
||||
],
|
||||
details: { title, url, scope, viewport: vpText },
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Screenshot failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,572 +0,0 @@
|
|||
import { stat } from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
buildFailureHypothesis,
|
||||
formatTimelineEntries,
|
||||
summarizeBrowserSession,
|
||||
} from "../core.js";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
import {
|
||||
ARTIFACT_ROOT,
|
||||
getActionTimeline,
|
||||
getActiveTraceSession,
|
||||
getConsoleLogs,
|
||||
getDialogLogs,
|
||||
getHarState,
|
||||
getNetworkLogs,
|
||||
getPageRegistry,
|
||||
getSessionArtifactDir,
|
||||
getSessionStartedAt,
|
||||
HAR_FILENAME,
|
||||
setActiveTraceSession,
|
||||
setHarState,
|
||||
} from "../state.js";
|
||||
import { ensureDir, getActiveFrameMetadata } from "../utils.js";
|
||||
|
||||
export function registerSessionTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_close
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_close",
|
||||
label: "Browser Close",
|
||||
description: "Close the browser and clean up all resources.",
|
||||
parameters: Type.Object({}),
|
||||
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.closeBrowser();
|
||||
return {
|
||||
content: [{ type: "text", text: "Browser closed." }],
|
||||
details: {},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Close failed: ${err.message}` }],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_trace_start
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_trace_start",
|
||||
label: "Browser Trace Start",
|
||||
description:
|
||||
"Start a Playwright trace for the current browser session and persist trace metadata under the session artifact directory.",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional short trace session name for artifact filenames.",
|
||||
}),
|
||||
),
|
||||
title: Type.Optional(
|
||||
Type.String({
|
||||
description: "Optional trace title recorded in metadata.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { context: browserContext } = await deps.ensureBrowser();
|
||||
const activeTrace = getActiveTraceSession();
|
||||
if (activeTrace) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Trace already active: ${activeTrace.name}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "trace_already_active",
|
||||
activeTraceSession: activeTrace,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const startedAt = Date.now();
|
||||
const name = (
|
||||
params.name?.trim() ||
|
||||
`trace-${deps.formatArtifactTimestamp(startedAt)}`
|
||||
).replace(/[^a-zA-Z0-9._-]+/g, "-");
|
||||
await browserContext.tracing.start({
|
||||
screenshots: true,
|
||||
snapshots: true,
|
||||
sources: true,
|
||||
title: params.title ?? name,
|
||||
});
|
||||
setActiveTraceSession({ startedAt, name, title: params.title ?? name });
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Trace started: ${name}\nSession dir: ${getSessionArtifactDir()}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
activeTraceSession: getActiveTraceSession(),
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Trace start failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_trace_stop
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_trace_stop",
|
||||
label: "Browser Trace Stop",
|
||||
description:
|
||||
"Stop the active Playwright trace and write the trace zip to disk under the session artifact directory.",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description: "Optional artifact basename override for the trace zip.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { context: browserContext } = await deps.ensureBrowser();
|
||||
const activeTrace = getActiveTraceSession();
|
||||
if (!activeTrace) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: "No active trace session to stop." },
|
||||
],
|
||||
details: {
|
||||
error: "trace_not_active",
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const traceSession = activeTrace;
|
||||
const traceName = (params.name?.trim() || traceSession.name).replace(
|
||||
/[^a-zA-Z0-9._-]+/g,
|
||||
"-",
|
||||
);
|
||||
const tracePath = deps.buildSessionArtifactPath(
|
||||
`${traceName}.trace.zip`,
|
||||
);
|
||||
await browserContext.tracing.stop({ path: tracePath });
|
||||
const fileStat = await stat(tracePath);
|
||||
setActiveTraceSession(null);
|
||||
return {
|
||||
content: [{ type: "text", text: `Trace stopped: ${tracePath}` }],
|
||||
details: {
|
||||
path: tracePath,
|
||||
bytes: fileStat.size,
|
||||
elapsedMs: Date.now() - traceSession.startedAt,
|
||||
traceName,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Trace stop failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_export_har
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_export_har",
|
||||
label: "Browser Export HAR",
|
||||
description:
|
||||
"Export the truthfully recorded session HAR from disk to a stable artifact path and return compact metadata.",
|
||||
parameters: Type.Object({
|
||||
filename: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional destination filename within the session artifact directory.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const harState = getHarState();
|
||||
if (
|
||||
!harState.enabled ||
|
||||
!harState.configuredAtContextCreation ||
|
||||
!harState.path
|
||||
) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "HAR export unavailable: HAR recording was not enabled at browser context creation.",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "har_not_enabled",
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const sourcePath = harState.path;
|
||||
const destinationName = (
|
||||
params.filename?.trim() || `export-${HAR_FILENAME}`
|
||||
).replace(/[^a-zA-Z0-9._-]+/g, "-");
|
||||
const destinationPath = deps.buildSessionArtifactPath(destinationName);
|
||||
const exportResult =
|
||||
sourcePath === destinationPath
|
||||
? { path: sourcePath, bytes: (await stat(sourcePath)).size }
|
||||
: await deps.copyArtifactFile(sourcePath, destinationPath);
|
||||
setHarState({
|
||||
...harState,
|
||||
exportCount: harState.exportCount + 1,
|
||||
lastExportedPath: exportResult.path,
|
||||
lastExportedAt: Date.now(),
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `HAR exported: ${exportResult.path}` },
|
||||
],
|
||||
details: {
|
||||
path: exportResult.path,
|
||||
bytes: exportResult.bytes,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `HAR export failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_timeline
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_timeline",
|
||||
label: "Browser Timeline",
|
||||
description:
|
||||
"Return a compact structured summary of the tracked browser action timeline and optional on-disk export path.",
|
||||
parameters: Type.Object({
|
||||
writeToDisk: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"Write the timeline JSON to disk under the session artifact directory.",
|
||||
}),
|
||||
),
|
||||
filename: Type.Optional(
|
||||
Type.String({
|
||||
description: "Optional JSON filename when writeToDisk is true.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const actionTimeline = getActionTimeline();
|
||||
const timeline = formatTimelineEntries(actionTimeline.entries, {
|
||||
limit: actionTimeline.limit,
|
||||
totalActions: actionTimeline.nextId - 1,
|
||||
});
|
||||
let artifact: { path: string; bytes: number } | null = null;
|
||||
if (params.writeToDisk) {
|
||||
const filename = (params.filename?.trim() || "timeline.json").replace(
|
||||
/[^a-zA-Z0-9._-]+/g,
|
||||
"-",
|
||||
);
|
||||
artifact = await deps.writeArtifactFile(
|
||||
deps.buildSessionArtifactPath(filename),
|
||||
JSON.stringify(timeline, null, 2),
|
||||
);
|
||||
}
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: artifact
|
||||
? `${timeline.summary}\nArtifact: ${artifact.path}`
|
||||
: timeline.summary,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
...timeline,
|
||||
artifact,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Timeline failed: ${err.message}` }],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_session_summary
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_session_summary",
|
||||
label: "Browser Session Summary",
|
||||
description:
|
||||
"Return a compact structured summary of the current browser session, including pages, actions, waits/assertions, bounded-history caveats, and trace/HAR state.",
|
||||
parameters: Type.Object({}),
|
||||
async execute(_toolCallId, _params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
await deps.ensureBrowser();
|
||||
const pages = await deps.getLivePagesSnapshot();
|
||||
const actionTimeline = getActionTimeline();
|
||||
const pageRegistry = getPageRegistry();
|
||||
const consoleLogs = getConsoleLogs();
|
||||
const networkLogs = getNetworkLogs();
|
||||
const dialogLogs = getDialogLogs();
|
||||
const baseSummary = summarizeBrowserSession({
|
||||
timeline: actionTimeline,
|
||||
totalActions: actionTimeline.nextId - 1,
|
||||
pages,
|
||||
activePageId: pageRegistry.activePageId,
|
||||
activeFrame: getActiveFrameMetadata(),
|
||||
consoleEntries: consoleLogs,
|
||||
networkEntries: networkLogs,
|
||||
dialogEntries: dialogLogs,
|
||||
consoleLimit: 1000,
|
||||
networkLimit: 1000,
|
||||
dialogLimit: 1000,
|
||||
sessionStartedAt: getSessionStartedAt(),
|
||||
now: Date.now(),
|
||||
});
|
||||
const failureHypothesis = buildFailureHypothesis({
|
||||
timeline: actionTimeline,
|
||||
consoleEntries: consoleLogs,
|
||||
networkEntries: networkLogs,
|
||||
dialogEntries: dialogLogs,
|
||||
});
|
||||
const activeTrace = getActiveTraceSession();
|
||||
const traceState = activeTrace
|
||||
? { status: "active", ...activeTrace }
|
||||
: {
|
||||
status: "inactive",
|
||||
lastTracePath: getSessionArtifactDir()
|
||||
? deps.buildSessionArtifactPath("*.trace.zip")
|
||||
: null,
|
||||
};
|
||||
const harState = getHarState();
|
||||
const harSummary = {
|
||||
enabled: harState.enabled,
|
||||
configuredAtContextCreation: harState.configuredAtContextCreation,
|
||||
path: harState.path,
|
||||
exportCount: harState.exportCount,
|
||||
lastExportedPath: harState.lastExportedPath,
|
||||
lastExportedAt: harState.lastExportedAt,
|
||||
};
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `${baseSummary.summary}\nFailure hypothesis: ${failureHypothesis}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
...baseSummary,
|
||||
failureHypothesis,
|
||||
trace: traceState,
|
||||
har: harSummary,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Session summary failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_debug_bundle
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_debug_bundle",
|
||||
label: "Browser Debug Bundle",
|
||||
description:
|
||||
"Write a timestamped debug bundle to disk with screenshot, logs, timeline, pages, session summary, and accessibility output, then return compact paths and counts.",
|
||||
parameters: Type.Object({
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional CSS selector to scope the accessibility snapshot before fallback behavior applies.",
|
||||
}),
|
||||
),
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional short bundle name suffix for the output directory.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const startedAt = Date.now();
|
||||
const sessionDir = await deps.ensureSessionArtifactDir();
|
||||
const bundleDir = path.join(
|
||||
ARTIFACT_ROOT,
|
||||
`${deps.formatArtifactTimestamp(startedAt)}-${deps.sanitizeArtifactName(params.name ?? "debug-bundle", "debug-bundle")}`,
|
||||
);
|
||||
await ensureDir(bundleDir);
|
||||
const pages = await deps.getLivePagesSnapshot();
|
||||
const actionTimeline = getActionTimeline();
|
||||
const pageRegistry = getPageRegistry();
|
||||
const consoleLogs = getConsoleLogs();
|
||||
const networkLogs = getNetworkLogs();
|
||||
const dialogLogs = getDialogLogs();
|
||||
const timeline = formatTimelineEntries(actionTimeline.entries, {
|
||||
limit: actionTimeline.limit,
|
||||
totalActions: actionTimeline.nextId - 1,
|
||||
});
|
||||
const sessionSummary = summarizeBrowserSession({
|
||||
timeline: actionTimeline,
|
||||
totalActions: actionTimeline.nextId - 1,
|
||||
pages,
|
||||
activePageId: pageRegistry.activePageId,
|
||||
activeFrame: getActiveFrameMetadata(),
|
||||
consoleEntries: consoleLogs,
|
||||
networkEntries: networkLogs,
|
||||
dialogEntries: dialogLogs,
|
||||
consoleLimit: 1000,
|
||||
networkLimit: 1000,
|
||||
dialogLimit: 1000,
|
||||
sessionStartedAt: getSessionStartedAt(),
|
||||
now: Date.now(),
|
||||
});
|
||||
const failureHypothesis = buildFailureHypothesis({
|
||||
timeline: actionTimeline,
|
||||
consoleEntries: consoleLogs,
|
||||
networkEntries: networkLogs,
|
||||
dialogEntries: dialogLogs,
|
||||
});
|
||||
const accessibility = await deps.captureAccessibilityMarkdown(
|
||||
params.selector,
|
||||
);
|
||||
const screenshotPath = path.join(bundleDir, "screenshot.jpg");
|
||||
await p.screenshot({
|
||||
path: screenshotPath,
|
||||
type: "jpeg",
|
||||
quality: 80,
|
||||
fullPage: false,
|
||||
});
|
||||
const screenshotStat = await stat(screenshotPath);
|
||||
const artifacts = {
|
||||
screenshot: { path: screenshotPath, bytes: screenshotStat.size },
|
||||
console: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "console.json"),
|
||||
JSON.stringify(consoleLogs, null, 2),
|
||||
),
|
||||
network: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "network.json"),
|
||||
JSON.stringify(networkLogs, null, 2),
|
||||
),
|
||||
dialog: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "dialog.json"),
|
||||
JSON.stringify(dialogLogs, null, 2),
|
||||
),
|
||||
timeline: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "timeline.json"),
|
||||
JSON.stringify(timeline, null, 2),
|
||||
),
|
||||
summary: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "summary.json"),
|
||||
JSON.stringify(
|
||||
{
|
||||
...sessionSummary,
|
||||
failureHypothesis,
|
||||
trace: getActiveTraceSession(),
|
||||
har: getHarState(),
|
||||
sessionArtifactDir: sessionDir,
|
||||
},
|
||||
null,
|
||||
2,
|
||||
),
|
||||
),
|
||||
pages: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "pages.json"),
|
||||
JSON.stringify(pages, null, 2),
|
||||
),
|
||||
accessibility: await deps.writeArtifactFile(
|
||||
path.join(bundleDir, "accessibility.md"),
|
||||
accessibility.snapshot,
|
||||
),
|
||||
};
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Debug bundle written: ${bundleDir}\n${sessionSummary.summary}\nFailure hypothesis: ${failureHypothesis}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
bundleDir,
|
||||
artifacts,
|
||||
accessibilityScope: accessibility.scope,
|
||||
accessibilitySource: accessibility.source,
|
||||
counts: {
|
||||
console: consoleLogs.length,
|
||||
network: networkLogs.length,
|
||||
dialog: dialogLogs.length,
|
||||
actions: timeline.retained,
|
||||
pages: pages.length,
|
||||
},
|
||||
elapsedMs: Date.now() - startedAt,
|
||||
summary: sessionSummary,
|
||||
failureHypothesis,
|
||||
...deps.getSessionArtifactMetadata(),
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Debug bundle failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message, ...deps.getSessionArtifactMetadata() },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,239 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* State persistence tools — save/restore cookies, localStorage, sessionStorage.
|
||||
*/
|
||||
|
||||
const STATE_DIR = ".sf/browser-state";
|
||||
|
||||
export function registerStatePersistenceTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_save_state
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_save_state",
|
||||
label: "Browser Save State",
|
||||
description:
|
||||
"Save cookies, localStorage, and sessionStorage to disk so authenticated sessions survive browser restarts. " +
|
||||
"State files are written to .sf/browser-state/ and should be gitignored (may contain auth tokens). " +
|
||||
"Never displays secret values in output.",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Name for the state file (default: 'default'). Used as the filename stem.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { context: ctx, page: p } = await deps.ensureBrowser();
|
||||
const name = deps.sanitizeArtifactName(
|
||||
params.name ?? "default",
|
||||
"default",
|
||||
);
|
||||
|
||||
const { mkdir, writeFile } = await import("node:fs/promises");
|
||||
const path = await import("node:path");
|
||||
const stateDir = path.resolve(process.cwd(), STATE_DIR);
|
||||
await mkdir(stateDir, { recursive: true });
|
||||
|
||||
// 1. Playwright storageState: cookies + localStorage
|
||||
const storageState = await ctx.storageState();
|
||||
|
||||
// 2. sessionStorage: must be extracted per-origin via page.evaluate
|
||||
const sessionStorageData: Record<string, Record<string, string>> = {};
|
||||
try {
|
||||
const origin = new URL(p.url()).origin;
|
||||
const ssData = await p.evaluate(() => {
|
||||
const data: Record<string, string> = {};
|
||||
for (let i = 0; i < sessionStorage.length; i++) {
|
||||
const key = sessionStorage.key(i);
|
||||
if (key) data[key] = sessionStorage.getItem(key) ?? "";
|
||||
}
|
||||
return data;
|
||||
});
|
||||
if (Object.keys(ssData).length > 0) {
|
||||
sessionStorageData[origin] = ssData;
|
||||
}
|
||||
} catch {
|
||||
// Page may not have a valid origin (about:blank, etc.)
|
||||
}
|
||||
|
||||
const combined = {
|
||||
storageState,
|
||||
sessionStorage: sessionStorageData,
|
||||
savedAt: new Date().toISOString(),
|
||||
url: p.url(),
|
||||
};
|
||||
|
||||
const filePath = path.join(stateDir, `${name}.json`);
|
||||
await writeFile(filePath, JSON.stringify(combined, null, 2));
|
||||
|
||||
// Ensure .gitignore covers the state dir
|
||||
const gitignorePath = path.resolve(
|
||||
process.cwd(),
|
||||
STATE_DIR,
|
||||
".gitignore",
|
||||
);
|
||||
await writeFile(gitignorePath, "*\n!.gitignore\n").catch(() => {
|
||||
/* best-effort — .gitignore may already exist or dir may be read-only */
|
||||
});
|
||||
|
||||
const cookieCount = storageState.cookies?.length ?? 0;
|
||||
const localStorageOrigins = storageState.origins?.length ?? 0;
|
||||
const sessionStorageOrigins = Object.keys(sessionStorageData).length;
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `State saved: ${filePath}\nCookies: ${cookieCount}\nlocalStorage origins: ${localStorageOrigins}\nsessionStorage origins: ${sessionStorageOrigins}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
path: filePath,
|
||||
cookieCount,
|
||||
localStorageOrigins,
|
||||
sessionStorageOrigins,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Save state failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
// -------------------------------------------------------------------------
|
||||
// browser_restore_state
|
||||
// -------------------------------------------------------------------------
|
||||
pi.registerTool({
|
||||
name: "browser_restore_state",
|
||||
label: "Browser Restore State",
|
||||
description:
|
||||
"Restore cookies, localStorage, and sessionStorage from a previously saved state file. " +
|
||||
"Injects cookies via context.addCookies() and storage via page.evaluate(). " +
|
||||
"For full fidelity, restore before navigating to the target site.",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Name of the state file to restore (default: 'default').",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { context: ctx, page: p } = await deps.ensureBrowser();
|
||||
const name = deps.sanitizeArtifactName(
|
||||
params.name ?? "default",
|
||||
"default",
|
||||
);
|
||||
|
||||
const { readFile } = await import("node:fs/promises");
|
||||
const path = await import("node:path");
|
||||
const filePath = path.join(process.cwd(), STATE_DIR, `${name}.json`);
|
||||
|
||||
let raw: string;
|
||||
try {
|
||||
raw = await readFile(filePath, "utf-8");
|
||||
} catch {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `State file not found: ${filePath}` },
|
||||
],
|
||||
details: { error: "file_not_found", path: filePath },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
const combined = JSON.parse(raw);
|
||||
const storageState = combined.storageState;
|
||||
const sessionStorageData: Record<
|
||||
string,
|
||||
Record<string, string>
|
||||
> = combined.sessionStorage ?? {};
|
||||
|
||||
// 1. Restore cookies
|
||||
let cookieCount = 0;
|
||||
if (storageState?.cookies?.length) {
|
||||
await ctx.addCookies(storageState.cookies);
|
||||
cookieCount = storageState.cookies.length;
|
||||
}
|
||||
|
||||
// 2. Restore localStorage via page.evaluate
|
||||
let localStorageOrigins = 0;
|
||||
if (storageState?.origins?.length) {
|
||||
for (const origin of storageState.origins) {
|
||||
try {
|
||||
await p.evaluate(
|
||||
(items: Array<{ name: string; value: string }>) => {
|
||||
for (const { name, value } of items) {
|
||||
localStorage.setItem(name, value);
|
||||
}
|
||||
},
|
||||
origin.localStorage ?? [],
|
||||
);
|
||||
localStorageOrigins++;
|
||||
} catch {
|
||||
// Origin mismatch — localStorage can only be set on matching origin
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Restore sessionStorage via page.evaluate
|
||||
let sessionStorageOrigins = 0;
|
||||
for (const [_origin, data] of Object.entries(sessionStorageData)) {
|
||||
try {
|
||||
await p.evaluate((items: Record<string, string>) => {
|
||||
for (const [key, value] of Object.entries(items)) {
|
||||
sessionStorage.setItem(key, value);
|
||||
}
|
||||
}, data);
|
||||
sessionStorageOrigins++;
|
||||
} catch {
|
||||
// Origin mismatch
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `State restored from: ${filePath}\nCookies: ${cookieCount}\nlocalStorage origins: ${localStorageOrigins}\nsessionStorage origins: ${sessionStorageOrigins}\nSaved at: ${combined.savedAt ?? "unknown"}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
path: filePath,
|
||||
cookieCount,
|
||||
localStorageOrigins,
|
||||
sessionStorageOrigins,
|
||||
savedAt: combined.savedAt,
|
||||
savedUrl: combined.url,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Restore state failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
export function registerVerifyTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_verify",
|
||||
label: "Browser Verify",
|
||||
description:
|
||||
"Run a structured browser verification flow: navigate to a URL, run checks (element visibility, text content), capture screenshots as evidence, and return structured pass/fail results.",
|
||||
promptGuidelines: [
|
||||
"Use browser_verify for UAT verification flows that need structured evidence.",
|
||||
"Each check produces a pass/fail result with captured evidence.",
|
||||
"Prefer this over manual navigation + assertion sequences for verification tasks.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
url: Type.String({ description: "URL to navigate to" }),
|
||||
checks: Type.Array(
|
||||
Type.Object({
|
||||
description: Type.String({ description: "What this check verifies" }),
|
||||
selector: Type.Optional(
|
||||
Type.String({ description: "CSS selector to check" }),
|
||||
),
|
||||
expectedText: Type.Optional(
|
||||
Type.String({ description: "Expected text content" }),
|
||||
),
|
||||
expectedVisible: Type.Optional(
|
||||
Type.Boolean({ description: "Whether element should be visible" }),
|
||||
),
|
||||
screenshot: Type.Optional(
|
||||
Type.Boolean({ description: "Capture screenshot as evidence" }),
|
||||
),
|
||||
}),
|
||||
{ description: "Verification checks to run" },
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({
|
||||
description: "Navigation timeout in ms",
|
||||
default: 10000,
|
||||
}),
|
||||
),
|
||||
}),
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
const startTime = Date.now();
|
||||
const { page } = await deps.ensureBrowser();
|
||||
const timeout = params.timeout ?? 10000;
|
||||
|
||||
try {
|
||||
await page.goto(params.url, { waitUntil: "domcontentloaded", timeout });
|
||||
} catch (navErr) {
|
||||
const msg = navErr instanceof Error ? navErr.message : String(navErr);
|
||||
return {
|
||||
content: [
|
||||
{ type: "text" as const, text: `Navigation failed: ${msg}` },
|
||||
],
|
||||
details: {
|
||||
url: params.url,
|
||||
passed: false,
|
||||
checks: params.checks.map((c) => ({
|
||||
description: c.description,
|
||||
passed: false,
|
||||
error: msg,
|
||||
})),
|
||||
duration: Date.now() - startTime,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const results: Array<{
|
||||
description: string;
|
||||
passed: boolean;
|
||||
actual?: string;
|
||||
evidence?: string;
|
||||
error?: string;
|
||||
}> = [];
|
||||
|
||||
for (const check of params.checks) {
|
||||
try {
|
||||
let passed = true;
|
||||
let actual: string | undefined;
|
||||
let evidence: string | undefined;
|
||||
|
||||
if (check.selector) {
|
||||
const element = await page.$(check.selector);
|
||||
|
||||
if (check.expectedVisible !== undefined) {
|
||||
const isVisible = element ? await element.isVisible() : false;
|
||||
passed = isVisible === check.expectedVisible;
|
||||
actual = `visible=${isVisible}`;
|
||||
}
|
||||
|
||||
if (check.expectedText !== undefined && element) {
|
||||
const text = await element.textContent();
|
||||
passed = passed && (text?.includes(check.expectedText) ?? false);
|
||||
actual = `text="${text?.slice(0, 200)}"`;
|
||||
}
|
||||
|
||||
if (
|
||||
!element &&
|
||||
(check.expectedVisible === true || check.expectedText)
|
||||
) {
|
||||
passed = false;
|
||||
actual = "element not found";
|
||||
}
|
||||
}
|
||||
|
||||
if (check.screenshot) {
|
||||
try {
|
||||
const buf = await page.screenshot({ type: "png" });
|
||||
evidence = `screenshot captured (${buf.length} bytes)`;
|
||||
} catch {
|
||||
evidence = "screenshot failed";
|
||||
}
|
||||
}
|
||||
|
||||
results.push({
|
||||
description: check.description,
|
||||
passed,
|
||||
actual,
|
||||
evidence,
|
||||
});
|
||||
} catch (checkErr) {
|
||||
results.push({
|
||||
description: check.description,
|
||||
passed: false,
|
||||
error:
|
||||
checkErr instanceof Error ? checkErr.message : String(checkErr),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const allPassed = results.every((r) => r.passed);
|
||||
const summary = results
|
||||
.map(
|
||||
(r) =>
|
||||
`${r.passed ? "PASS" : "FAIL"}: ${r.description}${r.actual ? ` (${r.actual})` : ""}${r.error ? ` — ${r.error}` : ""}`,
|
||||
)
|
||||
.join("\n");
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text" as const,
|
||||
text: `Verification ${allPassed ? "PASSED" : "FAILED"} (${results.filter((r) => r.passed).length}/${results.length})\n\n${summary}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
url: params.url,
|
||||
passed: allPassed,
|
||||
checks: results,
|
||||
duration: Date.now() - startTime,
|
||||
},
|
||||
};
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,235 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Visual regression diffing — compare current page screenshot against a stored baseline.
|
||||
*/
|
||||
|
||||
const BASELINE_DIR = ".sf/browser-baselines";
|
||||
|
||||
export function registerVisualDiffTools(
|
||||
pi: ExtensionAPI,
|
||||
deps: ToolDeps,
|
||||
): void {
|
||||
pi.registerTool({
|
||||
name: "browser_visual_diff",
|
||||
label: "Browser Visual Diff",
|
||||
description:
|
||||
"Compare current page screenshot against a stored baseline pixel-by-pixel. " +
|
||||
"Returns similarity score (0–1), diff pixel count, and optionally generates a diff image highlighting changes. " +
|
||||
"On first run with no baseline, saves the current screenshot as the baseline. " +
|
||||
"Baselines are stored in .sf/browser-baselines/ (gitignored, environment-specific).",
|
||||
parameters: Type.Object({
|
||||
name: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Baseline name (default: auto-generated from URL + viewport). " +
|
||||
"Use consistent names to compare the same view across runs.",
|
||||
}),
|
||||
),
|
||||
selector: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"CSS selector to scope comparison to a specific element instead of full viewport.",
|
||||
}),
|
||||
),
|
||||
threshold: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Pixel matching threshold 0–1 (default: 0.1). " +
|
||||
"Higher values are more tolerant of anti-aliasing and rendering differences.",
|
||||
}),
|
||||
),
|
||||
updateBaseline: Type.Optional(
|
||||
Type.Boolean({
|
||||
description:
|
||||
"If true, overwrite the existing baseline with the current screenshot (default: false).",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const { mkdir, readFile, writeFile } = await import("node:fs/promises");
|
||||
const pathMod = await import("node:path");
|
||||
|
||||
const baselineDir = pathMod.resolve(process.cwd(), BASELINE_DIR);
|
||||
await mkdir(baselineDir, { recursive: true });
|
||||
|
||||
// Ensure .gitignore
|
||||
const gitignorePath = pathMod.join(baselineDir, ".gitignore");
|
||||
await writeFile(gitignorePath, "*\n!.gitignore\n").catch(() => {
|
||||
/* best-effort — .gitignore may already exist or dir may be read-only */
|
||||
});
|
||||
|
||||
// Generate baseline name
|
||||
const url = p.url();
|
||||
const viewport = p.viewportSize();
|
||||
const vpSuffix = viewport
|
||||
? `${viewport.width}x${viewport.height}`
|
||||
: "unknown";
|
||||
const autoName = deps.sanitizeArtifactName(
|
||||
`${new URL(url).pathname.replace(/\//g, "-")}-${vpSuffix}`,
|
||||
`baseline-${vpSuffix}`,
|
||||
);
|
||||
const name = deps.sanitizeArtifactName(
|
||||
params.name ?? autoName,
|
||||
autoName,
|
||||
);
|
||||
|
||||
const baselinePath = pathMod.join(baselineDir, `${name}.png`);
|
||||
const diffPath = pathMod.join(baselineDir, `${name}-diff.png`);
|
||||
|
||||
// Capture current screenshot as PNG (needed for pixel comparison)
|
||||
let currentBuffer: Buffer;
|
||||
if (params.selector) {
|
||||
const locator = p.locator(params.selector).first();
|
||||
currentBuffer = await locator.screenshot({ type: "png" });
|
||||
} else {
|
||||
currentBuffer = await p.screenshot({ type: "png", fullPage: false });
|
||||
}
|
||||
|
||||
// Check if baseline exists
|
||||
let baselineBuffer: Buffer | null = null;
|
||||
try {
|
||||
baselineBuffer = (await readFile(baselinePath)) as Buffer;
|
||||
} catch {
|
||||
// No baseline yet
|
||||
}
|
||||
|
||||
if (!baselineBuffer || params.updateBaseline) {
|
||||
// Save as new baseline
|
||||
await writeFile(baselinePath, currentBuffer);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: baselineBuffer
|
||||
? `Baseline updated: ${baselinePath}\nSize: ${(currentBuffer.length / 1024).toFixed(1)} KB`
|
||||
: `Baseline created (first run): ${baselinePath}\nSize: ${(currentBuffer.length / 1024).toFixed(1)} KB\nRe-run to compare against this baseline.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
baselinePath,
|
||||
baselineCreated: !baselineBuffer,
|
||||
baselineUpdated: !!baselineBuffer,
|
||||
sizeBytes: currentBuffer.length,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Perform pixel comparison using sharp for PNG decoding
|
||||
const sharp = (await import("sharp")).default;
|
||||
|
||||
const baselineMeta = await sharp(baselineBuffer).metadata();
|
||||
const currentMeta = await sharp(currentBuffer).metadata();
|
||||
|
||||
const bWidth = baselineMeta.width ?? 0;
|
||||
const bHeight = baselineMeta.height ?? 0;
|
||||
const cWidth = currentMeta.width ?? 0;
|
||||
const cHeight = currentMeta.height ?? 0;
|
||||
|
||||
// If dimensions differ, report mismatch
|
||||
if (bWidth !== cWidth || bHeight !== cHeight) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Dimension mismatch: baseline is ${bWidth}x${bHeight}, current is ${cWidth}x${cHeight}. Cannot compare.\nUse updateBaseline: true to reset.`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
match: false,
|
||||
dimensionMismatch: true,
|
||||
baselineDimensions: { width: bWidth, height: bHeight },
|
||||
currentDimensions: { width: cWidth, height: cHeight },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// Extract raw RGBA pixel data
|
||||
const baselineRaw = await sharp(baselineBuffer)
|
||||
.ensureAlpha()
|
||||
.raw()
|
||||
.toBuffer();
|
||||
const currentRaw = await sharp(currentBuffer)
|
||||
.ensureAlpha()
|
||||
.raw()
|
||||
.toBuffer();
|
||||
|
||||
const width = bWidth;
|
||||
const height = bHeight;
|
||||
const totalPixels = width * height;
|
||||
const threshold = params.threshold ?? 0.1;
|
||||
|
||||
// Simple pixel-by-pixel comparison (avoiding pixelmatch dependency)
|
||||
const diffData = Buffer.alloc(width * height * 4);
|
||||
let diffPixels = 0;
|
||||
const thresholdSq = threshold * threshold * 255 * 255 * 3;
|
||||
|
||||
for (let i = 0; i < totalPixels; i++) {
|
||||
const offset = i * 4;
|
||||
const dr = baselineRaw[offset] - currentRaw[offset];
|
||||
const dg = baselineRaw[offset + 1] - currentRaw[offset + 1];
|
||||
const db = baselineRaw[offset + 2] - currentRaw[offset + 2];
|
||||
const distSq = dr * dr + dg * dg + db * db;
|
||||
|
||||
if (distSq > thresholdSq) {
|
||||
diffPixels++;
|
||||
// Mark diff pixels as red
|
||||
diffData[offset] = 255; // R
|
||||
diffData[offset + 1] = 0; // G
|
||||
diffData[offset + 2] = 0; // B
|
||||
diffData[offset + 3] = 255; // A
|
||||
} else {
|
||||
// Dim unchanged pixels
|
||||
diffData[offset] = currentRaw[offset] >> 1;
|
||||
diffData[offset + 1] = currentRaw[offset + 1] >> 1;
|
||||
diffData[offset + 2] = currentRaw[offset + 2] >> 1;
|
||||
diffData[offset + 3] = 255;
|
||||
}
|
||||
}
|
||||
|
||||
const similarity = 1 - diffPixels / totalPixels;
|
||||
const match = diffPixels === 0;
|
||||
|
||||
// Save diff image
|
||||
await sharp(diffData, { raw: { width, height, channels: 4 } })
|
||||
.png()
|
||||
.toFile(diffPath);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: match
|
||||
? `Visual diff: MATCH (100% similar)\nBaseline: ${baselinePath}`
|
||||
: `Visual diff: ${(similarity * 100).toFixed(2)}% similar\nDiff pixels: ${diffPixels} of ${totalPixels} (${((diffPixels / totalPixels) * 100).toFixed(2)}%)\nDiff image: ${diffPath}\nBaseline: ${baselinePath}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
match,
|
||||
similarity,
|
||||
diffPixels,
|
||||
totalPixels,
|
||||
diffPercentage: (diffPixels / totalPixels) * 100,
|
||||
dimensions: { width, height },
|
||||
baselinePath,
|
||||
diffImagePath: match ? undefined : diffPath,
|
||||
threshold,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Visual diff failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,378 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import { StringEnum } from "@singularity-forge/pi-ai";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
createRegionStableScript,
|
||||
includesNeedle,
|
||||
parseThreshold,
|
||||
validateWaitParams,
|
||||
} from "../core.js";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
import { getConsoleLogs } from "../state.js";
|
||||
|
||||
export function registerWaitTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_wait_for",
|
||||
label: "Browser Wait For",
|
||||
description:
|
||||
"Wait for a condition before continuing. Use after actions that trigger async updates — data fetches, route changes, animations, loading spinners. Choose the appropriate condition: 'selector_visible' waits for an element to appear, 'selector_hidden' waits for it to disappear, 'url_contains' waits for the URL to match, 'network_idle' waits for all network requests to finish, 'delay' waits a fixed number of milliseconds, 'text_visible' waits for text to appear in the page body, 'text_hidden' waits for text to disappear from the page body, 'request_completed' waits for a network response whose URL contains the given substring, 'console_message' waits for a console log message containing the given substring, 'element_count' waits for the number of elements matching the CSS selector in 'value' to satisfy the 'threshold' expression (e.g. '>=3', '==0', '<5'), 'region_stable' waits for the DOM region matching the CSS selector in 'value' to stop changing.",
|
||||
parameters: Type.Object({
|
||||
condition: StringEnum([
|
||||
"selector_visible",
|
||||
"selector_hidden",
|
||||
"url_contains",
|
||||
"network_idle",
|
||||
"delay",
|
||||
"text_visible",
|
||||
"text_hidden",
|
||||
"request_completed",
|
||||
"console_message",
|
||||
"element_count",
|
||||
"region_stable",
|
||||
] as const),
|
||||
value: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"For selector_visible/selector_hidden/element_count/region_stable: CSS selector. For url_contains/request_completed: URL substring. For text_visible/text_hidden/console_message: text substring. For delay: milliseconds as a string (e.g. '1000'). Not used for network_idle.",
|
||||
}),
|
||||
),
|
||||
threshold: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Threshold expression for element_count (e.g. '>=3', '==0', '<5', or bare '3' which defaults to >=). Only used with element_count condition.",
|
||||
}),
|
||||
),
|
||||
timeout: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Maximum milliseconds to wait before failing (default: 10000)",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const target = deps.getActiveTarget();
|
||||
const timeout = params.timeout ?? 10000;
|
||||
|
||||
const validation = validateWaitParams({
|
||||
condition: params.condition,
|
||||
value: params.value,
|
||||
threshold: (params as any).threshold,
|
||||
});
|
||||
if (validation) {
|
||||
return {
|
||||
content: [{ type: "text", text: validation.error }],
|
||||
details: { error: validation.error, condition: params.condition },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
switch (params.condition) {
|
||||
case "selector_visible": {
|
||||
if (!params.value) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "selector_visible requires a value (CSS selector)",
|
||||
},
|
||||
],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
await target.waitForSelector(params.value, {
|
||||
state: "visible",
|
||||
timeout,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Element "${params.value}" is now visible`,
|
||||
},
|
||||
],
|
||||
details: { condition: params.condition, value: params.value },
|
||||
};
|
||||
}
|
||||
|
||||
case "selector_hidden": {
|
||||
if (!params.value) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "selector_hidden requires a value (CSS selector)",
|
||||
},
|
||||
],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
await target.waitForSelector(params.value, {
|
||||
state: "hidden",
|
||||
timeout,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Element "${params.value}" is now hidden`,
|
||||
},
|
||||
],
|
||||
details: { condition: params.condition, value: params.value },
|
||||
};
|
||||
}
|
||||
|
||||
case "url_contains": {
|
||||
if (!params.value) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "url_contains requires a value (URL substring)",
|
||||
},
|
||||
],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
await p.waitForURL(
|
||||
(url) => url.toString().includes(params.value!),
|
||||
{ timeout },
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `URL now contains "${params.value}". Current URL: ${p.url()}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
condition: params.condition,
|
||||
value: params.value,
|
||||
url: p.url(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
case "network_idle": {
|
||||
await p.waitForLoadState("networkidle", { timeout });
|
||||
return {
|
||||
content: [{ type: "text", text: "Network is idle" }],
|
||||
details: { condition: params.condition },
|
||||
};
|
||||
}
|
||||
|
||||
case "delay": {
|
||||
const ms = parseInt(params.value ?? "1000", 10);
|
||||
if (Number.isNaN(ms)) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "delay requires a numeric value (milliseconds)",
|
||||
},
|
||||
],
|
||||
details: {},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, ms));
|
||||
return {
|
||||
content: [{ type: "text", text: `Waited ${ms}ms` }],
|
||||
details: { condition: params.condition, ms },
|
||||
};
|
||||
}
|
||||
|
||||
case "text_visible": {
|
||||
await target.waitForFunction(
|
||||
(needle: string) => {
|
||||
const body = document.body?.innerText ?? "";
|
||||
return body.toLowerCase().includes(needle.toLowerCase());
|
||||
},
|
||||
params.value!,
|
||||
{ timeout },
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Text "${params.value}" is now visible on the page`,
|
||||
},
|
||||
],
|
||||
details: { condition: params.condition, value: params.value },
|
||||
};
|
||||
}
|
||||
|
||||
case "text_hidden": {
|
||||
await target.waitForFunction(
|
||||
(needle: string) => {
|
||||
const body = document.body?.innerText ?? "";
|
||||
return !body.toLowerCase().includes(needle.toLowerCase());
|
||||
},
|
||||
params.value!,
|
||||
{ timeout },
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Text "${params.value}" is no longer visible on the page`,
|
||||
},
|
||||
],
|
||||
details: { condition: params.condition, value: params.value },
|
||||
};
|
||||
}
|
||||
|
||||
case "request_completed": {
|
||||
const response = await deps
|
||||
.getActivePage()
|
||||
.waitForResponse((resp) => resp.url().includes(params.value!), {
|
||||
timeout,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Request completed: ${response.url()} (status ${response.status()})`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
condition: params.condition,
|
||||
value: params.value,
|
||||
url: response.url(),
|
||||
status: response.status(),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
case "console_message": {
|
||||
const needle = params.value!;
|
||||
const startTime = Date.now();
|
||||
while (Date.now() - startTime < timeout) {
|
||||
const match = getConsoleLogs().find((entry) =>
|
||||
includesNeedle(entry.text, needle),
|
||||
);
|
||||
if (match) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Console message matching "${needle}" found: "${match.text}"`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
condition: params.condition,
|
||||
value: needle,
|
||||
matchedText: match.text,
|
||||
matchedType: match.type,
|
||||
},
|
||||
};
|
||||
}
|
||||
await new Promise((resolve) => setTimeout(resolve, 100));
|
||||
}
|
||||
throw new Error(
|
||||
`Timed out waiting for console message matching "${needle}" (${timeout}ms)`,
|
||||
);
|
||||
}
|
||||
|
||||
case "element_count": {
|
||||
const threshold = parseThreshold(
|
||||
(params as any).threshold ?? ">=1",
|
||||
);
|
||||
if (!threshold) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `element_count threshold is malformed: "${(params as any).threshold}"`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
error: "malformed threshold",
|
||||
condition: params.condition,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
const selector = params.value!;
|
||||
const op = threshold.op;
|
||||
const n = threshold.n;
|
||||
await target.waitForFunction(
|
||||
({
|
||||
selector,
|
||||
op,
|
||||
n,
|
||||
}: {
|
||||
selector: string;
|
||||
op: string;
|
||||
n: number;
|
||||
}) => {
|
||||
const count = document.querySelectorAll(selector).length;
|
||||
switch (op) {
|
||||
case ">=":
|
||||
return count >= n;
|
||||
case "<=":
|
||||
return count <= n;
|
||||
case "==":
|
||||
return count === n;
|
||||
case ">":
|
||||
return count > n;
|
||||
case "<":
|
||||
return count < n;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
},
|
||||
{ selector, op, n },
|
||||
{ timeout },
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Element count for "${selector}" satisfies ${op}${n}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
condition: params.condition,
|
||||
value: selector,
|
||||
threshold: `${op}${n}`,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
case "region_stable": {
|
||||
const script = createRegionStableScript(params.value!);
|
||||
await target.waitForFunction(script, undefined, {
|
||||
timeout,
|
||||
polling: 200,
|
||||
});
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Region "${params.value}" is now stable`,
|
||||
},
|
||||
],
|
||||
details: { condition: params.condition, value: params.value },
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [{ type: "text", text: `Wait failed: ${err.message}` }],
|
||||
details: {
|
||||
error: err.message,
|
||||
condition: params.condition,
|
||||
value: params.value,
|
||||
},
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import type { ToolDeps } from "../state.js";
|
||||
|
||||
/**
|
||||
* Region zoom / high-res capture — capture and upscale specific page regions.
|
||||
*/
|
||||
|
||||
export function registerZoomTools(pi: ExtensionAPI, deps: ToolDeps): void {
|
||||
pi.registerTool({
|
||||
name: "browser_zoom_region",
|
||||
label: "Browser Zoom Region",
|
||||
description:
|
||||
"Capture and optionally upscale a specific rectangular region of the page for detailed inspection. " +
|
||||
"Useful for dense UIs where full-page screenshots have text too small to read. " +
|
||||
"Returns the region as an inline image, same as browser_screenshot.",
|
||||
parameters: Type.Object({
|
||||
x: Type.Number({
|
||||
description: "Left coordinate of the region in CSS pixels.",
|
||||
}),
|
||||
y: Type.Number({
|
||||
description: "Top coordinate of the region in CSS pixels.",
|
||||
}),
|
||||
width: Type.Number({ description: "Width of the region in CSS pixels." }),
|
||||
height: Type.Number({
|
||||
description: "Height of the region in CSS pixels.",
|
||||
}),
|
||||
scale: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Upscale factor (default: 2). Use 1 for native resolution, 2-4 for zoomed detail.",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, _ctx) {
|
||||
try {
|
||||
const { page: p } = await deps.ensureBrowser();
|
||||
const { x, y, width, height } = params;
|
||||
const scale = params.scale ?? 2;
|
||||
|
||||
// Validate dimensions
|
||||
if (width <= 0 || height <= 0) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: "Width and height must be positive." },
|
||||
],
|
||||
details: { error: "invalid_dimensions" },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
|
||||
// Capture the region using Playwright's clip option
|
||||
const regionBuffer = await p.screenshot({
|
||||
type: "png",
|
||||
clip: { x, y, width, height },
|
||||
});
|
||||
|
||||
let outputBuffer: Buffer = regionBuffer;
|
||||
const outputMime = "image/png";
|
||||
|
||||
// Upscale if scale > 1
|
||||
if (scale > 1) {
|
||||
const sharp = (await import("sharp")).default;
|
||||
const targetWidth = Math.round(width * scale);
|
||||
const targetHeight = Math.round(height * scale);
|
||||
|
||||
outputBuffer = await sharp(regionBuffer)
|
||||
.resize(targetWidth, targetHeight, {
|
||||
kernel: "lanczos3",
|
||||
fit: "fill",
|
||||
})
|
||||
.png()
|
||||
.toBuffer();
|
||||
}
|
||||
|
||||
const base64Data = outputBuffer.toString("base64");
|
||||
const title = await p.title();
|
||||
const url = p.url();
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Region capture: ${width}x${height} at (${x},${y})${scale > 1 ? ` upscaled ${scale}x to ${Math.round(width * scale)}x${Math.round(height * scale)}` : ""}\nPage: ${title}\nURL: ${url}`,
|
||||
},
|
||||
{
|
||||
type: "image",
|
||||
data: base64Data,
|
||||
mimeType: outputMime,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
region: { x, y, width, height },
|
||||
scale,
|
||||
outputDimensions: {
|
||||
width: Math.round(width * scale),
|
||||
height: Math.round(height * scale),
|
||||
},
|
||||
title,
|
||||
url,
|
||||
},
|
||||
};
|
||||
} catch (err: any) {
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Region zoom failed: ${err.message}` },
|
||||
],
|
||||
details: { error: err.message },
|
||||
isError: true,
|
||||
};
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,667 +0,0 @@
|
|||
/**
|
||||
* browser-tools — Node-side utility functions
|
||||
*
|
||||
* All functions that were helpers in index.ts but run in Node (not browser).
|
||||
* They import state accessors from ./state.ts — never raw module-level variables.
|
||||
*/
|
||||
|
||||
import { copyFile, mkdir, stat, writeFile } from "node:fs/promises";
|
||||
import path from "node:path";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
truncateHead,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import type { Frame, Page } from "playwright";
|
||||
import {
|
||||
beginAction,
|
||||
findAction,
|
||||
finishAction,
|
||||
registryListPages,
|
||||
toActionParamsSummary,
|
||||
} from "./core.js";
|
||||
import {
|
||||
ARTIFACT_ROOT,
|
||||
actionTimeline,
|
||||
type BrowserAssertionCheckInput,
|
||||
type BrowserVerificationCheck,
|
||||
type BrowserVerificationResult,
|
||||
type ClickTargetStateSnapshot,
|
||||
type CompactPageState,
|
||||
type CompactSelectorState,
|
||||
type ConsoleEntry,
|
||||
getActiveFrame,
|
||||
getActiveTraceSession,
|
||||
getConsoleLogs,
|
||||
getDialogLogs,
|
||||
getHarState,
|
||||
getNetworkLogs,
|
||||
getPendingCriticalRequestsByPage,
|
||||
getSessionArtifactDir,
|
||||
getSessionStartedAt,
|
||||
type NetworkEntry,
|
||||
type ParsedRefSpec,
|
||||
pageRegistry,
|
||||
setSessionArtifactDir,
|
||||
setSessionStartedAt,
|
||||
} from "./state.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Text truncation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function truncateText(text: string): string {
|
||||
const result = truncateHead(text, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
if (result.truncated) {
|
||||
return (
|
||||
result.content +
|
||||
`\n\n[Output truncated: ${result.outputLines}/${result.totalLines} lines shown]`
|
||||
);
|
||||
}
|
||||
return result.content;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Artifact helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function formatArtifactTimestamp(timestamp: number): string {
|
||||
return new Date(timestamp).toISOString().replace(/[:.]/g, "-");
|
||||
}
|
||||
|
||||
export async function ensureDir(dirPath: string): Promise<string> {
|
||||
await mkdir(dirPath, { recursive: true });
|
||||
return dirPath;
|
||||
}
|
||||
|
||||
export async function writeArtifactFile(
|
||||
filePath: string,
|
||||
content: string | Uint8Array,
|
||||
): Promise<{ path: string; bytes: number }> {
|
||||
await ensureDir(path.dirname(filePath));
|
||||
await writeFile(filePath, content);
|
||||
const fileStat = await stat(filePath);
|
||||
return { path: filePath, bytes: fileStat.size };
|
||||
}
|
||||
|
||||
export async function copyArtifactFile(
|
||||
sourcePath: string,
|
||||
destinationPath: string,
|
||||
): Promise<{ path: string; bytes: number }> {
|
||||
await ensureDir(path.dirname(destinationPath));
|
||||
await copyFile(sourcePath, destinationPath);
|
||||
const fileStat = await stat(destinationPath);
|
||||
return { path: destinationPath, bytes: fileStat.size };
|
||||
}
|
||||
|
||||
export function ensureSessionStartedAt(): number {
|
||||
let t = getSessionStartedAt();
|
||||
if (!t) {
|
||||
t = Date.now();
|
||||
setSessionStartedAt(t);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
export async function ensureSessionArtifactDir(): Promise<string> {
|
||||
const existing = getSessionArtifactDir();
|
||||
if (existing) {
|
||||
await ensureDir(existing);
|
||||
return existing;
|
||||
}
|
||||
const startedAt = ensureSessionStartedAt();
|
||||
const dir = path.join(
|
||||
ARTIFACT_ROOT,
|
||||
`${formatArtifactTimestamp(startedAt)}-session`,
|
||||
);
|
||||
setSessionArtifactDir(dir);
|
||||
await ensureDir(dir);
|
||||
return dir;
|
||||
}
|
||||
|
||||
export function buildSessionArtifactPath(filename: string): string {
|
||||
const dir = getSessionArtifactDir();
|
||||
if (!dir) {
|
||||
throw new Error("browser session artifact directory is not initialized");
|
||||
}
|
||||
return path.join(dir, filename);
|
||||
}
|
||||
|
||||
export function getActivePageMetadata() {
|
||||
const registry = pageRegistry;
|
||||
const activeEntry =
|
||||
registry.activePageId !== null
|
||||
? (registry.pages.find(
|
||||
(entry: any) => entry.id === registry.activePageId,
|
||||
) ?? null)
|
||||
: null;
|
||||
return {
|
||||
id: activeEntry?.id ?? null,
|
||||
title: activeEntry?.title ?? "",
|
||||
url: activeEntry?.url ?? "",
|
||||
};
|
||||
}
|
||||
|
||||
export function getActiveFrameMetadata() {
|
||||
const frame = getActiveFrame();
|
||||
if (!frame) {
|
||||
return { name: null, url: null };
|
||||
}
|
||||
return {
|
||||
name: frame.name() || null,
|
||||
url: frame.url() || null,
|
||||
};
|
||||
}
|
||||
|
||||
export function getSessionArtifactMetadata() {
|
||||
return {
|
||||
artifactRoot: ARTIFACT_ROOT,
|
||||
sessionStartedAt: getSessionStartedAt(),
|
||||
sessionArtifactDir: getSessionArtifactDir(),
|
||||
activeTraceSession: getActiveTraceSession(),
|
||||
harState: { ...getHarState() },
|
||||
activePage: getActivePageMetadata(),
|
||||
activeFrame: getActiveFrameMetadata(),
|
||||
};
|
||||
}
|
||||
|
||||
export function sanitizeArtifactName(value: string, fallback: string): string {
|
||||
const sanitized = value
|
||||
.trim()
|
||||
.replace(/[^a-zA-Z0-9._-]+/g, "-")
|
||||
.replace(/^-+|-+$/g, "");
|
||||
return sanitized || fallback;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Page helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* getLivePagesSnapshot requires ensureBrowser (circular) — it will be
|
||||
* wired in via ToolDeps. This is a factory that takes ensureBrowser.
|
||||
*/
|
||||
export function createGetLivePagesSnapshot(
|
||||
ensureBrowser: () => Promise<{ page: Page }>,
|
||||
) {
|
||||
return async function getLivePagesSnapshot() {
|
||||
await ensureBrowser();
|
||||
for (const entry of pageRegistry.pages) {
|
||||
try {
|
||||
entry.title = await entry.page.title();
|
||||
entry.url = entry.page.url();
|
||||
} catch {
|
||||
// Page may have been closed between snapshots.
|
||||
}
|
||||
}
|
||||
return registryListPages(pageRegistry);
|
||||
};
|
||||
}
|
||||
|
||||
export async function resolveAccessibilityScope(
|
||||
selector?: string,
|
||||
): Promise<{ selector?: string; scope: string; source: string }> {
|
||||
if (selector?.trim()) {
|
||||
return {
|
||||
selector: selector.trim(),
|
||||
scope: `selector:${selector.trim()}`,
|
||||
source: "explicit_selector",
|
||||
};
|
||||
}
|
||||
const frame = getActiveFrame();
|
||||
// We need getActiveTarget for dialog check, but that requires page access.
|
||||
// For non-frame scoping, the caller must handle dialog detection separately
|
||||
// if needed. Here we handle the frame case and fall through to full_page.
|
||||
if (frame) {
|
||||
return {
|
||||
selector: "body",
|
||||
scope: frame.name() ? `active frame:${frame.name()}` : "active frame",
|
||||
source: "active_frame",
|
||||
};
|
||||
}
|
||||
return { selector: "body", scope: "full page", source: "full_page" };
|
||||
}
|
||||
|
||||
/**
|
||||
* captureAccessibilityMarkdown — needs access to the active target.
|
||||
* Accepts the target (Page | Frame) so it doesn't need to pull from state.
|
||||
*/
|
||||
export async function captureAccessibilityMarkdown(
|
||||
target: Page | Frame,
|
||||
selector?: string,
|
||||
): Promise<{ snapshot: string; scope: string; source: string }> {
|
||||
const scopeInfo = await resolveAccessibilityScope(selector);
|
||||
const locator = target.locator(scopeInfo.selector ?? "body").first();
|
||||
const snapshot = await locator.ariaSnapshot();
|
||||
return { snapshot, scope: scopeInfo.scope, source: scopeInfo.source };
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Critical request tracking
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function isCriticalResourceType(resourceType: string): boolean {
|
||||
return (
|
||||
resourceType === "document" ||
|
||||
resourceType === "fetch" ||
|
||||
resourceType === "xhr"
|
||||
);
|
||||
}
|
||||
|
||||
export function updatePendingCriticalRequests(p: Page, delta: number): void {
|
||||
const map = getPendingCriticalRequestsByPage();
|
||||
const current = map.get(p) ?? 0;
|
||||
map.set(p, Math.max(0, current + delta));
|
||||
}
|
||||
|
||||
export function getPendingCriticalRequests(p: Page): number {
|
||||
return getPendingCriticalRequestsByPage().get(p) ?? 0;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Verification helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function verificationFromChecks(
|
||||
checks: BrowserVerificationCheck[],
|
||||
retryHint?: string,
|
||||
): BrowserVerificationResult {
|
||||
const passedChecks = checks
|
||||
.filter((check) => check.passed)
|
||||
.map((check) => check.name);
|
||||
const verified = passedChecks.length > 0;
|
||||
return {
|
||||
verified,
|
||||
checks,
|
||||
verificationSummary: verified
|
||||
? `PASS (${passedChecks.join(", ")})`
|
||||
: "SOFT-FAIL (no observable state change)",
|
||||
retryHint: verified ? undefined : retryHint,
|
||||
};
|
||||
}
|
||||
|
||||
export function verificationLine(
|
||||
verification: BrowserVerificationResult,
|
||||
): string {
|
||||
return `Verification: ${verification.verificationSummary}`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Assertion helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function collectAssertionState(
|
||||
p: Page,
|
||||
checks: BrowserAssertionCheckInput[],
|
||||
captureCompactPageState: (
|
||||
p: Page,
|
||||
options?: {
|
||||
selectors?: string[];
|
||||
includeBodyText?: boolean;
|
||||
target?: Page | Frame;
|
||||
},
|
||||
) => Promise<CompactPageState>,
|
||||
target?: Page | Frame,
|
||||
): Promise<{
|
||||
url: string;
|
||||
title: string;
|
||||
bodyText: string;
|
||||
focus: string;
|
||||
selectorStates: Record<string, CompactSelectorState>;
|
||||
consoleEntries: ConsoleEntry[];
|
||||
networkEntries: NetworkEntry[];
|
||||
allConsoleEntries: ConsoleEntry[];
|
||||
allNetworkEntries: NetworkEntry[];
|
||||
actionTimeline: typeof actionTimeline;
|
||||
}> {
|
||||
const selectors = checks
|
||||
.map((check) => check.selector)
|
||||
.filter((value): value is string => !!value);
|
||||
const compactState = await captureCompactPageState(p, {
|
||||
selectors,
|
||||
includeBodyText: true,
|
||||
target,
|
||||
});
|
||||
const sinceActionId = checks.reduce<number | undefined>((max, check) => {
|
||||
if (check.sinceActionId === undefined) return max;
|
||||
if (max === undefined) return check.sinceActionId;
|
||||
return Math.max(max, check.sinceActionId);
|
||||
}, undefined);
|
||||
return {
|
||||
url: compactState.url,
|
||||
title: compactState.title,
|
||||
bodyText: compactState.bodyText,
|
||||
focus: compactState.focus,
|
||||
selectorStates: compactState.selectorStates,
|
||||
consoleEntries: getConsoleEntriesSince(sinceActionId),
|
||||
networkEntries: getNetworkEntriesSince(sinceActionId),
|
||||
allConsoleEntries: getConsoleLogs(),
|
||||
allNetworkEntries: getNetworkLogs(),
|
||||
actionTimeline,
|
||||
};
|
||||
}
|
||||
|
||||
export function formatAssertionText(
|
||||
result: ReturnType<typeof import("./core.js").evaluateAssertionChecks>,
|
||||
): string {
|
||||
const lines = [result.summary];
|
||||
for (const check of result.checks.slice(0, 8)) {
|
||||
lines.push(
|
||||
`- ${check.passed ? "PASS" : "FAIL"} ${check.name}: expected ${JSON.stringify(check.expected)}, got ${JSON.stringify(check.actual)}`,
|
||||
);
|
||||
}
|
||||
lines.push(`Hint: ${result.agentHint}`);
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
export function formatDiffText(
|
||||
diff: ReturnType<typeof import("./core.js").diffCompactStates>,
|
||||
): string {
|
||||
const lines = [diff.summary];
|
||||
for (const change of diff.changes.slice(0, 8)) {
|
||||
lines.push(
|
||||
`- ${change.type}: ${JSON.stringify(change.before ?? null)} → ${JSON.stringify(change.after ?? null)}`,
|
||||
);
|
||||
}
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// URL / dialog helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function getUrlHash(url: string): string {
|
||||
try {
|
||||
return new URL(url).hash || "";
|
||||
} catch {
|
||||
return "";
|
||||
}
|
||||
}
|
||||
|
||||
export async function countOpenDialogs(target: Page | Frame): Promise<number> {
|
||||
try {
|
||||
return await target.evaluate(
|
||||
() =>
|
||||
document.querySelectorAll('[role="dialog"]:not([hidden]),dialog[open]')
|
||||
.length,
|
||||
);
|
||||
} catch {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Click / input helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export async function captureClickTargetState(
|
||||
target: Page | Frame,
|
||||
selector: string,
|
||||
): Promise<ClickTargetStateSnapshot> {
|
||||
try {
|
||||
return await target.evaluate((sel) => {
|
||||
const el = document.querySelector(sel) as HTMLElement | null;
|
||||
if (!el) {
|
||||
return {
|
||||
exists: false,
|
||||
ariaExpanded: null,
|
||||
ariaPressed: null,
|
||||
ariaSelected: null,
|
||||
open: null,
|
||||
};
|
||||
}
|
||||
return {
|
||||
exists: true,
|
||||
ariaExpanded: el.getAttribute("aria-expanded"),
|
||||
ariaPressed: el.getAttribute("aria-pressed"),
|
||||
ariaSelected: el.getAttribute("aria-selected"),
|
||||
open:
|
||||
el instanceof HTMLDialogElement
|
||||
? el.open
|
||||
: el.getAttribute("open") !== null,
|
||||
};
|
||||
}, selector);
|
||||
} catch {
|
||||
return {
|
||||
exists: false,
|
||||
ariaExpanded: null,
|
||||
ariaPressed: null,
|
||||
ariaSelected: null,
|
||||
open: null,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export async function readInputLikeValue(
|
||||
target: Page | Frame,
|
||||
selector?: string,
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
return await target.evaluate((sel) => {
|
||||
const resolveTarget = (): Element | null => {
|
||||
if (sel) return document.querySelector(sel);
|
||||
const active = document.activeElement;
|
||||
if (
|
||||
!active ||
|
||||
active === document.body ||
|
||||
active === document.documentElement
|
||||
)
|
||||
return null;
|
||||
return active;
|
||||
};
|
||||
|
||||
const target = resolveTarget();
|
||||
if (!target) return null;
|
||||
if (
|
||||
target instanceof HTMLInputElement ||
|
||||
target instanceof HTMLTextAreaElement
|
||||
) {
|
||||
return target.value;
|
||||
}
|
||||
if (target instanceof HTMLSelectElement) {
|
||||
return target.value;
|
||||
}
|
||||
if ((target as HTMLElement).isContentEditable) {
|
||||
return (target.textContent ?? "").trim();
|
||||
}
|
||||
return (target as HTMLElement).getAttribute("value");
|
||||
}, selector);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function firstErrorLine(err: unknown): string {
|
||||
const message =
|
||||
typeof err === "object" && err && "message" in err
|
||||
? String((err as { message?: unknown }).message ?? "")
|
||||
: String(err ?? "unknown error");
|
||||
return message.split("\n")[0] || "unknown error";
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Action tracking
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function beginTrackedAction(
|
||||
tool: string,
|
||||
params: unknown,
|
||||
beforeUrl: string,
|
||||
) {
|
||||
return beginAction(actionTimeline, {
|
||||
tool,
|
||||
paramsSummary: toActionParamsSummary(params),
|
||||
beforeUrl,
|
||||
});
|
||||
}
|
||||
|
||||
export function finishTrackedAction(
|
||||
actionId: number,
|
||||
updates: {
|
||||
status: "success" | "error";
|
||||
afterUrl?: string;
|
||||
verificationSummary?: string;
|
||||
warningSummary?: string;
|
||||
diffSummary?: string;
|
||||
changed?: boolean;
|
||||
error?: string;
|
||||
beforeState?: CompactPageState;
|
||||
afterState?: CompactPageState;
|
||||
},
|
||||
) {
|
||||
return finishAction(actionTimeline, actionId, updates);
|
||||
}
|
||||
|
||||
export function getSinceTimestamp(sinceActionId?: number): number {
|
||||
if (!sinceActionId) return 0;
|
||||
const action = findAction(actionTimeline, sinceActionId);
|
||||
if (!action) return 0;
|
||||
return action.startedAt ?? 0;
|
||||
}
|
||||
|
||||
export function getConsoleEntriesSince(sinceActionId?: number): ConsoleEntry[] {
|
||||
const since = getSinceTimestamp(sinceActionId);
|
||||
return getConsoleLogs().filter((entry) => entry.timestamp >= since);
|
||||
}
|
||||
|
||||
export function getNetworkEntriesSince(sinceActionId?: number): NetworkEntry[] {
|
||||
const since = getSinceTimestamp(sinceActionId);
|
||||
return getNetworkLogs().filter((entry) => entry.timestamp >= since);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Error summary
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function getRecentErrors(pageUrl: string): string {
|
||||
const parts: string[] = [];
|
||||
const now = Date.now();
|
||||
const since = now - 12_000;
|
||||
|
||||
const toOrigin = (url: string): string | null => {
|
||||
try {
|
||||
return new URL(url).origin;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
const pageOrigin = toOrigin(pageUrl);
|
||||
const sameOrigin = (url: string): boolean =>
|
||||
!pageOrigin || toOrigin(url) === pageOrigin;
|
||||
|
||||
const summarize = (items: string[], max: number): string[] => {
|
||||
const counts = new Map<string, number>();
|
||||
const order: string[] = [];
|
||||
for (const item of items) {
|
||||
if (!counts.has(item)) order.push(item);
|
||||
counts.set(item, (counts.get(item) ?? 0) + 1);
|
||||
}
|
||||
return order.slice(0, max).map((item) => {
|
||||
const count = counts.get(item) ?? 1;
|
||||
return count > 1 ? `${item} (x${count})` : item;
|
||||
});
|
||||
};
|
||||
|
||||
const consoleLogs = getConsoleLogs();
|
||||
const jsWarnings = consoleLogs
|
||||
.filter(
|
||||
(e) =>
|
||||
(e.type === "error" || e.type === "pageerror") &&
|
||||
e.timestamp >= since &&
|
||||
sameOrigin(e.url),
|
||||
)
|
||||
.map((e) => e.text.slice(0, 120));
|
||||
if (jsWarnings.length > 0) {
|
||||
parts.push("JS: " + summarize(jsWarnings, 2).join(" | "));
|
||||
}
|
||||
|
||||
const actionableStatus = new Set([401, 403, 404, 408, 409, 422, 429]);
|
||||
const actionableTypes = new Set(["document", "fetch", "xhr", "script"]);
|
||||
const networkLogs = getNetworkLogs();
|
||||
const netWarnings = networkLogs
|
||||
.filter((e) => e.timestamp >= since && sameOrigin(e.url))
|
||||
.filter((e) => {
|
||||
if (e.failed) return actionableTypes.has(e.resourceType);
|
||||
if (e.status === null) return false;
|
||||
if (e.status >= 500) return true;
|
||||
return (
|
||||
actionableStatus.has(e.status) && actionableTypes.has(e.resourceType)
|
||||
);
|
||||
})
|
||||
.map((e) => {
|
||||
if (e.failed) return `${e.method} ${e.resourceType} FAILED`;
|
||||
return `${e.method} ${e.resourceType} ${e.status}`;
|
||||
});
|
||||
if (netWarnings.length > 0) {
|
||||
parts.push("Network: " + summarize(netWarnings, 2).join(" | "));
|
||||
}
|
||||
|
||||
const dialogLogs = getDialogLogs();
|
||||
const dialogWarnings = dialogLogs
|
||||
.filter((e) => e.timestamp >= since && sameOrigin(e.url))
|
||||
.map((e) => `${e.type}: ${e.message.slice(0, 80)}`);
|
||||
if (dialogWarnings.length > 0) {
|
||||
parts.push("Dialogs: " + summarize(dialogWarnings, 1).join(" | "));
|
||||
}
|
||||
|
||||
if (parts.length === 0) return "";
|
||||
return `\n\nWarnings: ${parts.join("; ")}\nUse browser_get_console_logs/browser_get_network_logs for full diagnostics.`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ref helpers (parsing / formatting — no browser evaluate)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function parseRef(input: string): ParsedRefSpec {
|
||||
const trimmed = input.trim().toLowerCase();
|
||||
const token = trimmed.startsWith("@") ? trimmed.slice(1) : trimmed;
|
||||
const versioned = token.match(/^v(\d+):(e\d+)$/);
|
||||
if (versioned) {
|
||||
const version = parseInt(versioned[1], 10);
|
||||
const key = versioned[2];
|
||||
return { key, version, display: `@v${version}:${key}` };
|
||||
}
|
||||
return { key: token, version: null, display: `@${token}` };
|
||||
}
|
||||
|
||||
export function formatVersionedRef(version: number, key: string): string {
|
||||
return `@v${version}:${key}`;
|
||||
}
|
||||
|
||||
export function staleRefGuidance(refDisplay: string, reason: string): string {
|
||||
return `Ref ${refDisplay} could not be resolved (${reason}). The ref is likely stale after DOM/navigation changes. Call browser_snapshot_refs again to refresh refs.`;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Compact state summary formatting
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export function formatCompactStateSummary(state: CompactPageState): string {
|
||||
const lines: string[] = [];
|
||||
lines.push(`Title: ${state.title}`);
|
||||
lines.push(`URL: ${state.url}`);
|
||||
lines.push(
|
||||
`Elements: ${state.counts.landmarks} landmarks, ${state.counts.buttons} buttons, ${state.counts.links} links, ${state.counts.inputs} inputs`,
|
||||
);
|
||||
if (state.headings.length > 0) {
|
||||
lines.push(
|
||||
"Headings: " +
|
||||
state.headings
|
||||
.map((text, index) => `H${index + 1} "${text}"`)
|
||||
.join(", "),
|
||||
);
|
||||
}
|
||||
if (state.focus) {
|
||||
lines.push(`Focused: ${state.focus}`);
|
||||
}
|
||||
if (state.dialog.title) {
|
||||
lines.push(`Active dialog: "${state.dialog.title}"`);
|
||||
}
|
||||
lines.push(
|
||||
"Use browser_find for targeted discovery, browser_assert for verification, or browser_get_accessibility_tree for full detail.",
|
||||
);
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
/**
|
||||
* Claude Code CLI Provider Extension
|
||||
*
|
||||
* Registers a model provider that delegates inference to the user's
|
||||
* locally-installed Claude Code CLI via the official Agent SDK.
|
||||
*
|
||||
* Users with a Claude Code subscription (Pro/Max/Team) get access to
|
||||
* subsidized inference through SF's UI — no API key required.
|
||||
*
|
||||
* TOS-compliant: uses Anthropic's official `@anthropic-ai/claude-agent-sdk`,
|
||||
* never touches credentials, never offers a login flow.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { CLAUDE_CODE_MODELS } from "./models.js";
|
||||
import { isClaudeCodeReady } from "./readiness.js";
|
||||
import { streamViaClaudeCode } from "./stream-adapter.js";
|
||||
|
||||
export default function claudeCodeCli(pi: ExtensionAPI) {
|
||||
pi.registerProvider("claude-code", {
|
||||
authMode: "externalCli",
|
||||
api: "anthropic-messages",
|
||||
baseUrl: "local://claude-code",
|
||||
isReady: isClaudeCodeReady,
|
||||
streamSimple: streamViaClaudeCode,
|
||||
models: CLAUDE_CODE_MODELS,
|
||||
});
|
||||
}
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
/**
|
||||
* Model definitions for the Claude Code CLI provider.
|
||||
*
|
||||
* Costs are zero because inference is covered by the user's Claude Code
|
||||
* subscription. The SDK's `result` message still provides token counts
|
||||
* for display in the TUI.
|
||||
*
|
||||
* Context windows and max tokens match the Anthropic API definitions
|
||||
* in models.generated.ts.
|
||||
*/
|
||||
|
||||
const ZERO_COST = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
|
||||
|
||||
export const CLAUDE_CODE_MODELS = [
|
||||
{
|
||||
id: "claude-opus-4-6",
|
||||
name: "Claude Opus 4.6 (via Claude Code)",
|
||||
reasoning: true,
|
||||
input: ["text", "image"] as ("text" | "image")[],
|
||||
cost: ZERO_COST,
|
||||
contextWindow: 1_000_000,
|
||||
maxTokens: 128_000,
|
||||
},
|
||||
{
|
||||
id: "claude-sonnet-4-6",
|
||||
name: "Claude Sonnet 4.6 (via Claude Code)",
|
||||
reasoning: true,
|
||||
input: ["text", "image"] as ("text" | "image")[],
|
||||
cost: ZERO_COST,
|
||||
contextWindow: 1_000_000,
|
||||
maxTokens: 64_000,
|
||||
},
|
||||
{
|
||||
id: "claude-haiku-4-5",
|
||||
name: "Claude Haiku 4.5 (via Claude Code)",
|
||||
reasoning: true,
|
||||
input: ["text", "image"] as ("text" | "image")[],
|
||||
cost: ZERO_COST,
|
||||
contextWindow: 200_000,
|
||||
maxTokens: 64_000,
|
||||
},
|
||||
];
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
{
|
||||
"name": "@singularity-forge/claude-code-cli",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
"name": "@singularity-forge/claude-code-cli",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.js"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,382 +0,0 @@
|
|||
/**
|
||||
* Content-block mapping helpers and streaming state tracker.
|
||||
*
|
||||
* Translates the Claude Agent SDK's `BetaRawMessageStreamEvent` sequence
|
||||
* into SF's `AssistantMessageEvent` deltas for incremental TUI rendering.
|
||||
*/
|
||||
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEvent,
|
||||
ServerToolUseContent,
|
||||
StopReason,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
Usage,
|
||||
WebSearchResultContent,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { hasXmlParameterTags, repairToolJson } from "@singularity-forge/pi-ai";
|
||||
import type {
|
||||
BetaContentBlock,
|
||||
BetaRawMessageStreamEvent,
|
||||
NonNullableUsage,
|
||||
} from "./sdk-types.js";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// MCP tool name parsing
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Split a Claude Code MCP tool name (`mcp__<server>__<tool>`) into its parts.
|
||||
* Returns null for non-prefixed names so callers can fall through unchanged.
|
||||
*
|
||||
* Server names may contain hyphens (`sf-workflow`); the SDK uses the literal
|
||||
* `__` delimiter between the server name and the tool name.
|
||||
*/
|
||||
export function parseMcpToolName(
|
||||
name: string,
|
||||
): { server: string; tool: string } | null {
|
||||
if (!name.startsWith("mcp__")) return null;
|
||||
const rest = name.slice("mcp__".length);
|
||||
const delim = rest.indexOf("__");
|
||||
if (delim <= 0 || delim === rest.length - 2) return null;
|
||||
return { server: rest.slice(0, delim), tool: rest.slice(delim + 2) };
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a SF ToolCall block from a Claude Code SDK tool_use block, stripping
|
||||
* the `mcp__<server>__` prefix from the name so registered extension renderers
|
||||
* (which use the unprefixed canonical names) can match. The original server
|
||||
* name is preserved on the block for diagnostics and rendering.
|
||||
*/
|
||||
function toolCallFromBlock(
|
||||
id: string,
|
||||
rawName: string,
|
||||
input: Record<string, unknown>,
|
||||
): ToolCall {
|
||||
const parsed = parseMcpToolName(rawName);
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id,
|
||||
name: parsed ? parsed.tool : rawName,
|
||||
arguments: input,
|
||||
};
|
||||
if (parsed) {
|
||||
(toolCall as ToolCall & { mcpServer?: string }).mcpServer = parsed.server;
|
||||
}
|
||||
return toolCall;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Content-block mapping helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Convert a single BetaContentBlock to the corresponding SF content type.
|
||||
*/
|
||||
export function mapContentBlock(
|
||||
block: BetaContentBlock,
|
||||
):
|
||||
| TextContent
|
||||
| ThinkingContent
|
||||
| ToolCall
|
||||
| ServerToolUseContent
|
||||
| WebSearchResultContent {
|
||||
switch (block.type) {
|
||||
case "text":
|
||||
return { type: "text", text: block.text } satisfies TextContent;
|
||||
|
||||
case "thinking":
|
||||
return {
|
||||
type: "thinking",
|
||||
thinking: block.thinking,
|
||||
...(block.signature ? { thinkingSignature: block.signature } : {}),
|
||||
} satisfies ThinkingContent;
|
||||
|
||||
case "tool_use":
|
||||
return toolCallFromBlock(block.id, block.name, block.input);
|
||||
|
||||
case "server_tool_use":
|
||||
return {
|
||||
type: "serverToolUse",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input,
|
||||
} satisfies ServerToolUseContent;
|
||||
|
||||
case "web_search_tool_result":
|
||||
return {
|
||||
type: "webSearchResult",
|
||||
toolUseId: block.tool_use_id,
|
||||
content: block.content,
|
||||
} satisfies WebSearchResultContent;
|
||||
|
||||
default: {
|
||||
const unknown = block as Record<string, unknown>;
|
||||
return {
|
||||
type: "text",
|
||||
text: `[unknown content block: ${JSON.stringify(unknown)}]`,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function mapStopReason(reason: string | null): StopReason {
|
||||
switch (reason) {
|
||||
case "end_turn":
|
||||
case "stop_sequence":
|
||||
return "stop";
|
||||
case "max_tokens":
|
||||
return "length";
|
||||
case "tool_use":
|
||||
return "toolUse";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert SDK usage + total_cost_usd into SF's Usage shape.
|
||||
*
|
||||
* The SDK does not break cost down per-bucket, so all cost is
|
||||
* attributed to `cost.total`.
|
||||
*/
|
||||
export function mapUsage(
|
||||
sdkUsage: NonNullableUsage,
|
||||
totalCostUsd: number,
|
||||
): Usage {
|
||||
return {
|
||||
input: sdkUsage.input_tokens,
|
||||
output: sdkUsage.output_tokens,
|
||||
cacheRead: sdkUsage.cache_read_input_tokens,
|
||||
cacheWrite: sdkUsage.cache_creation_input_tokens,
|
||||
totalTokens:
|
||||
sdkUsage.input_tokens +
|
||||
sdkUsage.output_tokens +
|
||||
sdkUsage.cache_read_input_tokens +
|
||||
sdkUsage.cache_creation_input_tokens,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: totalCostUsd,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Zero-cost usage constant
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export const ZERO_USAGE: Usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Streaming partial-message state tracker
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Mutable accumulator that tracks the partial AssistantMessage being built
|
||||
* from a sequence of stream_event messages. Produces AssistantMessageEvent
|
||||
* deltas that the TUI can render incrementally.
|
||||
*/
|
||||
export class PartialMessageBuilder {
|
||||
private partial: AssistantMessage;
|
||||
/** Map from stream-event `index` to our content array index. */
|
||||
private indexMap = new Map<number, number>();
|
||||
/** Accumulated JSON input string per tool_use block (keyed by stream index). */
|
||||
private toolJsonAccum = new Map<number, string>();
|
||||
|
||||
constructor(model: string) {
|
||||
this.partial = {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: "anthropic-messages",
|
||||
provider: "claude-code",
|
||||
model,
|
||||
usage: { ...ZERO_USAGE },
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
get message(): AssistantMessage {
|
||||
return this.partial;
|
||||
}
|
||||
|
||||
/**
|
||||
* Feed a BetaRawMessageStreamEvent and return the corresponding
|
||||
* AssistantMessageEvent (or null if the event is not mapped).
|
||||
*/
|
||||
handleEvent(event: BetaRawMessageStreamEvent): AssistantMessageEvent | null {
|
||||
const streamIndex = event.index ?? 0;
|
||||
|
||||
switch (event.type) {
|
||||
// ---- Block start ----
|
||||
case "content_block_start": {
|
||||
const block = event.content_block;
|
||||
if (!block) return null;
|
||||
|
||||
const contentIndex = this.partial.content.length;
|
||||
this.indexMap.set(streamIndex, contentIndex);
|
||||
|
||||
if (block.type === "text") {
|
||||
this.partial.content.push({ type: "text", text: "" });
|
||||
return { type: "text_start", contentIndex, partial: this.partial };
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
this.partial.content.push({ type: "thinking", thinking: "" });
|
||||
return {
|
||||
type: "thinking_start",
|
||||
contentIndex,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (block.type === "tool_use") {
|
||||
this.toolJsonAccum.set(streamIndex, "");
|
||||
this.partial.content.push(
|
||||
toolCallFromBlock(block.id, block.name, {}),
|
||||
);
|
||||
return {
|
||||
type: "toolcall_start",
|
||||
contentIndex,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (block.type === "server_tool_use") {
|
||||
this.partial.content.push({
|
||||
type: "serverToolUse",
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input,
|
||||
});
|
||||
return {
|
||||
type: "server_tool_use",
|
||||
contentIndex,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// ---- Block delta ----
|
||||
case "content_block_delta": {
|
||||
const contentIndex = this.indexMap.get(streamIndex);
|
||||
if (contentIndex === undefined) return null;
|
||||
const delta = event.delta;
|
||||
if (!delta) return null;
|
||||
|
||||
if (delta.type === "text_delta" && typeof delta.text === "string") {
|
||||
const existing = this.partial.content[contentIndex] as TextContent;
|
||||
existing.text += delta.text;
|
||||
return {
|
||||
type: "text_delta",
|
||||
contentIndex,
|
||||
delta: delta.text,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (
|
||||
delta.type === "thinking_delta" &&
|
||||
typeof delta.thinking === "string"
|
||||
) {
|
||||
const existing = this.partial.content[
|
||||
contentIndex
|
||||
] as ThinkingContent;
|
||||
existing.thinking += delta.thinking;
|
||||
return {
|
||||
type: "thinking_delta",
|
||||
contentIndex,
|
||||
delta: delta.thinking,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (
|
||||
delta.type === "input_json_delta" &&
|
||||
typeof delta.partial_json === "string"
|
||||
) {
|
||||
const accum =
|
||||
(this.toolJsonAccum.get(streamIndex) ?? "") + delta.partial_json;
|
||||
this.toolJsonAccum.set(streamIndex, accum);
|
||||
return {
|
||||
type: "toolcall_delta",
|
||||
contentIndex,
|
||||
delta: delta.partial_json,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
// ---- Block stop ----
|
||||
case "content_block_stop": {
|
||||
const contentIndex = this.indexMap.get(streamIndex);
|
||||
if (contentIndex === undefined) return null;
|
||||
const block = this.partial.content[contentIndex];
|
||||
|
||||
if (block.type === "text") {
|
||||
return {
|
||||
type: "text_end",
|
||||
contentIndex,
|
||||
content: block.text,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
return {
|
||||
type: "thinking_end",
|
||||
contentIndex,
|
||||
content: block.thinking,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
if (block.type === "toolCall") {
|
||||
const jsonStr = this.toolJsonAccum.get(streamIndex) ?? "{}";
|
||||
const jsonForParse = hasXmlParameterTags(jsonStr)
|
||||
? repairToolJson(jsonStr)
|
||||
: jsonStr;
|
||||
try {
|
||||
block.arguments = JSON.parse(jsonForParse);
|
||||
} catch {
|
||||
// JSON.parse failed — attempt repair for YAML-style bullet
|
||||
// lists that LLMs copy from template formatting (#2660).
|
||||
try {
|
||||
block.arguments = JSON.parse(repairToolJson(jsonForParse));
|
||||
} catch {
|
||||
// Repair also failed — stream was truncated or garbage.
|
||||
// Preserve the raw string for diagnostics but signal the
|
||||
// malformation explicitly so downstream consumers can
|
||||
// distinguish this from a healthy tool completion (#2574).
|
||||
block.arguments = { _raw: jsonStr };
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex,
|
||||
toolCall: block,
|
||||
partial: this.partial,
|
||||
malformedArguments: true,
|
||||
};
|
||||
}
|
||||
}
|
||||
return {
|
||||
type: "toolcall_end",
|
||||
contentIndex,
|
||||
toolCall: block,
|
||||
partial: this.partial,
|
||||
};
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,91 +0,0 @@
|
|||
/**
|
||||
* Readiness check for the Claude Code CLI provider.
|
||||
*
|
||||
* Verifies the `claude` binary is installed, responsive, AND authenticated.
|
||||
* Results are cached for 30 seconds to avoid shelling out on every
|
||||
* model-availability check.
|
||||
*
|
||||
* Auth verification follows the T3 Code pattern: run `claude auth status`
|
||||
* and check the exit code + output for an authenticated session.
|
||||
*/
|
||||
|
||||
import { execFileSync } from "node:child_process";
|
||||
|
||||
let cachedBinaryPresent: boolean | null = null;
|
||||
let cachedAuthed: boolean | null = null;
|
||||
let lastCheckMs = 0;
|
||||
const CHECK_INTERVAL_MS = 30_000;
|
||||
|
||||
function refreshCache(): void {
|
||||
const now = Date.now();
|
||||
if (cachedBinaryPresent !== null && now - lastCheckMs < CHECK_INTERVAL_MS) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Set timestamp first to prevent re-entrant checks during the same window
|
||||
lastCheckMs = now;
|
||||
|
||||
// Check binary presence
|
||||
try {
|
||||
execFileSync("claude", ["--version"], { timeout: 5_000, stdio: "pipe" });
|
||||
cachedBinaryPresent = true;
|
||||
} catch {
|
||||
cachedBinaryPresent = false;
|
||||
cachedAuthed = false;
|
||||
return;
|
||||
}
|
||||
|
||||
// Check auth status — exit code 0 with non-error output means authenticated
|
||||
try {
|
||||
const output = execFileSync("claude", ["auth", "status"], {
|
||||
timeout: 5_000,
|
||||
stdio: "pipe",
|
||||
})
|
||||
.toString()
|
||||
.toLowerCase();
|
||||
// The CLI outputs "not logged in", "no credentials", or similar when unauthenticated
|
||||
cachedAuthed =
|
||||
!/not logged in|no credentials|unauthenticated|not authenticated/i.test(
|
||||
output,
|
||||
);
|
||||
} catch {
|
||||
// Non-zero exit code means not authenticated
|
||||
cachedAuthed = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the `claude` binary is installed (regardless of auth state).
|
||||
*/
|
||||
export function isClaudeBinaryPresent(): boolean {
|
||||
refreshCache();
|
||||
return cachedBinaryPresent ?? false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the `claude` CLI is authenticated with a valid session.
|
||||
* Returns false if the binary is not installed.
|
||||
*/
|
||||
export function isClaudeCodeAuthed(): boolean {
|
||||
refreshCache();
|
||||
return (cachedBinaryPresent ?? false) && (cachedAuthed ?? false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Full readiness check: binary installed AND authenticated.
|
||||
* This is the gating function used by the provider registration.
|
||||
*/
|
||||
export function isClaudeCodeReady(): boolean {
|
||||
refreshCache();
|
||||
return (cachedBinaryPresent ?? false) && (cachedAuthed ?? false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Force-clear the cached readiness state.
|
||||
* Useful after the user completes auth setup so the next check is fresh.
|
||||
*/
|
||||
export function clearReadinessCache(): void {
|
||||
cachedBinaryPresent = null;
|
||||
cachedAuthed = null;
|
||||
lastCheckMs = 0;
|
||||
}
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
/**
|
||||
* Lightweight type mirrors for the Claude Agent SDK.
|
||||
*
|
||||
* These stubs allow the extension to compile without a hard dependency on
|
||||
* `@anthropic-ai/claude-agent-sdk`. The real SDK is imported dynamically
|
||||
* at runtime in stream-adapter.ts.
|
||||
*/
|
||||
|
||||
/** UUID branded string from the SDK. */
|
||||
export type UUID = string;
|
||||
|
||||
/** BetaMessage from the Anthropic SDK, as wrapped by SDKAssistantMessage. */
|
||||
export interface BetaMessage {
|
||||
id: string;
|
||||
type: "message";
|
||||
role: "assistant";
|
||||
content: BetaContentBlock[];
|
||||
model: string;
|
||||
stop_reason: "end_turn" | "max_tokens" | "stop_sequence" | "tool_use" | null;
|
||||
usage: { input_tokens: number; output_tokens: number };
|
||||
}
|
||||
|
||||
export type BetaContentBlock =
|
||||
| { type: "text"; text: string }
|
||||
| { type: "thinking"; thinking: string; signature?: string }
|
||||
| {
|
||||
type: "tool_use";
|
||||
id: string;
|
||||
name: string;
|
||||
input: Record<string, unknown>;
|
||||
}
|
||||
| { type: "server_tool_use"; id: string; name: string; input: unknown }
|
||||
| { type: "web_search_tool_result"; tool_use_id: string; content: unknown };
|
||||
|
||||
/** Streaming event emitted when includePartialMessages is true. */
|
||||
export interface BetaRawMessageStreamEvent {
|
||||
type: string;
|
||||
index?: number;
|
||||
content_block?: BetaContentBlock;
|
||||
delta?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface SDKAssistantMessage {
|
||||
type: "assistant";
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
message: BetaMessage;
|
||||
parent_tool_use_id: string | null;
|
||||
error?: { type: string; message: string };
|
||||
}
|
||||
|
||||
export interface SDKUserMessage {
|
||||
type: "user";
|
||||
uuid?: UUID;
|
||||
session_id: string;
|
||||
message: unknown;
|
||||
parent_tool_use_id: string | null;
|
||||
isSynthetic?: boolean;
|
||||
tool_use_result?: unknown;
|
||||
}
|
||||
|
||||
export interface SDKSystemMessage {
|
||||
type: "system";
|
||||
subtype: "init";
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
export interface SDKStatusMessage {
|
||||
type: "system";
|
||||
subtype: "status";
|
||||
status: "compacting" | null;
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
}
|
||||
|
||||
export interface SDKPartialAssistantMessage {
|
||||
type: "stream_event";
|
||||
event: BetaRawMessageStreamEvent;
|
||||
parent_tool_use_id: string | null;
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
}
|
||||
|
||||
export interface SDKToolProgressMessage {
|
||||
type: "tool_progress";
|
||||
tool_use_id: string;
|
||||
tool_name: string;
|
||||
parent_tool_use_id: string | null;
|
||||
elapsed_time_seconds: number;
|
||||
task_id?: string;
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
}
|
||||
|
||||
export interface NonNullableUsage {
|
||||
input_tokens: number;
|
||||
output_tokens: number;
|
||||
cache_read_input_tokens: number;
|
||||
cache_creation_input_tokens: number;
|
||||
}
|
||||
|
||||
export type SDKResultMessage =
|
||||
| {
|
||||
type: "result";
|
||||
subtype: "success";
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
duration_ms: number;
|
||||
duration_api_ms: number;
|
||||
is_error: boolean;
|
||||
num_turns: number;
|
||||
result: string;
|
||||
stop_reason: string | null;
|
||||
total_cost_usd: number;
|
||||
usage: NonNullableUsage;
|
||||
}
|
||||
| {
|
||||
type: "result";
|
||||
subtype:
|
||||
| "error_max_turns"
|
||||
| "error_during_execution"
|
||||
| "error_max_budget_usd"
|
||||
| "error_max_structured_output_retries";
|
||||
uuid: UUID;
|
||||
session_id: string;
|
||||
duration_ms: number;
|
||||
duration_api_ms: number;
|
||||
is_error: boolean;
|
||||
num_turns: number;
|
||||
stop_reason: string | null;
|
||||
total_cost_usd: number;
|
||||
usage: NonNullableUsage;
|
||||
errors: string[];
|
||||
};
|
||||
|
||||
/** Catch-all for SDK message types we don't map. */
|
||||
export interface SDKOtherMessage {
|
||||
type: string;
|
||||
[key: string]: unknown;
|
||||
}
|
||||
|
||||
/**
|
||||
* Union of all SDK message types this extension handles.
|
||||
* Mirrors the real `SDKMessage` from `@anthropic-ai/claude-agent-sdk`.
|
||||
*/
|
||||
export type SDKMessage =
|
||||
| SDKAssistantMessage
|
||||
| SDKUserMessage
|
||||
| SDKResultMessage
|
||||
| SDKSystemMessage
|
||||
| SDKStatusMessage
|
||||
| SDKPartialAssistantMessage
|
||||
| SDKToolProgressMessage
|
||||
| SDKOtherMessage;
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,264 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, test } from 'vitest';
|
||||
import {
|
||||
mapContentBlock,
|
||||
PartialMessageBuilder,
|
||||
parseMcpToolName,
|
||||
} from "../partial-builder.ts";
|
||||
import type {
|
||||
BetaContentBlock,
|
||||
BetaRawMessageStreamEvent,
|
||||
} from "../sdk-types.ts";
|
||||
|
||||
describe("PartialMessageBuilder — malformed tool arguments (#2574)", () => {
|
||||
/**
|
||||
* Helper: feed a tool_use block through the builder lifecycle and return
|
||||
* the toolcall_end event. Simulates: content_block_start → N deltas → content_block_stop.
|
||||
*/
|
||||
function feedToolCall(
|
||||
builder: PartialMessageBuilder,
|
||||
jsonFragments: string[],
|
||||
) {
|
||||
// Start the tool_use block at stream index 0
|
||||
builder.handleEvent({
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: "tool_1",
|
||||
name: "sf_plan_slice",
|
||||
input: {},
|
||||
},
|
||||
} as BetaRawMessageStreamEvent);
|
||||
|
||||
// Feed JSON fragments as input_json_delta
|
||||
for (const fragment of jsonFragments) {
|
||||
builder.handleEvent({
|
||||
type: "content_block_delta",
|
||||
index: 0,
|
||||
delta: { type: "input_json_delta", partial_json: fragment },
|
||||
} as BetaRawMessageStreamEvent);
|
||||
}
|
||||
|
||||
// Stop the block — this is where JSON parse happens
|
||||
return builder.handleEvent({
|
||||
type: "content_block_stop",
|
||||
index: 0,
|
||||
} as BetaRawMessageStreamEvent);
|
||||
}
|
||||
|
||||
test("valid JSON → toolcall_end without malformedArguments", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
const event = feedToolCall(builder, ['{"milestone', 'Id": "M001"}']);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
// Valid JSON should NOT have the malformedArguments flag
|
||||
assert.equal(
|
||||
(event as any).malformedArguments,
|
||||
undefined,
|
||||
"valid JSON should not set malformedArguments",
|
||||
);
|
||||
// Arguments should be parsed correctly
|
||||
if (event!.type === "toolcall_end") {
|
||||
assert.deepEqual(event!.toolCall.arguments, { milestoneId: "M001" });
|
||||
}
|
||||
});
|
||||
|
||||
test("unrepairable JSON → toolcall_end WITH malformedArguments: true", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
// Simulate a stream with unrepairable garbage that repairToolJson cannot fix
|
||||
const event = feedToolCall(builder, ['{{{']);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
assert.equal(
|
||||
(event as any).malformedArguments,
|
||||
true,
|
||||
"unrepairable JSON should set malformedArguments: true",
|
||||
);
|
||||
// The _raw field should contain the original broken JSON
|
||||
if (event!.type === "toolcall_end") {
|
||||
assert.equal(
|
||||
event!.toolCall.arguments._raw,
|
||||
'{{{',
|
||||
"_raw should contain the unrepairable JSON string",
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
test("no JSON deltas → malformedArguments: true (empty accumulator is not valid JSON)", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
// No deltas — the accumulator is initialized to "" by content_block_start,
|
||||
// and "" is not valid JSON, so this correctly signals malformed.
|
||||
const event = feedToolCall(builder, []);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
assert.equal(
|
||||
(event as any).malformedArguments,
|
||||
true,
|
||||
"empty accumulator (no JSON deltas) is not valid JSON → malformed",
|
||||
);
|
||||
});
|
||||
|
||||
test("repairable non-JSON → toolcall_end without malformedArguments", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
// repairToolJson wraps bare strings in quotes, making them valid JSON
|
||||
const event = feedToolCall(builder, ["not json at all <html>"]);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
assert.equal(
|
||||
(event as any).malformedArguments,
|
||||
undefined,
|
||||
"repairable bare string should not set malformedArguments",
|
||||
);
|
||||
assert.equal(
|
||||
(event as any).toolCall.arguments,
|
||||
"not json at all <html>",
|
||||
"repaired bare string should be the parsed argument value",
|
||||
);
|
||||
});
|
||||
|
||||
test("YAML bullet lists repaired to JSON arrays (#2660)", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
const malformedJson =
|
||||
'{"milestoneId": "M005", "keyDecisions": - Used Web Notification API, "keyFiles": - src/lib.rs, "title": "done"}';
|
||||
const event = feedToolCall(builder, [malformedJson]);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
// Repaired YAML bullets should NOT set malformedArguments
|
||||
assert.equal(
|
||||
(event as any).malformedArguments,
|
||||
undefined,
|
||||
"repaired YAML bullets should not set malformedArguments",
|
||||
);
|
||||
if (event!.type === "toolcall_end") {
|
||||
assert.equal(event!.toolCall.arguments.milestoneId, "M005");
|
||||
assert.ok(
|
||||
Array.isArray(event!.toolCall.arguments.keyDecisions),
|
||||
"keyDecisions should be repaired to an array",
|
||||
);
|
||||
assert.ok(
|
||||
Array.isArray(event!.toolCall.arguments.keyFiles),
|
||||
"keyFiles should be repaired to an array",
|
||||
);
|
||||
assert.equal(event!.toolCall.arguments.title, "done");
|
||||
}
|
||||
});
|
||||
|
||||
test("XML parameter tags trapped inside valid JSON strings are promoted (#3751)", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
const malformedJson =
|
||||
'{"narrative":"text.</narrative>\\n<parameter name=\\"verification\\">all tests pass</parameter>\\n<parameter name=\\"verificationEvidence\\">[\\"npm test\\"]</parameter>","oneLiner":"done"}';
|
||||
const event = feedToolCall(builder, [malformedJson]);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_end");
|
||||
assert.equal((event as any).malformedArguments, undefined);
|
||||
if (event!.type === "toolcall_end") {
|
||||
assert.equal(event.toolCall.arguments.narrative, "text.");
|
||||
assert.equal(event.toolCall.arguments.verification, "all tests pass");
|
||||
assert.deepEqual(event.toolCall.arguments.verificationEvidence, [
|
||||
"npm test",
|
||||
]);
|
||||
assert.equal(event.toolCall.arguments.oneLiner, "done");
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseMcpToolName", () => {
|
||||
test("splits mcp__<server>__<tool> into parts", () => {
|
||||
assert.deepEqual(parseMcpToolName("mcp__sf-workflow__sf_plan_milestone"), {
|
||||
server: "sf-workflow",
|
||||
tool: "sf_plan_milestone",
|
||||
});
|
||||
});
|
||||
|
||||
test("preserves server names containing hyphens", () => {
|
||||
assert.deepEqual(parseMcpToolName("mcp__my-cool-server__do_thing"), {
|
||||
server: "my-cool-server",
|
||||
tool: "do_thing",
|
||||
});
|
||||
});
|
||||
|
||||
test("preserves tool names containing underscores", () => {
|
||||
assert.deepEqual(parseMcpToolName("mcp__srv__a_b_c_d"), {
|
||||
server: "srv",
|
||||
tool: "a_b_c_d",
|
||||
});
|
||||
});
|
||||
|
||||
test("returns null for non-prefixed names", () => {
|
||||
assert.equal(parseMcpToolName("Bash"), null);
|
||||
assert.equal(parseMcpToolName("sf_plan_milestone"), null);
|
||||
});
|
||||
|
||||
test("returns null for malformed prefixes", () => {
|
||||
assert.equal(parseMcpToolName("mcp__"), null);
|
||||
assert.equal(parseMcpToolName("mcp__server"), null);
|
||||
assert.equal(parseMcpToolName("mcp__server__"), null);
|
||||
assert.equal(parseMcpToolName("mcp____tool"), null);
|
||||
});
|
||||
});
|
||||
|
||||
describe("PartialMessageBuilder — MCP tool name normalization", () => {
|
||||
test("strips mcp__<server>__ prefix on content_block_start", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
const event = builder.handleEvent({
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: "tool_1",
|
||||
name: "mcp__sf-workflow__sf_plan_milestone",
|
||||
input: {},
|
||||
},
|
||||
} as BetaRawMessageStreamEvent);
|
||||
|
||||
assert.ok(event, "event should not be null");
|
||||
assert.equal(event!.type, "toolcall_start");
|
||||
if (event!.type === "toolcall_start") {
|
||||
const toolCall = event.partial.content[event.contentIndex] as any;
|
||||
assert.equal(toolCall.name, "sf_plan_milestone");
|
||||
assert.equal(toolCall.mcpServer, "sf-workflow");
|
||||
}
|
||||
});
|
||||
|
||||
test("leaves non-MCP tool names untouched", () => {
|
||||
const builder = new PartialMessageBuilder("claude-sonnet-4-20250514");
|
||||
const event = builder.handleEvent({
|
||||
type: "content_block_start",
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: "tool_1",
|
||||
name: "Bash",
|
||||
input: {},
|
||||
},
|
||||
} as BetaRawMessageStreamEvent);
|
||||
|
||||
assert.ok(event);
|
||||
if (event!.type === "toolcall_start") {
|
||||
const toolCall = event.partial.content[event.contentIndex] as any;
|
||||
assert.equal(toolCall.name, "Bash");
|
||||
assert.equal(toolCall.mcpServer, undefined);
|
||||
}
|
||||
});
|
||||
|
||||
test("mapContentBlock strips MCP prefix on full tool_use blocks", () => {
|
||||
const block: BetaContentBlock = {
|
||||
type: "tool_use",
|
||||
id: "tool_2",
|
||||
name: "mcp__sf-workflow__sf_task_complete",
|
||||
input: { taskId: "T001" },
|
||||
};
|
||||
const mapped = mapContentBlock(block) as any;
|
||||
assert.equal(mapped.type, "toolCall");
|
||||
assert.equal(mapped.name, "sf_task_complete");
|
||||
assert.equal(mapped.mcpServer, "sf-workflow");
|
||||
assert.deepEqual(mapped.arguments, { taskId: "T001" });
|
||||
});
|
||||
});
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,508 +0,0 @@
|
|||
import { execFileSync, spawn } from "node:child_process";
|
||||
import { existsSync } from "node:fs";
|
||||
import type { SFPreferences } from "../sf/preferences.js";
|
||||
import type { Phase, SFState } from "../sf/types.js";
|
||||
|
||||
const DEFAULT_SOCKET_PATH = "/tmp/cmux.sock";
|
||||
const STATUS_KEY = "sf";
|
||||
const lastSidebarSnapshots = new Map<string, string>();
|
||||
let cmuxPromptedThisSession = false;
|
||||
let cachedCliAvailability: boolean | null = null;
|
||||
|
||||
export interface CmuxEnvironment {
|
||||
available: boolean;
|
||||
cliAvailable: boolean;
|
||||
socketPath: string;
|
||||
workspaceId?: string;
|
||||
surfaceId?: string;
|
||||
}
|
||||
|
||||
export interface ResolvedCmuxConfig extends CmuxEnvironment {
|
||||
enabled: boolean;
|
||||
notifications: boolean;
|
||||
sidebar: boolean;
|
||||
splits: boolean;
|
||||
browser: boolean;
|
||||
}
|
||||
|
||||
export interface CmuxSidebarProgress {
|
||||
value: number;
|
||||
label: string;
|
||||
}
|
||||
|
||||
export type CmuxLogLevel =
|
||||
| "info"
|
||||
| "progress"
|
||||
| "success"
|
||||
| "warning"
|
||||
| "error";
|
||||
|
||||
export function detectCmuxEnvironment(
|
||||
env: NodeJS.ProcessEnv = process.env,
|
||||
socketExists: (path: string) => boolean = existsSync,
|
||||
cliAvailable: () => boolean = isCmuxCliAvailable,
|
||||
): CmuxEnvironment {
|
||||
const socketPath = env.CMUX_SOCKET_PATH ?? DEFAULT_SOCKET_PATH;
|
||||
const workspaceId = env.CMUX_WORKSPACE_ID?.trim() || undefined;
|
||||
const surfaceId = env.CMUX_SURFACE_ID?.trim() || undefined;
|
||||
const available = Boolean(
|
||||
workspaceId && surfaceId && socketExists(socketPath),
|
||||
);
|
||||
return {
|
||||
available,
|
||||
cliAvailable: cliAvailable(),
|
||||
socketPath,
|
||||
workspaceId,
|
||||
surfaceId,
|
||||
};
|
||||
}
|
||||
|
||||
export function resolveCmuxConfig(
|
||||
preferences: SFPreferences | undefined,
|
||||
env: NodeJS.ProcessEnv = process.env,
|
||||
socketExists: (path: string) => boolean = existsSync,
|
||||
cliAvailable: () => boolean = isCmuxCliAvailable,
|
||||
): ResolvedCmuxConfig {
|
||||
const detected = detectCmuxEnvironment(env, socketExists, cliAvailable);
|
||||
const cmux = preferences?.cmux ?? {};
|
||||
const enabled = detected.available && cmux.enabled === true;
|
||||
return {
|
||||
...detected,
|
||||
enabled,
|
||||
notifications: enabled && cmux.notifications !== false,
|
||||
sidebar: enabled && cmux.sidebar !== false,
|
||||
splits: enabled && cmux.splits === true,
|
||||
browser: enabled && cmux.browser === true,
|
||||
};
|
||||
}
|
||||
|
||||
export function shouldPromptToEnableCmux(
|
||||
preferences: SFPreferences | undefined,
|
||||
env: NodeJS.ProcessEnv = process.env,
|
||||
socketExists: (path: string) => boolean = existsSync,
|
||||
cliAvailable: () => boolean = isCmuxCliAvailable,
|
||||
): boolean {
|
||||
if (cmuxPromptedThisSession) return false;
|
||||
const detected = detectCmuxEnvironment(env, socketExists, cliAvailable);
|
||||
if (!detected.available) return false;
|
||||
return preferences?.cmux?.enabled === undefined;
|
||||
}
|
||||
|
||||
export function markCmuxPromptShown(): void {
|
||||
cmuxPromptedThisSession = true;
|
||||
}
|
||||
|
||||
export function resetCmuxPromptState(): void {
|
||||
cmuxPromptedThisSession = false;
|
||||
}
|
||||
|
||||
export function isCmuxCliAvailable(): boolean {
|
||||
if (cachedCliAvailability !== null) return cachedCliAvailability;
|
||||
try {
|
||||
execFileSync("cmux", ["--help"], { stdio: "ignore", timeout: 1000 });
|
||||
cachedCliAvailability = true;
|
||||
} catch {
|
||||
cachedCliAvailability = false;
|
||||
}
|
||||
return cachedCliAvailability;
|
||||
}
|
||||
|
||||
export function supportsOsc777Notifications(
|
||||
env: NodeJS.ProcessEnv = process.env,
|
||||
): boolean {
|
||||
const termProgram = env.TERM_PROGRAM?.toLowerCase() ?? "";
|
||||
return (
|
||||
termProgram === "ghostty" ||
|
||||
termProgram === "wezterm" ||
|
||||
termProgram === "iterm.app"
|
||||
);
|
||||
}
|
||||
|
||||
export function emitOsc777Notification(title: string, body: string): void {
|
||||
if (!supportsOsc777Notifications()) return;
|
||||
const safeTitle = normalizeNotificationText(title).replace(/;/g, ",");
|
||||
const safeBody = normalizeNotificationText(body).replace(/;/g, ",");
|
||||
process.stdout.write(`\x1b]777;notify;${safeTitle};${safeBody}\x07`);
|
||||
}
|
||||
|
||||
export function buildCmuxStatusLabel(state: SFState): string {
|
||||
const parts: string[] = [];
|
||||
if (state.activeMilestone) parts.push(state.activeMilestone.id);
|
||||
if (state.activeSlice) parts.push(state.activeSlice.id);
|
||||
if (state.activeTask) {
|
||||
const prev = parts.pop();
|
||||
parts.push(prev ? `${prev}/${state.activeTask.id}` : state.activeTask.id);
|
||||
}
|
||||
if (parts.length === 0) return state.phase;
|
||||
return `${parts.join(" ")} · ${state.phase}`;
|
||||
}
|
||||
|
||||
export function buildCmuxProgress(state: SFState): CmuxSidebarProgress | null {
|
||||
const progress = state.progress;
|
||||
if (!progress) return null;
|
||||
|
||||
const choose = (
|
||||
done: number,
|
||||
total: number,
|
||||
label: string,
|
||||
): CmuxSidebarProgress | null => {
|
||||
if (total <= 0) return null;
|
||||
return {
|
||||
value: Math.max(0, Math.min(1, done / total)),
|
||||
label: `${done}/${total} ${label}`,
|
||||
};
|
||||
};
|
||||
|
||||
return (
|
||||
choose(progress.tasks?.done ?? 0, progress.tasks?.total ?? 0, "tasks") ??
|
||||
choose(progress.slices?.done ?? 0, progress.slices?.total ?? 0, "slices") ??
|
||||
choose(progress.milestones.done, progress.milestones.total, "milestones")
|
||||
);
|
||||
}
|
||||
|
||||
function phaseVisuals(phase: Phase): { icon: string; color: string } {
|
||||
switch (phase) {
|
||||
case "blocked":
|
||||
return { icon: "triangle-alert", color: "#ef4444" };
|
||||
case "paused":
|
||||
return { icon: "pause", color: "#f59e0b" };
|
||||
case "complete":
|
||||
case "completing-milestone":
|
||||
return { icon: "check", color: "#22c55e" };
|
||||
case "planning":
|
||||
case "researching":
|
||||
case "replanning-slice":
|
||||
return { icon: "compass", color: "#3b82f6" };
|
||||
case "validating-milestone":
|
||||
case "verifying":
|
||||
return { icon: "shield-check", color: "#06b6d4" };
|
||||
default:
|
||||
return { icon: "rocket", color: "#4ade80" };
|
||||
}
|
||||
}
|
||||
|
||||
function sidebarSnapshotKey(config: ResolvedCmuxConfig): string {
|
||||
return config.workspaceId ?? "default";
|
||||
}
|
||||
|
||||
export class CmuxClient {
|
||||
private readonly config: ResolvedCmuxConfig;
|
||||
|
||||
constructor(config: ResolvedCmuxConfig) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
static fromPreferences(preferences: SFPreferences | undefined): CmuxClient {
|
||||
return new CmuxClient(resolveCmuxConfig(preferences));
|
||||
}
|
||||
|
||||
getConfig(): ResolvedCmuxConfig {
|
||||
return this.config;
|
||||
}
|
||||
|
||||
private canRun(): boolean {
|
||||
return this.config.available && this.config.cliAvailable;
|
||||
}
|
||||
|
||||
private appendWorkspace(args: string[]): string[] {
|
||||
return this.config.workspaceId
|
||||
? [...args, "--workspace", this.config.workspaceId]
|
||||
: args;
|
||||
}
|
||||
|
||||
private appendSurface(args: string[], surfaceId?: string): string[] {
|
||||
return surfaceId ? [...args, "--surface", surfaceId] : args;
|
||||
}
|
||||
|
||||
private runSync(args: string[]): string | null {
|
||||
if (!this.canRun()) return null;
|
||||
try {
|
||||
return execFileSync("cmux", args, {
|
||||
encoding: "utf-8",
|
||||
timeout: 3000,
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
env: process.env,
|
||||
});
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
private async runAsync(args: string[]): Promise<string | null> {
|
||||
if (!this.canRun()) return null;
|
||||
return new Promise<string | null>((resolve) => {
|
||||
const child = spawn("cmux", args, {
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
env: process.env,
|
||||
});
|
||||
const chunks: Buffer[] = [];
|
||||
let settled = false;
|
||||
const done = (result: string | null) => {
|
||||
if (!settled) {
|
||||
settled = true;
|
||||
resolve(result);
|
||||
}
|
||||
};
|
||||
const timer = setTimeout(() => {
|
||||
child.kill();
|
||||
done(null);
|
||||
}, 5000);
|
||||
child.stdout!.on("data", (chunk: Buffer) => chunks.push(chunk));
|
||||
child.on("close", (code) => {
|
||||
clearTimeout(timer);
|
||||
done(code === 0 ? Buffer.concat(chunks).toString("utf-8") : null);
|
||||
});
|
||||
child.on("error", () => {
|
||||
clearTimeout(timer);
|
||||
done(null);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
getCapabilities(): unknown | null {
|
||||
const stdout = this.runSync(["capabilities", "--json"]);
|
||||
return stdout ? parseJson(stdout) : null;
|
||||
}
|
||||
|
||||
identify(): unknown | null {
|
||||
const stdout = this.runSync(["identify", "--json"]);
|
||||
return stdout ? parseJson(stdout) : null;
|
||||
}
|
||||
|
||||
setStatus(label: string, phase: Phase): void {
|
||||
if (!this.config.sidebar) return;
|
||||
const visuals = phaseVisuals(phase);
|
||||
this.runSync(
|
||||
this.appendWorkspace([
|
||||
"set-status",
|
||||
STATUS_KEY,
|
||||
label,
|
||||
"--icon",
|
||||
visuals.icon,
|
||||
"--color",
|
||||
visuals.color,
|
||||
]),
|
||||
);
|
||||
}
|
||||
|
||||
clearStatus(): void {
|
||||
if (!this.config.sidebar) return;
|
||||
this.runSync(this.appendWorkspace(["clear-status", STATUS_KEY]));
|
||||
}
|
||||
|
||||
setProgress(progress: CmuxSidebarProgress | null): void {
|
||||
if (!this.config.sidebar) return;
|
||||
if (!progress) {
|
||||
this.runSync(this.appendWorkspace(["clear-progress"]));
|
||||
return;
|
||||
}
|
||||
this.runSync(
|
||||
this.appendWorkspace([
|
||||
"set-progress",
|
||||
progress.value.toFixed(3),
|
||||
"--label",
|
||||
progress.label,
|
||||
]),
|
||||
);
|
||||
}
|
||||
|
||||
log(message: string, level: CmuxLogLevel = "info", source = "sf"): void {
|
||||
if (!this.config.sidebar) return;
|
||||
this.runSync(
|
||||
this.appendWorkspace([
|
||||
"log",
|
||||
"--level",
|
||||
level,
|
||||
"--source",
|
||||
source,
|
||||
"--",
|
||||
message,
|
||||
]),
|
||||
);
|
||||
}
|
||||
|
||||
notify(title: string, body: string, subtitle?: string): boolean {
|
||||
if (!this.config.notifications) return false;
|
||||
const args = ["notify", "--title", title, "--body", body];
|
||||
if (subtitle) args.push("--subtitle", subtitle);
|
||||
return this.runSync(args) !== null;
|
||||
}
|
||||
|
||||
async listSurfaceIds(): Promise<string[]> {
|
||||
const stdout = await this.runAsync(
|
||||
this.appendWorkspace(["list-surfaces", "--json", "--id-format", "both"]),
|
||||
);
|
||||
const parsed = stdout ? parseJson(stdout) : null;
|
||||
return extractSurfaceIds(parsed);
|
||||
}
|
||||
|
||||
async createSplit(
|
||||
direction: "right" | "down" | "left" | "up",
|
||||
): Promise<string | null> {
|
||||
return this.createSplitFrom(this.config.surfaceId, direction);
|
||||
}
|
||||
|
||||
async createSplitFrom(
|
||||
sourceSurfaceId: string | undefined,
|
||||
direction: "right" | "down" | "left" | "up",
|
||||
): Promise<string | null> {
|
||||
if (!this.config.splits) return null;
|
||||
const before = new Set(await this.listSurfaceIds());
|
||||
const args = ["new-split", direction];
|
||||
const scopedArgs = this.appendSurface(
|
||||
this.appendWorkspace(args),
|
||||
sourceSurfaceId,
|
||||
);
|
||||
await this.runAsync(scopedArgs);
|
||||
const after = await this.listSurfaceIds();
|
||||
for (const id of after) {
|
||||
if (!before.has(id)) return id;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a grid of surfaces for parallel agent execution.
|
||||
*
|
||||
* Layout strategy (sf stays in the original surface):
|
||||
* 1 agent: [sf | A]
|
||||
* 2 agents: [sf | A]
|
||||
* [ | B]
|
||||
* 3 agents: [sf | A]
|
||||
* [ C | B]
|
||||
* 4 agents: [sf | A]
|
||||
* [ C | B] (D splits from B downward)
|
||||
* [ | D]
|
||||
*
|
||||
* Returns surface IDs in order, or empty array on failure.
|
||||
*/
|
||||
async createGridLayout(count: number): Promise<string[]> {
|
||||
if (!this.config.splits || count <= 0) return [];
|
||||
const surfaces: string[] = [];
|
||||
|
||||
// First split: create right column from the sf surface
|
||||
const rightCol = await this.createSplitFrom(this.config.surfaceId, "right");
|
||||
if (!rightCol) return [];
|
||||
surfaces.push(rightCol);
|
||||
if (count === 1) return surfaces;
|
||||
|
||||
// Second split: split right column down → bottom-right
|
||||
const bottomRight = await this.createSplitFrom(rightCol, "down");
|
||||
if (!bottomRight) return surfaces;
|
||||
surfaces.push(bottomRight);
|
||||
if (count === 2) return surfaces;
|
||||
|
||||
// Third split: split sf surface down → bottom-left
|
||||
const bottomLeft = await this.createSplitFrom(
|
||||
this.config.surfaceId,
|
||||
"down",
|
||||
);
|
||||
if (!bottomLeft) return surfaces;
|
||||
surfaces.push(bottomLeft);
|
||||
if (count === 3) return surfaces;
|
||||
|
||||
// Fourth+: split subsequent surfaces down from the last created
|
||||
let lastSurface = bottomRight;
|
||||
for (let i = 3; i < count; i++) {
|
||||
const next = await this.createSplitFrom(lastSurface, "down");
|
||||
if (!next) break;
|
||||
surfaces.push(next);
|
||||
lastSurface = next;
|
||||
}
|
||||
|
||||
return surfaces;
|
||||
}
|
||||
|
||||
async sendSurface(surfaceId: string, text: string): Promise<boolean> {
|
||||
const payload = text.endsWith("\n") ? text : `${text}\n`;
|
||||
const stdout = await this.runAsync([
|
||||
"send-surface",
|
||||
"--surface",
|
||||
surfaceId,
|
||||
payload,
|
||||
]);
|
||||
return stdout !== null;
|
||||
}
|
||||
}
|
||||
|
||||
export function syncCmuxSidebar(
|
||||
preferences: SFPreferences | undefined,
|
||||
state: SFState,
|
||||
): void {
|
||||
const client = CmuxClient.fromPreferences(preferences);
|
||||
const config = client.getConfig();
|
||||
if (!config.sidebar) return;
|
||||
|
||||
const label = buildCmuxStatusLabel(state);
|
||||
const progress = buildCmuxProgress(state);
|
||||
const snapshot = JSON.stringify({ label, progress, phase: state.phase });
|
||||
const key = sidebarSnapshotKey(config);
|
||||
if (lastSidebarSnapshots.get(key) === snapshot) return;
|
||||
|
||||
client.setStatus(label, state.phase);
|
||||
client.setProgress(progress);
|
||||
lastSidebarSnapshots.set(key, snapshot);
|
||||
}
|
||||
|
||||
export function clearCmuxSidebar(preferences: SFPreferences | undefined): void {
|
||||
const config = resolveCmuxConfig(preferences);
|
||||
if (!config.available || !config.cliAvailable) return;
|
||||
const client = new CmuxClient({ ...config, enabled: true, sidebar: true });
|
||||
const key = sidebarSnapshotKey(config);
|
||||
client.clearStatus();
|
||||
client.setProgress(null);
|
||||
lastSidebarSnapshots.delete(key);
|
||||
}
|
||||
|
||||
export function logCmuxEvent(
|
||||
preferences: SFPreferences | undefined,
|
||||
message: string,
|
||||
level: CmuxLogLevel = "info",
|
||||
): void {
|
||||
CmuxClient.fromPreferences(preferences).log(message, level);
|
||||
}
|
||||
|
||||
export function shellEscape(value: string): string {
|
||||
return `'${value.replace(/'/g, `'\\''`)}'`;
|
||||
}
|
||||
|
||||
function normalizeNotificationText(value: string): string {
|
||||
return value.replace(/\r?\n/g, " ").trim();
|
||||
}
|
||||
|
||||
function parseJson(text: string): unknown {
|
||||
try {
|
||||
return JSON.parse(text);
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
function extractSurfaceIds(value: unknown): string[] {
|
||||
const found = new Set<string>();
|
||||
|
||||
const visit = (node: unknown): void => {
|
||||
if (Array.isArray(node)) {
|
||||
for (const item of node) visit(item);
|
||||
return;
|
||||
}
|
||||
if (!node || typeof node !== "object") return;
|
||||
|
||||
for (const [key, child] of Object.entries(
|
||||
node as Record<string, unknown>,
|
||||
)) {
|
||||
if (
|
||||
typeof child === "string" &&
|
||||
(key === "surface_id" ||
|
||||
key === "surface" ||
|
||||
(key === "id" && child.includes("surface")))
|
||||
) {
|
||||
found.add(child);
|
||||
}
|
||||
visit(child);
|
||||
}
|
||||
};
|
||||
|
||||
visit(value);
|
||||
return Array.from(found);
|
||||
}
|
||||
|
|
@ -1,492 +0,0 @@
|
|||
/**
|
||||
* Context7 Documentation Extension
|
||||
*
|
||||
* Replaces the context7 MCP server with a native pi extension.
|
||||
* Provides two tools for the LLM:
|
||||
*
|
||||
* resolve_library - Search for a library by name, returns candidates with metadata
|
||||
* get_library_docs - Fetch docs for a library ID, scoped to an optional query/topic
|
||||
*
|
||||
* API contract (verified against live API 2026-03-04):
|
||||
* Search: GET /api/v2/libs/search?libraryName=&query= → { results: C7Library[] }
|
||||
* Context: GET /api/v2/context?libraryId=&query=&tokens= → text/plain (markdown)
|
||||
*
|
||||
* Features:
|
||||
* - Bearer auth via CONTEXT7_API_KEY env var (optional, increases rate limits)
|
||||
* - In-session caching of search results and doc pages
|
||||
* - Smart token budgeting (default 5000, configurable per call, max 10000)
|
||||
* - Proper truncation guard so context is never overwhelmed
|
||||
* - Custom TUI rendering for clean display in pi
|
||||
*
|
||||
* Setup:
|
||||
* export CONTEXT7_API_KEY=your_key (get one at context7.com/dashboard)
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
truncateHead,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
|
||||
// ─── API types ────────────────────────────────────────────────────────────────
|
||||
|
||||
/** Shape returned by GET /api/v2/libs/search */
|
||||
interface C7SearchResponse {
|
||||
results: C7Library[];
|
||||
}
|
||||
|
||||
interface C7Library {
|
||||
id: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
branch?: string;
|
||||
lastUpdateDate?: string;
|
||||
state?: string;
|
||||
totalTokens?: number;
|
||||
totalSnippets?: number;
|
||||
stars?: number;
|
||||
trustScore?: number;
|
||||
benchmarkScore?: number;
|
||||
versions?: string[];
|
||||
}
|
||||
|
||||
// ─── In-session cache ─────────────────────────────────────────────────────────
|
||||
|
||||
// Keyed by lowercased query string
|
||||
const searchCache = new Map<string, C7Library[]>();
|
||||
|
||||
// Keyed by `${libraryId}::${query ?? ""}::${tokens}`
|
||||
const docCache = new Map<string, string>();
|
||||
|
||||
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
const BASE_URL = "https://context7.com/api/v2";
|
||||
|
||||
function getApiKey(): string | undefined {
|
||||
return process.env.CONTEXT7_API_KEY;
|
||||
}
|
||||
|
||||
function buildHeaders(): Record<string, string> {
|
||||
const headers: Record<string, string> = {
|
||||
"User-Agent": "pi-coding-agent/context7-extension",
|
||||
};
|
||||
const key = getApiKey();
|
||||
if (key) headers["Authorization"] = `Bearer ${key}`;
|
||||
return headers;
|
||||
}
|
||||
|
||||
async function apiFetchJson(
|
||||
url: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<unknown> {
|
||||
const res = await fetch(url, {
|
||||
headers: { ...buildHeaders(), Accept: "application/json" },
|
||||
signal,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = await res.text().catch(() => "");
|
||||
throw new Error(`Context7 API ${res.status}: ${body.slice(0, 300)}`);
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
async function apiFetchText(
|
||||
url: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<string> {
|
||||
const res = await fetch(url, {
|
||||
headers: { ...buildHeaders(), Accept: "text/plain" },
|
||||
signal,
|
||||
});
|
||||
if (!res.ok) {
|
||||
const body = await res.text().catch(() => "");
|
||||
throw new Error(`Context7 API ${res.status}: ${body.slice(0, 300)}`);
|
||||
}
|
||||
return res.text();
|
||||
}
|
||||
|
||||
/**
|
||||
* Format library search results into a compact, LLM-readable string.
|
||||
* Each library gets a block with the key signals for picking the best match.
|
||||
*/
|
||||
function formatLibraryList(libs: C7Library[], query: string): string {
|
||||
if (libs.length === 0) {
|
||||
return `No libraries found for "${query}". Try a different name or spelling.`;
|
||||
}
|
||||
|
||||
const lines: string[] = [
|
||||
`Found ${libs.length} ${libs.length === 1 ? "library" : "libraries"} matching "${query}":\n`,
|
||||
];
|
||||
|
||||
for (const lib of libs) {
|
||||
let line = `• ${lib.title} (ID: ${lib.id})`;
|
||||
if (lib.description) line += `\n ${lib.description}`;
|
||||
|
||||
const meta: string[] = [];
|
||||
if (lib.trustScore !== undefined) meta.push(`trust: ${lib.trustScore}/10`);
|
||||
if (lib.benchmarkScore !== undefined)
|
||||
meta.push(`benchmark: ${lib.benchmarkScore.toFixed(1)}`);
|
||||
if (lib.totalSnippets !== undefined)
|
||||
meta.push(`${lib.totalSnippets.toLocaleString()} snippets`);
|
||||
if (lib.totalTokens !== undefined)
|
||||
meta.push(`${(lib.totalTokens / 1000).toFixed(0)}k tokens`);
|
||||
if (lib.lastUpdateDate)
|
||||
meta.push(`updated: ${lib.lastUpdateDate.split("T")[0]}`);
|
||||
if (meta.length > 0) line += `\n ${meta.join(" · ")}`;
|
||||
|
||||
lines.push(line);
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"\nUse the ID (e.g. /websites/react_dev) with get_library_docs to fetch documentation.",
|
||||
);
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ─── Tool details types ───────────────────────────────────────────────────────
|
||||
|
||||
interface ResolveDetails {
|
||||
query: string;
|
||||
resultCount: number;
|
||||
cached: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
interface DocsDetails {
|
||||
libraryId: string;
|
||||
query?: string;
|
||||
tokens: number;
|
||||
cached: boolean;
|
||||
truncated: boolean;
|
||||
charCount: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// ─── Extension ───────────────────────────────────────────────────────────────
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
// ── resolve_library ──────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool({
|
||||
name: "resolve_library",
|
||||
label: "Resolve Library",
|
||||
description:
|
||||
"Search the Context7 library catalogue by name and return matching libraries with metadata. " +
|
||||
"Use this to find the correct library ID before fetching documentation. " +
|
||||
"Results are ranked by trustScore (0–10) and benchmarkScore — prefer the highest. " +
|
||||
"If you already have a library ID (e.g. /vercel/next.js), skip this and call get_library_docs directly.",
|
||||
promptSnippet:
|
||||
"Search Context7 for a library by name to get its ID for documentation lookup",
|
||||
promptGuidelines: [
|
||||
"Call resolve_library first when the user asks about a library, package, or framework you need current docs for.",
|
||||
"Choose the result with the highest trustScore and benchmarkScore when multiple matches appear.",
|
||||
"Pass the user's question as the query parameter — it improves result ranking.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
libraryName: Type.String({
|
||||
description:
|
||||
"Library or framework name to search for, e.g. 'react', 'next.js', 'tailwindcss', 'prisma', 'langchain'",
|
||||
}),
|
||||
query: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Optional: the user's question or topic. Improves search ranking. E.g. 'how do I use server actions?'",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, signal, _onUpdate, _ctx) {
|
||||
const cacheKey = params.libraryName.toLowerCase().trim();
|
||||
|
||||
if (searchCache.has(cacheKey)) {
|
||||
const cached = searchCache.get(cacheKey)!;
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: formatLibraryList(cached, params.libraryName),
|
||||
},
|
||||
],
|
||||
details: {
|
||||
query: params.libraryName,
|
||||
resultCount: cached.length,
|
||||
cached: true,
|
||||
} as ResolveDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const url = new URL(`${BASE_URL}/libs/search`);
|
||||
url.searchParams.set("libraryName", params.libraryName);
|
||||
if (params.query) url.searchParams.set("query", params.query);
|
||||
|
||||
let libs: C7Library[];
|
||||
try {
|
||||
const data = (await apiFetchJson(
|
||||
url.toString(),
|
||||
signal,
|
||||
)) as C7SearchResponse;
|
||||
libs = Array.isArray(data?.results) ? data.results : [];
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return {
|
||||
content: [{ type: "text", text: `Context7 search failed: ${msg}` }],
|
||||
isError: true,
|
||||
details: {
|
||||
query: params.libraryName,
|
||||
resultCount: 0,
|
||||
cached: false,
|
||||
error: msg,
|
||||
} as ResolveDetails,
|
||||
};
|
||||
}
|
||||
|
||||
searchCache.set(cacheKey, libs);
|
||||
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: formatLibraryList(libs, params.libraryName) },
|
||||
],
|
||||
details: {
|
||||
query: params.libraryName,
|
||||
resultCount: libs.length,
|
||||
cached: false,
|
||||
} as ResolveDetails,
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("resolve_library "));
|
||||
text += theme.fg("accent", `"${args.libraryName}"`);
|
||||
if (args.query) text += theme.fg("muted", ` — "${args.query}"`);
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial }, theme) {
|
||||
const d = result.details as ResolveDetails | undefined;
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Searching Context7..."), 0, 0);
|
||||
if ((result as any).isError || d?.error) {
|
||||
return new Text(
|
||||
theme.fg("error", `Error: ${d?.error ?? "unknown"}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
let text = theme.fg(
|
||||
"success",
|
||||
`${d?.resultCount ?? 0} ${d?.resultCount === 1 ? "library" : "libraries"} found`,
|
||||
);
|
||||
if (d?.cached) text += theme.fg("dim", " (cached)");
|
||||
text += theme.fg("dim", ` for "${d?.query}"`);
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
});
|
||||
|
||||
// ── get_library_docs ─────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool({
|
||||
name: "get_library_docs",
|
||||
label: "Get Library Docs",
|
||||
description:
|
||||
"Fetch up-to-date documentation from Context7 for a specific library. " +
|
||||
"Pass the library ID from resolve_library (e.g. /websites/react_dev) and a focused topic query " +
|
||||
"to get the most relevant snippets. " +
|
||||
"The tokens parameter controls how much documentation to retrieve (default 5000, max 10000). " +
|
||||
"A specific query (e.g. 'server actions form submission') returns better results than a broad one.",
|
||||
promptSnippet:
|
||||
"Fetch up-to-date, version-specific documentation for a library from Context7",
|
||||
promptGuidelines: [
|
||||
"Use a specific topic query for best results — e.g. 'useEffect cleanup' not just 'hooks'.",
|
||||
"Start with tokens=5000. Increase to 10000 only if the first response lacks the detail you need.",
|
||||
"Results are cached per-session — repeated calls for the same library+query have no API cost.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
libraryId: Type.String({
|
||||
description:
|
||||
"Context7 library ID from resolve_library, e.g. /websites/react_dev or /vercel/next.js",
|
||||
}),
|
||||
query: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Specific topic to focus the docs on, e.g. 'server actions', 'useEffect cleanup', 'authentication middleware'. More specific = better results.",
|
||||
}),
|
||||
),
|
||||
tokens: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Max tokens of documentation to return (default 5000, max 10000).",
|
||||
minimum: 500,
|
||||
maximum: 10000,
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, signal, _onUpdate, _ctx) {
|
||||
const tokens = Math.min(Math.max(params.tokens ?? 5000, 500), 10000);
|
||||
// Strip accidental leading @ that some models inject
|
||||
const libraryId = params.libraryId.startsWith("@")
|
||||
? params.libraryId.slice(1)
|
||||
: params.libraryId;
|
||||
const query = params.query?.trim() || undefined;
|
||||
|
||||
const cacheKey = `${libraryId}::${query ?? ""}::${tokens}`;
|
||||
|
||||
if (docCache.has(cacheKey)) {
|
||||
const cached = docCache.get(cacheKey)!;
|
||||
return {
|
||||
content: [{ type: "text", text: cached }],
|
||||
details: {
|
||||
libraryId,
|
||||
query,
|
||||
tokens,
|
||||
cached: true,
|
||||
truncated: false,
|
||||
charCount: cached.length,
|
||||
} as DocsDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const url = new URL(`${BASE_URL}/context`);
|
||||
url.searchParams.set("libraryId", libraryId);
|
||||
if (query) url.searchParams.set("query", query);
|
||||
url.searchParams.set("tokens", String(tokens));
|
||||
|
||||
let rawText: string;
|
||||
try {
|
||||
rawText = await apiFetchText(url.toString(), signal);
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Context7 doc fetch failed: ${msg}` },
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
libraryId,
|
||||
query,
|
||||
tokens,
|
||||
cached: false,
|
||||
truncated: false,
|
||||
charCount: 0,
|
||||
error: msg,
|
||||
} as DocsDetails,
|
||||
};
|
||||
}
|
||||
|
||||
if (!rawText.trim()) {
|
||||
const notFound = query
|
||||
? `No documentation found for "${query}" in ${libraryId}. Try a broader query or different library ID.`
|
||||
: `No documentation found for ${libraryId}. Try resolve_library to verify the library ID.`;
|
||||
return {
|
||||
content: [{ type: "text", text: notFound }],
|
||||
details: {
|
||||
libraryId,
|
||||
query,
|
||||
tokens,
|
||||
cached: false,
|
||||
truncated: false,
|
||||
charCount: 0,
|
||||
} as DocsDetails,
|
||||
};
|
||||
}
|
||||
|
||||
// Truncation guard — Context7 already respects the token budget, but be defensive
|
||||
const truncation = truncateHead(rawText, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText +=
|
||||
`\n\n[Truncated: showing ${truncation.outputLines}/${truncation.totalLines} lines` +
|
||||
` (${formatSize(truncation.outputBytes)} of ${formatSize(truncation.totalBytes)}).` +
|
||||
` Use a more specific query to reduce output size.]`;
|
||||
}
|
||||
|
||||
docCache.set(cacheKey, finalText);
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: {
|
||||
libraryId,
|
||||
query,
|
||||
tokens,
|
||||
cached: false,
|
||||
truncated: truncation.truncated,
|
||||
charCount: finalText.length,
|
||||
} as DocsDetails,
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("get_library_docs "));
|
||||
text += theme.fg("accent", args.libraryId);
|
||||
if (args.query) text += theme.fg("muted", ` — "${args.query}"`);
|
||||
if (args.tokens && args.tokens !== 5000)
|
||||
text += theme.fg("dim", ` (${args.tokens} tokens)`);
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial, expanded }, theme) {
|
||||
const d = result.details as DocsDetails | undefined;
|
||||
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Fetching documentation..."), 0, 0);
|
||||
if ((result as any).isError || d?.error) {
|
||||
return new Text(
|
||||
theme.fg("error", `Error: ${d?.error ?? "unknown"}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
let text = theme.fg(
|
||||
"success",
|
||||
`${(d?.charCount ?? 0).toLocaleString()} chars`,
|
||||
);
|
||||
text += theme.fg("dim", ` · ${d?.tokens ?? 5000} token budget`);
|
||||
if (d?.cached) text += theme.fg("dim", " · cached");
|
||||
if (d?.truncated) text += theme.fg("warning", " · truncated");
|
||||
text += theme.fg("dim", ` · ${d?.libraryId}`);
|
||||
if (d?.query) text += theme.fg("dim", ` — "${d.query}"`);
|
||||
|
||||
if (expanded) {
|
||||
const content = result.content[0];
|
||||
if (content?.type === "text") {
|
||||
const preview = content.text.split("\n").slice(0, 12).join("\n");
|
||||
text += "\n\n" + theme.fg("dim", preview);
|
||||
if (content.text.split("\n").length > 12) {
|
||||
text += "\n" + theme.fg("muted", "… (Ctrl+O to collapse)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
});
|
||||
|
||||
// ── Session cleanup ─────────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
searchCache.clear();
|
||||
docCache.clear();
|
||||
});
|
||||
|
||||
// ── Startup notification ─────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
if (!getApiKey()) {
|
||||
ctx.ui.notify(
|
||||
"Context7: No CONTEXT7_API_KEY set. Using free tier (1000 req/month limit). " +
|
||||
"Set CONTEXT7_API_KEY for higher limits.",
|
||||
"warning",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
{
|
||||
"name": "pi-extension-context7",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
"name": "pi-extension-context7",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.js"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +0,0 @@
|
|||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { installGenaiProxyExtension } from "./proxy-command.js";
|
||||
|
||||
export {
|
||||
installGenaiProxyExtension,
|
||||
resolveProxyPort,
|
||||
} from "./proxy-command.js";
|
||||
export { createProxyServer, ProxyServer } from "./proxy-server.js";
|
||||
|
||||
export default function genaiProxyExtension(api: ExtensionAPI): void {
|
||||
installGenaiProxyExtension(api);
|
||||
}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
{
|
||||
"name": "pi-genai-proxy",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
"name": "pi-genai-proxy",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.js"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,162 +0,0 @@
|
|||
import type {
|
||||
ExtensionAPI,
|
||||
ExtensionCommandContext,
|
||||
ExtensionStartupContext,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { createProxyServer, type ProxyServer } from "./proxy-server.js";
|
||||
|
||||
const PROXY_COMMAND_NAME = "genai-proxy";
|
||||
const PROXY_FLAG_NAME = "gemini-cli-proxy";
|
||||
const DEFAULT_PROXY_PORT = 3000;
|
||||
|
||||
export interface ProxyCommandDependencies {
|
||||
createProxyServer?: typeof createProxyServer;
|
||||
}
|
||||
|
||||
export function installGenaiProxyExtension(
|
||||
api: Pick<ExtensionAPI, "registerCommand" | "registerFlag">,
|
||||
dependencies?: ProxyCommandDependencies,
|
||||
): void {
|
||||
let proxyServer: ProxyServer | null = null;
|
||||
const buildProxyServer = dependencies?.createProxyServer ?? createProxyServer;
|
||||
|
||||
const ensureProxyServer = (
|
||||
context: ExtensionStartupContext | ExtensionCommandContext,
|
||||
port: number,
|
||||
): ProxyServer => {
|
||||
if (proxyServer && proxyServer.getPort() === port) {
|
||||
return proxyServer;
|
||||
}
|
||||
if (proxyServer) {
|
||||
throw new Error(`Proxy already running on port ${proxyServer.getPort()}`);
|
||||
}
|
||||
|
||||
proxyServer = buildProxyServer({
|
||||
port,
|
||||
modelRegistry: context.modelRegistry,
|
||||
onLog: (message) => notifyProxyStatus(context, message, "info"),
|
||||
});
|
||||
return proxyServer;
|
||||
};
|
||||
|
||||
api.registerFlag(PROXY_FLAG_NAME, {
|
||||
description: "Start the Gemini CLI proxy server",
|
||||
type: "string",
|
||||
allowNoValue: true,
|
||||
onStartup: async (value, context) => {
|
||||
const server = ensureProxyServer(context, resolveProxyPort(value));
|
||||
await server.start();
|
||||
},
|
||||
});
|
||||
|
||||
api.registerCommand(PROXY_COMMAND_NAME, {
|
||||
description: "Manage the Gemini CLI proxy server",
|
||||
handler: async (args, context) => {
|
||||
await handleProxyCommand(
|
||||
args ?? "",
|
||||
context,
|
||||
ensureProxyServer,
|
||||
() => proxyServer,
|
||||
() => {
|
||||
proxyServer = null;
|
||||
},
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function resolveProxyPort(
|
||||
flagValue: boolean | string | undefined,
|
||||
): number {
|
||||
if (flagValue === true || flagValue === false || flagValue === undefined) {
|
||||
return DEFAULT_PROXY_PORT;
|
||||
}
|
||||
|
||||
const port = Number.parseInt(flagValue, 10);
|
||||
if (!Number.isFinite(port) || port <= 0 || port > 65535) {
|
||||
throw new Error(`Invalid proxy port: ${flagValue}`);
|
||||
}
|
||||
return port;
|
||||
}
|
||||
|
||||
async function handleProxyCommand(
|
||||
rawArgs: string,
|
||||
context: ExtensionCommandContext,
|
||||
ensureProxyServer: (
|
||||
context: ExtensionCommandContext,
|
||||
port: number,
|
||||
) => ProxyServer,
|
||||
getProxyServer: () => ProxyServer | null,
|
||||
clearProxyServer: () => void,
|
||||
): Promise<void> {
|
||||
const [subcommand = "status", portArg] = rawArgs
|
||||
.trim()
|
||||
.split(/\s+/)
|
||||
.filter((value): value is string => value.length > 0);
|
||||
|
||||
if (subcommand === "start") {
|
||||
const existingServer = getProxyServer();
|
||||
if (existingServer?.isRunning()) {
|
||||
notifyProxyStatus(
|
||||
context,
|
||||
`Proxy already running on port ${existingServer.getPort()}`,
|
||||
"info",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const server = ensureProxyServer(
|
||||
context,
|
||||
resolveProxyPort(portArg === undefined ? true : portArg),
|
||||
);
|
||||
await server.start();
|
||||
return;
|
||||
}
|
||||
|
||||
if (subcommand === "stop") {
|
||||
const server = getProxyServer();
|
||||
if (!server?.isRunning()) {
|
||||
notifyProxyStatus(context, "Proxy is not running", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
await server.stop();
|
||||
clearProxyServer();
|
||||
notifyProxyStatus(context, "Proxy stopped", "success");
|
||||
return;
|
||||
}
|
||||
|
||||
if (subcommand === "status") {
|
||||
const server = getProxyServer();
|
||||
if (server?.isRunning()) {
|
||||
notifyProxyStatus(
|
||||
context,
|
||||
`Proxy running on port ${server.getPort()}`,
|
||||
"info",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
notifyProxyStatus(context, "Proxy is not running", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
notifyProxyStatus(
|
||||
context,
|
||||
"Usage: /genai-proxy start [port] | stop | status",
|
||||
"warning",
|
||||
);
|
||||
}
|
||||
|
||||
function notifyProxyStatus(
|
||||
context: ExtensionStartupContext | ExtensionCommandContext,
|
||||
message: string,
|
||||
type: Parameters<ExtensionCommandContext["ui"]["notify"]>[1],
|
||||
): void {
|
||||
if ("ui" in context) {
|
||||
context.ui.notify(message, type);
|
||||
return;
|
||||
}
|
||||
|
||||
process.stderr.write(`[genai-proxy] ${message}\n`);
|
||||
}
|
||||
|
|
@ -1,489 +0,0 @@
|
|||
import type { Server } from "node:http";
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
type Model,
|
||||
type ProviderStreamOptions,
|
||||
stream,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import type { ModelRegistry } from "@singularity-forge/pi-coding-agent";
|
||||
import express from "express";
|
||||
|
||||
const LISTEN_ADDRESS = "127.0.0.1";
|
||||
const OPENAI_CREATED_TIMESTAMP = 1_677_610_602;
|
||||
const SSE_CONTENT_TYPE = "text/event-stream";
|
||||
const NDJSON_CONTENT_TYPE = "application/x-ndjson";
|
||||
|
||||
type ProxyStreamFn = (
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: ProviderStreamOptions,
|
||||
) => AssistantMessageEventStream;
|
||||
|
||||
export interface ProxyServerOptions {
|
||||
port: number;
|
||||
modelRegistry: Pick<ModelRegistry, "find" | "getAll" | "getApiKey">;
|
||||
onLog?: (message: string) => void;
|
||||
streamModel?: ProxyStreamFn;
|
||||
}
|
||||
|
||||
interface OpenAiMessage {
|
||||
role?: string;
|
||||
content?: string | Array<{ type?: string; text?: string }>;
|
||||
}
|
||||
|
||||
interface OpenAiChatBody {
|
||||
model?: string;
|
||||
messages?: OpenAiMessage[];
|
||||
stream?: boolean;
|
||||
temperature?: number;
|
||||
max_tokens?: number;
|
||||
}
|
||||
|
||||
interface GoogleStreamBody {
|
||||
model?: string;
|
||||
contents?: Array<{
|
||||
role?: string;
|
||||
parts?: Array<{ text?: string }>;
|
||||
}>;
|
||||
systemInstruction?: {
|
||||
parts?: Array<{ text?: string }>;
|
||||
};
|
||||
stream?: boolean;
|
||||
temperature?: number;
|
||||
generationConfig?: {
|
||||
maxOutputTokens?: number;
|
||||
};
|
||||
}
|
||||
|
||||
type RouteKind = "openai" | "google";
|
||||
|
||||
export class ProxyServer {
|
||||
private server: Server | null = null;
|
||||
private boundPort: number | null = null;
|
||||
private readonly options: ProxyServerOptions;
|
||||
private readonly streamModel: ProxyStreamFn;
|
||||
|
||||
constructor(options: ProxyServerOptions) {
|
||||
this.options = options;
|
||||
this.streamModel = options.streamModel ?? stream;
|
||||
}
|
||||
|
||||
isRunning(): boolean {
|
||||
return this.server !== null;
|
||||
}
|
||||
|
||||
getPort(): number | null {
|
||||
return this.boundPort;
|
||||
}
|
||||
|
||||
async start(): Promise<void> {
|
||||
if (this.server) {
|
||||
return;
|
||||
}
|
||||
|
||||
const app = express();
|
||||
app.use(express.json({ limit: "2mb" }));
|
||||
|
||||
app.get(["/v1/models", "/v1beta/models"], (_req, res) => {
|
||||
const models = this.options.modelRegistry.getAll().map((model) => ({
|
||||
id: model.id,
|
||||
object: "model",
|
||||
created: OPENAI_CREATED_TIMESTAMP,
|
||||
owned_by: model.provider,
|
||||
name: model.name,
|
||||
capabilities: model.capabilities,
|
||||
}));
|
||||
|
||||
if (_req.path.startsWith("/v1beta")) {
|
||||
res.json({ models });
|
||||
return;
|
||||
}
|
||||
|
||||
res.json({ object: "list", data: models });
|
||||
});
|
||||
|
||||
app.post("/v1/chat/completions", async (req, res) => {
|
||||
await this.handleCompletionRequest(req, res, "openai");
|
||||
});
|
||||
|
||||
app.post(
|
||||
"/v1beta/models/:modelId\\:streamGenerateContent",
|
||||
async (req, res) => {
|
||||
await this.handleCompletionRequest(req, res, "google");
|
||||
},
|
||||
);
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
const server = app.listen(this.options.port, LISTEN_ADDRESS, () => {
|
||||
this.server = server;
|
||||
const address = server.address();
|
||||
if (typeof address === "object" && address) {
|
||||
this.boundPort = address.port;
|
||||
} else {
|
||||
this.boundPort = this.options.port;
|
||||
}
|
||||
this.options.onLog?.(
|
||||
`Proxy Server running on http://${LISTEN_ADDRESS}:${this.boundPort}`,
|
||||
);
|
||||
resolve();
|
||||
});
|
||||
|
||||
server.once("error", reject);
|
||||
});
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
if (!this.server) {
|
||||
return;
|
||||
}
|
||||
|
||||
const server = this.server;
|
||||
this.server = null;
|
||||
this.boundPort = null;
|
||||
|
||||
await new Promise<void>((resolve, reject) => {
|
||||
server.close((error) => {
|
||||
if (error) {
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
resolve();
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private async handleCompletionRequest(
|
||||
req: express.Request,
|
||||
res: express.Response,
|
||||
routeKind: RouteKind,
|
||||
): Promise<void> {
|
||||
const body = req.body as OpenAiChatBody | GoogleStreamBody;
|
||||
const modelReference = this.resolveModelReference(
|
||||
body.model,
|
||||
req.params.modelId,
|
||||
);
|
||||
|
||||
if (!modelReference) {
|
||||
res.status(400).json({ error: "Model ID is required" });
|
||||
return;
|
||||
}
|
||||
|
||||
const model = this.resolveModel(modelReference);
|
||||
if (!model) {
|
||||
res.status(404).json({ error: `Model ${modelReference} not found` });
|
||||
return;
|
||||
}
|
||||
|
||||
const apiKey = await this.options.modelRegistry.getApiKey(model);
|
||||
if (!apiKey) {
|
||||
res
|
||||
.status(401)
|
||||
.json({ error: `No credentials for provider ${model.provider}` });
|
||||
return;
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
req.once("close", () => abortController.abort());
|
||||
|
||||
const maxTokens =
|
||||
routeKind === "openai"
|
||||
? (body as OpenAiChatBody).max_tokens
|
||||
: (body as GoogleStreamBody).generationConfig?.maxOutputTokens;
|
||||
|
||||
const context = this.normalizeContext(body, routeKind);
|
||||
const options: ProviderStreamOptions = {
|
||||
apiKey,
|
||||
temperature: body.temperature,
|
||||
maxTokens,
|
||||
signal: abortController.signal,
|
||||
};
|
||||
|
||||
const eventStream = this.streamModel(model, context, options);
|
||||
const shouldStream =
|
||||
routeKind === "google"
|
||||
? (body as GoogleStreamBody).stream !== false
|
||||
: (body as OpenAiChatBody).stream === true;
|
||||
|
||||
if (shouldStream) {
|
||||
await this.sendStreamingResponse(eventStream, res, routeKind, model);
|
||||
return;
|
||||
}
|
||||
|
||||
await this.sendBufferedResponse(eventStream, res, routeKind, model);
|
||||
}
|
||||
|
||||
private resolveModelReference(
|
||||
bodyModel: string | undefined,
|
||||
pathModelId: string | undefined,
|
||||
): string | undefined {
|
||||
return bodyModel ?? pathModelId;
|
||||
}
|
||||
|
||||
private resolveModel(modelReference: string): Model<Api> | undefined {
|
||||
const normalizedReference = modelReference.toLowerCase();
|
||||
const exact = this.options.modelRegistry
|
||||
.getAll()
|
||||
.find(
|
||||
(model) =>
|
||||
`${model.provider}/${model.id}`.toLowerCase() ===
|
||||
normalizedReference ||
|
||||
model.id.toLowerCase() === normalizedReference,
|
||||
);
|
||||
if (exact) {
|
||||
return exact;
|
||||
}
|
||||
|
||||
const slashIndex = modelReference.indexOf("/");
|
||||
if (slashIndex === -1) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const provider = modelReference.slice(0, slashIndex);
|
||||
const modelId = modelReference.slice(slashIndex + 1);
|
||||
return this.options.modelRegistry.find(provider, modelId);
|
||||
}
|
||||
|
||||
private normalizeContext(
|
||||
body: OpenAiChatBody | GoogleStreamBody,
|
||||
routeKind: RouteKind,
|
||||
): Context {
|
||||
if (routeKind === "google") {
|
||||
return this.normalizeGoogleContext(body as GoogleStreamBody);
|
||||
}
|
||||
|
||||
return this.normalizeOpenAiContext(body as OpenAiChatBody);
|
||||
}
|
||||
|
||||
private normalizeOpenAiContext(body: OpenAiChatBody): Context {
|
||||
const messages = body.messages ?? [];
|
||||
const systemPrompt = messages.find(
|
||||
(message) => message.role === "system",
|
||||
)?.content;
|
||||
const normalizedMessages = messages
|
||||
.filter((message) => message.role !== "system")
|
||||
.map((message) => this.normalizeOpenAiMessage(message));
|
||||
|
||||
return {
|
||||
systemPrompt: typeof systemPrompt === "string" ? systemPrompt : undefined,
|
||||
messages: normalizedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
private normalizeGoogleContext(body: GoogleStreamBody): Context {
|
||||
const systemPrompt =
|
||||
body.systemInstruction?.parts?.map((part) => part.text ?? "").join("") ||
|
||||
undefined;
|
||||
const normalizedMessages = (body.contents ?? [])
|
||||
.map((content) => {
|
||||
const textContent = (content.parts ?? [])
|
||||
.filter((part) => typeof part.text === "string")
|
||||
.map((part) => ({ type: "text" as const, text: part.text ?? "" }));
|
||||
|
||||
if (content.role === "user") {
|
||||
return this.createUserMessage(textContent);
|
||||
}
|
||||
|
||||
return this.createAssistantMessage(textContent);
|
||||
})
|
||||
.filter((message) => message.content.length > 0);
|
||||
|
||||
return {
|
||||
systemPrompt,
|
||||
messages: normalizedMessages,
|
||||
};
|
||||
}
|
||||
|
||||
private normalizeOpenAiMessage(
|
||||
message: OpenAiMessage,
|
||||
): Context["messages"][number] {
|
||||
if (message.role === "assistant") {
|
||||
return this.createAssistantMessage(
|
||||
this.normalizeContent(message.content),
|
||||
);
|
||||
}
|
||||
|
||||
return this.createUserMessage(this.normalizeContent(message.content));
|
||||
}
|
||||
|
||||
private createUserMessage(
|
||||
content: string | { type: "text"; text: string }[],
|
||||
): Context["messages"][number] {
|
||||
return {
|
||||
role: "user",
|
||||
content,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
private createAssistantMessage(
|
||||
content: string | { type: "text"; text: string }[],
|
||||
): AssistantMessage {
|
||||
const normalizedContent =
|
||||
typeof content === "string"
|
||||
? [{ type: "text" as const, text: content }]
|
||||
: content;
|
||||
|
||||
return {
|
||||
role: "assistant",
|
||||
content: normalizedContent,
|
||||
api: "google-gemini-cli" as Api,
|
||||
provider: "google-gemini-cli",
|
||||
model: "proxy",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
private normalizeContent(
|
||||
content: string | Array<{ type?: string; text?: string }> | undefined,
|
||||
): string | { type: "text"; text: string }[] {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
|
||||
return (content ?? [])
|
||||
.filter((part) => typeof part.text === "string")
|
||||
.map((part) => ({ type: "text" as const, text: part.text ?? "" }));
|
||||
}
|
||||
|
||||
private async sendStreamingResponse(
|
||||
eventStream: AssistantMessageEventStream,
|
||||
res: express.Response,
|
||||
routeKind: RouteKind,
|
||||
model: Model<Api>,
|
||||
): Promise<void> {
|
||||
res.status(200);
|
||||
res.setHeader(
|
||||
"Content-Type",
|
||||
routeKind === "openai" ? SSE_CONTENT_TYPE : NDJSON_CONTENT_TYPE,
|
||||
);
|
||||
res.setHeader("Cache-Control", "no-cache");
|
||||
res.setHeader("Connection", "keep-alive");
|
||||
|
||||
for await (const event of eventStream) {
|
||||
if (event.type === "text_delta") {
|
||||
if (routeKind === "openai") {
|
||||
res.write(
|
||||
`data: ${JSON.stringify(this.buildOpenAiChunk(model, event.delta))}\n\n`,
|
||||
);
|
||||
} else {
|
||||
res.write(`${JSON.stringify(this.buildGoogleChunk(event.delta))}\n`);
|
||||
}
|
||||
}
|
||||
|
||||
if (event.type === "done") {
|
||||
if (routeKind === "openai") {
|
||||
res.write("data: [DONE]\n\n");
|
||||
}
|
||||
res.end();
|
||||
return;
|
||||
}
|
||||
|
||||
if (event.type === "error") {
|
||||
if (!res.headersSent) {
|
||||
res
|
||||
.status(500)
|
||||
.json({ error: event.error.errorMessage ?? "Proxy stream failed" });
|
||||
} else {
|
||||
res.end();
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
res.end();
|
||||
}
|
||||
|
||||
private async sendBufferedResponse(
|
||||
eventStream: AssistantMessageEventStream,
|
||||
res: express.Response,
|
||||
routeKind: RouteKind,
|
||||
model: Model<Api>,
|
||||
): Promise<void> {
|
||||
const assistantMessage = await eventStream.result();
|
||||
const text = this.extractText(assistantMessage);
|
||||
|
||||
if (routeKind === "openai") {
|
||||
res.json({
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: "chat.completion",
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model.id,
|
||||
choices: [
|
||||
{
|
||||
index: 0,
|
||||
message: { role: "assistant", content: text },
|
||||
finish_reason: "stop",
|
||||
},
|
||||
],
|
||||
usage: assistantMessage.usage,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
res.json({
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text }],
|
||||
},
|
||||
},
|
||||
],
|
||||
usageMetadata: assistantMessage.usage,
|
||||
});
|
||||
}
|
||||
|
||||
private extractText(message: AssistantMessage): string {
|
||||
return message.content
|
||||
.filter(
|
||||
(
|
||||
content,
|
||||
): content is Extract<
|
||||
AssistantMessage["content"][number],
|
||||
{ type: "text" }
|
||||
> => content.type === "text",
|
||||
)
|
||||
.map((content) => content.text)
|
||||
.join("");
|
||||
}
|
||||
|
||||
private buildOpenAiChunk(
|
||||
model: Model<Api>,
|
||||
delta: string,
|
||||
): Record<string, unknown> {
|
||||
return {
|
||||
id: `chatcmpl-${Date.now()}`,
|
||||
object: "chat.completion.chunk",
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
model: model.id,
|
||||
choices: [{ index: 0, delta: { content: delta }, finish_reason: null }],
|
||||
};
|
||||
}
|
||||
|
||||
private buildGoogleChunk(delta: string): Record<string, unknown> {
|
||||
return {
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: delta }],
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
export function createProxyServer(options: ProxyServerOptions): ProxyServer {
|
||||
return new ProxyServer(options);
|
||||
}
|
||||
|
|
@ -1,32 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
import { describe, it } from 'vitest';
|
||||
|
||||
const extensionDir = join("src", "resources", "extensions", "genai-proxy");
|
||||
|
||||
describe("genai-proxy package metadata", () => {
|
||||
it("declares the index.ts extension entrypoint", () => {
|
||||
const packageJson = JSON.parse(
|
||||
readFileSync(join(extensionDir, "package.json"), "utf-8"),
|
||||
) as {
|
||||
pi?: { extensions?: string[] };
|
||||
};
|
||||
|
||||
assert.deepEqual(packageJson.pi?.extensions, ["./index.ts"]);
|
||||
});
|
||||
|
||||
it("declares a bundled extension manifest", () => {
|
||||
const manifest = JSON.parse(
|
||||
readFileSync(join(extensionDir, "extension-manifest.json"), "utf-8"),
|
||||
) as {
|
||||
id: string;
|
||||
tier: string;
|
||||
};
|
||||
|
||||
assert.deepEqual(
|
||||
{ id: manifest.id, tier: manifest.tier },
|
||||
{ id: "genai-proxy", tier: "bundled" },
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from 'vitest';
|
||||
import {
|
||||
installGenaiProxyExtension,
|
||||
resolveProxyPort,
|
||||
} from "../proxy-command.ts";
|
||||
|
||||
describe("genai-proxy command boundary", () => {
|
||||
it("resolves default and explicit proxy ports from flag values", () => {
|
||||
const result = [resolveProxyPort(true), resolveProxyPort("8080")];
|
||||
|
||||
assert.deepEqual(result, [3000, 8080]);
|
||||
});
|
||||
|
||||
it("registers the startup flag and slash command", () => {
|
||||
const registeredFlags: Array<{
|
||||
name: string;
|
||||
type: string;
|
||||
allowNoValue: boolean;
|
||||
hasStartupHandler: boolean;
|
||||
}> = [];
|
||||
const registeredCommands: string[] = [];
|
||||
|
||||
installGenaiProxyExtension({
|
||||
registerCommand: (name) => {
|
||||
registeredCommands.push(name);
|
||||
},
|
||||
registerFlag: (name, options) => {
|
||||
registeredFlags.push({
|
||||
name,
|
||||
type: options.type,
|
||||
allowNoValue: options.allowNoValue ?? false,
|
||||
hasStartupHandler: typeof options.onStartup === "function",
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
assert.deepEqual(
|
||||
{ flags: registeredFlags, commands: registeredCommands },
|
||||
{
|
||||
flags: [
|
||||
{
|
||||
name: "gemini-cli-proxy",
|
||||
type: "string",
|
||||
allowNoValue: true,
|
||||
hasStartupHandler: true,
|
||||
},
|
||||
],
|
||||
commands: ["genai-proxy"],
|
||||
},
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,248 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { afterEach, describe, it } from 'vitest';
|
||||
import type {
|
||||
Api,
|
||||
AssistantMessageEventStream,
|
||||
Model,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { AuthStorage, ModelRegistry } from "@singularity-forge/pi-coding-agent";
|
||||
import { createProxyServer, type ProxyServer } from "../proxy-server.ts";
|
||||
|
||||
let serverCleanup: ProxyServer | undefined;
|
||||
|
||||
afterEach(async () => {
|
||||
if (serverCleanup) {
|
||||
await serverCleanup.stop();
|
||||
serverCleanup = undefined;
|
||||
}
|
||||
});
|
||||
|
||||
function createFakeStream(): AssistantMessageEventStream {
|
||||
const events: Array<
|
||||
| { type: "start"; partial: ReturnType<typeof buildAssistantMessage> }
|
||||
| {
|
||||
type: "text_delta";
|
||||
contentIndex: number;
|
||||
delta: string;
|
||||
partial: ReturnType<typeof buildAssistantMessage>;
|
||||
}
|
||||
| {
|
||||
type: "done";
|
||||
reason: "stop";
|
||||
message: ReturnType<typeof buildAssistantMessage>;
|
||||
}
|
||||
> = [];
|
||||
let finalResult: ReturnType<typeof buildAssistantMessage> | undefined;
|
||||
let completed = false;
|
||||
|
||||
const stream = {
|
||||
push(event: (typeof events)[number]) {
|
||||
events.push(event);
|
||||
if (event.type === "done") {
|
||||
completed = true;
|
||||
finalResult = event.message;
|
||||
}
|
||||
},
|
||||
end(): void {
|
||||
completed = true;
|
||||
finalResult = buildAssistantMessage([]);
|
||||
},
|
||||
result(): Promise<ReturnType<typeof buildAssistantMessage>> {
|
||||
if (finalResult) {
|
||||
return Promise.resolve(finalResult);
|
||||
}
|
||||
return new Promise((resolve) => {
|
||||
const interval = setInterval(() => {
|
||||
if (finalResult) {
|
||||
clearInterval(interval);
|
||||
resolve(finalResult);
|
||||
}
|
||||
}, 0);
|
||||
});
|
||||
},
|
||||
async *[Symbol.asyncIterator](): AsyncIterator<(typeof events)[number]> {
|
||||
let cursor = 0;
|
||||
while (!completed || cursor < events.length) {
|
||||
const event = events[cursor];
|
||||
if (event) {
|
||||
cursor++;
|
||||
yield event;
|
||||
continue;
|
||||
}
|
||||
|
||||
await new Promise((resolve) => setTimeout(resolve, 0));
|
||||
}
|
||||
},
|
||||
} as unknown as AssistantMessageEventStream;
|
||||
|
||||
queueMicrotask(() => {
|
||||
stream.push({
|
||||
type: "start",
|
||||
partial: buildAssistantMessage([]),
|
||||
});
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex: 0,
|
||||
delta: "hello",
|
||||
partial: buildAssistantMessage([]),
|
||||
});
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: "stop",
|
||||
message: buildAssistantMessage([{ type: "text", text: "hello" }]),
|
||||
});
|
||||
});
|
||||
return stream;
|
||||
}
|
||||
|
||||
function buildAssistantMessage(content: { type: "text"; text: string }[]) {
|
||||
return {
|
||||
role: "assistant" as const,
|
||||
content,
|
||||
api: "google-gemini-cli" as Api,
|
||||
provider: "google-gemini-cli",
|
||||
model: "proxy",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop" as const,
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createRegistry() {
|
||||
const authStorage = AuthStorage.inMemory({
|
||||
openai: { type: "api_key", key: "sk-test" },
|
||||
});
|
||||
const modelRegistry = new ModelRegistry(authStorage, undefined);
|
||||
return { modelRegistry };
|
||||
}
|
||||
|
||||
function createRegistryWithoutCredentials() {
|
||||
return {
|
||||
modelRegistry: new ModelRegistry(AuthStorage.inMemory({}), undefined),
|
||||
};
|
||||
}
|
||||
|
||||
function createProxyServerForTests(modelRegistry: ModelRegistry): ProxyServer {
|
||||
return createProxyServer({
|
||||
port: 0,
|
||||
modelRegistry: {
|
||||
find: (provider, modelId) => modelRegistry.find(provider, modelId),
|
||||
getAll: () => modelRegistry.getAll(),
|
||||
getApiKey: (model) => modelRegistry.getApiKey(model),
|
||||
},
|
||||
streamModel: () => createFakeStream(),
|
||||
});
|
||||
}
|
||||
|
||||
function findOpenAiModel(modelRegistry: ModelRegistry): Model<Api> {
|
||||
const model = modelRegistry
|
||||
.getAll()
|
||||
.find((candidate) => candidate.provider === "openai");
|
||||
if (!model) {
|
||||
throw new Error("Expected at least one openai model in the registry");
|
||||
}
|
||||
return model;
|
||||
}
|
||||
|
||||
describe("ProxyServer", () => {
|
||||
it("serves model listings on /v1/models", async () => {
|
||||
const { modelRegistry } = createRegistry();
|
||||
const server = createProxyServerForTests(modelRegistry);
|
||||
serverCleanup = server;
|
||||
await server.start();
|
||||
|
||||
const response = await fetch(
|
||||
`http://127.0.0.1:${server.getPort()}/v1/models`,
|
||||
);
|
||||
const data = (await response.json()) as {
|
||||
object: string;
|
||||
data: Array<{ object: string }>;
|
||||
};
|
||||
|
||||
assert.deepEqual(
|
||||
{ ok: response.ok, object: data.object, hasModels: data.data.length > 0 },
|
||||
{ ok: true, object: "list", hasModels: true },
|
||||
);
|
||||
});
|
||||
|
||||
it("serves OpenAI completions on /v1/chat/completions", async () => {
|
||||
const { modelRegistry } = createRegistry();
|
||||
const model = findOpenAiModel(modelRegistry);
|
||||
const server = createProxyServerForTests(modelRegistry);
|
||||
serverCleanup = server;
|
||||
await server.start();
|
||||
|
||||
const response = await fetch(
|
||||
`http://127.0.0.1:${server.getPort()}/v1/chat/completions`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: `${model.provider}/${model.id}`,
|
||||
messages: [{ role: "user", content: "hello" }],
|
||||
}),
|
||||
},
|
||||
);
|
||||
const data = (await response.json()) as {
|
||||
choices: Array<{ message: { content: string } }>;
|
||||
};
|
||||
|
||||
assert.deepEqual(data.choices[0].message.content, "hello");
|
||||
});
|
||||
|
||||
it("streams Google content on /v1beta/models/:modelId:streamGenerateContent", async () => {
|
||||
const { modelRegistry } = createRegistry();
|
||||
const model = findOpenAiModel(modelRegistry);
|
||||
const server = createProxyServerForTests(modelRegistry);
|
||||
serverCleanup = server;
|
||||
await server.start();
|
||||
|
||||
const response = await fetch(
|
||||
`http://127.0.0.1:${server.getPort()}/v1beta/models/${encodeURIComponent(`${model.provider}/${model.id}`)}:streamGenerateContent`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
contents: [{ role: "user", parts: [{ text: "hello" }] }],
|
||||
}),
|
||||
},
|
||||
);
|
||||
const text = await response.text();
|
||||
|
||||
assert.deepEqual(text.includes("hello"), true);
|
||||
});
|
||||
|
||||
it("returns 401 when credentials are absent", async () => {
|
||||
const { modelRegistry } = createRegistryWithoutCredentials();
|
||||
const model = modelRegistry
|
||||
.getAll()
|
||||
.find((candidate) => candidate.provider === "openai");
|
||||
if (!model) {
|
||||
throw new Error("Expected at least one openai model in the registry");
|
||||
}
|
||||
const server = createProxyServerForTests(modelRegistry);
|
||||
serverCleanup = server;
|
||||
await server.start();
|
||||
|
||||
const response = await fetch(
|
||||
`http://127.0.0.1:${server.getPort()}/v1/chat/completions`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "content-type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
model: `${model.provider}/${model.id}`,
|
||||
messages: [{ role: "user", content: "hello" }],
|
||||
}),
|
||||
},
|
||||
);
|
||||
|
||||
assert.deepEqual(response.status, 401);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,713 +0,0 @@
|
|||
/**
|
||||
* get-secrets-from-user — paged secure env var collection + apply
|
||||
*
|
||||
* Collects secrets one-per-page via masked TUI input, then writes them
|
||||
* to .env (local), Vercel, or Convex. No ctx.callTool, no external deps.
|
||||
* Uses Node fs/promises for file I/O and pi.exec() for CLI sinks.
|
||||
*/
|
||||
|
||||
import { existsSync, statSync } from "node:fs";
|
||||
import { readFile, writeFile } from "node:fs/promises";
|
||||
import { resolve } from "node:path";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI, Theme } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
Editor,
|
||||
type EditorTheme,
|
||||
Key,
|
||||
matchesKey,
|
||||
Text,
|
||||
truncateToWidth,
|
||||
wrapTextWithAnsi,
|
||||
} from "@singularity-forge/pi-tui";
|
||||
import { formatSecretsManifest, parseSecretsManifest } from "./sf/files.js";
|
||||
import { resolveMilestoneFile } from "./sf/paths.js";
|
||||
import type { SecretsManifestEntry } from "./sf/types.js";
|
||||
import { maskEditorLine, type ProgressStatus } from "./shared/mod.js";
|
||||
import { makeUI } from "./shared/tui.js";
|
||||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface CollectedSecret {
|
||||
key: string;
|
||||
value: string | null; // null = skipped
|
||||
}
|
||||
|
||||
interface ToolResultDetails {
|
||||
destination: string;
|
||||
environment?: string;
|
||||
applied: string[];
|
||||
skipped: string[];
|
||||
existingSkipped?: string[];
|
||||
detectedDestination?: string;
|
||||
}
|
||||
|
||||
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
function maskPreview(value: string): string {
|
||||
if (!value) return "";
|
||||
if (value.length <= 8) return "*".repeat(value.length);
|
||||
return `${value.slice(0, 4)}${"*".repeat(Math.max(4, value.length - 8))}${value.slice(-4)}`;
|
||||
}
|
||||
|
||||
function shellEscapeSingle(value: string): string {
|
||||
return `'${value.replace(/'/g, `'\\''`)}'`;
|
||||
}
|
||||
|
||||
function isSafeEnvVarKey(key: string): boolean {
|
||||
return /^[A-Za-z_][A-Za-z0-9_]*$/.test(key);
|
||||
}
|
||||
|
||||
function isSupportedDeploymentEnvironment(env: string): boolean {
|
||||
return env === "development" || env === "preview" || env === "production";
|
||||
}
|
||||
|
||||
function hydrateProcessEnv(key: string, value: string): void {
|
||||
// Make newly collected secrets immediately visible to the current session.
|
||||
// Some extensions read process.env directly and do not reload .env on every call.
|
||||
process.env[key] = value;
|
||||
}
|
||||
|
||||
async function writeEnvKey(
|
||||
filePath: string,
|
||||
key: string,
|
||||
value: string,
|
||||
): Promise<void> {
|
||||
if (typeof value !== "string") {
|
||||
throw new TypeError(
|
||||
`writeEnvKey expects a string value for key "${key}", got ${typeof value}`,
|
||||
);
|
||||
}
|
||||
let content = "";
|
||||
try {
|
||||
content = await readFile(filePath, "utf8");
|
||||
} catch {
|
||||
content = "";
|
||||
}
|
||||
const escaped = value
|
||||
.replace(/\\/g, "\\\\")
|
||||
.replace(/\n/g, "\\n")
|
||||
.replace(/\r/g, "");
|
||||
const line = `${key}=${escaped}`;
|
||||
const regex = new RegExp(
|
||||
`^${key.replace(/[.*+?^${}()|[\]\\]/g, "\\$&")}\\s*=.*$`,
|
||||
"m",
|
||||
);
|
||||
if (regex.test(content)) {
|
||||
content = content.replace(regex, line);
|
||||
} else {
|
||||
if (content.length > 0 && !content.endsWith("\n")) content += "\n";
|
||||
content += `${line}\n`;
|
||||
}
|
||||
await writeFile(filePath, content, "utf8");
|
||||
}
|
||||
|
||||
// ─── Exported utilities ───────────────────────────────────────────────────────
|
||||
|
||||
// Re-export from env-utils.ts so existing consumers still work.
|
||||
// The implementation lives in env-utils.ts to avoid pulling @singularity-forge/pi-tui
|
||||
// into modules that only need env-checking (e.g. files.ts during reports).
|
||||
import { checkExistingEnvKeys } from "./sf/env-utils.js";
|
||||
|
||||
export { checkExistingEnvKeys };
|
||||
|
||||
/**
|
||||
* Detect the write destination based on project files in basePath.
|
||||
* Priority: vercel.json → convex/ dir → fallback "dotenv".
|
||||
*/
|
||||
export function detectDestination(
|
||||
basePath: string,
|
||||
): "dotenv" | "vercel" | "convex" {
|
||||
if (existsSync(resolve(basePath, "vercel.json"))) {
|
||||
return "vercel";
|
||||
}
|
||||
const convexPath = resolve(basePath, "convex");
|
||||
try {
|
||||
if (existsSync(convexPath) && statSync(convexPath).isDirectory()) {
|
||||
return "convex";
|
||||
}
|
||||
} catch {
|
||||
// stat error — treat as not found
|
||||
}
|
||||
return "dotenv";
|
||||
}
|
||||
|
||||
// ─── Paged secure input UI ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Show a single-key masked input page via ctx.ui.custom().
|
||||
* Returns the entered value, or null if skipped/cancelled.
|
||||
*/
|
||||
async function collectOneSecret(
|
||||
ctx: { ui: any; hasUI: boolean },
|
||||
pageIndex: number,
|
||||
totalPages: number,
|
||||
keyName: string,
|
||||
hint: string | undefined,
|
||||
guidance?: string[],
|
||||
): Promise<string | null> {
|
||||
if (!ctx.hasUI) return null;
|
||||
|
||||
const customResult = await ctx.ui.custom(
|
||||
(tui: any, theme: any, _kb: any, done: (r: string | null) => void) => {
|
||||
let value = "";
|
||||
let cachedLines: string[] | undefined;
|
||||
|
||||
const editorTheme: EditorTheme = {
|
||||
borderColor: (s: string) => theme.fg("accent", s),
|
||||
selectList: {
|
||||
selectedPrefix: (t: string) => theme.fg("accent", t),
|
||||
selectedText: (t: string) => theme.fg("accent", t),
|
||||
description: (t: string) => theme.fg("muted", t),
|
||||
scrollInfo: (t: string) => theme.fg("dim", t),
|
||||
noMatch: (t: string) => theme.fg("warning", t),
|
||||
},
|
||||
};
|
||||
const editor = new Editor(tui, editorTheme, { paddingX: 1 });
|
||||
|
||||
function refresh() {
|
||||
cachedLines = undefined;
|
||||
tui.requestRender();
|
||||
}
|
||||
|
||||
function handleInput(data: string) {
|
||||
if (matchesKey(data, Key.enter)) {
|
||||
value = editor.getText().trim();
|
||||
done(value.length > 0 ? value : null);
|
||||
return;
|
||||
}
|
||||
if (matchesKey(data, Key.escape)) {
|
||||
done(null);
|
||||
return;
|
||||
}
|
||||
// ctrl+s = skip this key
|
||||
if (data === "\x13") {
|
||||
done(null);
|
||||
return;
|
||||
}
|
||||
editor.handleInput(data);
|
||||
refresh();
|
||||
}
|
||||
|
||||
function render(width: number): string[] {
|
||||
if (cachedLines) return cachedLines;
|
||||
const lines: string[] = [];
|
||||
const add = (s: string) => lines.push(truncateToWidth(s, width));
|
||||
|
||||
add(theme.fg("accent", "─".repeat(width)));
|
||||
add(
|
||||
theme.fg(
|
||||
"dim",
|
||||
` Page ${pageIndex + 1}/${totalPages} · Secure Env Setup`,
|
||||
),
|
||||
);
|
||||
lines.push("");
|
||||
|
||||
// Key name as big header
|
||||
add(theme.fg("accent", theme.bold(` ${keyName}`)));
|
||||
if (hint) {
|
||||
add(theme.fg("muted", ` ${hint}`));
|
||||
}
|
||||
|
||||
// Guidance steps (numbered, dim, wrapped for long URLs)
|
||||
if (guidance && guidance.length > 0) {
|
||||
lines.push("");
|
||||
for (let g = 0; g < guidance.length; g++) {
|
||||
const prefix = ` ${g + 1}. `;
|
||||
const step = guidance[g] as string;
|
||||
const wrappedLines = wrapTextWithAnsi(step, width - 4);
|
||||
for (let w = 0; w < wrappedLines.length; w++) {
|
||||
const indent = w === 0 ? prefix : " ".repeat(prefix.length);
|
||||
lines.push(theme.fg("dim", `${indent}${wrappedLines[w]}`));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
lines.push("");
|
||||
|
||||
// Masked preview
|
||||
const raw = editor.getText();
|
||||
const preview =
|
||||
raw.length > 0
|
||||
? maskPreview(raw)
|
||||
: theme.fg("dim", "(empty — press enter to skip)");
|
||||
add(theme.fg("text", ` Preview: ${preview}`));
|
||||
lines.push("");
|
||||
|
||||
// Editor
|
||||
add(theme.fg("muted", " Enter value:"));
|
||||
for (const line of editor.render(width - 2)) {
|
||||
add(theme.fg("text", maskEditorLine(line)));
|
||||
}
|
||||
|
||||
lines.push("");
|
||||
add(
|
||||
theme.fg(
|
||||
"dim",
|
||||
` enter to confirm | ctrl+s or esc to skip | esc cancels`,
|
||||
),
|
||||
);
|
||||
add(theme.fg("accent", "─".repeat(width)));
|
||||
|
||||
cachedLines = lines;
|
||||
return lines;
|
||||
}
|
||||
|
||||
return {
|
||||
render,
|
||||
invalidate: () => {
|
||||
cachedLines = undefined;
|
||||
},
|
||||
handleInput,
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
// RPC/web surfaces may not implement ctx.ui.custom(). Fall back to a
|
||||
// standard input prompt so users can still provide the secret.
|
||||
if (customResult !== undefined) {
|
||||
return customResult;
|
||||
}
|
||||
|
||||
if (typeof ctx.ui?.input !== "function") {
|
||||
return null;
|
||||
}
|
||||
|
||||
const inputTitle = `Secure value for ${keyName} (${pageIndex + 1}/${totalPages})`;
|
||||
const inputPlaceholder = hint || "Enter secret value";
|
||||
const inputResult = await ctx.ui.input(inputTitle, inputPlaceholder, {
|
||||
secure: true,
|
||||
});
|
||||
if (typeof inputResult !== "string") {
|
||||
return null;
|
||||
}
|
||||
const trimmed = inputResult.trim();
|
||||
return trimmed.length > 0 ? trimmed : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Exported wrapper around collectOneSecret for testing.
|
||||
* Exposes the same interface with guidance parameter for test verification.
|
||||
*/
|
||||
export const collectOneSecretWithGuidance = collectOneSecret;
|
||||
|
||||
// ─── Summary Screen ───────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Read-only summary screen showing all manifest entries with status indicators.
|
||||
* Follows the confirm-ui.ts pattern: render → any key → done.
|
||||
*
|
||||
* Status mapping:
|
||||
* - collected → done
|
||||
* - pending → pending
|
||||
* - skipped → skipped
|
||||
* - existing keys (in existingKeys) → done with "already set" annotation
|
||||
*/
|
||||
export async function showSecretsSummary(
|
||||
ctx: { ui: any; hasUI: boolean },
|
||||
entries: SecretsManifestEntry[],
|
||||
existingKeys: string[],
|
||||
): Promise<void> {
|
||||
if (!ctx.hasUI) return;
|
||||
|
||||
const existingSet = new Set(existingKeys);
|
||||
|
||||
await ctx.ui.custom(
|
||||
(_tui: any, theme: Theme, _kb: any, done: (r: null) => void) => {
|
||||
let cachedLines: string[] | undefined;
|
||||
|
||||
function handleInput(_data: string) {
|
||||
// Any key dismisses — pass null to satisfy the typed done() callback
|
||||
done(null);
|
||||
}
|
||||
|
||||
function render(width: number): string[] {
|
||||
if (cachedLines) return cachedLines;
|
||||
|
||||
const ui = makeUI(theme, width);
|
||||
const lines: string[] = [];
|
||||
const push = (...rows: string[][]) => {
|
||||
for (const r of rows) lines.push(...r);
|
||||
};
|
||||
|
||||
push(ui.bar());
|
||||
push(ui.blank());
|
||||
push(ui.header(" Secrets Summary"));
|
||||
push(ui.blank());
|
||||
|
||||
for (const entry of entries) {
|
||||
let status: ProgressStatus;
|
||||
let detail: string | undefined;
|
||||
|
||||
if (existingSet.has(entry.key)) {
|
||||
status = "done";
|
||||
detail = "already set";
|
||||
} else if (entry.status === "collected") {
|
||||
status = "done";
|
||||
} else if (entry.status === "skipped") {
|
||||
status = "skipped";
|
||||
} else {
|
||||
status = "pending";
|
||||
}
|
||||
|
||||
push(ui.progressItem(entry.key, status, { detail }));
|
||||
}
|
||||
|
||||
push(ui.blank());
|
||||
push(ui.hints(["any key to continue"]));
|
||||
push(ui.bar());
|
||||
|
||||
cachedLines = lines;
|
||||
return lines;
|
||||
}
|
||||
|
||||
return {
|
||||
render,
|
||||
invalidate: () => {
|
||||
cachedLines = undefined;
|
||||
},
|
||||
handleInput,
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Destination Write Helper ─────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Apply collected secrets to the target destination.
|
||||
* Dotenv writes are handled directly; vercel/convex require pi.exec.
|
||||
*/
|
||||
async function applySecrets(
|
||||
provided: Array<{ key: string; value: string }>,
|
||||
destination: "dotenv" | "vercel" | "convex",
|
||||
opts: {
|
||||
envFilePath: string;
|
||||
environment?: string;
|
||||
exec?: (
|
||||
cmd: string,
|
||||
args: string[],
|
||||
) => Promise<{ code: number; stderr: string }>;
|
||||
},
|
||||
): Promise<{ applied: string[]; errors: string[] }> {
|
||||
const applied: string[] = [];
|
||||
const errors: string[] = [];
|
||||
|
||||
if (destination === "dotenv") {
|
||||
for (const { key, value } of provided) {
|
||||
try {
|
||||
await writeEnvKey(opts.envFilePath, key, value);
|
||||
applied.push(key);
|
||||
hydrateProcessEnv(key, value);
|
||||
} catch (err: any) {
|
||||
errors.push(`${key}: ${err.message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if ((destination === "vercel" || destination === "convex") && opts.exec) {
|
||||
const env = opts.environment ?? "development";
|
||||
if (!isSupportedDeploymentEnvironment(env)) {
|
||||
errors.push(`environment: unsupported target environment "${env}"`);
|
||||
return { applied, errors };
|
||||
}
|
||||
for (const { key, value } of provided) {
|
||||
if (!isSafeEnvVarKey(key)) {
|
||||
errors.push(`${key}: invalid environment variable name`);
|
||||
continue;
|
||||
}
|
||||
const cmd =
|
||||
destination === "vercel"
|
||||
? `printf %s ${shellEscapeSingle(value)} | vercel env add ${key} ${env}`
|
||||
: "";
|
||||
try {
|
||||
const result =
|
||||
destination === "vercel"
|
||||
? await opts.exec("sh", ["-c", cmd])
|
||||
: await opts.exec("npx", ["convex", "env", "set", key, value]);
|
||||
if (result.code !== 0) {
|
||||
errors.push(`${key}: ${result.stderr.slice(0, 200)}`);
|
||||
} else {
|
||||
applied.push(key);
|
||||
hydrateProcessEnv(key, value);
|
||||
}
|
||||
} catch (err: any) {
|
||||
errors.push(`${key}: ${err.message}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { applied, errors };
|
||||
}
|
||||
|
||||
// ─── Manifest Orchestrator ────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Full orchestrator: reads manifest, checks env, shows summary, collects
|
||||
* only pending keys (with guidance + hint), updates manifest statuses,
|
||||
* writes back, and applies collected values to the destination.
|
||||
*
|
||||
* Returns a structured result matching the tool result shape.
|
||||
*/
|
||||
export async function collectSecretsFromManifest(
|
||||
base: string,
|
||||
milestoneId: string,
|
||||
ctx: { ui: any; hasUI: boolean; cwd: string },
|
||||
): Promise<{
|
||||
applied: string[];
|
||||
skipped: string[];
|
||||
existingSkipped: string[];
|
||||
}> {
|
||||
// (a) Resolve manifest path
|
||||
const manifestPath = resolveMilestoneFile(base, milestoneId, "SECRETS");
|
||||
if (!manifestPath) {
|
||||
throw new Error(
|
||||
`Secrets manifest not found for milestone ${milestoneId} in ${base}`,
|
||||
);
|
||||
}
|
||||
|
||||
// (b) Read and parse manifest
|
||||
const content = await readFile(manifestPath, "utf8");
|
||||
const manifest = parseSecretsManifest(content);
|
||||
|
||||
// (c) Check existing keys
|
||||
const envPath = resolve(base, ".env");
|
||||
const allKeys = manifest.entries.map((e) => e.key);
|
||||
const existingKeys = await checkExistingEnvKeys(allKeys, envPath);
|
||||
const existingSet = new Set(existingKeys);
|
||||
|
||||
// (d) Build categorization
|
||||
const existingSkipped: string[] = [];
|
||||
const alreadySkipped: string[] = [];
|
||||
const pendingEntries: SecretsManifestEntry[] = [];
|
||||
|
||||
for (const entry of manifest.entries) {
|
||||
if (existingSet.has(entry.key)) {
|
||||
existingSkipped.push(entry.key);
|
||||
} else if (entry.status === "skipped") {
|
||||
alreadySkipped.push(entry.key);
|
||||
} else if (entry.status === "pending") {
|
||||
pendingEntries.push(entry);
|
||||
}
|
||||
// collected entries that are not in env are left as-is
|
||||
}
|
||||
|
||||
// (e) Show summary screen
|
||||
await showSecretsSummary(ctx, manifest.entries, existingKeys);
|
||||
|
||||
// (f) Detect destination
|
||||
const destination = detectDestination(ctx.cwd);
|
||||
|
||||
// (g) Collect only pending keys that are not already existing
|
||||
const collected: CollectedSecret[] = [];
|
||||
for (let i = 0; i < pendingEntries.length; i++) {
|
||||
const entry = pendingEntries[i] as SecretsManifestEntry;
|
||||
const value = await collectOneSecret(
|
||||
ctx,
|
||||
i,
|
||||
pendingEntries.length,
|
||||
entry.key,
|
||||
entry.formatHint || undefined,
|
||||
entry.guidance.length > 0 ? entry.guidance : undefined,
|
||||
);
|
||||
collected.push({ key: entry.key, value });
|
||||
}
|
||||
|
||||
// (h) Update manifest entry statuses
|
||||
for (const { key, value } of collected) {
|
||||
const entry = manifest.entries.find((e) => e.key === key);
|
||||
if (entry) {
|
||||
entry.status = value != null ? "collected" : "skipped";
|
||||
}
|
||||
}
|
||||
|
||||
// (i) Write manifest back to disk
|
||||
await writeFile(manifestPath, formatSecretsManifest(manifest), "utf8");
|
||||
|
||||
// (j) Apply collected values to destination
|
||||
const provided = collected.filter((c) => c.value != null) as Array<{
|
||||
key: string;
|
||||
value: string;
|
||||
}>;
|
||||
const { applied } = await applySecrets(provided, destination, {
|
||||
envFilePath: resolve(ctx.cwd, ".env"),
|
||||
});
|
||||
|
||||
const skipped = [
|
||||
...alreadySkipped,
|
||||
...collected.filter((c) => c.value == null).map((c) => c.key),
|
||||
];
|
||||
|
||||
return { applied, skipped, existingSkipped };
|
||||
}
|
||||
|
||||
// ─── Extension ────────────────────────────────────────────────────────────────
|
||||
|
||||
export default function secureEnv(pi: ExtensionAPI) {
|
||||
pi.registerTool({
|
||||
name: "secure_env_collect",
|
||||
label: "Secure Env Collect",
|
||||
description:
|
||||
"Collect one or more env vars through a paged masked-input UI, then write them to .env, Vercel, or Convex. " +
|
||||
"Values are shown masked to the user (e.g. sk-ir***dgdh) and never echoed in tool output.",
|
||||
promptSnippet:
|
||||
"Collect and apply env vars securely without asking user to edit files manually.",
|
||||
promptGuidelines: [
|
||||
"NEVER ask the user to manually edit .env files, copy-paste into a terminal, or open a dashboard to set env vars. Always use secure_env_collect instead.",
|
||||
"When a command fails due to a missing env var (e.g. 'OPENAI_API_KEY is not set', 'Missing required environment variable', 'Invalid API key', 'authentication required'), immediately call secure_env_collect with the missing keys before retrying.",
|
||||
"When starting a new project or running setup steps that require secrets (API keys, tokens, database URLs), proactively call secure_env_collect before the first command that needs them.",
|
||||
"Detect the right destination: use 'dotenv' for local dev, 'vercel' when deploying to Vercel, 'convex' when using Convex backend.",
|
||||
"After secure_env_collect completes, re-run the originally blocked command to verify the fix worked.",
|
||||
"Never echo, log, or repeat secret values in your responses. Only report key names and applied/skipped status.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
destination: Type.Optional(
|
||||
Type.Union(
|
||||
[
|
||||
Type.Literal("dotenv"),
|
||||
Type.Literal("vercel"),
|
||||
Type.Literal("convex"),
|
||||
],
|
||||
{ description: "Where to write the collected secrets" },
|
||||
),
|
||||
),
|
||||
keys: Type.Array(
|
||||
Type.Object({
|
||||
key: Type.String({
|
||||
description: "Env var name, e.g. OPENAI_API_KEY",
|
||||
}),
|
||||
hint: Type.Optional(
|
||||
Type.String({
|
||||
description: "Format hint shown to user, e.g. 'starts with sk-'",
|
||||
}),
|
||||
),
|
||||
required: Type.Optional(Type.Boolean()),
|
||||
guidance: Type.Optional(
|
||||
Type.Array(Type.String(), {
|
||||
description: "Step-by-step guidance for finding this key",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
{ minItems: 1 },
|
||||
),
|
||||
envFilePath: Type.Optional(
|
||||
Type.String({
|
||||
description:
|
||||
"Path to .env file (dotenv only). Defaults to .env in cwd.",
|
||||
}),
|
||||
),
|
||||
environment: Type.Optional(
|
||||
Type.Union(
|
||||
[
|
||||
Type.Literal("development"),
|
||||
Type.Literal("preview"),
|
||||
Type.Literal("production"),
|
||||
],
|
||||
{ description: "Target environment (vercel only)" },
|
||||
),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, _signal, _onUpdate, ctx) {
|
||||
if (!ctx.hasUI) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Error: UI not available (interactive mode required for secure env collection).",
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: undefined as unknown,
|
||||
};
|
||||
}
|
||||
|
||||
// Auto-detect destination when not provided
|
||||
const destinationAutoDetected = params.destination == null;
|
||||
const destination = params.destination ?? detectDestination(ctx.cwd);
|
||||
|
||||
const collected: CollectedSecret[] = [];
|
||||
|
||||
// Collect one key per page
|
||||
for (let i = 0; i < params.keys.length; i++) {
|
||||
const item = params.keys[i];
|
||||
const value = await collectOneSecret(
|
||||
ctx,
|
||||
i,
|
||||
params.keys.length,
|
||||
item.key,
|
||||
item.hint,
|
||||
item.guidance,
|
||||
);
|
||||
collected.push({ key: item.key, value });
|
||||
}
|
||||
|
||||
const provided = collected.filter((c) => c.value != null) as Array<{
|
||||
key: string;
|
||||
value: string;
|
||||
}>;
|
||||
const skipped = collected
|
||||
.filter((c) => c.value == null)
|
||||
.map((c) => c.key);
|
||||
|
||||
// Apply to destination via shared helper
|
||||
const { applied, errors } = await applySecrets(provided, destination, {
|
||||
envFilePath: resolve(ctx.cwd, params.envFilePath ?? ".env"),
|
||||
environment: params.environment,
|
||||
exec: (cmd, args) => pi.exec(cmd, args),
|
||||
});
|
||||
|
||||
const details: ToolResultDetails = {
|
||||
destination,
|
||||
environment: params.environment,
|
||||
applied,
|
||||
skipped,
|
||||
...(destinationAutoDetected
|
||||
? { detectedDestination: destination }
|
||||
: {}),
|
||||
};
|
||||
|
||||
const lines = [
|
||||
`destination: ${destination}${destinationAutoDetected ? " (auto-detected)" : ""}${params.environment ? ` (${params.environment})` : ""}`,
|
||||
...applied.map((k) => `✓ ${k}: applied`),
|
||||
...skipped.map((k) => `• ${k}: skipped`),
|
||||
...errors.map((e) => `✗ ${e}`),
|
||||
];
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: lines.join("\n") }],
|
||||
details,
|
||||
isError: errors.length > 0 && applied.length === 0,
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
const count = Array.isArray(args.keys) ? args.keys.length : 0;
|
||||
return new Text(
|
||||
theme.fg("toolTitle", theme.bold("secure_env_collect ")) +
|
||||
theme.fg("muted", `→ ${args.destination ?? "auto"}`) +
|
||||
theme.fg("dim", ` ${count} key${count !== 1 ? "s" : ""}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
},
|
||||
|
||||
renderResult(result, _options, theme) {
|
||||
const details = result.details as ToolResultDetails | undefined;
|
||||
if (!details) {
|
||||
const t = result.content[0];
|
||||
return new Text(t?.type === "text" ? t.text : "", 0, 0);
|
||||
}
|
||||
const lines = [
|
||||
`${theme.fg("success", "✓")} ${details.destination}${details.environment ? ` (${details.environment})` : ""}`,
|
||||
...details.applied.map(
|
||||
(k) => ` ${theme.fg("success", "✓")} ${k}: applied`,
|
||||
),
|
||||
...details.skipped.map(
|
||||
(k) => ` ${theme.fg("warning", "•")} ${k}: skipped`,
|
||||
),
|
||||
];
|
||||
return new Text(lines.join("\n"), 0, 0);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
|
@ -1,460 +0,0 @@
|
|||
/**
|
||||
* Thin wrapper around the `gh` CLI.
|
||||
*
|
||||
* Every public function returns `GhResult<T>` — never throws.
|
||||
* Uses `execFileSync` (not `execSync`) for safety.
|
||||
*/
|
||||
|
||||
import { execFileSync } from "node:child_process";
|
||||
|
||||
// ─── Result Type ────────────────────────────────────────────────────────────
|
||||
|
||||
export interface GhResult<T> {
|
||||
ok: boolean;
|
||||
data?: T;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
function ok<T>(data: T): GhResult<T> {
|
||||
return { ok: true, data };
|
||||
}
|
||||
|
||||
function fail<T>(error: string): GhResult<T> {
|
||||
return { ok: false, error };
|
||||
}
|
||||
|
||||
// ─── gh Availability ────────────────────────────────────────────────────────
|
||||
|
||||
let _ghAvailable: boolean | null = null;
|
||||
|
||||
export function ghIsAvailable(): boolean {
|
||||
if (_ghAvailable !== null) return _ghAvailable;
|
||||
try {
|
||||
execFileSync("gh", ["--version"], {
|
||||
encoding: "utf-8",
|
||||
stdio: ["ignore", "pipe", "ignore"],
|
||||
timeout: 5_000,
|
||||
});
|
||||
_ghAvailable = true;
|
||||
} catch {
|
||||
_ghAvailable = false;
|
||||
}
|
||||
return _ghAvailable;
|
||||
}
|
||||
|
||||
/** Reset cached availability (for testing). */
|
||||
export function _resetGhCache(): void {
|
||||
_ghAvailable = null;
|
||||
}
|
||||
|
||||
// ─── Rate Limit Check ───────────────────────────────────────────────────────
|
||||
|
||||
let _rateLimitCheckedAt = 0;
|
||||
let _rateLimitOk = true;
|
||||
const RATE_LIMIT_CHECK_INTERVAL_MS = 300_000; // 5 minutes
|
||||
|
||||
export function ghHasRateLimit(cwd: string): boolean {
|
||||
const now = Date.now();
|
||||
if (now - _rateLimitCheckedAt < RATE_LIMIT_CHECK_INTERVAL_MS)
|
||||
return _rateLimitOk;
|
||||
_rateLimitCheckedAt = now;
|
||||
try {
|
||||
const raw = execFileSync(
|
||||
"gh",
|
||||
["api", "rate_limit", "--jq", ".rate.remaining"],
|
||||
{
|
||||
cwd,
|
||||
encoding: "utf-8",
|
||||
stdio: ["ignore", "pipe", "ignore"],
|
||||
timeout: 10_000,
|
||||
},
|
||||
).trim();
|
||||
const remaining = parseInt(raw, 10);
|
||||
_rateLimitOk = Number.isFinite(remaining) && remaining >= 100;
|
||||
} catch {
|
||||
// Can't check — assume OK so we don't silently disable sync
|
||||
_rateLimitOk = true;
|
||||
}
|
||||
return _rateLimitOk;
|
||||
}
|
||||
|
||||
// ─── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
const GH_TIMEOUT = 15_000;
|
||||
const MAX_BODY_LENGTH = 65_000;
|
||||
|
||||
function truncateBody(body: string): string {
|
||||
if (body.length <= MAX_BODY_LENGTH) return body;
|
||||
return (
|
||||
body.slice(0, MAX_BODY_LENGTH) +
|
||||
"\n\n---\n*Body truncated (exceeded 65K characters)*"
|
||||
);
|
||||
}
|
||||
|
||||
function runGh(args: string[], cwd: string): GhResult<string> {
|
||||
try {
|
||||
const stdout = execFileSync("gh", args, {
|
||||
cwd,
|
||||
encoding: "utf-8",
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
timeout: GH_TIMEOUT,
|
||||
}).trim();
|
||||
return ok(stdout);
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return fail(msg);
|
||||
}
|
||||
}
|
||||
|
||||
function runGhJson<T>(args: string[], cwd: string): GhResult<T> {
|
||||
const result = runGh(args, cwd);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
try {
|
||||
return ok(JSON.parse(result.data!) as T);
|
||||
} catch {
|
||||
return fail(`Failed to parse JSON: ${result.data}`);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Repo Detection ─────────────────────────────────────────────────────────
|
||||
|
||||
export function ghDetectRepo(cwd: string): GhResult<string> {
|
||||
const result = runGh(
|
||||
["repo", "view", "--json", "nameWithOwner", "--jq", ".nameWithOwner"],
|
||||
cwd,
|
||||
);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
const repo = result.data!.trim();
|
||||
if (!repo || !repo.includes("/")) return fail("Could not detect repo");
|
||||
return ok(repo);
|
||||
}
|
||||
|
||||
// ─── Issues ─────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface CreateIssueOpts {
|
||||
repo: string;
|
||||
title: string;
|
||||
body: string;
|
||||
labels?: string[];
|
||||
milestone?: number;
|
||||
parentIssue?: number;
|
||||
}
|
||||
|
||||
export function ghCreateIssue(
|
||||
cwd: string,
|
||||
opts: CreateIssueOpts,
|
||||
): GhResult<number> {
|
||||
const args = [
|
||||
"issue",
|
||||
"create",
|
||||
"--repo",
|
||||
opts.repo,
|
||||
"--title",
|
||||
opts.title,
|
||||
"--body",
|
||||
truncateBody(opts.body),
|
||||
];
|
||||
if (opts.labels?.length) {
|
||||
args.push("--label", opts.labels.join(","));
|
||||
}
|
||||
if (opts.milestone) {
|
||||
args.push("--milestone", String(opts.milestone));
|
||||
}
|
||||
|
||||
const result = runGh(args, cwd);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
|
||||
// gh issue create returns the URL; extract issue number
|
||||
const match = result.data!.match(/\/issues\/(\d+)/);
|
||||
if (!match) return fail(`Could not parse issue number from: ${result.data}`);
|
||||
const issueNumber = parseInt(match[1], 10);
|
||||
|
||||
// If parent specified, add as sub-issue via GraphQL
|
||||
if (opts.parentIssue) {
|
||||
ghAddSubIssue(cwd, opts.repo, opts.parentIssue, issueNumber);
|
||||
}
|
||||
|
||||
return ok(issueNumber);
|
||||
}
|
||||
|
||||
export function ghCloseIssue(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
issueNumber: number,
|
||||
comment?: string,
|
||||
): GhResult<void> {
|
||||
if (comment) {
|
||||
ghAddComment(cwd, repo, issueNumber, comment);
|
||||
}
|
||||
const result = runGh(
|
||||
["issue", "close", String(issueNumber), "--repo", repo],
|
||||
cwd,
|
||||
);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(undefined);
|
||||
}
|
||||
|
||||
export function ghAddComment(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
issueNumber: number,
|
||||
body: string,
|
||||
): GhResult<void> {
|
||||
const result = runGh(
|
||||
[
|
||||
"issue",
|
||||
"comment",
|
||||
String(issueNumber),
|
||||
"--repo",
|
||||
repo,
|
||||
"--body",
|
||||
truncateBody(body),
|
||||
],
|
||||
cwd,
|
||||
);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(undefined);
|
||||
}
|
||||
|
||||
// ─── Sub-Issues (GraphQL) ───────────────────────────────────────────────────
|
||||
|
||||
function ghAddSubIssue(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
parentNumber: number,
|
||||
childNumber: number,
|
||||
): GhResult<void> {
|
||||
// Get node IDs for both issues
|
||||
const parentResult = runGhJson<{ id: string }>(
|
||||
["api", `repos/${repo}/issues/${parentNumber}`, "--jq", "{id: .node_id}"],
|
||||
cwd,
|
||||
);
|
||||
const childResult = runGhJson<{ id: string }>(
|
||||
["api", `repos/${repo}/issues/${childNumber}`, "--jq", "{id: .node_id}"],
|
||||
cwd,
|
||||
);
|
||||
|
||||
if (!parentResult.ok || !childResult.ok) {
|
||||
return fail("Could not resolve issue node IDs for sub-issue linking");
|
||||
}
|
||||
|
||||
const mutation = `mutation { addSubIssue(input: { issueId: "${parentResult.data!.id}", subIssueId: "${childResult.data!.id}" }) { issue { id } } }`;
|
||||
return runGh(
|
||||
["api", "graphql", "-f", `query=${mutation}`],
|
||||
cwd,
|
||||
) as unknown as GhResult<void>;
|
||||
}
|
||||
|
||||
// ─── Milestones ─────────────────────────────────────────────────────────────
|
||||
|
||||
export function ghCreateMilestone(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
title: string,
|
||||
description: string,
|
||||
): GhResult<number> {
|
||||
const result = runGhJson<{ number: number }>(
|
||||
[
|
||||
"api",
|
||||
`repos/${repo}/milestones`,
|
||||
"-X",
|
||||
"POST",
|
||||
"-f",
|
||||
`title=${title}`,
|
||||
"-f",
|
||||
`description=${truncateBody(description)}`,
|
||||
"-f",
|
||||
"state=open",
|
||||
"--jq",
|
||||
"{number: .number}",
|
||||
],
|
||||
cwd,
|
||||
);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(result.data!.number);
|
||||
}
|
||||
|
||||
export function ghCloseMilestone(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
milestoneNumber: number,
|
||||
): GhResult<void> {
|
||||
const result = runGh(
|
||||
[
|
||||
"api",
|
||||
`repos/${repo}/milestones/${milestoneNumber}`,
|
||||
"-X",
|
||||
"PATCH",
|
||||
"-f",
|
||||
"state=closed",
|
||||
],
|
||||
cwd,
|
||||
);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(undefined);
|
||||
}
|
||||
|
||||
// ─── Pull Requests ──────────────────────────────────────────────────────────
|
||||
|
||||
export interface CreatePROpts {
|
||||
repo: string;
|
||||
base: string;
|
||||
head: string;
|
||||
title: string;
|
||||
body: string;
|
||||
draft?: boolean;
|
||||
}
|
||||
|
||||
export function ghCreatePR(cwd: string, opts: CreatePROpts): GhResult<number> {
|
||||
const args = [
|
||||
"pr",
|
||||
"create",
|
||||
"--repo",
|
||||
opts.repo,
|
||||
"--base",
|
||||
opts.base,
|
||||
"--head",
|
||||
opts.head,
|
||||
"--title",
|
||||
opts.title,
|
||||
"--body",
|
||||
truncateBody(opts.body),
|
||||
];
|
||||
if (opts.draft) args.push("--draft");
|
||||
|
||||
const result = runGh(args, cwd);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
|
||||
const match = result.data!.match(/\/pull\/(\d+)/);
|
||||
if (!match) return fail(`Could not parse PR number from: ${result.data}`);
|
||||
return ok(parseInt(match[1], 10));
|
||||
}
|
||||
|
||||
export function ghMarkPRReady(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
prNumber: number,
|
||||
): GhResult<void> {
|
||||
const result = runGh(["pr", "ready", String(prNumber), "--repo", repo], cwd);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(undefined);
|
||||
}
|
||||
|
||||
export function ghMergePR(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
prNumber: number,
|
||||
strategy: "squash" | "merge" = "squash",
|
||||
): GhResult<void> {
|
||||
const args = [
|
||||
"pr",
|
||||
"merge",
|
||||
String(prNumber),
|
||||
"--repo",
|
||||
repo,
|
||||
strategy === "squash" ? "--squash" : "--merge",
|
||||
"--delete-branch",
|
||||
];
|
||||
const result = runGh(args, cwd);
|
||||
if (!result.ok) return fail(result.error!);
|
||||
return ok(undefined);
|
||||
}
|
||||
|
||||
// ─── Projects v2 ────────────────────────────────────────────────────────────
|
||||
|
||||
export function ghAddToProject(
|
||||
cwd: string,
|
||||
repo: string,
|
||||
projectNumber: number,
|
||||
issueNumber: number,
|
||||
): GhResult<void> {
|
||||
// Get the issue's node ID first
|
||||
const issueResult = runGhJson<{ id: string }>(
|
||||
["api", `repos/${repo}/issues/${issueNumber}`, "--jq", "{id: .node_id}"],
|
||||
cwd,
|
||||
);
|
||||
if (!issueResult.ok) return fail(issueResult.error!);
|
||||
|
||||
// Get the project's node ID
|
||||
const [owner] = repo.split("/");
|
||||
const projectResult = runGhJson<{ id: string }>(
|
||||
[
|
||||
"api",
|
||||
"graphql",
|
||||
"-f",
|
||||
`query=query { user(login: "${owner}") { projectV2(number: ${projectNumber}) { id } } }`,
|
||||
"--jq",
|
||||
".data.user.projectV2.id",
|
||||
],
|
||||
cwd,
|
||||
);
|
||||
|
||||
// Try org if user fails
|
||||
let projectId: string | undefined;
|
||||
if (projectResult.ok && projectResult.data?.id) {
|
||||
projectId = projectResult.data.id;
|
||||
} else {
|
||||
const orgResult = runGhJson<{ id: string }>(
|
||||
[
|
||||
"api",
|
||||
"graphql",
|
||||
"-f",
|
||||
`query=query { organization(login: "${owner}") { projectV2(number: ${projectNumber}) { id } } }`,
|
||||
"--jq",
|
||||
".data.organization.projectV2.id",
|
||||
],
|
||||
cwd,
|
||||
);
|
||||
if (orgResult.ok) projectId = orgResult.data?.id;
|
||||
}
|
||||
|
||||
if (!projectId) return fail("Could not find project");
|
||||
|
||||
const mutation = `mutation { addProjectV2ItemById(input: { projectId: "${projectId}", contentId: "${issueResult.data!.id}" }) { item { id } } }`;
|
||||
return runGh(
|
||||
["api", "graphql", "-f", `query=${mutation}`],
|
||||
cwd,
|
||||
) as unknown as GhResult<void>;
|
||||
}
|
||||
|
||||
// ─── Branch Operations ──────────────────────────────────────────────────────
|
||||
|
||||
export function ghPushBranch(
|
||||
cwd: string,
|
||||
branch: string,
|
||||
setUpstream = true,
|
||||
): GhResult<void> {
|
||||
const args = ["git", "push"];
|
||||
if (setUpstream) args.push("-u", "origin", branch);
|
||||
else args.push("origin", branch);
|
||||
|
||||
try {
|
||||
execFileSync(args[0], args.slice(1), {
|
||||
cwd,
|
||||
encoding: "utf-8",
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
timeout: 30_000,
|
||||
});
|
||||
return ok(undefined);
|
||||
} catch (err) {
|
||||
return fail(err instanceof Error ? err.message : String(err));
|
||||
}
|
||||
}
|
||||
|
||||
export function ghCreateBranch(
|
||||
cwd: string,
|
||||
branch: string,
|
||||
from: string,
|
||||
): GhResult<void> {
|
||||
try {
|
||||
execFileSync("git", ["branch", branch, from], {
|
||||
cwd,
|
||||
encoding: "utf-8",
|
||||
stdio: ["ignore", "pipe", "pipe"],
|
||||
timeout: 10_000,
|
||||
});
|
||||
return ok(undefined);
|
||||
} catch (err) {
|
||||
return fail(err instanceof Error ? err.message : String(err));
|
||||
}
|
||||
}
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
/**
|
||||
* GitHub Sync extension for SF.
|
||||
*
|
||||
* Opt-in extension that syncs SF lifecycle events to GitHub:
|
||||
* milestones → GH Milestones + tracking issues, slices → draft PRs,
|
||||
* tasks → sub-issues with auto-close on commit.
|
||||
*
|
||||
* Integration happens via a single dynamic import in auto-post-unit.ts.
|
||||
* This index registers a `/github-sync` command for manual bootstrap
|
||||
* and status display.
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { ghIsAvailable } from "./cli.js";
|
||||
import { loadSyncMapping } from "./mapping.js";
|
||||
import { bootstrapSync } from "./sync.js";
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerCommand("github-sync", {
|
||||
description: "Bootstrap GitHub sync or show sync status",
|
||||
handler: async (args: string, ctx) => {
|
||||
const subcommand = args.trim().toLowerCase();
|
||||
|
||||
if (subcommand === "status") {
|
||||
await showStatus(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
if (subcommand === "bootstrap" || subcommand === "") {
|
||||
await runBootstrap(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.ui.notify("Usage: /github-sync [bootstrap|status]", "info");
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function showStatus(
|
||||
ctx: import("@singularity-forge/pi-coding-agent").ExtensionCommandContext,
|
||||
) {
|
||||
if (!ghIsAvailable()) {
|
||||
ctx.ui.notify(
|
||||
"GitHub sync: `gh` CLI not installed or not authenticated.",
|
||||
"warning",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const mapping = loadSyncMapping(ctx.cwd);
|
||||
if (!mapping) {
|
||||
ctx.ui.notify(
|
||||
"GitHub sync: No sync mapping found. Run `/github-sync bootstrap` to initialize.",
|
||||
"info",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const milestoneCount = Object.keys(mapping.milestones).length;
|
||||
const sliceCount = Object.keys(mapping.slices).length;
|
||||
const taskCount = Object.keys(mapping.tasks).length;
|
||||
const openMilestones = Object.values(mapping.milestones).filter(
|
||||
(m) => m.state === "open",
|
||||
).length;
|
||||
const openSlices = Object.values(mapping.slices).filter(
|
||||
(s) => s.state === "open",
|
||||
).length;
|
||||
const openTasks = Object.values(mapping.tasks).filter(
|
||||
(t) => t.state === "open",
|
||||
).length;
|
||||
|
||||
ctx.ui.notify(
|
||||
[
|
||||
`GitHub sync: repo=${mapping.repo}`,
|
||||
` Milestones: ${milestoneCount} (${openMilestones} open)`,
|
||||
` Slices: ${sliceCount} (${openSlices} open)`,
|
||||
` Tasks: ${taskCount} (${openTasks} open)`,
|
||||
].join("\n"),
|
||||
"info",
|
||||
);
|
||||
}
|
||||
|
||||
async function runBootstrap(
|
||||
ctx: import("@singularity-forge/pi-coding-agent").ExtensionCommandContext,
|
||||
) {
|
||||
if (!ghIsAvailable()) {
|
||||
ctx.ui.notify(
|
||||
"GitHub sync: `gh` CLI not installed or not authenticated.",
|
||||
"warning",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.ui.notify("GitHub sync: bootstrapping...", "info");
|
||||
|
||||
try {
|
||||
const counts = await bootstrapSync(ctx.cwd);
|
||||
if (counts.milestones === 0 && counts.slices === 0 && counts.tasks === 0) {
|
||||
ctx.ui.notify(
|
||||
"GitHub sync: everything already synced (or no milestones found).",
|
||||
"info",
|
||||
);
|
||||
} else {
|
||||
ctx.ui.notify(
|
||||
`GitHub sync: created ${counts.milestones} milestone(s), ${counts.slices} slice(s), ${counts.tasks} task(s).`,
|
||||
"info",
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
ctx.ui.notify(`GitHub sync bootstrap failed: ${err}`, "error");
|
||||
}
|
||||
}
|
||||
|
|
@ -1,118 +0,0 @@
|
|||
/**
|
||||
* Persistence layer for the GitHub sync mapping.
|
||||
*
|
||||
* The mapping lives at `.sf/github-sync.json` and tracks which SF
|
||||
* entities have been synced to which GitHub entities (issues, PRs,
|
||||
* milestones) along with their numbers and sync timestamps.
|
||||
*/
|
||||
|
||||
import { existsSync, readFileSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
import { atomicWriteSync } from "../sf/atomic-write.js";
|
||||
import type {
|
||||
MilestoneSyncRecord,
|
||||
SliceSyncRecord,
|
||||
SyncEntityRecord,
|
||||
SyncMapping,
|
||||
} from "./types.js";
|
||||
|
||||
const MAPPING_FILENAME = "github-sync.json";
|
||||
|
||||
function mappingPath(basePath: string): string {
|
||||
return join(basePath, ".sf", MAPPING_FILENAME);
|
||||
}
|
||||
|
||||
// ─── Load / Save ────────────────────────────────────────────────────────────
|
||||
|
||||
export function loadSyncMapping(basePath: string): SyncMapping | null {
|
||||
const path = mappingPath(basePath);
|
||||
if (!existsSync(path)) return null;
|
||||
try {
|
||||
const raw = readFileSync(path, "utf-8");
|
||||
const parsed = JSON.parse(raw);
|
||||
if (parsed?.version !== 1) return null;
|
||||
return parsed as SyncMapping;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export function saveSyncMapping(basePath: string, mapping: SyncMapping): void {
|
||||
const path = mappingPath(basePath);
|
||||
atomicWriteSync(path, JSON.stringify(mapping, null, 2) + "\n");
|
||||
}
|
||||
|
||||
export function createEmptyMapping(repo: string): SyncMapping {
|
||||
return {
|
||||
version: 1,
|
||||
repo,
|
||||
milestones: {},
|
||||
slices: {},
|
||||
tasks: {},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Accessors ──────────────────────────────────────────────────────────────
|
||||
|
||||
export function getMilestoneRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
): MilestoneSyncRecord | null {
|
||||
return mapping.milestones[mid] ?? null;
|
||||
}
|
||||
|
||||
export function getSliceRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
sid: string,
|
||||
): SliceSyncRecord | null {
|
||||
return mapping.slices[`${mid}/${sid}`] ?? null;
|
||||
}
|
||||
|
||||
export function getTaskRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
sid: string,
|
||||
tid: string,
|
||||
): SyncEntityRecord | null {
|
||||
return mapping.tasks[`${mid}/${sid}/${tid}`] ?? null;
|
||||
}
|
||||
|
||||
export function getTaskIssueNumber(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
sid: string,
|
||||
tid: string,
|
||||
): number | null {
|
||||
const record = getTaskRecord(mapping, mid, sid, tid);
|
||||
return record?.issueNumber ?? null;
|
||||
}
|
||||
|
||||
// ─── Mutators ───────────────────────────────────────────────────────────────
|
||||
|
||||
export function setMilestoneRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
record: MilestoneSyncRecord,
|
||||
): void {
|
||||
mapping.milestones[mid] = record;
|
||||
}
|
||||
|
||||
export function setSliceRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
sid: string,
|
||||
record: SliceSyncRecord,
|
||||
): void {
|
||||
mapping.slices[`${mid}/${sid}`] = record;
|
||||
}
|
||||
|
||||
export function setTaskRecord(
|
||||
mapping: SyncMapping,
|
||||
mid: string,
|
||||
sid: string,
|
||||
tid: string,
|
||||
record: SyncEntityRecord,
|
||||
): void {
|
||||
mapping.tasks[`${mid}/${sid}/${tid}`] = record;
|
||||
}
|
||||
|
|
@ -1,602 +0,0 @@
|
|||
/**
|
||||
* Core GitHub sync engine.
|
||||
*
|
||||
* Entry point: `runGitHubSync()` — called from the SF post-unit pipeline.
|
||||
* Routes to per-event sync functions based on the unit type, reads SF
|
||||
* files to build GitHub entities, and persists the sync mapping.
|
||||
*
|
||||
* All errors are caught internally — sync failures never block execution.
|
||||
*/
|
||||
|
||||
import { existsSync, readdirSync } from "node:fs";
|
||||
import { join } from "node:path";
|
||||
import { debugLog } from "../sf/debug-logger.js";
|
||||
import { loadFile, parseSummary } from "../sf/files.js";
|
||||
import { parsePlan, parseRoadmap } from "../sf/parsers.js";
|
||||
import {
|
||||
resolveMilestoneFile,
|
||||
resolveSliceFile,
|
||||
resolveTaskFile,
|
||||
} from "../sf/paths.js";
|
||||
import { loadEffectiveSFPreferences } from "../sf/preferences.js";
|
||||
import {
|
||||
ghAddComment,
|
||||
ghAddToProject,
|
||||
ghCloseIssue,
|
||||
ghCloseMilestone,
|
||||
ghCreateBranch,
|
||||
ghCreateIssue,
|
||||
ghCreateMilestone,
|
||||
ghCreatePR,
|
||||
ghDetectRepo,
|
||||
ghHasRateLimit,
|
||||
ghIsAvailable,
|
||||
ghMarkPRReady,
|
||||
ghMergePR,
|
||||
ghPushBranch,
|
||||
} from "./cli.js";
|
||||
import {
|
||||
createEmptyMapping,
|
||||
getMilestoneRecord,
|
||||
getSliceRecord,
|
||||
getTaskRecord,
|
||||
loadSyncMapping,
|
||||
saveSyncMapping,
|
||||
setMilestoneRecord,
|
||||
setSliceRecord,
|
||||
setTaskRecord,
|
||||
} from "./mapping.js";
|
||||
import {
|
||||
formatMilestoneIssueBody,
|
||||
formatSlicePRBody,
|
||||
formatSummaryComment,
|
||||
formatTaskIssueBody,
|
||||
} from "./templates.js";
|
||||
import type { GitHubSyncConfig, SyncMapping } from "./types.js";
|
||||
|
||||
// ─── Entry Point ────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Main sync entry point — called from SF post-unit pipeline.
|
||||
* Routes to the appropriate sync function based on unit type.
|
||||
*/
|
||||
export async function runGitHubSync(
|
||||
basePath: string,
|
||||
unitType: string,
|
||||
unitId: string,
|
||||
): Promise<void> {
|
||||
try {
|
||||
const config = loadGitHubSyncConfig(basePath);
|
||||
if (!config?.enabled) return;
|
||||
if (!ghIsAvailable()) {
|
||||
debugLog("github-sync", { skip: "gh CLI not available" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Resolve repo
|
||||
const repo = config.repo ?? resolveRepo(basePath);
|
||||
if (!repo) {
|
||||
debugLog("github-sync", { skip: "could not detect repo" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Rate limit check
|
||||
if (!ghHasRateLimit(basePath)) {
|
||||
debugLog("github-sync", { skip: "rate limit low" });
|
||||
return;
|
||||
}
|
||||
|
||||
// Load or init mapping
|
||||
const mapping = loadSyncMapping(basePath) ?? createEmptyMapping(repo);
|
||||
mapping.repo = repo;
|
||||
|
||||
// Parse unit ID parts
|
||||
const parts = unitId.split("/");
|
||||
const [mid, sid, tid] = parts;
|
||||
|
||||
// Route by unit type
|
||||
switch (unitType) {
|
||||
case "plan-milestone":
|
||||
if (mid) await syncMilestonePlan(basePath, mapping, config, mid);
|
||||
break;
|
||||
case "plan-slice":
|
||||
case "research-slice":
|
||||
if (mid && sid)
|
||||
await syncSlicePlan(basePath, mapping, config, mid, sid);
|
||||
break;
|
||||
case "execute-task":
|
||||
case "reactive-execute":
|
||||
if (mid && sid && tid)
|
||||
await syncTaskComplete(basePath, mapping, config, mid, sid, tid);
|
||||
break;
|
||||
case "complete-slice":
|
||||
if (mid && sid)
|
||||
await syncSliceComplete(basePath, mapping, config, mid, sid);
|
||||
break;
|
||||
case "complete-milestone":
|
||||
if (mid) await syncMilestoneComplete(basePath, mapping, config, mid);
|
||||
break;
|
||||
}
|
||||
|
||||
saveSyncMapping(basePath, mapping);
|
||||
} catch (err) {
|
||||
debugLog("github-sync", { error: String(err) });
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Per-Event Sync Functions ───────────────────────────────────────────────
|
||||
|
||||
async function syncMilestonePlan(
|
||||
basePath: string,
|
||||
mapping: SyncMapping,
|
||||
config: GitHubSyncConfig,
|
||||
mid: string,
|
||||
): Promise<void> {
|
||||
// Skip if already synced
|
||||
if (getMilestoneRecord(mapping, mid)) return;
|
||||
|
||||
// Load roadmap data
|
||||
const roadmapPath = resolveMilestoneFile(basePath, mid, "ROADMAP");
|
||||
if (!roadmapPath) return;
|
||||
const content = await loadFile(roadmapPath);
|
||||
if (!content) return;
|
||||
|
||||
const roadmap = parseRoadmap(content);
|
||||
const title = `${mid}: ${roadmap.title || "Milestone"}`;
|
||||
|
||||
// Create GitHub Milestone
|
||||
const milestoneResult = ghCreateMilestone(
|
||||
basePath,
|
||||
mapping.repo,
|
||||
title,
|
||||
roadmap.vision || "",
|
||||
);
|
||||
if (!milestoneResult.ok) {
|
||||
debugLog("github-sync", {
|
||||
phase: "create-milestone",
|
||||
error: milestoneResult.error,
|
||||
});
|
||||
return;
|
||||
}
|
||||
const ghMilestoneNumber = milestoneResult.data!;
|
||||
|
||||
// Create tracking issue
|
||||
const issueBody = formatMilestoneIssueBody({
|
||||
id: mid,
|
||||
title: roadmap.title || "Milestone",
|
||||
vision: roadmap.vision,
|
||||
successCriteria: roadmap.successCriteria,
|
||||
slices: roadmap.slices?.map((s) => ({
|
||||
id: s.id,
|
||||
title: s.title,
|
||||
})),
|
||||
});
|
||||
|
||||
const issueResult = ghCreateIssue(basePath, {
|
||||
repo: mapping.repo,
|
||||
title: `${mid}: ${roadmap.title || "Milestone"} — Tracking`,
|
||||
body: issueBody,
|
||||
labels: config.labels,
|
||||
milestone: ghMilestoneNumber,
|
||||
});
|
||||
if (!issueResult.ok) {
|
||||
debugLog("github-sync", {
|
||||
phase: "create-tracking-issue",
|
||||
error: issueResult.error,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Add to project if configured
|
||||
if (config.project) {
|
||||
ghAddToProject(basePath, mapping.repo, config.project, issueResult.data!);
|
||||
}
|
||||
|
||||
setMilestoneRecord(mapping, mid, {
|
||||
issueNumber: issueResult.data!,
|
||||
ghMilestoneNumber,
|
||||
lastSyncedAt: new Date().toISOString(),
|
||||
state: "open",
|
||||
});
|
||||
|
||||
debugLog("github-sync", {
|
||||
phase: "milestone-synced",
|
||||
mid,
|
||||
milestone: ghMilestoneNumber,
|
||||
issue: issueResult.data,
|
||||
});
|
||||
}
|
||||
|
||||
async function syncSlicePlan(
|
||||
basePath: string,
|
||||
mapping: SyncMapping,
|
||||
config: GitHubSyncConfig,
|
||||
mid: string,
|
||||
sid: string,
|
||||
): Promise<void> {
|
||||
// Skip if already synced
|
||||
if (getSliceRecord(mapping, mid, sid)) return;
|
||||
|
||||
// Ensure milestone is synced first
|
||||
if (!getMilestoneRecord(mapping, mid)) {
|
||||
await syncMilestonePlan(basePath, mapping, config, mid);
|
||||
}
|
||||
const milestoneRecord = getMilestoneRecord(mapping, mid);
|
||||
|
||||
// Load slice plan
|
||||
const planPath = resolveSliceFile(basePath, mid, sid, "PLAN");
|
||||
if (!planPath) return;
|
||||
const content = await loadFile(planPath);
|
||||
if (!content) return;
|
||||
|
||||
const plan = parsePlan(content);
|
||||
const sliceBranch = `milestone/${mid}/${sid}`;
|
||||
const milestoneBranch = `milestone/${mid}`;
|
||||
|
||||
// Create task sub-issues first (so we can link them in the PR body)
|
||||
const taskIssueNumbers: Array<{
|
||||
id: string;
|
||||
title: string;
|
||||
issueNumber?: number;
|
||||
}> = [];
|
||||
|
||||
if (plan.tasks) {
|
||||
for (const task of plan.tasks) {
|
||||
// Skip if already synced
|
||||
if (getTaskRecord(mapping, mid, sid, task.id)) {
|
||||
const existing = getTaskRecord(mapping, mid, sid, task.id)!;
|
||||
taskIssueNumbers.push({
|
||||
id: task.id,
|
||||
title: task.title,
|
||||
issueNumber: existing.issueNumber,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
const taskBody = formatTaskIssueBody({
|
||||
id: task.id,
|
||||
title: task.title,
|
||||
description: task.description,
|
||||
files: task.files,
|
||||
verifyCriteria: task.verify ? [task.verify] : undefined,
|
||||
});
|
||||
|
||||
const taskResult = ghCreateIssue(basePath, {
|
||||
repo: mapping.repo,
|
||||
title: `${mid}/${sid}/${task.id}: ${task.title}`,
|
||||
body: taskBody,
|
||||
labels: config.labels,
|
||||
milestone: milestoneRecord?.ghMilestoneNumber,
|
||||
parentIssue: milestoneRecord?.issueNumber,
|
||||
});
|
||||
|
||||
if (taskResult.ok) {
|
||||
setTaskRecord(mapping, mid, sid, task.id, {
|
||||
issueNumber: taskResult.data!,
|
||||
lastSyncedAt: new Date().toISOString(),
|
||||
state: "open",
|
||||
});
|
||||
taskIssueNumbers.push({
|
||||
id: task.id,
|
||||
title: task.title,
|
||||
issueNumber: taskResult.data!,
|
||||
});
|
||||
|
||||
if (config.project) {
|
||||
ghAddToProject(
|
||||
basePath,
|
||||
mapping.repo,
|
||||
config.project,
|
||||
taskResult.data!,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
taskIssueNumbers.push({ id: task.id, title: task.title });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (config.slice_prs === false) {
|
||||
// Slice PRs disabled — just record without PR
|
||||
setSliceRecord(mapping, mid, sid, {
|
||||
issueNumber: 0,
|
||||
prNumber: 0,
|
||||
branch: sliceBranch,
|
||||
lastSyncedAt: new Date().toISOString(),
|
||||
state: "open",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Create slice branch from milestone branch
|
||||
const branchResult = ghCreateBranch(basePath, sliceBranch, milestoneBranch);
|
||||
if (!branchResult.ok) {
|
||||
debugLog("github-sync", {
|
||||
phase: "create-slice-branch",
|
||||
error: branchResult.error,
|
||||
});
|
||||
// Branch might already exist — continue anyway
|
||||
}
|
||||
|
||||
// Push the slice branch
|
||||
const pushResult = ghPushBranch(basePath, sliceBranch);
|
||||
if (!pushResult.ok) {
|
||||
debugLog("github-sync", {
|
||||
phase: "push-slice-branch",
|
||||
error: pushResult.error,
|
||||
});
|
||||
}
|
||||
|
||||
// Create draft PR
|
||||
const prBody = formatSlicePRBody({
|
||||
id: sid,
|
||||
title: plan.title || sid,
|
||||
goal: plan.goal,
|
||||
mustHaves: plan.mustHaves,
|
||||
demoCriterion: plan.demo,
|
||||
tasks: taskIssueNumbers,
|
||||
});
|
||||
|
||||
const prResult = ghCreatePR(basePath, {
|
||||
repo: mapping.repo,
|
||||
base: milestoneBranch,
|
||||
head: sliceBranch,
|
||||
title: `${sid}: ${plan.title || sid}`,
|
||||
body: prBody,
|
||||
draft: true,
|
||||
});
|
||||
|
||||
const prNumber = prResult.ok ? prResult.data! : 0;
|
||||
if (!prResult.ok) {
|
||||
debugLog("github-sync", {
|
||||
phase: "create-slice-pr",
|
||||
error: prResult.error,
|
||||
});
|
||||
}
|
||||
|
||||
setSliceRecord(mapping, mid, sid, {
|
||||
issueNumber: 0, // Slice doesn't get its own issue — tracked via PR
|
||||
prNumber,
|
||||
branch: sliceBranch,
|
||||
lastSyncedAt: new Date().toISOString(),
|
||||
state: "open",
|
||||
});
|
||||
|
||||
debugLog("github-sync", {
|
||||
phase: "slice-synced",
|
||||
mid,
|
||||
sid,
|
||||
pr: prNumber,
|
||||
taskIssues: taskIssueNumbers.filter((t) => t.issueNumber).length,
|
||||
});
|
||||
}
|
||||
|
||||
async function syncTaskComplete(
|
||||
basePath: string,
|
||||
mapping: SyncMapping,
|
||||
_config: GitHubSyncConfig,
|
||||
mid: string,
|
||||
sid: string,
|
||||
tid: string,
|
||||
): Promise<void> {
|
||||
const taskRecord = getTaskRecord(mapping, mid, sid, tid);
|
||||
if (!taskRecord || taskRecord.state === "closed") return;
|
||||
|
||||
// Load task summary
|
||||
const summaryPath = resolveTaskFile(basePath, mid, sid, tid, "SUMMARY");
|
||||
if (summaryPath) {
|
||||
const content = await loadFile(summaryPath);
|
||||
if (content) {
|
||||
const summary = parseSummary(content);
|
||||
const comment = formatSummaryComment({
|
||||
oneLiner: summary.oneLiner,
|
||||
body: summary.whatHappened,
|
||||
frontmatter: summary.frontmatter as unknown as Record<string, unknown>,
|
||||
});
|
||||
ghAddComment(basePath, mapping.repo, taskRecord.issueNumber, comment);
|
||||
}
|
||||
}
|
||||
|
||||
// Close the task issue
|
||||
ghCloseIssue(basePath, mapping.repo, taskRecord.issueNumber);
|
||||
|
||||
taskRecord.state = "closed";
|
||||
taskRecord.lastSyncedAt = new Date().toISOString();
|
||||
setTaskRecord(mapping, mid, sid, tid, taskRecord);
|
||||
|
||||
debugLog("github-sync", {
|
||||
phase: "task-closed",
|
||||
mid,
|
||||
sid,
|
||||
tid,
|
||||
issue: taskRecord.issueNumber,
|
||||
});
|
||||
}
|
||||
|
||||
async function syncSliceComplete(
|
||||
basePath: string,
|
||||
mapping: SyncMapping,
|
||||
_config: GitHubSyncConfig,
|
||||
mid: string,
|
||||
sid: string,
|
||||
): Promise<void> {
|
||||
const sliceRecord = getSliceRecord(mapping, mid, sid);
|
||||
if (!sliceRecord || sliceRecord.state === "closed") return;
|
||||
|
||||
// Post slice summary as PR comment
|
||||
const summaryPath = resolveSliceFile(basePath, mid, sid, "SUMMARY");
|
||||
if (summaryPath && sliceRecord.prNumber) {
|
||||
const content = await loadFile(summaryPath);
|
||||
if (content) {
|
||||
const summary = parseSummary(content);
|
||||
const comment = formatSummaryComment({
|
||||
oneLiner: summary.oneLiner,
|
||||
body: summary.whatHappened,
|
||||
frontmatter: summary.frontmatter as unknown as Record<string, unknown>,
|
||||
});
|
||||
ghAddComment(basePath, mapping.repo, sliceRecord.prNumber, comment);
|
||||
}
|
||||
}
|
||||
|
||||
// Mark PR ready for review, then merge
|
||||
if (sliceRecord.prNumber) {
|
||||
ghMarkPRReady(basePath, mapping.repo, sliceRecord.prNumber);
|
||||
// Squash-merge into milestone branch
|
||||
ghMergePR(basePath, mapping.repo, sliceRecord.prNumber, "squash");
|
||||
}
|
||||
|
||||
sliceRecord.state = "closed";
|
||||
sliceRecord.lastSyncedAt = new Date().toISOString();
|
||||
setSliceRecord(mapping, mid, sid, sliceRecord);
|
||||
|
||||
debugLog("github-sync", {
|
||||
phase: "slice-completed",
|
||||
mid,
|
||||
sid,
|
||||
pr: sliceRecord.prNumber,
|
||||
});
|
||||
}
|
||||
|
||||
async function syncMilestoneComplete(
|
||||
basePath: string,
|
||||
mapping: SyncMapping,
|
||||
_config: GitHubSyncConfig,
|
||||
mid: string,
|
||||
): Promise<void> {
|
||||
const record = getMilestoneRecord(mapping, mid);
|
||||
if (!record || record.state === "closed") return;
|
||||
|
||||
// Close tracking issue
|
||||
ghCloseIssue(
|
||||
basePath,
|
||||
mapping.repo,
|
||||
record.issueNumber,
|
||||
`Milestone ${mid} completed.`,
|
||||
);
|
||||
|
||||
// Close GitHub milestone
|
||||
ghCloseMilestone(basePath, mapping.repo, record.ghMilestoneNumber);
|
||||
|
||||
record.state = "closed";
|
||||
record.lastSyncedAt = new Date().toISOString();
|
||||
setMilestoneRecord(mapping, mid, record);
|
||||
|
||||
debugLog("github-sync", { phase: "milestone-completed", mid });
|
||||
}
|
||||
|
||||
// ─── Bootstrap ──────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Walk the `.sf/milestones/` tree and create GitHub entities for any
|
||||
* that are missing from the sync mapping. Safe to run multiple times.
|
||||
*/
|
||||
export async function bootstrapSync(basePath: string): Promise<{
|
||||
milestones: number;
|
||||
slices: number;
|
||||
tasks: number;
|
||||
}> {
|
||||
const config = loadGitHubSyncConfig(basePath);
|
||||
if (!config?.enabled) return { milestones: 0, slices: 0, tasks: 0 };
|
||||
if (!ghIsAvailable()) return { milestones: 0, slices: 0, tasks: 0 };
|
||||
|
||||
const repo = config.repo ?? resolveRepo(basePath);
|
||||
if (!repo) return { milestones: 0, slices: 0, tasks: 0 };
|
||||
|
||||
const mapping = loadSyncMapping(basePath) ?? createEmptyMapping(repo);
|
||||
mapping.repo = repo;
|
||||
|
||||
const taskCountBefore = Object.keys(mapping.tasks).length;
|
||||
const counts = { milestones: 0, slices: 0, tasks: 0 };
|
||||
const milestonesDir = join(basePath, ".sf", "milestones");
|
||||
if (!existsSync(milestonesDir)) return counts;
|
||||
|
||||
const milestoneIds = readdirSync(milestonesDir, { withFileTypes: true })
|
||||
.filter((d) => d.isDirectory())
|
||||
.map((d) => d.name)
|
||||
.sort();
|
||||
|
||||
for (const mid of milestoneIds) {
|
||||
if (!getMilestoneRecord(mapping, mid)) {
|
||||
await syncMilestonePlan(basePath, mapping, config, mid);
|
||||
counts.milestones++;
|
||||
}
|
||||
|
||||
// Find slices
|
||||
const slicesDir = join(milestonesDir, mid, "slices");
|
||||
if (!existsSync(slicesDir)) continue;
|
||||
|
||||
const sliceIds = readdirSync(slicesDir, { withFileTypes: true })
|
||||
.filter((d) => d.isDirectory())
|
||||
.map((d) => d.name)
|
||||
.sort();
|
||||
|
||||
for (const sid of sliceIds) {
|
||||
if (!getSliceRecord(mapping, mid, sid)) {
|
||||
await syncSlicePlan(basePath, mapping, config, mid, sid);
|
||||
counts.slices++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
counts.tasks = Object.keys(mapping.tasks).length - taskCountBefore;
|
||||
saveSyncMapping(basePath, mapping);
|
||||
return counts;
|
||||
}
|
||||
|
||||
// ─── Config Loading ─────────────────────────────────────────────────────────
|
||||
|
||||
let _cachedConfig: GitHubSyncConfig | null | undefined;
|
||||
|
||||
function loadGitHubSyncConfig(_basePath: string): GitHubSyncConfig | null {
|
||||
if (_cachedConfig !== undefined) return _cachedConfig;
|
||||
try {
|
||||
const prefs = loadEffectiveSFPreferences();
|
||||
const github = (prefs?.preferences as Record<string, unknown>)?.github;
|
||||
if (!github || typeof github !== "object") {
|
||||
_cachedConfig = null;
|
||||
return null;
|
||||
}
|
||||
_cachedConfig = github as GitHubSyncConfig;
|
||||
return _cachedConfig;
|
||||
} catch {
|
||||
_cachedConfig = null;
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/** Reset config cache (for testing). */
|
||||
export function _resetConfigCache(): void {
|
||||
_cachedConfig = undefined;
|
||||
}
|
||||
|
||||
function resolveRepo(basePath: string): string | null {
|
||||
const result = ghDetectRepo(basePath);
|
||||
return result.ok ? result.data! : null;
|
||||
}
|
||||
|
||||
// ─── Commit Linking ─────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Look up the GitHub issue number for a task so the commit message
|
||||
* can include `Resolves #N`. Called from git-service commit building.
|
||||
*/
|
||||
export function getTaskIssueNumberForCommit(
|
||||
basePath: string,
|
||||
mid: string,
|
||||
sid: string,
|
||||
tid: string,
|
||||
): number | null {
|
||||
try {
|
||||
const config = loadGitHubSyncConfig(basePath);
|
||||
if (!config?.enabled) return null;
|
||||
if (config.auto_link_commits === false) return null;
|
||||
|
||||
const mapping = loadSyncMapping(basePath);
|
||||
if (!mapping) return null;
|
||||
|
||||
const record = getTaskRecord(mapping, mid, sid, tid);
|
||||
return record?.issueNumber ?? null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
/**
|
||||
* Markdown formatters for GitHub issue bodies, PR descriptions,
|
||||
* and summary comments.
|
||||
*
|
||||
* All functions produce GitHub-flavored markdown strings ready
|
||||
* for the `gh` CLI body parameters.
|
||||
*/
|
||||
|
||||
// ─── Milestone Issue Body ───────────────────────────────────────────────────
|
||||
|
||||
export interface MilestoneData {
|
||||
id: string;
|
||||
title: string;
|
||||
vision?: string;
|
||||
successCriteria?: string[];
|
||||
slices?: Array<{ id: string; title: string; taskCount?: number }>;
|
||||
}
|
||||
|
||||
export function formatMilestoneIssueBody(data: MilestoneData): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push(`# ${data.id}: ${data.title}`);
|
||||
lines.push("");
|
||||
|
||||
if (data.vision) {
|
||||
lines.push("## Vision");
|
||||
lines.push(data.vision);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.successCriteria?.length) {
|
||||
lines.push("## Success Criteria");
|
||||
for (const criterion of data.successCriteria) {
|
||||
lines.push(`- [ ] ${criterion}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.slices?.length) {
|
||||
lines.push("## Slices");
|
||||
lines.push("");
|
||||
lines.push("| Slice | Title | Tasks |");
|
||||
lines.push("|-------|-------|-------|");
|
||||
for (const slice of data.slices) {
|
||||
lines.push(
|
||||
`| ${slice.id} | ${slice.title} | ${slice.taskCount ?? "—"} |`,
|
||||
);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
lines.push("---");
|
||||
lines.push("*Auto-generated by SF GitHub Sync*");
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ─── Slice PR Body ──────────────────────────────────────────────────────────
|
||||
|
||||
export interface SliceData {
|
||||
id: string;
|
||||
title: string;
|
||||
goal?: string;
|
||||
mustHaves?: string[];
|
||||
demoCriterion?: string;
|
||||
tasks?: Array<{ id: string; title: string; issueNumber?: number }>;
|
||||
}
|
||||
|
||||
export function formatSlicePRBody(data: SliceData): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push(`## ${data.id}: ${data.title}`);
|
||||
lines.push("");
|
||||
|
||||
if (data.goal) {
|
||||
lines.push("### Goal");
|
||||
lines.push(data.goal);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.mustHaves?.length) {
|
||||
lines.push("### Must-Haves");
|
||||
for (const item of data.mustHaves) {
|
||||
lines.push(`- ${item}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.demoCriterion) {
|
||||
lines.push("### Demo Criterion");
|
||||
lines.push(data.demoCriterion);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.tasks?.length) {
|
||||
lines.push("### Tasks");
|
||||
for (const task of data.tasks) {
|
||||
const ref = task.issueNumber ? ` (#${task.issueNumber})` : "";
|
||||
lines.push(`- [ ] ${task.id}: ${task.title}${ref}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
lines.push("---");
|
||||
lines.push("*Auto-generated by SF GitHub Sync*");
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ─── Task Issue Body ────────────────────────────────────────────────────────
|
||||
|
||||
export interface TaskData {
|
||||
id: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
files?: string[];
|
||||
verifyCriteria?: string[];
|
||||
}
|
||||
|
||||
export function formatTaskIssueBody(data: TaskData): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
lines.push(`## ${data.id}: ${data.title}`);
|
||||
lines.push("");
|
||||
|
||||
if (data.description) {
|
||||
lines.push(data.description);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.files?.length) {
|
||||
lines.push("### Files");
|
||||
for (const file of data.files) {
|
||||
lines.push(`- \`${file}\``);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.verifyCriteria?.length) {
|
||||
lines.push("### Verification");
|
||||
for (const criterion of data.verifyCriteria) {
|
||||
lines.push(`- [ ] ${criterion}`);
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ─── Summary Comment ────────────────────────────────────────────────────────
|
||||
|
||||
export interface SummaryData {
|
||||
oneLiner?: string;
|
||||
body?: string;
|
||||
frontmatter?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export function formatSummaryComment(data: SummaryData): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
if (data.oneLiner) {
|
||||
lines.push(`**Summary:** ${data.oneLiner}`);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.body) {
|
||||
lines.push(data.body);
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
if (data.frontmatter && Object.keys(data.frontmatter).length > 0) {
|
||||
lines.push("<details>");
|
||||
lines.push("<summary>Metadata</summary>");
|
||||
lines.push("");
|
||||
lines.push("```yaml");
|
||||
for (const [key, value] of Object.entries(data.frontmatter)) {
|
||||
lines.push(`${key}: ${JSON.stringify(value)}`);
|
||||
}
|
||||
lines.push("```");
|
||||
lines.push("");
|
||||
lines.push("</details>");
|
||||
}
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { beforeEach, describe, it } from 'vitest';
|
||||
import { _resetGhCache, ghIsAvailable } from "../cli.ts";
|
||||
|
||||
describe("cli", () => {
|
||||
beforeEach(() => {
|
||||
_resetGhCache();
|
||||
});
|
||||
|
||||
it("ghIsAvailable returns boolean", () => {
|
||||
const result = ghIsAvailable();
|
||||
assert.equal(typeof result, "boolean");
|
||||
});
|
||||
|
||||
it("ghIsAvailable caches result", () => {
|
||||
const first = ghIsAvailable();
|
||||
const second = ghIsAvailable();
|
||||
assert.equal(first, second);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from 'vitest';
|
||||
import { buildTaskCommitMessage } from "../../sf/git-service.ts";
|
||||
|
||||
describe("commit linking", () => {
|
||||
it("appends Resolves #N when issueNumber is set", () => {
|
||||
const msg = buildTaskCommitMessage({
|
||||
taskId: "S01/T02",
|
||||
taskTitle: "implement auth",
|
||||
issueNumber: 43,
|
||||
});
|
||||
assert.ok(msg.includes("Resolves #43"), "should include Resolves trailer");
|
||||
assert.ok(msg.startsWith("feat:"), "subject line has no scope");
|
||||
assert.ok(msg.includes("SF-Task: S01/T02"), "SF-Task trailer present");
|
||||
});
|
||||
|
||||
it("includes both key files and Resolves #N", () => {
|
||||
const msg = buildTaskCommitMessage({
|
||||
taskId: "S01/T02",
|
||||
taskTitle: "implement auth",
|
||||
keyFiles: ["src/auth.ts"],
|
||||
issueNumber: 43,
|
||||
});
|
||||
assert.ok(msg.includes("- src/auth.ts"), "key files present");
|
||||
assert.ok(msg.includes("Resolves #43"), "Resolves trailer present");
|
||||
assert.ok(msg.includes("SF-Task: S01/T02"), "SF-Task trailer present");
|
||||
// SF-Task should come after key files but before Resolves
|
||||
const keyFilesIdx = msg.indexOf("- src/auth.ts");
|
||||
const taskIdx = msg.indexOf("SF-Task: S01/T02");
|
||||
const resolvesIdx = msg.indexOf("Resolves #43");
|
||||
assert.ok(taskIdx > keyFilesIdx, "SF-Task after key files");
|
||||
assert.ok(resolvesIdx > taskIdx, "Resolves after SF-Task");
|
||||
});
|
||||
|
||||
it("no Resolves trailer when issueNumber is not set", () => {
|
||||
const msg = buildTaskCommitMessage({
|
||||
taskId: "S01/T02",
|
||||
taskTitle: "implement auth",
|
||||
});
|
||||
assert.ok(!msg.includes("Resolves"), "no Resolves when no issueNumber");
|
||||
assert.ok(
|
||||
msg.includes("SF-Task: S01/T02"),
|
||||
"SF-Task trailer still present",
|
||||
);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,108 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { mkdirSync, mkdtempSync, rmSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { afterEach, beforeEach, describe, it } from 'vitest';
|
||||
import {
|
||||
createEmptyMapping,
|
||||
getMilestoneRecord,
|
||||
getSliceRecord,
|
||||
getTaskIssueNumber,
|
||||
getTaskRecord,
|
||||
loadSyncMapping,
|
||||
saveSyncMapping,
|
||||
setMilestoneRecord,
|
||||
setSliceRecord,
|
||||
setTaskRecord,
|
||||
} from "../mapping.ts";
|
||||
import type {
|
||||
MilestoneSyncRecord,
|
||||
SliceSyncRecord,
|
||||
SyncEntityRecord,
|
||||
} from "../types.ts";
|
||||
|
||||
describe("mapping", () => {
|
||||
let tmpDir: string;
|
||||
|
||||
beforeEach(() => {
|
||||
tmpDir = mkdtempSync(join(tmpdir(), "sf-sync-test-"));
|
||||
mkdirSync(join(tmpDir, ".sf"), { recursive: true });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
rmSync(tmpDir, { recursive: true, force: true });
|
||||
});
|
||||
|
||||
it("loadSyncMapping returns null when no file exists", () => {
|
||||
const result = loadSyncMapping(tmpDir);
|
||||
assert.equal(result, null);
|
||||
});
|
||||
|
||||
it("round-trips save/load", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
saveSyncMapping(tmpDir, mapping);
|
||||
const loaded = loadSyncMapping(tmpDir);
|
||||
assert.deepEqual(loaded, mapping);
|
||||
});
|
||||
|
||||
it("createEmptyMapping has correct structure", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
assert.equal(mapping.version, 1);
|
||||
assert.equal(mapping.repo, "owner/repo");
|
||||
assert.deepEqual(mapping.milestones, {});
|
||||
assert.deepEqual(mapping.slices, {});
|
||||
assert.deepEqual(mapping.tasks, {});
|
||||
});
|
||||
|
||||
it("milestone record accessors work", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
assert.equal(getMilestoneRecord(mapping, "M001"), null);
|
||||
|
||||
const record: MilestoneSyncRecord = {
|
||||
issueNumber: 42,
|
||||
ghMilestoneNumber: 1,
|
||||
lastSyncedAt: "2025-01-01T00:00:00Z",
|
||||
state: "open",
|
||||
};
|
||||
setMilestoneRecord(mapping, "M001", record);
|
||||
assert.deepEqual(getMilestoneRecord(mapping, "M001"), record);
|
||||
});
|
||||
|
||||
it("slice record accessors work", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
assert.equal(getSliceRecord(mapping, "M001", "S01"), null);
|
||||
|
||||
const record: SliceSyncRecord = {
|
||||
issueNumber: 0,
|
||||
prNumber: 50,
|
||||
branch: "milestone/M001/S01",
|
||||
lastSyncedAt: "2025-01-01T00:00:00Z",
|
||||
state: "open",
|
||||
};
|
||||
setSliceRecord(mapping, "M001", "S01", record);
|
||||
assert.deepEqual(getSliceRecord(mapping, "M001", "S01"), record);
|
||||
});
|
||||
|
||||
it("task record accessors work", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
assert.equal(getTaskRecord(mapping, "M001", "S01", "T01"), null);
|
||||
assert.equal(getTaskIssueNumber(mapping, "M001", "S01", "T01"), null);
|
||||
|
||||
const record: SyncEntityRecord = {
|
||||
issueNumber: 43,
|
||||
lastSyncedAt: "2025-01-01T00:00:00Z",
|
||||
state: "open",
|
||||
};
|
||||
setTaskRecord(mapping, "M001", "S01", "T01", record);
|
||||
assert.deepEqual(getTaskRecord(mapping, "M001", "S01", "T01"), record);
|
||||
assert.equal(getTaskIssueNumber(mapping, "M001", "S01", "T01"), 43);
|
||||
});
|
||||
|
||||
it("rejects mapping with wrong version", () => {
|
||||
const mapping = createEmptyMapping("owner/repo");
|
||||
(mapping as any).version = 2;
|
||||
saveSyncMapping(tmpDir, mapping);
|
||||
const loaded = loadSyncMapping(tmpDir);
|
||||
assert.equal(loaded, null);
|
||||
});
|
||||
});
|
||||
|
|
@ -1,110 +0,0 @@
|
|||
import assert from "node:assert/strict";
|
||||
import { describe, it } from 'vitest';
|
||||
import {
|
||||
formatMilestoneIssueBody,
|
||||
formatSlicePRBody,
|
||||
formatSummaryComment,
|
||||
formatTaskIssueBody,
|
||||
} from "../templates.ts";
|
||||
|
||||
describe("templates", () => {
|
||||
describe("formatMilestoneIssueBody", () => {
|
||||
it("includes title and vision", () => {
|
||||
const body = formatMilestoneIssueBody({
|
||||
id: "M001",
|
||||
title: "Build Auth",
|
||||
vision: "Secure authentication for all users",
|
||||
});
|
||||
assert.ok(body.includes("M001: Build Auth"));
|
||||
assert.ok(body.includes("Secure authentication"));
|
||||
});
|
||||
|
||||
it("renders success criteria as checkboxes", () => {
|
||||
const body = formatMilestoneIssueBody({
|
||||
id: "M001",
|
||||
title: "Auth",
|
||||
successCriteria: ["Users can log in", "OAuth works"],
|
||||
});
|
||||
assert.ok(body.includes("- [ ] Users can log in"));
|
||||
assert.ok(body.includes("- [ ] OAuth works"));
|
||||
});
|
||||
|
||||
it("renders slice table", () => {
|
||||
const body = formatMilestoneIssueBody({
|
||||
id: "M001",
|
||||
title: "Auth",
|
||||
slices: [
|
||||
{ id: "S01", title: "Core types", taskCount: 3 },
|
||||
{ id: "S02", title: "OAuth", taskCount: 5 },
|
||||
],
|
||||
});
|
||||
assert.ok(body.includes("| S01 | Core types | 3 |"));
|
||||
assert.ok(body.includes("| S02 | OAuth | 5 |"));
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatSlicePRBody", () => {
|
||||
it("includes goal and must-haves", () => {
|
||||
const body = formatSlicePRBody({
|
||||
id: "S01",
|
||||
title: "Core Auth Types",
|
||||
goal: "Define all auth types",
|
||||
mustHaves: ["User type", "Session type"],
|
||||
});
|
||||
assert.ok(body.includes("Define all auth types"));
|
||||
assert.ok(body.includes("- User type"));
|
||||
assert.ok(body.includes("- Session type"));
|
||||
});
|
||||
|
||||
it("renders task checklist with issue links", () => {
|
||||
const body = formatSlicePRBody({
|
||||
id: "S01",
|
||||
title: "Auth",
|
||||
tasks: [
|
||||
{ id: "T01", title: "Types", issueNumber: 43 },
|
||||
{ id: "T02", title: "Schema" },
|
||||
],
|
||||
});
|
||||
assert.ok(body.includes("- [ ] T01: Types (#43)"));
|
||||
assert.ok(body.includes("- [ ] T02: Schema"));
|
||||
assert.ok(!body.includes("T02: Schema (#"));
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatTaskIssueBody", () => {
|
||||
it("includes files and verification", () => {
|
||||
const body = formatTaskIssueBody({
|
||||
id: "T01",
|
||||
title: "Add types",
|
||||
files: ["src/types.ts"],
|
||||
verifyCriteria: ["Types compile"],
|
||||
});
|
||||
assert.ok(body.includes("`src/types.ts`"));
|
||||
assert.ok(body.includes("- [ ] Types compile"));
|
||||
});
|
||||
});
|
||||
|
||||
describe("formatSummaryComment", () => {
|
||||
it("includes one-liner and body", () => {
|
||||
const comment = formatSummaryComment({
|
||||
oneLiner: "Added retry logic",
|
||||
body: "Implemented exponential backoff",
|
||||
});
|
||||
assert.ok(comment.includes("**Summary:** Added retry logic"));
|
||||
assert.ok(comment.includes("Implemented exponential backoff"));
|
||||
});
|
||||
|
||||
it("wraps frontmatter in details block", () => {
|
||||
const comment = formatSummaryComment({
|
||||
frontmatter: { duration: "45m", key_files: ["a.ts"] },
|
||||
});
|
||||
assert.ok(comment.includes("<details>"));
|
||||
assert.ok(comment.includes("duration:"));
|
||||
});
|
||||
|
||||
it("handles empty data gracefully", () => {
|
||||
const comment = formatSummaryComment({});
|
||||
assert.equal(typeof comment, "string");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
/**
|
||||
* Type definitions for the GitHub Sync extension.
|
||||
*
|
||||
* Config shape (stored in SF preferences under `github` key) and
|
||||
* sync mapping records (stored in `.sf/github-sync.json`).
|
||||
*/
|
||||
|
||||
// ─── Configuration ──────────────────────────────────────────────────────────
|
||||
|
||||
export interface GitHubSyncConfig {
|
||||
enabled: boolean;
|
||||
/** "owner/repo" — auto-detected from git remote if omitted. */
|
||||
repo?: string;
|
||||
/** GitHub Projects v2 number (optional). */
|
||||
project?: number;
|
||||
/** Labels applied to all created issues. */
|
||||
labels?: string[];
|
||||
/** Append "Resolves #N" to task commits. Default: true. */
|
||||
auto_link_commits?: boolean;
|
||||
/** Create per-slice draft PRs. Default: true. */
|
||||
slice_prs?: boolean;
|
||||
}
|
||||
|
||||
// ─── Sync Mapping ───────────────────────────────────────────────────────────
|
||||
|
||||
export interface SyncEntityRecord {
|
||||
issueNumber: number;
|
||||
lastSyncedAt: string;
|
||||
state: "open" | "closed";
|
||||
}
|
||||
|
||||
export interface MilestoneSyncRecord extends SyncEntityRecord {
|
||||
ghMilestoneNumber: number;
|
||||
}
|
||||
|
||||
export interface SliceSyncRecord extends SyncEntityRecord {
|
||||
prNumber: number;
|
||||
branch: string;
|
||||
}
|
||||
|
||||
export interface SyncMapping {
|
||||
version: 1;
|
||||
repo: string;
|
||||
milestones: Record<string, MilestoneSyncRecord>;
|
||||
slices: Record<string, SliceSyncRecord>;
|
||||
tasks: Record<string, SyncEntityRecord>;
|
||||
}
|
||||
|
|
@ -1,512 +0,0 @@
|
|||
/**
|
||||
* Google Search Extension
|
||||
*
|
||||
* Provides a `google_search` tool that performs web searches via Gemini's
|
||||
* Google Search grounding feature. Uses the user's existing GEMINI_API_KEY or
|
||||
* GOOGLE_GENERATIVE_AI_API_KEY and Google Cloud GenAI credits.
|
||||
*
|
||||
* The tool sends queries to Gemini Flash with `googleSearch: {}` enabled.
|
||||
* Gemini internally performs Google searches, synthesizes an answer, and
|
||||
* returns it with source URLs from grounding metadata.
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
truncateHead,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
|
||||
// ── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface SearchSource {
|
||||
title: string;
|
||||
uri: string;
|
||||
domain: string;
|
||||
}
|
||||
|
||||
interface SearchResult {
|
||||
answer: string;
|
||||
sources: SearchSource[];
|
||||
searchQueries: string[];
|
||||
cached: boolean;
|
||||
}
|
||||
|
||||
interface SearchDetails {
|
||||
query: string;
|
||||
sourceCount: number;
|
||||
cached: boolean;
|
||||
durationMs: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
// ── Lazy singleton client ────────────────────────────────────────────────────
|
||||
|
||||
type GoogleGenAIClient = {
|
||||
models: {
|
||||
generateContent: (args: {
|
||||
model: string;
|
||||
contents: string;
|
||||
config?: {
|
||||
tools?: Array<{ googleSearch: Record<string, never> }>;
|
||||
abortSignal?: AbortSignal;
|
||||
};
|
||||
}) => Promise<any>;
|
||||
};
|
||||
};
|
||||
|
||||
let client: GoogleGenAIClient | null = null;
|
||||
|
||||
function getGeminiApiKey(): string | undefined {
|
||||
return process.env.GEMINI_API_KEY || process.env.GOOGLE_GENERATIVE_AI_API_KEY;
|
||||
}
|
||||
|
||||
async function getClient(): Promise<GoogleGenAIClient> {
|
||||
if (!client) {
|
||||
const { GoogleGenAI } = await import("@google/genai");
|
||||
client = new GoogleGenAI({ apiKey: getGeminiApiKey()! });
|
||||
}
|
||||
return client;
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform a search using OAuth credentials via the Cloud Code Assist API.
|
||||
* This is used as a fallback when a Gemini API key env var is not set.
|
||||
*/
|
||||
async function searchWithOAuth(
|
||||
query: string,
|
||||
accessToken: string,
|
||||
projectId: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<SearchResult> {
|
||||
const model = process.env.GEMINI_SEARCH_MODEL || "gemini-2.5-flash";
|
||||
const url = `https://cloudcode-pa.googleapis.com/v1internal:streamGenerateContent?alt=sse`;
|
||||
|
||||
const GEMINI_CLI_HEADERS = {
|
||||
ideType: "IDE_UNSPECIFIED",
|
||||
platform: "PLATFORM_UNSPECIFIED",
|
||||
pluginType: "GEMINI",
|
||||
};
|
||||
|
||||
const executeFetch = async (retries = 3): Promise<Response> => {
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
Authorization: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "google-cloud-sdk vscode_cloudshelleditor/0.1",
|
||||
"X-Goog-Api-Client": "gl-node/22.17.0",
|
||||
"Client-Metadata": JSON.stringify(GEMINI_CLI_HEADERS),
|
||||
},
|
||||
body: JSON.stringify({
|
||||
project: projectId,
|
||||
model,
|
||||
request: {
|
||||
contents: [{ parts: [{ text: query }] }],
|
||||
tools: [{ googleSearch: {} }],
|
||||
},
|
||||
userAgent: "pi-coding-agent",
|
||||
}),
|
||||
signal,
|
||||
});
|
||||
|
||||
if (
|
||||
!response.ok &&
|
||||
retries > 0 &&
|
||||
(response.status === 429 || response.status >= 500)
|
||||
) {
|
||||
await new Promise((resolve) => setTimeout(resolve, 1000 * (4 - retries)));
|
||||
return executeFetch(retries - 1);
|
||||
}
|
||||
|
||||
return response;
|
||||
};
|
||||
|
||||
const response = await executeFetch();
|
||||
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(
|
||||
`Cloud Code Assist API error (${response.status}): ${errorText}`,
|
||||
);
|
||||
}
|
||||
|
||||
// Note: streamGenerateContent returns SSE; for now, we consume all chunks.
|
||||
// For simplicity and to match the previous structure, we'll read to end.
|
||||
const text = await response.text();
|
||||
const jsonLines = text
|
||||
.split("\n")
|
||||
.filter((l) => l.startsWith("data:"))
|
||||
.map((l) => l.slice(5).trim())
|
||||
.filter((l) => l.length > 0);
|
||||
|
||||
let data: any;
|
||||
if (jsonLines.length > 0) {
|
||||
// Aggregate chunks if needed, but for now we take the last chunk or assume it's one
|
||||
data = JSON.parse(jsonLines[jsonLines.length - 1]);
|
||||
} else {
|
||||
data = JSON.parse(text);
|
||||
}
|
||||
const candidate = data.response?.candidates?.[0];
|
||||
const answer =
|
||||
candidate?.content?.parts?.find((p: any) => p.text)?.text ?? "";
|
||||
const grounding = candidate?.groundingMetadata;
|
||||
|
||||
const sources: SearchSource[] = [];
|
||||
const seenTitles = new Set<string>();
|
||||
if (grounding?.groundingChunks) {
|
||||
for (const chunk of grounding.groundingChunks) {
|
||||
if (chunk.web) {
|
||||
const title = chunk.web.title ?? "Untitled";
|
||||
if (seenTitles.has(title)) continue;
|
||||
seenTitles.add(title);
|
||||
const domain = chunk.web.domain ?? title;
|
||||
sources.push({
|
||||
title,
|
||||
uri: chunk.web.uri ?? "",
|
||||
domain,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const searchQueries = grounding?.webSearchQueries ?? [];
|
||||
return { answer, sources, searchQueries, cached: false };
|
||||
}
|
||||
|
||||
// ── In-session cache ─────────────────────────────────────────────────────────
|
||||
|
||||
const resultCache = new Map<string, SearchResult>();
|
||||
|
||||
function cacheKey(query: string): string {
|
||||
return query.toLowerCase().trim();
|
||||
}
|
||||
|
||||
// ── Extension ────────────────────────────────────────────────────────────────
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
pi.registerTool({
|
||||
name: "google_search",
|
||||
label: "Google Search",
|
||||
description:
|
||||
"Search the web using Google Search via Gemini. " +
|
||||
"Returns an AI-synthesized answer grounded in Google Search results, plus source URLs. " +
|
||||
"Use this when you need current information from the web: recent events, documentation, " +
|
||||
"product details, technical references, news, etc. " +
|
||||
"Requires GEMINI_API_KEY, GOOGLE_GENERATIVE_AI_API_KEY, or Google login. Alternative to Brave-based search tools.",
|
||||
promptSnippet:
|
||||
"Search the web via Google Search to get current information with sources",
|
||||
promptGuidelines: [
|
||||
"Use google_search when you need up-to-date web information that isn't in your training data.",
|
||||
"Be specific with queries for better results, e.g. 'Next.js 15 app router migration guide' not just 'Next.js'.",
|
||||
"The tool returns both an answer and source URLs. Cite sources when sharing results with the user.",
|
||||
"Results are cached per-session, so repeated identical queries are free.",
|
||||
"You can still use fetch_page to read a specific URL if needed after getting results from google_search.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
query: Type.String({
|
||||
description:
|
||||
"The search query, e.g. 'latest Node.js LTS version' or 'how to configure Tailwind v4'",
|
||||
}),
|
||||
maxSources: Type.Optional(
|
||||
Type.Number({
|
||||
description:
|
||||
"Maximum number of source URLs to include (default 5, max 10).",
|
||||
minimum: 1,
|
||||
maximum: 10,
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, signal, _onUpdate, ctx) {
|
||||
const startTime = Date.now();
|
||||
const maxSources = Math.min(Math.max(params.maxSources ?? 5, 1), 10);
|
||||
|
||||
// Check for credentials
|
||||
let oauthToken: string | undefined;
|
||||
let projectId: string | undefined;
|
||||
|
||||
const geminiApiKey = getGeminiApiKey();
|
||||
|
||||
if (!geminiApiKey) {
|
||||
const oauthRaw =
|
||||
await ctx.modelRegistry.getApiKeyForProvider("google-gemini-cli");
|
||||
if (oauthRaw) {
|
||||
try {
|
||||
const parsed = JSON.parse(oauthRaw);
|
||||
oauthToken = parsed.token;
|
||||
projectId = parsed.projectId;
|
||||
} catch {
|
||||
// Fall through to error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!geminiApiKey && (!oauthToken || !projectId)) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Error: No authentication found for Google Search. Please set GEMINI_API_KEY, GOOGLE_GENERATIVE_AI_API_KEY, or log in via Google.\n\nExample: export GEMINI_API_KEY=your_key or use /login google",
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
query: params.query,
|
||||
sourceCount: 0,
|
||||
cached: false,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "auth_error: No credentials set",
|
||||
} as SearchDetails,
|
||||
};
|
||||
}
|
||||
|
||||
// Check cache
|
||||
const key = cacheKey(params.query);
|
||||
if (resultCache.has(key)) {
|
||||
const cached = resultCache.get(key)!;
|
||||
const output = formatOutput(cached, maxSources);
|
||||
return {
|
||||
content: [{ type: "text", text: output }],
|
||||
details: {
|
||||
query: params.query,
|
||||
sourceCount: cached.sources.length,
|
||||
cached: true,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as SearchDetails,
|
||||
};
|
||||
}
|
||||
|
||||
// Call Gemini with Google Search grounding
|
||||
let result: SearchResult;
|
||||
try {
|
||||
if (geminiApiKey) {
|
||||
const ai = await getClient();
|
||||
|
||||
// Add a 30-second timeout to prevent hanging (#1100)
|
||||
const timeoutController = new AbortController();
|
||||
const timeoutId = setTimeout(() => timeoutController.abort(), 30_000);
|
||||
const combinedSignal = signal
|
||||
? AbortSignal.any([signal, timeoutController.signal])
|
||||
: timeoutController.signal;
|
||||
|
||||
let response: Awaited<ReturnType<typeof ai.models.generateContent>>;
|
||||
try {
|
||||
response = await ai.models.generateContent({
|
||||
model: process.env.GEMINI_SEARCH_MODEL || "gemini-2.5-flash",
|
||||
contents: params.query,
|
||||
config: {
|
||||
tools: [{ googleSearch: {} }],
|
||||
abortSignal: combinedSignal,
|
||||
},
|
||||
});
|
||||
} finally {
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
// Extract answer text
|
||||
const answer = response.text ?? "";
|
||||
|
||||
// Extract grounding metadata
|
||||
const candidate = response.candidates?.[0];
|
||||
const grounding = candidate?.groundingMetadata;
|
||||
|
||||
// Parse sources from grounding chunks
|
||||
const sources: SearchSource[] = [];
|
||||
const seenTitles = new Set<string>();
|
||||
if (grounding?.groundingChunks) {
|
||||
for (const chunk of grounding.groundingChunks) {
|
||||
if (chunk.web) {
|
||||
const title = chunk.web.title ?? "Untitled";
|
||||
// Dedupe by title since URIs are redirect URLs that differ per call
|
||||
if (seenTitles.has(title)) continue;
|
||||
seenTitles.add(title);
|
||||
// domain field is not available via Gemini API, use title as fallback
|
||||
// (title is typically the domain name, e.g. "wikipedia.org")
|
||||
const domain = chunk.web.domain ?? title;
|
||||
sources.push({
|
||||
title,
|
||||
uri: chunk.web.uri ?? "",
|
||||
domain,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract search queries Gemini actually performed
|
||||
const searchQueries = grounding?.webSearchQueries ?? [];
|
||||
result = { answer, sources, searchQueries, cached: false };
|
||||
} else {
|
||||
result = await searchWithOAuth(
|
||||
params.query,
|
||||
oauthToken!,
|
||||
projectId!,
|
||||
signal,
|
||||
);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
|
||||
let errorType = "api_error";
|
||||
if (msg.includes("401") || msg.includes("UNAUTHENTICATED")) {
|
||||
errorType = "auth_error";
|
||||
} else if (
|
||||
msg.includes("429") ||
|
||||
msg.includes("RESOURCE_EXHAUSTED") ||
|
||||
msg.includes("quota")
|
||||
) {
|
||||
errorType = "rate_limit";
|
||||
}
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Google Search failed (${errorType}): ${msg}`,
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
query: params.query,
|
||||
sourceCount: 0,
|
||||
cached: false,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: `${errorType}: ${msg}`,
|
||||
} as SearchDetails,
|
||||
};
|
||||
}
|
||||
|
||||
// Cache the result
|
||||
resultCache.set(key, result);
|
||||
|
||||
// Format and truncate output
|
||||
const rawOutput = formatOutput(result, maxSources);
|
||||
const truncation = truncateHead(rawOutput, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText +=
|
||||
`\n\n[Truncated: showing ${truncation.outputLines}/${truncation.totalLines} lines` +
|
||||
` (${formatSize(truncation.outputBytes)} of ${formatSize(truncation.totalBytes)})]`;
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: {
|
||||
query: params.query,
|
||||
sourceCount: result.sources.length,
|
||||
cached: false,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as SearchDetails,
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("google_search "));
|
||||
text += theme.fg("accent", `"${args.query}"`);
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial, expanded }, theme) {
|
||||
const d = result.details as SearchDetails | undefined;
|
||||
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Searching Google..."), 0, 0);
|
||||
if ((result as any).isError || d?.error) {
|
||||
return new Text(
|
||||
theme.fg("error", `Error: ${d?.error ?? "unknown"}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
let text = theme.fg("success", `${d?.sourceCount ?? 0} sources`);
|
||||
text += theme.fg("dim", ` (${d?.durationMs ?? 0}ms)`);
|
||||
if (d?.cached) text += theme.fg("dim", " · cached");
|
||||
|
||||
if (expanded) {
|
||||
const content = result.content[0];
|
||||
if (content?.type === "text") {
|
||||
const preview = content.text.split("\n").slice(0, 8).join("\n");
|
||||
text += "\n\n" + theme.fg("dim", preview);
|
||||
if (content.text.split("\n").length > 8) {
|
||||
text += "\n" + theme.fg("muted", "...");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
});
|
||||
|
||||
// ── Session cleanup ─────────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
resultCache.clear();
|
||||
client = null;
|
||||
});
|
||||
|
||||
// ── Startup notification ─────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
if (getGeminiApiKey()) return;
|
||||
|
||||
const hasOAuth =
|
||||
await ctx.modelRegistry.authStorage.hasAuth("google-gemini-cli");
|
||||
if (!hasOAuth) {
|
||||
ctx.ui.notify(
|
||||
"Google Search: No authentication set. Log in via Google or set GEMINI_API_KEY / GOOGLE_GENERATIVE_AI_API_KEY to use google_search.",
|
||||
"warning",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// ── Output formatting ────────────────────────────────────────────────────────
|
||||
|
||||
function formatOutput(result: SearchResult, maxSources: number): string {
|
||||
const lines: string[] = [];
|
||||
|
||||
// Answer
|
||||
if (result.answer) {
|
||||
lines.push(result.answer);
|
||||
} else {
|
||||
lines.push("(No answer text returned from search)");
|
||||
}
|
||||
|
||||
// Sources
|
||||
if (result.sources.length > 0) {
|
||||
lines.push("");
|
||||
lines.push("Sources:");
|
||||
const sourcesToShow = result.sources.slice(0, maxSources);
|
||||
for (let i = 0; i < sourcesToShow.length; i++) {
|
||||
const s = sourcesToShow[i];
|
||||
lines.push(`[${i + 1}] ${s.title} - ${s.domain}`);
|
||||
lines.push(` ${s.uri}`);
|
||||
}
|
||||
if (result.sources.length > maxSources) {
|
||||
lines.push(
|
||||
`(${result.sources.length - maxSources} more sources omitted)`,
|
||||
);
|
||||
}
|
||||
} else {
|
||||
lines.push("");
|
||||
lines.push("(No source URLs found in grounding metadata)");
|
||||
}
|
||||
|
||||
// Search queries
|
||||
if (result.searchQueries.length > 0) {
|
||||
lines.push("");
|
||||
lines.push(
|
||||
`Searches performed: ${result.searchQueries.map((q) => `"${q}"`).join(", ")}`,
|
||||
);
|
||||
}
|
||||
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
{
|
||||
"name": "pi-extension-google-search",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.ts"
|
||||
]
|
||||
}
|
||||
"name": "pi-extension-google-search",
|
||||
"private": true,
|
||||
"version": "1.0.0",
|
||||
"type": "module",
|
||||
"engines": {
|
||||
"node": ">=24.15.0"
|
||||
},
|
||||
"pi": {
|
||||
"extensions": [
|
||||
"./index.js"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,732 +0,0 @@
|
|||
/**
|
||||
* Guardrails Extension — Security & Redaction
|
||||
*
|
||||
* Ported from the pi community "agents" extension pack.
|
||||
*
|
||||
* Features:
|
||||
* - Redacts secrets from tool results before the LLM sees them
|
||||
* - Blocks dangerous bash commands (rm -rf, sudo, mkfs, etc.)
|
||||
* - Blocks writes to protected paths (.env, .git, .ssh, etc.)
|
||||
*/
|
||||
|
||||
import * as path from "node:path";
|
||||
import type {
|
||||
ExtensionAPI,
|
||||
ExtensionContext,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
|
||||
// ============================================================================
|
||||
// Secret Redaction
|
||||
// ============================================================================
|
||||
|
||||
interface RedactionRule {
|
||||
pattern: RegExp;
|
||||
replacement: string;
|
||||
}
|
||||
|
||||
const SENSITIVE_PATTERNS: RedactionRule[] = [
|
||||
{
|
||||
pattern: /\b(sk-[a-zA-Z0-9]{20,})\b/g,
|
||||
replacement: "[OPENAI_KEY_REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern: /\b(ghp_[a-zA-Z0-9]{36,})\b/g,
|
||||
replacement: "[GITHUB_TOKEN_REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern: /\b(gho_[a-zA-Z0-9]{36,})\b/g,
|
||||
replacement: "[GITHUB_OAUTH_REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern: /\b(xox[baprs]-[a-zA-Z0-9-]{10,})\b/g,
|
||||
replacement: "[SLACK_TOKEN_REDACTED]",
|
||||
},
|
||||
{ pattern: /\b(AKIA[A-Z0-9]{16})\b/g, replacement: "[AWS_KEY_REDACTED]" },
|
||||
{
|
||||
pattern: /\b(api[_-]?key|apikey)\s*[=:]\s*['"]?([a-zA-Z0-9_-]{20,})['"]?/gi,
|
||||
replacement: "$1=[REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern:
|
||||
/\b(secret|token|password|passwd|pwd)\s*[=:]\s*['"]?([^\s'"]{8,})['"]?/gi,
|
||||
replacement: "$1=[REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern: /\b(bearer)\s+([a-zA-Z0-9._-]{20,})\b/gi,
|
||||
replacement: "Bearer [REDACTED]",
|
||||
},
|
||||
{
|
||||
pattern: /(mongodb(\+srv)?:\/\/[^:]+:)[^@]+(@)/gi,
|
||||
replacement: "$1[REDACTED]$3",
|
||||
},
|
||||
{
|
||||
pattern: /(postgres(ql)?:\/\/[^:]+:)[^@]+(@)/gi,
|
||||
replacement: "$1[REDACTED]$3",
|
||||
},
|
||||
{ pattern: /(mysql:\/\/[^:]+:)[^@]+(@)/gi, replacement: "$1[REDACTED]$3" },
|
||||
{ pattern: /(redis:\/\/[^:]+:)[^@]+(@)/gi, replacement: "$1[REDACTED]$3" },
|
||||
{
|
||||
pattern:
|
||||
/-----BEGIN (RSA |EC |OPENSSH |)PRIVATE KEY-----[\s\S]*?-----END \1PRIVATE KEY-----/g,
|
||||
replacement: "[PRIVATE_KEY_REDACTED]",
|
||||
},
|
||||
];
|
||||
|
||||
const SENSITIVE_FILES: { pattern: RegExp; desc: string }[] = [
|
||||
{ pattern: /\.env$/, desc: ".env" },
|
||||
{ pattern: /\.env\.(?!example$)[^/]+$/, desc: ".env local/override" },
|
||||
{ pattern: /\.dev\.vars($|\.[ˆ/]+$)/, desc: ".dev.vars" },
|
||||
{ pattern: /secrets?\.(json|ya?ml|toml)$/i, desc: "secrets file" },
|
||||
{ pattern: /credentials/i, desc: "credentials file" },
|
||||
];
|
||||
|
||||
function redactToolResult(
|
||||
toolName: string,
|
||||
filePath: string | undefined,
|
||||
text: string,
|
||||
ctx: ExtensionContext,
|
||||
): { content: [{ type: "text"; text: string }] } | undefined {
|
||||
if (toolName === "read" && filePath) {
|
||||
if (/(^|\/)\.env\.example$/i.test(filePath)) {
|
||||
return undefined;
|
||||
}
|
||||
for (const { pattern, desc } of SENSITIVE_FILES) {
|
||||
if (pattern.test(filePath)) {
|
||||
ctx.ui.notify(
|
||||
`🔒 Redacted contents of sensitive file: ${filePath}`,
|
||||
"info",
|
||||
);
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `[Contents of ${desc} (${filePath}) redacted for security]`,
|
||||
},
|
||||
],
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = text;
|
||||
let modified = false;
|
||||
for (const { pattern, replacement } of SENSITIVE_PATTERNS) {
|
||||
const next = result.replace(pattern, replacement);
|
||||
if (next !== result) {
|
||||
modified = true;
|
||||
result = next;
|
||||
}
|
||||
}
|
||||
|
||||
if (modified) {
|
||||
ctx.ui.notify("🔒 Sensitive data redacted from output", "info");
|
||||
return { content: [{ type: "text", text: result }] };
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Command & Path Security
|
||||
// ============================================================================
|
||||
|
||||
interface DangerousCommand {
|
||||
pattern: RegExp;
|
||||
desc: string;
|
||||
}
|
||||
|
||||
const DANGEROUS_COMMANDS: DangerousCommand[] = [
|
||||
{ pattern: /\brm\s+(-[^\s]*r|--recursive)/, desc: "recursive delete" },
|
||||
{ pattern: /\bsudo\b/, desc: "sudo command" },
|
||||
{ pattern: /\b(chmod|chown)\b.*777/, desc: "dangerous permissions" },
|
||||
{ pattern: /\bmkfs\b/, desc: "filesystem format" },
|
||||
{ pattern: /\bdd\b.*\bof=\/dev\//, desc: "raw device write" },
|
||||
{ pattern: />\s*\/dev\/sd[a-z]/, desc: "raw device overwrite" },
|
||||
{ pattern: /\bkill\s+-9\s+-1\b/, desc: "kill all processes" },
|
||||
{ pattern: /:\(\)\s*\{\s*:\s*\|\s*:\s*&\s*\}\s*;/, desc: "fork bomb" },
|
||||
];
|
||||
|
||||
const PROTECTED_PATHS: { pattern: RegExp; desc: string }[] = [
|
||||
{ pattern: /\.env($|\.(?!example))/, desc: "environment file" },
|
||||
{ pattern: /\.dev\.vars($|\.[ˆ/]+$)/, desc: "dev vars file" },
|
||||
{ pattern: /node_modules\//, desc: "node_modules" },
|
||||
{ pattern: /^\.git\/|\/\.git\//, desc: "git directory" },
|
||||
{ pattern: /\.pem$|\.key$/, desc: "private key file" },
|
||||
{ pattern: /id_rsa|id_ed25519|id_ecdsa/, desc: "SSH key" },
|
||||
{ pattern: /\.ssh\//, desc: ".ssh directory" },
|
||||
{ pattern: /secrets?\.(json|ya?ml|toml)$/i, desc: "secrets file" },
|
||||
{ pattern: /credentials/i, desc: "credentials file" },
|
||||
];
|
||||
|
||||
const SOFT_PROTECTED_PATHS: { pattern: RegExp; desc: string }[] = [
|
||||
{ pattern: /package-lock\.json$/, desc: "package-lock.json" },
|
||||
{ pattern: /yarn\.lock$/, desc: "yarn.lock" },
|
||||
{ pattern: /pnpm-lock\.yaml$/, desc: "pnpm-lock.yaml" },
|
||||
];
|
||||
|
||||
const DANGEROUS_BASH_WRITES: RegExp[] = [
|
||||
/>\s*\.env(?!\.example)(\b|$)/,
|
||||
/>\s*\.dev\.vars/,
|
||||
/>\s*.*\.pem/,
|
||||
/>\s*.*\.key/,
|
||||
/tee\s+.*\.env(?!\.example)(\b|$)/,
|
||||
/tee\s+.*\.dev\.vars/,
|
||||
/cp\s+.*\s+\.env(?!\.example)(\b|$)/,
|
||||
/mv\s+.*\s+\.env(?!\.example)(\b|$)/,
|
||||
];
|
||||
|
||||
async function checkBashCommand(
|
||||
command: string,
|
||||
ctx: ExtensionContext,
|
||||
): Promise<{ block: true; reason: string } | undefined> {
|
||||
for (const { pattern, desc } of DANGEROUS_COMMANDS) {
|
||||
if (pattern.test(command)) {
|
||||
if (!ctx.hasUI) {
|
||||
return { block: true, reason: `Blocked ${desc} (no UI to confirm)` };
|
||||
}
|
||||
const ok = await ctx.ui.confirm(`⚠️ Dangerous command: ${desc}`, command);
|
||||
if (!ok) {
|
||||
return { block: true, reason: `Blocked ${desc} by user` };
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
for (const pattern of DANGEROUS_BASH_WRITES) {
|
||||
if (pattern.test(command)) {
|
||||
ctx.ui.notify("🛡️ Blocked bash write to protected path", "warning");
|
||||
return { block: true, reason: "Bash command writes to protected path" };
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async function checkWritePath(
|
||||
filePath: string,
|
||||
ctx: ExtensionContext,
|
||||
): Promise<{ block: true; reason: string } | undefined> {
|
||||
const normalized = path.normalize(filePath);
|
||||
|
||||
for (const { pattern, desc } of PROTECTED_PATHS) {
|
||||
if (pattern.test(normalized)) {
|
||||
ctx.ui.notify(`🛡️ Blocked write to ${desc}: ${filePath}`, "warning");
|
||||
return { block: true, reason: `Protected path: ${desc}` };
|
||||
}
|
||||
}
|
||||
|
||||
for (const { pattern, desc } of SOFT_PROTECTED_PATHS) {
|
||||
if (pattern.test(normalized)) {
|
||||
if (!ctx.hasUI) {
|
||||
return { block: true, reason: `Protected path (no UI): ${desc}` };
|
||||
}
|
||||
const ok = await ctx.ui.confirm(
|
||||
`⚠️ Modifying ${desc}`,
|
||||
`Are you sure you want to modify ${filePath}?`,
|
||||
);
|
||||
if (!ok) {
|
||||
return { block: true, reason: `User blocked write to ${desc}` };
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Safe Git
|
||||
// ============================================================================
|
||||
|
||||
type PromptLevel = "high" | "medium" | "none";
|
||||
type Severity = "high" | "medium";
|
||||
type GitCommandDecision = { block: true; reason: string } | undefined;
|
||||
|
||||
interface SafeGitConfig {
|
||||
promptLevel?: PromptLevel;
|
||||
enabledByDefault?: boolean;
|
||||
}
|
||||
|
||||
interface SafeGitGateState {
|
||||
pendingDecisions: Map<string, Promise<GitCommandDecision>>;
|
||||
recentOnceApprovals: Map<string, number>;
|
||||
}
|
||||
|
||||
const SAFE_GIT_DEFAULTS: Required<SafeGitConfig> = {
|
||||
promptLevel: "medium",
|
||||
enabledByDefault: true,
|
||||
};
|
||||
|
||||
const RECENT_ONCE_APPROVAL_TTL_MS = 5_000;
|
||||
|
||||
const GIT_PATTERNS: { pattern: RegExp; action: string; severity: Severity }[] =
|
||||
[
|
||||
// High risk
|
||||
{
|
||||
pattern: /\bgit\s+push\s+.*--force(-with-lease)?\b/i,
|
||||
action: "force push",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+reset\s+--hard\b/i,
|
||||
action: "hard reset",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+clean\s+-[a-z]*f/i,
|
||||
action: "clean (remove untracked files)",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+stash\s+(drop|clear)\b/i,
|
||||
action: "drop/clear stash",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+branch\s+-[dD]\b/i,
|
||||
action: "delete branch",
|
||||
severity: "high",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+reflog\s+expire\b/i,
|
||||
action: "expire reflog",
|
||||
severity: "high",
|
||||
},
|
||||
// Medium risk
|
||||
{ pattern: /\bgit\s+push\b/i, action: "push", severity: "medium" },
|
||||
{ pattern: /\bgit\s+commit\b/i, action: "commit", severity: "medium" },
|
||||
{ pattern: /\bgit\s+rebase\b/i, action: "rebase", severity: "medium" },
|
||||
{ pattern: /\bgit\s+merge\b/i, action: "merge", severity: "medium" },
|
||||
{
|
||||
pattern: /\bgit\s+tag\b/i,
|
||||
action: "create/modify tag",
|
||||
severity: "medium",
|
||||
},
|
||||
{
|
||||
pattern: /\bgit\s+cherry-pick\b/i,
|
||||
action: "cherry-pick",
|
||||
severity: "medium",
|
||||
},
|
||||
{ pattern: /\bgit\s+revert\b/i, action: "revert", severity: "medium" },
|
||||
{ pattern: /\bgit\s+am\b/i, action: "apply patches", severity: "medium" },
|
||||
// GitHub CLI
|
||||
{ pattern: /\bgh\s+\S+/i, action: "GitHub CLI", severity: "medium" },
|
||||
];
|
||||
|
||||
const severityIcons: Record<Severity, string> = {
|
||||
high: "🔴",
|
||||
medium: "🟡",
|
||||
};
|
||||
|
||||
function getSafeGitConfig(
|
||||
ctx: ExtensionContext,
|
||||
enabledOverride?: boolean | null,
|
||||
promptLevelOverride?: PromptLevel | null,
|
||||
): { enabled: boolean; promptLevel: PromptLevel } {
|
||||
const settings = (ctx as any).settingsManager?.getSettings() ?? {};
|
||||
const config: Required<SafeGitConfig> = {
|
||||
...SAFE_GIT_DEFAULTS,
|
||||
...(settings.safeGit ?? {}),
|
||||
};
|
||||
return {
|
||||
enabled:
|
||||
enabledOverride !== null && enabledOverride !== undefined
|
||||
? enabledOverride
|
||||
: config.enabledByDefault,
|
||||
promptLevel:
|
||||
promptLevelOverride !== null && promptLevelOverride !== undefined
|
||||
? promptLevelOverride
|
||||
: config.promptLevel,
|
||||
};
|
||||
}
|
||||
|
||||
function shouldPrompt(severity: Severity, promptLevel: PromptLevel): boolean {
|
||||
if (promptLevel === "none") return false;
|
||||
if (promptLevel === "high") return severity === "high";
|
||||
return true;
|
||||
}
|
||||
|
||||
function gitGateKey(action: string, command: string): string {
|
||||
return `${action}\0${command.trim().replace(/\s+/g, " ")}`;
|
||||
}
|
||||
|
||||
function pruneRecentOnceApprovals(
|
||||
state: SafeGitGateState,
|
||||
now = Date.now(),
|
||||
): void {
|
||||
for (const [key, expiresAt] of state.recentOnceApprovals) {
|
||||
if (expiresAt <= now) state.recentOnceApprovals.delete(key);
|
||||
}
|
||||
}
|
||||
|
||||
async function promptForGitCommand(
|
||||
action: string,
|
||||
severity: Severity,
|
||||
gateKey: string,
|
||||
ctx: ExtensionContext,
|
||||
sessionApprovedActions: Set<string>,
|
||||
sessionBlockedActions: Set<string>,
|
||||
gateState: SafeGitGateState,
|
||||
): Promise<GitCommandDecision> {
|
||||
const icon = severityIcons[severity];
|
||||
const title =
|
||||
severity === "high"
|
||||
? `${icon} ⚠️ HIGH RISK: Git ${action} requires approval`
|
||||
: `${icon} Git ${action} requires approval`;
|
||||
|
||||
let choice: string | string[] | undefined;
|
||||
try {
|
||||
choice = await ctx.ui.select(title, [
|
||||
"✅ Allow this command once",
|
||||
"⏭️ Decline this time (ask again later)",
|
||||
`✅✅ Auto-approve all "git ${action}" for this session only`,
|
||||
`🚫 Auto-block all "git ${action}" for this session only`,
|
||||
]);
|
||||
} catch {
|
||||
choice = undefined;
|
||||
}
|
||||
|
||||
if (typeof choice !== "string") {
|
||||
ctx.ui.notify(
|
||||
`Git ${action} approval not answered; command paused`,
|
||||
"warning",
|
||||
);
|
||||
return {
|
||||
block: true,
|
||||
reason: `Git ${action} approval not answered; command paused`,
|
||||
};
|
||||
}
|
||||
|
||||
if (!choice || choice.startsWith("⏭️")) {
|
||||
ctx.ui.notify(`Git ${action} declined`, "info");
|
||||
return { block: true, reason: `Git ${action} declined by user` };
|
||||
}
|
||||
if (choice.startsWith("🚫")) {
|
||||
sessionBlockedActions.add(action);
|
||||
ctx.ui.notify(
|
||||
`🚫 All "git ${action}" commands auto-blocked for this session`,
|
||||
"warning",
|
||||
);
|
||||
return {
|
||||
block: true,
|
||||
reason: `Git ${action} blocked by user (session setting)`,
|
||||
};
|
||||
}
|
||||
if (choice.startsWith("✅✅")) {
|
||||
sessionApprovedActions.add(action);
|
||||
ctx.ui.notify(
|
||||
`✅ All "git ${action}" commands auto-approved for this session`,
|
||||
"info",
|
||||
);
|
||||
} else {
|
||||
gateState.recentOnceApprovals.set(
|
||||
gateKey,
|
||||
Date.now() + RECENT_ONCE_APPROVAL_TTL_MS,
|
||||
);
|
||||
ctx.ui.notify(`Git ${action} approved once`, "info");
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
async function checkGitCommand(
|
||||
command: string,
|
||||
ctx: ExtensionContext,
|
||||
sessionApprovedActions: Set<string>,
|
||||
sessionBlockedActions: Set<string>,
|
||||
gateState: SafeGitGateState,
|
||||
enabledOverride?: boolean | null,
|
||||
promptLevelOverride?: PromptLevel | null,
|
||||
): Promise<{ block: true; reason: string } | undefined> {
|
||||
const { enabled, promptLevel } = getSafeGitConfig(
|
||||
ctx,
|
||||
enabledOverride,
|
||||
promptLevelOverride,
|
||||
);
|
||||
if (!enabled || promptLevel === "none") return undefined;
|
||||
|
||||
for (const { pattern, action, severity } of GIT_PATTERNS) {
|
||||
if (pattern.test(command)) {
|
||||
if (sessionBlockedActions.has(action)) {
|
||||
ctx.ui.notify(
|
||||
`🚫 Git ${action} auto-blocked (session setting)`,
|
||||
"warning",
|
||||
);
|
||||
return {
|
||||
block: true,
|
||||
reason: `Git ${action} blocked by user (session setting)`,
|
||||
};
|
||||
}
|
||||
if (sessionApprovedActions.has(action)) {
|
||||
ctx.ui.notify(
|
||||
`✅ Git ${action} auto-approved (session setting)`,
|
||||
"info",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
const gateKey = gitGateKey(action, command);
|
||||
pruneRecentOnceApprovals(gateState);
|
||||
if (gateState.recentOnceApprovals.has(gateKey)) {
|
||||
ctx.ui.notify(
|
||||
`Git ${action} approval reused for duplicate request`,
|
||||
"info",
|
||||
);
|
||||
return undefined;
|
||||
}
|
||||
if (!shouldPrompt(severity, promptLevel)) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (!ctx.hasUI) {
|
||||
return {
|
||||
block: true,
|
||||
reason: `Git ${action} blocked: requires explicit user approval (no UI available)`,
|
||||
};
|
||||
}
|
||||
|
||||
const existingDecision = gateState.pendingDecisions.get(gateKey);
|
||||
if (existingDecision) return existingDecision;
|
||||
|
||||
const pendingDecision = promptForGitCommand(
|
||||
action,
|
||||
severity,
|
||||
gateKey,
|
||||
ctx,
|
||||
sessionApprovedActions,
|
||||
sessionBlockedActions,
|
||||
gateState,
|
||||
);
|
||||
gateState.pendingDecisions.set(gateKey, pendingDecision);
|
||||
const cleanup = () => {
|
||||
if (gateState.pendingDecisions.get(gateKey) === pendingDecision) {
|
||||
gateState.pendingDecisions.delete(gateKey);
|
||||
}
|
||||
};
|
||||
pendingDecision.then(cleanup, cleanup);
|
||||
return pendingDecision;
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function registerSafeGitCommands(
|
||||
pi: ExtensionAPI,
|
||||
sessionEnabledOverride: { value: boolean | null },
|
||||
sessionPromptLevelOverride: { value: PromptLevel | null },
|
||||
yoloPreviousPromptLevel: { value: PromptLevel | null },
|
||||
) {
|
||||
pi.registerCommand("safegit", {
|
||||
description: "Toggle safe-git protection on/off for this session",
|
||||
handler: async (_, ctx) => {
|
||||
const { enabled } = getSafeGitConfig(
|
||||
ctx,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
sessionEnabledOverride.value = !enabled;
|
||||
ctx.ui.notify(
|
||||
sessionEnabledOverride.value
|
||||
? "🔒 Safe-git protection ON"
|
||||
: "🔓 Safe-git protection OFF",
|
||||
"info",
|
||||
);
|
||||
ctx.ui.notify("(Temporary for this session)", "info");
|
||||
},
|
||||
});
|
||||
|
||||
pi.registerCommand("safegit-level", {
|
||||
description: "Set prompt level: high, medium, or none",
|
||||
handler: async (args, ctx) => {
|
||||
const arg = typeof args === "string" ? args.trim().toLowerCase() : "";
|
||||
if (arg === "high" || arg === "medium" || arg === "none") {
|
||||
sessionPromptLevelOverride.value = arg;
|
||||
const desc = {
|
||||
high: "🔴 Only high-risk operations require approval",
|
||||
medium: "🟡 Medium and high-risk operations require approval",
|
||||
none: "⚠️ No approval required (protection disabled)",
|
||||
};
|
||||
ctx.ui.notify(`Prompt level: ${arg}`, "info");
|
||||
ctx.ui.notify(desc[arg], "info");
|
||||
ctx.ui.notify("(Temporary for this session)", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
const { promptLevel } = getSafeGitConfig(
|
||||
ctx,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
const options = [
|
||||
`🔴 high - Only high-risk (force push, hard reset, etc.)`,
|
||||
`🟡 medium - Medium and high-risk (push, commit, etc.)`,
|
||||
`⚠️ none - No prompts (disable protection)`,
|
||||
`❌ Cancel`,
|
||||
];
|
||||
|
||||
ctx.ui.notify(`Current level: ${promptLevel}\n`, "info");
|
||||
const choice = await ctx.ui.select("Set prompt level:", options);
|
||||
const selectedChoice = typeof choice === "string" ? choice : undefined;
|
||||
if (!selectedChoice || selectedChoice.startsWith("❌")) {
|
||||
ctx.ui.notify("Cancelled.", "info");
|
||||
return;
|
||||
}
|
||||
const level = selectedChoice.split(" ")[1] as PromptLevel;
|
||||
sessionPromptLevelOverride.value = level;
|
||||
ctx.ui.notify(`Prompt level set to: ${selectedChoice}`, "info");
|
||||
ctx.ui.notify("(Temporary for this session)", "info");
|
||||
},
|
||||
});
|
||||
|
||||
pi.registerCommand("yolo", {
|
||||
description: "Toggle session-only safe-git prompt bypass",
|
||||
handler: async (_, ctx) => {
|
||||
const { promptLevel } = getSafeGitConfig(
|
||||
ctx,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
|
||||
if (promptLevel === "none") {
|
||||
sessionPromptLevelOverride.value =
|
||||
yoloPreviousPromptLevel.value ?? SAFE_GIT_DEFAULTS.promptLevel;
|
||||
yoloPreviousPromptLevel.value = null;
|
||||
ctx.ui.notify(
|
||||
`YOLO mode OFF - safe-git prompt level restored to ${sessionPromptLevelOverride.value}`,
|
||||
"info",
|
||||
);
|
||||
} else {
|
||||
yoloPreviousPromptLevel.value = promptLevel;
|
||||
sessionPromptLevelOverride.value = "none";
|
||||
ctx.ui.notify(
|
||||
"YOLO mode ON - safe-git prompts disabled for this session",
|
||||
"info",
|
||||
);
|
||||
}
|
||||
ctx.ui.notify("(Temporary for this session)", "info");
|
||||
},
|
||||
});
|
||||
|
||||
pi.registerCommand("safegit-status", {
|
||||
description: "Show safe-git status and settings",
|
||||
handler: async (_, ctx) => {
|
||||
const settings = (ctx as any).settingsManager?.getSettings() ?? {};
|
||||
const globalConfig: Required<SafeGitConfig> = {
|
||||
...SAFE_GIT_DEFAULTS,
|
||||
...(settings.safeGit ?? {}),
|
||||
};
|
||||
const { enabled, promptLevel } = getSafeGitConfig(
|
||||
ctx,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
|
||||
const lines = [
|
||||
"─── Safe Git Status ───",
|
||||
"",
|
||||
"Session State:",
|
||||
` Enabled: ${enabled ? "🔒 ON" : "🔓 OFF"}${sessionEnabledOverride.value !== null ? " (session override)" : ""}`,
|
||||
` Prompt Level: ${promptLevel}${sessionPromptLevelOverride.value !== null ? " (session override)" : ""}`,
|
||||
"",
|
||||
"Global Defaults:",
|
||||
` Enabled: ${globalConfig.enabledByDefault ? "ON" : "OFF"}`,
|
||||
` Prompt Level: ${globalConfig.promptLevel}`,
|
||||
"",
|
||||
"Prompt Levels:",
|
||||
` 🔴 high - force push, hard reset, clean, delete branch`,
|
||||
` 🟡 medium - push, commit, rebase, merge, tag, gh CLI`,
|
||||
"",
|
||||
"Commands: /yolo /safegit /safegit-level /safegit-status",
|
||||
"───────────────────────",
|
||||
];
|
||||
ctx.ui.notify(lines.join("\n"), "info");
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Entry Point
|
||||
// ============================================================================
|
||||
|
||||
export default function guardrails(pi: ExtensionAPI): void {
|
||||
const sessionApprovedActions = new Set<string>();
|
||||
const sessionBlockedActions = new Set<string>();
|
||||
const gateState: SafeGitGateState = {
|
||||
pendingDecisions: new Map(),
|
||||
recentOnceApprovals: new Map(),
|
||||
};
|
||||
const sessionEnabledOverride: { value: boolean | null } = { value: null };
|
||||
const sessionPromptLevelOverride: { value: PromptLevel | null } = {
|
||||
value: null,
|
||||
};
|
||||
const yoloPreviousPromptLevel: { value: PromptLevel | null } = {
|
||||
value: null,
|
||||
};
|
||||
|
||||
registerSafeGitCommands(
|
||||
pi,
|
||||
sessionEnabledOverride,
|
||||
sessionPromptLevelOverride,
|
||||
yoloPreviousPromptLevel,
|
||||
);
|
||||
|
||||
pi.on("session_start", async (_, ctx) => {
|
||||
sessionEnabledOverride.value = null;
|
||||
sessionPromptLevelOverride.value = null;
|
||||
yoloPreviousPromptLevel.value = null;
|
||||
sessionApprovedActions.clear();
|
||||
sessionBlockedActions.clear();
|
||||
gateState.pendingDecisions.clear();
|
||||
gateState.recentOnceApprovals.clear();
|
||||
|
||||
const { enabled, promptLevel } = getSafeGitConfig(
|
||||
ctx,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
if (ctx.hasUI && enabled && promptLevel !== "none") {
|
||||
const promptDesc =
|
||||
promptLevel === "high" ? "🔴 high-risk only" : "🟡 medium+high";
|
||||
ctx.ui.notify(`Safe-git: Protection ${promptDesc}`, "info");
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("tool_call", async (event, ctx) => {
|
||||
if (event.toolName === "bash") {
|
||||
const command = event.input.command as string;
|
||||
const gitResult = await checkGitCommand(
|
||||
command,
|
||||
ctx,
|
||||
sessionApprovedActions,
|
||||
sessionBlockedActions,
|
||||
gateState,
|
||||
sessionEnabledOverride.value,
|
||||
sessionPromptLevelOverride.value,
|
||||
);
|
||||
if (gitResult) return gitResult;
|
||||
return checkBashCommand(command, ctx);
|
||||
}
|
||||
|
||||
if (event.toolName === "write" || event.toolName === "edit") {
|
||||
const filePath = event.input.path as string;
|
||||
return checkWritePath(filePath, ctx);
|
||||
}
|
||||
|
||||
return undefined;
|
||||
});
|
||||
|
||||
pi.on("tool_result", async (event, ctx) => {
|
||||
if (event.isError) return undefined;
|
||||
|
||||
const textContent = event.content.find(
|
||||
(c): c is { type: "text"; text: string } => c.type === "text",
|
||||
);
|
||||
if (!textContent) return undefined;
|
||||
|
||||
return redactToolResult(
|
||||
event.toolName,
|
||||
event.input.path as string | undefined,
|
||||
textContent.text,
|
||||
ctx,
|
||||
);
|
||||
});
|
||||
}
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,153 +0,0 @@
|
|||
/**
|
||||
* MCP Client OAuth / Auth helpers
|
||||
*
|
||||
* Builds transport options (headers, OAuthClientProvider) from MCP server
|
||||
* config entries so that HTTP transports can authenticate with remote
|
||||
* servers (Sentry, Linear, etc.).
|
||||
*
|
||||
* Fixes #2160 — MCP HTTP transport lacked an OAuth auth provider.
|
||||
*/
|
||||
|
||||
import type { OAuthClientProvider } from "@modelcontextprotocol/sdk/client/auth.js";
|
||||
import type { StreamableHTTPClientTransportOptions } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
||||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
export interface McpHttpAuthHeaders {
|
||||
/** Static headers to attach to every request, e.g. `{ Authorization: "Bearer ${TOKEN}" }`. */
|
||||
headers?: Record<string, string>;
|
||||
}
|
||||
|
||||
export interface McpHttpOAuthConfig {
|
||||
/** OAuth configuration for servers that require the full OAuth flow. */
|
||||
oauth?: {
|
||||
clientId: string;
|
||||
clientSecret?: string;
|
||||
scopes?: string[];
|
||||
redirectUrl?: string;
|
||||
};
|
||||
}
|
||||
|
||||
/** Union of all auth-related config fields for an HTTP MCP server. */
|
||||
export type McpHttpAuthConfig = McpHttpAuthHeaders & McpHttpOAuthConfig;
|
||||
|
||||
// ─── Env resolution ───────────────────────────────────────────────────────────
|
||||
|
||||
/** Resolve `${VAR}` references in a string against `process.env`. */
|
||||
function resolveEnvValue(value: string): string {
|
||||
return value.replace(
|
||||
/\$\{([^}]+)\}/g,
|
||||
(_match, varName) => process.env[varName] ?? "",
|
||||
);
|
||||
}
|
||||
|
||||
function resolveHeaders(raw: Record<string, string>): Record<string, string> {
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(raw)) {
|
||||
resolved[key] = typeof value === "string" ? resolveEnvValue(value) : value;
|
||||
}
|
||||
return resolved;
|
||||
}
|
||||
|
||||
// ─── OAuth provider (minimal CLI-friendly implementation) ─────────────────────
|
||||
|
||||
/**
|
||||
* Creates a minimal `OAuthClientProvider` suitable for CLI / headless use.
|
||||
*
|
||||
* This provider supports:
|
||||
* - Pre-configured client credentials (client_id, optional client_secret)
|
||||
* - Token storage in memory (per-session)
|
||||
* - Scopes
|
||||
*
|
||||
* For full interactive OAuth flows (browser redirect), a richer provider would
|
||||
* be needed, but for server-to-server and pre-authed scenarios this is
|
||||
* sufficient.
|
||||
*/
|
||||
function createCliOAuthProvider(
|
||||
config: NonNullable<McpHttpOAuthConfig["oauth"]>,
|
||||
): OAuthClientProvider {
|
||||
let storedTokens:
|
||||
| { access_token: string; token_type: string; refresh_token?: string }
|
||||
| undefined;
|
||||
let storedCodeVerifier = "";
|
||||
|
||||
return {
|
||||
get redirectUrl() {
|
||||
return config.redirectUrl ?? "http://localhost:0/callback";
|
||||
},
|
||||
|
||||
get clientMetadata() {
|
||||
return {
|
||||
redirect_uris: [config.redirectUrl ?? "http://localhost:0/callback"],
|
||||
client_name: "sf",
|
||||
...(config.scopes ? { scope: config.scopes.join(" ") } : {}),
|
||||
};
|
||||
},
|
||||
|
||||
clientInformation() {
|
||||
return {
|
||||
client_id: config.clientId,
|
||||
...(config.clientSecret ? { client_secret: config.clientSecret } : {}),
|
||||
};
|
||||
},
|
||||
|
||||
tokens() {
|
||||
return storedTokens;
|
||||
},
|
||||
|
||||
saveTokens(tokens) {
|
||||
storedTokens = tokens as typeof storedTokens;
|
||||
},
|
||||
|
||||
redirectToAuthorization(authorizationUrl: URL) {
|
||||
// In a CLI context we can't open a browser automatically.
|
||||
// Log the URL so the user can manually visit it.
|
||||
// eslint-disable-next-line no-console
|
||||
console.error(
|
||||
`[MCP OAuth] Authorization required. Visit:\n ${authorizationUrl.toString()}`,
|
||||
);
|
||||
},
|
||||
|
||||
saveCodeVerifier(codeVerifier: string) {
|
||||
storedCodeVerifier = codeVerifier;
|
||||
},
|
||||
|
||||
codeVerifier() {
|
||||
return storedCodeVerifier;
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Public API ───────────────────────────────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Build `StreamableHTTPClientTransportOptions` from an MCP server config's
|
||||
* auth-related fields.
|
||||
*
|
||||
* Supports two auth strategies:
|
||||
* 1. **`headers`** — static Authorization (or other) headers, with `${VAR}` env resolution.
|
||||
* 2. **`oauth`** — full OAuthClientProvider for servers that implement MCP OAuth.
|
||||
*
|
||||
* When both are provided, `oauth` takes precedence (the SDK's built-in OAuth
|
||||
* flow handles token refresh automatically).
|
||||
*/
|
||||
export function buildHttpTransportOpts(
|
||||
authConfig: McpHttpAuthConfig,
|
||||
): StreamableHTTPClientTransportOptions {
|
||||
const opts: StreamableHTTPClientTransportOptions = {};
|
||||
|
||||
// OAuth takes precedence
|
||||
if (authConfig.oauth) {
|
||||
opts.authProvider = createCliOAuthProvider(authConfig.oauth);
|
||||
return opts;
|
||||
}
|
||||
|
||||
// Static headers (with env var resolution)
|
||||
if (authConfig.headers && Object.keys(authConfig.headers).length > 0) {
|
||||
opts.requestInit = {
|
||||
headers: resolveHeaders(authConfig.headers),
|
||||
};
|
||||
}
|
||||
|
||||
return opts;
|
||||
}
|
||||
|
|
@ -1,741 +0,0 @@
|
|||
/**
|
||||
* MCP Client Extension — Native MCP server integration for pi
|
||||
*
|
||||
* Provides on-demand access to MCP servers configured in project files
|
||||
* (.mcp.json, .sf/mcp.json) using the @modelcontextprotocol/sdk Client
|
||||
* directly — no external CLI dependency required.
|
||||
*
|
||||
* Three tools:
|
||||
* mcp_servers — List available MCP servers from config files
|
||||
* mcp_discover — Get tool signatures for a specific server (lazy connect)
|
||||
* mcp_call — Call a tool on an MCP server (lazy connect)
|
||||
*/
|
||||
|
||||
import { existsSync, readFileSync } from "node:fs";
|
||||
import { homedir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { Client } from "@modelcontextprotocol/sdk/client";
|
||||
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
|
||||
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import {
|
||||
DEFAULT_MAX_BYTES,
|
||||
DEFAULT_MAX_LINES,
|
||||
formatSize,
|
||||
truncateHead,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
import type { McpHttpAuthConfig } from "./auth.js";
|
||||
import { buildHttpTransportOpts } from "./auth.js";
|
||||
|
||||
// ─── Types ────────────────────────────────────────────────────────────────────
|
||||
|
||||
interface McpServerConfig {
|
||||
name: string;
|
||||
transport: "stdio" | "http" | "unknown";
|
||||
command?: string;
|
||||
args?: string[];
|
||||
env?: Record<string, string>;
|
||||
url?: string;
|
||||
cwd?: string;
|
||||
/** Static headers for HTTP transport (supports ${VAR} env resolution). */
|
||||
headers?: Record<string, string>;
|
||||
/** OAuth config for HTTP transport. */
|
||||
oauth?: McpHttpAuthConfig["oauth"];
|
||||
}
|
||||
|
||||
interface McpToolSchema {
|
||||
name: string;
|
||||
description: string;
|
||||
inputSchema?: Record<string, unknown>;
|
||||
}
|
||||
|
||||
interface ManagedConnection {
|
||||
client: Client;
|
||||
transport: StdioClientTransport | StreamableHTTPClientTransport;
|
||||
}
|
||||
|
||||
// ─── Connection Manager ───────────────────────────────────────────────────────
|
||||
|
||||
const connections = new Map<string, ManagedConnection>();
|
||||
let configCache: McpServerConfig[] | null = null;
|
||||
/** Servers whose MCP tools have been auto-registered as first-class pi tools. */
|
||||
const autoRegisteredServers = new Set<string>();
|
||||
const toolCache = new Map<string, McpToolSchema[]>();
|
||||
|
||||
function readConfigs(): McpServerConfig[] {
|
||||
if (configCache) return configCache;
|
||||
|
||||
const servers: McpServerConfig[] = [];
|
||||
const seen = new Set<string>();
|
||||
// Search order matters: first hit wins (seen-guard below), so put
|
||||
// project-local configs first — a project can override or shadow a
|
||||
// globally-registered server by re-declaring the same name.
|
||||
const sfHome = process.env.SF_HOME || join(homedir(), ".sf");
|
||||
const configPaths = [
|
||||
join(process.cwd(), ".mcp.json"),
|
||||
join(process.cwd(), ".sf", "mcp.json"),
|
||||
join(sfHome, "mcp.json"), // global: ~/.sf/mcp.json
|
||||
join(sfHome, "agent", "mcp.json"), // global: ~/.sf/agent/mcp.json (legacy alt)
|
||||
join(homedir(), ".mcp.json"), // user-global: ~/.mcp.json (Claude Code, npx, etc.)
|
||||
];
|
||||
|
||||
for (const configPath of configPaths) {
|
||||
try {
|
||||
if (!existsSync(configPath)) continue;
|
||||
const raw = readFileSync(configPath, "utf-8");
|
||||
const data = JSON.parse(raw) as Record<string, unknown>;
|
||||
const mcpServers = (data.mcpServers ?? data.servers) as
|
||||
| Record<string, Record<string, unknown>>
|
||||
| undefined;
|
||||
if (!mcpServers || typeof mcpServers !== "object") continue;
|
||||
|
||||
for (const [name, config] of Object.entries(mcpServers)) {
|
||||
if (seen.has(name)) continue;
|
||||
seen.add(name);
|
||||
|
||||
const hasCommand = typeof config.command === "string";
|
||||
const hasUrl = typeof config.url === "string";
|
||||
const transport: McpServerConfig["transport"] = hasCommand
|
||||
? "stdio"
|
||||
: hasUrl
|
||||
? "http"
|
||||
: "unknown";
|
||||
|
||||
const hasHeaders =
|
||||
hasUrl && config.headers && typeof config.headers === "object";
|
||||
const hasOAuth =
|
||||
hasUrl && config.oauth && typeof config.oauth === "object";
|
||||
|
||||
servers.push({
|
||||
name,
|
||||
transport,
|
||||
...(hasCommand && {
|
||||
command: config.command as string,
|
||||
args: Array.isArray(config.args)
|
||||
? (config.args as string[])
|
||||
: undefined,
|
||||
env:
|
||||
config.env && typeof config.env === "object"
|
||||
? (config.env as Record<string, string>)
|
||||
: undefined,
|
||||
cwd: typeof config.cwd === "string" ? config.cwd : undefined,
|
||||
}),
|
||||
...(hasUrl && { url: config.url as string }),
|
||||
headers: hasHeaders
|
||||
? (config.headers as Record<string, string>)
|
||||
: undefined,
|
||||
oauth: hasOAuth
|
||||
? (config.oauth as McpHttpAuthConfig["oauth"])
|
||||
: undefined,
|
||||
});
|
||||
}
|
||||
} catch {
|
||||
// Non-fatal — config file may not exist or be malformed
|
||||
}
|
||||
}
|
||||
|
||||
configCache = servers;
|
||||
return servers;
|
||||
}
|
||||
|
||||
function getServerConfig(name: string): McpServerConfig | undefined {
|
||||
const trimmed = name.trim();
|
||||
return readConfigs().find(
|
||||
(s) => s.name === trimmed || s.name.toLowerCase() === trimmed.toLowerCase(),
|
||||
);
|
||||
}
|
||||
|
||||
/** Resolve ${VAR} references in env values against process.env. */
|
||||
function resolveEnv(env: Record<string, string>): Record<string, string> {
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(env)) {
|
||||
if (typeof value === "string") {
|
||||
resolved[key] = value.replace(
|
||||
/\$\{([^}]+)\}/g,
|
||||
(_match, varName) => process.env[varName] ?? "",
|
||||
);
|
||||
} else {
|
||||
resolved[key] = value;
|
||||
}
|
||||
}
|
||||
return resolved;
|
||||
}
|
||||
|
||||
// ─── JSON Schema → TypeBox converter ─────────────────────────────────────────
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
function jsonSchemaPropToTypeBox(schema: Record<string, unknown>): any {
|
||||
if (!schema || typeof schema !== "object") return Type.Any();
|
||||
const t = schema.type as string;
|
||||
if (t === "string") return Type.String({ description: schema.description as string | undefined });
|
||||
if (t === "number" || t === "integer") return Type.Number({ description: schema.description as string | undefined });
|
||||
if (t === "boolean") return Type.Boolean({ description: schema.description as string | undefined });
|
||||
if (t === "array") return Type.Array(Type.Any());
|
||||
if (t === "object") {
|
||||
const props = schema.properties as Record<string, Record<string, unknown>> | undefined;
|
||||
if (props) {
|
||||
const entries: Record<string, unknown> = {};
|
||||
for (const [k, v] of Object.entries(props)) {
|
||||
entries[k] = jsonSchemaPropToTypeBox(v);
|
||||
}
|
||||
return Type.Object(entries as Parameters<typeof Type.Object>[0]);
|
||||
}
|
||||
}
|
||||
return Type.Any();
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
function jsonSchemaToTypeBox(schema: Record<string, unknown> | undefined): any {
|
||||
if (!schema || typeof schema !== "object") return Type.Object({});
|
||||
const obj = schema as Record<string, unknown>;
|
||||
const props = obj.properties as Record<string, Record<string, unknown>> | undefined;
|
||||
if (!props) return Type.Object({});
|
||||
const entries: Record<string, unknown> = {};
|
||||
for (const [k, v] of Object.entries(props)) {
|
||||
entries[k] = jsonSchemaPropToTypeBox(v);
|
||||
}
|
||||
return Type.Object(entries as Parameters<typeof Type.Object>[0]);
|
||||
}
|
||||
|
||||
// ─── Dynamic MCP tool auto-registration ───────────────────────────────────────
|
||||
|
||||
function registerMcpToolsForServer(pi: ExtensionAPI, serverName: string, tools: McpToolSchema[]) {
|
||||
if (autoRegisteredServers.has(serverName)) return;
|
||||
autoRegisteredServers.add(serverName);
|
||||
|
||||
for (const tool of tools) {
|
||||
const piToolName = `${serverName}_${tool.name}`;
|
||||
const description = tool.description || `MCP tool: ${tool.name} on ${serverName}`;
|
||||
// Build parameter TypeBox type from MCP inputSchema
|
||||
const paramType = tool.inputSchema
|
||||
? jsonSchemaToTypeBox(tool.inputSchema)
|
||||
: Type.Object({});
|
||||
|
||||
try {
|
||||
pi.registerTool({
|
||||
name: piToolName,
|
||||
label: `${serverName}:${tool.name}`,
|
||||
description,
|
||||
parameters: paramType,
|
||||
async execute(_id, params) {
|
||||
// Delegate to the internal mcp_call logic directly via the client
|
||||
const client = await getOrConnect(serverName);
|
||||
const result = await client.callTool(
|
||||
{ name: tool.name, arguments: params as Record<string, unknown> },
|
||||
undefined,
|
||||
{ timeout: 60000 },
|
||||
);
|
||||
const contentItems = result.content as Array<{ type: string; text?: string }>;
|
||||
const raw = contentItems
|
||||
.map((c) => (c.type === "text" ? (c.text ?? "") : JSON.stringify(c)))
|
||||
.join("\n");
|
||||
const truncation = truncateHead(raw, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText += `\n\n[Output truncated: ${truncation.outputLines}/${truncation.totalLines} lines]`;
|
||||
}
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: { server: serverName, tool: tool.name },
|
||||
};
|
||||
},
|
||||
});
|
||||
}
|
||||
catch {
|
||||
// Non-fatal — tool registration can fail if schema is unconvertible
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async function getOrConnect(
|
||||
name: string,
|
||||
signal?: AbortSignal,
|
||||
): Promise<Client> {
|
||||
const config = getServerConfig(name);
|
||||
if (!config)
|
||||
throw new Error(
|
||||
`Unknown MCP server: "${name}". Use mcp_servers to list available servers.`,
|
||||
);
|
||||
|
||||
// Always use config.name as the canonical cache key so that variant
|
||||
// casing / whitespace still hits the same connection.
|
||||
const existing = connections.get(config.name);
|
||||
if (existing) return existing.client;
|
||||
|
||||
const client = new Client({ name: "sf", version: "1.0.0" });
|
||||
let transport: StdioClientTransport | StreamableHTTPClientTransport;
|
||||
|
||||
if (config.transport === "stdio" && config.command) {
|
||||
transport = new StdioClientTransport({
|
||||
command: config.command,
|
||||
args: config.args,
|
||||
env: config.env
|
||||
? ({ ...process.env, ...resolveEnv(config.env) } as Record<
|
||||
string,
|
||||
string
|
||||
>)
|
||||
: undefined,
|
||||
cwd: config.cwd,
|
||||
stderr: "pipe",
|
||||
});
|
||||
} else if (config.transport === "http" && config.url) {
|
||||
const resolvedUrl = config.url.replace(
|
||||
/\$\{([^}]+)\}/g,
|
||||
(_, varName) => process.env[varName] ?? "",
|
||||
);
|
||||
const httpOpts = buildHttpTransportOpts({
|
||||
headers: config.headers,
|
||||
oauth: config.oauth,
|
||||
});
|
||||
transport = new StreamableHTTPClientTransport(
|
||||
new URL(resolvedUrl),
|
||||
httpOpts,
|
||||
);
|
||||
} else {
|
||||
throw new Error(
|
||||
`Server "${config.name}" has unsupported transport: ${config.transport}`,
|
||||
);
|
||||
}
|
||||
|
||||
await client.connect(transport, { signal, timeout: 30000 });
|
||||
connections.set(config.name, { client, transport });
|
||||
return client;
|
||||
}
|
||||
|
||||
async function closeAll(): Promise<void> {
|
||||
const closing = Array.from(connections.entries()).map(
|
||||
async ([name, conn]) => {
|
||||
try {
|
||||
await conn.client.close();
|
||||
} catch {
|
||||
// Best-effort cleanup
|
||||
}
|
||||
connections.delete(name);
|
||||
},
|
||||
);
|
||||
await Promise.allSettled(closing);
|
||||
toolCache.clear();
|
||||
}
|
||||
|
||||
// ─── Formatters ───────────────────────────────────────────────────────────────
|
||||
|
||||
function formatServerList(servers: McpServerConfig[]): string {
|
||||
if (servers.length === 0)
|
||||
return "No MCP servers configured. Add servers to .mcp.json or .sf/mcp.json.";
|
||||
|
||||
const lines: string[] = [`${servers.length} MCP servers configured:\n`];
|
||||
|
||||
for (const s of servers) {
|
||||
const connected = connections.has(s.name) ? "✓" : "○";
|
||||
const cached = toolCache.get(s.name);
|
||||
const toolCount = cached ? ` — ${cached.length} tools` : "";
|
||||
lines.push(`${connected} ${s.name} (${s.transport})${toolCount}`);
|
||||
}
|
||||
|
||||
lines.push(
|
||||
"\nUse mcp_discover to see full tool schemas for a specific server.",
|
||||
);
|
||||
lines.push("Use mcp_call to invoke a tool: mcp_call(server, tool, args).");
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
function formatToolList(serverName: string, tools: McpToolSchema[]): string {
|
||||
const lines: string[] = [`${serverName} — ${tools.length} tools:\n`];
|
||||
|
||||
for (const tool of tools) {
|
||||
lines.push(`## ${tool.name}`);
|
||||
if (tool.description) lines.push(tool.description);
|
||||
if (tool.inputSchema) {
|
||||
lines.push("```json");
|
||||
lines.push(JSON.stringify(tool.inputSchema, null, 2));
|
||||
lines.push("```");
|
||||
}
|
||||
lines.push("");
|
||||
}
|
||||
|
||||
lines.push(
|
||||
`Call with: mcp_call(server="${serverName}", tool="<tool_name>", args={...})`,
|
||||
);
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
// ─── Status helper (consumed by /sf mcp) ─────────────────────────────────────
|
||||
|
||||
/**
|
||||
* Return the live connection status for a named MCP server.
|
||||
* Safe to call even when the server has never been connected.
|
||||
*/
|
||||
export function getConnectionStatus(name: string): {
|
||||
connected: boolean;
|
||||
tools: string[];
|
||||
error?: string;
|
||||
} {
|
||||
const conn = connections.get(name);
|
||||
const cached = toolCache.get(name);
|
||||
return {
|
||||
connected: !!conn,
|
||||
tools: cached ? cached.map((t) => t.name) : [],
|
||||
error: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Test-exported helpers ────────────────────────────────────────────────────
|
||||
|
||||
const SAFE_CHILD_ENV_KEYS = new Set([
|
||||
"PATH",
|
||||
"HOME",
|
||||
"USER",
|
||||
"LOGNAME",
|
||||
"SHELL",
|
||||
"LANG",
|
||||
"LC_ALL",
|
||||
"LC_CTYPE",
|
||||
"LC_MESSAGES",
|
||||
"LC_NUMERIC",
|
||||
"LC_TIME",
|
||||
"TMPDIR",
|
||||
"TMP",
|
||||
"TEMP",
|
||||
"TZ",
|
||||
"TERM",
|
||||
"COLORTERM",
|
||||
]);
|
||||
|
||||
export function _buildMcpChildEnvForTest(
|
||||
env: Record<string, string>,
|
||||
): Record<string, string> {
|
||||
const safe: Record<string, string> = {};
|
||||
for (const key of SAFE_CHILD_ENV_KEYS) {
|
||||
if (process.env[key] !== undefined) safe[key] = process.env[key]!;
|
||||
}
|
||||
return { ...safe, ...resolveEnv(env) };
|
||||
}
|
||||
|
||||
export function _buildMcpTrustConfirmOptionsForTest(signal: AbortSignal): {
|
||||
timeout: number;
|
||||
signal: AbortSignal;
|
||||
} {
|
||||
return { timeout: 120_000, signal };
|
||||
}
|
||||
|
||||
// ─── Extension ────────────────────────────────────────────────────────────────
|
||||
|
||||
export default function (pi: ExtensionAPI) {
|
||||
// ── mcp_servers ──────────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool({
|
||||
name: "mcp_servers",
|
||||
label: "MCP Servers",
|
||||
description:
|
||||
"List all available MCP servers configured in project files (.mcp.json, .sf/mcp.json). " +
|
||||
"Shows server names, transport type, and connection status. After mcp_discover, each server's " +
|
||||
"tools are auto-registered as first-class pi tools (e.g. serena_find_symbol).",
|
||||
promptSnippet: "List available MCP servers from project configuration",
|
||||
promptGuidelines: [
|
||||
"Call mcp_servers to see what MCP servers are available before trying to use one.",
|
||||
"After mcp_discover(server), the server's tools appear as real pi tools.",
|
||||
"MCP servers provide external integrations (Twitter, Linear, Railway, etc.) via the Model Context Protocol.",
|
||||
"After listing, use mcp_discover(server) to get tool schemas, then mcp_call(server, tool, args) to invoke.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
refresh: Type.Optional(
|
||||
Type.Boolean({
|
||||
description: "Force refresh the server list (default: use cache)",
|
||||
}),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_id, params) {
|
||||
if (params.refresh) configCache = null;
|
||||
|
||||
const servers = readConfigs();
|
||||
return {
|
||||
content: [{ type: "text", text: formatServerList(servers) }],
|
||||
details: {
|
||||
serverCount: servers.length,
|
||||
cached: !params.refresh && configCache !== null,
|
||||
},
|
||||
};
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("mcp_servers"));
|
||||
if (args.refresh) text += theme.fg("warning", " (refresh)");
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial }, theme) {
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Reading MCP config..."), 0, 0);
|
||||
const d = result.details as { serverCount: number } | undefined;
|
||||
return new Text(
|
||||
theme.fg("success", `${d?.serverCount ?? 0} servers configured`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// ── mcp_discover ─────────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool({
|
||||
name: "mcp_discover",
|
||||
label: "MCP Discover",
|
||||
description:
|
||||
"Get detailed tool signatures and JSON schemas for a specific MCP server. " +
|
||||
"Connects to the server on first call (lazy connection). " +
|
||||
"After discovery, each MCP tool is auto-registered as a first-class pi tool " +
|
||||
"(e.g. serena_find_symbol) — the LLM can call them directly without mcp_call.",
|
||||
promptSnippet:
|
||||
"Discover MCP server tools and register them as first-class pi tools",
|
||||
promptGuidelines: [
|
||||
"Call mcp_discover(server) to connect to an MCP server and surface its tools.",
|
||||
"After discovery, the LLM sees each tool by its real name (e.g. serena_search_for_pattern).",
|
||||
"Call tools directly by their names instead of going through mcp_call.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
server: Type.String({
|
||||
description:
|
||||
"MCP server name (from mcp_servers output), e.g. 'railway', 'twitter-mcp', 'linear'",
|
||||
}),
|
||||
}),
|
||||
|
||||
async execute(_id, params, signal) {
|
||||
try {
|
||||
// Return cached tools if available
|
||||
const cached = toolCache.get(params.server);
|
||||
if (cached) {
|
||||
const text = formatToolList(params.server, cached);
|
||||
const truncation = truncateHead(text, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText += `\n\n[Truncated: ${truncation.outputLines}/${truncation.totalLines} lines (${formatSize(truncation.outputBytes)} of ${formatSize(truncation.totalBytes)})]`;
|
||||
}
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: {
|
||||
server: params.server,
|
||||
toolCount: cached.length,
|
||||
cached: true,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
const client = await getOrConnect(params.server, signal);
|
||||
const result = await client.listTools(undefined, {
|
||||
signal,
|
||||
timeout: 30000,
|
||||
});
|
||||
const tools: McpToolSchema[] = (result.tools ?? []).map((t) => ({
|
||||
name: t.name,
|
||||
description: t.description ?? "",
|
||||
inputSchema: t.inputSchema as Record<string, unknown> | undefined,
|
||||
}));
|
||||
toolCache.set(params.server, tools);
|
||||
|
||||
// Auto-register each MCP tool as a first-class pi tool.
|
||||
// After this, the LLM sees e.g. serena_find_symbol directly instead
|
||||
// of going through the generic mcp_call indirection.
|
||||
registerMcpToolsForServer(pi, params.server, tools);
|
||||
|
||||
const text = formatToolList(params.server, tools);
|
||||
const truncation = truncateHead(text, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText += `\n\n[Truncated: ${truncation.outputLines}/${truncation.totalLines} lines (${formatSize(truncation.outputBytes)} of ${formatSize(truncation.totalBytes)})]`;
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: {
|
||||
server: params.server,
|
||||
toolCount: tools.length,
|
||||
cached: false,
|
||||
},
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
throw new Error(
|
||||
`Failed to discover tools for "${params.server}": ${msg}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("mcp_discover "));
|
||||
text += theme.fg("accent", args.server);
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial }, theme) {
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Discovering tools..."), 0, 0);
|
||||
const d = result.details as
|
||||
| { server: string; toolCount: number }
|
||||
| undefined;
|
||||
return new Text(
|
||||
theme.fg("success", `${d?.toolCount ?? 0} tools`) +
|
||||
theme.fg("dim", ` · ${d?.server}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
},
|
||||
});
|
||||
|
||||
// ── mcp_call ─────────────────────────────────────────────────────────────
|
||||
|
||||
pi.registerTool({
|
||||
name: "mcp_call",
|
||||
label: "MCP Call",
|
||||
description:
|
||||
"Call a tool on an MCP server. Provide the server name, tool name, and arguments. " +
|
||||
"Connects to the server on first call (lazy connection). " +
|
||||
"Use mcp_discover first to see available tools and their required arguments.",
|
||||
promptSnippet: "Call a tool on an MCP server",
|
||||
promptGuidelines: [
|
||||
"Always use mcp_discover first to understand the tool's parameters before calling mcp_call.",
|
||||
"Arguments are passed as a JSON object matching the tool's input schema.",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
server: Type.String({
|
||||
description: "MCP server name, e.g. 'railway', 'twitter-mcp'",
|
||||
}),
|
||||
tool: Type.String({
|
||||
description: "Tool name on that server, e.g. 'railway_list_projects'",
|
||||
}),
|
||||
args: Type.Optional(
|
||||
Type.Object(
|
||||
{},
|
||||
{
|
||||
additionalProperties: true,
|
||||
description:
|
||||
"Tool arguments as key-value pairs matching the tool's input schema",
|
||||
},
|
||||
),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_id, params, signal) {
|
||||
try {
|
||||
const client = await getOrConnect(params.server, signal);
|
||||
const result = await client.callTool(
|
||||
{ name: params.tool, arguments: params.args ?? {} },
|
||||
undefined,
|
||||
{ signal, timeout: 60000 },
|
||||
);
|
||||
|
||||
// Serialize result content to text
|
||||
const contentItems = result.content as Array<{
|
||||
type: string;
|
||||
text?: string;
|
||||
}>;
|
||||
const raw = contentItems
|
||||
.map((c) => (c.type === "text" ? (c.text ?? "") : JSON.stringify(c)))
|
||||
.join("\n");
|
||||
|
||||
const truncation = truncateHead(raw, {
|
||||
maxLines: DEFAULT_MAX_LINES,
|
||||
maxBytes: DEFAULT_MAX_BYTES,
|
||||
});
|
||||
let finalText = truncation.content;
|
||||
if (truncation.truncated) {
|
||||
finalText += `\n\n[Output truncated: ${truncation.outputLines}/${truncation.totalLines} lines (${formatSize(truncation.outputBytes)} of ${formatSize(truncation.totalBytes)})]`;
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: finalText }],
|
||||
details: {
|
||||
server: params.server,
|
||||
tool: params.tool,
|
||||
charCount: finalText.length,
|
||||
truncated: truncation.truncated,
|
||||
},
|
||||
};
|
||||
} catch (err: unknown) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
throw new Error(
|
||||
`MCP call failed: ${params.server}.${params.tool}\n${msg}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("mcp_call "));
|
||||
text += theme.fg("accent", `${args.server}.${args.tool}`);
|
||||
if (args.args && Object.keys(args.args).length > 0) {
|
||||
const preview = Object.entries(args.args)
|
||||
.slice(0, 3)
|
||||
.map(([k, v]) => {
|
||||
const val = typeof v === "string" ? v : JSON.stringify(v);
|
||||
return `${k}:${val.length > 30 ? val.slice(0, 30) + "…" : val}`;
|
||||
})
|
||||
.join(" ");
|
||||
text += " " + theme.fg("muted", preview);
|
||||
}
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial, expanded }, theme) {
|
||||
if (isPartial)
|
||||
return new Text(theme.fg("warning", "Calling MCP tool..."), 0, 0);
|
||||
|
||||
const d = result.details as
|
||||
| {
|
||||
server: string;
|
||||
tool: string;
|
||||
charCount: number;
|
||||
truncated: boolean;
|
||||
}
|
||||
| undefined;
|
||||
|
||||
let text = theme.fg("success", `✓ ${d?.server}.${d?.tool}`);
|
||||
text += theme.fg(
|
||||
"dim",
|
||||
` · ${(d?.charCount ?? 0).toLocaleString()} chars`,
|
||||
);
|
||||
if (d?.truncated) text += theme.fg("warning", " · truncated");
|
||||
|
||||
if (expanded) {
|
||||
const content = result.content[0];
|
||||
if (content?.type === "text") {
|
||||
const preview = content.text.split("\n").slice(0, 15).join("\n");
|
||||
text += "\n\n" + theme.fg("dim", preview);
|
||||
}
|
||||
}
|
||||
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
});
|
||||
|
||||
// ── Lifecycle ─────────────────────────────────────────────────────────────
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
const servers = readConfigs();
|
||||
if (servers.length > 0) {
|
||||
ctx.ui.notify(
|
||||
`MCP client ready — ${servers.length} server(s) configured`,
|
||||
"info",
|
||||
);
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
await closeAll();
|
||||
});
|
||||
|
||||
pi.on("session_switch", async () => {
|
||||
await closeAll();
|
||||
configCache = null;
|
||||
});
|
||||
}
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
/**
|
||||
* Regression test for #3029 — mcp_discover fails for server names with spaces.
|
||||
*
|
||||
* The getServerConfig lookup must handle:
|
||||
* 1. Exact match (already works)
|
||||
* 2. Names with leading/trailing whitespace (trimming)
|
||||
* 3. Case-insensitive matching (e.g. "Langgraph code" vs "langgraph Code")
|
||||
*
|
||||
* We test at the source level since getServerConfig is not exported.
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { dirname, join } from "node:path";
|
||||
import { test } from 'vitest';
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = dirname(__filename);
|
||||
|
||||
const source = readFileSync(join(__dirname, "..", "index.ts"), "utf-8");
|
||||
|
||||
test("#3029: getServerConfig trims whitespace from input name", () => {
|
||||
assert.ok(
|
||||
source.includes(".trim()"),
|
||||
"getServerConfig should trim the input name before comparison",
|
||||
);
|
||||
});
|
||||
|
||||
test("#3029: getServerConfig performs case-insensitive matching", () => {
|
||||
assert.ok(
|
||||
source.includes(".toLowerCase()"),
|
||||
"getServerConfig should compare names case-insensitively",
|
||||
);
|
||||
});
|
||||
|
||||
test("#3029: getOrConnect normalizes name for connection cache lookup", () => {
|
||||
// The connections Map key must use the canonical (config) name, not the
|
||||
// raw user input, so that subsequent lookups hit the cache even when the
|
||||
// user's casing differs.
|
||||
const getOrConnectMatch = source.match(
|
||||
/async function getOrConnect\(\s*name:\s*string[\s\S]*?const existing = connections\.get\(/,
|
||||
);
|
||||
assert.ok(getOrConnectMatch, "getOrConnect function should exist");
|
||||
// After the fix, getOrConnect should normalize the name via getServerConfig
|
||||
// or use config.name as the canonical cache key.
|
||||
assert.ok(
|
||||
source.includes("connections.get(config.name") ||
|
||||
source.includes("connections.set(config.name"),
|
||||
"getOrConnect should use config.name (canonical) as the connections cache key",
|
||||
);
|
||||
});
|
||||
|
|
@ -1,168 +0,0 @@
|
|||
// sf — Ollama Extension: First-class local LLM support
|
||||
/**
|
||||
* Ollama Extension
|
||||
*
|
||||
* Auto-detects a running Ollama instance, discovers locally pulled models,
|
||||
* and registers them as a first-class provider. No configuration required —
|
||||
* if Ollama is running, models appear automatically.
|
||||
*
|
||||
* Features:
|
||||
* - Auto-discovery of local models via /api/tags
|
||||
* - Capability detection (vision, reasoning, context window)
|
||||
* - /ollama slash commands for model management
|
||||
* - ollama_manage tool for LLM-driven model operations
|
||||
* - Zero-cost model registration (local inference)
|
||||
*
|
||||
* Respects OLLAMA_HOST env var for non-default endpoints.
|
||||
*/
|
||||
|
||||
import {
|
||||
type ExtensionAPI,
|
||||
importExtensionModule,
|
||||
} from "@singularity-forge/pi-coding-agent";
|
||||
import { streamOllamaChat } from "./ollama-chat-provider.js";
|
||||
import * as client from "./ollama-client.js";
|
||||
import { registerOllamaCommands } from "./ollama-commands.js";
|
||||
import { discoverModels } from "./ollama-discovery.js";
|
||||
|
||||
let toolsPromise: Promise<void> | null = null;
|
||||
|
||||
async function registerOllamaTools(pi: ExtensionAPI): Promise<void> {
|
||||
if (!toolsPromise) {
|
||||
toolsPromise = (async () => {
|
||||
const { registerOllamaTool } = await importExtensionModule<
|
||||
typeof import("./ollama-tool.js")
|
||||
>(import.meta.url, "./ollama-tool.js");
|
||||
registerOllamaTool(pi);
|
||||
})().catch((error) => {
|
||||
toolsPromise = null;
|
||||
throw error;
|
||||
});
|
||||
}
|
||||
return toolsPromise;
|
||||
}
|
||||
|
||||
/** Track whether we've registered models so we can clean up on shutdown */
|
||||
let providerRegistered = false;
|
||||
|
||||
/**
|
||||
* Opt-in check: skip the probe entirely unless OLLAMA_HOST is explicitly set.
|
||||
*
|
||||
* Rationale: the historical behavior was to probe http://localhost:11434 on
|
||||
* every startup, which produced startup cost and a "[phase] ollama" status
|
||||
* indicator even for users who have never run Ollama locally and never will.
|
||||
* Making the probe opt-in means:
|
||||
* - No-op for users who don't use Ollama (the vast majority).
|
||||
* - Works for ollama-cloud: set OLLAMA_HOST=https://ollama.com and
|
||||
* OLLAMA_API_KEY and the existing discovery/register path runs unchanged.
|
||||
* - Works for self-hosted local Ollama: set OLLAMA_HOST=http://localhost:11434
|
||||
* explicitly to re-enable the old behavior.
|
||||
*/
|
||||
function isOllamaConfigured(): boolean {
|
||||
const host = process.env.OLLAMA_HOST;
|
||||
return typeof host === "string" && host.trim().length > 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Probe Ollama and register discovered models.
|
||||
* Safe to call multiple times — re-discovers and re-registers.
|
||||
*/
|
||||
async function probeAndRegister(pi: ExtensionAPI): Promise<boolean> {
|
||||
if (!isOllamaConfigured()) return false;
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
if (providerRegistered) {
|
||||
pi.unregisterProvider("ollama");
|
||||
providerRegistered = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const models = await discoverModels();
|
||||
if (models.length === 0) {
|
||||
// No local models means there's nothing usable to register in SF.
|
||||
// Keep the footer/status clean instead of advertising Ollama availability.
|
||||
if (providerRegistered) {
|
||||
pi.unregisterProvider("ollama");
|
||||
providerRegistered = false;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
const baseUrl = client.getOllamaHost();
|
||||
|
||||
// Use authMode "apiKey" (#3440). Local Ollama ignores the Authorization header,
|
||||
// so the "ollama" fallback is harmless. For cloud endpoints (OLLAMA_HOST pointing
|
||||
// to ollama.com or a remote instance), OLLAMA_API_KEY is picked up here.
|
||||
pi.registerProvider("ollama", {
|
||||
authMode: "apiKey",
|
||||
apiKey: process.env.OLLAMA_API_KEY ?? "ollama",
|
||||
baseUrl,
|
||||
api: "ollama-chat",
|
||||
streamSimple: streamOllamaChat,
|
||||
isReady: () => true,
|
||||
models: models.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
reasoning: m.reasoning,
|
||||
input: m.input,
|
||||
cost: m.cost,
|
||||
contextWindow: m.contextWindow,
|
||||
maxTokens: m.maxTokens,
|
||||
providerOptions: (m.ollamaOptions ?? {}) as Record<string, unknown>,
|
||||
})),
|
||||
});
|
||||
|
||||
providerRegistered = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
export default function ollama(pi: ExtensionAPI) {
|
||||
// Opt-in: skip all registration if OLLAMA_HOST is not configured.
|
||||
// See isOllamaConfigured() for rationale.
|
||||
if (!isOllamaConfigured()) return;
|
||||
|
||||
// Register slash commands immediately (they check Ollama availability themselves)
|
||||
registerOllamaCommands(pi);
|
||||
|
||||
pi.on("session_start", async (_event, ctx) => {
|
||||
// Register tool (deferred to avoid blocking startup)
|
||||
if (ctx.hasUI) {
|
||||
void registerOllamaTools(pi).catch((error) => {
|
||||
ctx.ui.notify(
|
||||
`Ollama tool failed to load: ${error instanceof Error ? error.message : String(error)}`,
|
||||
"warning",
|
||||
);
|
||||
});
|
||||
} else {
|
||||
await registerOllamaTools(pi);
|
||||
}
|
||||
|
||||
// In headless/auto mode, await the probe so the fallback resolver can
|
||||
// see Ollama before the first LLM call (#3531 race condition).
|
||||
// In interactive mode, keep it async for fast startup.
|
||||
if (!ctx.hasUI) {
|
||||
try {
|
||||
await probeAndRegister(pi);
|
||||
} catch {
|
||||
/* non-fatal */
|
||||
}
|
||||
} else {
|
||||
probeAndRegister(pi)
|
||||
.then((found) => {
|
||||
ctx.ui.setStatus("ollama", found ? "Ollama" : undefined);
|
||||
})
|
||||
.catch(() => {
|
||||
ctx.ui.setStatus("ollama", undefined);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
pi.on("session_shutdown", async () => {
|
||||
if (providerRegistered) {
|
||||
pi.unregisterProvider("ollama");
|
||||
providerRegistered = false;
|
||||
}
|
||||
toolsPromise = null;
|
||||
});
|
||||
}
|
||||
|
|
@ -1,374 +0,0 @@
|
|||
// sf — Known model capability table for Ollama models
|
||||
|
||||
/**
|
||||
* Maps well-known Ollama model families to their capabilities.
|
||||
* Used to enrich auto-discovered models with accurate context windows,
|
||||
* vision support, and reasoning detection.
|
||||
*
|
||||
* Fallback: estimate from parameter count if model isn't in the table.
|
||||
*/
|
||||
|
||||
import type { OllamaChatOptions } from "./types.js";
|
||||
|
||||
export interface ModelCapability {
|
||||
contextWindow?: number;
|
||||
maxTokens?: number;
|
||||
input?: ("text" | "image")[];
|
||||
reasoning?: boolean;
|
||||
/** Ollama-specific default inference options for this model family. */
|
||||
ollamaOptions?: OllamaChatOptions;
|
||||
}
|
||||
|
||||
/**
|
||||
* Known model family capabilities.
|
||||
* Keys are matched as prefixes against the model name (before the colon/tag).
|
||||
* More specific entries should appear first.
|
||||
*/
|
||||
// Note: ollamaOptions.num_ctx is set for known model families where the context
|
||||
// window is authoritative. For unknown/estimated models, num_ctx is NOT sent
|
||||
// to avoid OOM risk — Ollama uses its own safe default instead.
|
||||
const KNOWN_MODELS: Array<[pattern: string, caps: ModelCapability]> = [
|
||||
// ─── Reasoning models ───────────────────────────────────────────────
|
||||
[
|
||||
"deepseek-r1",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
reasoning: true,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"qwq",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
reasoning: true,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Vision models ──────────────────────────────────────────────────
|
||||
[
|
||||
"llava",
|
||||
{
|
||||
contextWindow: 4096,
|
||||
input: ["text", "image"],
|
||||
ollamaOptions: { num_ctx: 4096 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"bakllava",
|
||||
{
|
||||
contextWindow: 4096,
|
||||
input: ["text", "image"],
|
||||
ollamaOptions: { num_ctx: 4096 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"moondream",
|
||||
{
|
||||
contextWindow: 8192,
|
||||
input: ["text", "image"],
|
||||
ollamaOptions: { num_ctx: 8192 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"llama3.2-vision",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
input: ["text", "image"],
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"minicpm-v",
|
||||
{
|
||||
contextWindow: 4096,
|
||||
input: ["text", "image"],
|
||||
ollamaOptions: { num_ctx: 4096 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Code models ────────────────────────────────────────────────────
|
||||
[
|
||||
"codestral",
|
||||
{
|
||||
contextWindow: 262144,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 262144 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"qwen2.5-coder",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"deepseek-coder-v2",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"starcoder2",
|
||||
{
|
||||
contextWindow: 16384,
|
||||
maxTokens: 8192,
|
||||
ollamaOptions: { num_ctx: 16384 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"codegemma",
|
||||
{ contextWindow: 8192, maxTokens: 8192, ollamaOptions: { num_ctx: 8192 } },
|
||||
],
|
||||
[
|
||||
"codellama",
|
||||
{
|
||||
contextWindow: 16384,
|
||||
maxTokens: 8192,
|
||||
ollamaOptions: { num_ctx: 16384 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"devstral",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Llama family ───────────────────────────────────────────────────
|
||||
[
|
||||
"llama3.3",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"llama3.2",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"llama3.1",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"llama3",
|
||||
{ contextWindow: 8192, maxTokens: 8192, ollamaOptions: { num_ctx: 8192 } },
|
||||
],
|
||||
[
|
||||
"llama2",
|
||||
{ contextWindow: 4096, maxTokens: 4096, ollamaOptions: { num_ctx: 4096 } },
|
||||
],
|
||||
|
||||
// ─── Qwen family ────────────────────────────────────────────────────
|
||||
[
|
||||
"qwen3",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"qwen2.5",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"qwen2",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 32768,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Gemma family ───────────────────────────────────────────────────
|
||||
[
|
||||
"gemma3",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"gemma2",
|
||||
{ contextWindow: 8192, maxTokens: 8192, ollamaOptions: { num_ctx: 8192 } },
|
||||
],
|
||||
|
||||
// ─── Mistral family ─────────────────────────────────────────────────
|
||||
[
|
||||
"mistral-large",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"mistral-small",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"mistral-nemo",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"mistral",
|
||||
{
|
||||
contextWindow: 32768,
|
||||
maxTokens: 8192,
|
||||
ollamaOptions: { num_ctx: 32768 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"mixtral",
|
||||
{
|
||||
contextWindow: 32768,
|
||||
maxTokens: 8192,
|
||||
ollamaOptions: { num_ctx: 32768 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Phi family ─────────────────────────────────────────────────────
|
||||
[
|
||||
"phi4",
|
||||
{
|
||||
contextWindow: 16384,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 16384 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"phi3.5",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"phi3",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 4096,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
|
||||
// ─── Command R ──────────────────────────────────────────────────────
|
||||
[
|
||||
"command-r-plus",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
[
|
||||
"command-r",
|
||||
{
|
||||
contextWindow: 131072,
|
||||
maxTokens: 16384,
|
||||
ollamaOptions: { num_ctx: 131072 },
|
||||
},
|
||||
],
|
||||
];
|
||||
|
||||
/**
|
||||
* Look up capabilities for a model by name.
|
||||
* Matches the longest prefix from the known models table.
|
||||
*/
|
||||
export function getModelCapabilities(modelName: string): ModelCapability {
|
||||
// Strip tag (everything after the colon) for matching
|
||||
const baseName = modelName.split(":")[0].toLowerCase();
|
||||
|
||||
for (const [pattern, caps] of KNOWN_MODELS) {
|
||||
if (baseName === pattern || baseName.startsWith(pattern)) {
|
||||
return caps;
|
||||
}
|
||||
}
|
||||
|
||||
return {};
|
||||
}
|
||||
|
||||
/**
|
||||
* Estimate context window from parameter size string (e.g. "7B", "70B", "1.5B").
|
||||
* Used as fallback when model isn't in the known table.
|
||||
*/
|
||||
export function estimateContextFromParams(parameterSize: string): number {
|
||||
const match = parameterSize.match(/([\d.]+)\s*([BbMm])/);
|
||||
if (!match) return 8192;
|
||||
|
||||
const size = parseFloat(match[1]);
|
||||
const unit = match[2].toUpperCase();
|
||||
|
||||
// Convert to billions
|
||||
const billions = unit === "M" ? size / 1000 : size;
|
||||
|
||||
// Rough heuristics: larger models tend to support larger contexts
|
||||
if (billions >= 70) return 131072;
|
||||
if (billions >= 30) return 65536;
|
||||
if (billions >= 13) return 32768;
|
||||
if (billions >= 7) return 16384;
|
||||
return 8192;
|
||||
}
|
||||
|
||||
/**
|
||||
* Humanize a model name for display (e.g. "llama3.1:8b" → "Llama 3.1 8B").
|
||||
*/
|
||||
export function humanizeModelName(modelName: string): string {
|
||||
const [base, tag] = modelName.split(":");
|
||||
|
||||
// Capitalize first letter, add spaces around version numbers
|
||||
let name = base
|
||||
.replace(/([a-z])(\d)/g, "$1 $2")
|
||||
.replace(/(\d)([a-z])/g, "$1 $2")
|
||||
.replace(/^./, (c) => c.toUpperCase());
|
||||
|
||||
// Clean up common patterns
|
||||
name = name.replace(/\s*-\s*/g, " ");
|
||||
|
||||
if (tag && tag !== "latest") {
|
||||
name += ` ${tag.toUpperCase()}`;
|
||||
}
|
||||
|
||||
return name;
|
||||
}
|
||||
|
||||
/**
|
||||
* Format byte size for display (e.g. 4700000000 → "4.7 GB").
|
||||
*/
|
||||
export function formatModelSize(bytes: number): string {
|
||||
if (bytes >= 1e9) return `${(bytes / 1e9).toFixed(1)} GB`;
|
||||
if (bytes >= 1e6) return `${(bytes / 1e6).toFixed(1)} MB`;
|
||||
return `${(bytes / 1e3).toFixed(0)} KB`;
|
||||
}
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
// sf — Ollama Extension: NDJSON streaming parser
|
||||
|
||||
/**
|
||||
* Parses a streaming NDJSON (newline-delimited JSON) response body into
|
||||
* typed objects. Used for Ollama's /api/chat and /api/pull endpoints.
|
||||
*
|
||||
* @param strict When true, malformed JSON lines throw instead of being skipped.
|
||||
* Use strict mode for inference streams where silent data loss is unacceptable.
|
||||
* Use permissive mode (default) for progress endpoints like /api/pull.
|
||||
*/
|
||||
|
||||
export async function* parseNDJsonStream<T>(
|
||||
body: ReadableStream<Uint8Array>,
|
||||
signal?: AbortSignal,
|
||||
strict = false,
|
||||
): AsyncGenerator<T> {
|
||||
const reader = body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
if (signal?.aborted) break;
|
||||
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
const lines = buffer.split("\n");
|
||||
buffer = lines.pop() ?? "";
|
||||
|
||||
for (const line of lines) {
|
||||
const trimmed = line.trim();
|
||||
if (!trimmed) continue;
|
||||
try {
|
||||
yield JSON.parse(trimmed) as T;
|
||||
} catch (_err) {
|
||||
if (strict) {
|
||||
throw new Error(
|
||||
`Malformed NDJSON line from Ollama: ${trimmed.slice(0, 200)}`,
|
||||
);
|
||||
}
|
||||
// Permissive mode: skip malformed lines
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush remaining buffer (skip if aborted)
|
||||
if (buffer.trim() && !signal?.aborted) {
|
||||
try {
|
||||
yield JSON.parse(buffer.trim()) as T;
|
||||
} catch (_err) {
|
||||
if (strict) {
|
||||
throw new Error(
|
||||
`Malformed NDJSON line from Ollama: ${buffer.trim().slice(0, 200)}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,21 +0,0 @@
|
|||
/**
|
||||
* Regression test for #3440: Ollama extension must register with
|
||||
* authMode "apiKey" (not "none") to avoid streamSimple requirement.
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { dirname, join } from "node:path";
|
||||
import { test } from 'vitest';
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
|
||||
test("Ollama registers with authMode apiKey, not none (#3440)", () => {
|
||||
const src = readFileSync(join(__dirname, "index.ts"), "utf-8");
|
||||
// Find the registerProvider call
|
||||
const registerBlock = src.slice(src.indexOf('pi.registerProvider("ollama"'));
|
||||
const authLine = registerBlock.match(/authMode:\s*"(\w+)"/);
|
||||
assert.ok(authLine, "registerProvider must specify authMode");
|
||||
assert.equal(authLine![1], "apiKey", "authMode must be apiKey, not none");
|
||||
});
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
// sf — Ollama Extension: Native /api/chat stream provider
|
||||
|
||||
/**
|
||||
* Implements the "ollama-chat" API provider, streaming responses directly
|
||||
* from Ollama's native /api/chat endpoint instead of the OpenAI compatibility
|
||||
* shim. This exposes Ollama-specific options (num_ctx, keep_alive, num_gpu,
|
||||
* sampling parameters) and surfaces inference performance metrics.
|
||||
*/
|
||||
|
||||
import {
|
||||
type Api,
|
||||
type AssistantMessage,
|
||||
type AssistantMessageEvent,
|
||||
type AssistantMessageEventStream,
|
||||
type Context,
|
||||
EventStream,
|
||||
type ImageContent,
|
||||
type InferenceMetrics,
|
||||
type Message,
|
||||
type Model,
|
||||
type SimpleStreamOptions,
|
||||
type StopReason,
|
||||
type TextContent,
|
||||
type ThinkingContent,
|
||||
type Tool,
|
||||
type ToolCall,
|
||||
type Usage,
|
||||
} from "@singularity-forge/pi-ai";
|
||||
import { chat } from "./ollama-client.js";
|
||||
import { type ParsedChunk, ThinkingTagParser } from "./thinking-parser.js";
|
||||
import type {
|
||||
OllamaChatMessage,
|
||||
OllamaChatOptions,
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaTool,
|
||||
OllamaToolCall,
|
||||
} from "./types.js";
|
||||
|
||||
/** Create an AssistantMessageEventStream using the base EventStream class. */
|
||||
function createStream(): AssistantMessageEventStream {
|
||||
return new EventStream<AssistantMessageEvent, AssistantMessage>(
|
||||
(event) => event.type === "done" || event.type === "error",
|
||||
(event) => {
|
||||
if (event.type === "done") return event.message;
|
||||
if (event.type === "error") return event.error;
|
||||
throw new Error("Unexpected event type for final result");
|
||||
},
|
||||
) as AssistantMessageEventStream;
|
||||
}
|
||||
|
||||
// ─── Stream handler ─────────────────────────────────────────────────────────
|
||||
|
||||
export function streamOllamaChat(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): AssistantMessageEventStream {
|
||||
const stream = createStream();
|
||||
|
||||
(async () => {
|
||||
const output = buildInitialOutput(model);
|
||||
|
||||
try {
|
||||
const request = buildRequest(model, context, options);
|
||||
stream.push({ type: "start", partial: output });
|
||||
|
||||
const useThinkingParser = model.reasoning;
|
||||
const thinkParser = useThinkingParser ? new ThinkingTagParser() : null;
|
||||
|
||||
let contentIndex = -1;
|
||||
let currentBlockType: "text" | "thinking" | null = null;
|
||||
|
||||
function startBlock(type: "text" | "thinking") {
|
||||
contentIndex++;
|
||||
currentBlockType = type;
|
||||
if (type === "text") {
|
||||
output.content.push({ type: "text", text: "" });
|
||||
stream.push({ type: "text_start", contentIndex, partial: output });
|
||||
} else {
|
||||
output.content.push({ type: "thinking", thinking: "" });
|
||||
stream.push({
|
||||
type: "thinking_start",
|
||||
contentIndex,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function endBlock() {
|
||||
if (currentBlockType === null) return;
|
||||
if (currentBlockType === "text") {
|
||||
const block = output.content[contentIndex] as TextContent;
|
||||
stream.push({
|
||||
type: "text_end",
|
||||
contentIndex,
|
||||
content: block.text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
const block = output.content[contentIndex] as ThinkingContent;
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex,
|
||||
content: block.thinking,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
currentBlockType = null;
|
||||
}
|
||||
|
||||
function emitDelta(type: "text" | "thinking", text: string) {
|
||||
if (!text) return;
|
||||
if (currentBlockType !== type) {
|
||||
endBlock();
|
||||
startBlock(type);
|
||||
}
|
||||
if (type === "text") {
|
||||
(output.content[contentIndex] as TextContent).text += text;
|
||||
stream.push({
|
||||
type: "text_delta",
|
||||
contentIndex,
|
||||
delta: text,
|
||||
partial: output,
|
||||
});
|
||||
} else {
|
||||
(output.content[contentIndex] as ThinkingContent).thinking += text;
|
||||
stream.push({
|
||||
type: "thinking_delta",
|
||||
contentIndex,
|
||||
delta: text,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
function processChunks(chunks: ParsedChunk[]) {
|
||||
for (const chunk of chunks) {
|
||||
emitDelta(chunk.type, chunk.text);
|
||||
}
|
||||
}
|
||||
|
||||
function processToolCalls(toolCalls: OllamaToolCall[]) {
|
||||
endBlock();
|
||||
for (const tc of toolCalls) {
|
||||
contentIndex++;
|
||||
const toolCall: ToolCall = {
|
||||
type: "toolCall",
|
||||
id: `ollama_tc_${contentIndex}`,
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
};
|
||||
output.content.push(toolCall);
|
||||
stream.push({
|
||||
type: "toolcall_start",
|
||||
contentIndex,
|
||||
partial: output,
|
||||
});
|
||||
// Emit a delta with the serialized arguments (convention: start/delta/end)
|
||||
stream.push({
|
||||
type: "toolcall_delta",
|
||||
contentIndex,
|
||||
delta: JSON.stringify(tc.function.arguments),
|
||||
partial: output,
|
||||
});
|
||||
stream.push({
|
||||
type: "toolcall_end",
|
||||
contentIndex,
|
||||
toolCall,
|
||||
partial: output,
|
||||
});
|
||||
}
|
||||
output.stopReason = "toolUse";
|
||||
}
|
||||
|
||||
for await (const chunk of chat(request, options?.signal)) {
|
||||
// Handle text content — process independently of tool_calls
|
||||
// (a chunk may contain both content and tool_calls)
|
||||
const content = chunk.message?.content ?? "";
|
||||
if (content) {
|
||||
if (thinkParser) {
|
||||
processChunks(thinkParser.push(content));
|
||||
} else {
|
||||
emitDelta("text", content);
|
||||
}
|
||||
}
|
||||
|
||||
// Handle tool calls (Ollama sends them complete, may be on done:true chunk)
|
||||
if (chunk.message?.tool_calls?.length) {
|
||||
processToolCalls(chunk.message.tool_calls);
|
||||
}
|
||||
|
||||
if (chunk.done) {
|
||||
// Final chunk — extract metrics and usage
|
||||
if (thinkParser) processChunks(thinkParser.flush());
|
||||
endBlock();
|
||||
|
||||
output.usage = buildUsage(chunk);
|
||||
output.inferenceMetrics = extractMetrics(chunk);
|
||||
// Preserve "toolUse" if tool calls were processed
|
||||
if (output.stopReason !== "toolUse") {
|
||||
output.stopReason = mapStopReason(chunk.done_reason);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
assertStreamSuccess(output, options?.signal);
|
||||
finalizeStream(stream, output);
|
||||
} catch (error) {
|
||||
handleStreamError(stream, output, error, options?.signal);
|
||||
}
|
||||
})();
|
||||
|
||||
return stream;
|
||||
}
|
||||
|
||||
// ─── Request building ───────────────────────────────────────────────────────
|
||||
|
||||
function buildRequest(
|
||||
model: Model<Api>,
|
||||
context: Context,
|
||||
options?: SimpleStreamOptions,
|
||||
): OllamaChatRequest {
|
||||
const ollamaOpts = (model.providerOptions ?? {}) as OllamaChatOptions;
|
||||
|
||||
const request: OllamaChatRequest = {
|
||||
model: model.id,
|
||||
messages: convertMessages(context),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Build options block with all Ollama-specific parameters
|
||||
const reqOptions: NonNullable<OllamaChatRequest["options"]> = {};
|
||||
|
||||
// Context window — only sent when explicitly configured via providerOptions.
|
||||
// Sending inferred/estimated values risks OOM on constrained hosts.
|
||||
// Users can set num_ctx per-model in models.json ollamaOptions or the
|
||||
// capability table can provide it for known model families.
|
||||
if (ollamaOpts.num_ctx !== undefined && ollamaOpts.num_ctx > 0) {
|
||||
reqOptions.num_ctx = ollamaOpts.num_ctx;
|
||||
}
|
||||
|
||||
// Max output tokens
|
||||
const maxTokens = options?.maxTokens ?? model.maxTokens;
|
||||
if (maxTokens > 0) {
|
||||
reqOptions.num_predict = maxTokens;
|
||||
}
|
||||
|
||||
// Temperature
|
||||
if (options?.temperature !== undefined) {
|
||||
reqOptions.temperature = options.temperature;
|
||||
}
|
||||
|
||||
// Per-model sampling options from providerOptions
|
||||
if (ollamaOpts.top_p !== undefined) reqOptions.top_p = ollamaOpts.top_p;
|
||||
if (ollamaOpts.top_k !== undefined) reqOptions.top_k = ollamaOpts.top_k;
|
||||
if (ollamaOpts.repeat_penalty !== undefined)
|
||||
reqOptions.repeat_penalty = ollamaOpts.repeat_penalty;
|
||||
if (ollamaOpts.seed !== undefined) reqOptions.seed = ollamaOpts.seed;
|
||||
if (ollamaOpts.num_gpu !== undefined) reqOptions.num_gpu = ollamaOpts.num_gpu;
|
||||
|
||||
if (Object.keys(reqOptions).length > 0) {
|
||||
request.options = reqOptions;
|
||||
}
|
||||
|
||||
// Keep alive
|
||||
if (ollamaOpts.keep_alive !== undefined) {
|
||||
request.keep_alive = ollamaOpts.keep_alive;
|
||||
}
|
||||
|
||||
// Tools
|
||||
if (context.tools?.length) {
|
||||
request.tools = convertTools(context.tools);
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
// ─── Message conversion ─────────────────────────────────────────────────────
|
||||
|
||||
function convertMessages(context: Context): OllamaChatMessage[] {
|
||||
const messages: OllamaChatMessage[] = [];
|
||||
|
||||
// System prompt
|
||||
if (context.systemPrompt) {
|
||||
messages.push({ role: "system", content: context.systemPrompt });
|
||||
}
|
||||
|
||||
for (const msg of context.messages) {
|
||||
switch (msg.role) {
|
||||
case "user":
|
||||
messages.push(convertUserMessage(msg));
|
||||
break;
|
||||
case "assistant":
|
||||
messages.push(convertAssistantMessage(msg));
|
||||
break;
|
||||
case "toolResult":
|
||||
messages.push({
|
||||
role: "tool",
|
||||
content: msg.content
|
||||
.filter((c): c is TextContent => c.type === "text")
|
||||
.map((c) => c.text)
|
||||
.join("\n"),
|
||||
name: msg.toolName,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return messages;
|
||||
}
|
||||
|
||||
function convertUserMessage(
|
||||
msg: Message & { role: "user" },
|
||||
): OllamaChatMessage {
|
||||
if (typeof msg.content === "string") {
|
||||
return { role: "user", content: msg.content };
|
||||
}
|
||||
|
||||
const textParts: string[] = [];
|
||||
const images: string[] = [];
|
||||
|
||||
for (const part of msg.content) {
|
||||
if (part.type === "text") {
|
||||
textParts.push(part.text);
|
||||
} else if (part.type === "image") {
|
||||
// Strip data URI prefix if present
|
||||
let data = (part as ImageContent).data;
|
||||
const commaIdx = data.indexOf(",");
|
||||
if (commaIdx !== -1 && data.startsWith("data:")) {
|
||||
data = data.slice(commaIdx + 1);
|
||||
}
|
||||
images.push(data);
|
||||
}
|
||||
}
|
||||
|
||||
const result: OllamaChatMessage = {
|
||||
role: "user",
|
||||
content: textParts.join("\n"),
|
||||
};
|
||||
if (images.length > 0) {
|
||||
result.images = images;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
function convertAssistantMessage(
|
||||
msg: Message & { role: "assistant" },
|
||||
): OllamaChatMessage {
|
||||
let content = "";
|
||||
const toolCalls: OllamaChatMessage["tool_calls"] = [];
|
||||
|
||||
for (const block of msg.content) {
|
||||
if (block.type === "thinking") {
|
||||
// Serialize thinking back inline for round-trip with Ollama
|
||||
content += `<think>${(block as ThinkingContent).thinking}</think>`;
|
||||
} else if (block.type === "text") {
|
||||
content += (block as TextContent).text;
|
||||
} else if (block.type === "toolCall") {
|
||||
const tc = block as ToolCall;
|
||||
toolCalls.push({
|
||||
function: {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
const result: OllamaChatMessage = { role: "assistant", content };
|
||||
if (toolCalls.length > 0) {
|
||||
result.tool_calls = toolCalls;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// ─── Tool conversion ────────────────────────────────────────────────────────
|
||||
|
||||
function convertTools(tools: Tool[]): OllamaTool[] {
|
||||
return tools.map((tool) => {
|
||||
const params = tool.parameters as Record<string, unknown>;
|
||||
return {
|
||||
type: "function" as const,
|
||||
function: {
|
||||
name: tool.name,
|
||||
description: tool.description,
|
||||
parameters: {
|
||||
type: "object" as const,
|
||||
required: params.required as string[] | undefined,
|
||||
properties: (params.properties as Record<string, unknown>) ?? {},
|
||||
},
|
||||
},
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
// ─── Response mapping ───────────────────────────────────────────────────────
|
||||
|
||||
function mapStopReason(doneReason?: string): StopReason {
|
||||
switch (doneReason) {
|
||||
case "stop":
|
||||
return "stop";
|
||||
case "length":
|
||||
return "length";
|
||||
default:
|
||||
return "stop";
|
||||
}
|
||||
}
|
||||
|
||||
function buildUsage(chunk: OllamaChatResponse): Usage {
|
||||
const input = chunk.prompt_eval_count ?? 0;
|
||||
const outputTokens = chunk.eval_count ?? 0;
|
||||
return {
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: input + outputTokens,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
}
|
||||
|
||||
function extractMetrics(
|
||||
chunk: OllamaChatResponse,
|
||||
): InferenceMetrics | undefined {
|
||||
if (!chunk.eval_duration && !chunk.total_duration) return undefined;
|
||||
|
||||
const evalCount = chunk.eval_count ?? 0;
|
||||
const evalDurationNs = chunk.eval_duration ?? 0;
|
||||
const evalDurationMs = evalDurationNs / 1e6;
|
||||
const tokensPerSecond =
|
||||
evalDurationNs > 0 ? evalCount / (evalDurationNs / 1e9) : 0;
|
||||
|
||||
return {
|
||||
tokensPerSecond,
|
||||
totalDurationMs: (chunk.total_duration ?? 0) / 1e6,
|
||||
evalDurationMs,
|
||||
promptEvalDurationMs: (chunk.prompt_eval_duration ?? 0) / 1e6,
|
||||
};
|
||||
}
|
||||
|
||||
// ─── Stream lifecycle helpers ───────────────────────────────────────────────
|
||||
// Replicated from openai-shared.ts (not exported from "@singularity-forge/pi-ai)
|
||||
|
||||
function buildInitialOutput(model: Model<Api>): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api: model.api as Api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
stopReason: "stop",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function assertStreamSuccess(
|
||||
output: AssistantMessage,
|
||||
signal?: AbortSignal,
|
||||
): void {
|
||||
if (signal?.aborted) {
|
||||
throw new Error("Request was aborted");
|
||||
}
|
||||
if (output.stopReason === "aborted" || output.stopReason === "error") {
|
||||
throw new Error("An unknown error occurred");
|
||||
}
|
||||
}
|
||||
|
||||
function finalizeStream(
|
||||
stream: AssistantMessageEventStream,
|
||||
output: AssistantMessage,
|
||||
): void {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: output.stopReason as Extract<
|
||||
StopReason,
|
||||
"stop" | "length" | "toolUse" | "pauseTurn"
|
||||
>,
|
||||
message: output,
|
||||
});
|
||||
stream.end();
|
||||
}
|
||||
|
||||
function handleStreamError(
|
||||
stream: AssistantMessageEventStream,
|
||||
output: AssistantMessage,
|
||||
error: unknown,
|
||||
signal?: AbortSignal,
|
||||
): void {
|
||||
for (const block of output.content)
|
||||
delete (block as { index?: number }).index;
|
||||
output.stopReason = signal?.aborted ? "aborted" : "error";
|
||||
output.errorMessage =
|
||||
error instanceof Error ? error.message : JSON.stringify(error);
|
||||
stream.push({ type: "error", reason: output.stopReason, error: output });
|
||||
stream.end();
|
||||
}
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
// sf — HTTP client for Ollama REST API
|
||||
|
||||
/**
|
||||
* Low-level HTTP client for the Ollama REST API.
|
||||
* Respects the OLLAMA_HOST environment variable for non-default endpoints.
|
||||
*
|
||||
* Reference: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
*/
|
||||
|
||||
import { parseNDJsonStream } from "./ndjson-stream.js";
|
||||
import type {
|
||||
OllamaChatRequest,
|
||||
OllamaChatResponse,
|
||||
OllamaPsResponse,
|
||||
OllamaPullProgress,
|
||||
OllamaShowResponse,
|
||||
OllamaTagsResponse,
|
||||
OllamaVersionResponse,
|
||||
} from "./types.js";
|
||||
|
||||
const DEFAULT_HOST = "http://localhost:11434";
|
||||
const PROBE_TIMEOUT_MS = 1500;
|
||||
const REQUEST_TIMEOUT_MS = 10000;
|
||||
|
||||
/**
|
||||
* Get the Ollama host URL from OLLAMA_HOST or default.
|
||||
*/
|
||||
export function getOllamaHost(): string {
|
||||
const host = process.env.OLLAMA_HOST;
|
||||
if (!host) return DEFAULT_HOST;
|
||||
|
||||
// OLLAMA_HOST can be just a host:port without scheme
|
||||
if (host.startsWith("http://") || host.startsWith("https://")) return host;
|
||||
return `http://${host}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get auth headers for Ollama API requests.
|
||||
* For cloud endpoints (OLLAMA_HOST pointing to ollama.com or remote instances),
|
||||
* OLLAMA_API_KEY is used as a Bearer token. Local Ollama ignores the header.
|
||||
*/
|
||||
function getAuthHeaders(): Record<string, string> {
|
||||
const apiKey = process.env.OLLAMA_API_KEY;
|
||||
if (!apiKey) return {};
|
||||
return { Authorization: `Bearer ${apiKey}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge auth headers into request options.
|
||||
*/
|
||||
function withAuth(options: RequestInit = {}): RequestInit {
|
||||
const authHeaders = getAuthHeaders();
|
||||
if (Object.keys(authHeaders).length === 0) return options;
|
||||
return {
|
||||
...options,
|
||||
headers: {
|
||||
...authHeaders,
|
||||
...((options.headers as Record<string, string>) || {}),
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
async function fetchWithTimeout(
|
||||
url: string,
|
||||
options: RequestInit = {},
|
||||
timeoutMs = REQUEST_TIMEOUT_MS,
|
||||
): Promise<Response> {
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||||
try {
|
||||
return await fetch(
|
||||
url,
|
||||
withAuth({ ...options, signal: controller.signal }),
|
||||
);
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if Ollama is running and reachable.
|
||||
* For cloud endpoints (OLLAMA_HOST pointing to ollama.com), uses /api/tags
|
||||
* as the probe since the root endpoint may not be available.
|
||||
*/
|
||||
export async function isRunning(): Promise<boolean> {
|
||||
try {
|
||||
const host = getOllamaHost();
|
||||
const isCloud = host.includes("ollama.com") || host.includes("cloud");
|
||||
const probeUrl = isCloud ? `${host}/api/tags` : `${host}/`;
|
||||
const timeout = isCloud ? REQUEST_TIMEOUT_MS : PROBE_TIMEOUT_MS;
|
||||
const response = await fetchWithTimeout(
|
||||
probeUrl,
|
||||
isCloud ? { method: "GET" } : {},
|
||||
timeout,
|
||||
);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get Ollama version.
|
||||
*/
|
||||
export async function getVersion(): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/version`);
|
||||
if (!response.ok) return null;
|
||||
const data = (await response.json()) as OllamaVersionResponse;
|
||||
return data.version;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* List all locally available models.
|
||||
*/
|
||||
export async function listModels(): Promise<OllamaTagsResponse> {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/tags`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Ollama /api/tags returned ${response.status}: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return (await response.json()) as OllamaTagsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get detailed information about a specific model.
|
||||
*/
|
||||
export async function showModel(name: string): Promise<OllamaShowResponse> {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/show`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Ollama /api/show returned ${response.status}: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return (await response.json()) as OllamaShowResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* List currently loaded/running models.
|
||||
*/
|
||||
export async function getRunningModels(): Promise<OllamaPsResponse> {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/ps`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Ollama /api/ps returned ${response.status}: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return (await response.json()) as OllamaPsResponse;
|
||||
}
|
||||
|
||||
/**
|
||||
* Pull a model with streaming progress.
|
||||
* Calls onProgress for each progress update.
|
||||
* Returns when the pull is complete.
|
||||
*/
|
||||
export async function pullModel(
|
||||
name: string,
|
||||
onProgress?: (progress: OllamaPullProgress) => void,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
const response = await fetch(
|
||||
`${getOllamaHost()}/api/pull`,
|
||||
withAuth({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name, stream: true }),
|
||||
signal,
|
||||
}),
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Ollama /api/pull returned ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Ollama /api/pull returned no body");
|
||||
}
|
||||
|
||||
for await (const progress of parseNDJsonStream<OllamaPullProgress>(
|
||||
response.body,
|
||||
signal,
|
||||
)) {
|
||||
onProgress?.(progress);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Stream a chat completion via /api/chat.
|
||||
* Returns an async generator yielding each NDJSON response chunk.
|
||||
*/
|
||||
export async function* chat(
|
||||
request: OllamaChatRequest,
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<OllamaChatResponse> {
|
||||
const response = await fetch(
|
||||
`${getOllamaHost()}/api/chat`,
|
||||
withAuth({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
signal,
|
||||
}),
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Ollama /api/chat returned ${response.status}: ${text}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Ollama /api/chat returned no body");
|
||||
}
|
||||
|
||||
yield* parseNDJsonStream<OllamaChatResponse>(response.body, signal, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Delete a local model.
|
||||
*/
|
||||
export async function deleteModel(name: string): Promise<void> {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/delete`, {
|
||||
method: "DELETE",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Ollama /api/delete returned ${response.status}: ${text}`);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Copy a model to a new name.
|
||||
*/
|
||||
export async function copyModel(
|
||||
source: string,
|
||||
destination: string,
|
||||
): Promise<void> {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/api/copy`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ source, destination }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
throw new Error(`Ollama /api/copy returned ${response.status}: ${text}`);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,272 +0,0 @@
|
|||
// sf — Ollama slash commands
|
||||
|
||||
/**
|
||||
* Registers /ollama slash commands for managing local Ollama models.
|
||||
*
|
||||
* Commands:
|
||||
* /ollama — Show status (running?, version, loaded models)
|
||||
* /ollama list — List all available local models with sizes
|
||||
* /ollama pull — Pull a model with progress
|
||||
* /ollama remove — Delete a local model
|
||||
* /ollama ps — Show running models and resource usage
|
||||
*/
|
||||
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
import { formatModelSize } from "./model-capabilities.js";
|
||||
import * as client from "./ollama-client.js";
|
||||
import { discoverModels, formatModelForDisplay } from "./ollama-discovery.js";
|
||||
|
||||
export function registerOllamaCommands(pi: ExtensionAPI): void {
|
||||
pi.registerCommand("ollama", {
|
||||
description: "Manage local Ollama models — list | pull | remove | ps",
|
||||
async handler(args, ctx) {
|
||||
const parts = (args ?? "").trim().split(/\s+/);
|
||||
const subcommand = parts[0] || "status";
|
||||
const modelArg = parts.slice(1).join(" ");
|
||||
|
||||
switch (subcommand) {
|
||||
case "status":
|
||||
return await handleStatus(ctx);
|
||||
case "list":
|
||||
case "ls":
|
||||
return await handleList(ctx);
|
||||
case "pull":
|
||||
return await handlePull(modelArg, ctx);
|
||||
case "remove":
|
||||
case "rm":
|
||||
case "delete":
|
||||
return await handleRemove(modelArg, ctx);
|
||||
case "ps":
|
||||
return await handlePs(ctx);
|
||||
default:
|
||||
ctx.ui.notify(
|
||||
`Unknown subcommand: ${subcommand}. Use: status, list, pull, remove, ps`,
|
||||
"warning",
|
||||
);
|
||||
}
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
async function handleStatus(ctx: any): Promise<void> {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
ctx.ui.notify(
|
||||
"Ollama is not running. Install from https://ollama.com and run 'ollama serve'",
|
||||
"warning",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const version = await client.getVersion();
|
||||
const lines: string[] = [];
|
||||
lines.push(
|
||||
`Ollama${version ? ` v${version}` : ""} — running (${client.getOllamaHost()})`,
|
||||
);
|
||||
|
||||
// Show loaded models
|
||||
try {
|
||||
const ps = await client.getRunningModels();
|
||||
if (ps.models && ps.models.length > 0) {
|
||||
lines.push("");
|
||||
lines.push("Loaded:");
|
||||
for (const m of ps.models) {
|
||||
const vram =
|
||||
m.size_vram > 0 ? formatModelSize(m.size_vram) + " VRAM" : "CPU";
|
||||
const expiresAt = new Date(m.expires_at);
|
||||
const idleMs = expiresAt.getTime() - Date.now();
|
||||
const idleMin = Math.max(0, Math.floor(idleMs / 60000));
|
||||
lines.push(` ${m.name} ${vram} expires in ${idleMin}m`);
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// ps endpoint may not be available on older versions
|
||||
}
|
||||
|
||||
// Show available models
|
||||
try {
|
||||
const models = await discoverModels();
|
||||
if (models.length > 0) {
|
||||
lines.push("");
|
||||
lines.push("Available:");
|
||||
for (const m of models) {
|
||||
lines.push(` ${formatModelForDisplay(m)}`);
|
||||
}
|
||||
} else {
|
||||
lines.push("");
|
||||
lines.push("No models pulled. Use /ollama pull <model> to get started.");
|
||||
}
|
||||
} catch (err) {
|
||||
lines.push("");
|
||||
lines.push(
|
||||
`Error listing models: ${err instanceof Error ? err.message : String(err)}`,
|
||||
);
|
||||
}
|
||||
|
||||
await ctx.ui.custom(
|
||||
(_tui: any, theme: any, _kb: any, done: (r: undefined) => void) => {
|
||||
const text = new Text(
|
||||
lines.map((l) => theme.fg("fg", l)).join("\n"),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
setTimeout(() => done(undefined), 0);
|
||||
return text;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
async function handleList(ctx: any): Promise<void> {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
ctx.ui.notify("Ollama is not running", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
const models = await discoverModels();
|
||||
if (models.length === 0) {
|
||||
ctx.ui.notify(
|
||||
"No models available. Use /ollama pull <model> to download one.",
|
||||
"info",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = ["Local Ollama models:", ""];
|
||||
for (const m of models) {
|
||||
lines.push(` ${formatModelForDisplay(m)}`);
|
||||
}
|
||||
|
||||
await ctx.ui.custom(
|
||||
(_tui: any, theme: any, _kb: any, done: (r: undefined) => void) => {
|
||||
const text = new Text(
|
||||
lines.map((l) => theme.fg("fg", l)).join("\n"),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
setTimeout(() => done(undefined), 0);
|
||||
return text;
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
async function handlePull(modelName: string, ctx: any): Promise<void> {
|
||||
if (!modelName) {
|
||||
ctx.ui.notify(
|
||||
"Usage: /ollama pull <model> (e.g. /ollama pull llama3.1:8b)",
|
||||
"warning",
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
ctx.ui.notify("Ollama is not running", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
ctx.ui.setWidget("ollama-pull", [`Pulling ${modelName}...`]);
|
||||
|
||||
try {
|
||||
let lastPercent = -1;
|
||||
await client.pullModel(modelName, (progress) => {
|
||||
if (progress.total && progress.completed) {
|
||||
const percent = Math.floor((progress.completed / progress.total) * 100);
|
||||
if (percent !== lastPercent) {
|
||||
lastPercent = percent;
|
||||
const completed = formatModelSize(progress.completed);
|
||||
const total = formatModelSize(progress.total);
|
||||
ctx.ui.setWidget("ollama-pull", [
|
||||
`Pulling ${modelName}... ${percent}% (${completed} / ${total})`,
|
||||
]);
|
||||
}
|
||||
} else if (progress.status) {
|
||||
ctx.ui.setWidget("ollama-pull", [`${modelName}: ${progress.status}`]);
|
||||
}
|
||||
});
|
||||
|
||||
ctx.ui.setWidget("ollama-pull", undefined);
|
||||
ctx.ui.notify(`${modelName} pulled successfully`, "success");
|
||||
} catch (err) {
|
||||
ctx.ui.setWidget("ollama-pull", undefined);
|
||||
ctx.ui.notify(
|
||||
`Failed to pull ${modelName}: ${err instanceof Error ? err.message : String(err)}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleRemove(modelName: string, ctx: any): Promise<void> {
|
||||
if (!modelName) {
|
||||
ctx.ui.notify("Usage: /ollama remove <model>", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
ctx.ui.notify("Ollama is not running", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
const confirmed = await ctx.ui.confirm(
|
||||
"Delete model",
|
||||
`Are you sure you want to delete ${modelName}?`,
|
||||
);
|
||||
|
||||
if (!confirmed) return;
|
||||
|
||||
try {
|
||||
await client.deleteModel(modelName);
|
||||
ctx.ui.notify(`${modelName} deleted`, "success");
|
||||
} catch (err) {
|
||||
ctx.ui.notify(
|
||||
`Failed to delete ${modelName}: ${err instanceof Error ? err.message : String(err)}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function handlePs(ctx: any): Promise<void> {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
ctx.ui.notify("Ollama is not running", "warning");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const ps = await client.getRunningModels();
|
||||
if (!ps.models || ps.models.length === 0) {
|
||||
ctx.ui.notify("No models currently loaded in memory", "info");
|
||||
return;
|
||||
}
|
||||
|
||||
const lines = ["Running models:", ""];
|
||||
for (const m of ps.models) {
|
||||
const vram =
|
||||
m.size_vram > 0 ? formatModelSize(m.size_vram) + " VRAM" : "CPU only";
|
||||
const totalSize = formatModelSize(m.size);
|
||||
const expiresAt = new Date(m.expires_at);
|
||||
const idleMs = expiresAt.getTime() - Date.now();
|
||||
const idleMin = Math.max(0, Math.floor(idleMs / 60000));
|
||||
lines.push(` ${m.name} ${totalSize} ${vram} expires in ${idleMin}m`);
|
||||
}
|
||||
|
||||
await ctx.ui.custom(
|
||||
(_tui: any, theme: any, _kb: any, done: (r: undefined) => void) => {
|
||||
const text = new Text(
|
||||
lines.map((l) => theme.fg("fg", l)).join("\n"),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
setTimeout(() => done(undefined), 0);
|
||||
return text;
|
||||
},
|
||||
);
|
||||
} catch (err) {
|
||||
ctx.ui.notify(
|
||||
`Failed to get running models: ${err instanceof Error ? err.message : String(err)}`,
|
||||
"error",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,154 +0,0 @@
|
|||
// sf — Ollama model discovery and capability detection
|
||||
|
||||
/**
|
||||
* Discovers locally available Ollama models and enriches them with
|
||||
* capability metadata (context window, vision, reasoning) from the
|
||||
* known model table and /api/show responses.
|
||||
*
|
||||
* Returns models in the format expected by pi.registerProvider().
|
||||
*/
|
||||
|
||||
import {
|
||||
estimateContextFromParams,
|
||||
formatModelSize,
|
||||
getModelCapabilities,
|
||||
humanizeModelName,
|
||||
} from "./model-capabilities.js";
|
||||
import { listModels, showModel } from "./ollama-client.js";
|
||||
import type { OllamaChatOptions, OllamaModelInfo } from "./types.js";
|
||||
|
||||
/**
|
||||
* Extract context window from /api/show model_info.
|
||||
* Keys follow the pattern "{architecture}.context_length" (e.g. "llama.context_length").
|
||||
*/
|
||||
function extractContextFromModelInfo(
|
||||
modelInfo: Record<string, unknown>,
|
||||
): number | undefined {
|
||||
for (const [key, value] of Object.entries(modelInfo)) {
|
||||
if (
|
||||
key.endsWith(".context_length") &&
|
||||
typeof value === "number" &&
|
||||
value > 0
|
||||
) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
type ClientDeps = {
|
||||
listModels: typeof listModels;
|
||||
showModel: typeof showModel;
|
||||
};
|
||||
|
||||
export interface DiscoveredOllamaModel {
|
||||
id: string;
|
||||
name: string;
|
||||
reasoning: boolean;
|
||||
input: ("text" | "image")[];
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
};
|
||||
contextWindow: number;
|
||||
maxTokens: number;
|
||||
/** Raw size in bytes for display purposes */
|
||||
sizeBytes: number;
|
||||
/** Parameter size string from Ollama (e.g. "7B") */
|
||||
parameterSize: string;
|
||||
/** Ollama-specific inference options for this model */
|
||||
ollamaOptions?: OllamaChatOptions;
|
||||
}
|
||||
|
||||
const ZERO_COST = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
|
||||
|
||||
async function enrichModel(
|
||||
info: OllamaModelInfo,
|
||||
deps: ClientDeps,
|
||||
): Promise<DiscoveredOllamaModel> {
|
||||
const caps = getModelCapabilities(info.name);
|
||||
const parameterSize = info.details?.parameter_size ?? "";
|
||||
|
||||
// /api/tags doesn't include context length; /api/show does via "{arch}.context_length" in model_info.
|
||||
let showContextWindow: number | undefined;
|
||||
if (caps.contextWindow === undefined) {
|
||||
try {
|
||||
const showData = await deps.showModel(info.name);
|
||||
showContextWindow = extractContextFromModelInfo(showData.model_info);
|
||||
} catch (err) {
|
||||
// non-fatal: fall through to estimate
|
||||
if (process.env.SF_DEBUG)
|
||||
console.warn(
|
||||
`[ollama] /api/show failed for ${info.name}:`,
|
||||
err instanceof Error ? err.message : String(err),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Determine context window: known table > /api/show > estimate from param size > default
|
||||
const contextWindow =
|
||||
caps.contextWindow ??
|
||||
showContextWindow ??
|
||||
(parameterSize ? estimateContextFromParams(parameterSize) : 8192);
|
||||
|
||||
// Determine max tokens: known table > fraction of context > default
|
||||
const maxTokens =
|
||||
caps.maxTokens ?? Math.min(Math.floor(contextWindow / 4), 16384);
|
||||
|
||||
// Detect vision from families or known table
|
||||
const hasVision =
|
||||
caps.input?.includes("image") ??
|
||||
info.details?.families?.some((f) => f === "clip" || f === "mllama") ??
|
||||
false;
|
||||
|
||||
// Detect reasoning from known table
|
||||
const reasoning = caps.reasoning ?? false;
|
||||
|
||||
return {
|
||||
id: info.name,
|
||||
name: humanizeModelName(info.name),
|
||||
reasoning,
|
||||
input: hasVision ? ["text", "image"] : ["text"],
|
||||
cost: ZERO_COST,
|
||||
contextWindow,
|
||||
maxTokens,
|
||||
sizeBytes: info.size,
|
||||
parameterSize,
|
||||
ollamaOptions: caps.ollamaOptions,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Discover all locally available Ollama models with enriched capabilities.
|
||||
*/
|
||||
export async function discoverModels(
|
||||
deps: ClientDeps = { listModels, showModel },
|
||||
): Promise<DiscoveredOllamaModel[]> {
|
||||
const tags = await deps.listModels();
|
||||
if (!tags.models || tags.models.length === 0) return [];
|
||||
|
||||
return Promise.all(tags.models.map((m) => enrichModel(m, deps)));
|
||||
}
|
||||
|
||||
/**
|
||||
* Format a discovered model for display in model list.
|
||||
*/
|
||||
export function formatModelForDisplay(model: DiscoveredOllamaModel): string {
|
||||
const parts = [model.id];
|
||||
|
||||
if (model.sizeBytes > 0) {
|
||||
parts.push(`(${formatModelSize(model.sizeBytes)})`);
|
||||
}
|
||||
|
||||
const flags: string[] = [];
|
||||
if (model.reasoning) flags.push("reasoning");
|
||||
if (model.input.includes("image")) flags.push("vision");
|
||||
|
||||
if (flags.length > 0) {
|
||||
parts.push(`[${flags.join(", ")}]`);
|
||||
}
|
||||
|
||||
return parts.join(" ");
|
||||
}
|
||||
|
|
@ -1,56 +0,0 @@
|
|||
/**
|
||||
* Regression test: don't show an Ollama footer status unless Ollama is
|
||||
* actually usable (running with at least one discovered model).
|
||||
*/
|
||||
|
||||
import assert from "node:assert/strict";
|
||||
import { readFileSync } from "node:fs";
|
||||
import { dirname, join } from "node:path";
|
||||
import { test } from 'vitest';
|
||||
import { fileURLToPath } from "node:url";
|
||||
|
||||
const __dirname = dirname(fileURLToPath(import.meta.url));
|
||||
const src = readFileSync(join(__dirname, "index.ts"), "utf-8");
|
||||
|
||||
test("probeAndRegister returns false when no Ollama models are discovered", () => {
|
||||
assert.match(
|
||||
src,
|
||||
/if \(models\.length === 0\)[\s\S]*return false;/,
|
||||
"running-without-models should not be treated as available",
|
||||
);
|
||||
});
|
||||
|
||||
test("interactive session clears ollama footer status when unavailable", () => {
|
||||
assert.match(
|
||||
src,
|
||||
/ctx\.ui\.setStatus\("ollama", found \? "Ollama" : undefined\)/,
|
||||
"status should be cleared when probeAndRegister reports unavailable",
|
||||
);
|
||||
});
|
||||
|
||||
test("registration is gated on OLLAMA_HOST being set", () => {
|
||||
// Top-level guard in the default export — no probe, no commands, no tool
|
||||
// registration happens without an explicit OLLAMA_HOST. This makes the
|
||||
// extension opt-in for users who never run local Ollama.
|
||||
assert.match(
|
||||
src,
|
||||
/export default function ollama\([\s\S]*?if \(!isOllamaConfigured\(\)\) return;/,
|
||||
"default export should short-circuit when OLLAMA_HOST is unset",
|
||||
);
|
||||
});
|
||||
|
||||
test("probeAndRegister bails out before hitting the network when unconfigured", () => {
|
||||
assert.match(
|
||||
src,
|
||||
/async function probeAndRegister\([\s\S]*?if \(!isOllamaConfigured\(\)\) return false;/,
|
||||
"probeAndRegister should skip client.isRunning() when OLLAMA_HOST is unset",
|
||||
);
|
||||
});
|
||||
|
||||
test("isOllamaConfigured keys off OLLAMA_HOST env var", () => {
|
||||
assert.match(
|
||||
src,
|
||||
/function isOllamaConfigured\(\)[\s\S]*?process\.env\.OLLAMA_HOST/,
|
||||
"configuration check should read OLLAMA_HOST",
|
||||
);
|
||||
});
|
||||
|
|
@ -1,438 +0,0 @@
|
|||
// sf — LLM-callable Ollama management tool
|
||||
/**
|
||||
* Registers an ollama_manage tool that the LLM can call to interact
|
||||
* with the local Ollama instance — list models, pull new ones, check status.
|
||||
*/
|
||||
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { ExtensionAPI } from "@singularity-forge/pi-coding-agent";
|
||||
import { Text } from "@singularity-forge/pi-tui";
|
||||
import { formatModelSize } from "./model-capabilities.js";
|
||||
import * as client from "./ollama-client.js";
|
||||
import { discoverModels, formatModelForDisplay } from "./ollama-discovery.js";
|
||||
|
||||
interface OllamaToolDetails {
|
||||
action: string;
|
||||
model?: string;
|
||||
modelCount?: number;
|
||||
durationMs: number;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export function registerOllamaTool(pi: ExtensionAPI): void {
|
||||
pi.registerTool({
|
||||
name: "ollama_manage",
|
||||
label: "Ollama",
|
||||
description:
|
||||
"Manage local Ollama models. List available models, pull new ones, " +
|
||||
"check Ollama status, or see running models and resource usage. " +
|
||||
"Use this when you need a specific local model that isn't available yet.",
|
||||
promptSnippet: "Manage local Ollama models (list, pull, status, ps)",
|
||||
promptGuidelines: [
|
||||
"Use 'list' to see what models are available locally before trying to use one.",
|
||||
"Use 'pull' to download a model that isn't available yet.",
|
||||
"Use 'remove' to delete a local model that is no longer needed.",
|
||||
"Use 'show' to get detailed info about a model (parameters, quantization, families).",
|
||||
"Use 'status' to check if Ollama is running.",
|
||||
"Use 'ps' to see which models are loaded in memory and VRAM usage.",
|
||||
"Common models: llama3.1:8b, qwen2.5-coder:7b, deepseek-r1:8b, codestral:22b",
|
||||
],
|
||||
parameters: Type.Object({
|
||||
action: Type.Union(
|
||||
[
|
||||
Type.Literal("list"),
|
||||
Type.Literal("pull"),
|
||||
Type.Literal("remove"),
|
||||
Type.Literal("show"),
|
||||
Type.Literal("status"),
|
||||
Type.Literal("ps"),
|
||||
],
|
||||
{ description: "Action to perform" },
|
||||
),
|
||||
model: Type.Optional(
|
||||
Type.String({ description: "Model name (required for pull)" }),
|
||||
),
|
||||
}),
|
||||
|
||||
async execute(_toolCallId, params, signal, onUpdate, _ctx) {
|
||||
const startTime = Date.now();
|
||||
const { action, model } = params;
|
||||
|
||||
try {
|
||||
switch (action) {
|
||||
case "status": {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Ollama is not running. It needs to be started with 'ollama serve'.",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
const version = await client.getVersion();
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Ollama${version ? ` v${version}` : ""} is running at ${client.getOllamaHost()}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
case "list": {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Ollama is not running." }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "not_running",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const models = await discoverModels();
|
||||
if (models.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "No models available. Pull one with action='pull'.",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
modelCount: 0,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const lines = models.map((m) => formatModelForDisplay(m));
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: `Available models:\n${lines.join("\n")}`,
|
||||
},
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
modelCount: models.length,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
case "pull": {
|
||||
if (!model) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Error: 'model' parameter is required for pull action.",
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "missing_model",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Ollama is not running." }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "not_running",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
let lastStatus = "";
|
||||
await client.pullModel(
|
||||
model,
|
||||
(progress) => {
|
||||
if (progress.total && progress.completed) {
|
||||
const pct = Math.floor(
|
||||
(progress.completed / progress.total) * 100,
|
||||
);
|
||||
const status = `Pulling ${model}... ${pct}%`;
|
||||
if (status !== lastStatus) {
|
||||
lastStatus = status;
|
||||
onUpdate?.({
|
||||
content: [{ type: "text", text: status }],
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
});
|
||||
}
|
||||
} else if (progress.status && progress.status !== lastStatus) {
|
||||
lastStatus = progress.status;
|
||||
onUpdate?.({
|
||||
content: [
|
||||
{ type: "text", text: `${model}: ${progress.status}` },
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
});
|
||||
}
|
||||
},
|
||||
signal,
|
||||
);
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: `Successfully pulled ${model}` }],
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
case "ps": {
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Ollama is not running." }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "not_running",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const ps = await client.getRunningModels();
|
||||
if (!ps.models || ps.models.length === 0) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "No models currently loaded in memory.",
|
||||
},
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
modelCount: 0,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const lines = ps.models.map((m) => {
|
||||
const vram =
|
||||
m.size_vram > 0
|
||||
? `${formatModelSize(m.size_vram)} VRAM`
|
||||
: "CPU";
|
||||
return `${m.name} — ${formatModelSize(m.size)} total, ${vram}`;
|
||||
});
|
||||
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Loaded models:\n${lines.join("\n")}` },
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
modelCount: ps.models.length,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
case "remove": {
|
||||
if (!model) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Error: 'model' parameter is required for remove action.",
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "missing_model",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Ollama is not running." }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "not_running",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
await client.deleteModel(model);
|
||||
return {
|
||||
content: [
|
||||
{ type: "text", text: `Successfully removed ${model}` },
|
||||
],
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
case "show": {
|
||||
if (!model) {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: "text",
|
||||
text: "Error: 'model' parameter is required for show action.",
|
||||
},
|
||||
],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "missing_model",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const running = await client.isRunning();
|
||||
if (!running) {
|
||||
return {
|
||||
content: [{ type: "text", text: "Ollama is not running." }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "not_running",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
const info = await client.showModel(model);
|
||||
const details = info.details;
|
||||
const infoLines = [
|
||||
`Model: ${model}`,
|
||||
`Family: ${details.family}`,
|
||||
`Parameters: ${details.parameter_size}`,
|
||||
`Quantization: ${details.quantization_level}`,
|
||||
`Format: ${details.format}`,
|
||||
];
|
||||
if (details.families?.length) {
|
||||
infoLines.push(`Families: ${details.families.join(", ")}`);
|
||||
}
|
||||
if (info.parameters) {
|
||||
infoLines.push(`\nModelfile parameters:\n${info.parameters}`);
|
||||
}
|
||||
|
||||
return {
|
||||
content: [{ type: "text", text: infoLines.join("\n") }],
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
|
||||
default:
|
||||
return {
|
||||
content: [{ type: "text", text: `Unknown action: ${action}` }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: "unknown_action",
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
} catch (err) {
|
||||
const msg = err instanceof Error ? err.message : String(err);
|
||||
return {
|
||||
content: [{ type: "text", text: `Ollama error: ${msg}` }],
|
||||
isError: true,
|
||||
details: {
|
||||
action,
|
||||
model,
|
||||
durationMs: Date.now() - startTime,
|
||||
error: msg,
|
||||
} as OllamaToolDetails,
|
||||
};
|
||||
}
|
||||
},
|
||||
|
||||
renderCall(args, theme) {
|
||||
let text = theme.fg("toolTitle", theme.bold("ollama "));
|
||||
text += theme.fg("accent", args.action);
|
||||
if (args.model) {
|
||||
text += theme.fg("dim", ` ${args.model}`);
|
||||
}
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
|
||||
renderResult(result, { isPartial, expanded }, theme) {
|
||||
const d = result.details as OllamaToolDetails | undefined;
|
||||
|
||||
if (isPartial) return new Text(theme.fg("warning", "Working..."), 0, 0);
|
||||
if ((result as any).isError || d?.error) {
|
||||
return new Text(
|
||||
theme.fg("error", `Error: ${d?.error ?? "unknown"}`),
|
||||
0,
|
||||
0,
|
||||
);
|
||||
}
|
||||
|
||||
let text = theme.fg("success", d?.action ?? "done");
|
||||
if (d?.modelCount !== undefined) {
|
||||
text += theme.fg("dim", ` (${d.modelCount} models)`);
|
||||
}
|
||||
text += theme.fg("dim", ` ${d?.durationMs ?? 0}ms`);
|
||||
|
||||
if (expanded) {
|
||||
const content = result.content[0];
|
||||
if (content?.type === "text") {
|
||||
const preview = content.text.split("\n").slice(0, 10).join("\n");
|
||||
text += "\n\n" + theme.fg("dim", preview);
|
||||
}
|
||||
}
|
||||
|
||||
return new Text(text, 0, 0);
|
||||
},
|
||||
});
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue