Coverage for mcpgateway / services / llm_proxy_service.py: 100%
293 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/llm_proxy_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Proxy Service
8This module implements the internal proxy for routing LLM requests
9to configured providers. It translates requests to provider-specific
10formats and handles streaming responses.
11"""
13# Standard
14import time
15from typing import Any, AsyncGenerator, Dict, Optional, Tuple
16import uuid
18# Third-Party
19import httpx
20import orjson
21from sqlalchemy import select
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.common.validators import SecurityValidator
26from mcpgateway.config import settings
27from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType
28from mcpgateway.llm_schemas import (
29 ChatChoice,
30 ChatCompletionRequest,
31 ChatCompletionResponse,
32 ChatMessage,
33 UsageStats,
34)
35from mcpgateway.services.llm_provider_service import (
36 decrypt_provider_config_for_runtime,
37 LLMModelNotFoundError,
38 LLMProviderNotFoundError,
39)
40from mcpgateway.services.logging_service import LoggingService
41from mcpgateway.utils.services_auth import decode_auth
43# Initialize logging
44logging_service = LoggingService()
45logger = logging_service.get_logger(__name__)
48class LLMProxyError(Exception):
49 """Base class for LLM proxy errors."""
52class LLMProxyAuthError(LLMProxyError):
53 """Raised when authentication fails."""
56class LLMProxyRequestError(LLMProxyError):
57 """Raised when request to provider fails."""
60class LLMProxyService:
61 """Service for proxying LLM requests to configured providers.
63 Handles request translation, streaming, and response formatting
64 for the internal /v1/chat/completions endpoint.
65 """
67 def __init__(self) -> None:
68 """Initialize the LLM proxy service."""
69 self._initialized = False
70 self._client: Optional[httpx.AsyncClient] = None
72 async def initialize(self) -> None:
73 """Initialize the proxy service and HTTP client."""
74 if not self._initialized:
75 self._client = httpx.AsyncClient(
76 timeout=httpx.Timeout(settings.llm_request_timeout, connect=30.0),
77 limits=httpx.Limits(
78 max_connections=settings.httpx_max_connections,
79 max_keepalive_connections=settings.httpx_max_keepalive_connections,
80 keepalive_expiry=settings.httpx_keepalive_expiry,
81 ),
82 verify=not settings.skip_ssl_verify,
83 )
84 logger.info("Initialized LLM Proxy Service")
85 self._initialized = True
87 async def shutdown(self) -> None:
88 """Shutdown the proxy service and close connections."""
89 if self._initialized and self._client:
90 await self._client.aclose()
91 self._client = None
92 logger.info("Shutdown LLM Proxy Service")
93 self._initialized = False
95 def _resolve_model(
96 self,
97 db: Session,
98 model_id: str,
99 ) -> Tuple[LLMProvider, LLMModel]:
100 """Resolve a model ID to provider and model.
102 Args:
103 db: Database session.
104 model_id: Model ID (can be model.id, model.model_id, or model.model_alias).
106 Returns:
107 Tuple of (LLMProvider, LLMModel).
109 Raises:
110 LLMModelNotFoundError: If model not found.
111 LLMProviderNotFoundError: If provider not found or disabled.
112 """
113 # Try to find by model.id first
114 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none()
116 # Try by model_id
117 if not model:
118 model = db.execute(select(LLMModel).where(LLMModel.model_id == model_id)).scalar_one_or_none()
120 # Try by model_alias
121 if not model:
122 model = db.execute(select(LLMModel).where(LLMModel.model_alias == model_id)).scalar_one_or_none()
124 if not model:
125 raise LLMModelNotFoundError(f"Model not found: {model_id}")
127 if not model.enabled:
128 raise LLMModelNotFoundError(f"Model is disabled: {model_id}")
130 # Get provider
131 provider = db.execute(select(LLMProvider).where(LLMProvider.id == model.provider_id)).scalar_one_or_none()
133 if not provider:
134 raise LLMProviderNotFoundError(f"Provider not found for model: {model_id}")
136 if not provider.enabled:
137 raise LLMProviderNotFoundError(f"Provider is disabled: {provider.name}")
139 return provider, model
141 def _get_api_key(self, provider: LLMProvider) -> Optional[str]:
142 """Extract API key from provider.
144 Args:
145 provider: LLM provider instance.
147 Returns:
148 Decrypted API key or None.
149 """
150 if not provider.api_key:
151 return None
153 try:
154 auth_data = decode_auth(provider.api_key)
155 return auth_data.get("api_key")
156 except Exception as e:
157 logger.error(f"Failed to decode API key for provider {provider.name}: {e}")
158 return None
160 def _build_openai_request(
161 self,
162 request: ChatCompletionRequest,
163 provider: LLMProvider,
164 model: LLMModel,
165 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
166 """Build request for OpenAI-compatible providers.
168 Args:
169 request: Chat completion request.
170 provider: LLM provider.
171 model: LLM model.
173 Returns:
174 Tuple of (url, headers, body).
175 """
176 api_key = self._get_api_key(provider)
177 base_url = provider.api_base or "https://api.openai.com/v1"
179 url = f"{base_url.rstrip('/')}/chat/completions"
181 headers = {
182 "Content-Type": "application/json",
183 }
184 if api_key:
185 headers["Authorization"] = f"Bearer {api_key}"
187 # Build request body
188 body: Dict[str, Any] = {
189 "model": model.model_id,
190 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
191 }
193 # Add optional parameters
194 if request.temperature is not None:
195 body["temperature"] = request.temperature
196 elif provider.default_temperature:
197 body["temperature"] = provider.default_temperature
199 if request.max_tokens is not None:
200 body["max_tokens"] = request.max_tokens
201 elif provider.default_max_tokens:
202 body["max_tokens"] = provider.default_max_tokens
204 if request.stream:
205 body["stream"] = True
207 if request.tools:
208 body["tools"] = [t.model_dump() for t in request.tools]
209 if request.tool_choice:
210 body["tool_choice"] = request.tool_choice
211 if request.top_p is not None:
212 body["top_p"] = request.top_p
213 if request.frequency_penalty is not None:
214 body["frequency_penalty"] = request.frequency_penalty
215 if request.presence_penalty is not None:
216 body["presence_penalty"] = request.presence_penalty
217 if request.stop:
218 body["stop"] = request.stop
220 return url, headers, body
222 def _build_azure_request(
223 self,
224 request: ChatCompletionRequest,
225 provider: LLMProvider,
226 model: LLMModel,
227 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
228 """Build request for Azure OpenAI.
230 Args:
231 request: Chat completion request.
232 provider: LLM provider.
233 model: LLM model.
235 Returns:
236 Tuple of (url, headers, body).
237 """
238 api_key = self._get_api_key(provider)
239 provider_config = decrypt_provider_config_for_runtime(provider.config)
241 # Get Azure-specific config
242 deployment_name = provider_config.get("deployment_name") or provider_config.get("deployment") or model.model_id
243 resource_name = provider_config.get("resource_name", "")
244 api_version = provider_config.get("api_version") or provider.api_version or "2024-02-15-preview"
246 # Build base URL from resource name if not provided
247 if not provider.api_base and resource_name:
248 base_url = f"https://{resource_name}.openai.azure.com"
249 else:
250 base_url = provider.api_base or ""
252 url = f"{base_url.rstrip('/')}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}"
254 headers = {
255 "Content-Type": "application/json",
256 "api-key": api_key or "",
257 }
259 # Build request body (similar to OpenAI)
260 body: Dict[str, Any] = {
261 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages],
262 }
264 if request.temperature is not None:
265 body["temperature"] = request.temperature
266 elif provider.default_temperature:
267 body["temperature"] = provider.default_temperature
269 if request.max_tokens is not None:
270 body["max_tokens"] = request.max_tokens
271 elif provider.default_max_tokens:
272 body["max_tokens"] = provider.default_max_tokens
274 if request.stream:
275 body["stream"] = True
277 return url, headers, body
279 def _build_anthropic_request(
280 self,
281 request: ChatCompletionRequest,
282 provider: LLMProvider,
283 model: LLMModel,
284 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
285 """Build request for Anthropic Claude.
287 Args:
288 request: Chat completion request.
289 provider: LLM provider.
290 model: LLM model.
292 Returns:
293 Tuple of (url, headers, body).
294 """
295 api_key = self._get_api_key(provider)
296 base_url = provider.api_base or "https://api.anthropic.com"
297 provider_config = decrypt_provider_config_for_runtime(provider.config)
299 url = f"{base_url.rstrip('/')}/v1/messages"
301 # Get Anthropic-specific config
302 anthropic_version = provider_config.get("anthropic_version") or provider.api_version or "2023-06-01"
304 headers = {
305 "Content-Type": "application/json",
306 "x-api-key": api_key or "",
307 "anthropic-version": anthropic_version,
308 }
310 # Convert messages to Anthropic format
311 system_message = None
312 messages = []
313 for msg in request.messages:
314 if msg.role == "system":
315 system_message = msg.content
316 else:
317 messages.append(
318 {
319 "role": msg.role,
320 "content": msg.content or "",
321 }
322 )
324 body: Dict[str, Any] = {
325 "model": model.model_id,
326 "messages": messages,
327 "max_tokens": request.max_tokens or provider.default_max_tokens or 4096,
328 }
330 if system_message:
331 body["system"] = system_message
333 if request.temperature is not None:
334 body["temperature"] = request.temperature
335 elif provider.default_temperature:
336 body["temperature"] = provider.default_temperature
338 if request.stream:
339 body["stream"] = True
341 return url, headers, body
343 def _build_ollama_request(
344 self,
345 request: ChatCompletionRequest,
346 provider: LLMProvider,
347 model: LLMModel,
348 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]:
349 """Build request for Ollama.
351 Args:
352 request: Chat completion request.
353 provider: LLM provider.
354 model: LLM model.
356 Returns:
357 Tuple of (url, headers, body).
358 """
359 base_url = provider.api_base or "http://localhost:11434"
360 base_url = base_url.rstrip("/")
362 # Check if using OpenAI-compatible endpoint
363 if base_url.endswith("/v1"):
364 # Use OpenAI-compatible API
365 url = f"{base_url}/chat/completions"
366 headers = {"Content-Type": "application/json"}
367 body: Dict[str, Any] = {
368 "model": model.model_id,
369 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
370 "stream": request.stream,
371 }
372 if request.temperature is not None:
373 body["temperature"] = request.temperature
374 elif provider.default_temperature:
375 body["temperature"] = provider.default_temperature
376 if request.max_tokens:
377 body["max_tokens"] = request.max_tokens
378 elif provider.default_max_tokens:
379 body["max_tokens"] = provider.default_max_tokens
380 else:
381 # Use native Ollama API
382 url = f"{base_url}/api/chat"
383 headers = {"Content-Type": "application/json"}
384 body = {
385 "model": model.model_id,
386 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages],
387 "stream": request.stream,
388 }
389 options = {}
390 if request.temperature is not None:
391 options["temperature"] = request.temperature
392 elif provider.default_temperature:
393 options["temperature"] = provider.default_temperature
394 if options:
395 body["options"] = options
397 return url, headers, body
399 async def chat_completion(
400 self,
401 db: Session,
402 request: ChatCompletionRequest,
403 ) -> ChatCompletionResponse:
404 """Process a chat completion request (non-streaming).
406 Args:
407 db: Database session.
408 request: Chat completion request.
410 Returns:
411 ChatCompletionResponse.
413 Raises:
414 LLMProxyRequestError: If request fails.
415 """
416 if not self._client:
417 await self.initialize()
419 provider, model = self._resolve_model(db, request.model)
421 # Build request based on provider type
422 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
423 url, headers, body = self._build_azure_request(request, provider, model)
424 elif provider.provider_type == LLMProviderType.ANTHROPIC:
425 url, headers, body = self._build_anthropic_request(request, provider, model)
426 elif provider.provider_type == LLMProviderType.OLLAMA:
427 url, headers, body = self._build_ollama_request(request, provider, model)
428 else:
429 # Default to OpenAI-compatible
430 url, headers, body = self._build_openai_request(request, provider, model)
432 # Ensure non-streaming
433 body["stream"] = False
435 # Validate the constructed URL to prevent SSRF attacks
436 try:
437 SecurityValidator.validate_url(url, "LLM provider URL")
438 except ValueError as url_err:
439 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err
441 try:
442 response = await self._client.post(url, headers=headers, json=body)
443 response.raise_for_status()
444 data = response.json()
446 # Transform response based on provider
447 if provider.provider_type == LLMProviderType.ANTHROPIC:
448 return self._transform_anthropic_response(data, model.model_id)
449 if provider.provider_type == LLMProviderType.OLLAMA:
450 # Check if using OpenAI-compatible endpoint
451 base_url = (provider.api_base or "").rstrip("/")
452 if base_url.endswith("/v1"):
453 return self._transform_openai_response(data)
454 return self._transform_ollama_response(data, model.model_id)
455 return self._transform_openai_response(data)
457 except httpx.HTTPStatusError as e:
458 logger.error(f"LLM request failed: {e.response.status_code} - {e.response.text}")
459 raise LLMProxyRequestError(f"Request failed: {e.response.status_code}")
460 except httpx.RequestError as e:
461 logger.error(f"LLM request error: {e}")
462 raise LLMProxyRequestError(f"Connection error: {str(e)}")
464 async def chat_completion_stream(
465 self,
466 db: Session,
467 request: ChatCompletionRequest,
468 ) -> AsyncGenerator[str, None]:
469 """Process a streaming chat completion request.
471 Args:
472 db: Database session.
473 request: Chat completion request.
475 Yields:
476 SSE-formatted string chunks.
478 Raises:
479 LLMProxyRequestError: If request fails.
480 """
481 if not self._client:
482 await self.initialize()
484 provider, model = self._resolve_model(db, request.model)
486 # Build request based on provider type
487 if provider.provider_type == LLMProviderType.AZURE_OPENAI:
488 url, headers, body = self._build_azure_request(request, provider, model)
489 elif provider.provider_type == LLMProviderType.ANTHROPIC:
490 url, headers, body = self._build_anthropic_request(request, provider, model)
491 elif provider.provider_type == LLMProviderType.OLLAMA:
492 url, headers, body = self._build_ollama_request(request, provider, model)
493 else:
494 url, headers, body = self._build_openai_request(request, provider, model)
496 # Ensure streaming
497 body["stream"] = True
499 # Validate the constructed URL to prevent SSRF attacks
500 try:
501 SecurityValidator.validate_url(url, "LLM provider URL")
502 except ValueError as url_err:
503 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err
505 response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}"
506 created = int(time.time())
508 try:
509 async with self._client.stream("POST", url, headers=headers, json=body) as response:
510 response.raise_for_status()
512 async for line in response.aiter_lines():
513 if not line:
514 continue
516 # Handle SSE format
517 if line.startswith("data:"):
518 data_str = line[5:]
519 if data_str.startswith(" "):
520 data_str = data_str[1:]
521 if data_str.strip() == "[DONE]":
522 yield "data: [DONE]\n\n"
523 break
525 try:
526 data = orjson.loads(data_str)
528 # Transform based on provider
529 if provider.provider_type == LLMProviderType.ANTHROPIC:
530 chunk = self._transform_anthropic_stream_chunk(data, response_id, created, model.model_id)
531 elif provider.provider_type == LLMProviderType.OLLAMA:
532 # Check if using OpenAI-compatible endpoint
533 base_url = (provider.api_base or "").rstrip("/")
534 if base_url.endswith("/v1"):
535 chunk = data_str # Already OpenAI format
536 else:
537 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
538 else:
539 chunk = data_str
541 if chunk:
542 yield f"data: {chunk}\n\n"
544 except orjson.JSONDecodeError:
545 continue
547 # Handle Ollama's newline-delimited JSON (native API only)
548 elif provider.provider_type == LLMProviderType.OLLAMA:
549 base_url = (provider.api_base or "").rstrip("/")
550 if not base_url.endswith("/v1"):
551 try:
552 data = orjson.loads(line)
553 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id)
554 if chunk:
555 yield f"data: {chunk}\n\n"
556 except orjson.JSONDecodeError:
557 continue
559 except httpx.HTTPStatusError as e:
560 error_chunk = {
561 "error": {
562 "message": f"Request failed: {e.response.status_code}",
563 "type": "proxy_error",
564 }
565 }
566 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
567 except httpx.RequestError as e:
568 error_chunk = {
569 "error": {
570 "message": f"Connection error: {str(e)}",
571 "type": "proxy_error",
572 }
573 }
574 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n"
576 def _transform_openai_response(self, data: Dict[str, Any]) -> ChatCompletionResponse:
577 """Transform OpenAI response to standard format.
579 Args:
580 data: Raw OpenAI API response data.
582 Returns:
583 ChatCompletionResponse in standard format.
584 """
585 choices = []
586 for choice in data.get("choices", []):
587 message_data = choice.get("message", {})
588 choices.append(
589 ChatChoice(
590 index=choice.get("index", 0),
591 message=ChatMessage(
592 role=message_data.get("role", "assistant"),
593 content=message_data.get("content"),
594 tool_calls=message_data.get("tool_calls"),
595 ),
596 finish_reason=choice.get("finish_reason"),
597 )
598 )
600 usage_data = data.get("usage", {})
601 usage = UsageStats(
602 prompt_tokens=usage_data.get("prompt_tokens", 0),
603 completion_tokens=usage_data.get("completion_tokens", 0),
604 total_tokens=usage_data.get("total_tokens", 0),
605 )
607 return ChatCompletionResponse(
608 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
609 created=data.get("created", int(time.time())),
610 model=data.get("model", "unknown"),
611 choices=choices,
612 usage=usage,
613 )
615 def _transform_anthropic_response(
616 self,
617 data: Dict[str, Any],
618 model_id: str,
619 ) -> ChatCompletionResponse:
620 """Transform Anthropic response to OpenAI format.
622 Args:
623 data: Raw Anthropic API response data.
624 model_id: Model ID to include in response.
626 Returns:
627 ChatCompletionResponse in OpenAI format.
628 """
629 content = ""
630 for block in data.get("content", []):
631 if block.get("type") == "text":
632 content += block.get("text", "")
634 usage_data = data.get("usage", {})
636 return ChatCompletionResponse(
637 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"),
638 created=int(time.time()),
639 model=model_id,
640 choices=[
641 ChatChoice(
642 index=0,
643 message=ChatMessage(role="assistant", content=content),
644 finish_reason=data.get("stop_reason", "stop"),
645 )
646 ],
647 usage=UsageStats(
648 prompt_tokens=usage_data.get("input_tokens", 0),
649 completion_tokens=usage_data.get("output_tokens", 0),
650 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0),
651 ),
652 )
654 def _transform_ollama_response(
655 self,
656 data: Dict[str, Any],
657 model_id: str,
658 ) -> ChatCompletionResponse:
659 """Transform Ollama response to OpenAI format.
661 Args:
662 data: Raw Ollama API response data.
663 model_id: Model ID to include in response.
665 Returns:
666 ChatCompletionResponse in OpenAI format.
667 """
668 message = data.get("message", {})
670 return ChatCompletionResponse(
671 id=f"chatcmpl-{uuid.uuid4().hex[:24]}",
672 created=int(time.time()),
673 model=model_id,
674 choices=[
675 ChatChoice(
676 index=0,
677 message=ChatMessage(
678 role=message.get("role", "assistant"),
679 content=message.get("content", ""),
680 ),
681 finish_reason="stop" if data.get("done") else None,
682 )
683 ],
684 usage=UsageStats(
685 prompt_tokens=data.get("prompt_eval_count", 0),
686 completion_tokens=data.get("eval_count", 0),
687 total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
688 ),
689 )
691 def _transform_anthropic_stream_chunk(
692 self,
693 data: Dict[str, Any],
694 response_id: str,
695 created: int,
696 model_id: str,
697 ) -> Optional[str]:
698 """Transform Anthropic streaming chunk to OpenAI format.
700 Args:
701 data: Raw Anthropic streaming event data.
702 response_id: Response ID for the chunk.
703 created: Timestamp for the response.
704 model_id: Model ID to include in response.
706 Returns:
707 JSON string chunk in OpenAI format, or None if not applicable.
708 """
709 event_type = data.get("type")
711 if event_type == "content_block_delta":
712 delta = data.get("delta", {})
713 if delta.get("type") == "text_delta":
714 chunk = {
715 "id": response_id,
716 "object": "chat.completion.chunk",
717 "created": created,
718 "model": model_id,
719 "choices": [
720 {
721 "index": 0,
722 "delta": {"content": delta.get("text", "")},
723 "finish_reason": None,
724 }
725 ],
726 }
727 return orjson.dumps(chunk).decode()
729 elif event_type == "message_stop":
730 chunk = {
731 "id": response_id,
732 "object": "chat.completion.chunk",
733 "created": created,
734 "model": model_id,
735 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
736 }
737 return orjson.dumps(chunk).decode()
739 return None
741 def _transform_ollama_stream_chunk(
742 self,
743 data: Dict[str, Any],
744 response_id: str,
745 created: int,
746 model_id: str,
747 ) -> Optional[str]:
748 """Transform Ollama streaming chunk to OpenAI format.
750 Args:
751 data: Raw Ollama streaming event data.
752 response_id: Response ID for the chunk.
753 created: Timestamp for the response.
754 model_id: Model ID to include in response.
756 Returns:
757 JSON string chunk in OpenAI format, or None if not applicable.
758 """
759 message = data.get("message", {})
760 content = message.get("content", "")
762 if data.get("done"):
763 chunk = {
764 "id": response_id,
765 "object": "chat.completion.chunk",
766 "created": created,
767 "model": model_id,
768 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
769 }
770 else:
771 chunk = {
772 "id": response_id,
773 "object": "chat.completion.chunk",
774 "created": created,
775 "model": model_id,
776 "choices": [
777 {
778 "index": 0,
779 "delta": {"content": content} if content else {},
780 "finish_reason": None,
781 }
782 ],
783 }
785 return orjson.dumps(chunk).decode()