feat(01-03): insert STEP 2 capability scoring into resolveModelForComplexity
- Add unitType and taskMetadata optional params to resolveModelForComplexity - Replace findModelForTier with getEligibleModels for multi-model eligible set - Insert STEP 2 scoring block: activates when capability_routing enabled, eligible.length > 1, unitType provided - Add buildFallbackChain helper to deduplicate fallback assembly logic - Scoring returns capability-scored selectionMethod with capabilityScores and taskRequirements - Single-model and zero-model paths fall through to tier-only behavior - All 42 existing tests pass unchanged (backward compat via optional params)
This commit is contained in:
parent
bf918d30d5
commit
accee43563
1 changed files with 56 additions and 124 deletions
|
|
@ -278,16 +278,34 @@ export function getEligibleModels(
|
|||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Build a fallback chain for a selected model: [selectedModel, ...configuredFallbacks, configuredPrimary]
|
||||
* Deduplicates entries while preserving order.
|
||||
*/
|
||||
function buildFallbackChain(selectedModelId: string, phaseConfig: ResolvedModelConfig): string[] {
|
||||
return [
|
||||
...phaseConfig.fallbacks.filter(f => f !== selectedModelId),
|
||||
phaseConfig.primary,
|
||||
].filter(f => f !== selectedModelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve the model to use for a given complexity tier.
|
||||
*
|
||||
* Downgrade-only: the returned model is always equal to or cheaper than
|
||||
* the user's configured primary model. Never upgrades beyond configuration.
|
||||
*
|
||||
* @param classification The complexity classification result
|
||||
* @param phaseConfig The user's configured model for this phase (ceiling)
|
||||
* @param routingConfig Dynamic routing configuration
|
||||
* @param availableModelIds List of available model IDs (from registry)
|
||||
* STEP 1: Filter to eligible models for the requested tier.
|
||||
* STEP 2: Capability scoring — ranks eligible models by task-capability match
|
||||
* when capability_routing is enabled and multiple eligible models exist.
|
||||
* STEP 3: Fallback chain assembly.
|
||||
*
|
||||
* @param classification The complexity classification result
|
||||
* @param phaseConfig The user's configured model for this phase (ceiling)
|
||||
* @param routingConfig Dynamic routing configuration
|
||||
* @param availableModelIds List of available model IDs (from registry)
|
||||
* @param unitType The unit type for capability requirement computation (optional)
|
||||
* @param taskMetadata Task metadata for refined requirement vectors (optional)
|
||||
*/
|
||||
export function resolveModelForComplexity(
|
||||
classification: ClassificationResult,
|
||||
|
|
@ -295,7 +313,7 @@ export function resolveModelForComplexity(
|
|||
routingConfig: DynamicRoutingConfig,
|
||||
availableModelIds: string[],
|
||||
unitType?: string,
|
||||
metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number },
|
||||
taskMetadata?: TaskMetadata,
|
||||
): RoutingDecision {
|
||||
// If no phase config or routing disabled, pass through
|
||||
if (!phaseConfig || !routingConfig.enabled) {
|
||||
|
|
@ -341,45 +359,48 @@ export function resolveModelForComplexity(
|
|||
};
|
||||
}
|
||||
|
||||
// Find the best model for the requested tier
|
||||
const useCapabilityScoring = routingConfig.capability_routing && unitType;
|
||||
// STEP 1: Get all eligible models for the requested tier
|
||||
const eligible = getEligibleModels(requestedTier, availableModelIds, routingConfig);
|
||||
|
||||
let targetModelId: string | null;
|
||||
let capabilityScores: Record<string, number> | undefined;
|
||||
let taskRequirements: Partial<Record<string, number>> | undefined;
|
||||
let selectionMethod: "tier-only" | "capability-scored" = "tier-only";
|
||||
|
||||
if (useCapabilityScoring) {
|
||||
const result = findModelForTierWithCapability(
|
||||
requestedTier, routingConfig, availableModelIds,
|
||||
routingConfig.cross_provider !== false, unitType, metadata,
|
||||
);
|
||||
targetModelId = result.modelId;
|
||||
capabilityScores = Object.keys(result.scores).length > 0 ? result.scores : undefined;
|
||||
taskRequirements = Object.keys(result.requirements).length > 0 ? result.requirements : undefined;
|
||||
selectionMethod = capabilityScores ? "capability-scored" : "tier-only";
|
||||
} else {
|
||||
targetModelId = findModelForTier(
|
||||
requestedTier, routingConfig, availableModelIds,
|
||||
routingConfig.cross_provider !== false,
|
||||
);
|
||||
}
|
||||
|
||||
if (!targetModelId) {
|
||||
if (eligible.length === 0) {
|
||||
// No suitable model found — use configured primary
|
||||
return {
|
||||
modelId: configuredPrimary,
|
||||
fallbacks: phaseConfig.fallbacks,
|
||||
tier: requestedTier,
|
||||
wasDowngraded: false,
|
||||
reason: `no ${requestedTier}-tier model available`,
|
||||
selectionMethod,
|
||||
selectionMethod: "tier-only",
|
||||
};
|
||||
}
|
||||
|
||||
const fallbacks = [
|
||||
...phaseConfig.fallbacks.filter(f => f !== targetModelId),
|
||||
configuredPrimary,
|
||||
].filter(f => f !== targetModelId);
|
||||
// STEP 2: Capability scoring (when enabled and multiple eligible models exist)
|
||||
if (routingConfig.capability_routing !== false && eligible.length > 1 && unitType) {
|
||||
const requirements = computeTaskRequirements(unitType, taskMetadata);
|
||||
const scored = scoreEligibleModels(eligible, requirements);
|
||||
const winner = scored[0];
|
||||
if (winner) {
|
||||
const capScores: Record<string, number> = {};
|
||||
for (const s of scored) capScores[s.modelId] = s.score;
|
||||
const fallbacks = buildFallbackChain(winner.modelId, phaseConfig);
|
||||
return {
|
||||
modelId: winner.modelId,
|
||||
fallbacks,
|
||||
tier: requestedTier,
|
||||
wasDowngraded: true,
|
||||
reason: `capability-scored: ${winner.modelId} (${winner.score.toFixed(1)}) for ${unitType}`,
|
||||
capabilityScores: capScores,
|
||||
taskRequirements: requirements,
|
||||
selectionMethod: "capability-scored",
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// STEP 3: Fallback — use first eligible model (cheapest in tier, or single eligible)
|
||||
const targetModelId = eligible[0];
|
||||
|
||||
// Build fallback chain: [downgraded_model, ...configured_fallbacks, configured_primary]
|
||||
const fallbacks = buildFallbackChain(targetModelId, phaseConfig);
|
||||
|
||||
return {
|
||||
modelId: targetModelId,
|
||||
|
|
@ -387,9 +408,7 @@ export function resolveModelForComplexity(
|
|||
tier: requestedTier,
|
||||
wasDowngraded: true,
|
||||
reason: classification.reason,
|
||||
selectionMethod,
|
||||
capabilityScores,
|
||||
taskRequirements,
|
||||
selectionMethod: "tier-only",
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -447,93 +466,6 @@ function isKnownModel(modelId: string): boolean {
|
|||
return false;
|
||||
}
|
||||
|
||||
function findModelForTier(
|
||||
tier: ComplexityTier,
|
||||
config: DynamicRoutingConfig,
|
||||
availableModelIds: string[],
|
||||
crossProvider: boolean,
|
||||
): string | null {
|
||||
// 1. Check explicit tier_models config
|
||||
const explicitModel = config.tier_models?.[tier];
|
||||
if (explicitModel && availableModelIds.includes(explicitModel)) {
|
||||
return explicitModel;
|
||||
}
|
||||
// Also check with provider prefix stripped
|
||||
if (explicitModel) {
|
||||
const match = availableModelIds.find(id => {
|
||||
const bareAvail = id.includes("/") ? id.split("/").pop()! : id;
|
||||
const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel;
|
||||
return bareAvail === bareExplicit;
|
||||
});
|
||||
if (match) return match;
|
||||
}
|
||||
|
||||
// 2. Auto-detect: find the cheapest available model in the requested tier
|
||||
const candidates = availableModelIds
|
||||
.filter(id => {
|
||||
const modelTier = getModelTier(id);
|
||||
return modelTier === tier;
|
||||
})
|
||||
.sort((a, b) => {
|
||||
if (!crossProvider) return 0;
|
||||
const costA = getModelCost(a);
|
||||
const costB = getModelCost(b);
|
||||
return costA - costB;
|
||||
});
|
||||
|
||||
return candidates[0] ?? null;
|
||||
}
|
||||
|
||||
function findModelForTierWithCapability(
|
||||
tier: ComplexityTier,
|
||||
config: DynamicRoutingConfig,
|
||||
availableModelIds: string[],
|
||||
crossProvider: boolean,
|
||||
unitType: string,
|
||||
metadata?: { tags?: string[]; complexityKeywords?: string[]; fileCount?: number; estimatedLines?: number },
|
||||
): { modelId: string | null; scores: Record<string, number>; requirements: Partial<Record<string, number>> } {
|
||||
const explicitModel = config.tier_models?.[tier];
|
||||
if (explicitModel) {
|
||||
const match = availableModelIds.find(id => {
|
||||
const bareAvail = id.includes("/") ? id.split("/").pop()! : id;
|
||||
const bareExplicit = explicitModel.includes("/") ? explicitModel.split("/").pop()! : explicitModel;
|
||||
return bareAvail === bareExplicit || id === explicitModel;
|
||||
});
|
||||
if (match) return { modelId: match, scores: {}, requirements: {} };
|
||||
}
|
||||
|
||||
const requirements = computeTaskRequirements(unitType, metadata);
|
||||
const candidates = availableModelIds.filter(id => getModelTier(id) === tier);
|
||||
if (candidates.length === 0) return { modelId: null, scores: {}, requirements };
|
||||
|
||||
const scores: Record<string, number> = {};
|
||||
for (const id of candidates) {
|
||||
const bareId = id.includes("/") ? id.split("/").pop()! : id;
|
||||
const profile = getModelProfile(bareId);
|
||||
scores[id] = scoreModel(profile, requirements);
|
||||
}
|
||||
|
||||
candidates.sort((a, b) => {
|
||||
const scoreDiff = scores[b] - scores[a];
|
||||
if (Math.abs(scoreDiff) > 2) return scoreDiff;
|
||||
if (crossProvider) {
|
||||
const costDiff = getModelCost(a) - getModelCost(b);
|
||||
if (costDiff !== 0) return costDiff;
|
||||
}
|
||||
return a.localeCompare(b);
|
||||
});
|
||||
|
||||
return { modelId: candidates[0], scores, requirements };
|
||||
}
|
||||
|
||||
function getModelProfile(bareId: string): ModelCapabilities {
|
||||
if (MODEL_CAPABILITY_PROFILES[bareId]) return MODEL_CAPABILITY_PROFILES[bareId];
|
||||
for (const [knownId, profile] of Object.entries(MODEL_CAPABILITY_PROFILES)) {
|
||||
if (bareId.includes(knownId) || knownId.includes(bareId)) return profile;
|
||||
}
|
||||
return { coding: 50, debugging: 50, research: 50, reasoning: 50, speed: 50, longContext: 50, instruction: 50 };
|
||||
}
|
||||
|
||||
function getModelCost(modelId: string): number {
|
||||
const bareId = modelId.includes("/") ? modelId.split("/").pop()! : modelId;
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue