diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 00000000..d6bf9ba0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,158 @@ +# ============================================================================ +# OpenARC From Scratch - Ubuntu Base + Manual Intel Setup +# ============================================================================ +FROM ubuntu:24.04 + +ENV DEBIAN_FRONTEND=noninteractive + +# ============================================================================ +# System Dependencies +# ============================================================================ +RUN apt-get update && apt-get install -y \ + ca-certificates \ + curl \ + git \ + gpg \ + gpg-agent \ + wget \ + python3 \ + python3-venv \ + python3-dev \ + python3-pip && \ + update-alternatives --install /usr/bin/python python /usr/bin/python3 1 && \ + rm -rf /var/lib/apt/lists/* + +# ============================================================================ +# Intel GPU Drivers +# ============================================================================ +RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | \ + gpg --dearmor --output /usr/share/keyrings/intel-graphics.gpg && \ + echo "deb [arch=amd64 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/gpu/ubuntu noble client" | \ + tee /etc/apt/sources.list.d/intel-gpu-noble.list && \ + apt-get update && apt-get install -y \ + intel-opencl-icd \ + intel-level-zero-gpu \ + level-zero \ + level-zero-dev && \ + rm -rf /var/lib/apt/lists/* + +# ============================================================================ +# Intel NPU Driver +# ============================================================================ +RUN apt-get update && apt-get install -y \ + cmake \ + build-essential \ + libudev-dev && \ + git clone https://github.com/intel/linux-npu-driver.git /tmp/npu-driver && \ + cd /tmp/npu-driver && \ + git submodule update --init --recursive && \ + mkdir build && cd build && \ + cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/local .. && \ + make -j$(nproc) && \ + make install && \ + ldconfig && \ + cd / && rm -rf /tmp/npu-driver /var/lib/apt/lists/* + +# ============================================================================ +# Install uv package manager +# ============================================================================ +RUN curl -LsSf https://astral.sh/uv/install.sh | sh +ENV PATH="/root/.local/bin:$PATH" + +# ============================================================================ +# Clone and setup OpenArc +# ============================================================================ +WORKDIR /app +RUN git clone https://github.com/SearchSavior/OpenArc.git . && \ + echo "OpenARC version: $(git describe --tags --always)" + +# ============================================================================ +# Install Python dependencies with uv +# ============================================================================ +RUN uv sync && \ + uv pip install "optimum-intel[openvino] @ git+https://github.com/huggingface/optimum-intel" && \ + uv pip install --pre -U openvino-genai openvino-tokenizers \ + --extra-index-url https://storage.openvinotoolkit.org/simple/wheels/nightly + +# Add venv to PATH so openarc command works +ENV PATH="/app/.venv/bin:$PATH" + +# ============================================================================ +# Runtime Configuration +# ============================================================================ +ENV NEOReadDebugKeys=1 \ + OverrideGpuAddressSpace=48 \ + EnableImplicitScaling=1 \ + OPENARC_API_KEY=key \ + OPENARC_AUTOLOAD_MODEL="" + +# Create persistent config directory and symlink +RUN mkdir -p /persist && \ + ln -sf /persist/openarc_config.json /app/openarc_config.json + +# ============================================================================ +# Build Info Logging +# ============================================================================ +RUN echo "=== Build Information ===" > /app/BUILD_INFO.txt && \ + echo "Build Date: $(date -u +"%Y-%m-%d %H:%M:%S UTC")" >> /app/BUILD_INFO.txt && \ + echo "OpenARC Version: $(git describe --tags --always)" >> /app/BUILD_INFO.txt && \ + echo "" >> /app/BUILD_INFO.txt && \ + echo "=== Intel Package Versions ===" >> /app/BUILD_INFO.txt && \ + uv pip list | grep -E "(openvino|optimum|torch)" >> /app/BUILD_INFO.txt || true && \ + echo "" >> /app/BUILD_INFO.txt && \ + echo "=== System Package Versions ===" >> /app/BUILD_INFO.txt && \ + dpkg -l | grep -E "intel-opencl|level-zero" | awk '{print $2 " " $3}' >> /app/BUILD_INFO.txt || true + +# ============================================================================ +# Startup Script +# ============================================================================ +RUN cat > /usr/local/bin/start-openarc.sh <<'SCRIPT' +#!/bin/bash +set -e + +echo "================================================" +echo "=== Starting OpenArc Server ===" +echo "================================================" + +if [ -f /app/BUILD_INFO.txt ]; then + cat /app/BUILD_INFO.txt + echo "" +fi + +echo "=== Runtime Configuration ===" +echo "Port: 8000" +echo "API Key: ${OPENARC_API_KEY:0:10}..." +echo "Auto-load Model: ${OPENARC_AUTOLOAD_MODEL:-none}" +echo "" +echo "================================================" + +# Start server in background +openarc serve start --host 0.0.0.0 --openarc-port 8000 & +SERVER_PID=$! + +# Auto-load model if specified +if [ -n "$OPENARC_AUTOLOAD_MODEL" ]; then + echo "Waiting for server to start..." + for i in {1..30}; do + if curl -s -f -H "Authorization: Bearer ${OPENARC_API_KEY}" http://localhost:8000/v1/models >/dev/null 2>&1; then + echo "Server ready after $i seconds" + echo "Auto-loading model: $OPENARC_AUTOLOAD_MODEL" + openarc load "$OPENARC_AUTOLOAD_MODEL" || echo "Failed to auto-load model" + break + fi + sleep 1 + done +fi + +# Wait for server +wait $SERVER_PID +SCRIPT + +RUN chmod +x /usr/local/bin/start-openarc.sh + +EXPOSE 8000 + +HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ + CMD curl -f -H "Authorization: Bearer ${OPENARC_API_KEY}" http://localhost:8000/v1/models || exit 1 + +CMD ["/usr/local/bin/start-openarc.sh"] diff --git a/README.md b/README.md index b4e860ea..53098aa4 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,7 @@ Thanks to everyone on Discord for their continued support! - [Quickstart](#quickstart) - [Linux](#linux) - [Windows](#windows) + - [Docker](#docker) - [OpenArc CLI](#openarc-cli) - [openarc add](#openarc-add) - [openarc list](#openarc-list) @@ -45,7 +46,8 @@ Thanks to everyone on Discord for their continued support! - [Codebase Documentation](./docs/index.md) ## Features - + - NEW! Containerization with Docker #60 by @meatposes + - NEW! Speculative decoding support for LLMs #57 by @meatposes - Multi GPU Pipeline Paralell - CPU offload/Hybrid device - NPU device support @@ -178,6 +180,41 @@ openarc --help +
+Docker + +
+ +Instead of fighting with Intel's own docker images, we built our own which is as close to boilerplate as possible. If you need a primer on docker [check out this video](https://www.youtube.com/watch?v=DQdB7wFEygo). + + +**Build and run the container:** +```bash +docker-compose up --build -d +``` + +**Run the container:** +```bash +docker run -d -p 8000:8000 openarc:latest +``` +**Enter the container:** +```bash +docker exec -it openarc /bin/bash +``` + +## Environment Variables + +```bash +export OPENARC_API_KEY="openarc-api-key" # default, set it to whatever you want +export OPENARC_AUTOLOAD_MODEL="model_name" # model_name to load on startup +export MODEL_PATH="/path/to/your/models" # mount your models to `/models` inside the container +docker-compose up --build -d +``` + + +Take a look at the [Dockerfile](Dockerfile) and [docker-compose](docker-compose.yaml) for more details. + +
> [!NOTE] > Need help installing drivers? [Join our Discord](https://discord.gg/Bzz9hax9Jq) or open an issue. diff --git a/src/engine/ov_genai/llm.py b/src/engine/ov_genai/llm.py index 5b00e0e7..80a406aa 100755 --- a/src/engine/ov_genai/llm.py +++ b/src/engine/ov_genai/llm.py @@ -13,7 +13,7 @@ from src.server.models.ov_genai import OVGenAI_GenConfig from src.server.model_registry import ModelRegistry from src.server.models.registration import ModelLoadConfig -from src.engine.ov_genai.streamers import ChunkStreamer +from src.engine.ov_genai.streamers import ChunkStreamer, BlockStreamer from src.server.utils.chat import flatten_messages logging.basicConfig( @@ -28,6 +28,9 @@ def __init__(self, load_config: ModelLoadConfig): self.model_path = None self.encoder_tokenizer = None self.load_config = load_config + # Track active streaming requests for cancellation + self._active_streamer: Optional[ChunkStreamer] = None + self._active_request_id: Optional[str] = None def prepare_inputs(self, messages: List[Dict[str, Any]], @@ -54,28 +57,36 @@ def prepare_inputs(self, return_tensors="np" ) return ov.Tensor(prompt_token_ids) - - def generate_type(self, gen_config: OVGenAI_GenConfig): + + async def cancel(self, request_id: str) -> bool: """ - Unified text generation method that routes to streaming or non-streaming - based on the stream flag in gen_config. Both paths return an async iterator. - + Cancel an ongoing streaming generation by request_id. + Args: - gen_config: Configuration containing the stream flag and other parameters - + request_id: The request ID to cancel + Returns: - - Non-streaming: async iterator yielding [metrics: dict, new_text: str] - - Streaming: async iterator yielding token chunks (str)... then [metrics: dict, new_text: str] + True if cancellation was triggered, False if request_id didn't match """ - if gen_config.stream: - return self.generate_stream(gen_config) - else: - return self.generate_text(gen_config) - - async def generate_text(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[Dict[str, Any], str]]: + if self._active_request_id == request_id and self._active_streamer is not None: + self._active_streamer.cancel() + logger.info(f"[{self.load_config.model_name}] Cancellation triggered for request {request_id}") + return True + return False + + async def arc_infer(self, gen_config: OVGenAI_GenConfig, request_id: Optional[str] = None) -> AsyncIterator[Union[str, Dict[str, Any]]]: """ - Async non-streaming text generation. - Yields in order: metrics (dict), new_text (str). + Unified inference method that uses appropriate streamer based on stream flag. + - stream=True: Uses ChunkStreamer for incremental token streaming + - stream=False: Uses BlockStreamer for single-block output + + Args: + gen_config: Configuration containing generation parameters including stream flag + request_id: Optional request ID for tracking cancellation + + Yields: + Token chunks (str) as they arrive, then metrics (dict) at the end. + For non-streaming (stream=False), yields a single chunk with all tokens. """ generation_kwargs = GenerationConfig( max_new_tokens=gen_config.max_tokens, @@ -85,6 +96,23 @@ async def generate_text(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Un repetition_penalty=gen_config.repetition_penalty, ) + decoder_tokenizer = self.model.get_tokenizer() + + # Select appropriate streamer based on stream flag + if gen_config.stream: + # Streaming mode: use ChunkStreamer with configured chunk size + from copy import deepcopy + streamer_config = deepcopy(gen_config) + streamer_config.stream_chunk_tokens = gen_config.stream_chunk_tokens + streamer = ChunkStreamer(decoder_tokenizer, streamer_config) + else: + # Non-streaming mode: use BlockStreamer for single-block output + streamer = BlockStreamer(decoder_tokenizer) + + # Track active streamer for cancellation + self._active_streamer = streamer + self._active_request_id = request_id + # Support pre-encoded input_ids, raw prompts, and chat messages if gen_config.input_ids: # Pre-encoded input IDs (used by /openarc/bench endpoint for benchmarking) @@ -96,40 +124,6 @@ async def generate_text(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Un else: # Chat template tokenization for messages (used by /v1/chat/completions endpoint) prompt_token_ids = self.prepare_inputs(gen_config.messages, gen_config.tools) - - result = await asyncio.to_thread(self.model.generate, prompt_token_ids, generation_kwargs) - - perf_metrics = result.perf_metrics - decoder_tokenizer = self.model.get_tokenizer() - text = decoder_tokenizer.decode(result.tokens)[0] if getattr(result, "tokens", None) else "" - - metrics_dict = self.collect_metrics(gen_config, perf_metrics) - yield metrics_dict - yield text - - async def generate_stream(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[str, Dict[str, Any]]]: - """ - Async streaming text generation. - Yields token chunks (str) as they arrive, then metrics (dict), then final new_text (str). - """ - generation_kwargs = GenerationConfig( - max_new_tokens=gen_config.max_tokens, - temperature=gen_config.temperature, - top_k=gen_config.top_k, - top_p=gen_config.top_p, - repetition_penalty=gen_config.repetition_penalty - ) - - decoder_tokenizer = self.model.get_tokenizer() - streamer = ChunkStreamer(decoder_tokenizer, gen_config) - - # Support both chat messages and raw prompts - if gen_config.prompt: - # Direct tokenization for raw text (used by /v1/completions endpoint) - prompt_token_ids = ov.Tensor(self.encoder_tokenizer.encode(gen_config.prompt, return_tensors="np")) - else: - # Chat template tokenization for messages (used by /v1/chat/completions endpoint) - prompt_token_ids = self.prepare_inputs(gen_config.messages, gen_config.tools) async def _run_generation(): return await asyncio.to_thread( @@ -147,14 +141,20 @@ async def _run_generation(): if chunk is None: break yield chunk - finally: - result = await gen_task - perf_metrics = result.perf_metrics - metrics = self.collect_metrics(gen_config, perf_metrics) - - yield metrics - + # Clear active streamer tracking + self._active_streamer = None + self._active_request_id = None + # Wait for generation task to complete (may be cancelled) + try: + result = await gen_task + perf_metrics = result.perf_metrics + metrics = self.collect_metrics(gen_config, perf_metrics) + yield metrics + except Exception: + # Generation was cancelled or failed, don't yield metrics + pass + def collect_metrics(self, gen_config: OVGenAI_GenConfig, perf_metrics) -> Dict[str, Any]: """ Collect and format performance metrics into a dictionary. @@ -232,5 +232,3 @@ async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool: gc.collect() logging.info(f"[{self.load_config.model_name}] unloaded successfully") return removed - - diff --git a/src/engine/ov_genai/streamers.py b/src/engine/ov_genai/streamers.py index a2f67208..43fede18 100644 --- a/src/engine/ov_genai/streamers.py +++ b/src/engine/ov_genai/streamers.py @@ -21,8 +21,15 @@ def __init__(self, decoder_tokenizer, gen_config: OVGenAI_GenConfig): self.since_last_emit: int = 0 # tokens collected since last emit self.last_print_len: int = 0 # length of decoded text we've already emitted self.text_queue: "asyncio.Queue[Optional[str]]" = asyncio.Queue() + self._cancelled = asyncio.Event() # cancellation flag for thread-safe signaling def write(self, token: Union[int, List[int]]) -> openvino_genai.StreamingStatus: + # Check for cancellation first + if self._cancelled.is_set(): + # Signal completion to the queue so the consumer can exit + self.text_queue.put_nowait(None) + return openvino_genai.StreamingStatus.CANCEL + # Normalize input to a list of ints if isinstance(token, list): self.tokens_cache.extend(token) @@ -44,6 +51,14 @@ def write(self, token: Union[int, List[int]]) -> openvino_genai.StreamingStatus: return openvino_genai.StreamingStatus.RUNNING + def cancel(self) -> None: + """Signal cancellation of the streaming generation.""" + self._cancelled.set() + + def is_cancelled(self) -> bool: + """Check if cancellation has been signaled.""" + return self._cancelled.is_set() + def end(self) -> None: # Flush any remaining tokens at the end text = self.decoder_tokenizer.decode(self.tokens_cache) @@ -53,3 +68,50 @@ def end(self) -> None: self.text_queue.put_nowait(chunk) # Signal completion self.text_queue.put_nowait(None) + + +class BlockStreamer(StreamerBase): + """ + Non-streaming (block) mode streamer. + Collects all tokens during generation and emits the complete text as a single block + when generation ends. Used for stream=False mode. + + Unlike ChunkStreamer, this does not emit partial results during generation - + the entire response is yielded at once. + """ + def __init__(self, decoder_tokenizer): + super().__init__() + self.decoder_tokenizer = decoder_tokenizer + self.tokens_cache: List[int] = [] + self.text_queue: "asyncio.Queue[Optional[str]]" = asyncio.Queue() + self._cancelled = asyncio.Event() + + def write(self, token: Union[int, List[int]]) -> openvino_genai.StreamingStatus: + # Check for cancellation first + if self._cancelled.is_set(): + self.text_queue.put_nowait(None) + return openvino_genai.StreamingStatus.CANCEL + + # Collect tokens without emitting + if isinstance(token, list): + self.tokens_cache.extend(token) + else: + self.tokens_cache.append(token) + + return openvino_genai.StreamingStatus.RUNNING + + def cancel(self) -> None: + """Signal cancellation of the generation.""" + self._cancelled.set() + + def is_cancelled(self) -> bool: + """Check if cancellation has been signaled.""" + return self._cancelled.is_set() + + def end(self) -> None: + # Decode and emit all tokens as a single block + text = self.decoder_tokenizer.decode(self.tokens_cache) + if text: + self.text_queue.put_nowait(text) + # Signal completion + self.text_queue.put_nowait(None) diff --git a/src/engine/ov_genai/vlm.py b/src/engine/ov_genai/vlm.py index 381bcb76..6a1beb3c 100644 --- a/src/engine/ov_genai/vlm.py +++ b/src/engine/ov_genai/vlm.py @@ -31,6 +31,9 @@ def __init__(self, load_config: ModelLoadConfig): self.tokenizer = None self.vision_token = None self.load_config = load_config + # Track active streaming requests for cancellation + self._active_streamer: Optional[ChunkStreamer] = None + self._active_request_id: Optional[str] = None def _vision_token_for_index(self, index: int) -> str: """ @@ -123,56 +126,34 @@ def prepare_inputs(self, return tokenized_messages, ov_images - def generate_type(self, gen_config: OVGenAI_GenConfig): + async def cancel(self, request_id: str) -> bool: """ - Unified generation method that routes to streaming or non-streaming - based on the stream flag in gen_config. Both paths return an async iterator. - """ - if gen_config.stream: - return self.generate_stream(gen_config) - else: - return self.generate_text(gen_config) - - async def generate_text(self, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[Dict[str, Any], str]]: - """ - Async non-streaming generation for VLM. - Yields in order: metrics (dict), new_text (str). - """ - try: - generation_kwargs = GenerationConfig( - max_new_tokens=gen_config.max_tokens, - temperature=gen_config.temperature, - top_k=gen_config.top_k, - top_p=gen_config.top_p, - repetition_penalty=gen_config.repetition_penalty, - ) + Cancel an ongoing streaming generation by request_id. - prompt, ov_images = self.prepare_inputs(gen_config.messages, gen_config.tools) - - result = await asyncio.to_thread( - self.model_path.generate, - prompt=prompt, - **({'images': ov_images} if len(ov_images) > 0 else {}), - generation_config=generation_kwargs, - ) + Args: + request_id: The request ID to cancel - perf_metrics = result.perf_metrics + Returns: + True if cancellation was triggered, False if request_id didn't match + """ + if self._active_request_id == request_id and self._active_streamer is not None: + self._active_streamer.cancel() + logger.info(f"[{self.load_config.model_name}] Cancellation triggered for request {request_id}") + return True + return False - text = result.texts[0] if getattr(result, "texts", None) else "" - logger.info(f"[{self.load_config.model_name}] Generation completed, generated {len(text)} characters") + async def arc_infer(self, gen_config: OVGenAI_GenConfig, request_id: Optional[str] = None) -> AsyncIterator[Union[str, Dict[str, Any]]]: + """ + Unified inference method that uses ChunkStreamer for both streaming and non-streaming modes. + Dynamically adjusts stream_chunk_tokens based on the stream flag. - metrics_dict = self.collect_metrics(gen_config, perf_metrics) - yield metrics_dict - yield text - except Exception as e: - logger.error(f"[{self.load_config.model_name}] Error during non-streaming generation: {e}", exc_info=True) - raise + Args: + gen_config: Configuration containing generation parameters including stream flag + request_id: Optional request ID for tracking cancellation - async def generate_stream(self, - gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[str, Dict[str, Any]]]: - """ - Async streaming generation for VLM. - Yields token chunks (str) as they arrive, then metrics (dict). + Yields: + Token chunks (str) as they arrive, then metrics (dict) at the end. + For non-streaming (stream=False), yields a single chunk with all tokens. """ generation_kwargs = GenerationConfig( max_new_tokens=gen_config.max_tokens, @@ -183,7 +164,26 @@ async def generate_stream(self, ) decoder_tokenizer = self.model_path.get_tokenizer() - streamer = ChunkStreamer(decoder_tokenizer, gen_config) + + # Dynamically set stream_chunk_tokens based on stream flag + # Non-streaming: emit all tokens at once by setting to max_tokens + # Streaming: use configured stream_chunk_tokens value + if gen_config.stream: + chunk_tokens = gen_config.stream_chunk_tokens + else: + chunk_tokens = gen_config.max_tokens + + # Create a modified gen_config for ChunkStreamer with the adjusted chunk size + from copy import deepcopy + streamer_config = deepcopy(gen_config) + streamer_config.stream_chunk_tokens = chunk_tokens + + streamer = ChunkStreamer(decoder_tokenizer, streamer_config) + + # Track active streamer for cancellation + self._active_streamer = streamer + self._active_request_id = request_id + prompt, ov_images = self.prepare_inputs(gen_config.messages, gen_config.tools) async def _run_generation(): @@ -204,10 +204,18 @@ async def _run_generation(): break yield chunk finally: - result = await gen_task - perf_metrics = result.perf_metrics - metrics = self.collect_metrics(gen_config, perf_metrics) - yield metrics + # Clear active streamer tracking + self._active_streamer = None + self._active_request_id = None + # Wait for generation task to complete (may be cancelled) + try: + result = await gen_task + perf_metrics = result.perf_metrics + metrics = self.collect_metrics(gen_config, perf_metrics) + yield metrics + except Exception: + # Generation was cancelled or failed, don't yield metrics + pass def collect_metrics(self, gen_config: OVGenAI_GenConfig, perf_metrics) -> Dict[str, Any]: """ @@ -279,4 +287,3 @@ async def unload_model(self, registry: ModelRegistry, model_name: str) -> bool: gc.collect() logger.info(f"[{self.load_config.model_name}] unloaded successfully") return removed - diff --git a/src/server/main.py b/src/server/main.py index da67026a..84eec0b9 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -2,6 +2,7 @@ # They are one hero among many future heroes working to make OpenArc better. import datetime +import asyncio import json import logging import os @@ -227,10 +228,16 @@ async def benchmark(request: OpenArcBenchRequest): config_kwargs = {k: v for k, v in config_kwargs.items() if v is not None} generation_config = OVGenAI_GenConfig(**config_kwargs) - - result = await _workers.generate(request.model, generation_config) - metrics = result.get("metrics", {}) or {} - + + # Collect results from arc_generate + text = "" + metrics = {} + async for item in _workers.arc_generate(request.model, generation_config): + if isinstance(item, dict): + metrics = item + else: + text = item + logger.info(f"[bench] model={request.model} input_ids_len={len(request.input_ids)} metrics={metrics}") return {"metrics": metrics} @@ -297,105 +304,118 @@ async def event_stream() -> AsyncIterator[bytes]: accumulated_text = "" metrics_data = None tool_call_sent = False - - async for item in _workers.stream_generate(model_name, generation_config): - if isinstance(item, dict): - metrics_data = item.get("metrics", item) - continue - - accumulated_text += item - tool_calls = parse_tool_calls(accumulated_text) - - # If tool call detected and not yet sent, stream tool call deltas - if tool_calls and not tool_call_sent: - tool_call_sent = True - # Send tool call structure - for idx, tc in enumerate(tool_calls): - # Initial tool call with id, type, name - tool_call_start = { - 'id': request_id, - 'object': 'chat.completion.chunk', - 'created': created_ts, - 'model': model_name, - 'choices': [{ - 'index': 0, - 'delta': { - 'tool_calls': [{ - 'index': idx, - 'id': tc['id'], - 'type': tc['type'], - 'function': {'name': tc['function']['name'], 'arguments': ''} - }] - }, - 'finish_reason': None - }] - } - yield (f"data: {json.dumps(tool_call_start)}\n\n").encode() - - # Stream arguments - tool_call_args = { - 'id': request_id, - 'object': 'chat.completion.chunk', - 'created': created_ts, - 'model': model_name, - 'choices': [{ - 'index': 0, - 'delta': { - 'tool_calls': [{ - 'index': idx, - 'function': {'arguments': tc['function']['arguments']} - }] - }, - 'finish_reason': None - }] + + try: + async for item in _workers.arc_generate(model_name, generation_config): + if isinstance(item, dict): + metrics_data = item.get("metrics", item) + continue + + accumulated_text += item + tool_calls = parse_tool_calls(accumulated_text) + + # If tool call detected and not yet sent, stream tool call deltas + if tool_calls and not tool_call_sent: + tool_call_sent = True + # Send tool call structure + for idx, tc in enumerate(tool_calls): + # Initial tool call with id, type, name + tool_call_start = { + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created_ts, + 'model': model_name, + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': idx, + 'id': tc['id'], + 'type': tc['type'], + 'function': {'name': tc['function']['name'], 'arguments': ''} + }] + }, + 'finish_reason': None + }] + } + yield (f"data: {json.dumps(tool_call_start)}\n\n").encode() + + # Stream arguments + tool_call_args = { + 'id': request_id, + 'object': 'chat.completion.chunk', + 'created': created_ts, + 'model': model_name, + 'choices': [{ + 'index': 0, + 'delta': { + 'tool_calls': [{ + 'index': idx, + 'function': {'arguments': tc['function']['arguments']} + }] + }, + 'finish_reason': None + }] + } + yield (f"data: {json.dumps(tool_call_args)}\n\n").encode() + elif not tool_calls: + # Regular content streaming + chunk_payload = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_ts, + "model": model_name, + "choices": [{ + "index": 0, + "delta": {"content": item}, + "finish_reason": None, + }], } - yield (f"data: {json.dumps(tool_call_args)}\n\n").encode() - elif not tool_calls: - # Regular content streaming - chunk_payload = { - "id": request_id, - "object": "chat.completion.chunk", - "created": created_ts, - "model": model_name, - "choices": [{ - "index": 0, - "delta": {"content": item}, - "finish_reason": None, - }], - } - yield (f"data: {json.dumps(chunk_payload)}\n\n").encode() + yield (f"data: {json.dumps(chunk_payload)}\n\n").encode() - # Final chunk - prompt_tokens = (metrics_data or {}).get("input_token", 0) - completion_tokens = (metrics_data or {}).get("new_token", 0) - total_tokens = (metrics_data or {}).get("total_token", prompt_tokens + completion_tokens) - - finish_reason = "tool_calls" if tool_call_sent else "stop" - - final_payload = { - "id": request_id, - "object": "chat.completion.chunk", - "created": created_ts, - "model": model_name, - "choices": [{ - "index": 0, - "delta": {}, - "finish_reason": finish_reason, - }], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - }, - } - yield (f"data: {json.dumps(final_payload)}\n\n").encode() - yield b"data: [DONE]\n\n" + # Final chunk + prompt_tokens = (metrics_data or {}).get("input_token", 0) + completion_tokens = (metrics_data or {}).get("new_token", 0) + total_tokens = (metrics_data or {}).get("total_token", prompt_tokens + completion_tokens) + + finish_reason = "tool_calls" if tool_call_sent else "stop" + + final_payload = { + "id": request_id, + "object": "chat.completion.chunk", + "created": created_ts, + "model": model_name, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": finish_reason, + }], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, + } + yield (f"data: {json.dumps(final_payload)}\n\n").encode() + yield b"data: [DONE]\n\n" + + except asyncio.CancelledError: + # Client disconnected - trigger cancellation for KV cache cleanup + logger.info(f"[chat/completions] Client disconnected, cancelling request {request_id}") + await _workers.cancel(request_id) + # Re-raise to let StreamingResponse handle cleanup + raise return StreamingResponse(event_stream(), media_type="text/event-stream") else: - result = await _workers.generate(model_name, generation_config) - text = result.get("text", "") - metrics = result.get("metrics", {}) or {} + # Non-streaming: collect from arc_generate + text = "" + metrics = {} + async for item in _workers.arc_generate(model_name, generation_config): + if isinstance(item, dict): + metrics = item + else: + text = item prompt_tokens = metrics.get("input_token", 0) completion_tokens = metrics.get("new_token", 0) @@ -467,12 +487,35 @@ async def openai_completions(request: OpenAICompletionRequest): async def event_stream() -> AsyncIterator[bytes]: # Stream OpenAI-compatible chunks metrics_data = None - async for item in _workers.stream_generate(model_name, generation_config): - if isinstance(item, dict): - # Capture metrics for final usage payload - metrics_data = item.get("metrics", item) - continue - chunk_payload = { + try: + async for item in _workers.stream_generate(model_name, generation_config, request_id): + if isinstance(item, dict): + # Capture metrics for final usage payload + metrics_data = item.get("metrics", item) + continue + chunk_payload = { + "id": request_id, + "object": "text_completion.chunk", + "created": created_ts, + "model": model_name, + "choices": [ + { + "index": 0, + "text": item, + "finish_reason": None, + } + ], + } + yield (f"data: {json.dumps(chunk_payload)}\n\n").encode() + + # Final stop signal per OpenAI SSE with usage + prompt_tokens = (metrics_data or {}).get("input_token", 0) + completion_tokens = (metrics_data or {}).get("new_token", 0) + total_tokens = (metrics_data or {}).get("total_token", prompt_tokens + completion_tokens) + + logger.info(f"[completions] stream=true model={model_name} metrics={metrics_data}") + + final_payload = { "id": request_id, "object": "text_completion.chunk", "created": created_ts, @@ -480,40 +523,24 @@ async def event_stream() -> AsyncIterator[bytes]: "choices": [ { "index": 0, - "text": item, - "finish_reason": None, + "text": "", + "finish_reason": "stop", } ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens, + }, } - yield (f"data: {json.dumps(chunk_payload)}\n\n").encode() - - # Final stop signal per OpenAI SSE with usage - prompt_tokens = (metrics_data or {}).get("input_token", 0) - completion_tokens = (metrics_data or {}).get("new_token", 0) - total_tokens = (metrics_data or {}).get("total_token", prompt_tokens + completion_tokens) - - logger.info(f"[completions] stream=true model={model_name} metrics={metrics_data}") - - final_payload = { - "id": request_id, - "object": "text_completion.chunk", - "created": created_ts, - "model": model_name, - "choices": [ - { - "index": 0, - "text": "", - "finish_reason": "stop", - } - ], - "usage": { - "prompt_tokens": prompt_tokens, - "completion_tokens": completion_tokens, - "total_tokens": total_tokens, - }, - } - yield (f"data: {json.dumps(final_payload)}\n\n").encode() - yield b"data: [DONE]\n\n" + yield (f"data: {json.dumps(final_payload)}\n\n").encode() + yield b"data: [DONE]\n\n" + + except asyncio.CancelledError: + # Client disconnected - trigger cancellation for KV cache cleanup + logger.info(f"[completions] Client disconnected, cancelling request {request_id}") + await _workers.cancel(request_id) + raise return StreamingResponse(event_stream(), media_type="text/event-stream") else: diff --git a/src/server/worker_registry.py b/src/server/worker_registry.py index 19627534..51a1e1fe 100644 --- a/src/server/worker_registry.py +++ b/src/server/worker_registry.py @@ -6,7 +6,7 @@ import torch import soundfile as sf from dataclasses import dataclass -from typing import Any, AsyncIterator, Dict, Optional, Union +from typing import Any, AsyncIterator, Dict, Optional, Tuple, Union from src.engine.ov_genai.llm import OVGenAI_LLM from src.engine.ov_genai.vlm import OVGenAI_VLM @@ -67,7 +67,6 @@ class InferWorker: Responsibilities: - Execute generation requests using pipelines - Methods: - infer_llm: Process text-to-text generation requests - infer_vlm: Process image-to-text generation requests @@ -78,13 +77,20 @@ class InferWorker: """ @staticmethod - async def infer_llm(packet: WorkerPacket, llm_instance: OVGenAI_LLM) -> WorkerPacket: + async def infer_llm(packet: WorkerPacket, llm_instance: OVGenAI_LLM, registry: 'WorkerRegistry' = None) -> WorkerPacket: """Generate text for a single packet using the OVGenAI_LLM pipeline""" metrics = None final_text = "" try: - async for item in llm_instance.generate_type(packet.gen_config): + # Register model instance for cancellation tracking + if registry is not None: + async with registry._lock: + if packet.request_id in registry._active_requests: + model_name, _ = registry._active_requests[packet.request_id] + registry._active_requests[packet.request_id] = (model_name, llm_instance) + + async for item in llm_instance.arc_infer(packet.gen_config, packet.request_id): if isinstance(item, dict): metrics = item else: @@ -110,17 +116,29 @@ async def infer_llm(packet: WorkerPacket, llm_instance: OVGenAI_LLM) -> WorkerPa # Signal error to stream if streaming if packet.gen_config.stream and packet.stream_queue is not None: await packet.stream_queue.put(None) - + + # Clean up active request tracking + if registry is not None: + async with registry._lock: + registry._active_requests.pop(packet.request_id, None) + return packet @staticmethod - async def infer_vlm(packet: WorkerPacket, vlm_model: OVGenAI_VLM) -> WorkerPacket: + async def infer_vlm(packet: WorkerPacket, vlm_model: OVGenAI_VLM, registry: 'WorkerRegistry' = None) -> WorkerPacket: """Generate text from image for a single packet using the OVGenAI_VLM pipeline""" metrics = None final_text = "" try: - async for item in vlm_model.generate_type(packet.gen_config): + # Register model instance for cancellation tracking + if registry is not None: + async with registry._lock: + if packet.request_id in registry._active_requests: + model_name, _ = registry._active_requests[packet.request_id] + registry._active_requests[packet.request_id] = (model_name, vlm_model) + + async for item in vlm_model.arc_infer(packet.gen_config, packet.request_id): if isinstance(item, dict): metrics = item else: @@ -146,7 +164,12 @@ async def infer_vlm(packet: WorkerPacket, vlm_model: OVGenAI_VLM) -> WorkerPacke # Signal error to stream if streaming if packet.gen_config.stream and packet.stream_queue is not None: await packet.stream_queue.put(None) - + + # Clean up active request tracking + if registry is not None: + async with registry._lock: + registry._active_requests.pop(packet.request_id, None) + return packet @staticmethod @@ -281,81 +304,40 @@ async def infer_rerank(packet: WorkerPacket, rerank_instance: Optimum_RR) -> Wor class QueueWorker: """ Manages inference worker loops for consuming and processing packets from model queues. - - """ - - @staticmethod - async def queue_worker_llm(model_name: str, model_queue: asyncio.Queue, llm_model: OVGenAI_LLM, registry: ModelRegistry): - """Text model inference worker that processes packets from queue""" - logger.info(f"[{model_name} LLM Worker] Started, waiting for packets...") - while True: - packet = await model_queue.get() - if packet is None: - logger.info(f"[{model_name} LLM Worker] Shutdown signal received.") - break - - completed_packet = await InferWorker.infer_llm(packet, llm_model) - # Check if inference failed and trigger model unload - if completed_packet.response and completed_packet.response.startswith("Error:"): - logger.error(f"[{model_name} LLM Worker] Inference failed, triggering model unload...") - asyncio.create_task(registry.register_unload(model_name)) - break - - if completed_packet.metrics: - logger.info(f"[{model_name} LLM Worker] Metrics: {completed_packet.metrics}") - - if packet.result_future is not None and not packet.result_future.done(): - packet.result_future.set_result(completed_packet) - - model_queue.task_done() - - @staticmethod - async def queue_worker_vlm(model_name: str, model_queue: asyncio.Queue, vlm_model: OVGenAI_VLM, registry: ModelRegistry): - """Image model inference worker that processes packets from queue""" - logger.info(f"[{model_name} VLM Worker] Started, waiting for packets...") - while True: - packet = await model_queue.get() - if packet is None: - logger.info(f"[{model_name} VLM Worker] Shutdown signal received.") - break - - completed_packet = await InferWorker.infer_vlm(packet, vlm_model) - - # Check if inference failed and trigger model unload - if completed_packet.response and completed_packet.response.startswith("Error:"): - logger.error(f"[{model_name} VLM Worker] Inference failed, triggering model unload...") - asyncio.create_task(registry.register_unload(model_name)) - break - - if completed_packet.metrics: - logger.info(f"[{model_name} VLM Worker] Metrics: {completed_packet.metrics}") - - if packet.result_future is not None and not packet.result_future.done(): - packet.result_future.set_result(completed_packet) - - model_queue.task_done() + Uses a factory pattern to create worker coroutines dynamically based on model type. + """ @staticmethod - async def queue_worker_whisper(model_name: str, model_queue: asyncio.Queue, whisper_model: OVGenAI_Whisper, registry: ModelRegistry): - """Whisper model inference worker that processes packets from queue""" - logger.info(f"[{model_name} Whisper Worker] Started, waiting for packets...") + async def _generic_worker( + model_name: str, + model_queue: asyncio.Queue, + model_instance: Any, + registry: ModelRegistry, + worker_type: str, + infer_method: callable, + error_check_fn: callable, + worker_registry: 'WorkerRegistry' = None, + ) -> None: + """Generic worker loop that processes packets from queue using provided inference method.""" + logger.info(f"[{model_name} {worker_type} Worker] Started, waiting for packets...") while True: packet = await model_queue.get() if packet is None: - logger.info(f"[{model_name} Whisper Worker] Shutdown signal received.") + logger.info(f"[{model_name} {worker_type} Worker] Shutdown signal received.") break - completed_packet = await InferWorker.infer_whisper(packet, whisper_model) + # Pass worker_registry for cancellation tracking + completed_packet = await infer_method(packet, model_instance, worker_registry) # Check if inference failed and trigger model unload - if completed_packet.response and completed_packet.response.startswith("Error:"): - logger.error(f"[{model_name} Whisper Worker] Inference failed, triggering model unload...") + if error_check_fn(completed_packet): + logger.error(f"[{model_name} {worker_type} Worker] Inference failed, triggering model unload...") asyncio.create_task(registry.register_unload(model_name)) break if completed_packet.metrics: - logger.info(f"[{model_name} Whisper Worker] Metrics: {completed_packet.metrics}") + logger.info(f"[{model_name} {worker_type} Worker] Metrics: {completed_packet.metrics}") if packet.result_future is not None and not packet.result_future.done(): packet.result_future.set_result(completed_packet) @@ -363,76 +345,72 @@ async def queue_worker_whisper(model_name: str, model_queue: asyncio.Queue, whis model_queue.task_done() @staticmethod - async def queue_worker_kokoro(model_name: str, model_queue: asyncio.Queue, kokoro_model: OV_Kokoro, registry: ModelRegistry): - """Kokoro model inference worker that processes packets from queue""" - logger.info(f"[{model_name} Kokoro Worker] Started, waiting for packets...") - while True: - packet = await model_queue.get() - if packet is None: - logger.info(f"[{model_name} Kokoro Worker] Shutdown signal received.") - break - - completed_packet = await InferWorker.infer_kokoro(packet, kokoro_model) - - # Check if inference failed and trigger model unload - if completed_packet.response and completed_packet.response.startswith("Error:"): - logger.error(f"[{model_name} Kokoro Worker] Inference failed, triggering model unload...") - asyncio.create_task(registry.register_unload(model_name)) - break - - # Log the text that was converted to speech - - if completed_packet.metrics: - logger.info(f"[{model_name} Kokoro Worker] Metrics: {completed_packet.metrics}") - - if packet.result_future is not None and not packet.result_future.done(): - packet.result_future.set_result(completed_packet) - - model_queue.task_done() - - @staticmethod - async def queue_worker_emb(model_name: str, model_queue: asyncio.Queue, emb_model: Optimum_EMB, registry: ModelRegistry): - """EMB model inference worker that processes packets from queue""" - logger.info(f"[{model_name} EMB Worker] Started, waiting for packets...") - while True: - packet = await model_queue.get() - if packet is None: - logger.info(f"[{model_name} EMB Worker] Shutdown signal received.") - break - - completed_packet = await InferWorker.infer_emb(packet, emb_model) - # Check if inference failed and trigger model unload - if not completed_packet.response: - logger.error(f"[{model_name} EMB Worker] Inference failed, triggering model unload...") - asyncio.create_task(registry.register_unload(model_name)) - break - if completed_packet.metrics: - logger.info(f"[{model_name} LLM Worker] Metrics: {completed_packet.metrics}") - if packet.result_future is not None and not packet.result_future.done(): - packet.result_future.set_result(completed_packet) - model_queue.task_done() - - @staticmethod - async def queue_worker_rr(model_name: str, model_queue: asyncio.Queue, rr_model: Optimum_RR, registry: ModelRegistry): - """Reranker model inference worker that processes packets from queue""" - logger.info(f"[{model_name} Reranker Worker] Started, waiting for packets...") - while True: - packet = await model_queue.get() - if packet is None: - logger.info(f"[{model_name} Reranker Worker] Shutdown signal received.") - break - - completed_packet = await InferWorker.infer_rerank(packet, rr_model) - # Check if inference failed and trigger model unload - if not completed_packet.response: - logger.error(f"[{model_name} Reranker Worker] Inference failed, triggering model unload...") - asyncio.create_task(registry.register_unload(model_name)) - break - if completed_packet.metrics: - logger.info(f"[{model_name} Reranker Worker] Metrics: {completed_packet.metrics}") - if packet.result_future is not None and not packet.result_future.done(): - packet.result_future.set_result(completed_packet) - model_queue.task_done() + def create_worker_queue( + model_type: ModelType, + model_name: str, + model_queue: asyncio.Queue, + model_instance: Any, + registry: ModelRegistry, + worker_registry: 'WorkerRegistry' = None, + ) -> asyncio.Task: + """Factory method to create the appropriate worker task based on model type.""" + # Error check functions + def error_check_starts_with_error(packet: WorkerPacket) -> bool: + return bool(packet.response and packet.response.startswith("Error:")) + + def error_check_falsy_response(packet: WorkerPacket) -> bool: + return not packet.response + + # Worker configuration mapping + worker_config = { + ModelType.LLM: { + "worker_type": "LLM", + "infer_method": InferWorker.infer_llm, + "error_check_fn": error_check_starts_with_error, + }, + ModelType.VLM: { + "worker_type": "VLM", + "infer_method": InferWorker.infer_vlm, + "error_check_fn": error_check_starts_with_error, + }, + ModelType.WHISPER: { + "worker_type": "Whisper", + "infer_method": InferWorker.infer_whisper, + "error_check_fn": error_check_starts_with_error, + }, + ModelType.KOKORO: { + "worker_type": "Kokoro", + "infer_method": InferWorker.infer_kokoro, + "error_check_fn": error_check_starts_with_error, + }, + ModelType.EMB: { + "worker_type": "EMB", + "infer_method": InferWorker.infer_emb, + "error_check_fn": error_check_falsy_response, + }, + ModelType.RERANK: { + "worker_type": "Reranker", + "infer_method": InferWorker.infer_rerank, + "error_check_fn": error_check_falsy_response, + }, + } + + config = worker_config.get(model_type) + if config is None: + raise ValueError(f"Unsupported model type: {model_type}") + + return asyncio.create_task( + QueueWorker._generic_worker( + model_name=model_name, + model_queue=model_queue, + model_instance=model_instance, + registry=registry, + worker_type=config["worker_type"], + infer_method=config["infer_method"], + error_check_fn=config["error_check_fn"], + worker_registry=worker_registry, + ) + ) class WorkerRegistry: """ @@ -468,6 +446,10 @@ def __init__(self, model_registry: ModelRegistry): self._model_queues_rerank: Dict[str, asyncio.Queue] = {} self._model_tasks_rerank: Dict[str, asyncio.Task] = {} + # Track active streaming requests for cancellation + # request_id -> (model_name, model_instance) + self._active_requests: Dict[str, Tuple[str, Any]] = {} + self._lock = asyncio.Lock() self._model_registry.add_on_loaded(self._on_model_loaded) @@ -494,42 +476,42 @@ async def _on_model_loaded(self, record: ModelRecord) -> None: if record.model_name not in self._model_queues_llm: q: asyncio.Queue = asyncio.Queue() self._model_queues_llm[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_llm(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry, self) self._model_tasks_llm[record.model_name] = task elif mt == ModelType.VLM and isinstance(instance, OVGenAI_VLM): if record.model_name not in self._model_queues_vlm: q: asyncio.Queue = asyncio.Queue() self._model_queues_vlm[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_vlm(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry, self) self._model_tasks_vlm[record.model_name] = task elif mt == ModelType.WHISPER and isinstance(instance, OVGenAI_Whisper): if record.model_name not in self._model_queues_whisper: q: asyncio.Queue = asyncio.Queue() self._model_queues_whisper[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_whisper(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry) self._model_tasks_whisper[record.model_name] = task elif mt == ModelType.KOKORO and isinstance(instance, OV_Kokoro): if record.model_name not in self._model_queues_kokoro: q: asyncio.Queue = asyncio.Queue() self._model_queues_kokoro[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_kokoro(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry) self._model_tasks_kokoro[record.model_name] = task elif mt == ModelType.EMB and isinstance(instance, Optimum_EMB): if record.model_name not in self._model_queues_emb: q: asyncio.Queue = asyncio.Queue() self._model_queues_emb[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_emb(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry) self._model_tasks_emb[record.model_name] = task - + elif mt == ModelType.RERANK and isinstance(instance, Optimum_RR): if record.model_name not in self._model_queues_rerank: q: asyncio.Queue = asyncio.Queue() self._model_queues_rerank[record.model_name] = q - task = asyncio.create_task(QueueWorker.queue_worker_rr(record.model_name, q, instance, self._model_registry)) + task = QueueWorker.create_worker_queue(mt, record.model_name, q, instance, self._model_registry) self._model_tasks_rerank[record.model_name] = task else: logger.info(f"[WorkerRegistry] Model type/instance mismatch for {record.model_name}: {record.model_type}, {type(instance)}") @@ -617,40 +599,73 @@ def _get_rerank_queue(self, model_name: str) -> asyncio.Queue: return q raise ValueError(f"Rerank model '{model_name}' is not loaded or no worker is available") - async def generate(self, model_name: str, gen_config: OVGenAI_GenConfig) -> Dict[str, Any]: - """Generate text without streaming.""" - request_id = uuid.uuid4().hex - result_future: asyncio.Future = asyncio.get_running_loop().create_future() - packet = WorkerPacket( - request_id=request_id, - id_model=model_name, - gen_config=gen_config, - result_future=result_future, - ) - q = self._get_model_queue(model_name) - await q.put(packet) - completed = await result_future - return {"text": completed.response or "", "metrics": completed.metrics or {}} + async def arc_generate(self, model_name: str, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[str, Dict[str, Any]]]: + """Generate text using the arc_infer codepath, supporting both streaming and non-streaming. + + Unified entry point for LLM inference that delegates to llm.py arc_infer. + Handles both streaming (stream=True) and non-streaming (stream=False) based on gen_config.stream. - async def stream_generate(self, model_name: str, gen_config: OVGenAI_GenConfig) -> AsyncIterator[Union[str, Dict[str, Any]]]: - """Generate text with streaming.""" + Args: + model_name: Target model name + gen_config: Generation configuration with stream flag + + Yields: + For streaming (stream=True): Text chunks followed by metrics dict + For non-streaming (stream=False): Single text chunk followed by metrics dict + """ request_id = uuid.uuid4().hex - stream_queue: asyncio.Queue = asyncio.Queue() - result_future: asyncio.Future = asyncio.get_running_loop().create_future() - packet = WorkerPacket( - request_id=request_id, - id_model=model_name, - gen_config=gen_config, - stream_queue=stream_queue, - result_future=result_future, - ) - q = self._get_model_queue(model_name) - await q.put(packet) - while True: - item = await stream_queue.get() - if item is None: - break - yield item + + if gen_config.stream: + # Streaming mode: use stream_queue for async iteration + stream_queue: asyncio.Queue = asyncio.Queue() + packet = WorkerPacket( + request_id=request_id, + id_model=model_name, + gen_config=gen_config, + stream_queue=stream_queue, + ) + + # Register active request for cancellation tracking + async with self._lock: + self._active_requests[request_id] = (model_name, None) + + q = self._get_model_queue(model_name) + await q.put(packet) + + while True: + item = await stream_queue.get() + if item is None: + # Clean up active request tracking + async with self._lock: + self._active_requests.pop(request_id, None) + break + yield item + else: + # Non-streaming mode: use result_future for single response + result_future: asyncio.Future = asyncio.get_running_loop().create_future() + packet = WorkerPacket( + request_id=request_id, + id_model=model_name, + gen_config=gen_config, + result_future=result_future, + ) + + # Register active request for cancellation tracking + async with self._lock: + self._active_requests[request_id] = (model_name, None) + + q = self._get_model_queue(model_name) + await q.put(packet) + completed = await result_future + + # Clean up active request tracking + async with self._lock: + self._active_requests.pop(request_id, None) + + # Yield the full response as a single chunk, then yield metrics + yield completed.response or "" + if completed.metrics: + yield completed.metrics async def transcribe_whisper(self, model_name: str, gen_config: OVGenAI_WhisperGenConfig) -> Dict[str, Any]: """Transcribe audio using Whisper model.""" @@ -714,4 +729,27 @@ async def rerank(self, model_name: str, rr_config: RerankerConfig) -> Dict[str, q = self._get_rerank_queue(model_name) await q.put(packet) completed = await result_future - return {"data": completed.response, "metrics": completed.metrics or {}} \ No newline at end of file + return {"data": completed.response, "metrics": completed.metrics or {}} + + async def cancel(self, request_id: str) -> bool: + """ + Cancel an ongoing generation by request_id (works for both streaming and non-streaming). + + Args: + request_id: The request ID to cancel + + Returns: + True if cancellation was triggered, False if request_id not found + """ + if request_id in self._active_requests: + model_name, _ = self._active_requests[request_id] + # Look up model instance from ModelRegistry + async with self._model_registry._lock: + for record in self._model_registry._models.values(): + if record.model_name == model_name and record.model_instance is not None: + model_instance = record.model_instance + if hasattr(model_instance, 'cancel'): + await model_instance.cancel(request_id) + logger.info(f"[WorkerRegistry] Cancelled request {request_id} on model {model_name}") + return True + return False \ No newline at end of file