feat(01-03): insert STEP 2 capability scoring into resolveModelForComplexity

- Add unitType and taskMetadata optional params to resolveModelForComplexity
- Replace findModelForTier with getEligibleModels for multi-model eligible set
- Insert STEP 2 scoring block: activates when capability_routing enabled, eligible.length > 1, unitType provided
- Add buildFallbackChain helper to deduplicate fallback assembly logic
- Scoring returns capability-scored selectionMethod with capabilityScores and taskRequirements
- Single-model and zero-model paths fall through to tier-only behavior
- All 42 existing tests pass unchanged (backward compat via optional params)
This commit is contained in:
Jeremy 2026-03-26 17:19:55 -05:00
parent bf918d30d5
commit accee43563

View file

@ -278,16 +278,34 @@ export function getEligibleModels(
});
}
/**
* Build a fallback chain for a selected model: [selectedModel, ...configuredFallbacks, configuredPrimary]
* Deduplicates entries while preserving order.
*/
function buildFallbackChain(selectedModelId: string, phaseConfig: ResolvedModelConfig): string[] {
return [
...phaseConfig.fallbacks.filter(f => f !== selectedModelId),
phaseConfig.primary,
].filter(f => f !== selectedModelId);
}
/**
* Resolve the model to use for a given complexity tier.
*
* Downgrade-only: the returned model is always equal to or cheaper than
* the user's configured primary model. Never upgrades beyond configuration.
*
* @param classification The complexity classification result
* @param phaseConfig The user's configured model for this phase (ceiling)
* @param routingConfig Dynamic routing configuration
* @param availableModelIds List of available model IDs (from registry)
* STEP 1: Filter to eligible models for the requested tier.
* STEP 2: Capability scoring ranks eligible models by task-capability match
* when capability_routing is enabled and multiple eligible models exist.
* STEP 3: Fallback chain assembly.
*
* @param classification The complexity classification result
* @param phaseConfig The user's configured model for this phase (ceiling)
* @param routingConfig Dynamic routing configuration
* @param availableModelIds List of available model IDs (from registry)
* @param unitType The unit type for capability requirement computation (optional)
* @param taskMetadata Task metadata for refined requirement vectors (optional)
*/
export function resolveModelForComplexity(
classification: ClassificationResult,
@ -295,7 +313,7 @@ export function resolveModelForComplexity(
routingConfig: DynamicRoutingConfig,
availableModelIds: string[],
unitType?: string,
metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number },
taskMetadata?: TaskMetadata,
): RoutingDecision {
// If no phase config or routing disabled, pass through
if (!phaseConfig || !routingConfig.enabled) {
@ -341,45 +359,48 @@ export function resolveModelForComplexity(
};
}
// Find the best model for the requested tier
const useCapabilityScoring = routingConfig.capability_routing && unitType;
// STEP 1: Get all eligible models for the requested tier
const eligible = getEligibleModels(requestedTier, availableModelIds, routingConfig);
let targetModelId: string | null;
let capabilityScores: Record<string, number> | undefined;
let taskRequirements: Partial<Record<string, number>> | undefined;
let selectionMethod: "tier-only" | "capability-scored" = "tier-only";
if (useCapabilityScoring) {
const result = findModelForTierWithCapability(
requestedTier, routingConfig, availableModelIds,
routingConfig.cross_provider !== false, unitType, metadata,
);
targetModelId = result.modelId;
capabilityScores = Object.keys(result.scores).length > 0 ? result.scores : undefined;
taskRequirements = Object.keys(result.requirements).length > 0 ? result.requirements : undefined;
selectionMethod = capabilityScores ? "capability-scored" : "tier-only";
} else {
targetModelId = findModelForTier(
requestedTier, routingConfig, availableModelIds,
routingConfig.cross_provider !== false,
);
}
if (!targetModelId) {
if (eligible.length === 0) {
// No suitable model found — use configured primary
return {
modelId: configuredPrimary,
fallbacks: phaseConfig.fallbacks,
tier: requestedTier,
wasDowngraded: false,
reason: `no ${requestedTier}-tier model available`,
selectionMethod,
selectionMethod: "tier-only",
};
}
const fallbacks = [
...phaseConfig.fallbacks.filter(f => f !== targetModelId),
configuredPrimary,
].filter(f => f !== targetModelId);
// STEP 2: Capability scoring (when enabled and multiple eligible models exist)
if (routingConfig.capability_routing !== false && eligible.length > 1 && unitType) {
const requirements = computeTaskRequirements(unitType, taskMetadata);
const scored = scoreEligibleModels(eligible, requirements);
const winner = scored[0];
if (winner) {
const capScores: Record<string, number> = {};
for (const s of scored) capScores[s.modelId] = s.score;
const fallbacks = buildFallbackChain(winner.modelId, phaseConfig);
return {
modelId: winner.modelId,
fallbacks,
tier: requestedTier,
wasDowngraded: true,
reason: `capability-scored: ${winner.modelId} (${winner.score.toFixed(1)}) for ${unitType}`,
capabilityScores: capScores,
taskRequirements: requirements,
selectionMethod: "capability-scored",
};
}
}
// STEP 3: Fallback — use first eligible model (cheapest in tier, or single eligible)
const targetModelId = eligible[0];
// Build fallback chain: [downgraded_model, ...configured_fallbacks, configured_primary]
const fallbacks = buildFallbackChain(targetModelId, phaseConfig);
return {
modelId: targetModelId,
@ -387,9 +408,7 @@ export function resolveModelForComplexity(
tier: requestedTier,
wasDowngraded: true,
reason: classification.reason,
selectionMethod,
capabilityScores,
taskRequirements,
selectionMethod: "tier-only",
};
}
@ -447,93 +466,6 @@ function isKnownModel(modelId: string): boolean {
return false;
}
function findModelForTier(
tier: ComplexityTier,
config: DynamicRoutingConfig,
availableModelIds: string[],
crossProvider: boolean,
): string | null {
// 1. Check explicit tier_models config
const explicitModel = config.tier_models?.[tier];
if (explicitModel && availableModelIds.includes(explicitModel)) {
return explicitModel;
}
// Also check with provider prefix stripped
if (explicitModel) {
const match = availableModelIds.find(id => {
const bareAvail = id.includes("/") ? id.split("/").pop()! : id;
const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel;
return bareAvail === bareExplicit;
});
if (match) return match;
}
// 2. Auto-detect: find the cheapest available model in the requested tier
const candidates = availableModelIds
.filter(id => {
const modelTier = getModelTier(id);
return modelTier === tier;
})
.sort((a, b) => {
if (!crossProvider) return 0;
const costA = getModelCost(a);
const costB = getModelCost(b);
return costA - costB;
});
return candidates[0] ?? null;
}
function findModelForTierWithCapability(
tier: ComplexityTier,
config: DynamicRoutingConfig,
availableModelIds: string[],
crossProvider: boolean,
unitType: string,
metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number },
): { modelId: string | null; scores: Record<string, number>; requirements: Partial<Record<string, number>> } {
const explicitModel = config.tier_models?.[tier];
if (explicitModel) {
const match = availableModelIds.find(id => {
const bareAvail = id.includes("/") ? id.split("/").pop()! : id;
const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel;
return bareAvail === bareExplicit || id === explicitModel;
});
if (match) return { modelId: match, scores: {}, requirements: {} };
}
const requirements = computeTaskRequirements(unitType, metadata);
const candidates = availableModelIds.filter(id => getModelTier(id) === tier);
if (candidates.length === 0) return { modelId: null, scores: {}, requirements };
const scores: Record<string, number> = {};
for (const id of candidates) {
const bareId = id.includes("/") ? id.split("/").pop()! : id;
const profile = getModelProfile(bareId);
scores[id] = scoreModel(profile, requirements);
}
candidates.sort((a, b) => {
const scoreDiff = scores[b] - scores[a];
if (Math.abs(scoreDiff) > 2) return scoreDiff;
if (crossProvider) {
const costDiff = getModelCost(a) - getModelCost(b);
if (costDiff !== 0) return costDiff;
}
return a.localeCompare(b);
});
return { modelId: candidates[0], scores, requirements };
}
function getModelProfile(bareId: string): ModelCapabilities {
if (MODEL_CAPABILITY_PROFILES[bareId]) return MODEL_CAPABILITY_PROFILES[bareId];
for (const [knownId, profile] of Object.entries(MODEL_CAPABILITY_PROFILES)) {
if (bareId.includes(knownId) || knownId.includes(bareId)) return profile;
}
return { coding: 50, debugging: 50, research: 50, reasoning: 50, speed: 50, longContext: 50, instruction: 50 };
}
function getModelCost(modelId: string): number {
const bareId = modelId.includes("/") ? modelId.split("/").pop()! : modelId;