diff --git a/src/resources/extensions/gsd/model-router.ts b/src/resources/extensions/gsd/model-router.ts index 5436bd667..445ce6b73 100644 --- a/src/resources/extensions/gsd/model-router.ts +++ b/src/resources/extensions/gsd/model-router.ts @@ -278,16 +278,34 @@ export function getEligibleModels( }); } +/** + * Build a fallback chain for a selected model: [selectedModel, ...configuredFallbacks, configuredPrimary] + * Deduplicates entries while preserving order. + */ +function buildFallbackChain(selectedModelId: string, phaseConfig: ResolvedModelConfig): string[] { + return [ + ...phaseConfig.fallbacks.filter(f => f !== selectedModelId), + phaseConfig.primary, + ].filter(f => f !== selectedModelId); +} + /** * Resolve the model to use for a given complexity tier. * * Downgrade-only: the returned model is always equal to or cheaper than * the user's configured primary model. Never upgrades beyond configuration. * - * @param classification The complexity classification result - * @param phaseConfig The user's configured model for this phase (ceiling) - * @param routingConfig Dynamic routing configuration - * @param availableModelIds List of available model IDs (from registry) + * STEP 1: Filter to eligible models for the requested tier. + * STEP 2: Capability scoring — ranks eligible models by task-capability match + * when capability_routing is enabled and multiple eligible models exist. + * STEP 3: Fallback chain assembly. + * + * @param classification The complexity classification result + * @param phaseConfig The user's configured model for this phase (ceiling) + * @param routingConfig Dynamic routing configuration + * @param availableModelIds List of available model IDs (from registry) + * @param unitType The unit type for capability requirement computation (optional) + * @param taskMetadata Task metadata for refined requirement vectors (optional) */ export function resolveModelForComplexity( classification: ClassificationResult, @@ -295,7 +313,7 @@ export function resolveModelForComplexity( routingConfig: DynamicRoutingConfig, availableModelIds: string[], unitType?: string, - metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number }, + taskMetadata?: TaskMetadata, ): RoutingDecision { // If no phase config or routing disabled, pass through if (!phaseConfig || !routingConfig.enabled) { @@ -341,45 +359,48 @@ export function resolveModelForComplexity( }; } - // Find the best model for the requested tier - const useCapabilityScoring = routingConfig.capability_routing && unitType; + // STEP 1: Get all eligible models for the requested tier + const eligible = getEligibleModels(requestedTier, availableModelIds, routingConfig); - let targetModelId: string | null; - let capabilityScores: Record | undefined; - let taskRequirements: Partial> | undefined; - let selectionMethod: "tier-only" | "capability-scored" = "tier-only"; - - if (useCapabilityScoring) { - const result = findModelForTierWithCapability( - requestedTier, routingConfig, availableModelIds, - routingConfig.cross_provider !== false, unitType, metadata, - ); - targetModelId = result.modelId; - capabilityScores = Object.keys(result.scores).length > 0 ? result.scores : undefined; - taskRequirements = Object.keys(result.requirements).length > 0 ? result.requirements : undefined; - selectionMethod = capabilityScores ? "capability-scored" : "tier-only"; - } else { - targetModelId = findModelForTier( - requestedTier, routingConfig, availableModelIds, - routingConfig.cross_provider !== false, - ); - } - - if (!targetModelId) { + if (eligible.length === 0) { + // No suitable model found — use configured primary return { modelId: configuredPrimary, fallbacks: phaseConfig.fallbacks, tier: requestedTier, wasDowngraded: false, reason: `no ${requestedTier}-tier model available`, - selectionMethod, + selectionMethod: "tier-only", }; } - const fallbacks = [ - ...phaseConfig.fallbacks.filter(f => f !== targetModelId), - configuredPrimary, - ].filter(f => f !== targetModelId); + // STEP 2: Capability scoring (when enabled and multiple eligible models exist) + if (routingConfig.capability_routing !== false && eligible.length > 1 && unitType) { + const requirements = computeTaskRequirements(unitType, taskMetadata); + const scored = scoreEligibleModels(eligible, requirements); + const winner = scored[0]; + if (winner) { + const capScores: Record = {}; + for (const s of scored) capScores[s.modelId] = s.score; + const fallbacks = buildFallbackChain(winner.modelId, phaseConfig); + return { + modelId: winner.modelId, + fallbacks, + tier: requestedTier, + wasDowngraded: true, + reason: `capability-scored: ${winner.modelId} (${winner.score.toFixed(1)}) for ${unitType}`, + capabilityScores: capScores, + taskRequirements: requirements, + selectionMethod: "capability-scored", + }; + } + } + + // STEP 3: Fallback — use first eligible model (cheapest in tier, or single eligible) + const targetModelId = eligible[0]; + + // Build fallback chain: [downgraded_model, ...configured_fallbacks, configured_primary] + const fallbacks = buildFallbackChain(targetModelId, phaseConfig); return { modelId: targetModelId, @@ -387,9 +408,7 @@ export function resolveModelForComplexity( tier: requestedTier, wasDowngraded: true, reason: classification.reason, - selectionMethod, - capabilityScores, - taskRequirements, + selectionMethod: "tier-only", }; } @@ -447,93 +466,6 @@ function isKnownModel(modelId: string): boolean { return false; } -function findModelForTier( - tier: ComplexityTier, - config: DynamicRoutingConfig, - availableModelIds: string[], - crossProvider: boolean, -): string | null { - // 1. Check explicit tier_models config - const explicitModel = config.tier_models?.[tier]; - if (explicitModel && availableModelIds.includes(explicitModel)) { - return explicitModel; - } - // Also check with provider prefix stripped - if (explicitModel) { - const match = availableModelIds.find(id => { - const bareAvail = id.includes("/") ? id.split("/").pop()! : id; - const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel; - return bareAvail === bareExplicit; - }); - if (match) return match; - } - - // 2. Auto-detect: find the cheapest available model in the requested tier - const candidates = availableModelIds - .filter(id => { - const modelTier = getModelTier(id); - return modelTier === tier; - }) - .sort((a, b) => { - if (!crossProvider) return 0; - const costA = getModelCost(a); - const costB = getModelCost(b); - return costA - costB; - }); - - return candidates[0] ?? null; -} - -function findModelForTierWithCapability( - tier: ComplexityTier, - config: DynamicRoutingConfig, - availableModelIds: string[], - crossProvider: boolean, - unitType: string, - metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number }, -): { modelId: string | null; scores: Record; requirements: Partial> } { - const explicitModel = config.tier_models?.[tier]; - if (explicitModel) { - const match = availableModelIds.find(id => { - const bareAvail = id.includes("/") ? id.split("/").pop()! : id; - const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel; - return bareAvail === bareExplicit || id === explicitModel; - }); - if (match) return { modelId: match, scores: {}, requirements: {} }; - } - - const requirements = computeTaskRequirements(unitType, metadata); - const candidates = availableModelIds.filter(id => getModelTier(id) === tier); - if (candidates.length === 0) return { modelId: null, scores: {}, requirements }; - - const scores: Record = {}; - for (const id of candidates) { - const bareId = id.includes("/") ? id.split("/").pop()! : id; - const profile = getModelProfile(bareId); - scores[id] = scoreModel(profile, requirements); - } - - candidates.sort((a, b) => { - const scoreDiff = scores[b] - scores[a]; - if (Math.abs(scoreDiff) > 2) return scoreDiff; - if (crossProvider) { - const costDiff = getModelCost(a) - getModelCost(b); - if (costDiff !== 0) return costDiff; - } - return a.localeCompare(b); - }); - - return { modelId: candidates[0], scores, requirements }; -} - -function getModelProfile(bareId: string): ModelCapabilities { - if (MODEL_CAPABILITY_PROFILES[bareId]) return MODEL_CAPABILITY_PROFILES[bareId]; - for (const [knownId, profile] of Object.entries(MODEL_CAPABILITY_PROFILES)) { - if (bareId.includes(knownId) || knownId.includes(bareId)) return profile; - } - return { coding: 50, debugging: 50, research: 50, reasoning: 50, speed: 50, longContext: 50, instruction: 50 }; -} - function getModelCost(modelId: string): number { const bareId = modelId.includes("/") ? modelId.split("/").pop()! : modelId;