fix(pi-ai): correct Copilot context window and output token limits (#2118)
* fix(gsd extension): detect initialized projects in health widget Use .gsd presence plus project-state detection for the health widget so bootstrapped projects no longer appear as unloaded before metrics exist. * fix(gsd extension): detect initialized projects in health widget Use .gsd presence plus project-state detection for the health widget so bootstrapped projects no longer appear as unloaded before metrics exist. * fix(pi-ai): correct Copilot context window and output token limits - Remove github-copilot from 1M contextWindow override in generate-models.ts - Add runtime fetching of model limits from Copilot /models API - Apply fetched limits in modifyModels and refreshToken flows - Regenerate models.generated.ts with corrected values - Fix models.ts type constraints for providers not in MODELS Fixes #2115 * fix(pi-ai): address QA round 1 - Use strict type/bounds checks for API limit values (QA-R1-001/005) - Add caller-level try/catch in refreshToken for defense-in-depth (QA-R1-009) * fix(pi-coding-agent): refresh model registry after OAuth token refresh ModelRegistry.modifyModels() only ran at load time, so model limits fetched during token refresh were persisted to auth.json but never applied to the in-memory model objects. Users saw stale contextWindow values (e.g., 144K from models.dev instead of 200K from the Copilot API). Add credential change notification to AuthStorage: after a successful OAuth token refresh, listeners are notified via queueMicrotask. The ModelRegistry now registers a listener at construction that triggers a full model reload, picking up the new limits from modifyModels().
This commit is contained in:
parent
d97d0ad03c
commit
5ecf047553
6 changed files with 2757 additions and 1564 deletions
1543
packages/pi-ai/scripts/generate-models.ts
Normal file
1543
packages/pi-ai/scripts/generate-models.ts
Normal file
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -12,12 +12,15 @@ for (const [provider, models] of Object.entries(MODELS)) {
|
|||
modelRegistry.set(provider, providerModels);
|
||||
}
|
||||
|
||||
/** Providers that have entries in the generated MODELS constant */
|
||||
type GeneratedProvider = keyof typeof MODELS & KnownProvider;
|
||||
|
||||
type ModelApi<
|
||||
TProvider extends KnownProvider,
|
||||
TProvider extends GeneratedProvider,
|
||||
TModelId extends keyof (typeof MODELS)[TProvider],
|
||||
> = (typeof MODELS)[TProvider][TModelId] extends { api: infer TApi } ? (TApi extends Api ? TApi : never) : never;
|
||||
|
||||
export function getModel<TProvider extends KnownProvider, TModelId extends keyof (typeof MODELS)[TProvider]>(
|
||||
export function getModel<TProvider extends GeneratedProvider, TModelId extends keyof (typeof MODELS)[TProvider]>(
|
||||
provider: TProvider,
|
||||
modelId: TModelId,
|
||||
): Model<ModelApi<TProvider, TModelId>> {
|
||||
|
|
@ -31,9 +34,9 @@ export function getProviders(): KnownProvider[] {
|
|||
|
||||
export function getModels<TProvider extends KnownProvider>(
|
||||
provider: TProvider,
|
||||
): Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[] {
|
||||
): Model<Api>[] {
|
||||
const models = modelRegistry.get(provider);
|
||||
return models ? (Array.from(models.values()) as Model<ModelApi<TProvider, keyof (typeof MODELS)[TProvider]>>[]) : [];
|
||||
return models ? (Array.from(models.values()) as Model<Api>[]) : [];
|
||||
}
|
||||
|
||||
export function calculateCost<TApi extends Api>(model: Model<TApi>, usage: Usage): Usage["cost"] {
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } fr
|
|||
|
||||
type CopilotCredentials = OAuthCredentials & {
|
||||
enterpriseUrl?: string;
|
||||
/** Model limits from the /models API, keyed by model ID */
|
||||
modelLimits?: Record<string, { contextWindow: number; maxTokens: number }>;
|
||||
};
|
||||
|
||||
const decode = (s: string) => atob(s);
|
||||
|
|
@ -305,6 +307,47 @@ async function enableAllGitHubCopilotModels(
|
|||
);
|
||||
}
|
||||
|
||||
async function fetchCopilotModelLimits(
|
||||
token: string,
|
||||
enterpriseDomain?: string,
|
||||
): Promise<Record<string, { contextWindow: number; maxTokens: number }>> {
|
||||
const baseUrl = getGitHubCopilotBaseUrl(token, enterpriseDomain);
|
||||
try {
|
||||
const response = await fetch(`${baseUrl}/models`, {
|
||||
headers: {
|
||||
Accept: "application/json",
|
||||
Authorization: `Bearer ${token}`,
|
||||
"X-GitHub-Api-Version": "2025-05-01",
|
||||
...COPILOT_HEADERS,
|
||||
},
|
||||
signal: AbortSignal.timeout(30_000),
|
||||
});
|
||||
if (!response.ok) return {};
|
||||
const data = (await response.json()) as {
|
||||
data?: Array<{
|
||||
id: string;
|
||||
capabilities?: {
|
||||
limits?: {
|
||||
max_context_window_tokens?: number;
|
||||
max_output_tokens?: number;
|
||||
};
|
||||
};
|
||||
}>;
|
||||
};
|
||||
const limits: Record<string, { contextWindow: number; maxTokens: number }> = {};
|
||||
for (const m of data.data || []) {
|
||||
const ctx = m.capabilities?.limits?.max_context_window_tokens;
|
||||
const out = m.capabilities?.limits?.max_output_tokens;
|
||||
if (typeof ctx === "number" && typeof out === "number" && ctx > 0 && out > 0 && Number.isFinite(ctx) && Number.isFinite(out)) {
|
||||
limits[m.id] = { contextWindow: ctx, maxTokens: out };
|
||||
}
|
||||
}
|
||||
return limits;
|
||||
} catch {
|
||||
return {};
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Login with GitHub Copilot OAuth (device code flow)
|
||||
*
|
||||
|
|
@ -351,6 +394,14 @@ export async function loginGitHubCopilot(options: {
|
|||
// Enable all models after successful login
|
||||
options.onProgress?.("Enabling models...");
|
||||
await enableAllGitHubCopilotModels(credentials.access, enterpriseDomain ?? undefined);
|
||||
|
||||
// Fetch real model limits from the Copilot API
|
||||
options.onProgress?.("Fetching model limits...");
|
||||
const modelLimits = await fetchCopilotModelLimits(credentials.access, enterpriseDomain ?? undefined);
|
||||
if (Object.keys(modelLimits).length > 0) {
|
||||
(credentials as CopilotCredentials).modelLimits = modelLimits;
|
||||
}
|
||||
|
||||
return credentials;
|
||||
}
|
||||
|
||||
|
|
@ -369,7 +420,16 @@ export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
|||
|
||||
async refreshToken(credentials: OAuthCredentials): Promise<OAuthCredentials> {
|
||||
const creds = credentials as CopilotCredentials;
|
||||
return refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||
const refreshed = await refreshGitHubCopilotToken(creds.refresh, creds.enterpriseUrl);
|
||||
try {
|
||||
const modelLimits = await fetchCopilotModelLimits(refreshed.access, creds.enterpriseUrl);
|
||||
if (Object.keys(modelLimits).length > 0) {
|
||||
(refreshed as CopilotCredentials).modelLimits = modelLimits;
|
||||
}
|
||||
} catch {
|
||||
// Model limits fetch is best-effort; don't block token refresh
|
||||
}
|
||||
return refreshed;
|
||||
},
|
||||
|
||||
getApiKey(credentials: OAuthCredentials): string {
|
||||
|
|
@ -380,6 +440,18 @@ export const githubCopilotOAuthProvider: OAuthProviderInterface = {
|
|||
const creds = credentials as CopilotCredentials;
|
||||
const domain = creds.enterpriseUrl ? (normalizeDomain(creds.enterpriseUrl) ?? undefined) : undefined;
|
||||
const baseUrl = getGitHubCopilotBaseUrl(creds.access, domain);
|
||||
return models.map((m) => (m.provider === "github-copilot" ? { ...m, baseUrl } : m));
|
||||
const limits = creds.modelLimits;
|
||||
return models.map((m) => {
|
||||
if (m.provider !== "github-copilot") return m;
|
||||
const modelLimits = limits?.[m.id];
|
||||
return {
|
||||
...m,
|
||||
baseUrl,
|
||||
...(modelLimits && {
|
||||
contextWindow: modelLimits.contextWindow,
|
||||
maxTokens: modelLimits.maxTokens,
|
||||
}),
|
||||
};
|
||||
});
|
||||
},
|
||||
};
|
||||
|
|
|
|||
|
|
@ -202,6 +202,7 @@ export class AuthStorage {
|
|||
private fallbackResolver?: (provider: string) => string | undefined;
|
||||
private loadError: Error | null = null;
|
||||
private errors: Error[] = [];
|
||||
private credentialChangeListeners: Set<() => void> = new Set();
|
||||
|
||||
/**
|
||||
* Round-robin index per provider. Incremented on each call to getApiKey
|
||||
|
|
@ -263,6 +264,25 @@ export class AuthStorage {
|
|||
this.fallbackResolver = resolver;
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a callback to be notified when credentials change (e.g., after OAuth token refresh).
|
||||
* Returns a function to unregister the listener.
|
||||
*/
|
||||
onCredentialChange(listener: () => void): () => void {
|
||||
this.credentialChangeListeners.add(listener);
|
||||
return () => this.credentialChangeListeners.delete(listener);
|
||||
}
|
||||
|
||||
private notifyCredentialChange(): void {
|
||||
for (const listener of this.credentialChangeListeners) {
|
||||
try {
|
||||
listener();
|
||||
} catch {
|
||||
// Don't let listener errors break the refresh flow
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private recordError(error: unknown): void {
|
||||
const normalizedError = error instanceof Error ? error : new Error(String(error));
|
||||
this.errors.push(normalizedError);
|
||||
|
|
@ -667,6 +687,11 @@ export class AuthStorage {
|
|||
return { result: refreshed, next: JSON.stringify(merged, null, 2) };
|
||||
});
|
||||
|
||||
// Notify listeners after credential change (e.g., model registry refresh)
|
||||
if (result) {
|
||||
queueMicrotask(() => this.notifyCredentialChange());
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -243,6 +243,9 @@ export class ModelRegistry {
|
|||
return undefined;
|
||||
});
|
||||
|
||||
// Refresh models when credentials change (e.g., OAuth token refresh with new model limits)
|
||||
this.authStorage.onCredentialChange(() => this.refresh());
|
||||
|
||||
// Load models
|
||||
this.loadModels();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue