Merge pull request #255 from gsd-build/feat/multi-credential
feat: multi-credential round-robin with rate-limit fallback
This commit is contained in:
commit
2452d34f53
4 changed files with 547 additions and 72 deletions
|
|
@ -869,7 +869,7 @@ export class AgentSession {
|
|||
}
|
||||
|
||||
// Validate API key
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId);
|
||||
if (!apiKey) {
|
||||
const isOAuth = this._modelRegistry.isUsingOAuth(this.model);
|
||||
if (isOAuth) {
|
||||
|
|
@ -1309,7 +1309,7 @@ export class AgentSession {
|
|||
* @throws Error if no API key available for the model
|
||||
*/
|
||||
async setModel(model: Model<any>, options?: { persist?: boolean }): Promise<void> {
|
||||
const apiKey = await this._modelRegistry.getApiKey(model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(model, this.sessionId);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${model.provider}/${model.id}`);
|
||||
}
|
||||
|
|
@ -1351,7 +1351,7 @@ export class AgentSession {
|
|||
if (apiKeysByProvider.has(provider)) {
|
||||
apiKey = apiKeysByProvider.get(provider);
|
||||
} else {
|
||||
apiKey = await this._modelRegistry.getApiKeyForProvider(provider);
|
||||
apiKey = await this._modelRegistry.getApiKeyForProvider(provider, this.sessionId);
|
||||
apiKeysByProvider.set(provider, apiKey);
|
||||
}
|
||||
|
||||
|
|
@ -1406,7 +1406,7 @@ export class AgentSession {
|
|||
const nextIndex = direction === "forward" ? (currentIndex + 1) % len : (currentIndex - 1 + len) % len;
|
||||
const nextModel = availableModels[nextIndex];
|
||||
|
||||
const apiKey = await this._modelRegistry.getApiKey(nextModel);
|
||||
const apiKey = await this._modelRegistry.getApiKey(nextModel, this.sessionId);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${nextModel.provider}/${nextModel.id}`);
|
||||
}
|
||||
|
|
@ -1560,7 +1560,7 @@ export class AgentSession {
|
|||
throw new Error("No model selected");
|
||||
}
|
||||
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${this.model.provider}`);
|
||||
}
|
||||
|
|
@ -1780,7 +1780,7 @@ export class AgentSession {
|
|||
return;
|
||||
}
|
||||
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(this.model, this.sessionId);
|
||||
if (!apiKey) {
|
||||
this._emit({ type: "auto_compaction_end", result: undefined, aborted: false, willRetry: false });
|
||||
return;
|
||||
|
|
@ -2082,7 +2082,7 @@ export class AgentSession {
|
|||
refreshTools: () => this._refreshToolRegistry(),
|
||||
getCommands,
|
||||
setModel: async (model, options) => {
|
||||
const key = await this.modelRegistry.getApiKey(model);
|
||||
const key = await this.modelRegistry.getApiKey(model, this.sessionId);
|
||||
if (!key) return false;
|
||||
await this.setModel(model, options);
|
||||
return true;
|
||||
|
|
@ -2282,8 +2282,21 @@ export class AgentSession {
|
|||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Classify an error message into a usage-limit error type for credential backoff.
|
||||
*/
|
||||
private _classifyErrorType(errorMessage: string): import("./auth-storage.js").UsageLimitErrorType {
|
||||
const err = errorMessage.toLowerCase();
|
||||
if (/quota|billing|exceeded.*limit|usage.*limit/i.test(err)) return "quota_exhausted";
|
||||
if (/rate.?limit|too many requests|429/i.test(err)) return "rate_limit";
|
||||
if (/500|502|503|504|server.?error|internal.?error|service.?unavailable/i.test(err)) return "server_error";
|
||||
return "unknown";
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle retryable errors with exponential backoff.
|
||||
* When multiple credentials are available, marks the failing credential
|
||||
* as backed off and retries immediately with the next one.
|
||||
* @returns true if retry was initiated, false if max retries exceeded or disabled
|
||||
*/
|
||||
private async _handleRetryableError(message: AssistantMessage): Promise<boolean> {
|
||||
|
|
@ -2301,6 +2314,42 @@ export class AgentSession {
|
|||
});
|
||||
}
|
||||
|
||||
// Try credential fallback before counting against retry budget.
|
||||
// If another credential is available, switch to it and retry immediately.
|
||||
if (this.model && message.errorMessage) {
|
||||
const errorType = this._classifyErrorType(message.errorMessage);
|
||||
const hasAlternate = this._modelRegistry.authStorage.markUsageLimitReached(
|
||||
this.model.provider,
|
||||
this.sessionId,
|
||||
{ errorType },
|
||||
);
|
||||
|
||||
if (hasAlternate) {
|
||||
// Remove error message from agent state
|
||||
const messages = this.agent.state.messages;
|
||||
if (messages.length > 0 && messages[messages.length - 1].role === "assistant") {
|
||||
this.agent.replaceMessages(messages.slice(0, -1));
|
||||
}
|
||||
|
||||
this._emit({
|
||||
type: "auto_retry_start",
|
||||
attempt: this._retryAttempt + 1,
|
||||
maxAttempts: settings.maxRetries,
|
||||
delayMs: 0,
|
||||
errorMessage: `${message.errorMessage} (switching credential)`,
|
||||
});
|
||||
|
||||
// Retry immediately with the next credential - don't increment _retryAttempt
|
||||
setTimeout(() => {
|
||||
this.agent.continue().catch(() => {
|
||||
// Retry failed - will be caught by next agent_end
|
||||
});
|
||||
}, 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
this._retryAttempt++;
|
||||
|
||||
if (this._retryAttempt > settings.maxRetries) {
|
||||
|
|
@ -2757,7 +2806,7 @@ export class AgentSession {
|
|||
let summaryDetails: unknown;
|
||||
if (options.summarize && entriesToSummarize.length > 0 && !extensionSummary) {
|
||||
const model = this.model!;
|
||||
const apiKey = await this._modelRegistry.getApiKey(model);
|
||||
const apiKey = await this._modelRegistry.getApiKey(model, this.sessionId);
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for ${model.provider}`);
|
||||
}
|
||||
|
|
|
|||
194
packages/pi-coding-agent/src/core/auth-storage.test.ts
Normal file
194
packages/pi-coding-agent/src/core/auth-storage.test.ts
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
import { describe, it } from "node:test";
|
||||
import assert from "node:assert/strict";
|
||||
import { AuthStorage } from "./auth-storage.js";
|
||||
|
||||
// ─── helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
function makeKey(key: string) {
|
||||
return { type: "api_key" as const, key };
|
||||
}
|
||||
|
||||
function inMemory(data: Record<string, unknown> = {}) {
|
||||
return AuthStorage.inMemory(data as any);
|
||||
}
|
||||
|
||||
// ─── single credential (backward compat) ─────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — single credential (backward compat)", () => {
|
||||
it("returns the api key for a provider with one key", async () => {
|
||||
const storage = inMemory({ anthropic: makeKey("sk-abc") });
|
||||
const key = await storage.getApiKey("anthropic");
|
||||
assert.equal(key, "sk-abc");
|
||||
});
|
||||
|
||||
it("returns undefined for unknown provider", async () => {
|
||||
const storage = inMemory({});
|
||||
const key = await storage.getApiKey("unknown");
|
||||
assert.equal(key, undefined);
|
||||
});
|
||||
|
||||
it("runtime override takes precedence over stored key", async () => {
|
||||
const storage = inMemory({ anthropic: makeKey("sk-stored") });
|
||||
storage.setRuntimeApiKey("anthropic", "sk-runtime");
|
||||
const key = await storage.getApiKey("anthropic");
|
||||
assert.equal(key, "sk-runtime");
|
||||
});
|
||||
});
|
||||
|
||||
// ─── multiple credentials ─────────────────────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — multiple credentials", () => {
|
||||
it("round-robins across multiple api keys without sessionId", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")],
|
||||
});
|
||||
|
||||
const keys = new Set<string>();
|
||||
for (let i = 0; i < 6; i++) {
|
||||
const k = await storage.getApiKey("anthropic");
|
||||
assert.ok(k, `call ${i} should return a key`);
|
||||
keys.add(k);
|
||||
}
|
||||
// All three keys should have been selected across 6 calls
|
||||
assert.deepEqual(keys, new Set(["sk-1", "sk-2", "sk-3"]));
|
||||
});
|
||||
|
||||
it("session-sticky: same sessionId always picks the same key", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")],
|
||||
});
|
||||
|
||||
const sessionId = "sess-abc";
|
||||
const first = await storage.getApiKey("anthropic", sessionId);
|
||||
for (let i = 0; i < 5; i++) {
|
||||
const k = await storage.getApiKey("anthropic", sessionId);
|
||||
assert.equal(k, first, `call ${i} should be sticky to first selection`);
|
||||
}
|
||||
});
|
||||
|
||||
it("different sessionIds may select different keys", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2"), makeKey("sk-3")],
|
||||
});
|
||||
|
||||
const results = new Set<string>();
|
||||
for (let i = 0; i < 20; i++) {
|
||||
const k = await storage.getApiKey("anthropic", `sess-${i}`);
|
||||
if (k) results.add(k);
|
||||
}
|
||||
// With 20 different sessions and 3 keys, we should see more than one key
|
||||
assert.ok(results.size > 1, "multiple sessions should hash to different keys");
|
||||
});
|
||||
});
|
||||
|
||||
// ─── login accumulation ───────────────────────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — login accumulation", () => {
|
||||
it("accumulates api keys on repeated set()", () => {
|
||||
const storage = inMemory({});
|
||||
storage.set("anthropic", makeKey("sk-1"));
|
||||
storage.set("anthropic", makeKey("sk-2"));
|
||||
const creds = storage.getCredentialsForProvider("anthropic");
|
||||
assert.equal(creds.length, 2);
|
||||
assert.deepEqual(
|
||||
creds.map((c) => (c.type === "api_key" ? c.key : null)),
|
||||
["sk-1", "sk-2"],
|
||||
);
|
||||
});
|
||||
|
||||
it("deduplicates identical api keys", () => {
|
||||
const storage = inMemory({});
|
||||
storage.set("anthropic", makeKey("sk-1"));
|
||||
storage.set("anthropic", makeKey("sk-1"));
|
||||
const creds = storage.getCredentialsForProvider("anthropic");
|
||||
assert.equal(creds.length, 1);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── backoff / markUsageLimitReached ─────────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — rate-limit backoff", () => {
|
||||
it("returns true when a backed-off credential has an alternate", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2")],
|
||||
});
|
||||
|
||||
// Use sk-1 via round-robin (first call, index 0)
|
||||
await storage.getApiKey("anthropic");
|
||||
|
||||
// Mark it as rate-limited; sk-2 should still be available
|
||||
const hasAlternate = storage.markUsageLimitReached("anthropic");
|
||||
assert.equal(hasAlternate, true);
|
||||
});
|
||||
|
||||
it("returns false when all credentials are backed off", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2")],
|
||||
});
|
||||
|
||||
// Back off both keys
|
||||
await storage.getApiKey("anthropic"); // uses index 0
|
||||
storage.markUsageLimitReached("anthropic"); // backs off index 0
|
||||
await storage.getApiKey("anthropic"); // uses index 1
|
||||
const hasAlternate = storage.markUsageLimitReached("anthropic"); // backs off index 1
|
||||
assert.equal(hasAlternate, false);
|
||||
});
|
||||
|
||||
it("backed-off credential is skipped; next available key is returned", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2")],
|
||||
});
|
||||
|
||||
// First call → sk-1 (round-robin index 0)
|
||||
const first = await storage.getApiKey("anthropic");
|
||||
assert.equal(first, "sk-1");
|
||||
|
||||
// Back off sk-1
|
||||
storage.markUsageLimitReached("anthropic");
|
||||
|
||||
// Next call should skip backed-off sk-1 and return sk-2
|
||||
const second = await storage.getApiKey("anthropic");
|
||||
assert.equal(second, "sk-2");
|
||||
});
|
||||
|
||||
it("single credential: markUsageLimitReached returns false", async () => {
|
||||
const storage = inMemory({ anthropic: makeKey("sk-only") });
|
||||
await storage.getApiKey("anthropic");
|
||||
const hasAlternate = storage.markUsageLimitReached("anthropic");
|
||||
assert.equal(hasAlternate, false);
|
||||
});
|
||||
|
||||
it("session-sticky: marks the correct credential as backed off", async () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2")],
|
||||
});
|
||||
|
||||
const sessionId = "sess-xyz";
|
||||
const chosen = await storage.getApiKey("anthropic", sessionId);
|
||||
assert.ok(chosen);
|
||||
|
||||
// Back off the chosen credential for this session
|
||||
const hasAlternate = storage.markUsageLimitReached("anthropic", sessionId);
|
||||
assert.equal(hasAlternate, true);
|
||||
|
||||
// Next call with same session should return the other key
|
||||
const next = await storage.getApiKey("anthropic", sessionId);
|
||||
assert.ok(next);
|
||||
assert.notEqual(next, chosen);
|
||||
});
|
||||
});
|
||||
|
||||
// ─── getAll truncation ────────────────────────────────────────────────────────
|
||||
|
||||
describe("AuthStorage — getAll()", () => {
|
||||
it("returns first credential only for providers with multiple keys", () => {
|
||||
const storage = inMemory({
|
||||
anthropic: [makeKey("sk-1"), makeKey("sk-2")],
|
||||
openai: makeKey("sk-openai"),
|
||||
});
|
||||
const all = storage.getAll();
|
||||
assert.ok(all["anthropic"]?.type === "api_key");
|
||||
assert.equal((all["anthropic"] as any).key, "sk-1");
|
||||
assert.equal((all["openai"] as any).key, "sk-openai");
|
||||
});
|
||||
});
|
||||
|
|
@ -2,6 +2,9 @@
|
|||
* Credential storage for API keys and OAuth tokens.
|
||||
* Handles loading, saving, and refreshing credentials from auth.json.
|
||||
*
|
||||
* Supports multiple credentials per provider with round-robin selection,
|
||||
* session-sticky hashing, and automatic rate-limit fallback.
|
||||
*
|
||||
* Uses file locking to prevent race conditions when multiple pi instances
|
||||
* try to refresh tokens simultaneously.
|
||||
*/
|
||||
|
|
@ -30,7 +33,11 @@ export type OAuthCredential = {
|
|||
|
||||
export type AuthCredential = ApiKeyCredential | OAuthCredential;
|
||||
|
||||
export type AuthStorageData = Record<string, AuthCredential>;
|
||||
/**
|
||||
* On-disk format: each provider maps to a single credential or an array of credentials.
|
||||
* Single credentials are normalized to arrays at load time for internal use.
|
||||
*/
|
||||
export type AuthStorageData = Record<string, AuthCredential | AuthCredential[]>;
|
||||
|
||||
type LockResult<T> = {
|
||||
result: T;
|
||||
|
|
@ -178,8 +185,49 @@ export class InMemoryAuthStorageBackend implements AuthStorageBackend {
|
|||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Backoff durations for different error types (milliseconds)
|
||||
// ============================================================================
|
||||
|
||||
const BACKOFF_RATE_LIMIT_MS = 30_000; // 30s for rate limit / 429
|
||||
const BACKOFF_QUOTA_EXHAUSTED_MS = 30 * 60_000; // 30min for quota exhausted
|
||||
const BACKOFF_SERVER_ERROR_MS = 20_000; // 20s for 5xx server errors
|
||||
const BACKOFF_DEFAULT_MS = 60_000; // 60s fallback
|
||||
|
||||
export type UsageLimitErrorType = "rate_limit" | "quota_exhausted" | "server_error" | "unknown";
|
||||
|
||||
/**
|
||||
* Get backoff duration for an error type.
|
||||
*/
|
||||
function getBackoffDuration(errorType: UsageLimitErrorType): number {
|
||||
switch (errorType) {
|
||||
case "rate_limit":
|
||||
return BACKOFF_RATE_LIMIT_MS;
|
||||
case "quota_exhausted":
|
||||
return BACKOFF_QUOTA_EXHAUSTED_MS;
|
||||
case "server_error":
|
||||
return BACKOFF_SERVER_ERROR_MS;
|
||||
default:
|
||||
return BACKOFF_DEFAULT_MS;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple string hash for session-sticky credential selection.
|
||||
* Returns a positive integer.
|
||||
*/
|
||||
function hashString(str: string): number {
|
||||
let hash = 0;
|
||||
for (let i = 0; i < str.length; i++) {
|
||||
const char = str.charCodeAt(i);
|
||||
hash = ((hash << 5) - hash + char) | 0;
|
||||
}
|
||||
return Math.abs(hash);
|
||||
}
|
||||
|
||||
/**
|
||||
* Credential storage backed by a JSON file.
|
||||
* Supports multiple credentials per provider with round-robin rotation and rate-limit fallback.
|
||||
*/
|
||||
export class AuthStorage {
|
||||
private data: AuthStorageData = {};
|
||||
|
|
@ -188,6 +236,18 @@ export class AuthStorage {
|
|||
private loadError: Error | null = null;
|
||||
private errors: Error[] = [];
|
||||
|
||||
/**
|
||||
* Round-robin index per provider. Incremented on each call to getApiKey
|
||||
* when no sessionId is provided.
|
||||
*/
|
||||
private providerRoundRobinIndex: Map<string, number> = new Map();
|
||||
|
||||
/**
|
||||
* Backoff tracking per provider per credential index.
|
||||
* Map<provider, Map<credentialIndex, backoffExpiresAt>>
|
||||
*/
|
||||
private credentialBackoff: Map<string, Map<number, number>> = new Map();
|
||||
|
||||
private constructor(private storage: AuthStorageBackend) {
|
||||
this.reload();
|
||||
}
|
||||
|
|
@ -241,6 +301,17 @@ export class AuthStorage {
|
|||
return JSON.parse(content) as AuthStorageData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Normalize a storage entry to an array of credentials.
|
||||
* Handles both single credential (backward compat) and array formats.
|
||||
*/
|
||||
getCredentialsForProvider(provider: string): AuthCredential[] {
|
||||
const entry = this.data[provider];
|
||||
if (!entry) return [];
|
||||
if (Array.isArray(entry)) return entry;
|
||||
return [entry];
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload credentials from storage.
|
||||
*/
|
||||
|
|
@ -259,7 +330,7 @@ export class AuthStorage {
|
|||
}
|
||||
}
|
||||
|
||||
private persistProviderChange(provider: string, credential: AuthCredential | undefined): void {
|
||||
private persistProviderChange(provider: string, credential: AuthCredential | AuthCredential[] | undefined): void {
|
||||
if (this.loadError) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -281,25 +352,52 @@ export class AuthStorage {
|
|||
}
|
||||
|
||||
/**
|
||||
* Get credential for a provider.
|
||||
* Get the first credential for a provider (backward-compatible).
|
||||
*/
|
||||
get(provider: string): AuthCredential | undefined {
|
||||
return this.data[provider] ?? undefined;
|
||||
const creds = this.getCredentialsForProvider(provider);
|
||||
return creds[0] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set credential for a provider.
|
||||
* Set credential for a provider. For API key credentials, appends to
|
||||
* existing credentials (accumulation on duplicate login). For OAuth,
|
||||
* replaces (only one OAuth token per provider makes sense).
|
||||
*/
|
||||
set(provider: string, credential: AuthCredential): void {
|
||||
this.data[provider] = credential;
|
||||
this.persistProviderChange(provider, credential);
|
||||
if (credential.type === "api_key") {
|
||||
const existing = this.getCredentialsForProvider(provider);
|
||||
// Deduplicate: don't add if same key already exists
|
||||
const isDuplicate = existing.some(
|
||||
(c) => c.type === "api_key" && c.key === credential.key,
|
||||
);
|
||||
if (isDuplicate) return;
|
||||
|
||||
const updated = [...existing, credential];
|
||||
this.data[provider] = updated.length === 1 ? updated[0] : updated;
|
||||
this.persistProviderChange(provider, updated.length === 1 ? updated[0] : updated);
|
||||
} else {
|
||||
// OAuth: replace any existing OAuth credential, keep API keys
|
||||
const existing = this.getCredentialsForProvider(provider);
|
||||
const apiKeys = existing.filter((c) => c.type === "api_key");
|
||||
if (apiKeys.length === 0) {
|
||||
this.data[provider] = credential;
|
||||
this.persistProviderChange(provider, credential);
|
||||
} else {
|
||||
const updated = [...apiKeys, credential];
|
||||
this.data[provider] = updated;
|
||||
this.persistProviderChange(provider, updated);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove credential for a provider.
|
||||
* Remove all credentials for a provider.
|
||||
*/
|
||||
remove(provider: string): void {
|
||||
delete this.data[provider];
|
||||
this.providerRoundRobinIndex.delete(provider);
|
||||
this.credentialBackoff.delete(provider);
|
||||
this.persistProviderChange(provider, undefined);
|
||||
}
|
||||
|
||||
|
|
@ -331,9 +429,19 @@ export class AuthStorage {
|
|||
|
||||
/**
|
||||
* Get all credentials (for passing to getOAuthApiKey).
|
||||
* Returns normalized format where each provider has a single credential
|
||||
* (the first one) for backward compatibility with OAuth refresh.
|
||||
*
|
||||
* NOTE: For providers with multiple API keys, only the first credential is
|
||||
* returned. This is intentional — callers use this for OAuth refresh only,
|
||||
* which is always single-credential. Do not use for API key enumeration.
|
||||
*/
|
||||
getAll(): AuthStorageData {
|
||||
return { ...this.data };
|
||||
getAll(): Record<string, AuthCredential> {
|
||||
const result: Record<string, AuthCredential> = {};
|
||||
for (const [provider, entry] of Object.entries(this.data)) {
|
||||
result[provider] = Array.isArray(entry) ? entry[0] : entry;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
drainErrors(): Error[] {
|
||||
|
|
@ -362,6 +470,108 @@ export class AuthStorage {
|
|||
this.remove(provider);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a credential index is currently backed off.
|
||||
*/
|
||||
private isCredentialBackedOff(provider: string, index: number): boolean {
|
||||
const providerBackoff = this.credentialBackoff.get(provider);
|
||||
if (!providerBackoff) return false;
|
||||
const expiresAt = providerBackoff.get(index);
|
||||
if (expiresAt === undefined) return false;
|
||||
if (Date.now() >= expiresAt) {
|
||||
providerBackoff.delete(index);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Select the best credential index for a provider.
|
||||
* - If sessionId is provided, uses session-sticky hashing as the starting point.
|
||||
* - Otherwise, uses round-robin as the starting point.
|
||||
* - Skips credentials that are currently backed off.
|
||||
* - Returns -1 if all credentials are backed off.
|
||||
*/
|
||||
private selectCredentialIndex(provider: string, credentials: AuthCredential[], sessionId?: string): number {
|
||||
if (credentials.length === 0) return -1;
|
||||
if (credentials.length === 1) {
|
||||
return this.isCredentialBackedOff(provider, 0) ? -1 : 0;
|
||||
}
|
||||
|
||||
let startIndex: number;
|
||||
if (sessionId) {
|
||||
startIndex = hashString(sessionId) % credentials.length;
|
||||
} else {
|
||||
const current = this.providerRoundRobinIndex.get(provider) ?? 0;
|
||||
startIndex = current % credentials.length;
|
||||
this.providerRoundRobinIndex.set(provider, current + 1);
|
||||
}
|
||||
|
||||
// Try starting from the preferred index, wrapping around
|
||||
for (let offset = 0; offset < credentials.length; offset++) {
|
||||
const index = (startIndex + offset) % credentials.length;
|
||||
if (!this.isCredentialBackedOff(provider, index)) {
|
||||
return index;
|
||||
}
|
||||
}
|
||||
|
||||
// All credentials are backed off
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* Mark a credential as rate-limited. Finds the credential that was most
|
||||
* recently used for this provider+session and backs it off.
|
||||
*
|
||||
* @returns true if another credential is available (caller should retry),
|
||||
* false if all credentials for this provider are backed off.
|
||||
*/
|
||||
markUsageLimitReached(
|
||||
provider: string,
|
||||
sessionId?: string,
|
||||
options?: { errorType?: UsageLimitErrorType },
|
||||
): boolean {
|
||||
const credentials = this.getCredentialsForProvider(provider);
|
||||
if (credentials.length === 0) return false;
|
||||
|
||||
const errorType = options?.errorType ?? "rate_limit";
|
||||
const backoffMs = getBackoffDuration(errorType);
|
||||
|
||||
// Determine which credential was just used (same logic as selectCredentialIndex
|
||||
// but without incrementing round-robin)
|
||||
let usedIndex: number;
|
||||
if (credentials.length === 1) {
|
||||
usedIndex = 0;
|
||||
} else if (sessionId) {
|
||||
usedIndex = hashString(sessionId) % credentials.length;
|
||||
} else {
|
||||
// Round-robin was already incremented in getApiKey, so the last-used
|
||||
// index is (current - 1). Note: in a concurrent scenario where another
|
||||
// getApiKey call fires between the original request and this backoff call,
|
||||
// we may back off the wrong credential index. This is acceptable because:
|
||||
// (a) pi runs single-threaded event loop, (b) backing off the wrong key
|
||||
// is safe — it self-heals when the backoff expires.
|
||||
const current = this.providerRoundRobinIndex.get(provider) ?? 0;
|
||||
usedIndex = ((current - 1) % credentials.length + credentials.length) % credentials.length;
|
||||
}
|
||||
|
||||
// Set backoff for this credential
|
||||
let providerBackoff = this.credentialBackoff.get(provider);
|
||||
if (!providerBackoff) {
|
||||
providerBackoff = new Map();
|
||||
this.credentialBackoff.set(provider, providerBackoff);
|
||||
}
|
||||
providerBackoff.set(usedIndex, Date.now() + backoffMs);
|
||||
|
||||
// Check if any credential is still available
|
||||
for (let i = 0; i < credentials.length; i++) {
|
||||
if (!this.isCredentialBackedOff(provider, i)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Refresh OAuth token with backend locking to prevent race conditions.
|
||||
* Multiple pi instances may try to refresh simultaneously when tokens expire.
|
||||
|
|
@ -379,8 +589,10 @@ export class AuthStorage {
|
|||
this.data = currentData;
|
||||
this.loadError = null;
|
||||
|
||||
const cred = currentData[providerId];
|
||||
if (cred?.type !== "oauth") {
|
||||
// Find the OAuth credential for this provider
|
||||
const creds = this.getCredentialsForProvider(providerId);
|
||||
const cred = creds.find((c) => c.type === "oauth");
|
||||
if (!cred || cred.type !== "oauth") {
|
||||
return { result: null };
|
||||
}
|
||||
|
||||
|
|
@ -390,8 +602,9 @@ export class AuthStorage {
|
|||
|
||||
const oauthCreds: Record<string, OAuthCredentials> = {};
|
||||
for (const [key, value] of Object.entries(currentData)) {
|
||||
if (value.type === "oauth") {
|
||||
oauthCreds[key] = value;
|
||||
const first = Array.isArray(value) ? value.find((c) => c.type === "oauth") : value;
|
||||
if (first?.type === "oauth") {
|
||||
oauthCreds[key] = first;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -400,9 +613,20 @@ export class AuthStorage {
|
|||
return { result: null };
|
||||
}
|
||||
|
||||
// Update the OAuth credential in-place within the array
|
||||
const existingEntry = currentData[providerId];
|
||||
const newOAuthCred: OAuthCredential = { type: "oauth", ...refreshed.newCredentials };
|
||||
let updatedEntry: AuthCredential | AuthCredential[];
|
||||
|
||||
if (Array.isArray(existingEntry)) {
|
||||
updatedEntry = existingEntry.map((c) => (c.type === "oauth" ? newOAuthCred : c));
|
||||
} else {
|
||||
updatedEntry = newOAuthCred;
|
||||
}
|
||||
|
||||
const merged: AuthStorageData = {
|
||||
...currentData,
|
||||
[providerId]: { type: "oauth", ...refreshed.newCredentials },
|
||||
[providerId]: updatedEntry,
|
||||
};
|
||||
this.data = merged;
|
||||
this.loadError = null;
|
||||
|
|
@ -412,64 +636,70 @@ export class AuthStorage {
|
|||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve an API key from a single credential.
|
||||
*/
|
||||
private async resolveCredentialApiKey(
|
||||
providerId: string,
|
||||
cred: AuthCredential,
|
||||
): Promise<string | undefined> {
|
||||
if (cred.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
}
|
||||
|
||||
if (cred.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) return undefined;
|
||||
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
if (needsRefresh) {
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) return result.apiKey;
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
this.reload();
|
||||
const updatedCreds = this.getCredentialsForProvider(providerId);
|
||||
const updatedOAuth = updatedCreds.find((c) => c.type === "oauth");
|
||||
if (updatedOAuth?.type === "oauth" && Date.now() < updatedOAuth.expires) {
|
||||
return provider.getApiKey(updatedOAuth);
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
} else {
|
||||
return provider.getApiKey(cred);
|
||||
}
|
||||
}
|
||||
|
||||
return undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* Priority:
|
||||
* 1. Runtime override (CLI --api-key)
|
||||
* 2. API key from auth.json
|
||||
* 3. OAuth token from auth.json (auto-refreshed with locking)
|
||||
* 4. Environment variable
|
||||
* 5. Fallback resolver (models.json custom providers)
|
||||
* 2. Credential(s) from auth.json (with round-robin / session-sticky selection)
|
||||
* 3. Environment variable
|
||||
* 4. Fallback resolver (models.json custom providers)
|
||||
*
|
||||
* @param providerId - The provider to get an API key for
|
||||
* @param sessionId - Optional session ID for sticky credential selection
|
||||
*/
|
||||
async getApiKey(providerId: string): Promise<string | undefined> {
|
||||
async getApiKey(providerId: string, sessionId?: string): Promise<string | undefined> {
|
||||
// Runtime override takes highest priority
|
||||
const runtimeKey = this.runtimeOverrides.get(providerId);
|
||||
if (runtimeKey) {
|
||||
return runtimeKey;
|
||||
}
|
||||
|
||||
const cred = this.data[providerId];
|
||||
const credentials = this.getCredentialsForProvider(providerId);
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
const provider = getOAuthProvider(providerId);
|
||||
if (!provider) {
|
||||
// Unknown OAuth provider, can't get API key
|
||||
return undefined;
|
||||
}
|
||||
|
||||
// Check if token needs refresh
|
||||
const needsRefresh = Date.now() >= cred.expires;
|
||||
|
||||
if (needsRefresh) {
|
||||
// Use locked refresh to prevent race conditions
|
||||
try {
|
||||
const result = await this.refreshOAuthTokenWithLock(providerId);
|
||||
if (result) {
|
||||
return result.apiKey;
|
||||
}
|
||||
} catch (error) {
|
||||
this.recordError(error);
|
||||
// Refresh failed - re-read file to check if another instance succeeded
|
||||
this.reload();
|
||||
const updatedCred = this.data[providerId];
|
||||
|
||||
if (updatedCred?.type === "oauth" && Date.now() < updatedCred.expires) {
|
||||
// Another instance refreshed successfully, use those credentials
|
||||
return provider.getApiKey(updatedCred);
|
||||
}
|
||||
|
||||
// Refresh truly failed - return undefined so model discovery skips this provider
|
||||
// User can /login to re-authenticate (credentials preserved for retry)
|
||||
return undefined;
|
||||
}
|
||||
} else {
|
||||
// Token not expired, use current access token
|
||||
return provider.getApiKey(cred);
|
||||
if (credentials.length > 0) {
|
||||
const index = this.selectCredentialIndex(providerId, credentials, sessionId);
|
||||
if (index >= 0) {
|
||||
return this.resolveCredentialApiKey(providerId, credentials[index]);
|
||||
}
|
||||
// All credentials backed off - fall through to env/fallback
|
||||
}
|
||||
|
||||
// Fall back to environment variable
|
||||
|
|
|
|||
|
|
@ -517,16 +517,18 @@ export class ModelRegistry {
|
|||
|
||||
/**
|
||||
* Get API key for a model.
|
||||
* @param sessionId - Optional session ID for sticky credential selection
|
||||
*/
|
||||
async getApiKey(model: Model<Api>): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider);
|
||||
async getApiKey(model: Model<Api>, sessionId?: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(model.provider, sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get API key for a provider.
|
||||
* @param sessionId - Optional session ID for sticky credential selection
|
||||
*/
|
||||
async getApiKeyForProvider(provider: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(provider);
|
||||
async getApiKeyForProvider(provider: string, sessionId?: string): Promise<string | undefined> {
|
||||
return this.authStorage.getApiKey(provider, sessionId);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue