fix(ollama): add cloud auth support and resolve real context window via /api/show (#4017)
- Add OLLAMA_API_KEY Bearer token auth to all Ollama HTTP client requests
(fetchWithTimeout, pullModel, chat) via getAuthHeaders/withAuth helpers.
Local Ollama ignores the Authorization header; cloud endpoints require it.
- Fix isRunning() probe for cloud endpoints: use /api/tags instead of root /
since cloud hosts may not serve the root endpoint.
- Resolve real context window for unknown models via /api/show model_info
({arch}.context_length) instead of defaulting to 8192. Priority chain:
known table > /api/show > estimate from parameter_size > 8192.
- Use dependency injection for discoverModels() to allow test mocking
without ESM named export issues.
- Pick up OLLAMA_API_KEY in provider registration (apiKey field).
Closes #3544
Co-authored-by: luannevesb <luannevesb@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
843b72e7c1
commit
00c6442e1a
4 changed files with 130 additions and 17 deletions
|
|
@ -69,13 +69,12 @@ async function probeAndRegister(pi: ExtensionAPI): Promise<boolean> {
|
|||
|
||||
const baseUrl = client.getOllamaHost();
|
||||
|
||||
// Use authMode "apiKey" with a dummy key (#3440).
|
||||
// authMode "none" requires a custom streamSimple handler, but Ollama uses
|
||||
// the standard OpenAI-compatible streaming endpoint. Ollama ignores the
|
||||
// Authorization header so the dummy key is harmless.
|
||||
// Use authMode "apiKey" (#3440). Local Ollama ignores the Authorization header,
|
||||
// so the "ollama" fallback is harmless. For cloud endpoints (OLLAMA_HOST pointing
|
||||
// to ollama.com or a remote instance), OLLAMA_API_KEY is picked up here.
|
||||
pi.registerProvider("ollama", {
|
||||
authMode: "apiKey",
|
||||
apiKey: "ollama",
|
||||
apiKey: process.env.OLLAMA_API_KEY ?? "ollama",
|
||||
baseUrl,
|
||||
api: "ollama-chat",
|
||||
streamSimple: streamOllamaChat,
|
||||
|
|
|
|||
|
|
@ -34,11 +34,34 @@ export function getOllamaHost(): string {
|
|||
return `http://${host}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get auth headers for Ollama API requests.
|
||||
* For cloud endpoints (OLLAMA_HOST pointing to ollama.com or remote instances),
|
||||
* OLLAMA_API_KEY is used as a Bearer token. Local Ollama ignores the header.
|
||||
*/
|
||||
function getAuthHeaders(): Record<string, string> {
|
||||
const apiKey = process.env.OLLAMA_API_KEY;
|
||||
if (!apiKey) return {};
|
||||
return { Authorization: `Bearer ${apiKey}` };
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge auth headers into request options.
|
||||
*/
|
||||
function withAuth(options: RequestInit = {}): RequestInit {
|
||||
const authHeaders = getAuthHeaders();
|
||||
if (Object.keys(authHeaders).length === 0) return options;
|
||||
return {
|
||||
...options,
|
||||
headers: { ...authHeaders, ...(options.headers as Record<string, string> || {}) },
|
||||
};
|
||||
}
|
||||
|
||||
async function fetchWithTimeout(url: string, options: RequestInit = {}, timeoutMs = REQUEST_TIMEOUT_MS): Promise<Response> {
|
||||
const controller = new AbortController();
|
||||
const timeout = setTimeout(() => controller.abort(), timeoutMs);
|
||||
try {
|
||||
return await fetch(url, { ...options, signal: controller.signal });
|
||||
return await fetch(url, withAuth({ ...options, signal: controller.signal }));
|
||||
} finally {
|
||||
clearTimeout(timeout);
|
||||
}
|
||||
|
|
@ -46,10 +69,16 @@ async function fetchWithTimeout(url: string, options: RequestInit = {}, timeoutM
|
|||
|
||||
/**
|
||||
* Check if Ollama is running and reachable.
|
||||
* For cloud endpoints (OLLAMA_HOST pointing to ollama.com), uses /api/tags
|
||||
* as the probe since the root endpoint may not be available.
|
||||
*/
|
||||
export async function isRunning(): Promise<boolean> {
|
||||
try {
|
||||
const response = await fetchWithTimeout(`${getOllamaHost()}/`, {}, PROBE_TIMEOUT_MS);
|
||||
const host = getOllamaHost();
|
||||
const isCloud = host.includes("ollama.com") || host.includes("cloud");
|
||||
const probeUrl = isCloud ? `${host}/api/tags` : `${host}/`;
|
||||
const timeout = isCloud ? REQUEST_TIMEOUT_MS : PROBE_TIMEOUT_MS;
|
||||
const response = await fetchWithTimeout(probeUrl, isCloud ? { method: "GET" } : {}, timeout);
|
||||
return response.ok;
|
||||
} catch {
|
||||
return false;
|
||||
|
|
@ -117,12 +146,12 @@ export async function pullModel(
|
|||
onProgress?: (progress: OllamaPullProgress) => void,
|
||||
signal?: AbortSignal,
|
||||
): Promise<void> {
|
||||
const response = await fetch(`${getOllamaHost()}/api/pull`, {
|
||||
const response = await fetch(`${getOllamaHost()}/api/pull`, withAuth({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ name, stream: true }),
|
||||
signal,
|
||||
});
|
||||
}));
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
|
|
@ -146,12 +175,12 @@ export async function* chat(
|
|||
request: OllamaChatRequest,
|
||||
signal?: AbortSignal,
|
||||
): AsyncGenerator<OllamaChatResponse> {
|
||||
const response = await fetch(`${getOllamaHost()}/api/chat`, {
|
||||
const response = await fetch(`${getOllamaHost()}/api/chat`, withAuth({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(request),
|
||||
signal,
|
||||
});
|
||||
}));
|
||||
|
||||
if (!response.ok) {
|
||||
const text = await response.text();
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
* Returns models in the format expected by pi.registerProvider().
|
||||
*/
|
||||
|
||||
import { listModels } from "./ollama-client.js";
|
||||
import { listModels, showModel } from "./ollama-client.js";
|
||||
import {
|
||||
estimateContextFromParams,
|
||||
formatModelSize,
|
||||
|
|
@ -17,6 +17,24 @@ import {
|
|||
} from "./model-capabilities.js";
|
||||
import type { OllamaChatOptions, OllamaModelInfo } from "./types.js";
|
||||
|
||||
/**
|
||||
* Extract context window from /api/show model_info.
|
||||
* Keys follow the pattern "{architecture}.context_length" (e.g. "llama.context_length").
|
||||
*/
|
||||
function extractContextFromModelInfo(modelInfo: Record<string, unknown>): number | undefined {
|
||||
for (const [key, value] of Object.entries(modelInfo)) {
|
||||
if (key.endsWith(".context_length") && typeof value === "number" && value > 0) {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
type ClientDeps = {
|
||||
listModels: typeof listModels;
|
||||
showModel: typeof showModel;
|
||||
};
|
||||
|
||||
export interface DiscoveredOllamaModel {
|
||||
id: string;
|
||||
name: string;
|
||||
|
|
@ -35,13 +53,26 @@ export interface DiscoveredOllamaModel {
|
|||
|
||||
const ZERO_COST = { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 };
|
||||
|
||||
function enrichModel(info: OllamaModelInfo): DiscoveredOllamaModel {
|
||||
async function enrichModel(info: OllamaModelInfo, deps: ClientDeps): Promise<DiscoveredOllamaModel> {
|
||||
const caps = getModelCapabilities(info.name);
|
||||
const parameterSize = info.details?.parameter_size ?? "";
|
||||
|
||||
// Determine context window: known table > estimate from param size > default
|
||||
// /api/tags doesn't include context length; /api/show does via "{arch}.context_length" in model_info.
|
||||
let showContextWindow: number | undefined;
|
||||
if (caps.contextWindow === undefined) {
|
||||
try {
|
||||
const showData = await deps.showModel(info.name);
|
||||
showContextWindow = extractContextFromModelInfo(showData.model_info);
|
||||
} catch (err) {
|
||||
// non-fatal: fall through to estimate
|
||||
if (process.env.GSD_DEBUG) console.warn(`[ollama] /api/show failed for ${info.name}:`, err instanceof Error ? err.message : String(err));
|
||||
}
|
||||
}
|
||||
|
||||
// Determine context window: known table > /api/show > estimate from param size > default
|
||||
const contextWindow =
|
||||
caps.contextWindow ??
|
||||
showContextWindow ??
|
||||
(parameterSize ? estimateContextFromParams(parameterSize) : 8192);
|
||||
|
||||
// Determine max tokens: known table > fraction of context > default
|
||||
|
|
@ -73,11 +104,11 @@ function enrichModel(info: OllamaModelInfo): DiscoveredOllamaModel {
|
|||
/**
|
||||
* Discover all locally available Ollama models with enriched capabilities.
|
||||
*/
|
||||
export async function discoverModels(): Promise<DiscoveredOllamaModel[]> {
|
||||
const tags = await listModels();
|
||||
export async function discoverModels(deps: ClientDeps = { listModels, showModel }): Promise<DiscoveredOllamaModel[]> {
|
||||
const tags = await deps.listModels();
|
||||
if (!tags.models || tags.models.length === 0) return [];
|
||||
|
||||
return tags.models.map(enrichModel);
|
||||
return Promise.all(tags.models.map((m) => enrichModel(m, deps)));
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -1 +1,55 @@
|
|||
// GSD2 — Tests for Ollama model discovery and enrichment
|
||||
import { describe, it } from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { discoverModels } from "../ollama-discovery.js";
|
||||
import type { OllamaTagsResponse, OllamaShowResponse } from "../types.js";
|
||||
|
||||
const EMPTY_DETAILS = { parent_model: "", format: "", family: "", families: null, parameter_size: "", quantization_level: "" };
|
||||
|
||||
function modelStub(name: string, parameterSize = "") {
|
||||
return { name, model: name, modified_at: "", size: 0, digest: "", details: { ...EMPTY_DETAILS, parameter_size: parameterSize } };
|
||||
}
|
||||
|
||||
function tagsStub(name: string, parameterSize = ""): OllamaTagsResponse {
|
||||
return { models: [modelStub(name, parameterSize)] };
|
||||
}
|
||||
|
||||
function showStub(modelInfo: Record<string, unknown>): OllamaShowResponse {
|
||||
return { modelfile: "", parameters: "", template: "", details: EMPTY_DETAILS, model_info: modelInfo };
|
||||
}
|
||||
|
||||
describe("discoverModels — context window resolution", () => {
|
||||
it("uses known table context window without calling /api/show", async () => {
|
||||
let showCalled = false;
|
||||
const models = await discoverModels({
|
||||
listModels: async () => tagsStub("llama3.2:latest", "3B"),
|
||||
showModel: async () => { showCalled = true; throw new Error("should not be called"); },
|
||||
});
|
||||
assert.equal(models[0].contextWindow, 131072);
|
||||
assert.equal(showCalled, false);
|
||||
});
|
||||
|
||||
it("uses context_length from /api/show model_info for unknown model", async () => {
|
||||
const models = await discoverModels({
|
||||
listModels: async () => tagsStub("gemini-3-flash-preview:latest"),
|
||||
showModel: async () => showStub({ "gemini.context_length": 1048576 }),
|
||||
});
|
||||
assert.equal(models[0].contextWindow, 1048576);
|
||||
});
|
||||
|
||||
it("falls back to 8192 when /api/show model_info has no context_length key", async () => {
|
||||
const models = await discoverModels({
|
||||
listModels: async () => tagsStub("unknown-model:latest"),
|
||||
showModel: async () => showStub({}),
|
||||
});
|
||||
assert.equal(models[0].contextWindow, 8192);
|
||||
});
|
||||
|
||||
it("falls back to 8192 when /api/show throws", async () => {
|
||||
const models = await discoverModels({
|
||||
listModels: async () => tagsStub("unknown-model:latest"),
|
||||
showModel: async () => { throw new Error("network error"); },
|
||||
});
|
||||
assert.equal(models[0].contextWindow, 8192);
|
||||
});
|
||||
});
|
||||
Loading…
Add table
Reference in a new issue