From 4b6a43c2b3b3f9aef97688627de30996b621aa1b Mon Sep 17 00:00:00 2001 From: Lex Christopherson Date: Fri, 13 Mar 2026 14:45:35 -0600 Subject: [PATCH 1/2] feat: multi-credential round-robin with rate-limit fallback Support multiple API keys per provider with automatic rotation: - AuthStorageData accepts single credential or array per provider - Round-robin selection across credentials (no sessionId) - Session-sticky hashing when sessionId is provided - Credential backoff on rate limits (30s), quota exhaustion (30min), server errors (20s) - markUsageLimitReached() backs off failing credential and returns whether an alternate is available - Login accumulation: duplicate provider logins append API keys instead of replacing - Agent retry handler tries credential fallback before counting against retry budget (immediate retry, no delay) - All getApiKey call sites thread sessionId for sticky selection Backward compatible: single credentials work unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../pi-coding-agent/src/core/agent-session.ts | 65 +++- .../pi-coding-agent/src/core/auth-storage.ts | 342 +++++++++++++++--- .../src/core/model-registry.ts | 10 +- 3 files changed, 345 insertions(+), 72 deletions(-) diff --git a/packages/pi-coding-agent/src/core/agent-session.ts b/packages/pi-coding-agent/src/core/agent-session.ts index 69e7d6680..3c5267854 100644 --- a/packages/pi-coding-agent/src/core/agent-session.ts +++ b/packages/pi-coding-agent/src/core/agent-session.ts @@ -869,7 +869,7 @@ export class AgentSession { } // Validate API key - const apiKey = await this._modelRegistry.getApiKey(this.model); + const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId); if (!apiKey) { const isOAuth = this._modelRegistry.isUsingOAuth(this.model); if (isOAuth) { @@ -1309,7 +1309,7 @@ export class AgentSession { * @throws Error if no API key available for the model */ async setModel(model: Model, options?: { persist?: boolean }): Promise { - const apiKey = await this._modelRegistry.getApiKey(model); + const apiKey = await this._modelRegistry.getApiKey(model, this.sessionId); if (!apiKey) { throw new Error(`No API key for ${model.provider}/${model.id}`); } @@ -1351,7 +1351,7 @@ export class AgentSession { if (apiKeysByProvider.has(provider)) { apiKey = apiKeysByProvider.get(provider); } else { - apiKey = await this._modelRegistry.getApiKeyForProvider(provider); + apiKey = await this._modelRegistry.getApiKeyForProvider(provider, this.sessionId); apiKeysByProvider.set(provider, apiKey); } @@ -1406,7 +1406,7 @@ export class AgentSession { const nextIndex = direction === "forward" ? (currentIndex + 1) % len : (currentIndex - 1 + len) % len; const nextModel = availableModels[nextIndex]; - const apiKey = await this._modelRegistry.getApiKey(nextModel); + const apiKey = await this._modelRegistry.getApiKey(nextModel, this.sessionId); if (!apiKey) { throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`); } @@ -1560,7 +1560,7 @@ export class AgentSession { throw new Error("No model selected"); } - const apiKey = await this._modelRegistry.getApiKey(this.model); + const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId); if (!apiKey) { throw new Error(`No API key for ${this.model.provider}`); } @@ -1780,7 +1780,7 @@ export class AgentSession { return; } - const apiKey = await this._modelRegistry.getApiKey(this.model); + const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId); if (!apiKey) { this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false }); return; @@ -2082,7 +2082,7 @@ export class AgentSession { refreshTools: () => this._refreshToolRegistry(), getCommands, setModel: async (model, options) => { - const key = await this.modelRegistry.getApiKey(model); + const key = await this.modelRegistry.getApiKey(model, this.sessionId); if (!key) return false; await this.setModel(model, options); return true; @@ -2275,8 +2275,21 @@ export class AgentSession { ); } + /** + * Classify an error message into a usage-limit error type for credential backoff. + */ + private _classifyErrorType(errorMessage: string): import("./auth-storage.js").UsageLimitErrorType { + const err = errorMessage.toLowerCase(); + if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) return "quota_exhausted"; + if (/rate.?limit|too many requests|429/i.test(err)) return "rate_limit"; + if (/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(err)) return "server_error"; + return "unknown"; + } + /** * Handle retryable errors with exponential backoff. + * When multiple credentials are available, marks the failing credential + * as backed off and retries immediately with the next one. * @returns true if retry was initiated, false if max retries exceeded or disabled */ private async _handleRetryableError(message: AssistantMessage): Promise { @@ -2294,6 +2307,42 @@ export class AgentSession { }); } + // Try credential fallback before counting against retry budget. + // If another credential is available, switch to it and retry immediately. + if (this.model && message.errorMessage) { + const errorType = this._classifyErrorType(message.errorMessage); + const hasAlternate = this._modelRegistry.authStorage.markUsageLimitReached( + this.model.provider, + this.sessionId, + { errorType }, + ); + + if (hasAlternate) { + // Remove error message from agent state + const messages = this.agent.state.messages; + if (messages.length > 0 && messages[messages.length - 1].role === "assistant") { + this.agent.replaceMessages(messages.slice(0, -1)); + } + + this._emit({ + type: "auto_retry_start", + attempt: this._retryAttempt + 1, + maxAttempts: settings.maxRetries, + delayMs: 0, + errorMessage: `${message.errorMessage} (switching credential)`, + }); + + // Retry immediately with the next credential - don't increment _retryAttempt + setTimeout(() => { + this.agent.continue().catch(() => { + // Retry failed - will be caught by next agent_end + }); + }, 0); + + return true; + } + } + this._retryAttempt++; if (this._retryAttempt > settings.maxRetries) { @@ -2750,7 +2799,7 @@ export class AgentSession { let summaryDetails: unknown; if (options.summarize && entriesToSummarize.length > 0 && !extensionSummary) { const model = this.model!; - const apiKey = await this._modelRegistry.getApiKey(model); + const apiKey = await this._modelRegistry.getApiKey(model, this.sessionId); if (!apiKey) { throw new Error(`No API key for ${model.provider}`); } diff --git a/packages/pi-coding-agent/src/core/auth-storage.ts b/packages/pi-coding-agent/src/core/auth-storage.ts index f8eb23bc5..d4578a53c 100644 --- a/packages/pi-coding-agent/src/core/auth-storage.ts +++ b/packages/pi-coding-agent/src/core/auth-storage.ts @@ -2,6 +2,9 @@ * Credential storage for API keys and OAuth tokens. * Handles loading, saving, and refreshing credentials from auth.json. * + * Supports multiple credentials per provider with round-robin selection, + * session-sticky hashing, and automatic rate-limit fallback. + * * Uses file locking to prevent race conditions when multiple pi instances * try to refresh tokens simultaneously. */ @@ -30,7 +33,11 @@ export type OAuthCredential = { export type AuthCredential = ApiKeyCredential | OAuthCredential; -export type AuthStorageData = Record; +/** + * On-disk format: each provider maps to a single credential or an array of credentials. + * Single credentials are normalized to arrays at load time for internal use. + */ +export type AuthStorageData = Record; type LockResult = { result: T; @@ -178,8 +185,49 @@ export class InMemoryAuthStorageBackend implements AuthStorageBackend { } } +// ============================================================================ +// Backoff durations for different error types (milliseconds) +// ============================================================================ + +const BACKOFF_RATE_LIMIT_MS = 30_000; // 30s for rate limit / 429 +const BACKOFF_QUOTA_EXHAUSTED_MS = 30 * 60_000; // 30min for quota exhausted +const BACKOFF_SERVER_ERROR_MS = 20_000; // 20s for 5xx server errors +const BACKOFF_DEFAULT_MS = 60_000; // 60s fallback + +export type UsageLimitErrorType = "rate_limit" | "quota_exhausted" | "server_error" | "unknown"; + +/** + * Get backoff duration for an error type. + */ +function getBackoffDuration(errorType: UsageLimitErrorType): number { + switch (errorType) { + case "rate_limit": + return BACKOFF_RATE_LIMIT_MS; + case "quota_exhausted": + return BACKOFF_QUOTA_EXHAUSTED_MS; + case "server_error": + return BACKOFF_SERVER_ERROR_MS; + default: + return BACKOFF_DEFAULT_MS; + } +} + +/** + * Simple string hash for session-sticky credential selection. + * Returns a positive integer. + */ +function hashString(str: string): number { + let hash = 0; + for (let i = 0; i < str.length; i++) { + const char = str.charCodeAt(i); + hash = ((hash << 5) - hash + char) | 0; + } + return Math.abs(hash); +} + /** * Credential storage backed by a JSON file. + * Supports multiple credentials per provider with round-robin rotation and rate-limit fallback. */ export class AuthStorage { private data: AuthStorageData = {}; @@ -188,6 +236,18 @@ export class AuthStorage { private loadError: Error | null = null; private errors: Error[] = []; + /** + * Round-robin index per provider. Incremented on each call to getApiKey + * when no sessionId is provided. + */ + private providerRoundRobinIndex: Map = new Map(); + + /** + * Backoff tracking per provider per credential index. + * Map> + */ + private credentialBackoff: Map> = new Map(); + private constructor(private storage: AuthStorageBackend) { this.reload(); } @@ -241,6 +301,17 @@ export class AuthStorage { return JSON.parse(content) as AuthStorageData; } + /** + * Normalize a storage entry to an array of credentials. + * Handles both single credential (backward compat) and array formats. + */ + getCredentialsForProvider(provider: string): AuthCredential[] { + const entry = this.data[provider]; + if (!entry) return []; + if (Array.isArray(entry)) return entry; + return [entry]; + } + /** * Reload credentials from storage. */ @@ -259,7 +330,7 @@ export class AuthStorage { } } - private persistProviderChange(provider: string, credential: AuthCredential | undefined): void { + private persistProviderChange(provider: string, credential: AuthCredential | AuthCredential[] | undefined): void { if (this.loadError) { return; } @@ -281,25 +352,52 @@ export class AuthStorage { } /** - * Get credential for a provider. + * Get the first credential for a provider (backward-compatible). */ get(provider: string): AuthCredential | undefined { - return this.data[provider] ?? undefined; + const creds = this.getCredentialsForProvider(provider); + return creds[0] ?? undefined; } /** - * Set credential for a provider. + * Set credential for a provider. For API key credentials, appends to + * existing credentials (accumulation on duplicate login). For OAuth, + * replaces (only one OAuth token per provider makes sense). */ set(provider: string, credential: AuthCredential): void { - this.data[provider] = credential; - this.persistProviderChange(provider, credential); + if (credential.type === "api_key") { + const existing = this.getCredentialsForProvider(provider); + // Deduplicate: don't add if same key already exists + const isDuplicate = existing.some( + (c) => c.type === "api_key" && c.key === credential.key, + ); + if (isDuplicate) return; + + const updated = [...existing, credential]; + this.data[provider] = updated.length === 1 ? updated[0] : updated; + this.persistProviderChange(provider, updated.length === 1 ? updated[0] : updated); + } else { + // OAuth: replace any existing OAuth credential, keep API keys + const existing = this.getCredentialsForProvider(provider); + const apiKeys = existing.filter((c) => c.type === "api_key"); + if (apiKeys.length === 0) { + this.data[provider] = credential; + this.persistProviderChange(provider, credential); + } else { + const updated = [...apiKeys, credential]; + this.data[provider] = updated; + this.persistProviderChange(provider, updated); + } + } } /** - * Remove credential for a provider. + * Remove all credentials for a provider. */ remove(provider: string): void { delete this.data[provider]; + this.providerRoundRobinIndex.delete(provider); + this.credentialBackoff.delete(provider); this.persistProviderChange(provider, undefined); } @@ -331,9 +429,15 @@ export class AuthStorage { /** * Get all credentials (for passing to getOAuthApiKey). + * Returns normalized format where each provider has a single credential + * (the first one) for backward compatibility with OAuth refresh. */ - getAll(): AuthStorageData { - return { ...this.data }; + getAll(): Record { + const result: Record = {}; + for (const [provider, entry] of Object.entries(this.data)) { + result[provider] = Array.isArray(entry) ? entry[0] : entry; + } + return result; } drainErrors(): Error[] { @@ -362,6 +466,104 @@ export class AuthStorage { this.remove(provider); } + /** + * Check if a credential index is currently backed off. + */ + private isCredentialBackedOff(provider: string, index: number): boolean { + const providerBackoff = this.credentialBackoff.get(provider); + if (!providerBackoff) return false; + const expiresAt = providerBackoff.get(index); + if (expiresAt === undefined) return false; + if (Date.now() >= expiresAt) { + providerBackoff.delete(index); + return false; + } + return true; + } + + /** + * Select the best credential index for a provider. + * - If sessionId is provided, uses session-sticky hashing as the starting point. + * - Otherwise, uses round-robin as the starting point. + * - Skips credentials that are currently backed off. + * - Returns -1 if all credentials are backed off. + */ + private selectCredentialIndex(provider: string, credentials: AuthCredential[], sessionId?: string): number { + if (credentials.length === 0) return -1; + if (credentials.length === 1) { + return this.isCredentialBackedOff(provider, 0) ? -1 : 0; + } + + let startIndex: number; + if (sessionId) { + startIndex = hashString(sessionId) % credentials.length; + } else { + const current = this.providerRoundRobinIndex.get(provider) ?? 0; + startIndex = current % credentials.length; + this.providerRoundRobinIndex.set(provider, current + 1); + } + + // Try starting from the preferred index, wrapping around + for (let offset = 0; offset < credentials.length; offset++) { + const index = (startIndex + offset) % credentials.length; + if (!this.isCredentialBackedOff(provider, index)) { + return index; + } + } + + // All credentials are backed off + return -1; + } + + /** + * Mark a credential as rate-limited. Finds the credential that was most + * recently used for this provider+session and backs it off. + * + * @returns true if another credential is available (caller should retry), + * false if all credentials for this provider are backed off. + */ + markUsageLimitReached( + provider: string, + sessionId?: string, + options?: { errorType?: UsageLimitErrorType }, + ): boolean { + const credentials = this.getCredentialsForProvider(provider); + if (credentials.length === 0) return false; + + const errorType = options?.errorType ?? "rate_limit"; + const backoffMs = getBackoffDuration(errorType); + + // Determine which credential was just used (same logic as selectCredentialIndex + // but without incrementing round-robin) + let usedIndex: number; + if (credentials.length === 1) { + usedIndex = 0; + } else if (sessionId) { + usedIndex = hashString(sessionId) % credentials.length; + } else { + // Round-robin was already incremented in getApiKey, so the last-used + // index is (current - 1) + const current = this.providerRoundRobinIndex.get(provider) ?? 0; + usedIndex = ((current - 1) % credentials.length + credentials.length) % credentials.length; + } + + // Set backoff for this credential + let providerBackoff = this.credentialBackoff.get(provider); + if (!providerBackoff) { + providerBackoff = new Map(); + this.credentialBackoff.set(provider, providerBackoff); + } + providerBackoff.set(usedIndex, Date.now() + backoffMs); + + // Check if any credential is still available + for (let i = 0; i < credentials.length; i++) { + if (!this.isCredentialBackedOff(provider, i)) { + return true; + } + } + return false; + } + /** * Refresh OAuth token with backend locking to prevent race conditions. * Multiple pi instances may try to refresh simultaneously when tokens expire. @@ -379,8 +581,10 @@ export class AuthStorage { this.data = currentData; this.loadError = null; - const cred = currentData[providerId]; - if (cred?.type !== "oauth") { + // Find the OAuth credential for this provider + const creds = this.getCredentialsForProvider(providerId); + const cred = creds.find((c) => c.type === "oauth"); + if (!cred || cred.type !== "oauth") { return { result: null }; } @@ -390,8 +594,9 @@ export class AuthStorage { const oauthCreds: Record = {}; for (const [key, value] of Object.entries(currentData)) { - if (value.type === "oauth") { - oauthCreds[key] = value; + const first = Array.isArray(value) ? value.find((c) => c.type === "oauth") : value; + if (first?.type === "oauth") { + oauthCreds[key] = first; } } @@ -400,9 +605,20 @@ export class AuthStorage { return { result: null }; } + // Update the OAuth credential in-place within the array + const existingEntry = currentData[providerId]; + const newOAuthCred: OAuthCredential = { type: "oauth", ...refreshed.newCredentials }; + let updatedEntry: AuthCredential | AuthCredential[]; + + if (Array.isArray(existingEntry)) { + updatedEntry = existingEntry.map((c) => (c.type === "oauth" ? newOAuthCred : c)); + } else { + updatedEntry = newOAuthCred; + } + const merged: AuthStorageData = { ...currentData, - [providerId]: { type: "oauth", ...refreshed.newCredentials }, + [providerId]: updatedEntry, }; this.data = merged; this.loadError = null; @@ -412,64 +628,70 @@ export class AuthStorage { return result; } + /** + * Resolve an API key from a single credential. + */ + private async resolveCredentialApiKey( + providerId: string, + cred: AuthCredential, + ): Promise { + if (cred.type === "api_key") { + return resolveConfigValue(cred.key); + } + + if (cred.type === "oauth") { + const provider = getOAuthProvider(providerId); + if (!provider) return undefined; + + const needsRefresh = Date.now() >= cred.expires; + if (needsRefresh) { + try { + const result = await this.refreshOAuthTokenWithLock(providerId); + if (result) return result.apiKey; + } catch (error) { + this.recordError(error); + this.reload(); + const updatedCreds = this.getCredentialsForProvider(providerId); + const updatedOAuth = updatedCreds.find((c) => c.type === "oauth"); + if (updatedOAuth?.type === "oauth" && Date.now() < updatedOAuth.expires) { + return provider.getApiKey(updatedOAuth); + } + return undefined; + } + } else { + return provider.getApiKey(cred); + } + } + + return undefined; + } + /** * Get API key for a provider. * Priority: * 1. Runtime override (CLI --api-key) - * 2. API key from auth.json - * 3. OAuth token from auth.json (auto-refreshed with locking) - * 4. Environment variable - * 5. Fallback resolver (models.json custom providers) + * 2. Credential(s) from auth.json (with round-robin / session-sticky selection) + * 3. Environment variable + * 4. Fallback resolver (models.json custom providers) + * + * @param providerId - The provider to get an API key for + * @param sessionId - Optional session ID for sticky credential selection */ - async getApiKey(providerId: string): Promise { + async getApiKey(providerId: string, sessionId?: string): Promise { // Runtime override takes highest priority const runtimeKey = this.runtimeOverrides.get(providerId); if (runtimeKey) { return runtimeKey; } - const cred = this.data[providerId]; + const credentials = this.getCredentialsForProvider(providerId); - if (cred?.type === "api_key") { - return resolveConfigValue(cred.key); - } - - if (cred?.type === "oauth") { - const provider = getOAuthProvider(providerId); - if (!provider) { - // Unknown OAuth provider, can't get API key - return undefined; - } - - // Check if token needs refresh - const needsRefresh = Date.now() >= cred.expires; - - if (needsRefresh) { - // Use locked refresh to prevent race conditions - try { - const result = await this.refreshOAuthTokenWithLock(providerId); - if (result) { - return result.apiKey; - } - } catch (error) { - this.recordError(error); - // Refresh failed - re-read file to check if another instance succeeded - this.reload(); - const updatedCred = this.data[providerId]; - - if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) { - // Another instance refreshed successfully, use those credentials - return provider.getApiKey(updatedCred); - } - - // Refresh truly failed - return undefined so model discovery skips this provider - // User can /login to re-authenticate (credentials preserved for retry) - return undefined; - } - } else { - // Token not expired, use current access token - return provider.getApiKey(cred); + if (credentials.length > 0) { + const index = this.selectCredentialIndex(providerId, credentials, sessionId); + if (index >= 0) { + return this.resolveCredentialApiKey(providerId, credentials[index]); } + // All credentials backed off - fall through to env/fallback } // Fall back to environment variable diff --git a/packages/pi-coding-agent/src/core/model-registry.ts b/packages/pi-coding-agent/src/core/model-registry.ts index 6cfdc3c4f..6d90af67f 100644 --- a/packages/pi-coding-agent/src/core/model-registry.ts +++ b/packages/pi-coding-agent/src/core/model-registry.ts @@ -517,16 +517,18 @@ export class ModelRegistry { /** * Get API key for a model. + * @param sessionId - Optional session ID for sticky credential selection */ - async getApiKey(model: Model): Promise { - return this.authStorage.getApiKey(model.provider); + async getApiKey(model: Model, sessionId?: string): Promise { + return this.authStorage.getApiKey(model.provider, sessionId); } /** * Get API key for a provider. + * @param sessionId - Optional session ID for sticky credential selection */ - async getApiKeyForProvider(provider: string): Promise { - return this.authStorage.getApiKey(provider); + async getApiKeyForProvider(provider: string, sessionId?: string): Promise { + return this.authStorage.getApiKey(provider, sessionId); } /** From e9676202e1dfa4fc76599218e6d633184a63e140 Mon Sep 17 00:00:00 2001 From: Lex Christopherson Date: Fri, 13 Mar 2026 15:49:44 -0600 Subject: [PATCH 2/2] fix: add tests and clarify edge cases for multi-credential auth storage - Add 14 tests covering round-robin, session-sticky, login accumulation, backoff/fallback, and getAll() truncation behavior - Document getAll() truncation is intentional (OAuth refresh only) - Add comment in markUsageLimitReached explaining round-robin race is benign in single-threaded event loop context Co-Authored-By: Claude Sonnet 4.6 --- .../src/core/auth-storage.test.ts | 194 ++++++++++++++++++ .../pi-coding-agent/src/core/auth-storage.ts | 10 +- 2 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 packages/pi-coding-agent/src/core/auth-storage.test.ts diff --git a/packages/pi-coding-agent/src/core/auth-storage.test.ts b/packages/pi-coding-agent/src/core/auth-storage.test.ts new file mode 100644 index 000000000..50b7ffedc --- /dev/null +++ b/packages/pi-coding-agent/src/core/auth-storage.test.ts @@ -0,0 +1,194 @@ +import { describe, it } from "node:test"; +import assert from "node:assert/strict"; +import { AuthStorage } from "./auth-storage.js"; + +// ─── helpers ────────────────────────────────────────────────────────────────── + +function makeKey(key: string) { + return { type: "api_key" as const, key }; +} + +function inMemory(data: Record = {}) { + return AuthStorage.inMemory(data as any); +} + +// ─── single credential (backward compat) ───────────────────────────────────── + +describe("AuthStorage — single credential (backward compat)", () => { + it("returns the api key for a provider with one key", async () => { + const storage = inMemory({ anthropic: makeKey("sk-abc") }); + const key = await storage.getApiKey("anthropic"); + assert.equal(key, "sk-abc"); + }); + + it("returns undefined for unknown provider", async () => { + const storage = inMemory({}); + const key = await storage.getApiKey("unknown"); + assert.equal(key, undefined); + }); + + it("runtime override takes precedence over stored key", async () => { + const storage = inMemory({ anthropic: makeKey("sk-stored") }); + storage.setRuntimeApiKey("anthropic", "sk-runtime"); + const key = await storage.getApiKey("anthropic"); + assert.equal(key, "sk-runtime"); + }); +}); + +// ─── multiple credentials ───────────────────────────────────────────────────── + +describe("AuthStorage — multiple credentials", () => { + it("round-robins across multiple api keys without sessionId", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")], + }); + + const keys = new Set(); + for (let i = 0; i < 6; i++) { + const k = await storage.getApiKey("anthropic"); + assert.ok(k, `call ${i} should return a key`); + keys.add(k); + } + // All three keys should have been selected across 6 calls + assert.deepEqual(keys, new Set(["sk-1", "sk-2", "sk-3"])); + }); + + it("session-sticky: same sessionId always picks the same key", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")], + }); + + const sessionId = "sess-abc"; + const first = await storage.getApiKey("anthropic", sessionId); + for (let i = 0; i < 5; i++) { + const k = await storage.getApiKey("anthropic", sessionId); + assert.equal(k, first, `call ${i} should be sticky to first selection`); + } + }); + + it("different sessionIds may select different keys", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")], + }); + + const results = new Set(); + for (let i = 0; i < 20; i++) { + const k = await storage.getApiKey("anthropic", `sess-${i}`); + if (k) results.add(k); + } + // With 20 different sessions and 3 keys, we should see more than one key + assert.ok(results.size > 1, "multiple sessions should hash to different keys"); + }); +}); + +// ─── login accumulation ─────────────────────────────────────────────────────── + +describe("AuthStorage — login accumulation", () => { + it("accumulates api keys on repeated set()", () => { + const storage = inMemory({}); + storage.set("anthropic", makeKey("sk-1")); + storage.set("anthropic", makeKey("sk-2")); + const creds = storage.getCredentialsForProvider("anthropic"); + assert.equal(creds.length, 2); + assert.deepEqual( + creds.map((c) => (c.type === "api_key" ? c.key : null)), + ["sk-1", "sk-2"], + ); + }); + + it("deduplicates identical api keys", () => { + const storage = inMemory({}); + storage.set("anthropic", makeKey("sk-1")); + storage.set("anthropic", makeKey("sk-1")); + const creds = storage.getCredentialsForProvider("anthropic"); + assert.equal(creds.length, 1); + }); +}); + +// ─── backoff / markUsageLimitReached ───────────────────────────────────────── + +describe("AuthStorage — rate-limit backoff", () => { + it("returns true when a backed-off credential has an alternate", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2")], + }); + + // Use sk-1 via round-robin (first call, index 0) + await storage.getApiKey("anthropic"); + + // Mark it as rate-limited; sk-2 should still be available + const hasAlternate = storage.markUsageLimitReached("anthropic"); + assert.equal(hasAlternate, true); + }); + + it("returns false when all credentials are backed off", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2")], + }); + + // Back off both keys + await storage.getApiKey("anthropic"); // uses index 0 + storage.markUsageLimitReached("anthropic"); // backs off index 0 + await storage.getApiKey("anthropic"); // uses index 1 + const hasAlternate = storage.markUsageLimitReached("anthropic"); // backs off index 1 + assert.equal(hasAlternate, false); + }); + + it("backed-off credential is skipped; next available key is returned", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2")], + }); + + // First call → sk-1 (round-robin index 0) + const first = await storage.getApiKey("anthropic"); + assert.equal(first, "sk-1"); + + // Back off sk-1 + storage.markUsageLimitReached("anthropic"); + + // Next call should skip backed-off sk-1 and return sk-2 + const second = await storage.getApiKey("anthropic"); + assert.equal(second, "sk-2"); + }); + + it("single credential: markUsageLimitReached returns false", async () => { + const storage = inMemory({ anthropic: makeKey("sk-only") }); + await storage.getApiKey("anthropic"); + const hasAlternate = storage.markUsageLimitReached("anthropic"); + assert.equal(hasAlternate, false); + }); + + it("session-sticky: marks the correct credential as backed off", async () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2")], + }); + + const sessionId = "sess-xyz"; + const chosen = await storage.getApiKey("anthropic", sessionId); + assert.ok(chosen); + + // Back off the chosen credential for this session + const hasAlternate = storage.markUsageLimitReached("anthropic", sessionId); + assert.equal(hasAlternate, true); + + // Next call with same session should return the other key + const next = await storage.getApiKey("anthropic", sessionId); + assert.ok(next); + assert.notEqual(next, chosen); + }); +}); + +// ─── getAll truncation ──────────────────────────────────────────────────────── + +describe("AuthStorage — getAll()", () => { + it("returns first credential only for providers with multiple keys", () => { + const storage = inMemory({ + anthropic: [makeKey("sk-1"), makeKey("sk-2")], + openai: makeKey("sk-openai"), + }); + const all = storage.getAll(); + assert.ok(all["anthropic"]?.type === "api_key"); + assert.equal((all["anthropic"] as any).key, "sk-1"); + assert.equal((all["openai"] as any).key, "sk-openai"); + }); +}); diff --git a/packages/pi-coding-agent/src/core/auth-storage.ts b/packages/pi-coding-agent/src/core/auth-storage.ts index d4578a53c..30beef551 100644 --- a/packages/pi-coding-agent/src/core/auth-storage.ts +++ b/packages/pi-coding-agent/src/core/auth-storage.ts @@ -431,6 +431,10 @@ export class AuthStorage { * Get all credentials (for passing to getOAuthApiKey). * Returns normalized format where each provider has a single credential * (the first one) for backward compatibility with OAuth refresh. + * + * NOTE: For providers with multiple API keys, only the first credential is + * returned. This is intentional — callers use this for OAuth refresh only, + * which is always single-credential. Do not use for API key enumeration. */ getAll(): Record { const result: Record = {}; @@ -542,7 +546,11 @@ export class AuthStorage { usedIndex = hashString(sessionId) % credentials.length; } else { // Round-robin was already incremented in getApiKey, so the last-used - // index is (current - 1) + // index is (current - 1). Note: in a concurrent scenario where another + // getApiKey call fires between the original request and this backoff call, + // we may back off the wrong credential index. This is acceptable because: + // (a) pi runs single-threaded event loop, (b) backing off the wrong key + // is safe — it self-heals when the backoff expires. const current = this.providerRoundRobinIndex.get(provider) ?? 0; usedIndex = ((current - 1) % credentials.length + credentials.length) % credentials.length; }