diff --git a/.env.example b/.env.example new file mode 100644 index 000000000..6e237e5f5 --- /dev/null +++ b/.env.example @@ -0,0 +1 @@ +GH_TOKEN= diff --git a/.gitignore b/.gitignore index 577a4f199..73ebed921 100644 --- a/.gitignore +++ b/.gitignore @@ -11,4 +11,6 @@ node_modules/ .eslintcache # build output -dist/ \ No newline at end of file +dist/ + +.env diff --git a/Dockerfile b/Dockerfile index 1265220ef..39be38106 100644 --- a/Dockerfile +++ b/Dockerfile @@ -21,5 +21,7 @@ HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ CMD wget --spider -q http://localhost:4141/ || exit 1 COPY entrypoint.sh /entrypoint.sh -RUN chmod +x /entrypoint.sh +COPY refresh-token /usr/local/bin/refresh-token +RUN chmod +x /entrypoint.sh /usr/local/bin/refresh-token + ENTRYPOINT ["/entrypoint.sh"] diff --git a/bun.lock b/bun.lock index 20e895e7f..9ece87578 100644 --- a/bun.lock +++ b/bun.lock @@ -1,5 +1,6 @@ { "lockfileVersion": 1, + "configVersion": 0, "workspaces": { "": { "name": "copilot-api", diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 000000000..61d744e2d --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +services: + copilot-api: + container_name: copilot-api + build: . + ports: + - "4141:4141" + volumes: + - copilot-data:/root/.local/share/copilot-api + restart: unless-stopped + +volumes: + copilot-data: diff --git a/entrypoint.sh b/entrypoint.sh index dfe63c902..cc1c741ce 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -1,9 +1,13 @@ #!/bin/sh -if [ "$1" = "--auth" ]; then - # Run auth command - exec bun run dist/main.js auth -else - # Default command - exec bun run dist/main.js start -g "$GH_TOKEN" "$@" -fi +case "$1" in + --auth) + exec bun run dist/main.js auth + ;; + refresh-token) + exec bun run dist/main.js refresh-token + ;; + *) + exec bun run dist/main.js start -g "$GH_TOKEN" "$@" + ;; +esac diff --git a/refresh-token b/refresh-token new file mode 100644 index 000000000..20d14123f --- /dev/null +++ b/refresh-token @@ -0,0 +1,5 @@ +#!/bin/sh +set -e + +echo "Refreshing Copilot token..." +bun run /app/dist/main.js refresh-token "$@" diff --git a/src/lib/model-normalization.ts b/src/lib/model-normalization.ts new file mode 100644 index 000000000..e3bf67d3f --- /dev/null +++ b/src/lib/model-normalization.ts @@ -0,0 +1,44 @@ +// When configuring claude-opus-4-6[1M] in Claude Code, the [1M] suffix only +// activates the client-side 1M context window. The actual model name sent in +// requests is still claude-opus-4-6. So to enable 1M context support, we must +// map claude-opus-4-6 to claude-opus-4.6-1m (not claude-opus-4.6). +const modelAliases: Record = { + "claude-opus-4-6[1M]": "claude-opus-4.6-1m", + "claude-opus-4-6": "claude-opus-4.6-1m", + "claude-sonnet-4-6": "claude-sonnet-4.6", + "claude-haiku-4-5": "claude-haiku-4.5", +} + +const reverseModelAliases = new Map>() +for (const [alias, canonical] of Object.entries(modelAliases)) { + const aliases = reverseModelAliases.get(canonical) ?? [] + aliases.push(alias) + reverseModelAliases.set(canonical, aliases) +} + +export function normalizeModelName(modelId: string): string { + return modelAliases[modelId] ?? modelId +} + +export function getModelAliases(modelId: string): Array { + return reverseModelAliases.get(modelId) ?? [] +} + +export function expandModelIdsWithAliases( + modelIds: Array, +): Array { + const expandedModelIds: Array = [] + const seenModelIds = new Set() + + for (const modelId of modelIds) { + for (const variant of [modelId, ...getModelAliases(modelId)]) { + if (seenModelIds.has(variant)) { + continue + } + seenModelIds.add(variant) + expandedModelIds.push(variant) + } + } + + return expandedModelIds +} diff --git a/src/lib/token.ts b/src/lib/token.ts index fc8d2785f..4b6b49549 100644 --- a/src/lib/token.ts +++ b/src/lib/token.ts @@ -12,6 +12,35 @@ import { state } from "./state" const readGithubToken = () => fs.readFile(PATHS.GITHUB_TOKEN_PATH, "utf8") +const REFRESH_COOLDOWN_MS = 60 * 60 * 1000 // 1 hour +let lastRefreshAttempt = 0 + +export async function refreshCopilotTokenOnError(): Promise { + const now = Date.now() + if (now - lastRefreshAttempt < REFRESH_COOLDOWN_MS) { + consola.warn( + "Token refresh on error skipped: cooldown not elapsed (1 hour limit)", + ) + return false + } + + lastRefreshAttempt = now + consola.info("Attempting to refresh Copilot token due to request error") + + try { + const { token } = await getCopilotToken() + state.copilotToken = token + consola.info("Copilot token refreshed successfully after error") + if (state.showToken) { + consola.info("Refreshed Copilot token:", token) + } + return true + } catch (error) { + consola.error("Failed to refresh Copilot token on error:", error) + return false + } +} + const writeGithubToken = (token: string) => fs.writeFile(PATHS.GITHUB_TOKEN_PATH, token) diff --git a/src/main.ts b/src/main.ts index 4f6ca784b..afe174d54 100644 --- a/src/main.ts +++ b/src/main.ts @@ -5,6 +5,7 @@ import { defineCommand, runMain } from "citty" import { auth } from "./auth" import { checkUsage } from "./check-usage" import { debug } from "./debug" +import { refreshToken } from "./refresh-token" import { start } from "./start" const main = defineCommand({ @@ -13,7 +14,13 @@ const main = defineCommand({ description: "A wrapper around GitHub Copilot API to make it OpenAI compatible, making it usable for other tools.", }, - subCommands: { auth, start, "check-usage": checkUsage, debug }, + subCommands: { + auth, + start, + "check-usage": checkUsage, + debug, + "refresh-token": refreshToken, + }, }) await runMain(main) diff --git a/src/refresh-token.ts b/src/refresh-token.ts new file mode 100644 index 000000000..1a8a956ab --- /dev/null +++ b/src/refresh-token.ts @@ -0,0 +1,37 @@ +import { defineCommand } from "citty" +import consola from "consola" + +export const refreshToken = defineCommand({ + meta: { + name: "refresh-token", + description: "Manually refresh the Copilot token via API", + }, + args: { + port: { + type: "string", + alias: "p", + default: "4141", + description: "The port the server is running on", + }, + }, + async run({ args }) { + const port = args.port + const url = `http://localhost:${port}/token/refresh` + + try { + const response = await fetch(url, { method: "POST" }) + const data = await response.json() + + if (response.ok && data.success) { + consola.success("Token refreshed successfully") + } else { + consola.error("Failed to refresh token:", data.error || "Unknown error") + process.exit(1) + } + } catch (error) { + consola.error("Failed to connect to server:", error) + consola.info("Make sure the server is running on port", port) + process.exit(1) + } + }, +}) diff --git a/src/routes/chat-completions/handler.ts b/src/routes/chat-completions/handler.ts index 04a5ae9ed..265bd692b 100644 --- a/src/routes/chat-completions/handler.ts +++ b/src/routes/chat-completions/handler.ts @@ -1,3 +1,4 @@ +import type { ServerSentEventMessage } from "fetch-event-stream" import type { Context } from "hono" import consola from "consola" @@ -9,7 +10,8 @@ import { state } from "~/lib/state" import { getTokenCount } from "~/lib/tokenizer" import { isNullish } from "~/lib/utils" import { - createChatCompletions, + createChatCompletionsStream, + type ChatCompletionChunk, type ChatCompletionResponse, type ChatCompletionsPayload, } from "~/services/copilot/create-chat-completions" @@ -47,22 +49,177 @@ export async function handleCompletion(c: Context) { consola.debug("Set max_tokens to:", JSON.stringify(payload.max_tokens)) } - const response = await createChatCompletions(payload) + // 记录客户端是否请求流式响应 + const clientWantsStream = payload.stream === true - if (isNonStreaming(response)) { - consola.debug("Non-streaming response:", JSON.stringify(response)) - return c.json(response) + // 内部始终使用流式模式,避免长时间请求超时导致 ECONNRESET + const response = await createChatCompletionsStream(payload) + + // 如果客户端请求流式响应,直接透传 + if (clientWantsStream) { + consola.debug("Streaming response") + return streamSSE(c, async (stream) => { + for await (const chunk of response) { + consola.debug("Streaming chunk:", JSON.stringify(chunk)) + await stream.writeSSE(chunk as SSEMessage) + } + }) } - consola.debug("Streaming response") - return streamSSE(c, async (stream) => { - for await (const chunk of response) { - consola.debug("Streaming chunk:", JSON.stringify(chunk)) - await stream.writeSSE(chunk as SSEMessage) + // 客户端请求非流式响应,收集流式数据块并合并 + consola.debug("Collecting stream chunks for non-streaming response") + const nonStreamResponse = await collectStreamToResponse(response) + consola.debug("Non-streaming response:", JSON.stringify(nonStreamResponse)) + return c.json(nonStreamResponse) +} + +type FinishReason = "stop" | "length" | "tool_calls" | "content_filter" + +type ToolCallAccumulator = { + id: string + type: "function" + function: { name: string; arguments: string } +} + +type StreamAccumulator = { + id: string + model: string + created: number + systemFingerprint?: string + finishReason: FinishReason + content: string + toolCalls: Map + usage?: ChatCompletionResponse["usage"] +} + +/** + * 将流式响应的数据块合并为非流式响应格式 + */ +async function collectStreamToResponse( + stream: AsyncGenerator, +): Promise { + const accumulator = createAccumulator() + + for await (const chunk of stream) { + if (!chunk.data) continue + if (chunk.data === "[DONE]") break + + const parsed = parseChunkDataOrLog(chunk.data) + if (!parsed) continue + applyChunkToAccumulator(parsed, accumulator) + } + + return buildResponse(accumulator) +} + +function createAccumulator(): StreamAccumulator { + return { + id: "", + model: "", + created: 0, + finishReason: "stop", + content: "", + toolCalls: new Map(), + } +} + +function parseChunkDataOrLog(data: unknown): ChatCompletionChunk | null { + if (typeof data !== "string") return null + + try { + return JSON.parse(data) as ChatCompletionChunk + } catch (error) { + consola.debug("Failed to parse SSE chunk data", { + dataPreview: data.slice(0, 500), + error, + }) + return null + } +} + +function applyChunkToAccumulator( + parsed: ChatCompletionChunk, + accumulator: StreamAccumulator, +) { + if (!accumulator.id && parsed.id) accumulator.id = parsed.id + if (!accumulator.model && parsed.model) accumulator.model = parsed.model + if (!accumulator.created && parsed.created) + accumulator.created = parsed.created + if (!accumulator.systemFingerprint && parsed.system_fingerprint) { + accumulator.systemFingerprint = parsed.system_fingerprint + } + if (parsed.usage) accumulator.usage = parsed.usage + + const choice = parsed.choices.at(0) + if (!choice) return + + if (choice.finish_reason) accumulator.finishReason = choice.finish_reason + + if (typeof choice.delta.content === "string") { + accumulator.content += choice.delta.content + } + + if (choice.delta.tool_calls) { + mergeToolCalls(accumulator.toolCalls, choice.delta.tool_calls) + } +} + +function mergeToolCalls( + toolCalls: Map, + deltas: NonNullable< + ChatCompletionChunk["choices"][number]["delta"]["tool_calls"] + >, +) { + for (const delta of deltas) { + const existing = toolCalls.get(delta.index) + if (!existing) { + toolCalls.set(delta.index, { + id: delta.id ?? "", + type: "function", + function: { + name: delta.function?.name ?? "", + arguments: delta.function?.arguments ?? "", + }, + }) + continue + } + + if (!existing.id && delta.id) existing.id = delta.id + if (!existing.function.name && delta.function?.name) { + existing.function.name = delta.function.name } - }) + if (delta.function?.arguments) { + existing.function.arguments += delta.function.arguments + } + } } -const isNonStreaming = ( - response: Awaited>, -): response is ChatCompletionResponse => Object.hasOwn(response, "choices") +function buildResponse(accumulator: StreamAccumulator): ChatCompletionResponse { + const response: ChatCompletionResponse = { + id: accumulator.id, + object: "chat.completion", + created: accumulator.created, + model: accumulator.model, + choices: [ + { + index: 0, + message: { + role: "assistant", + content: accumulator.content || null, + ...(accumulator.toolCalls.size > 0 && { + tool_calls: Array.from(accumulator.toolCalls.values()), + }), + }, + logprobs: null, + finish_reason: accumulator.finishReason, + }, + ], + } + + if (accumulator.systemFingerprint) { + response.system_fingerprint = accumulator.systemFingerprint + } + if (accumulator.usage) response.usage = accumulator.usage + + return response +} diff --git a/src/routes/messages/count-tokens-handler.ts b/src/routes/messages/count-tokens-handler.ts index 2ec849cb8..939cee317 100644 --- a/src/routes/messages/count-tokens-handler.ts +++ b/src/routes/messages/count-tokens-handler.ts @@ -20,7 +20,7 @@ export async function handleCountTokens(c: Context) { const openAIPayload = translateToOpenAI(anthropicPayload) const selectedModel = state.models?.data.find( - (model) => model.id === anthropicPayload.model, + (model) => model.id === openAIPayload.model, ) if (!selectedModel) { diff --git a/src/routes/messages/non-stream-translation.ts b/src/routes/messages/non-stream-translation.ts index dc41e6382..dd22038f6 100644 --- a/src/routes/messages/non-stream-translation.ts +++ b/src/routes/messages/non-stream-translation.ts @@ -1,3 +1,4 @@ +import { normalizeModelName } from "~/lib/model-normalization" import { type ChatCompletionResponse, type ChatCompletionsPayload, @@ -30,7 +31,7 @@ export function translateToOpenAI( payload: AnthropicMessagesPayload, ): ChatCompletionsPayload { return { - model: translateModelName(payload.model), + model: normalizeModelName(payload.model), messages: translateAnthropicMessagesToOpenAI( payload.messages, payload.system, @@ -46,16 +47,6 @@ export function translateToOpenAI( } } -function translateModelName(model: string): string { - // Subagent requests use a specific model number which Copilot doesn't support - if (model.startsWith("claude-sonnet-4-")) { - return model.replace(/^claude-sonnet-4-.*/, "claude-sonnet-4") - } else if (model.startsWith("claude-opus-")) { - return model.replace(/^claude-opus-4-.*/, "claude-opus-4") - } - return model -} - function translateAnthropicMessagesToOpenAI( anthropicMessages: Array, system: string | Array | undefined, diff --git a/src/routes/models/route.ts b/src/routes/models/route.ts index 5254e2af7..1a3b6e42e 100644 --- a/src/routes/models/route.ts +++ b/src/routes/models/route.ts @@ -1,6 +1,10 @@ import { Hono } from "hono" import { forwardError } from "~/lib/error" +import { + expandModelIdsWithAliases, + normalizeModelName, +} from "~/lib/model-normalization" import { state } from "~/lib/state" import { cacheModels } from "~/lib/utils" @@ -13,15 +17,26 @@ modelRoutes.get("/", async (c) => { await cacheModels() } - const models = state.models?.data.map((model) => ({ - id: model.id, - object: "model", - type: "model", - created: 0, // No date available from source - created_at: new Date(0).toISOString(), // No date available from source - owned_by: model.vendor, - display_name: model.name, - })) + const modelById = new Map( + state.models?.data.map((model) => [model.id, model]), + ) + const modelIds = expandModelIdsWithAliases( + state.models?.data.map((model) => model.id) ?? [], + ) + const models = modelIds.flatMap((modelId) => { + const sourceModel = modelById.get(normalizeModelName(modelId)) + if (!sourceModel) return [] + + return { + id: modelId, + object: "model", + type: "model", + created: 0, // No date available from source + created_at: new Date(0).toISOString(), // No date available from source + owned_by: sourceModel.vendor, + display_name: sourceModel.name, + } + }) return c.json({ object: "list", diff --git a/src/routes/token/route.ts b/src/routes/token/route.ts index dd0456d9a..1b1ae41f1 100644 --- a/src/routes/token/route.ts +++ b/src/routes/token/route.ts @@ -1,6 +1,7 @@ import { Hono } from "hono" import { state } from "~/lib/state" +import { getCopilotToken } from "~/services/github/get-copilot-token" export const tokenRoute = new Hono() @@ -14,3 +15,18 @@ tokenRoute.get("/", (c) => { return c.json({ error: "Failed to fetch token", token: null }, 500) } }) + +tokenRoute.post("/refresh", async (c) => { + try { + const { token } = await getCopilotToken() + state.copilotToken = token + console.log("Copilot token manually refreshed") + return c.json({ + success: true, + message: "Token refreshed successfully", + }) + } catch (error) { + console.error("Error refreshing token:", error) + return c.json({ error: "Failed to refresh token", success: false }, 500) + } +}) diff --git a/src/services/copilot/create-chat-completions.ts b/src/services/copilot/create-chat-completions.ts index 8534151da..e7cecf3a3 100644 --- a/src/services/copilot/create-chat-completions.ts +++ b/src/services/copilot/create-chat-completions.ts @@ -1,38 +1,62 @@ import consola from "consola" -import { events } from "fetch-event-stream" +import { events, type ServerSentEventMessage } from "fetch-event-stream" import { copilotHeaders, copilotBaseUrl } from "~/lib/api-config" import { HTTPError } from "~/lib/error" +import { normalizeModelName } from "~/lib/model-normalization" import { state } from "~/lib/state" +import { refreshCopilotTokenOnError } from "~/lib/token" -export const createChatCompletions = async ( +async function doFetch( payload: ChatCompletionsPayload, -) => { + streamOverride?: boolean, +) { if (!state.copilotToken) throw new Error("Copilot token not found") - const enableVision = payload.messages.some( + const normalizedPayload: ChatCompletionsPayload = { + ...payload, + model: normalizeModelName(payload.model), + } + + const enableVision = normalizedPayload.messages.some( (x) => typeof x.content !== "string" && x.content?.some((x) => x.type === "image_url"), ) - // Agent/user check for X-Initiator header - // Determine if any message is from an agent ("assistant" or "tool") - const isAgentCall = payload.messages.some((msg) => + const isAgentCall = normalizedPayload.messages.some((msg) => ["assistant", "tool"].includes(msg.role), ) - // Build headers and add X-Initiator const headers: Record = { ...copilotHeaders(state, enableVision), "X-Initiator": isAgentCall ? "agent" : "user", } - const response = await fetch(`${copilotBaseUrl(state)}/chat/completions`, { + const body = + streamOverride !== undefined ? + { ...normalizedPayload, stream: streamOverride } + : normalizedPayload + + return fetch(`${copilotBaseUrl(state)}/chat/completions`, { method: "POST", headers, - body: JSON.stringify(payload), + body: JSON.stringify(body), }) +} + +export const createChatCompletions = async ( + payload: ChatCompletionsPayload, +) => { + let response = await doFetch(payload) + + if (response.status === 401) { + consola.warn("Got 401, attempting token refresh") + const refreshed = await refreshCopilotTokenOnError() + if (refreshed) { + response = await doFetch(payload) + } + } if (!response.ok) { consola.error("Failed to create chat completions", response) @@ -46,6 +70,31 @@ export const createChatCompletions = async ( return (await response.json()) as ChatCompletionResponse } +/** + * 强制使用流式模式的版本,返回类型始终是 AsyncGenerator + * 用于避免非流式请求的超时问题 + */ +export const createChatCompletionsStream = async ( + payload: Omit, +): Promise> => { + let response = await doFetch(payload as ChatCompletionsPayload, true) + + if (response.status === 401) { + consola.warn("Got 401, attempting token refresh") + const refreshed = await refreshCopilotTokenOnError() + if (refreshed) { + response = await doFetch(payload as ChatCompletionsPayload, true) + } + } + + if (!response.ok) { + consola.error("Failed to create chat completions", response) + throw new HTTPError("Failed to create chat completions", response) + } + + return events(response) +} + // Streaming types export interface ChatCompletionChunk { diff --git a/src/start.ts b/src/start.ts index 14abbbdff..9fe1a8caf 100644 --- a/src/start.ts +++ b/src/start.ts @@ -6,6 +6,7 @@ import consola from "consola" import { serve, type ServerHandler } from "srvx" import invariant from "tiny-invariant" +import { expandModelIdsWithAliases } from "./lib/model-normalization" import { ensurePaths } from "./lib/paths" import { initProxyFromEnv } from "./lib/proxy" import { generateEnvScript } from "./lib/shell" @@ -59,9 +60,12 @@ export async function runServer(options: RunServerOptions): Promise { await setupCopilotToken() await cacheModels() + const availableModelIds = expandModelIdsWithAliases( + state.models?.data.map((model) => model.id) ?? [], + ) consola.info( - `Available models: \n${state.models?.data.map((model) => `- ${model.id}`).join("\n")}`, + `Available models: \n${availableModelIds.map((modelId) => `- ${modelId}`).join("\n")}`, ) const serverUrl = `http://localhost:${options.port}` @@ -73,7 +77,7 @@ export async function runServer(options: RunServerOptions): Promise { "Select a model to use with Claude Code", { type: "select", - options: state.models.data.map((model) => model.id), + options: availableModelIds, }, ) @@ -81,7 +85,7 @@ export async function runServer(options: RunServerOptions): Promise { "Select a small model to use with Claude Code", { type: "select", - options: state.models.data.map((model) => model.id), + options: availableModelIds, }, )