From aa60821ec36cea41f9125af2ff5b45b3aaaaad32 Mon Sep 17 00:00:00 2001 From: Mikael Hugo Date: Sat, 2 May 2026 23:20:29 +0200 Subject: [PATCH] feat(sf): wire rerank pass into getRelevantMemoriesRanked MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The gateway rerank surface was shipped dormant in 56ee89a94 — the function existed but no consumer called it, so setting SF_LLM_GATEWAY_RERANK_MODEL did nothing functional. Now: after the cosine-rank top-K is computed, optionally call rerankCandidates(query, top-K) when a rerank model is configured. Re- sort by relevance_score; gracefully fall back to cosine order in every sad path (no model, no worker, network error, malformed response). Strictly additive precision boost — the cosine-only ranking path is unchanged when rerank isn't enabled OR returns null. Two new tests: rerank actively reorders the top-K when scores are returned, and the no-worker-online soft-degrade path preserves cosine order. 12 tests in the file passing. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/resources/extensions/sf/memory-store.ts | 36 +++++++- .../sf/tests/memory-query-ranking.test.ts | 85 +++++++++++++++++++ 2 files changed, 117 insertions(+), 4 deletions(-) diff --git a/src/resources/extensions/sf/memory-store.ts b/src/resources/extensions/sf/memory-store.ts index db2217217..90dd31611 100644 --- a/src/resources/extensions/sf/memory-store.ts +++ b/src/resources/extensions/sf/memory-store.ts @@ -179,13 +179,41 @@ export async function getRelevantMemoriesRanked( embeddingMap, ); const byId = new Map(pool.map((m) => [m.id, m])); - const out: Memory[] = []; + // Top-K from cosine rank — feed this into the optional rerank pass. + const topK: Memory[] = []; for (const r of ranked) { const mem = byId.get(r.id); - if (mem) out.push(mem); - if (out.length >= limit) break; + if (mem) topK.push(mem); + if (topK.length >= limit) break; } - return out; + // Optional rerank refinement: when SF_LLM_GATEWAY_RERANK_MODEL is set + // AND the gateway has a rerank worker, re-score the cosine top-K with + // the cross-encoder rerank model. Returns null in every other case + // (no model configured, no worker online, network error) and we keep + // the cosine-ranked order as-is — strictly additive precision boost. + try { + const { loadGatewayConfigFromEnv, rerankCandidates } = await import( + "./memory-embeddings-llm-gateway.js" + ); + const cfg = loadGatewayConfigFromEnv(); + if (cfg?.rerankModel && topK.length > 1) { + const scores = await rerankCandidates( + cfg, + query, + topK.map((m) => ({ id: m.id, text: m.content })), + ); + if (scores && scores.length > 0) { + const scoreById = new Map(scores.map((s) => [s.id, s.score])); + return [...topK].sort( + (a, b) => + (scoreById.get(b.id) ?? 0) - (scoreById.get(a.id) ?? 0), + ); + } + } + } catch { + // Rerank is best-effort; cosine order is already a fine answer. + } + return topK; } catch { return pool.slice(0, limit); } diff --git a/src/resources/extensions/sf/tests/memory-query-ranking.test.ts b/src/resources/extensions/sf/tests/memory-query-ranking.test.ts index 1eddf89b3..71808fa5b 100644 --- a/src/resources/extensions/sf/tests/memory-query-ranking.test.ts +++ b/src/resources/extensions/sf/tests/memory-query-ranking.test.ts @@ -234,4 +234,89 @@ describe("getRelevantMemoriesRanked (async, mocked gateway)", () => { const out = await getRelevantMemoriesRanked("anything", 10); assert.equal(out[0].id, a, "high-confidence memory ranks first by static score"); }); + + test("rerank pass re-orders cosine top-K when worker is online", async () => { + process.env.SF_LLM_GATEWAY_KEY = "x"; + process.env.SF_LLM_GATEWAY_URL = "https://gateway.test/v1"; + process.env.SF_LLM_GATEWAY_RERANK_MODEL = "bge-reranker"; + const a = createMemory({ category: "architecture", content: "alpha alpha" }); + const b = createMemory({ category: "architecture", content: "beta beta" }); + assert.ok(a && b); + // Cosine order (with both vectors aligned to query) → tie broken by + // pool order, so a comes first. Rerank flips it: b gets higher score. + saveEmbedding(a, Float32Array.from([1, 0, 0]), "test-model"); + saveEmbedding(b, Float32Array.from([1, 0, 0]), "test-model"); + + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url.endsWith("/embeddings")) { + return new Response( + JSON.stringify({ + object: "list", + data: [ + { object: "embedding", index: 0, embedding: [1, 0, 0] }, + ], + }), + { status: 200 }, + ); + } + if (url.endsWith("/rerank")) { + return new Response( + JSON.stringify({ + results: [ + { index: 1, relevance_score: 0.95 }, // b ranks higher + { index: 0, relevance_score: 0.10 }, + ], + }), + { status: 200 }, + ); + } + return new Response("404", { status: 404 }); + }), + ); + + const out = await getRelevantMemoriesRanked("query", 10); + assert.equal(out.length, 2); + assert.equal(out[0].id, b, "rerank promoted b above a"); + assert.equal(out[1].id, a); + }); + + test("rerank no-op when no worker is online", async () => { + process.env.SF_LLM_GATEWAY_KEY = "x"; + process.env.SF_LLM_GATEWAY_URL = "https://gateway.test/v1"; + process.env.SF_LLM_GATEWAY_RERANK_MODEL = "bge-reranker"; + const a = createMemory({ category: "architecture", content: "alpha" }); + const b = createMemory({ category: "architecture", content: "beta" }); + assert.ok(a && b); + saveEmbedding(a, Float32Array.from([0, 1, 0]), "test-model"); + saveEmbedding(b, Float32Array.from([1, 0, 0]), "test-model"); + + vi.stubGlobal( + "fetch", + vi.fn(async (url: string) => { + if (url.endsWith("/embeddings")) { + return new Response( + JSON.stringify({ + object: "list", + data: [{ object: "embedding", index: 0, embedding: [1, 0, 0] }], + }), + { status: 200 }, + ); + } + if (url.endsWith("/rerank")) { + // No rerank worker — gateway returns a soft-degrade body. + return new Response("no worker with rerank capability is available", { + status: 503, + }); + } + return new Response("404", { status: 404 }); + }), + ); + + const out = await getRelevantMemoriesRanked("query", 10); + // Cosine order survives: b (cosine 1.0) above a (cosine 0). + assert.equal(out[0].id, b); + assert.equal(out[1].id, a); + }); });