Coverage for mcpgateway / services / llm_proxy_service.py: 100%
280 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/llm_proxy_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Proxy Service
8This module implements the internal proxy for routing LLM requests
9to configured providers. It translates requests to provider-specific
10formats and handles streaming responses.
11"""
13# Standard
14import time
15from typing import Any, AsyncGenerator, Dict, Optional, Tuple
16import uuid
18# Third-Party
19import httpx
20import orjson
21from sqlalchemy import select
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.config import settings
26from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType
27from mcpgateway.llm_schemas import (
28 ChatChoice,
29 ChatCompletionRequest,
30 ChatCompletionResponse,
31 ChatMessage,
32 UsageStats,
33)
34from mcpgateway.services.llm_provider_service import (
35 LLMModelNotFoundError,
36 LLMProviderNotFoundError,
37)
38from mcpgateway.services.logging_service import LoggingService
39from mcpgateway.utils.services_auth import decode_auth
41# Initialize logging
42logging_service = LoggingService()
43logger = logging_service.get_logger(__name__)
46class LLMProxyError(Exception):
47 """Base class for LLM proxy errors."""
50class LLMProxyAuthError(LLMProxyError):
51 """Raised when authentication fails."""
54class LLMProxyRequestError(LLMProxyError):
55 """Raised when request to provider fails."""
58class LLMProxyService:
59 """Service for proxying LLM requests to configured providers.
61 Handles request translation, streaming, and response formatting
62 for the internal /v1/chat/completions endpoint.
63 """
65 def __init__(self) -> None:
66 """Initialize the LLM proxy service."""
67 self._initialized = False
68 self._client: Optional[httpx.AsyncClient] = None
70 async def initialize(self) -> None:
71 """Initialize the proxy service and HTTP client."""
72 if not self._initialized:
73 self._client = httpx.AsyncClient(
74 timeout=httpx.Timeout(settings.llm_request_timeout, connect=30.0),
75 limits=httpx.Limits(
76 max_connections=settings.httpx_max_connections,
77 max_keepalive_connections=settings.httpx_max_keepalive_connections,
78 keepalive_expiry=settings.httpx_keepalive_expiry,
79 ),
80 verify=not settings.skip_ssl_verify,
81 )
82 logger.info("Initialized LLM Proxy Service")
83 self._initialized = True
85 async def shutdown(self) -> None:
86 """Shutdown the proxy service and close connections."""
87 if self._initialized and self._client:
88 await self._client.aclose()
89 self._client = None
90 logger.info("Shutdown LLM Proxy Service")
91 self._initialized = False
93 def _resolve_model(
94 self,
95 db: Session,
96 model_id: str,
97 ) -> Tuple[LLMProvider, LLMModel]:
98 """Resolve a model ID to provider and model.
100 Args:
101 db: Database session.
102 model_id: Model ID (can be model.id, model.model_id, or model.model_alias).
104 Returns:
105 Tuple of (LLMProvider, LLMModel).
107 Raises:
108 LLMModelNotFoundError: If model not found.
109 LLMProviderNotFoundError: If provider not found or disabled.
110 """
111 # Try to find by model.id first
112 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none()
114 # Try by model_id
115 if not model:
116 model = db.execute(select(LLMModel).where(LLMModel.model_id == model_id)).scalar_one_or_none()
118 # Try by model_alias
119 if not model:
120 model = db.execute(select(LLMModel).where(LLMModel.model_alias == model_id)).scalar_one_or_none()
122 if not model:
123 raise LLMModelNotFoundError(f"Model not found: {model_id}")
125 if not model.enabled:
126 raise LLMModelNotFoundError(f"Model is disabled: {model_id}")
128 # Get provider
129 provider = db.execute(select(LLMProvider).where(LLMProvider.id == model.provider_id)).scalar_one_or_none()
131 if not provider:
132 raise LLMProviderNotFoundError(f"Provider not found for model: {model_id}")
134 if not provider.enabled:
135 raise LLMProviderNotFoundError(f"Provider is disabled: {provider.name}")
137 return provider, model
139 def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
140 """Extract API key from provider.
142 Args:
143 provider: LLM provider instance.
145 Returns:
146 Decrypted API key or None.
147 """
148 if not provider.api_key:
149 return None
151 try:
152 auth_data = decode_auth(provider.api_key)
153 return auth_data.get("api_key")
154 except Exception as e:
155 logger.error(f"Failed to decode API key for provider {provider.name}: {e}")
156 return None
158 def _build_openai_request(
159 self,
160 request: ChatCompletionRequest,
161 provider: LLMProvider,
162 model: LLMModel,
163 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
164 """Build request for OpenAI-compatible providers.
166 Args:
167 request: Chat completion request.
168 provider: LLM provider.
169 model: LLM model.
171 Returns:
172 Tuple of (url, headers, body).
173 """
174 api_key = self._get_api_key(provider)
175 base_url = provider.api_base or "https://api.openai.com/v1"
177 url = f"{base_url.rstrip('/')}/chat/completions"
179 headers = {
180 "Content-Type": "application/json",
181 }
182 if api_key:
183 headers["Authorization"] = f"Bearer {api_key}"
185 # Build request body
186 body: Dict[str, Any] = {
187 "model": model.model_id,
188 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
189 }
191 # Add optional parameters
192 if request.temperature is not None:
193 body["temperature"] = request.temperature
194 elif provider.default_temperature:
195 body["temperature"] = provider.default_temperature
197 if request.max_tokens is not None:
198 body["max_tokens"] = request.max_tokens
199 elif provider.default_max_tokens:
200 body["max_tokens"] = provider.default_max_tokens
202 if request.stream:
203 body["stream"] = True
205 if request.tools:
206 body["tools"] = [t.model_dump() for t in request.tools]
207 if request.tool_choice:
208 body["tool_choice"] = request.tool_choice
209 if request.top_p is not None:
210 body["top_p"] = request.top_p
211 if request.frequency_penalty is not None:
212 body["frequency_penalty"] = request.frequency_penalty
213 if request.presence_penalty is not None:
214 body["presence_penalty"] = request.presence_penalty
215 if request.stop:
216 body["stop"] = request.stop
218 return url, headers, body
220 def _build_azure_request(
221 self,
222 request: ChatCompletionRequest,
223 provider: LLMProvider,
224 model: LLMModel,
225 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
226 """Build request for Azure OpenAI.
228 Args:
229 request: Chat completion request.
230 provider: LLM provider.
231 model: LLM model.
233 Returns:
234 Tuple of (url, headers, body).
235 """
236 api_key = self._get_api_key(provider)
238 # Get Azure-specific config
239 deployment_name = provider.config.get("deployment_name") or provider.config.get("deployment") or model.model_id
240 resource_name = provider.config.get("resource_name", "")
241 api_version = provider.config.get("api_version") or provider.api_version or "2024-02-15-preview"
243 # Build base URL from resource name if not provided
244 if not provider.api_base and resource_name:
245 base_url = f"https://{resource_name}.openai.azure.com"
246 else:
247 base_url = provider.api_base or ""
249 url = f"{base_url.rstrip('/')}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
251 headers = {
252 "Content-Type": "application/json",
253 "api-key": api_key or "",
254 }
256 # Build request body (similar to OpenAI)
257 body: Dict[str, Any] = {
258 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
259 }
261 if request.temperature is not None:
262 body["temperature"] = request.temperature
263 elif provider.default_temperature:
264 body["temperature"] = provider.default_temperature
266 if request.max_tokens is not None:
267 body["max_tokens"] = request.max_tokens
268 elif provider.default_max_tokens:
269 body["max_tokens"] = provider.default_max_tokens
271 if request.stream:
272 body["stream"] = True
274 return url, headers, body
276 def _build_anthropic_request(
277 self,
278 request: ChatCompletionRequest,
279 provider: LLMProvider,
280 model: LLMModel,
281 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
282 """Build request for Anthropic Claude.
284 Args:
285 request: Chat completion request.
286 provider: LLM provider.
287 model: LLM model.
289 Returns:
290 Tuple of (url, headers, body).
291 """
292 api_key = self._get_api_key(provider)
293 base_url = provider.api_base or "https://api.anthropic.com"
295 url = f"{base_url.rstrip('/')}/v1/messages"
297 # Get Anthropic-specific config
298 anthropic_version = provider.config.get("anthropic_version") or provider.api_version or "2023-06-01"
300 headers = {
301 "Content-Type": "application/json",
302 "x-api-key": api_key or "",
303 "anthropic-version": anthropic_version,
304 }
306 # Convert messages to Anthropic format
307 system_message = None
308 messages = []
309 for msg in request.messages:
310 if msg.role == "system":
311 system_message = msg.content
312 else:
313 messages.append(
314 {
315 "role": msg.role,
316 "content": msg.content or "",
317 }
318 )
320 body: Dict[str, Any] = {
321 "model": model.model_id,
322 "messages": messages,
323 "max_tokens": request.max_tokens or provider.default_max_tokens or 4096,
324 }
326 if system_message:
327 body["system"] = system_message
329 if request.temperature is not None:
330 body["temperature"] = request.temperature
331 elif provider.default_temperature:
332 body["temperature"] = provider.default_temperature
334 if request.stream:
335 body["stream"] = True
337 return url, headers, body
339 def _build_ollama_request(
340 self,
341 request: ChatCompletionRequest,
342 provider: LLMProvider,
343 model: LLMModel,
344 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
345 """Build request for Ollama.
347 Args:
348 request: Chat completion request.
349 provider: LLM provider.
350 model: LLM model.
352 Returns:
353 Tuple of (url, headers, body).
354 """
355 base_url = provider.api_base or "http://localhost:11434"
356 base_url = base_url.rstrip("/")
358 # Check if using OpenAI-compatible endpoint
359 if base_url.endswith("/v1"):
360 # Use OpenAI-compatible API
361 url = f"{base_url}/chat/completions"
362 headers = {"Content-Type": "application/json"}
363 body: Dict[str, Any] = {
364 "model": model.model_id,
365 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
366 "stream": request.stream,
367 }
368 if request.temperature is not None:
369 body["temperature"] = request.temperature
370 elif provider.default_temperature:
371 body["temperature"] = provider.default_temperature
372 if request.max_tokens:
373 body["max_tokens"] = request.max_tokens
374 elif provider.default_max_tokens:
375 body["max_tokens"] = provider.default_max_tokens
376 else:
377 # Use native Ollama API
378 url = f"{base_url}/api/chat"
379 headers = {"Content-Type": "application/json"}
380 body = {
381 "model": model.model_id,
382 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
383 "stream": request.stream,
384 }
385 options = {}
386 if request.temperature is not None:
387 options["temperature"] = request.temperature
388 elif provider.default_temperature:
389 options["temperature"] = provider.default_temperature
390 if options:
391 body["options"] = options
393 return url, headers, body
395 async def chat_completion(
396 self,
397 db: Session,
398 request: ChatCompletionRequest,
399 ) -> ChatCompletionResponse:
400 """Process a chat completion request (non-streaming).
402 Args:
403 db: Database session.
404 request: Chat completion request.
406 Returns:
407 ChatCompletionResponse.
409 Raises:
410 LLMProxyRequestError: If request fails.
411 """
412 if not self._client:
413 await self.initialize()
415 provider, model = self._resolve_model(db, request.model)
417 # Build request based on provider type
418 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
419 url, headers, body = self._build_azure_request(request, provider, model)
420 elif provider.provider_type == LLMProviderType.ANTHROPIC:
421 url, headers, body = self._build_anthropic_request(request, provider, model)
422 elif provider.provider_type == LLMProviderType.OLLAMA:
423 url, headers, body = self._build_ollama_request(request, provider, model)
424 else:
425 # Default to OpenAI-compatible
426 url, headers, body = self._build_openai_request(request, provider, model)
428 # Ensure non-streaming
429 body["stream"] = False
431 try:
432 response = await self._client.post(url, headers=headers, json=body)
433 response.raise_for_status()
434 data = response.json()
436 # Transform response based on provider
437 if provider.provider_type == LLMProviderType.ANTHROPIC:
438 return self._transform_anthropic_response(data, model.model_id)
439 if provider.provider_type == LLMProviderType.OLLAMA:
440 # Check if using OpenAI-compatible endpoint
441 base_url = (provider.api_base or "").rstrip("/")
442 if base_url.endswith("/v1"):
443 return self._transform_openai_response(data)
444 return self._transform_ollama_response(data, model.model_id)
445 return self._transform_openai_response(data)
447 except httpx.HTTPStatusError as e:
448 logger.error(f"LLM request failed: {e.response.status_code} - {e.response.text}")
449 raise LLMProxyRequestError(f"Request failed: {e.response.status_code}")
450 except httpx.RequestError as e:
451 logger.error(f"LLM request error: {e}")
452 raise LLMProxyRequestError(f"Connection error: {str(e)}")
454 async def chat_completion_stream(
455 self,
456 db: Session,
457 request: ChatCompletionRequest,
458 ) -> AsyncGenerator[str, None]:
459 """Process a streaming chat completion request.
461 Args:
462 db: Database session.
463 request: Chat completion request.
465 Yields:
466 SSE-formatted string chunks.
467 """
468 if not self._client:
469 await self.initialize()
471 provider, model = self._resolve_model(db, request.model)
473 # Build request based on provider type
474 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
475 url, headers, body = self._build_azure_request(request, provider, model)
476 elif provider.provider_type == LLMProviderType.ANTHROPIC:
477 url, headers, body = self._build_anthropic_request(request, provider, model)
478 elif provider.provider_type == LLMProviderType.OLLAMA:
479 url, headers, body = self._build_ollama_request(request, provider, model)
480 else:
481 url, headers, body = self._build_openai_request(request, provider, model)
483 # Ensure streaming
484 body["stream"] = True
486 response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
487 created = int(time.time())
489 try:
490 async with self._client.stream("POST", url, headers=headers, json=body) as response:
491 response.raise_for_status()
493 async for line in response.aiter_lines():
494 if not line:
495 continue
497 # Handle SSE format
498 if line.startswith("data: "):
499 data_str = line[6:]
500 if data_str.strip() == "[DONE]":
501 yield "data: [DONE]\n\n"
502 break
504 try:
505 data = orjson.loads(data_str)
507 # Transform based on provider
508 if provider.provider_type == LLMProviderType.ANTHROPIC:
509 chunk = self._transform_anthropic_stream_chunk(data, response_id, created, model.model_id)
510 elif provider.provider_type == LLMProviderType.OLLAMA:
511 # Check if using OpenAI-compatible endpoint
512 base_url = (provider.api_base or "").rstrip("/")
513 if base_url.endswith("/v1"):
514 chunk = data_str # Already OpenAI format
515 else:
516 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
517 else:
518 chunk = data_str
520 if chunk:
521 yield f"data: {chunk}\n\n"
523 except orjson.JSONDecodeError:
524 continue
526 # Handle Ollama's newline-delimited JSON (native API only)
527 elif provider.provider_type == LLMProviderType.OLLAMA:
528 base_url = (provider.api_base or "").rstrip("/")
529 if not base_url.endswith("/v1"):
530 try:
531 data = orjson.loads(line)
532 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
533 if chunk:
534 yield f"data: {chunk}\n\n"
535 except orjson.JSONDecodeError:
536 continue
538 except httpx.HTTPStatusError as e:
539 error_chunk = {
540 "error": {
541 "message": f"Request failed: {e.response.status_code}",
542 "type": "proxy_error",
543 }
544 }
545 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
546 except httpx.RequestError as e:
547 error_chunk = {
548 "error": {
549 "message": f"Connection error: {str(e)}",
550 "type": "proxy_error",
551 }
552 }
553 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
555 def _transform_openai_response(self, data: Dict[str, Any]) -> ChatCompletionResponse:
556 """Transform OpenAI response to standard format.
558 Args:
559 data: Raw OpenAI API response data.
561 Returns:
562 ChatCompletionResponse in standard format.
563 """
564 choices = []
565 for choice in data.get("choices", []):
566 message_data = choice.get("message", {})
567 choices.append(
568 ChatChoice(
569 index=choice.get("index", 0),
570 message=ChatMessage(
571 role=message_data.get("role", "assistant"),
572 content=message_data.get("content"),
573 tool_calls=message_data.get("tool_calls"),
574 ),
575 finish_reason=choice.get("finish_reason"),
576 )
577 )
579 usage_data = data.get("usage", {})
580 usage = UsageStats(
581 prompt_tokens=usage_data.get("prompt_tokens", 0),
582 completion_tokens=usage_data.get("completion_tokens", 0),
583 total_tokens=usage_data.get("total_tokens", 0),
584 )
586 return ChatCompletionResponse(
587 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
588 created=data.get("created", int(time.time())),
589 model=data.get("model", "unknown"),
590 choices=choices,
591 usage=usage,
592 )
594 def _transform_anthropic_response(
595 self,
596 data: Dict[str, Any],
597 model_id: str,
598 ) -> ChatCompletionResponse:
599 """Transform Anthropic response to OpenAI format.
601 Args:
602 data: Raw Anthropic API response data.
603 model_id: Model ID to include in response.
605 Returns:
606 ChatCompletionResponse in OpenAI format.
607 """
608 content = ""
609 for block in data.get("content", []):
610 if block.get("type") == "text":
611 content += block.get("text", "")
613 usage_data = data.get("usage", {})
615 return ChatCompletionResponse(
616 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
617 created=int(time.time()),
618 model=model_id,
619 choices=[
620 ChatChoice(
621 index=0,
622 message=ChatMessage(role="assistant", content=content),
623 finish_reason=data.get("stop_reason", "stop"),
624 )
625 ],
626 usage=UsageStats(
627 prompt_tokens=usage_data.get("input_tokens", 0),
628 completion_tokens=usage_data.get("output_tokens", 0),
629 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0),
630 ),
631 )
633 def _transform_ollama_response(
634 self,
635 data: Dict[str, Any],
636 model_id: str,
637 ) -> ChatCompletionResponse:
638 """Transform Ollama response to OpenAI format.
640 Args:
641 data: Raw Ollama API response data.
642 model_id: Model ID to include in response.
644 Returns:
645 ChatCompletionResponse in OpenAI format.
646 """
647 message = data.get("message", {})
649 return ChatCompletionResponse(
650 id=f"chatcmpl-{uuid.uuid4().hex[:24]}",
651 created=int(time.time()),
652 model=model_id,
653 choices=[
654 ChatChoice(
655 index=0,
656 message=ChatMessage(
657 role=message.get("role", "assistant"),
658 content=message.get("content", ""),
659 ),
660 finish_reason="stop" if data.get("done") else None,
661 )
662 ],
663 usage=UsageStats(
664 prompt_tokens=data.get("prompt_eval_count", 0),
665 completion_tokens=data.get("eval_count", 0),
666 total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
667 ),
668 )
670 def _transform_anthropic_stream_chunk(
671 self,
672 data: Dict[str, Any],
673 response_id: str,
674 created: int,
675 model_id: str,
676 ) -> Optional[str]:
677 """Transform Anthropic streaming chunk to OpenAI format.
679 Args:
680 data: Raw Anthropic streaming event data.
681 response_id: Response ID for the chunk.
682 created: Timestamp for the response.
683 model_id: Model ID to include in response.
685 Returns:
686 JSON string chunk in OpenAI format, or None if not applicable.
687 """
688 event_type = data.get("type")
690 if event_type == "content_block_delta":
691 delta = data.get("delta", {})
692 if delta.get("type") == "text_delta":
693 chunk = {
694 "id": response_id,
695 "object": "chat.completion.chunk",
696 "created": created,
697 "model": model_id,
698 "choices": [
699 {
700 "index": 0,
701 "delta": {"content": delta.get("text", "")},
702 "finish_reason": None,
703 }
704 ],
705 }
706 return orjson.dumps(chunk).decode()
708 elif event_type == "message_stop":
709 chunk = {
710 "id": response_id,
711 "object": "chat.completion.chunk",
712 "created": created,
713 "model": model_id,
714 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
715 }
716 return orjson.dumps(chunk).decode()
718 return None
720 def _transform_ollama_stream_chunk(
721 self,
722 data: Dict[str, Any],
723 response_id: str,
724 created: int,
725 model_id: str,
726 ) -> Optional[str]:
727 """Transform Ollama streaming chunk to OpenAI format.
729 Args:
730 data: Raw Ollama streaming event data.
731 response_id: Response ID for the chunk.
732 created: Timestamp for the response.
733 model_id: Model ID to include in response.
735 Returns:
736 JSON string chunk in OpenAI format, or None if not applicable.
737 """
738 message = data.get("message", {})
739 content = message.get("content", "")
741 if data.get("done"):
742 chunk = {
743 "id": response_id,
744 "object": "chat.completion.chunk",
745 "created": created,
746 "model": model_id,
747 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
748 }
749 else:
750 chunk = {
751 "id": response_id,
752 "object": "chat.completion.chunk",
753 "created": created,
754 "model": model_id,
755 "choices": [
756 {
757 "index": 0,
758 "delta": {"content": content} if content else {},
759 "finish_reason": None,
760 }
761 ],
762 }
764 return orjson.dumps(chunk).decode()