Coverage for mcpgateway / services / llm_provider_service.py: 99%
384 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/llm_provider_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Provider Service
8This module implements LLM provider management for ContextForge.
9It handles provider registration, CRUD operations, model management,
10and health checks for the internal LLM Chat feature.
11"""
13# Standard
14from datetime import datetime, timezone
15import re
16from typing import Any, Dict, List, Optional, Tuple
18# Third-Party
19import httpx
20from sqlalchemy import and_, func, select
21from sqlalchemy.exc import IntegrityError
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.common.validators import SecurityValidator
26from mcpgateway.config import settings
27from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType
28from mcpgateway.llm_provider_configs import PROVIDER_CONFIGS
29from mcpgateway.llm_schemas import (
30 GatewayModelInfo,
31 HealthStatus,
32 LLMModelCreate,
33 LLMModelResponse,
34 LLMModelUpdate,
35 LLMProviderCreate,
36 LLMProviderResponse,
37 LLMProviderUpdate,
38 ProviderHealthCheck,
39)
40from mcpgateway.services.logging_service import LoggingService
41from mcpgateway.utils.create_slug import slugify
42from mcpgateway.utils.services_auth import decode_auth, encode_auth
44# Initialize logging
45logging_service = LoggingService()
46logger = logging_service.get_logger(__name__)
48_ENCRYPTED_PROVIDER_CONFIG_KEY = "_mcpgateway_encrypted_value_v1"
49_PROVIDER_CONFIG_DATA_KEY = "data"
50_PROVIDER_CONFIG_LEGACY_VALUE_KEY = "value"
51_BASE_SENSITIVE_PROVIDER_CONFIG_KEYS = {
52 "api_key",
53 "auth_token",
54 "authorization",
55 "access_token",
56 "refresh_token",
57 "client_secret",
58 "secret_access_key",
59 "session_token",
60 "credentials_json",
61 "password",
62 "private_key",
63 "aws_secret_access_key",
64 "aws_session_token",
65}
68def _normalize_provider_config_key(key: str) -> str:
69 """Normalize provider config key names for matching.
71 Args:
72 key: Raw provider config field name.
74 Returns:
75 Canonical lowercase key using underscore separators.
76 """
77 normalized = str(key).strip().lower()
78 normalized = re.sub(r"[^a-z0-9]+", "_", normalized)
79 return normalized.strip("_")
82def _build_sensitive_provider_config_keys() -> set[str]:
83 """Build normalized sensitive key set from defaults and provider schemas.
85 Returns:
86 Set of normalized keys that should be treated as sensitive.
87 """
88 sensitive_keys = {_normalize_provider_config_key(key) for key in _BASE_SENSITIVE_PROVIDER_CONFIG_KEYS}
89 for provider_definition in PROVIDER_CONFIGS.values():
90 for field_definition in provider_definition.config_fields:
91 key_name = _normalize_provider_config_key(field_definition.name)
92 if field_definition.field_type == "password":
93 sensitive_keys.add(key_name)
94 return sensitive_keys
97_SENSITIVE_PROVIDER_CONFIG_KEYS = frozenset(_build_sensitive_provider_config_keys())
100def _is_sensitive_provider_config_key(key: str) -> bool:
101 """Return whether a provider config key is sensitive.
103 Args:
104 key: Candidate provider config key.
106 Returns:
107 ``True`` when key should be protected; otherwise ``False``.
108 """
109 return _normalize_provider_config_key(key) in _SENSITIVE_PROVIDER_CONFIG_KEYS
112def _is_encrypted_provider_config_value(value: Any) -> bool:
113 """Return whether a config fragment is an encrypted envelope.
115 Args:
116 value: Config fragment to inspect.
118 Returns:
119 ``True`` when fragment matches encrypted envelope structure.
120 """
121 return isinstance(value, dict) and isinstance(value.get(_ENCRYPTED_PROVIDER_CONFIG_KEY), str)
124def _encrypt_provider_config_secret(value: Any, existing_value: Any = None) -> Any:
125 """Encrypt a single sensitive provider config value.
127 Args:
128 value: Incoming value from create/update payload.
129 existing_value: Existing stored value for masked-value merge behavior.
131 Returns:
132 Encrypted envelope, preserved existing value, or ``None`` for explicit clear.
133 """
134 if value is None or value == "":
135 return value
137 if value == settings.masked_auth_value:
138 if existing_value in (None, ""):
139 return None
140 if _is_encrypted_provider_config_value(existing_value):
141 return existing_value
142 return _encrypt_provider_config_secret(existing_value, None)
144 if _is_encrypted_provider_config_value(value):
145 return value
147 encrypted = encode_auth({_PROVIDER_CONFIG_DATA_KEY: value})
148 return {_ENCRYPTED_PROVIDER_CONFIG_KEY: encrypted}
151def _protect_provider_config_fragment(config_fragment: Any, existing_fragment: Any = None) -> Any:
152 """Recursively protect sensitive provider config values.
154 Args:
155 config_fragment: Incoming config fragment to protect.
156 existing_fragment: Existing persisted fragment for merge behavior.
158 Returns:
159 Config fragment with sensitive values protected.
160 """
161 if isinstance(config_fragment, dict):
162 existing_dict = existing_fragment if isinstance(existing_fragment, dict) else {}
163 protected: Dict[str, Any] = {}
164 for key, value in config_fragment.items():
165 existing_value = existing_dict.get(key)
166 if _is_sensitive_provider_config_key(key):
167 protected[key] = _encrypt_provider_config_secret(value, existing_value)
168 else:
169 protected[key] = _protect_provider_config_fragment(value, existing_value)
170 return protected
172 if isinstance(config_fragment, list):
173 existing_list = existing_fragment if isinstance(existing_fragment, list) else []
174 return [_protect_provider_config_fragment(value, existing_list[idx] if idx < len(existing_list) else None) for idx, value in enumerate(config_fragment)]
176 return config_fragment
179def protect_provider_config_for_storage(config: Optional[Dict[str, Any]], existing_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
180 """Encrypt sensitive provider config fields before database persistence.
182 Args:
183 config: Incoming provider configuration payload.
184 existing_config: Existing stored configuration used for masked-value merges.
186 Returns:
187 Provider config structure with sensitive fields protected for storage.
188 """
189 if not isinstance(config, dict):
190 return {}
191 return _protect_provider_config_fragment(config, existing_config)
194def _decrypt_provider_config_fragment(config_fragment: Any) -> Any:
195 """Recursively decrypt provider config fragments for runtime usage.
197 Args:
198 config_fragment: Stored config fragment, possibly encrypted.
200 Returns:
201 Runtime-ready fragment with decryptable values restored.
202 """
203 if _is_encrypted_provider_config_value(config_fragment):
204 encrypted_payload = config_fragment.get(_ENCRYPTED_PROVIDER_CONFIG_KEY)
205 try:
206 decoded = decode_auth(encrypted_payload)
207 if isinstance(decoded, dict):
208 if _PROVIDER_CONFIG_DATA_KEY in decoded:
209 return decoded[_PROVIDER_CONFIG_DATA_KEY]
210 if _PROVIDER_CONFIG_LEGACY_VALUE_KEY in decoded:
211 return decoded[_PROVIDER_CONFIG_LEGACY_VALUE_KEY]
212 except Exception as exc:
213 logger.warning("Failed to decrypt provider config fragment: %s", exc)
214 return config_fragment
216 if isinstance(config_fragment, dict):
217 return {key: _decrypt_provider_config_fragment(value) for key, value in config_fragment.items()}
219 if isinstance(config_fragment, list):
220 return [_decrypt_provider_config_fragment(value) for value in config_fragment]
222 return config_fragment
225def decrypt_provider_config_for_runtime(config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
226 """Return runtime-ready provider config with encrypted fields decrypted.
228 Args:
229 config: Stored provider configuration payload.
231 Returns:
232 Provider config with decryptable sensitive fields restored.
233 """
234 if not isinstance(config, dict):
235 return {}
236 return _decrypt_provider_config_fragment(config)
239def _mask_provider_config_fragment(config_fragment: Any) -> Any:
240 """Recursively mask sensitive provider config values for API responses.
242 Args:
243 config_fragment: Runtime config fragment.
245 Returns:
246 Fragment with sensitive values replaced by mask markers.
247 """
248 if isinstance(config_fragment, dict):
249 masked: Dict[str, Any] = {}
250 for key, value in config_fragment.items():
251 if _is_sensitive_provider_config_key(key):
252 masked[key] = settings.masked_auth_value if value not in (None, "") else value
253 else:
254 masked[key] = _mask_provider_config_fragment(value)
255 return masked
257 if isinstance(config_fragment, list):
258 return [_mask_provider_config_fragment(value) for value in config_fragment]
260 return config_fragment
263def sanitize_provider_config_for_response(config: Optional[Dict[str, Any]]) -> Dict[str, Any]:
264 """Return API-safe provider config with sensitive fields masked.
266 Args:
267 config: Stored provider configuration payload.
269 Returns:
270 Provider config suitable for API responses with masked secrets.
271 """
272 runtime_config = decrypt_provider_config_for_runtime(config)
273 return _mask_provider_config_fragment(runtime_config)
276class LLMProviderError(Exception):
277 """Base class for LLM provider-related errors."""
280class LLMProviderNotFoundError(LLMProviderError):
281 """Raised when a requested LLM provider is not found."""
284class LLMProviderNameConflictError(LLMProviderError):
285 """Raised when an LLM provider name conflicts with an existing one."""
287 def __init__(self, name: str, provider_id: Optional[str] = None):
288 """Initialize the exception.
290 Args:
291 name: The conflicting provider name.
292 provider_id: Optional ID of the existing provider.
293 """
294 self.name = name
295 self.provider_id = provider_id
296 message = f"LLM Provider already exists with name: {name}"
297 if provider_id:
298 message += f" (ID: {provider_id})"
299 super().__init__(message)
302class LLMProviderValidationError(LLMProviderError, ValueError):
303 """Raised when provider payload validation fails."""
306class LLMModelNotFoundError(LLMProviderError):
307 """Raised when a requested LLM model is not found."""
310class LLMModelConflictError(LLMProviderError):
311 """Raised when an LLM model conflicts with an existing one."""
314class LLMProviderService:
315 """Service for managing LLM providers and models.
317 Provides methods to create, list, retrieve, update, and delete
318 provider and model records. Also supports health checks.
319 """
321 def __init__(self) -> None:
322 """Initialize a new LLMProviderService instance."""
323 self._initialized = False
325 @staticmethod
326 def _validate_provider_api_base(api_base: Optional[str]) -> None:
327 """Validate provider api_base against core URL + SSRF rules.
329 Args:
330 api_base: Provider base URL to validate.
332 Raises:
333 LLMProviderValidationError: If URL fails core or SSRF validation.
334 """
335 if api_base:
336 try:
337 SecurityValidator.validate_url(api_base, "Provider API base URL")
338 except ValueError as exc:
339 raise LLMProviderValidationError(str(exc)) from exc
341 async def initialize(self) -> None:
342 """Initialize the LLM provider service."""
343 if not self._initialized:
344 logger.info("Initializing LLM Provider Service")
345 self._initialized = True
347 async def shutdown(self) -> None:
348 """Shutdown the LLM provider service."""
349 if self._initialized:
350 logger.info("Shutting down LLM Provider Service")
351 self._initialized = False
353 # ---------------------------------------------------------------------------
354 # Provider CRUD Operations
355 # ---------------------------------------------------------------------------
357 def create_provider(
358 self,
359 db: Session,
360 provider_data: LLMProviderCreate,
361 created_by: Optional[str] = None,
362 ) -> LLMProvider:
363 """Create a new LLM provider.
365 Args:
366 db: Database session.
367 provider_data: Provider data to create.
368 created_by: Username of creator.
370 Returns:
371 Created LLMProvider instance.
373 Raises:
374 LLMProviderNameConflictError: If provider name already exists.
375 """
376 # Check for name conflict
377 existing = db.execute(select(LLMProvider).where(LLMProvider.name == provider_data.name)).scalar_one_or_none()
379 if existing:
380 raise LLMProviderNameConflictError(provider_data.name, existing.id)
382 # Encrypt API key if provided
383 encrypted_api_key = None
384 if provider_data.api_key:
385 encrypted_api_key = encode_auth({"api_key": provider_data.api_key})
387 self._validate_provider_api_base(provider_data.api_base)
389 # Create provider
390 provider = LLMProvider(
391 name=provider_data.name,
392 slug=slugify(provider_data.name),
393 description=provider_data.description,
394 provider_type=provider_data.provider_type.value,
395 api_key=encrypted_api_key,
396 api_base=provider_data.api_base,
397 api_version=provider_data.api_version,
398 config=protect_provider_config_for_storage(provider_data.config),
399 default_model=provider_data.default_model,
400 default_temperature=provider_data.default_temperature,
401 default_max_tokens=provider_data.default_max_tokens,
402 enabled=provider_data.enabled,
403 plugin_ids=provider_data.plugin_ids,
404 created_by=created_by,
405 )
407 try:
408 db.add(provider)
409 db.commit()
410 db.refresh(provider)
411 logger.info(f"Created LLM provider: {provider.name} (ID: {provider.id})")
412 return provider
413 except IntegrityError as e:
414 db.rollback()
415 logger.error(f"Failed to create LLM provider: {e}")
416 raise LLMProviderNameConflictError(provider_data.name)
418 def get_provider(self, db: Session, provider_id: str) -> LLMProvider:
419 """Get an LLM provider by ID.
421 Args:
422 db: Database session.
423 provider_id: Provider ID to retrieve.
425 Returns:
426 LLMProvider instance.
428 Raises:
429 LLMProviderNotFoundError: If provider not found.
430 """
431 provider = db.execute(select(LLMProvider).where(LLMProvider.id == provider_id)).scalar_one_or_none()
433 if not provider:
434 raise LLMProviderNotFoundError(f"Provider not found: {provider_id}")
436 return provider
438 def get_provider_by_slug(self, db: Session, slug: str) -> LLMProvider:
439 """Get an LLM provider by slug.
441 Args:
442 db: Database session.
443 slug: Provider slug to retrieve.
445 Returns:
446 LLMProvider instance.
448 Raises:
449 LLMProviderNotFoundError: If provider not found.
450 """
451 provider = db.execute(select(LLMProvider).where(LLMProvider.slug == slug)).scalar_one_or_none()
453 if not provider:
454 raise LLMProviderNotFoundError(f"Provider not found: {slug}")
456 return provider
458 def list_providers(
459 self,
460 db: Session,
461 enabled_only: bool = False,
462 page: int = 1,
463 page_size: int = 50,
464 ) -> Tuple[List[LLMProvider], int]:
465 """List all LLM providers.
467 Args:
468 db: Database session.
469 enabled_only: Only return enabled providers.
470 page: Page number (1-indexed).
471 page_size: Items per page.
473 Returns:
474 Tuple of (providers list, total count).
475 """
476 query = select(LLMProvider)
478 if enabled_only:
479 query = query.where(LLMProvider.enabled.is_(True))
481 # Get total count efficiently using func.count()
482 count_query = select(func.count(LLMProvider.id)) # pylint: disable=not-callable
483 if enabled_only:
484 count_query = count_query.where(LLMProvider.enabled.is_(True))
485 total = db.execute(count_query).scalar() or 0
487 # Apply pagination
488 offset = (page - 1) * page_size
489 query = query.offset(offset).limit(page_size).order_by(LLMProvider.name)
491 providers = list(db.execute(query).scalars().all())
492 return providers, total
494 def update_provider(
495 self,
496 db: Session,
497 provider_id: str,
498 provider_data: LLMProviderUpdate,
499 modified_by: Optional[str] = None,
500 ) -> LLMProvider:
501 """Update an LLM provider.
503 Args:
504 db: Database session.
505 provider_id: Provider ID to update.
506 provider_data: Updated provider data.
507 modified_by: Username of modifier.
509 Returns:
510 Updated LLMProvider instance.
512 Raises:
513 LLMProviderNotFoundError: If provider not found.
514 LLMProviderNameConflictError: If new name conflicts.
515 IntegrityError: If database constraint violation.
516 """
517 provider = self.get_provider(db, provider_id)
519 # Check for name conflict if name is being changed
520 if provider_data.name and provider_data.name != provider.name:
521 existing = db.execute(
522 select(LLMProvider).where(
523 and_(
524 LLMProvider.name == provider_data.name,
525 LLMProvider.id != provider_id,
526 )
527 )
528 ).scalar_one_or_none()
530 if existing:
531 raise LLMProviderNameConflictError(provider_data.name, existing.id)
533 provider.name = provider_data.name
534 provider.slug = slugify(provider_data.name)
536 # Update fields if provided
537 if provider_data.description is not None:
538 provider.description = provider_data.description
539 if provider_data.provider_type is not None:
540 provider.provider_type = provider_data.provider_type.value
541 if provider_data.api_key is not None:
542 provider.api_key = encode_auth({"api_key": provider_data.api_key})
543 if provider_data.api_base is not None:
544 self._validate_provider_api_base(provider_data.api_base)
545 provider.api_base = provider_data.api_base
546 if provider_data.api_version is not None:
547 provider.api_version = provider_data.api_version
548 if provider_data.config is not None:
549 provider.config = protect_provider_config_for_storage(
550 provider_data.config,
551 existing_config=provider.config if isinstance(provider.config, dict) else None,
552 )
553 if provider_data.default_model is not None:
554 provider.default_model = provider_data.default_model
555 if provider_data.default_temperature is not None:
556 provider.default_temperature = provider_data.default_temperature
557 if provider_data.default_max_tokens is not None:
558 provider.default_max_tokens = provider_data.default_max_tokens
559 if provider_data.enabled is not None:
560 provider.enabled = provider_data.enabled
561 if provider_data.plugin_ids is not None:
562 provider.plugin_ids = provider_data.plugin_ids
564 provider.modified_by = modified_by
566 try:
567 db.commit()
568 db.refresh(provider)
569 logger.info(f"Updated LLM provider: {provider.name} (ID: {provider.id})")
570 return provider
571 except IntegrityError as e:
572 db.rollback()
573 logger.error(f"Failed to update LLM provider: {e}")
574 raise
576 def delete_provider(self, db: Session, provider_id: str) -> bool:
577 """Delete an LLM provider.
579 Args:
580 db: Database session.
581 provider_id: Provider ID to delete.
583 Returns:
584 True if deleted successfully.
586 Raises:
587 LLMProviderNotFoundError: If provider not found.
588 """
589 provider = self.get_provider(db, provider_id)
590 provider_name = provider.name
592 db.delete(provider)
593 db.commit()
594 logger.info(f"Deleted LLM provider: {provider_name} (ID: {provider_id})")
595 return True
597 def set_provider_state(self, db: Session, provider_id: str, activate: Optional[bool] = None) -> LLMProvider:
598 """Set provider enabled state.
600 Args:
601 db: Database session.
602 provider_id: Provider ID to update.
603 activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior).
605 Returns:
606 Updated LLMProvider instance.
607 """
608 provider = self.get_provider(db, provider_id)
609 if activate is None:
610 # Legacy toggle behavior for backward compatibility
611 provider.enabled = not provider.enabled
612 else:
613 provider.enabled = activate
614 db.commit()
615 db.refresh(provider)
616 logger.info(f"Set LLM provider state: {provider.name} enabled={provider.enabled}")
617 return provider
619 # ---------------------------------------------------------------------------
620 # Model CRUD Operations
621 # ---------------------------------------------------------------------------
623 def create_model(
624 self,
625 db: Session,
626 model_data: LLMModelCreate,
627 ) -> LLMModel:
628 """Create a new LLM model.
630 Args:
631 db: Database session.
632 model_data: Model data to create.
634 Returns:
635 Created LLMModel instance.
637 Raises:
638 LLMProviderNotFoundError: If provider not found.
639 LLMModelConflictError: If model already exists for provider.
640 """
641 # Verify provider exists
642 self.get_provider(db, model_data.provider_id)
644 # Check for conflict
645 existing = db.execute(
646 select(LLMModel).where(
647 and_(
648 LLMModel.provider_id == model_data.provider_id,
649 LLMModel.model_id == model_data.model_id,
650 )
651 )
652 ).scalar_one_or_none()
654 if existing:
655 raise LLMModelConflictError(f"Model {model_data.model_id} already exists for provider {model_data.provider_id}")
657 model = LLMModel(
658 provider_id=model_data.provider_id,
659 model_id=model_data.model_id,
660 model_name=model_data.model_name,
661 model_alias=model_data.model_alias,
662 description=model_data.description,
663 supports_chat=model_data.supports_chat,
664 supports_streaming=model_data.supports_streaming,
665 supports_function_calling=model_data.supports_function_calling,
666 supports_vision=model_data.supports_vision,
667 context_window=model_data.context_window,
668 max_output_tokens=model_data.max_output_tokens,
669 enabled=model_data.enabled,
670 deprecated=model_data.deprecated,
671 )
673 try:
674 db.add(model)
675 db.commit()
676 db.refresh(model)
677 logger.info(f"Created LLM model: {model.model_id} (ID: {model.id})")
678 return model
679 except IntegrityError as e:
680 db.rollback()
681 logger.error(f"Failed to create LLM model: {e}")
682 raise LLMModelConflictError(f"Model conflict: {model_data.model_id}")
684 def get_model(self, db: Session, model_id: str) -> LLMModel:
685 """Get an LLM model by ID.
687 Args:
688 db: Database session.
689 model_id: Model ID to retrieve.
691 Returns:
692 LLMModel instance.
694 Raises:
695 LLMModelNotFoundError: If model not found.
696 """
697 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none()
699 if not model:
700 raise LLMModelNotFoundError(f"Model not found: {model_id}")
702 return model
704 def list_models(
705 self,
706 db: Session,
707 provider_id: Optional[str] = None,
708 enabled_only: bool = False,
709 page: int = 1,
710 page_size: int = 50,
711 ) -> Tuple[List[LLMModel], int]:
712 """List LLM models.
714 Args:
715 db: Database session.
716 provider_id: Filter by provider ID.
717 enabled_only: Only return enabled models.
718 page: Page number (1-indexed).
719 page_size: Items per page.
721 Returns:
722 Tuple of (models list, total count).
723 """
724 query = select(LLMModel)
726 if provider_id:
727 query = query.where(LLMModel.provider_id == provider_id)
728 if enabled_only:
729 query = query.where(LLMModel.enabled.is_(True))
731 # Get total count efficiently using func.count()
732 count_query = select(func.count(LLMModel.id)) # pylint: disable=not-callable
733 if provider_id:
734 count_query = count_query.where(LLMModel.provider_id == provider_id)
735 if enabled_only:
736 count_query = count_query.where(LLMModel.enabled.is_(True))
737 total = db.execute(count_query).scalar() or 0
739 # Apply pagination
740 offset = (page - 1) * page_size
741 query = query.offset(offset).limit(page_size).order_by(LLMModel.model_name)
743 models = list(db.execute(query).scalars().all())
744 return models, total
746 def update_model(
747 self,
748 db: Session,
749 model_id: str,
750 model_data: LLMModelUpdate,
751 ) -> LLMModel:
752 """Update an LLM model.
754 Args:
755 db: Database session.
756 model_id: Model ID to update.
757 model_data: Updated model data.
759 Returns:
760 Updated LLMModel instance.
761 """
762 model = self.get_model(db, model_id)
764 if model_data.model_id is not None:
765 model.model_id = model_data.model_id
766 if model_data.model_name is not None:
767 model.model_name = model_data.model_name
768 if model_data.model_alias is not None:
769 model.model_alias = model_data.model_alias
770 if model_data.description is not None:
771 model.description = model_data.description
772 if model_data.supports_chat is not None:
773 model.supports_chat = model_data.supports_chat
774 if model_data.supports_streaming is not None:
775 model.supports_streaming = model_data.supports_streaming
776 if model_data.supports_function_calling is not None:
777 model.supports_function_calling = model_data.supports_function_calling
778 if model_data.supports_vision is not None:
779 model.supports_vision = model_data.supports_vision
780 if model_data.context_window is not None:
781 model.context_window = model_data.context_window
782 if model_data.max_output_tokens is not None:
783 model.max_output_tokens = model_data.max_output_tokens
784 if model_data.enabled is not None:
785 model.enabled = model_data.enabled
786 if model_data.deprecated is not None:
787 model.deprecated = model_data.deprecated
789 db.commit()
790 db.refresh(model)
791 logger.info(f"Updated LLM model: {model.model_id} (ID: {model.id})")
792 return model
794 def delete_model(self, db: Session, model_id: str) -> bool:
795 """Delete an LLM model.
797 Args:
798 db: Database session.
799 model_id: Model ID to delete.
801 Returns:
802 True if deleted successfully.
803 """
804 model = self.get_model(db, model_id)
805 model_name = model.model_id
807 db.delete(model)
808 db.commit()
809 logger.info(f"Deleted LLM model: {model_name} (ID: {model_id})")
810 return True
812 def set_model_state(self, db: Session, model_id: str, activate: Optional[bool] = None) -> LLMModel:
813 """Set model enabled state.
815 Args:
816 db: Database session.
817 model_id: Model ID to update.
818 activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior).
820 Returns:
821 Updated LLMModel instance.
822 """
823 model = self.get_model(db, model_id)
824 if activate is None:
825 # Legacy toggle behavior for backward compatibility
826 model.enabled = not model.enabled
827 else:
828 model.enabled = activate
829 db.commit()
830 db.refresh(model)
831 logger.info(f"Set LLM model state: {model.model_id} enabled={model.enabled}")
832 return model
834 # ---------------------------------------------------------------------------
835 # Gateway Models (for LLM Chat dropdown)
836 # ---------------------------------------------------------------------------
838 def get_gateway_models(self, db: Session) -> List[GatewayModelInfo]:
839 """Get enabled models for the LLM Chat dropdown.
841 Args:
842 db: Database session.
844 Returns:
845 List of GatewayModelInfo for enabled models.
846 """
847 # Get enabled models from enabled providers
848 query = (
849 select(LLMModel, LLMProvider)
850 .join(LLMProvider, LLMModel.provider_id == LLMProvider.id)
851 .where(
852 and_(
853 LLMModel.enabled.is_(True),
854 LLMProvider.enabled.is_(True),
855 LLMModel.supports_chat.is_(True),
856 )
857 )
858 .order_by(LLMProvider.name, LLMModel.model_name)
859 )
861 results = db.execute(query).all()
863 models = []
864 for model, provider in results:
865 models.append(
866 GatewayModelInfo(
867 id=model.id,
868 model_id=model.model_id,
869 model_name=model.model_name,
870 provider_id=provider.id,
871 provider_name=provider.name,
872 provider_type=provider.provider_type,
873 supports_streaming=model.supports_streaming,
874 supports_function_calling=model.supports_function_calling,
875 supports_vision=model.supports_vision,
876 )
877 )
879 return models
881 # ---------------------------------------------------------------------------
882 # Health Check Operations
883 # ---------------------------------------------------------------------------
885 async def check_provider_health(
886 self,
887 db: Session,
888 provider_id: str,
889 ) -> ProviderHealthCheck:
890 """Check health of an LLM provider.
892 Args:
893 db: Database session.
894 provider_id: Provider ID to check.
896 Returns:
897 ProviderHealthCheck result.
898 """
899 provider = self.get_provider(db, provider_id)
901 start_time = datetime.now(timezone.utc)
902 status = HealthStatus.UNKNOWN
903 error_msg = None
904 response_time_ms = None
906 try:
907 # Get API key
908 api_key = None
909 if provider.api_key:
910 auth_data = decode_auth(provider.api_key)
911 api_key = auth_data.get("api_key")
913 # Perform health check based on provider type using shared HTTP client
914 # First-Party
915 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
917 client = await get_http_client()
918 if provider.provider_type == LLMProviderType.OPENAI:
919 # Check OpenAI models endpoint
920 headers = {"Authorization": f"Bearer {api_key}"}
921 base_url = provider.api_base or "https://api.openai.com/v1"
922 self._validate_provider_api_base(base_url)
923 response = await client.get(f"{base_url}/models", headers=headers, timeout=10.0)
924 if response.status_code == 200:
925 status = HealthStatus.HEALTHY
926 else:
927 status = HealthStatus.UNHEALTHY
928 error_msg = f"HTTP {response.status_code}"
930 elif provider.provider_type == LLMProviderType.OLLAMA:
931 # Check Ollama health endpoint
932 base_url = provider.api_base or "http://localhost:11434"
933 self._validate_provider_api_base(base_url)
934 # Handle OpenAI-compatible endpoint (/v1)
935 if base_url.rstrip("/").endswith("/v1"):
936 # Use OpenAI-compatible models endpoint
937 response = await client.get(f"{base_url.rstrip('/')}/models", timeout=10.0)
938 else:
939 # Use native Ollama API
940 response = await client.get(f"{base_url.rstrip('/')}/api/tags", timeout=10.0)
941 if response.status_code == 200:
942 status = HealthStatus.HEALTHY
943 else:
944 status = HealthStatus.UNHEALTHY
945 error_msg = f"HTTP {response.status_code}"
947 else:
948 # Generic check - just verify connectivity
949 if provider.api_base:
950 self._validate_provider_api_base(provider.api_base)
951 response = await client.get(provider.api_base, timeout=5.0)
952 status = HealthStatus.HEALTHY if response.status_code < 500 else HealthStatus.UNHEALTHY
953 else:
954 status = HealthStatus.UNKNOWN
955 error_msg = "No API base URL configured"
957 except ValueError as e:
958 status = HealthStatus.UNHEALTHY
959 error_msg = str(e)
960 except httpx.TimeoutException:
961 status = HealthStatus.UNHEALTHY
962 error_msg = "Connection timeout"
963 except httpx.RequestError as e:
964 status = HealthStatus.UNHEALTHY
965 error_msg = f"Connection error: {str(e)}"
966 except Exception as e:
967 status = HealthStatus.UNHEALTHY
968 error_msg = f"Error: {str(e)}"
970 end_time = datetime.now(timezone.utc)
971 response_time_ms = (end_time - start_time).total_seconds() * 1000
973 # Update provider health status
974 provider.health_status = status.value
975 provider.last_health_check = end_time
976 db.commit()
978 return ProviderHealthCheck(
979 provider_id=provider.id,
980 provider_name=provider.name,
981 provider_type=provider.provider_type,
982 status=status,
983 response_time_ms=response_time_ms,
984 error=error_msg,
985 checked_at=end_time,
986 )
988 def to_provider_response(
989 self,
990 provider: LLMProvider,
991 model_count: int = 0,
992 ) -> LLMProviderResponse:
993 """Convert LLMProvider to LLMProviderResponse.
995 Args:
996 provider: LLMProvider instance.
997 model_count: Number of models for this provider.
999 Returns:
1000 LLMProviderResponse instance.
1001 """
1002 return LLMProviderResponse(
1003 id=provider.id,
1004 name=provider.name,
1005 slug=provider.slug,
1006 description=provider.description,
1007 provider_type=provider.provider_type,
1008 api_base=provider.api_base,
1009 api_version=provider.api_version,
1010 config=sanitize_provider_config_for_response(provider.config),
1011 default_model=provider.default_model,
1012 default_temperature=provider.default_temperature,
1013 default_max_tokens=provider.default_max_tokens,
1014 enabled=provider.enabled,
1015 health_status=provider.health_status,
1016 last_health_check=provider.last_health_check,
1017 plugin_ids=provider.plugin_ids,
1018 created_at=provider.created_at,
1019 updated_at=provider.updated_at,
1020 created_by=provider.created_by,
1021 modified_by=provider.modified_by,
1022 model_count=model_count,
1023 )
1025 def to_model_response(
1026 self,
1027 model: LLMModel,
1028 provider: Optional[LLMProvider] = None,
1029 ) -> LLMModelResponse:
1030 """Convert LLMModel to LLMModelResponse.
1032 Args:
1033 model: LLMModel instance.
1034 provider: Optional provider for name/type info.
1036 Returns:
1037 LLMModelResponse instance.
1038 """
1039 return LLMModelResponse(
1040 id=model.id,
1041 provider_id=model.provider_id,
1042 model_id=model.model_id,
1043 model_name=model.model_name,
1044 model_alias=model.model_alias,
1045 description=model.description,
1046 supports_chat=model.supports_chat,
1047 supports_streaming=model.supports_streaming,
1048 supports_function_calling=model.supports_function_calling,
1049 supports_vision=model.supports_vision,
1050 context_window=model.context_window,
1051 max_output_tokens=model.max_output_tokens,
1052 enabled=model.enabled,
1053 deprecated=model.deprecated,
1054 created_at=model.created_at,
1055 updated_at=model.updated_at,
1056 provider_name=provider.name if provider else None,
1057 provider_type=provider.provider_type if provider else None,
1058 )