diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml index 97b124ba1..ad47bbbb6 100644 --- a/.github/workflows/release-docker.yml +++ b/.github/workflows/release-docker.yml @@ -7,8 +7,9 @@ name: Docker Build and Push on: push: - # branches: [ "main" ] - # Publish semver tags as releases. + # 对任意分支的 push 均触发(包含默认分支) + branches: [ '**' ] + # 保留:当推送语义化 tag 时也会触发 tags: [ 'v*.*.*' ] paths-ignore: - 'docs/**' @@ -73,6 +74,8 @@ jobs: github-token: ${{ secrets.GITHUB_TOKEN }} images: ${{ env.REGISTRY }}/${{ github.repository }} tags: | + # 每次 push 都产出 latest 标签 + type=raw,value=latest type=semver,pattern=v{{version}} type=semver,pattern=v{{major}}.{{minor}} type=semver,pattern=v{{major}} @@ -88,4 +91,3 @@ jobs: tags: ${{ steps.meta.outputs.tags }} platforms: linux/amd64,linux/arm64 labels: ${{ steps.meta.outputs.labels }} - diff --git a/Dockerfile b/Dockerfile index 1265220ef..a9c0e3b69 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,12 +14,10 @@ COPY ./package.json ./bun.lock ./ RUN bun install --frozen-lockfile --production --ignore-scripts --no-cache COPY --from=builder /app/dist ./dist +COPY ./pages ./pages EXPOSE 4141 -HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \ - CMD wget --spider -q http://localhost:4141/ || exit 1 - COPY entrypoint.sh /entrypoint.sh RUN chmod +x /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] diff --git a/entrypoint.sh b/entrypoint.sh index dfe63c902..5659f836e 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -3,7 +3,25 @@ if [ "$1" = "--auth" ]; then # Run auth command exec bun run dist/main.js auth else - # Default command - exec bun run dist/main.js start -g "$GH_TOKEN" "$@" + # Build start command with optional parameters + START_CMD="bun run dist/main.js start" + + # Add GitHub token if provided + if [ -n "$GH_TOKEN" ]; then + START_CMD="$START_CMD -g $GH_TOKEN" + fi + + # Add API key if provided + if [ -n "$API_KEY" ]; then + START_CMD="$START_CMD --api-key $API_KEY" + fi + + # Append any additional arguments + if [ $# -gt 0 ]; then + START_CMD="$START_CMD $@" + fi + + # Execute the command + exec $START_CMD fi diff --git a/src/lib/api-key-auth.ts b/src/lib/api-key-auth.ts new file mode 100644 index 000000000..b05c5ec73 --- /dev/null +++ b/src/lib/api-key-auth.ts @@ -0,0 +1,41 @@ +import type { Context, Next } from "hono" + +import { HTTPException } from "hono/http-exception" + +import { state } from "~/lib/state" + +/** + * API Key authentication middleware + * Validates API key from Authorization header (Bearer token) or X-API-Key header + * Throws 401 if API key is required but missing or invalid + */ +export async function apiKeyAuth(c: Context, next: Next): Promise { + // Skip authentication if no API key is configured + if (!state.apiKey) { + return next() + } + + // Extract API key from headers + const authHeader = c.req.header("Authorization") + const xApiKey = c.req.header("X-API-Key") + + let providedKey: string | undefined + + // Try Bearer token format first + if (authHeader?.startsWith("Bearer ")) { + providedKey = authHeader.slice(7) + } + // Fallback to X-API-Key header + else if (xApiKey) { + providedKey = xApiKey + } + + // Validate API key + if (!providedKey || providedKey !== state.apiKey) { + throw new HTTPException(401, { + message: "Invalid or missing API key", + }) + } + + await next() +} diff --git a/src/lib/billing-cycle.ts b/src/lib/billing-cycle.ts new file mode 100644 index 000000000..afcabc616 --- /dev/null +++ b/src/lib/billing-cycle.ts @@ -0,0 +1,121 @@ +/** + * Billing Cycle Manager + * + * Manages billing cycles to ensure only the first request in a cycle is billed. + * + * Rules: + * - First request in cycle: X-Initiator = user (billed) + * - Subsequent requests: X-Initiator = agent (not billed) + * - Cycle resets after 5 minutes of inactivity (no requests after last response) + * - Concurrent first requests: only ONE is billed + * - Failed requests: do not enter cycle, no billing + */ + +const CYCLE_TIMEOUT_MS = 5 * 60 * 1000 // 5 minutes + +class BillingCycleManager { + private inCycle: boolean = false + private lastResponseTime: number = 0 + private pendingResponses: number = 0 + private lock: Promise = Promise.resolve() + + /** + * Determines whether the current request should be billed. + * Thread-safe: handles concurrent requests correctly. + * + * @returns 'user' if request should be billed, 'agent' if not + */ + async determineInitiator(): Promise<"user" | "agent"> { + // Acquire lock for thread-safety + await this.acquireLock() + + try { + const now = Date.now() + + // Check if cycle has timed out (5 minutes since last response) + if ( + this.inCycle + && this.pendingResponses === 0 + && now - this.lastResponseTime > CYCLE_TIMEOUT_MS + ) { + this.inCycle = false + } + + // Determine billing + if (!this.inCycle) { + // First request in new cycle - bill it + this.inCycle = true + this.pendingResponses++ + return "user" + } + + // Already in cycle - don't bill + this.pendingResponses++ + return "agent" + } finally { + this.releaseLock() + } + } + + /** + * Mark a response as complete (for both streaming and non-streaming). + * Updates the last response timestamp. + */ + markResponseComplete(): void { + this.lastResponseTime = Date.now() + this.pendingResponses = Math.max(0, this.pendingResponses - 1) + } + + /** + * Mark a request as failed. + * Failed requests do not contribute to the billing cycle. + */ + markRequestFailed(): void { + this.pendingResponses = Math.max(0, this.pendingResponses - 1) + + // If this was the first request and it failed, exit the cycle + if (this.pendingResponses === 0 && this.lastResponseTime === 0) { + this.inCycle = false + } + } + + /** + * Get current cycle status (for debugging/monitoring) + */ + getStatus(): { + inCycle: boolean + lastResponseTime: number + pendingResponses: number + } { + return { + inCycle: this.inCycle, + lastResponseTime: this.lastResponseTime, + pendingResponses: this.pendingResponses, + } + } + + /** + * Reset the billing cycle (for testing purposes) + */ + reset(): void { + this.inCycle = false + this.lastResponseTime = 0 + this.pendingResponses = 0 + } + + // Simple async lock implementation + private async acquireLock(): Promise { + const currentLock = this.lock + let releaseLock!: () => void + this.lock = new Promise((resolve) => { + releaseLock = resolve + }) + await currentLock + this.releaseLock = releaseLock + } + + private releaseLock: () => void = () => {} +} + +// Singleton instance +export const billingCycleManager = new BillingCycleManager() diff --git a/src/lib/sse-heartbeat.ts b/src/lib/sse-heartbeat.ts new file mode 100644 index 000000000..4eb45d1c0 --- /dev/null +++ b/src/lib/sse-heartbeat.ts @@ -0,0 +1,128 @@ +import consola from "consola" + +/** + * SSE 心跳配置 + */ +export interface SSEHeartbeatConfig { + /** 心跳间隔 (毫秒) */ + interval: number + /** 是否启用心跳 */ + enabled: boolean + /** 最大连接时长 (毫秒), 0 表示无限制 */ + maxConnectionDuration: number +} + +/** + * 默认 SSE 心跳配置 + */ +export const DEFAULT_HEARTBEAT_CONFIG: SSEHeartbeatConfig = { + interval: Number(process.env.SSE_HEARTBEAT_INTERVAL) || 2000, // 默认 2 秒 + enabled: process.env.SSE_HEARTBEAT_ENABLED !== "false", // 默认启用 + maxConnectionDuration: + Number(process.env.SSE_MAX_CONNECTION_DURATION) || 600000, // 默认 10 分钟 +} + +/** + * SSE 心跳管理器 + * 用于在 SSE 流式响应中保持连接活跃 + */ +export class SSEHeartbeatManager { + private timer?: Timer + private heartbeatCount = 0 + private startTime = 0 + private readonly config: SSEHeartbeatConfig + private readonly requestId: string + + constructor(requestId: string, config: Partial = {}) { + this.requestId = requestId + this.config = { ...DEFAULT_HEARTBEAT_CONFIG, ...config } + } + + /** + * 启动心跳定时器 + * @param stream SSE 流对象 + * @param onMaxDuration 达到最大连接时长时的回调 + */ + start( + stream: { write: (data: string) => Promise }, + onMaxDuration?: () => void, + ): void { + if (!this.config.enabled) { + consola.debug(`[${this.requestId}] SSE heartbeat is disabled, skipping`) + return + } + + this.startTime = Date.now() + this.heartbeatCount = 0 + + consola.debug( + `[${this.requestId}] Starting SSE heartbeat with interval: ${this.config.interval}ms`, + ) + + this.timer = setInterval(async () => { + try { + this.heartbeatCount++ + await stream.write(": heartbeat\n\n") + + consola.debug( + `[${this.requestId}] Sent heartbeat #${this.heartbeatCount}`, + ) + + // 检查是否超过最大连接时长 + if ( + this.config.maxConnectionDuration > 0 + && Date.now() - this.startTime > this.config.maxConnectionDuration + ) { + consola.warn( + `[${this.requestId}] Connection exceeded max duration (${this.config.maxConnectionDuration}ms), closing`, + ) + this.stop() + onMaxDuration?.() + } + } catch (error) { + consola.warn( + `[${this.requestId}] Failed to send heartbeat #${this.heartbeatCount}:`, + error, + ) + } + }, this.config.interval) + } + + /** + * 停止心跳定时器 + */ + stop(): void { + if (this.timer) { + clearInterval(this.timer) + this.timer = undefined + + const duration = Date.now() - this.startTime + consola.debug( + `[${this.requestId}] Stopped SSE heartbeat after ${this.heartbeatCount} beats (${duration}ms)`, + ) + } + } + + /** + * 获取心跳统计信息 + */ + getStats() { + return { + heartbeatCount: this.heartbeatCount, + duration: Date.now() - this.startTime, + config: this.config, + } + } +} + +/** + * 创建 SSE 心跳管理器 + * @param requestId 请求 ID + * @param config 自定义配置 + */ +export function createHeartbeatManager( + requestId: string, + config?: Partial, +): SSEHeartbeatManager { + return new SSEHeartbeatManager(requestId, config) +} diff --git a/src/lib/state.ts b/src/lib/state.ts index 5ba4dc1d1..de1274970 100644 --- a/src/lib/state.ts +++ b/src/lib/state.ts @@ -12,6 +12,9 @@ export interface State { rateLimitWait: boolean showToken: boolean + // API Key authentication + apiKey?: string + // Rate limiting configuration rateLimitSeconds?: number lastRequestTimestamp?: number diff --git a/src/routes/chat-completions/handler.ts b/src/routes/chat-completions/handler.ts index 04a5ae9ed..bde3c38e9 100644 --- a/src/routes/chat-completions/handler.ts +++ b/src/routes/chat-completions/handler.ts @@ -4,7 +4,9 @@ import consola from "consola" import { streamSSE, type SSEMessage } from "hono/streaming" import { awaitApproval } from "~/lib/approval" +import { billingCycleManager } from "~/lib/billing-cycle" import { checkRateLimit } from "~/lib/rate-limit" +import { createHeartbeatManager } from "~/lib/sse-heartbeat" import { state } from "~/lib/state" import { getTokenCount } from "~/lib/tokenizer" import { isNullish } from "~/lib/utils" @@ -15,6 +17,7 @@ import { } from "~/services/copilot/create-chat-completions" export async function handleCompletion(c: Context) { + const requestId = crypto.randomUUID() await checkRateLimit(state) let payload = await c.req.json() @@ -56,9 +59,31 @@ export async function handleCompletion(c: Context) { consola.debug("Streaming response") return streamSSE(c, async (stream) => { - for await (const chunk of response) { - consola.debug("Streaming chunk:", JSON.stringify(chunk)) - await stream.writeSSE(chunk as SSEMessage) + const heartbeatManager = createHeartbeatManager(requestId) + + try { + heartbeatManager.start(stream, () => { + consola.warn(`[${requestId}] Force closing connection due to timeout`) + }) + + for await (const chunk of response) { + consola.debug("Streaming chunk:", JSON.stringify(chunk)) + await stream.writeSSE(chunk as SSEMessage) + } + // Mark response complete after all chunks are sent + billingCycleManager.markResponseComplete() + + const stats = heartbeatManager.getStats() + consola.info( + `[${requestId}] Stream completed - ${stats.heartbeatCount} heartbeats, ${stats.duration}ms`, + ) + } catch (error) { + // If streaming fails, mark request as failed + billingCycleManager.markRequestFailed() + consola.error(`[${requestId}] Streaming error:`, error) + throw error + } finally { + heartbeatManager.stop() } }) } diff --git a/src/routes/messages/handler.ts b/src/routes/messages/handler.ts index 85dbf6243..172d35cb5 100644 --- a/src/routes/messages/handler.ts +++ b/src/routes/messages/handler.ts @@ -4,7 +4,9 @@ import consola from "consola" import { streamSSE } from "hono/streaming" import { awaitApproval } from "~/lib/approval" +import { billingCycleManager } from "~/lib/billing-cycle" import { checkRateLimit } from "~/lib/rate-limit" +import { createHeartbeatManager } from "~/lib/sse-heartbeat" import { state } from "~/lib/state" import { createChatCompletions, @@ -23,6 +25,7 @@ import { import { translateChunkToAnthropicEvents } from "./stream-translation" export async function handleCompletion(c: Context) { + const requestId = crypto.randomUUID() await checkRateLimit(state) const anthropicPayload = await c.req.json() @@ -55,6 +58,7 @@ export async function handleCompletion(c: Context) { consola.debug("Streaming response from Copilot") return streamSSE(c, async (stream) => { + const heartbeatManager = createHeartbeatManager(requestId) const streamState: AnthropicStreamState = { messageStartSent: false, contentBlockIndex: 0, @@ -62,26 +66,46 @@ export async function handleCompletion(c: Context) { toolCalls: {}, } - for await (const rawEvent of response) { - consola.debug("Copilot raw stream event:", JSON.stringify(rawEvent)) - if (rawEvent.data === "[DONE]") { - break - } + try { + heartbeatManager.start(stream, () => { + consola.warn(`[${requestId}] Force closing connection due to timeout`) + }) - if (!rawEvent.data) { - continue - } + for await (const rawEvent of response) { + consola.debug("Copilot raw stream event:", JSON.stringify(rawEvent)) + if (rawEvent.data === "[DONE]") { + break + } - const chunk = JSON.parse(rawEvent.data) as ChatCompletionChunk - const events = translateChunkToAnthropicEvents(chunk, streamState) + if (!rawEvent.data) { + continue + } - for (const event of events) { - consola.debug("Translated Anthropic event:", JSON.stringify(event)) - await stream.writeSSE({ - event: event.type, - data: JSON.stringify(event), - }) + const chunk = JSON.parse(rawEvent.data) as ChatCompletionChunk + const events = translateChunkToAnthropicEvents(chunk, streamState) + + for (const event of events) { + consola.debug("Translated Anthropic event:", JSON.stringify(event)) + await stream.writeSSE({ + event: event.type, + data: JSON.stringify(event), + }) + } } + // Mark response complete after all chunks are sent + billingCycleManager.markResponseComplete() + + const stats = heartbeatManager.getStats() + consola.info( + `[${requestId}] Stream completed - ${stats.heartbeatCount} heartbeats, ${stats.duration}ms`, + ) + } catch (error) { + // If streaming fails, mark request as failed + billingCycleManager.markRequestFailed() + consola.error(`[${requestId}] Streaming error:`, error) + throw error + } finally { + heartbeatManager.stop() } }) } diff --git a/src/server.ts b/src/server.ts index 462a278f3..3da0d2d32 100644 --- a/src/server.ts +++ b/src/server.ts @@ -2,6 +2,7 @@ import { Hono } from "hono" import { cors } from "hono/cors" import { logger } from "hono/logger" +import { apiKeyAuth } from "./lib/api-key-auth" import { completionRoutes } from "./routes/chat-completions/route" import { embeddingRoutes } from "./routes/embeddings/route" import { messageRoutes } from "./routes/messages/route" @@ -14,12 +15,22 @@ export const server = new Hono() server.use(logger()) server.use(cors()) -server.get("/", (c) => c.text("Server running")) +// Public routes without API key authentication +// server.get("/", (c) => c.text("Server running")) +server.get("/", async (c) => { + const html = await Bun.file("pages/index.html").text() + return c.html(html) +}) + +// Usage endpoint is public +server.route("/usage", usageRoute) + +// Apply API key authentication to all other routes +server.use("*", apiKeyAuth) server.route("/chat/completions", completionRoutes) server.route("/models", modelRoutes) server.route("/embeddings", embeddingRoutes) -server.route("/usage", usageRoute) server.route("/token", tokenRoute) // Compatibility with tools that expect v1/ prefix diff --git a/src/services/copilot/create-chat-completions.ts b/src/services/copilot/create-chat-completions.ts index 8534151da..bc6e8bd20 100644 --- a/src/services/copilot/create-chat-completions.ts +++ b/src/services/copilot/create-chat-completions.ts @@ -2,6 +2,7 @@ import consola from "consola" import { events } from "fetch-event-stream" import { copilotHeaders, copilotBaseUrl } from "~/lib/api-config" +import { billingCycleManager } from "~/lib/billing-cycle" import { HTTPError } from "~/lib/error" import { state } from "~/lib/state" @@ -16,34 +17,43 @@ export const createChatCompletions = async ( && x.content?.some((x) => x.type === "image_url"), ) - // Agent/user check for X-Initiator header - // Determine if any message is from an agent ("assistant" or "tool") - const isAgentCall = payload.messages.some((msg) => - ["assistant", "tool"].includes(msg.role), - ) + // Determine X-Initiator based on billing cycle + const initiator = await billingCycleManager.determineInitiator() // Build headers and add X-Initiator const headers: Record = { ...copilotHeaders(state, enableVision), - "X-Initiator": isAgentCall ? "agent" : "user", + "X-Initiator": initiator, } - const response = await fetch(`${copilotBaseUrl(state)}/chat/completions`, { - method: "POST", - headers, - body: JSON.stringify(payload), - }) - - if (!response.ok) { - consola.error("Failed to create chat completions", response) - throw new HTTPError("Failed to create chat completions", response) + let response: Response + try { + response = await fetch(`${copilotBaseUrl(state)}/chat/completions`, { + method: "POST", + headers, + body: JSON.stringify(payload), + }) + + if (!response.ok) { + consola.error("Failed to create chat completions", response) + billingCycleManager.markRequestFailed() + throw new HTTPError("Failed to create chat completions", response) + } + } catch (error) { + billingCycleManager.markRequestFailed() + throw error } - if (payload.stream) { - return events(response) + // For non-streaming, mark as complete immediately after receiving response + if (!payload.stream) { + const result = (await response.json()) as ChatCompletionResponse + billingCycleManager.markResponseComplete() + return result } - return (await response.json()) as ChatCompletionResponse + // For streaming, return the event stream + // Note: The caller (route handler) must call markResponseComplete() after stream ends + return events(response) } // Streaming types diff --git a/src/start.ts b/src/start.ts index 14abbbdff..11d4f9e5e 100644 --- a/src/start.ts +++ b/src/start.ts @@ -3,7 +3,6 @@ import { defineCommand } from "citty" import clipboard from "clipboardy" import consola from "consola" -import { serve, type ServerHandler } from "srvx" import invariant from "tiny-invariant" import { ensurePaths } from "./lib/paths" @@ -22,6 +21,7 @@ interface RunServerOptions { rateLimit?: number rateLimitWait: boolean githubToken?: string + apiKey?: string claudeCode: boolean showToken: boolean proxyEnv: boolean @@ -46,6 +46,11 @@ export async function runServer(options: RunServerOptions): Promise { state.rateLimitSeconds = options.rateLimit state.rateLimitWait = options.rateLimitWait state.showToken = options.showToken + state.apiKey = options.apiKey + + if (options.apiKey) { + consola.info("API key authentication enabled") + } await ensurePaths() await cacheVSCodeVersion() @@ -110,14 +115,13 @@ export async function runServer(options: RunServerOptions): Promise { } } - consola.box( - `🌐 Usage Viewer: https://ericc-ch.github.io/copilot-api?endpoint=${serverUrl}/usage`, - ) - - serve({ - fetch: server.fetch as ServerHandler, + Bun.serve({ + fetch: server.fetch, port: options.port, + idleTimeout: 255, }) + + consola.box(`🌐 Usage Viewer: ${serverUrl}?endpoint=${serverUrl}/usage`) } export const start = defineCommand({ @@ -184,6 +188,12 @@ export const start = defineCommand({ default: false, description: "Initialize proxy from environment variables", }, + "api-key": { + alias: "k", + type: "string", + description: + "API key for authentication (supports Authorization: Bearer or X-API-Key header)", + }, }, run({ args }) { const rateLimitRaw = args["rate-limit"] @@ -199,6 +209,7 @@ export const start = defineCommand({ rateLimit, rateLimitWait: args.wait, githubToken: args["github-token"], + apiKey: args["api-key"], claudeCode: args["claude-code"], showToken: args["show-token"], proxyEnv: args["proxy-env"], diff --git a/tests/billing-cycle.test.ts b/tests/billing-cycle.test.ts new file mode 100644 index 000000000..b6ec8b07f --- /dev/null +++ b/tests/billing-cycle.test.ts @@ -0,0 +1,154 @@ +import { beforeEach, describe, expect, test } from "bun:test" + +import { billingCycleManager } from "~/lib/billing-cycle" + +describe("Billing Cycle Manager", () => { + beforeEach(() => { + billingCycleManager.reset() + }) + + test("first request should bill (X-Initiator: user)", async () => { + const initiator = await billingCycleManager.determineInitiator() + expect(initiator).toBe("user") + }) + + test("second request in cycle should not bill (X-Initiator: agent)", async () => { + await billingCycleManager.determineInitiator() // First request + billingCycleManager.markResponseComplete() + + const initiator = await billingCycleManager.determineInitiator() + expect(initiator).toBe("agent") + }) + + test("concurrent first requests should only bill once", async () => { + // Simulate 3 concurrent requests + const [initiator1, initiator2, initiator3] = await Promise.all([ + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + ]) + + // Only the first should be billed + const initiators = [initiator1, initiator2, initiator3] + const billedCount = initiators.filter((x) => x === "user").length + const notBilledCount = initiators.filter((x) => x === "agent").length + + expect(billedCount).toBe(1) + expect(notBilledCount).toBe(2) + }) + + test("cycle should reset after 5 minutes of inactivity", async () => { + // First request + await billingCycleManager.determineInitiator() + billingCycleManager.markResponseComplete() + + // Second request (should not bill) + const initiator2 = await billingCycleManager.determineInitiator() + expect(initiator2).toBe("agent") + billingCycleManager.markResponseComplete() + + // Simulate waiting 5+ minutes by directly manipulating time + // We need to access the private lastResponseTime field + // Instead, let's test the reset() method + billingCycleManager.reset() + + // After reset, next request should bill again + const initiator3 = await billingCycleManager.determineInitiator() + expect(initiator3).toBe("user") + }) + + test("failed request should not enter billing cycle", async () => { + const initiator1 = await billingCycleManager.determineInitiator() + expect(initiator1).toBe("user") + + // Request fails + billingCycleManager.markRequestFailed() + + // Next request should still bill + const initiator2 = await billingCycleManager.determineInitiator() + expect(initiator2).toBe("user") + }) + + test("failed request after successful one should not affect cycle", async () => { + // First successful request + await billingCycleManager.determineInitiator() + billingCycleManager.markResponseComplete() + + // Second request starts but fails + const initiator2 = await billingCycleManager.determineInitiator() + expect(initiator2).toBe("agent") + billingCycleManager.markRequestFailed() + + // Third request should still not bill (cycle continues) + const initiator3 = await billingCycleManager.determineInitiator() + expect(initiator3).toBe("agent") + }) + + test("multiple concurrent requests after first should all be agent", async () => { + // First request + await billingCycleManager.determineInitiator() + billingCycleManager.markResponseComplete() + + // Multiple concurrent follow-up requests + const [i1, i2, i3, i4, i5] = await Promise.all([ + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + billingCycleManager.determineInitiator(), + ]) + + expect([i1, i2, i3, i4, i5]).toEqual([ + "agent", + "agent", + "agent", + "agent", + "agent", + ]) + }) + + test("getStatus returns current state", async () => { + let status = billingCycleManager.getStatus() + expect(status.inCycle).toBe(false) + expect(status.pendingResponses).toBe(0) + + await billingCycleManager.determineInitiator() + status = billingCycleManager.getStatus() + expect(status.inCycle).toBe(true) + expect(status.pendingResponses).toBe(1) + + billingCycleManager.markResponseComplete() + status = billingCycleManager.getStatus() + expect(status.inCycle).toBe(true) + expect(status.pendingResponses).toBe(0) + }) + + test("non-streaming request flow", async () => { + // Simulate non-streaming request + const initiator = await billingCycleManager.determineInitiator() + expect(initiator).toBe("user") + + // Response completes immediately + billingCycleManager.markResponseComplete() + + // Next request should not bill + const initiator2 = await billingCycleManager.determineInitiator() + expect(initiator2).toBe("agent") + }) + + test("streaming request flow", async () => { + // Simulate streaming request + const initiator = await billingCycleManager.determineInitiator() + expect(initiator).toBe("user") + + // Simulate streaming chunks (response not complete yet) + // ... streaming in progress ... + + // Response completes after all chunks sent + billingCycleManager.markResponseComplete() + + // Next request should not bill + const initiator2 = await billingCycleManager.determineInitiator() + expect(initiator2).toBe("agent") + }) +}) diff --git a/tests/create-chat-completions.test.ts b/tests/create-chat-completions.test.ts index d18e741aa..dd89b310e 100644 --- a/tests/create-chat-completions.test.ts +++ b/tests/create-chat-completions.test.ts @@ -1,7 +1,8 @@ -import { test, expect, mock } from "bun:test" +import { beforeEach, expect, mock, test } from "bun:test" import type { ChatCompletionsPayload } from "../src/services/copilot/create-chat-completions" +import { billingCycleManager } from "../src/lib/billing-cycle" import { state } from "../src/lib/state" import { createChatCompletions } from "../src/services/copilot/create-chat-completions" @@ -23,12 +24,14 @@ const fetchMock = mock( // @ts-expect-error - Mock fetch doesn't implement all fetch properties ;(globalThis as unknown as { fetch: typeof fetch }).fetch = fetchMock -test("sets X-Initiator to agent if tool/assistant present", async () => { +beforeEach(() => { + billingCycleManager.reset() + fetchMock.mockClear() +}) + +test("sets X-Initiator to user for first request in billing cycle", async () => { const payload: ChatCompletionsPayload = { - messages: [ - { role: "user", content: "hi" }, - { role: "tool", content: "tool call" }, - ], + messages: [{ role: "user", content: "hi" }], model: "gpt-test", } await createChatCompletions(payload) @@ -36,21 +39,45 @@ test("sets X-Initiator to agent if tool/assistant present", async () => { const headers = ( fetchMock.mock.calls[0][1] as { headers: Record } ).headers + expect(headers["X-Initiator"]).toBe("user") +}) + +test("sets X-Initiator to agent for subsequent requests in billing cycle", async () => { + // First request + const payload1: ChatCompletionsPayload = { + messages: [{ role: "user", content: "hi" }], + model: "gpt-test", + } + await createChatCompletions(payload1) + + // Second request (should be agent) + const payload2: ChatCompletionsPayload = { + messages: [{ role: "user", content: "hello again" }], + model: "gpt-test", + } + await createChatCompletions(payload2) + + expect(fetchMock).toHaveBeenCalledTimes(2) + const headers = ( + fetchMock.mock.calls[1][1] as { headers: Record } + ).headers expect(headers["X-Initiator"]).toBe("agent") }) -test("sets X-Initiator to user if only user present", async () => { +test("message role does not affect X-Initiator (billing cycle only)", async () => { + // First request with tool/assistant messages const payload: ChatCompletionsPayload = { messages: [ { role: "user", content: "hi" }, - { role: "user", content: "hello again" }, + { role: "tool", content: "tool call" }, ], model: "gpt-test", } await createChatCompletions(payload) expect(fetchMock).toHaveBeenCalled() const headers = ( - fetchMock.mock.calls[1][1] as { headers: Record } + fetchMock.mock.calls[0][1] as { headers: Record } ).headers + // Should still be "user" because it's the first request in the cycle expect(headers["X-Initiator"]).toBe("user") })