diff --git a/src/check-usage.ts b/src/check-usage.ts index 1236ebc69..04142fe20 100644 --- a/src/check-usage.ts +++ b/src/check-usage.ts @@ -2,6 +2,7 @@ import { defineCommand } from "citty" import consola from "consola" import { ensurePaths } from "./lib/paths" +import { state } from "./lib/state" import { setupGitHubToken } from "./lib/token" import { getCopilotUsage, @@ -15,9 +16,20 @@ export const checkUsage = defineCommand({ }, async run() { await ensurePaths() - await setupGitHubToken() + const githubToken = await setupGitHubToken() try { - const usage = await getCopilotUsage() + const usage = await getCopilotUsage({ + account: { + name: "_check-usage", + accountType: state.accountType, + githubToken, + copilotTokenRefreshAt: 0, + inFlight: 0, + lastUsedAt: 0, + failureCount: 0, + }, + vsCodeVersion: state.vsCodeVersion, + }) const premium = usage.quota_snapshots.premium_interactions const premiumTotal = premium.entitlement const premiumUsed = premiumTotal - premium.remaining diff --git a/src/lib/api-config.ts b/src/lib/api-config.ts index 83bce92ad..24895dbaa 100644 --- a/src/lib/api-config.ts +++ b/src/lib/api-config.ts @@ -1,6 +1,6 @@ import { randomUUID } from "node:crypto" -import type { State } from "./state" +import type { Account } from "./account-pool" export const standardHeaders = () => ({ "content-type": "application/json", @@ -13,16 +13,22 @@ const USER_AGENT = `GitHubCopilotChat/${COPILOT_VERSION}` const API_VERSION = "2025-04-01" -export const copilotBaseUrl = (state: State) => - state.accountType === "individual" ? +export interface ApiContext { + account: Account + vsCodeVersion?: string +} + +export const copilotBaseUrl = (ctx: ApiContext) => + ctx.account.accountType === "individual" ? "https://api.githubcopilot.com" - : `https://api.${state.accountType}.githubcopilot.com` -export const copilotHeaders = (state: State, vision: boolean = false) => { + : `https://api.${ctx.account.accountType}.githubcopilot.com` + +export const copilotHeaders = (ctx: ApiContext, vision: boolean = false) => { const headers: Record = { - Authorization: `Bearer ${state.copilotToken}`, + Authorization: `Bearer ${ctx.account.copilotToken}`, "content-type": standardHeaders()["content-type"], "copilot-integration-id": "vscode-chat", - "editor-version": `vscode/${state.vsCodeVersion}`, + "editor-version": `vscode/${ctx.vsCodeVersion}`, "editor-plugin-version": EDITOR_PLUGIN_VERSION, "user-agent": USER_AGENT, "openai-intent": "conversation-panel", @@ -37,10 +43,10 @@ export const copilotHeaders = (state: State, vision: boolean = false) => { } export const GITHUB_API_BASE_URL = "https://api.github.com" -export const githubHeaders = (state: State) => ({ +export const githubHeaders = (ctx: ApiContext) => ({ ...standardHeaders(), - authorization: `token ${state.githubToken}`, - "editor-version": `vscode/${state.vsCodeVersion}`, + authorization: `token ${ctx.account.githubToken}`, + "editor-version": `vscode/${ctx.vsCodeVersion}`, "editor-plugin-version": EDITOR_PLUGIN_VERSION, "user-agent": USER_AGENT, "x-github-api-version": API_VERSION, diff --git a/src/lib/state.ts b/src/lib/state.ts index 7c9c45537..72bb388de 100644 --- a/src/lib/state.ts +++ b/src/lib/state.ts @@ -1,18 +1,11 @@ import type { ModelsResponse } from "~/services/copilot/get-models" -import type { Account, Strategy } from "./account-pool" -import type { AccountPool } from "./account-pool" +import type { Account, AccountPool, Strategy } from "./account-pool" export interface State { - // Multi-account pool. Until task 03 wires service code through it, - // legacy fields below mirror the "default" account. pool?: AccountPool strategy: Strategy - // Legacy fields (deprecated; will be removed in task 03): - githubToken?: string - copilotToken?: string - accountType: string models?: ModelsResponse vsCodeVersion?: string @@ -34,7 +27,7 @@ export const state: State = { showToken: false, } -/** Convenience: the first usable account, used by legacy single-account paths. */ +/** Convenience: the first usable account. */ export function defaultAccount(): Account | undefined { return state.pool?.accounts[0] } diff --git a/src/lib/token.ts b/src/lib/token.ts index cbebdc5df..4e3d8e9f2 100644 --- a/src/lib/token.ts +++ b/src/lib/token.ts @@ -11,31 +11,22 @@ import { pollAccessToken } from "~/services/github/poll-access-token" import { HTTPError } from "./error" import { state } from "./state" +import { makeApiContext } from "./utils" const readGithubToken = () => fs.readFile(PATHS.GITHUB_TOKEN_PATH, "utf8") const writeGithubToken = (token: string) => fs.writeFile(PATHS.GITHUB_TOKEN_PATH, token) -/** - * Set up the Copilot token for a single account, including auto-refresh. - * The previous global helper `setupCopilotToken` is replaced by per-account - * setup; legacy `state.copilotToken` is mirrored for not-yet-migrated callers. - */ +/** Per-account Copilot token setup with auto-refresh. */ export const setupCopilotTokenFor = async (account: Account) => { - // Temporarily expose this account's GitHub token for the legacy - // api-config helper which still reads `state.githubToken`. - state.githubToken = account.githubToken - const { token, refresh_in } = await getCopilotToken() + const ctx = makeApiContext(account) + const { token, refresh_in } = await getCopilotToken(ctx) /* eslint-disable require-atomic-updates */ account.copilotToken = token account.copilotTokenRefreshAt = Date.now() + refresh_in * 1000 /* eslint-enable require-atomic-updates */ - // Mirror the first account's token into legacy state for callers - // not yet migrated to the pool (removed in task 03). - state.copilotToken = token - consola.debug(`[${account.name}] Copilot token fetched successfully`) if (state.showToken) { consola.info(`[${account.name}] Copilot token:`, token) @@ -45,13 +36,11 @@ export const setupCopilotTokenFor = async (account: Account) => { account.refreshTimer = setInterval(async () => { consola.debug(`[${account.name}] Refreshing Copilot token`) try { - state.githubToken = account.githubToken - const refreshed = await getCopilotToken() + const refreshed = await getCopilotToken(makeApiContext(account)) /* eslint-disable require-atomic-updates */ account.copilotToken = refreshed.token account.copilotTokenRefreshAt = Date.now() + refreshed.refresh_in * 1000 /* eslint-enable require-atomic-updates */ - state.copilotToken = refreshed.token consola.debug(`[${account.name}] Copilot token refreshed`) if (state.showToken) { consola.info( @@ -70,20 +59,23 @@ interface SetupGitHubTokenOptions { force?: boolean } +/** + * Reads or fetches a single GitHub token file at PATHS.GITHUB_TOKEN_PATH. + * Returns the token; the caller is responsible for putting it into the + * account pool. + */ export async function setupGitHubToken( options?: SetupGitHubTokenOptions, -): Promise { +): Promise { try { const githubToken = await readGithubToken() if (githubToken && !options?.force) { - state.githubToken = githubToken if (state.showToken) { consola.info("GitHub token:", githubToken) } - await logUser() - - return + await logUser(githubToken) + return githubToken } consola.info("Not logged in, getting new access token") @@ -96,12 +88,12 @@ export async function setupGitHubToken( const token = await pollAccessToken(response) await writeGithubToken(token) - state.githubToken = token if (state.showToken) { consola.info("GitHub token:", token) } - await logUser() + await logUser(token) + return token } catch (error) { if (error instanceof HTTPError) { consola.error("Failed to get GitHub token:", await error.response.json()) @@ -113,16 +105,21 @@ export async function setupGitHubToken( } } -/** Backwards-compat wrapper: sets up Copilot token for the default account. */ -export const setupCopilotToken = async () => { - if (state.pool && state.pool.accounts.length > 0) { - await setupCopilotTokenFor(state.pool.accounts[0]) - return +async function logUser(githubToken: string) { + // Build a temporary "anonymous" account with just the GitHub token, + // so we can call /user without going through the pool. + const tempAccount: Account = { + name: "_setup", + accountType: state.accountType, + githubToken, + copilotTokenRefreshAt: 0, + inFlight: 0, + lastUsedAt: 0, + failureCount: 0, } - // No pool yet (very early callers) — do nothing. -} - -async function logUser() { - const user = await getGitHubUser() + const user = await getGitHubUser({ + account: tempAccount, + vsCodeVersion: state.vsCodeVersion, + }) consola.info(`Logged in as ${user.login}`) } diff --git a/src/lib/utils.ts b/src/lib/utils.ts index cc80be667..1ce123380 100644 --- a/src/lib/utils.ts +++ b/src/lib/utils.ts @@ -1,7 +1,12 @@ import consola from "consola" +import type { Context } from "hono" + +import type { Account } from "~/lib/account-pool" +import type { ApiContext } from "~/lib/api-config" import { getModels } from "~/services/copilot/get-models" import { getVSCodeVersion } from "~/services/get-vscode-version" +import type { Model } from "~/services/copilot/get-models" import { state } from "./state" @@ -13,8 +18,166 @@ export const sleep = (ms: number) => export const isNullish = (value: unknown): value is null | undefined => value === null || value === undefined +export function normalizeClaudeModelVersion(model: string): string { + if (!model.startsWith("claude-")) { + return model + } + + // Convert numeric segments from hyphen to dot, e.g. claude-opus-4-6 -> claude-opus-4.6. + // Only replace when the next numeric token ends at '-' or end, so suffixes like '-1m' stay unchanged. + return model.replace(/(\d)-(?=\d(?:-|$))/g, "$1.") +} + +/** + * Resolve model ID by checking the anthropic-beta header for context window variants. + */ +export function resolveModelId(model: string, c?: Context): string { + const normalized = normalizeClaudeModelVersion(model) + + if (!c) { + return normalized + } + + const betaHeader = c.req.header("anthropic-beta") + if ( + normalized.startsWith("claude-") + && betaHeader + && /\bcontext-1m\b/.test(betaHeader) + ) { + if (normalized.endsWith("-1m")) { + return normalized + } + return `${normalized}-1m` + } + + return normalized +} + +/** + * Calculate Jaccard similarity between two strings based on character bigrams. + */ +export function jaccardSimilarity(str1: string, str2: string): number { + const getBigrams = (str: string): Set => { + const bigrams = new Set() + const normalized = str.toLowerCase().replace(/[^a-z0-9]/g, "") + for (let i = 0; i < normalized.length - 1; i++) { + bigrams.add(normalized.substring(i, i + 2)) + } + return bigrams + } + + const bigrams1 = getBigrams(str1) + const bigrams2 = getBigrams(str2) + + if (bigrams1.size === 0 && bigrams2.size === 0) { + return 1 + } + + let intersection = 0 + for (const bigram of bigrams1) { + if (bigrams2.has(bigram)) { + intersection++ + } + } + + const union = bigrams1.size + bigrams2.size - intersection + return union === 0 ? 0 : intersection / union +} + +function findBestModelMatch( + modelId: string, + models: Array, + minSimilarity = 0.3, +): Model | null { + if (models.length === 0) { + return null + } + + let bestMatch: Model | null = null + let bestScore = 0 + + for (const model of models) { + const score = jaccardSimilarity(modelId, model.id) + if (score > bestScore) { + bestScore = score + bestMatch = model + } + } + + if (bestScore >= minSimilarity && bestMatch) { + consola.info( + `Fuzzy matched model "${modelId}" to "${bestMatch.id}" (similarity: ${bestScore.toFixed(2)})`, + ) + return bestMatch + } + + return null +} + +/** + * Resolve a requested model ID against available Copilot models. + * Order: exact -> fuzzy -> auto-version fallback -> first available. + */ +export function mapModelIdToAvailableModels( + requestedModelId: string, + models: Array, +): string { + if (models.length === 0) { + return requestedModelId + } + + const exact = models.find((m) => m.id === requestedModelId) + if (exact) { + return exact.id + } + + const fuzzy = findBestModelMatch(requestedModelId, models) + if (fuzzy) { + return fuzzy.id + } + + const autoModel = models.find((m) => m.id === "auto") + const autoVersionModel = models.find((m) => m.version === autoModel?.version) + if (autoVersionModel) { + consola.info( + `Model "${requestedModelId}" not found, using ${autoVersionModel.id} model`, + ) + return autoVersionModel.id + } + + const fallback = models[0] + consola.info( + `Model "${requestedModelId}" not found, using first available model: ${fallback.id}`, + ) + return fallback.id +} + +/** + * Resolve model ID from request metadata, then map to an available server model. + */ +export function resolveAndMapModelId( + model: string, + c?: Context, + models: Array = state.models?.data ?? [], +): string { + const resolved = resolveModelId(model, c) + return mapModelIdToAvailableModels(resolved, models) +} + +export function makeApiContext(account: Account): ApiContext { + return { account, vsCodeVersion: state.vsCodeVersion } +} + +/** Returns an ApiContext for the first available pool account. */ +export function defaultApiContext(): ApiContext { + if (!state.pool || state.pool.accounts.length === 0) { + throw new Error("Account pool is empty; cannot build ApiContext") + } + return makeApiContext(state.pool.accounts[0]) +} + export async function cacheModels(): Promise { - const models = await getModels() + const models = await getModels(defaultApiContext()) state.models = models } diff --git a/src/routes/chat-completions/handler.ts b/src/routes/chat-completions/handler.ts index 04a5ae9ed..e9efa686f 100644 --- a/src/routes/chat-completions/handler.ts +++ b/src/routes/chat-completions/handler.ts @@ -7,7 +7,11 @@ import { awaitApproval } from "~/lib/approval" import { checkRateLimit } from "~/lib/rate-limit" import { state } from "~/lib/state" import { getTokenCount } from "~/lib/tokenizer" -import { isNullish } from "~/lib/utils" +import { + isNullish, + makeApiContext, + resolveAndMapModelId, +} from "~/lib/utils" import { createChatCompletions, type ChatCompletionResponse, @@ -18,6 +22,10 @@ export async function handleCompletion(c: Context) { await checkRateLimit(state) let payload = await c.req.json() + payload = { + ...payload, + model: resolveAndMapModelId(payload.model, c, state.models?.data ?? []), + } consola.debug("Request payload:", JSON.stringify(payload).slice(-400)) // Find the selected model @@ -47,7 +55,14 @@ export async function handleCompletion(c: Context) { consola.debug("Set max_tokens to:", JSON.stringify(payload.max_tokens)) } - const response = await createChatCompletions(payload) + if (!state.pool) throw new Error("Account pool not initialized") + const account = state.pool.acquire() + let response: Awaited> + try { + response = await createChatCompletions(makeApiContext(account), payload) + } finally { + state.pool.release(account) + } if (isNonStreaming(response)) { consola.debug("Non-streaming response:", JSON.stringify(response)) diff --git a/src/routes/embeddings/route.ts b/src/routes/embeddings/route.ts index 4c4fc7b8a..478bea493 100644 --- a/src/routes/embeddings/route.ts +++ b/src/routes/embeddings/route.ts @@ -1,6 +1,8 @@ import { Hono } from "hono" import { forwardError } from "~/lib/error" +import { state } from "~/lib/state" +import { makeApiContext } from "~/lib/utils" import { createEmbeddings, type EmbeddingRequest, @@ -11,9 +13,14 @@ export const embeddingRoutes = new Hono() embeddingRoutes.post("/", async (c) => { try { const paylod = await c.req.json() - const response = await createEmbeddings(paylod) - - return c.json(response) + if (!state.pool) throw new Error("Account pool not initialized") + const account = state.pool.acquire() + try { + const response = await createEmbeddings(makeApiContext(account), paylod) + return c.json(response) + } finally { + state.pool.release(account) + } } catch (error) { return await forwardError(c, error) } diff --git a/src/routes/messages/count-tokens-handler.ts b/src/routes/messages/count-tokens-handler.ts index 2ec849cb8..34588782f 100644 --- a/src/routes/messages/count-tokens-handler.ts +++ b/src/routes/messages/count-tokens-handler.ts @@ -4,6 +4,7 @@ import consola from "consola" import { state } from "~/lib/state" import { getTokenCount } from "~/lib/tokenizer" +import { resolveAndMapModelId } from "~/lib/utils" import { type AnthropicMessagesPayload } from "./anthropic-types" import { translateToOpenAI } from "./non-stream-translation" @@ -17,10 +18,18 @@ export async function handleCountTokens(c: Context) { const anthropicPayload = await c.req.json() - const openAIPayload = translateToOpenAI(anthropicPayload) + let openAIPayload = translateToOpenAI(anthropicPayload, c) + openAIPayload = { + ...openAIPayload, + model: resolveAndMapModelId( + openAIPayload.model, + undefined, + state.models?.data ?? [], + ), + } const selectedModel = state.models?.data.find( - (model) => model.id === anthropicPayload.model, + (model) => model.id === openAIPayload.model, ) if (!selectedModel) { diff --git a/src/routes/messages/handler.ts b/src/routes/messages/handler.ts index 85dbf6243..8ddf1955e 100644 --- a/src/routes/messages/handler.ts +++ b/src/routes/messages/handler.ts @@ -6,6 +6,7 @@ import { streamSSE } from "hono/streaming" import { awaitApproval } from "~/lib/approval" import { checkRateLimit } from "~/lib/rate-limit" import { state } from "~/lib/state" +import { makeApiContext, resolveAndMapModelId } from "~/lib/utils" import { createChatCompletions, type ChatCompletionChunk, @@ -28,7 +29,15 @@ export async function handleCompletion(c: Context) { const anthropicPayload = await c.req.json() consola.debug("Anthropic request payload:", JSON.stringify(anthropicPayload)) - const openAIPayload = translateToOpenAI(anthropicPayload) + let openAIPayload = translateToOpenAI(anthropicPayload, c) + openAIPayload = { + ...openAIPayload, + model: resolveAndMapModelId( + openAIPayload.model, + undefined, + state.models?.data ?? [], + ), + } consola.debug( "Translated OpenAI request payload:", JSON.stringify(openAIPayload), @@ -38,7 +47,14 @@ export async function handleCompletion(c: Context) { await awaitApproval() } - const response = await createChatCompletions(openAIPayload) + if (!state.pool) throw new Error("Account pool not initialized") + const account = state.pool.acquire() + let response: Awaited> + try { + response = await createChatCompletions(makeApiContext(account), openAIPayload) + } finally { + state.pool.release(account) + } if (isNonStreaming(response)) { consola.debug( diff --git a/src/routes/messages/non-stream-translation.ts b/src/routes/messages/non-stream-translation.ts index dc41e6382..0c64b7c96 100644 --- a/src/routes/messages/non-stream-translation.ts +++ b/src/routes/messages/non-stream-translation.ts @@ -1,3 +1,6 @@ +import type { Context } from "hono" + +import { resolveModelId } from "~/lib/utils" import { type ChatCompletionResponse, type ChatCompletionsPayload, @@ -28,9 +31,10 @@ import { mapOpenAIStopReasonToAnthropic } from "./utils" export function translateToOpenAI( payload: AnthropicMessagesPayload, + c?: Context, ): ChatCompletionsPayload { return { - model: translateModelName(payload.model), + model: resolveModelId(payload.model, c), messages: translateAnthropicMessagesToOpenAI( payload.messages, payload.system, @@ -46,16 +50,6 @@ export function translateToOpenAI( } } -function translateModelName(model: string): string { - // Subagent requests use a specific model number which Copilot doesn't support - if (model.startsWith("claude-sonnet-4-")) { - return model.replace(/^claude-sonnet-4-.*/, "claude-sonnet-4") - } else if (model.startsWith("claude-opus-")) { - return model.replace(/^claude-opus-4-.*/, "claude-opus-4") - } - return model -} - function translateAnthropicMessagesToOpenAI( anthropicMessages: Array, system: string | Array | undefined, diff --git a/src/routes/token/route.ts b/src/routes/token/route.ts index dd0456d9a..5e1acfd8f 100644 --- a/src/routes/token/route.ts +++ b/src/routes/token/route.ts @@ -1,13 +1,14 @@ import { Hono } from "hono" -import { state } from "~/lib/state" +import { defaultAccount } from "~/lib/state" export const tokenRoute = new Hono() tokenRoute.get("/", (c) => { try { + const account = defaultAccount() return c.json({ - token: state.copilotToken, + token: account?.copilotToken ?? null, }) } catch (error) { console.error("Error fetching token:", error) diff --git a/src/routes/usage/route.ts b/src/routes/usage/route.ts index 3e9473236..847a2f94e 100644 --- a/src/routes/usage/route.ts +++ b/src/routes/usage/route.ts @@ -1,12 +1,13 @@ import { Hono } from "hono" +import { defaultApiContext } from "~/lib/utils" import { getCopilotUsage } from "~/services/github/get-copilot-usage" export const usageRoute = new Hono() usageRoute.get("/", async (c) => { try { - const usage = await getCopilotUsage() + const usage = await getCopilotUsage(defaultApiContext()) return c.json(usage) } catch (error) { console.error("Error fetching Copilot usage:", error) diff --git a/src/services/copilot/create-chat-completions.ts b/src/services/copilot/create-chat-completions.ts index 8534151da..c3e031b9c 100644 --- a/src/services/copilot/create-chat-completions.ts +++ b/src/services/copilot/create-chat-completions.ts @@ -1,14 +1,16 @@ import consola from "consola" import { events } from "fetch-event-stream" +import type { ApiContext } from "~/lib/api-config" + import { copilotHeaders, copilotBaseUrl } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" export const createChatCompletions = async ( + ctx: ApiContext, payload: ChatCompletionsPayload, ) => { - if (!state.copilotToken) throw new Error("Copilot token not found") + if (!ctx.account.copilotToken) throw new Error("Copilot token not found") const enableVision = payload.messages.some( (x) => @@ -17,18 +19,16 @@ export const createChatCompletions = async ( ) // Agent/user check for X-Initiator header - // Determine if any message is from an agent ("assistant" or "tool") const isAgentCall = payload.messages.some((msg) => ["assistant", "tool"].includes(msg.role), ) - // Build headers and add X-Initiator const headers: Record = { - ...copilotHeaders(state, enableVision), + ...copilotHeaders(ctx, enableVision), "X-Initiator": isAgentCall ? "agent" : "user", } - const response = await fetch(`${copilotBaseUrl(state)}/chat/completions`, { + const response = await fetch(`${copilotBaseUrl(ctx)}/chat/completions`, { method: "POST", headers, body: JSON.stringify(payload), diff --git a/src/services/copilot/create-embeddings.ts b/src/services/copilot/create-embeddings.ts index f2ad5c233..5c29a804f 100644 --- a/src/services/copilot/create-embeddings.ts +++ b/src/services/copilot/create-embeddings.ts @@ -1,13 +1,17 @@ +import type { ApiContext } from "~/lib/api-config" + import { copilotHeaders, copilotBaseUrl } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" -export const createEmbeddings = async (payload: EmbeddingRequest) => { - if (!state.copilotToken) throw new Error("Copilot token not found") +export const createEmbeddings = async ( + ctx: ApiContext, + payload: EmbeddingRequest, +) => { + if (!ctx.account.copilotToken) throw new Error("Copilot token not found") - const response = await fetch(`${copilotBaseUrl(state)}/embeddings`, { + const response = await fetch(`${copilotBaseUrl(ctx)}/embeddings`, { method: "POST", - headers: copilotHeaders(state), + headers: copilotHeaders(ctx), body: JSON.stringify(payload), }) diff --git a/src/services/copilot/get-models.ts b/src/services/copilot/get-models.ts index 3cfa30af0..450e3ac35 100644 --- a/src/services/copilot/get-models.ts +++ b/src/services/copilot/get-models.ts @@ -1,10 +1,11 @@ +import type { ApiContext } from "~/lib/api-config" + import { copilotBaseUrl, copilotHeaders } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" -export const getModels = async () => { - const response = await fetch(`${copilotBaseUrl(state)}/models`, { - headers: copilotHeaders(state), +export const getModels = async (ctx: ApiContext) => { + const response = await fetch(`${copilotBaseUrl(ctx)}/models`, { + headers: copilotHeaders(ctx), }) if (!response.ok) throw new HTTPError("Failed to get models", response) diff --git a/src/services/github/get-copilot-token.ts b/src/services/github/get-copilot-token.ts index 98744bab1..423e4827b 100644 --- a/src/services/github/get-copilot-token.ts +++ b/src/services/github/get-copilot-token.ts @@ -1,12 +1,13 @@ +import type { ApiContext } from "~/lib/api-config" + import { GITHUB_API_BASE_URL, githubHeaders } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" -export const getCopilotToken = async () => { +export const getCopilotToken = async (ctx: ApiContext) => { const response = await fetch( `${GITHUB_API_BASE_URL}/copilot_internal/v2/token`, { - headers: githubHeaders(state), + headers: githubHeaders(ctx), }, ) @@ -15,7 +16,6 @@ export const getCopilotToken = async () => { return (await response.json()) as GetCopilotTokenResponse } -// Trimmed for the sake of simplicity interface GetCopilotTokenResponse { expires_at: number refresh_in: number diff --git a/src/services/github/get-copilot-usage.ts b/src/services/github/get-copilot-usage.ts index 6cdd8bc10..5c8e0bc30 100644 --- a/src/services/github/get-copilot-usage.ts +++ b/src/services/github/get-copilot-usage.ts @@ -1,10 +1,13 @@ +import type { ApiContext } from "~/lib/api-config" + import { GITHUB_API_BASE_URL, githubHeaders } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" -export const getCopilotUsage = async (): Promise => { +export const getCopilotUsage = async ( + ctx: ApiContext, +): Promise => { const response = await fetch(`${GITHUB_API_BASE_URL}/copilot_internal/user`, { - headers: githubHeaders(state), + headers: githubHeaders(ctx), }) if (!response.ok) { diff --git a/src/services/github/get-user.ts b/src/services/github/get-user.ts index 23e1b1c1c..534e8e325 100644 --- a/src/services/github/get-user.ts +++ b/src/services/github/get-user.ts @@ -1,11 +1,12 @@ +import type { ApiContext } from "~/lib/api-config" + import { GITHUB_API_BASE_URL, standardHeaders } from "~/lib/api-config" import { HTTPError } from "~/lib/error" -import { state } from "~/lib/state" -export async function getGitHubUser() { +export async function getGitHubUser(ctx: ApiContext) { const response = await fetch(`${GITHUB_API_BASE_URL}/user`, { headers: { - authorization: `token ${state.githubToken}`, + authorization: `token ${ctx.account.githubToken}`, ...standardHeaders(), }, }) @@ -15,7 +16,6 @@ export async function getGitHubUser() { return (await response.json()) as GithubUserResponse } -// Trimmed for the sake of simplicity interface GithubUserResponse { login: string } diff --git a/src/start.ts b/src/start.ts index ead59acd4..423f51337 100644 --- a/src/start.ts +++ b/src/start.ts @@ -61,8 +61,7 @@ export async function runServer(options: RunServerOptions): Promise { // Resolve legacy single token if no accounts file is provided. let legacyToken = options.githubToken if (!options.accountsFile && !legacyToken) { - await setupGitHubToken() - legacyToken = state.githubToken + legacyToken = await setupGitHubToken() } else if (legacyToken) { consola.info("Using provided GitHub token") } @@ -80,7 +79,7 @@ export async function runServer(options: RunServerOptions): Promise { } const pool = new AccountPool(loaded, options.strategy) - // eslint-disable-next-line require-atomic-updates + state.pool = pool persistAccounts(loaded) diff --git a/tests/create-chat-completions.test.ts b/tests/create-chat-completions.test.ts index d18e741aa..59d62b976 100644 --- a/tests/create-chat-completions.test.ts +++ b/tests/create-chat-completions.test.ts @@ -1,16 +1,22 @@ import { test, expect, mock } from "bun:test" +import type { Account } from "../src/lib/account-pool" import type { ChatCompletionsPayload } from "../src/services/copilot/create-chat-completions" -import { state } from "../src/lib/state" import { createChatCompletions } from "../src/services/copilot/create-chat-completions" -// Mock state -state.copilotToken = "test-token" -state.vsCodeVersion = "1.0.0" -state.accountType = "individual" +const account: Account = { + name: "test", + accountType: "individual", + githubToken: "ghu_test", + copilotToken: "test-token", + copilotTokenRefreshAt: 0, + inFlight: 0, + lastUsedAt: 0, + failureCount: 0, +} +const ctx = { account, vsCodeVersion: "1.0.0" } -// Helper to mock fetch const fetchMock = mock( (_url: string, opts: { headers: Record }) => { return { @@ -31,7 +37,7 @@ test("sets X-Initiator to agent if tool/assistant present", async () => { ], model: "gpt-test", } - await createChatCompletions(payload) + await createChatCompletions(ctx, payload) expect(fetchMock).toHaveBeenCalled() const headers = ( fetchMock.mock.calls[0][1] as { headers: Record } @@ -47,7 +53,7 @@ test("sets X-Initiator to user if only user present", async () => { ], model: "gpt-test", } - await createChatCompletions(payload) + await createChatCompletions(ctx, payload) expect(fetchMock).toHaveBeenCalled() const headers = ( fetchMock.mock.calls[1][1] as { headers: Record } diff --git a/tests/model-mapping.test.ts b/tests/model-mapping.test.ts new file mode 100644 index 000000000..94bc9993c --- /dev/null +++ b/tests/model-mapping.test.ts @@ -0,0 +1,122 @@ +import { describe, expect, test } from "bun:test" +import type { Context } from "hono" +import type { Model } from "~/services/copilot/get-models" + +import { + jaccardSimilarity, + mapModelIdToAvailableModels, + normalizeClaudeModelVersion, + resolveModelId, +} from "../src/lib/utils" + +function makeContext(anthropicBeta?: string): Context { + return { + req: { + header: (name: string) => + name.toLowerCase() === "anthropic-beta" ? anthropicBeta : undefined, + }, + } as unknown as Context +} + +describe("model mapping", () => { + test("normalizes Claude numeric segments from hyphen to dot", () => { + expect(normalizeClaudeModelVersion("claude-opus-4-6")).toBe( + "claude-opus-4.6", + ) + expect(normalizeClaudeModelVersion("claude-3-5-sonnet-20241022")).toBe( + "claude-3.5-sonnet-20241022", + ) + }) + + test("does not change non-Claude models", () => { + expect(normalizeClaudeModelVersion("gpt-4.1")).toBe("gpt-4.1") + }) + + test("appends -1m when anthropic-beta has context-1m", () => { + const c = makeContext("foo,context-1m-2025-08-07,bar") + expect(resolveModelId("claude-opus-4-6", c)).toBe("claude-opus-4.6-1m") + }) + + test("does not append -1m twice", () => { + const c = makeContext("context-1m-2025-08-07") + expect(resolveModelId("claude-opus-4.6-1m", c)).toBe("claude-opus-4.6-1m") + }) + + test("keeps normalized model when context-1m is absent", () => { + const c = makeContext("claude-code-2025-02-19") + expect(resolveModelId("claude-opus-4-6", c)).toBe("claude-opus-4.6") + }) + + test("calculates Jaccard similarity for fuzzy matching", () => { + expect(jaccardSimilarity("claude-opus-4.6", "claude-opus-4.6")).toBe(1) + expect(jaccardSimilarity("claude-opus-4.6", "gpt-4o")).toBeLessThan(0.3) + }) + + test("uses exact match before fuzzy matching", () => { + const models = makeModels([ + "claude-opus-4.6", + "claude-sonnet-4.5", + "auto", + ]) + expect(mapModelIdToAvailableModels("claude-opus-4.6", models)).toBe( + "claude-opus-4.6", + ) + }) + + test("uses fuzzy match when exact model is missing", () => { + const models = makeModels([ + "claude-opus-4.6", + "claude-sonnet-4.5", + "auto", + ]) + expect(mapModelIdToAvailableModels("claude-opus-4-6", models)).toBe( + "claude-opus-4.6", + ) + }) + + test("falls back to auto-version model when no fuzzy match", () => { + const models = makeModels([ + "claude-opus-4.6", + "auto", + "gpt-4o", + ]) + expect(mapModelIdToAvailableModels("nonexistent-model", models)).toBe( + "auto", + ) + }) + + test("falls back to first model when auto is unavailable", () => { + const models = makeModels(["claude-opus-4.6", "gpt-4o"]) + expect(mapModelIdToAvailableModels("unknown-model", models)).toBe( + "claude-opus-4.6", + ) + }) +}) + +function makeModel(id: string, version = "v1"): Model { + return { + id, + version, + name: id, + vendor: "copilot", + object: "model", + preview: false, + model_picker_enabled: true, + capabilities: { + family: id.includes("claude") ? "claude" : "other", + limits: {}, + object: "model_capabilities", + supports: {}, + tokenizer: "o200k_base", + type: "chat", + }, + } +} + +function makeModels(ids: Array): Array { + const versions: Record = { + auto: "v-auto", + "gpt-4o": "v-auto", + } + return ids.map((id) => makeModel(id, versions[id] ?? "v1")) +}