diff --git a/src/resources/extensions/sf/memory-store.ts b/src/resources/extensions/sf/memory-store.ts index db2217217..90dd31611 100644 --- a/src/resources/extensions/sf/memory-store.ts +++ b/src/resources/extensions/sf/memory-store.ts @@ -179,13 +179,41 @@ export async function getRelevantMemoriesRanked( embeddingMap, ); const byId = new Map(pool.map((m) => [m.id, m])); - const out: Memory[] = []; + // Top-K from cosine rank — feed this into the optional rerank pass. + const topK: Memory[] = []; for (const r of ranked) { const mem = byId.get(r.id); - if (mem) out.push(mem); - if (out.length >= limit) break; + if (mem) topK.push(mem); + if (topK.length >= limit) break; } - return out; + // Optional rerank refinement: when SF_LLM_GATEWAY_RERANK_MODEL is set + // AND the gateway has a rerank worker, re-score the cosine top-K with + // the cross-encoder rerank model. Returns null in every other case + // (no model configured, no worker online, network error) and we keep + // the cosine-ranked order as-is — strictly additive precision boost. + try { + const { loadGatewayConfigFromEnv, rerankCandidates } = await import( + "./memory-embeddings-llm-gateway.js" + ); + const cfg = loadGatewayConfigFromEnv(); + if (cfg?.rerankModel && topK.length > 1) { + const scores = await rerankCandidates( + cfg, + query, + topK.map((m) => ({ id: m.id, text: m.content })), + ); + if (scores && scores.length > 0) { + const scoreById = new Map(scores.map((s) => [s.id, s.score])); + return [...topK].sort( + (a, b) => + (scoreById.get(b.id) ?? 0) - (scoreById.get(a.id) ?? 0), + ); + } + } + } catch { + // Rerank is best-effort; cosine order is already a fine answer. + } + return topK; } catch { return pool.slice(0, limit); } diff --git a/src/resources/extensions/sf/tests/memory-query-ranking.test.ts b/src/resources/extensions/sf/tests/memory-query-ranking.test.ts index 1eddf89b3..71808fa5b 100644 --- a/src/resources/extensions/sf/tests/memory-query-ranking.test.ts +++ b/src/resources/extensions/sf/tests/memory-query-ranking.test.ts @@ -234,4 +234,89 @@ describe("getRelevantMemoriesRanked (async, mocked gateway)", () => { const out = await getRelevantMemoriesRanked("anything", 10); assert.equal(out[0].id, a, "high-confidence memory ranks first by static score"); }); + + test("rerank pass re-orders cosine top-K when worker is online", async () => { + process.env.SF_LLM_GATEWAY_KEY = "x"; + process.env.SF_LLM_GATEWAY_URL = "https://gateway.test/v1"; + process.env.SF_LLM_GATEWAY_RERANK_MODEL = "bge-reranker"; + const a = createMemory({ category: "architecture", content: "alpha alpha" }); + const b = createMemory({ category: "architecture", content: "beta beta" }); + assert.ok(a && b); + // Cosine order (with both vectors aligned to query) → tie broken by + // pool order, so a comes first. Rerank flips it: b gets higher score. + saveEmbedding(a, Float32Array.from([1, 0, 0]), "test-model"); + saveEmbedding(b, Float32Array.from([1, 0, 0]), "test-model"); + + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url.endsWith("/embeddings")) { + return new Response( + JSON.stringify({ + object: "list", + data: [ + { object: "embedding", index: 0, embedding: [1, 0, 0] }, + ], + }), + { status: 200 }, + ); + } + if (url.endsWith("/rerank")) { + return new Response( + JSON.stringify({ + results: [ + { index: 1, relevance_score: 0.95 }, // b ranks higher + { index: 0, relevance_score: 0.10 }, + ], + }), + { status: 200 }, + ); + } + return new Response("404", { status: 404 }); + }), + ); + + const out = await getRelevantMemoriesRanked("query", 10); + assert.equal(out.length, 2); + assert.equal(out[0].id, b, "rerank promoted b above a"); + assert.equal(out[1].id, a); + }); + + test("rerank no-op when no worker is online", async () => { + process.env.SF_LLM_GATEWAY_KEY = "x"; + process.env.SF_LLM_GATEWAY_URL = "https://gateway.test/v1"; + process.env.SF_LLM_GATEWAY_RERANK_MODEL = "bge-reranker"; + const a = createMemory({ category: "architecture", content: "alpha" }); + const b = createMemory({ category: "architecture", content: "beta" }); + assert.ok(a && b); + saveEmbedding(a, Float32Array.from([0, 1, 0]), "test-model"); + saveEmbedding(b, Float32Array.from([1, 0, 0]), "test-model"); + + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url.endsWith("/embeddings")) { + return new Response( + JSON.stringify({ + object: "list", + data: [{ object: "embedding", index: 0, embedding: [1, 0, 0] }], + }), + { status: 200 }, + ); + } + if (url.endsWith("/rerank")) { + // No rerank worker — gateway returns a soft-degrade body. + return new Response("no worker with rerank capability is available", { + status: 503, + }); + } + return new Response("404", { status: 404 }); + }), + ); + + const out = await getRelevantMemoriesRanked("query", 10); + // Cosine order survives: b (cosine 1.0) above a (cosine 0). + assert.equal(out[0].id, b); + assert.equal(out[1].id, a); + }); });