Coverage for mcpgateway / services / llm_proxy_service.py: 99%
325 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/llm_proxy_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Proxy Service
8This module implements the internal proxy for routing LLM requests
9to configured providers. It translates requests to provider-specific
10formats and handles streaming responses.
11"""
13# Standard
14import time
15from typing import Any, AsyncGenerator, Dict, Optional, Tuple
16import uuid
18# Third-Party
19import httpx
20import orjson
21from sqlalchemy import select
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.common.validators import SecurityValidator
26from mcpgateway.config import settings
27from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType
28from mcpgateway.llm_schemas import (
29 ChatChoice,
30 ChatCompletionRequest,
31 ChatCompletionResponse,
32 ChatMessage,
33 UsageStats,
34)
35from mcpgateway.observability import create_span, set_span_attribute
36from mcpgateway.services.llm_provider_service import (
37 decrypt_provider_config_for_runtime,
38 LLMModelNotFoundError,
39 LLMProviderNotFoundError,
40)
41from mcpgateway.services.logging_service import LoggingService
42from mcpgateway.utils.services_auth import decode_auth
43from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload
45# Initialize logging
46logging_service = LoggingService()
47logger = logging_service.get_logger(__name__)
50def _provider_trace_system(provider: LLMProvider) -> str:
51 """Map provider type to a stable ``gen_ai.system`` label.
53 Args:
54 provider: Provider model used to resolve the tracing system label.
56 Returns:
57 Lowercase provider type string for trace attributes.
58 """
59 provider_type = str(provider.provider_type.value if hasattr(provider.provider_type, "value") else provider.provider_type)
60 return provider_type.lower()
63def _request_trace_input(request: ChatCompletionRequest) -> str:
64 """Return a redacted serialized prompt payload for tracing.
66 Args:
67 request: Chat completion request payload being proxied.
69 Returns:
70 Redacted serialized request payload for the trace input field.
71 """
72 return serialize_trace_payload(request.model_dump(mode="json", exclude_none=True))
75def _usage_trace_attrs(response: ChatCompletionResponse) -> Dict[str, int]:
76 """Extract token usage attributes from a chat completion response.
78 Args:
79 response: Provider response carrying token usage metadata.
81 Returns:
82 Trace attribute mapping for prompt, completion, and total token counts.
83 """
84 return {
85 "gen_ai.usage.prompt_tokens": response.usage.prompt_tokens,
86 "gen_ai.usage.completion_tokens": response.usage.completion_tokens,
87 "gen_ai.usage.total_tokens": response.usage.total_tokens,
88 }
91class LLMProxyError(Exception):
92 """Base class for LLM proxy errors."""
95class LLMProxyAuthError(LLMProxyError):
96 """Raised when authentication fails."""
99class LLMProxyRequestError(LLMProxyError):
100 """Raised when request to provider fails."""
103class LLMProxyService:
104 """Service for proxying LLM requests to configured providers.
106 Handles request translation, streaming, and response formatting
107 for the internal /v1/chat/completions endpoint.
108 """
110 def __init__(self) -> None:
111 """Initialize the LLM proxy service."""
112 self._initialized = False
113 self._client: Optional[httpx.AsyncClient] = None
115 async def initialize(self) -> None:
116 """Initialize the proxy service and HTTP client."""
117 if not self._initialized:
118 self._client = httpx.AsyncClient(
119 timeout=httpx.Timeout(settings.llm_request_timeout, connect=30.0),
120 limits=httpx.Limits(
121 max_connections=settings.httpx_max_connections,
122 max_keepalive_connections=settings.httpx_max_keepalive_connections,
123 keepalive_expiry=settings.httpx_keepalive_expiry,
124 ),
125 verify=not settings.skip_ssl_verify,
126 )
127 logger.info("Initialized LLM Proxy Service")
128 self._initialized = True
130 async def shutdown(self) -> None:
131 """Shutdown the proxy service and close connections."""
132 if self._initialized and self._client:
133 await self._client.aclose()
134 self._client = None
135 logger.info("Shutdown LLM Proxy Service")
136 self._initialized = False
138 def _resolve_model(
139 self,
140 db: Session,
141 model_id: str,
142 ) -> Tuple[LLMProvider, LLMModel]:
143 """Resolve a model ID to provider and model.
145 Args:
146 db: Database session.
147 model_id: Model ID (can be model.id, model.model_id, or model.model_alias).
149 Returns:
150 Tuple of (LLMProvider, LLMModel).
152 Raises:
153 LLMModelNotFoundError: If model not found.
154 LLMProviderNotFoundError: If provider not found or disabled.
155 """
156 # Try to find by model.id first
157 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none()
159 # Try by model_id
160 if not model:
161 model = db.execute(select(LLMModel).where(LLMModel.model_id == model_id)).scalar_one_or_none()
163 # Try by model_alias
164 if not model:
165 model = db.execute(select(LLMModel).where(LLMModel.model_alias == model_id)).scalar_one_or_none()
167 if not model:
168 raise LLMModelNotFoundError(f"Model not found: {model_id}")
170 if not model.enabled:
171 raise LLMModelNotFoundError(f"Model is disabled: {model_id}")
173 # Get provider
174 provider = db.execute(select(LLMProvider).where(LLMProvider.id == model.provider_id)).scalar_one_or_none()
176 if not provider:
177 raise LLMProviderNotFoundError(f"Provider not found for model: {model_id}")
179 if not provider.enabled:
180 raise LLMProviderNotFoundError(f"Provider is disabled: {provider.name}")
182 return provider, model
184 def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
185 """Extract API key from provider.
187 Args:
188 provider: LLM provider instance.
190 Returns:
191 Decrypted API key or None.
192 """
193 if not provider.api_key:
194 return None
196 try:
197 auth_data = decode_auth(provider.api_key)
198 return auth_data.get("api_key")
199 except Exception as e:
200 logger.error(f"Failed to decode API key for provider {provider.name}: {e}")
201 return None
203 def _build_openai_request(
204 self,
205 request: ChatCompletionRequest,
206 provider: LLMProvider,
207 model: LLMModel,
208 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
209 """Build request for OpenAI-compatible providers.
211 Args:
212 request: Chat completion request.
213 provider: LLM provider.
214 model: LLM model.
216 Returns:
217 Tuple of (url, headers, body).
218 """
219 api_key = self._get_api_key(provider)
220 base_url = provider.api_base or "https://api.openai.com/v1"
222 url = f"{base_url.rstrip('/')}/chat/completions"
224 headers = {
225 "Content-Type": "application/json",
226 }
227 if api_key:
228 headers["Authorization"] = f"Bearer {api_key}"
230 # Build request body
231 body: Dict[str, Any] = {
232 "model": model.model_id,
233 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
234 }
236 # Add optional parameters
237 if request.temperature is not None:
238 body["temperature"] = request.temperature
239 elif provider.default_temperature:
240 body["temperature"] = provider.default_temperature
242 if request.max_tokens is not None:
243 body["max_tokens"] = request.max_tokens
244 elif provider.default_max_tokens:
245 body["max_tokens"] = provider.default_max_tokens
247 if request.stream:
248 body["stream"] = True
250 if request.tools:
251 body["tools"] = [t.model_dump() for t in request.tools]
252 if request.tool_choice:
253 body["tool_choice"] = request.tool_choice
254 if request.top_p is not None:
255 body["top_p"] = request.top_p
256 if request.frequency_penalty is not None:
257 body["frequency_penalty"] = request.frequency_penalty
258 if request.presence_penalty is not None:
259 body["presence_penalty"] = request.presence_penalty
260 if request.stop:
261 body["stop"] = request.stop
263 return url, headers, body
265 def _build_azure_request(
266 self,
267 request: ChatCompletionRequest,
268 provider: LLMProvider,
269 model: LLMModel,
270 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
271 """Build request for Azure OpenAI.
273 Args:
274 request: Chat completion request.
275 provider: LLM provider.
276 model: LLM model.
278 Returns:
279 Tuple of (url, headers, body).
280 """
281 api_key = self._get_api_key(provider)
282 provider_config = decrypt_provider_config_for_runtime(provider.config)
284 # Get Azure-specific config
285 deployment_name = provider_config.get("deployment_name") or provider_config.get("deployment") or model.model_id
286 resource_name = provider_config.get("resource_name", "")
287 api_version = provider_config.get("api_version") or provider.api_version or "2024-02-15-preview"
289 # Build base URL from resource name if not provided
290 if not provider.api_base and resource_name:
291 base_url = f"https://{resource_name}.openai.azure.com"
292 else:
293 base_url = provider.api_base or ""
295 url = f"{base_url.rstrip('/')}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
297 headers = {
298 "Content-Type": "application/json",
299 "api-key": api_key or "",
300 }
302 # Build request body (similar to OpenAI)
303 body: Dict[str, Any] = {
304 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
305 }
307 if request.temperature is not None:
308 body["temperature"] = request.temperature
309 elif provider.default_temperature:
310 body["temperature"] = provider.default_temperature
312 if request.max_tokens is not None:
313 body["max_tokens"] = request.max_tokens
314 elif provider.default_max_tokens:
315 body["max_tokens"] = provider.default_max_tokens
317 if request.stream:
318 body["stream"] = True
320 return url, headers, body
322 def _build_anthropic_request(
323 self,
324 request: ChatCompletionRequest,
325 provider: LLMProvider,
326 model: LLMModel,
327 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
328 """Build request for Anthropic Claude.
330 Args:
331 request: Chat completion request.
332 provider: LLM provider.
333 model: LLM model.
335 Returns:
336 Tuple of (url, headers, body).
337 """
338 api_key = self._get_api_key(provider)
339 base_url = provider.api_base or "https://api.anthropic.com"
340 provider_config = decrypt_provider_config_for_runtime(provider.config)
342 url = f"{base_url.rstrip('/')}/v1/messages"
344 # Get Anthropic-specific config
345 anthropic_version = provider_config.get("anthropic_version") or provider.api_version or "2023-06-01"
347 headers = {
348 "Content-Type": "application/json",
349 "x-api-key": api_key or "",
350 "anthropic-version": anthropic_version,
351 }
353 # Convert messages to Anthropic format
354 system_message = None
355 messages = []
356 for msg in request.messages:
357 if msg.role == "system":
358 system_message = msg.content
359 else:
360 messages.append(
361 {
362 "role": msg.role,
363 "content": msg.content or "",
364 }
365 )
367 body: Dict[str, Any] = {
368 "model": model.model_id,
369 "messages": messages,
370 "max_tokens": request.max_tokens or provider.default_max_tokens or 4096,
371 }
373 if system_message:
374 body["system"] = system_message
376 if request.temperature is not None:
377 body["temperature"] = request.temperature
378 elif provider.default_temperature:
379 body["temperature"] = provider.default_temperature
381 if request.stream:
382 body["stream"] = True
384 return url, headers, body
386 def _build_ollama_request(
387 self,
388 request: ChatCompletionRequest,
389 provider: LLMProvider,
390 model: LLMModel,
391 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
392 """Build request for Ollama.
394 Args:
395 request: Chat completion request.
396 provider: LLM provider.
397 model: LLM model.
399 Returns:
400 Tuple of (url, headers, body).
401 """
402 base_url = provider.api_base or "http://localhost:11434"
403 base_url = base_url.rstrip("/")
405 # Check if using OpenAI-compatible endpoint
406 if base_url.endswith("/v1"):
407 # Use OpenAI-compatible API
408 url = f"{base_url}/chat/completions"
409 headers = {"Content-Type": "application/json"}
410 body: Dict[str, Any] = {
411 "model": model.model_id,
412 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
413 "stream": request.stream,
414 }
415 if request.temperature is not None:
416 body["temperature"] = request.temperature
417 elif provider.default_temperature:
418 body["temperature"] = provider.default_temperature
419 if request.max_tokens:
420 body["max_tokens"] = request.max_tokens
421 elif provider.default_max_tokens:
422 body["max_tokens"] = provider.default_max_tokens
423 else:
424 # Use native Ollama API
425 url = f"{base_url}/api/chat"
426 headers = {"Content-Type": "application/json"}
427 body = {
428 "model": model.model_id,
429 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
430 "stream": request.stream,
431 }
432 options = {}
433 if request.temperature is not None:
434 options["temperature"] = request.temperature
435 elif provider.default_temperature:
436 options["temperature"] = provider.default_temperature
437 if options:
438 body["options"] = options
440 return url, headers, body
442 async def chat_completion(
443 self,
444 db: Session,
445 request: ChatCompletionRequest,
446 ) -> ChatCompletionResponse:
447 """Process a chat completion request (non-streaming).
449 Args:
450 db: Database session.
451 request: Chat completion request.
453 Returns:
454 ChatCompletionResponse.
456 Raises:
457 LLMProxyRequestError: If request fails.
458 """
459 if not self._client:
460 await self.initialize()
462 provider, model = self._resolve_model(db, request.model)
464 # Build request based on provider type
465 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
466 url, headers, body = self._build_azure_request(request, provider, model)
467 elif provider.provider_type == LLMProviderType.ANTHROPIC:
468 url, headers, body = self._build_anthropic_request(request, provider, model)
469 elif provider.provider_type == LLMProviderType.OLLAMA:
470 url, headers, body = self._build_ollama_request(request, provider, model)
471 else:
472 # Default to OpenAI-compatible
473 url, headers, body = self._build_openai_request(request, provider, model)
475 # Ensure non-streaming
476 body["stream"] = False
478 # Validate the constructed URL to prevent SSRF attacks
479 try:
480 SecurityValidator.validate_url(url, "LLM provider URL")
481 except ValueError as url_err:
482 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err
484 span_attributes = {
485 "langfuse.observation.type": "generation",
486 "gen_ai.system": _provider_trace_system(provider),
487 "gen_ai.request.model": model.model_id,
488 "llm.provider.id": str(provider.id),
489 "llm.provider.type": _provider_trace_system(provider),
490 "llm.model.id": str(model.id),
491 }
492 if is_input_capture_enabled("llm.proxy"):
493 span_attributes["langfuse.observation.input"] = _request_trace_input(request)
495 with create_span("llm.proxy", span_attributes) as span:
496 try:
497 response = await self._client.post(url, headers=headers, json=body)
498 response.raise_for_status()
499 data = response.json()
501 # Transform response based on provider
502 if provider.provider_type == LLMProviderType.ANTHROPIC:
503 result = self._transform_anthropic_response(data, model.model_id)
504 elif provider.provider_type == LLMProviderType.OLLAMA:
505 base_url = (provider.api_base or "").rstrip("/")
506 if base_url.endswith("/v1"):
507 result = self._transform_openai_response(data)
508 else:
509 result = self._transform_ollama_response(data, model.model_id)
510 else:
511 result = self._transform_openai_response(data)
513 if span:
514 set_span_attribute(span, "gen_ai.response.model", result.model)
515 for key, value in _usage_trace_attrs(result).items():
516 set_span_attribute(span, key, value)
517 if is_output_capture_enabled("llm.proxy"):
518 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload(result))
520 return result
522 except httpx.HTTPStatusError as e:
523 logger.error(f"LLM request failed: {e.response.status_code} - {e.response.text}")
524 raise LLMProxyRequestError(f"Request failed: {e.response.status_code}")
525 except httpx.RequestError as e:
526 logger.error(f"LLM request error: {e}")
527 raise LLMProxyRequestError(f"Connection error: {str(e)}")
529 async def chat_completion_stream(
530 self,
531 db: Session,
532 request: ChatCompletionRequest,
533 ) -> AsyncGenerator[str, None]:
534 """Process a streaming chat completion request.
536 Args:
537 db: Database session.
538 request: Chat completion request.
540 Yields:
541 SSE-formatted string chunks.
543 Raises:
544 LLMProxyRequestError: If request fails.
545 """
546 if not self._client:
547 await self.initialize()
549 provider, model = self._resolve_model(db, request.model)
551 # Build request based on provider type
552 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
553 url, headers, body = self._build_azure_request(request, provider, model)
554 elif provider.provider_type == LLMProviderType.ANTHROPIC:
555 url, headers, body = self._build_anthropic_request(request, provider, model)
556 elif provider.provider_type == LLMProviderType.OLLAMA:
557 url, headers, body = self._build_ollama_request(request, provider, model)
558 else:
559 url, headers, body = self._build_openai_request(request, provider, model)
561 # Ensure streaming
562 body["stream"] = True
564 # Validate the constructed URL to prevent SSRF attacks
565 try:
566 SecurityValidator.validate_url(url, "LLM provider URL")
567 except ValueError as url_err:
568 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err
570 response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
571 created = int(time.time())
572 capture_output = is_output_capture_enabled("llm.proxy")
573 captured_output = ""
575 span_attributes = {
576 "langfuse.observation.type": "generation",
577 "gen_ai.system": _provider_trace_system(provider),
578 "gen_ai.request.model": model.model_id,
579 "llm.provider.id": str(provider.id),
580 "llm.provider.type": _provider_trace_system(provider),
581 "llm.model.id": str(model.id),
582 "llm.stream": True,
583 }
584 if is_input_capture_enabled("llm.proxy"):
585 span_attributes["langfuse.observation.input"] = _request_trace_input(request)
587 with create_span("llm.proxy", span_attributes) as span:
588 try:
589 async with self._client.stream("POST", url, headers=headers, json=body) as response:
590 response.raise_for_status()
592 async for line in response.aiter_lines():
593 if not line:
594 continue
596 # Handle SSE format
597 if line.startswith("data:"):
598 data_str = line[5:]
599 if data_str.startswith(" "):
600 data_str = data_str[1:]
601 if data_str.strip() == "[DONE]":
602 yield "data: [DONE]\n\n"
603 break
605 try:
606 data = orjson.loads(data_str)
608 if provider.provider_type == LLMProviderType.ANTHROPIC:
609 chunk = self._transform_anthropic_stream_chunk(data, response_id, created, model.model_id)
610 elif provider.provider_type == LLMProviderType.OLLAMA:
611 base_url = (provider.api_base or "").rstrip("/")
612 if base_url.endswith("/v1"):
613 chunk = data_str
614 else:
615 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
616 else:
617 chunk = data_str
619 if chunk:
620 if capture_output and len(captured_output) < 65536:
621 captured_output += chunk[: 65536 - len(captured_output)]
622 yield f"data: {chunk}\n\n"
624 except orjson.JSONDecodeError:
625 continue
627 elif provider.provider_type == LLMProviderType.OLLAMA:
628 base_url = (provider.api_base or "").rstrip("/")
629 if not base_url.endswith("/v1"):
630 try:
631 data = orjson.loads(line)
632 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
633 if chunk:
634 if capture_output and len(captured_output) < 65536:
635 captured_output += chunk[: 65536 - len(captured_output)]
636 yield f"data: {chunk}\n\n"
637 except orjson.JSONDecodeError:
638 continue
639 except httpx.HTTPStatusError as e:
640 error_chunk = {
641 "error": {
642 "message": f"Request failed: {e.response.status_code}",
643 "type": "proxy_error",
644 }
645 }
646 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
647 except httpx.RequestError as e:
648 error_chunk = {
649 "error": {
650 "message": f"Connection error: {str(e)}",
651 "type": "proxy_error",
652 }
653 }
654 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
655 finally:
656 if span and capture_output and captured_output:
657 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"stream": captured_output}))
659 def _transform_openai_response(self, data: Dict[str, Any]) -> ChatCompletionResponse:
660 """Transform OpenAI response to standard format.
662 Args:
663 data: Raw OpenAI API response data.
665 Returns:
666 ChatCompletionResponse in standard format.
667 """
668 choices = []
669 for choice in data.get("choices", []):
670 message_data = choice.get("message", {})
671 choices.append(
672 ChatChoice(
673 index=choice.get("index", 0),
674 message=ChatMessage(
675 role=message_data.get("role", "assistant"),
676 content=message_data.get("content"),
677 tool_calls=message_data.get("tool_calls"),
678 ),
679 finish_reason=choice.get("finish_reason"),
680 )
681 )
683 usage_data = data.get("usage", {})
684 usage = UsageStats(
685 prompt_tokens=usage_data.get("prompt_tokens", 0),
686 completion_tokens=usage_data.get("completion_tokens", 0),
687 total_tokens=usage_data.get("total_tokens", 0),
688 )
690 return ChatCompletionResponse(
691 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
692 created=data.get("created", int(time.time())),
693 model=data.get("model", "unknown"),
694 choices=choices,
695 usage=usage,
696 )
698 def _transform_anthropic_response(
699 self,
700 data: Dict[str, Any],
701 model_id: str,
702 ) -> ChatCompletionResponse:
703 """Transform Anthropic response to OpenAI format.
705 Args:
706 data: Raw Anthropic API response data.
707 model_id: Model ID to include in response.
709 Returns:
710 ChatCompletionResponse in OpenAI format.
711 """
712 content = ""
713 for block in data.get("content", []):
714 if block.get("type") == "text":
715 content += block.get("text", "")
717 usage_data = data.get("usage", {})
719 return ChatCompletionResponse(
720 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
721 created=int(time.time()),
722 model=model_id,
723 choices=[
724 ChatChoice(
725 index=0,
726 message=ChatMessage(role="assistant", content=content),
727 finish_reason=data.get("stop_reason", "stop"),
728 )
729 ],
730 usage=UsageStats(
731 prompt_tokens=usage_data.get("input_tokens", 0),
732 completion_tokens=usage_data.get("output_tokens", 0),
733 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0),
734 ),
735 )
737 def _transform_ollama_response(
738 self,
739 data: Dict[str, Any],
740 model_id: str,
741 ) -> ChatCompletionResponse:
742 """Transform Ollama response to OpenAI format.
744 Args:
745 data: Raw Ollama API response data.
746 model_id: Model ID to include in response.
748 Returns:
749 ChatCompletionResponse in OpenAI format.
750 """
751 message = data.get("message", {})
753 return ChatCompletionResponse(
754 id=f"chatcmpl-{uuid.uuid4().hex[:24]}",
755 created=int(time.time()),
756 model=model_id,
757 choices=[
758 ChatChoice(
759 index=0,
760 message=ChatMessage(
761 role=message.get("role", "assistant"),
762 content=message.get("content", ""),
763 ),
764 finish_reason="stop" if data.get("done") else None,
765 )
766 ],
767 usage=UsageStats(
768 prompt_tokens=data.get("prompt_eval_count", 0),
769 completion_tokens=data.get("eval_count", 0),
770 total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
771 ),
772 )
774 def _transform_anthropic_stream_chunk(
775 self,
776 data: Dict[str, Any],
777 response_id: str,
778 created: int,
779 model_id: str,
780 ) -> Optional[str]:
781 """Transform Anthropic streaming chunk to OpenAI format.
783 Args:
784 data: Raw Anthropic streaming event data.
785 response_id: Response ID for the chunk.
786 created: Timestamp for the response.
787 model_id: Model ID to include in response.
789 Returns:
790 JSON string chunk in OpenAI format, or None if not applicable.
791 """
792 event_type = data.get("type")
794 if event_type == "content_block_delta":
795 delta = data.get("delta", {})
796 if delta.get("type") == "text_delta":
797 chunk = {
798 "id": response_id,
799 "object": "chat.completion.chunk",
800 "created": created,
801 "model": model_id,
802 "choices": [
803 {
804 "index": 0,
805 "delta": {"content": delta.get("text", "")},
806 "finish_reason": None,
807 }
808 ],
809 }
810 return orjson.dumps(chunk).decode()
812 elif event_type == "message_stop":
813 chunk = {
814 "id": response_id,
815 "object": "chat.completion.chunk",
816 "created": created,
817 "model": model_id,
818 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
819 }
820 return orjson.dumps(chunk).decode()
822 return None
824 def _transform_ollama_stream_chunk(
825 self,
826 data: Dict[str, Any],
827 response_id: str,
828 created: int,
829 model_id: str,
830 ) -> Optional[str]:
831 """Transform Ollama streaming chunk to OpenAI format.
833 Args:
834 data: Raw Ollama streaming event data.
835 response_id: Response ID for the chunk.
836 created: Timestamp for the response.
837 model_id: Model ID to include in response.
839 Returns:
840 JSON string chunk in OpenAI format, or None if not applicable.
841 """
842 message = data.get("message", {})
843 content = message.get("content", "")
845 if data.get("done"):
846 chunk = {
847 "id": response_id,
848 "object": "chat.completion.chunk",
849 "created": created,
850 "model": model_id,
851 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
852 }
853 else:
854 chunk = {
855 "id": response_id,
856 "object": "chat.completion.chunk",
857 "created": created,
858 "model": model_id,
859 "choices": [
860 {
861 "index": 0,
862 "delta": {"content": content} if content else {},
863 "finish_reason": None,
864 }
865 ],
866 }
868 return orjson.dumps(chunk).decode()