Coverage for mcpgateway / services / llm_provider_service.py: 99%

384 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/llm_provider_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6LLM Provider Service 

7 

8This module implements LLM provider management for ContextForge. 

9It handles provider registration, CRUD operations, model management, 

10and health checks for the internal LLM Chat feature. 

11""" 

12 

13# Standard 

14from datetime import datetime, timezone 

15import re 

16from typing import Any, Dict, List, Optional, Tuple 

17 

18# Third-Party 

19import httpx 

20from sqlalchemy import and_, func, select 

21from sqlalchemy.exc import IntegrityError 

22from sqlalchemy.orm import Session 

23 

24# First-Party 

25from mcpgateway.common.validators import SecurityValidator 

26from mcpgateway.config import settings 

27from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType 

28from mcpgateway.llm_provider_configs import PROVIDER_CONFIGS 

29from mcpgateway.llm_schemas import ( 

30 GatewayModelInfo, 

31 HealthStatus, 

32 LLMModelCreate, 

33 LLMModelResponse, 

34 LLMModelUpdate, 

35 LLMProviderCreate, 

36 LLMProviderResponse, 

37 LLMProviderUpdate, 

38 ProviderHealthCheck, 

39) 

40from mcpgateway.services.logging_service import LoggingService 

41from mcpgateway.utils.create_slug import slugify 

42from mcpgateway.utils.services_auth import decode_auth, encode_auth 

43 

44# Initialize logging 

45logging_service = LoggingService() 

46logger = logging_service.get_logger(__name__) 

47 

48_ENCRYPTED_PROVIDER_CONFIG_KEY = "_mcpgateway_encrypted_value_v1" 

49_PROVIDER_CONFIG_DATA_KEY = "data" 

50_PROVIDER_CONFIG_LEGACY_VALUE_KEY = "value" 

51_BASE_SENSITIVE_PROVIDER_CONFIG_KEYS = { 

52 "api_key", 

53 "auth_token", 

54 "authorization", 

55 "access_token", 

56 "refresh_token", 

57 "client_secret", 

58 "secret_access_key", 

59 "session_token", 

60 "credentials_json", 

61 "password", 

62 "private_key", 

63 "aws_secret_access_key", 

64 "aws_session_token", 

65} 

66 

67 

68def _normalize_provider_config_key(key: str) -> str: 

69 """Normalize provider config key names for matching. 

70 

71 Args: 

72 key: Raw provider config field name. 

73 

74 Returns: 

75 Canonical lowercase key using underscore separators. 

76 """ 

77 normalized = str(key).strip().lower() 

78 normalized = re.sub(r"[^a-z0-9]+", "_", normalized) 

79 return normalized.strip("_") 

80 

81 

82def _build_sensitive_provider_config_keys() -> set[str]: 

83 """Build normalized sensitive key set from defaults and provider schemas. 

84 

85 Returns: 

86 Set of normalized keys that should be treated as sensitive. 

87 """ 

88 sensitive_keys = {_normalize_provider_config_key(key) for key in _BASE_SENSITIVE_PROVIDER_CONFIG_KEYS} 

89 for provider_definition in PROVIDER_CONFIGS.values(): 

90 for field_definition in provider_definition.config_fields: 

91 key_name = _normalize_provider_config_key(field_definition.name) 

92 if field_definition.field_type == "password": 

93 sensitive_keys.add(key_name) 

94 return sensitive_keys 

95 

96 

97_SENSITIVE_PROVIDER_CONFIG_KEYS = frozenset(_build_sensitive_provider_config_keys()) 

98 

99 

100def _is_sensitive_provider_config_key(key: str) -> bool: 

101 """Return whether a provider config key is sensitive. 

102 

103 Args: 

104 key: Candidate provider config key. 

105 

106 Returns: 

107 ``True`` when key should be protected; otherwise ``False``. 

108 """ 

109 return _normalize_provider_config_key(key) in _SENSITIVE_PROVIDER_CONFIG_KEYS 

110 

111 

112def _is_encrypted_provider_config_value(value: Any) -> bool: 

113 """Return whether a config fragment is an encrypted envelope. 

114 

115 Args: 

116 value: Config fragment to inspect. 

117 

118 Returns: 

119 ``True`` when fragment matches encrypted envelope structure. 

120 """ 

121 return isinstance(value, dict) and isinstance(value.get(_ENCRYPTED_PROVIDER_CONFIG_KEY), str) 

122 

123 

124def _encrypt_provider_config_secret(value: Any, existing_value: Any = None) -> Any: 

125 """Encrypt a single sensitive provider config value. 

126 

127 Args: 

128 value: Incoming value from create/update payload. 

129 existing_value: Existing stored value for masked-value merge behavior. 

130 

131 Returns: 

132 Encrypted envelope, preserved existing value, or ``None`` for explicit clear. 

133 """ 

134 if value is None or value == "": 

135 return value 

136 

137 if value == settings.masked_auth_value: 

138 if existing_value in (None, ""): 

139 return None 

140 if _is_encrypted_provider_config_value(existing_value): 

141 return existing_value 

142 return _encrypt_provider_config_secret(existing_value, None) 

143 

144 if _is_encrypted_provider_config_value(value): 

145 return value 

146 

147 encrypted = encode_auth({_PROVIDER_CONFIG_DATA_KEY: value}) 

148 return {_ENCRYPTED_PROVIDER_CONFIG_KEY: encrypted} 

149 

150 

151def _protect_provider_config_fragment(config_fragment: Any, existing_fragment: Any = None) -> Any: 

152 """Recursively protect sensitive provider config values. 

153 

154 Args: 

155 config_fragment: Incoming config fragment to protect. 

156 existing_fragment: Existing persisted fragment for merge behavior. 

157 

158 Returns: 

159 Config fragment with sensitive values protected. 

160 """ 

161 if isinstance(config_fragment, dict): 

162 existing_dict = existing_fragment if isinstance(existing_fragment, dict) else {} 

163 protected: Dict[str, Any] = {} 

164 for key, value in config_fragment.items(): 

165 existing_value = existing_dict.get(key) 

166 if _is_sensitive_provider_config_key(key): 

167 protected[key] = _encrypt_provider_config_secret(value, existing_value) 

168 else: 

169 protected[key] = _protect_provider_config_fragment(value, existing_value) 

170 return protected 

171 

172 if isinstance(config_fragment, list): 

173 existing_list = existing_fragment if isinstance(existing_fragment, list) else [] 

174 return [_protect_provider_config_fragment(value, existing_list[idx] if idx < len(existing_list) else None) for idx, value in enumerate(config_fragment)] 

175 

176 return config_fragment 

177 

178 

179def protect_provider_config_for_storage(config: Optional[Dict[str, Any]], existing_config: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: 

180 """Encrypt sensitive provider config fields before database persistence. 

181 

182 Args: 

183 config: Incoming provider configuration payload. 

184 existing_config: Existing stored configuration used for masked-value merges. 

185 

186 Returns: 

187 Provider config structure with sensitive fields protected for storage. 

188 """ 

189 if not isinstance(config, dict): 

190 return {} 

191 return _protect_provider_config_fragment(config, existing_config) 

192 

193 

194def _decrypt_provider_config_fragment(config_fragment: Any) -> Any: 

195 """Recursively decrypt provider config fragments for runtime usage. 

196 

197 Args: 

198 config_fragment: Stored config fragment, possibly encrypted. 

199 

200 Returns: 

201 Runtime-ready fragment with decryptable values restored. 

202 """ 

203 if _is_encrypted_provider_config_value(config_fragment): 

204 encrypted_payload = config_fragment.get(_ENCRYPTED_PROVIDER_CONFIG_KEY) 

205 try: 

206 decoded = decode_auth(encrypted_payload) 

207 if isinstance(decoded, dict): 

208 if _PROVIDER_CONFIG_DATA_KEY in decoded: 

209 return decoded[_PROVIDER_CONFIG_DATA_KEY] 

210 if _PROVIDER_CONFIG_LEGACY_VALUE_KEY in decoded: 

211 return decoded[_PROVIDER_CONFIG_LEGACY_VALUE_KEY] 

212 except Exception as exc: 

213 logger.warning("Failed to decrypt provider config fragment: %s", exc) 

214 return config_fragment 

215 

216 if isinstance(config_fragment, dict): 

217 return {key: _decrypt_provider_config_fragment(value) for key, value in config_fragment.items()} 

218 

219 if isinstance(config_fragment, list): 

220 return [_decrypt_provider_config_fragment(value) for value in config_fragment] 

221 

222 return config_fragment 

223 

224 

225def decrypt_provider_config_for_runtime(config: Optional[Dict[str, Any]]) -> Dict[str, Any]: 

226 """Return runtime-ready provider config with encrypted fields decrypted. 

227 

228 Args: 

229 config: Stored provider configuration payload. 

230 

231 Returns: 

232 Provider config with decryptable sensitive fields restored. 

233 """ 

234 if not isinstance(config, dict): 

235 return {} 

236 return _decrypt_provider_config_fragment(config) 

237 

238 

239def _mask_provider_config_fragment(config_fragment: Any) -> Any: 

240 """Recursively mask sensitive provider config values for API responses. 

241 

242 Args: 

243 config_fragment: Runtime config fragment. 

244 

245 Returns: 

246 Fragment with sensitive values replaced by mask markers. 

247 """ 

248 if isinstance(config_fragment, dict): 

249 masked: Dict[str, Any] = {} 

250 for key, value in config_fragment.items(): 

251 if _is_sensitive_provider_config_key(key): 

252 masked[key] = settings.masked_auth_value if value not in (None, "") else value 

253 else: 

254 masked[key] = _mask_provider_config_fragment(value) 

255 return masked 

256 

257 if isinstance(config_fragment, list): 

258 return [_mask_provider_config_fragment(value) for value in config_fragment] 

259 

260 return config_fragment 

261 

262 

263def sanitize_provider_config_for_response(config: Optional[Dict[str, Any]]) -> Dict[str, Any]: 

264 """Return API-safe provider config with sensitive fields masked. 

265 

266 Args: 

267 config: Stored provider configuration payload. 

268 

269 Returns: 

270 Provider config suitable for API responses with masked secrets. 

271 """ 

272 runtime_config = decrypt_provider_config_for_runtime(config) 

273 return _mask_provider_config_fragment(runtime_config) 

274 

275 

276class LLMProviderError(Exception): 

277 """Base class for LLM provider-related errors.""" 

278 

279 

280class LLMProviderNotFoundError(LLMProviderError): 

281 """Raised when a requested LLM provider is not found.""" 

282 

283 

284class LLMProviderNameConflictError(LLMProviderError): 

285 """Raised when an LLM provider name conflicts with an existing one.""" 

286 

287 def __init__(self, name: str, provider_id: Optional[str] = None): 

288 """Initialize the exception. 

289 

290 Args: 

291 name: The conflicting provider name. 

292 provider_id: Optional ID of the existing provider. 

293 """ 

294 self.name = name 

295 self.provider_id = provider_id 

296 message = f"LLM Provider already exists with name: {name}" 

297 if provider_id: 

298 message += f" (ID: {provider_id})" 

299 super().__init__(message) 

300 

301 

302class LLMProviderValidationError(LLMProviderError, ValueError): 

303 """Raised when provider payload validation fails.""" 

304 

305 

306class LLMModelNotFoundError(LLMProviderError): 

307 """Raised when a requested LLM model is not found.""" 

308 

309 

310class LLMModelConflictError(LLMProviderError): 

311 """Raised when an LLM model conflicts with an existing one.""" 

312 

313 

314class LLMProviderService: 

315 """Service for managing LLM providers and models. 

316 

317 Provides methods to create, list, retrieve, update, and delete 

318 provider and model records. Also supports health checks. 

319 """ 

320 

321 def __init__(self) -> None: 

322 """Initialize a new LLMProviderService instance.""" 

323 self._initialized = False 

324 

325 @staticmethod 

326 def _validate_provider_api_base(api_base: Optional[str]) -> None: 

327 """Validate provider api_base against core URL + SSRF rules. 

328 

329 Args: 

330 api_base: Provider base URL to validate. 

331 

332 Raises: 

333 LLMProviderValidationError: If URL fails core or SSRF validation. 

334 """ 

335 if api_base: 

336 try: 

337 SecurityValidator.validate_url(api_base, "Provider API base URL") 

338 except ValueError as exc: 

339 raise LLMProviderValidationError(str(exc)) from exc 

340 

341 async def initialize(self) -> None: 

342 """Initialize the LLM provider service.""" 

343 if not self._initialized: 

344 logger.info("Initializing LLM Provider Service") 

345 self._initialized = True 

346 

347 async def shutdown(self) -> None: 

348 """Shutdown the LLM provider service.""" 

349 if self._initialized: 

350 logger.info("Shutting down LLM Provider Service") 

351 self._initialized = False 

352 

353 # --------------------------------------------------------------------------- 

354 # Provider CRUD Operations 

355 # --------------------------------------------------------------------------- 

356 

357 def create_provider( 

358 self, 

359 db: Session, 

360 provider_data: LLMProviderCreate, 

361 created_by: Optional[str] = None, 

362 ) -> LLMProvider: 

363 """Create a new LLM provider. 

364 

365 Args: 

366 db: Database session. 

367 provider_data: Provider data to create. 

368 created_by: Username of creator. 

369 

370 Returns: 

371 Created LLMProvider instance. 

372 

373 Raises: 

374 LLMProviderNameConflictError: If provider name already exists. 

375 """ 

376 # Check for name conflict 

377 existing = db.execute(select(LLMProvider).where(LLMProvider.name == provider_data.name)).scalar_one_or_none() 

378 

379 if existing: 

380 raise LLMProviderNameConflictError(provider_data.name, existing.id) 

381 

382 # Encrypt API key if provided 

383 encrypted_api_key = None 

384 if provider_data.api_key: 

385 encrypted_api_key = encode_auth({"api_key": provider_data.api_key}) 

386 

387 self._validate_provider_api_base(provider_data.api_base) 

388 

389 # Create provider 

390 provider = LLMProvider( 

391 name=provider_data.name, 

392 slug=slugify(provider_data.name), 

393 description=provider_data.description, 

394 provider_type=provider_data.provider_type.value, 

395 api_key=encrypted_api_key, 

396 api_base=provider_data.api_base, 

397 api_version=provider_data.api_version, 

398 config=protect_provider_config_for_storage(provider_data.config), 

399 default_model=provider_data.default_model, 

400 default_temperature=provider_data.default_temperature, 

401 default_max_tokens=provider_data.default_max_tokens, 

402 enabled=provider_data.enabled, 

403 plugin_ids=provider_data.plugin_ids, 

404 created_by=created_by, 

405 ) 

406 

407 try: 

408 db.add(provider) 

409 db.commit() 

410 db.refresh(provider) 

411 logger.info(f"Created LLM provider: {provider.name} (ID: {provider.id})") 

412 return provider 

413 except IntegrityError as e: 

414 db.rollback() 

415 logger.error(f"Failed to create LLM provider: {e}") 

416 raise LLMProviderNameConflictError(provider_data.name) 

417 

418 def get_provider(self, db: Session, provider_id: str) -> LLMProvider: 

419 """Get an LLM provider by ID. 

420 

421 Args: 

422 db: Database session. 

423 provider_id: Provider ID to retrieve. 

424 

425 Returns: 

426 LLMProvider instance. 

427 

428 Raises: 

429 LLMProviderNotFoundError: If provider not found. 

430 """ 

431 provider = db.execute(select(LLMProvider).where(LLMProvider.id == provider_id)).scalar_one_or_none() 

432 

433 if not provider: 

434 raise LLMProviderNotFoundError(f"Provider not found: {provider_id}") 

435 

436 return provider 

437 

438 def get_provider_by_slug(self, db: Session, slug: str) -> LLMProvider: 

439 """Get an LLM provider by slug. 

440 

441 Args: 

442 db: Database session. 

443 slug: Provider slug to retrieve. 

444 

445 Returns: 

446 LLMProvider instance. 

447 

448 Raises: 

449 LLMProviderNotFoundError: If provider not found. 

450 """ 

451 provider = db.execute(select(LLMProvider).where(LLMProvider.slug == slug)).scalar_one_or_none() 

452 

453 if not provider: 

454 raise LLMProviderNotFoundError(f"Provider not found: {slug}") 

455 

456 return provider 

457 

458 def list_providers( 

459 self, 

460 db: Session, 

461 enabled_only: bool = False, 

462 page: int = 1, 

463 page_size: int = 50, 

464 ) -> Tuple[List[LLMProvider], int]: 

465 """List all LLM providers. 

466 

467 Args: 

468 db: Database session. 

469 enabled_only: Only return enabled providers. 

470 page: Page number (1-indexed). 

471 page_size: Items per page. 

472 

473 Returns: 

474 Tuple of (providers list, total count). 

475 """ 

476 query = select(LLMProvider) 

477 

478 if enabled_only: 

479 query = query.where(LLMProvider.enabled.is_(True)) 

480 

481 # Get total count efficiently using func.count() 

482 count_query = select(func.count(LLMProvider.id)) # pylint: disable=not-callable 

483 if enabled_only: 

484 count_query = count_query.where(LLMProvider.enabled.is_(True)) 

485 total = db.execute(count_query).scalar() or 0 

486 

487 # Apply pagination 

488 offset = (page - 1) * page_size 

489 query = query.offset(offset).limit(page_size).order_by(LLMProvider.name) 

490 

491 providers = list(db.execute(query).scalars().all()) 

492 return providers, total 

493 

494 def update_provider( 

495 self, 

496 db: Session, 

497 provider_id: str, 

498 provider_data: LLMProviderUpdate, 

499 modified_by: Optional[str] = None, 

500 ) -> LLMProvider: 

501 """Update an LLM provider. 

502 

503 Args: 

504 db: Database session. 

505 provider_id: Provider ID to update. 

506 provider_data: Updated provider data. 

507 modified_by: Username of modifier. 

508 

509 Returns: 

510 Updated LLMProvider instance. 

511 

512 Raises: 

513 LLMProviderNotFoundError: If provider not found. 

514 LLMProviderNameConflictError: If new name conflicts. 

515 IntegrityError: If database constraint violation. 

516 """ 

517 provider = self.get_provider(db, provider_id) 

518 

519 # Check for name conflict if name is being changed 

520 if provider_data.name and provider_data.name != provider.name: 

521 existing = db.execute( 

522 select(LLMProvider).where( 

523 and_( 

524 LLMProvider.name == provider_data.name, 

525 LLMProvider.id != provider_id, 

526 ) 

527 ) 

528 ).scalar_one_or_none() 

529 

530 if existing: 

531 raise LLMProviderNameConflictError(provider_data.name, existing.id) 

532 

533 provider.name = provider_data.name 

534 provider.slug = slugify(provider_data.name) 

535 

536 # Update fields if provided 

537 if provider_data.description is not None: 

538 provider.description = provider_data.description 

539 if provider_data.provider_type is not None: 

540 provider.provider_type = provider_data.provider_type.value 

541 if provider_data.api_key is not None: 

542 provider.api_key = encode_auth({"api_key": provider_data.api_key}) 

543 if provider_data.api_base is not None: 

544 self._validate_provider_api_base(provider_data.api_base) 

545 provider.api_base = provider_data.api_base 

546 if provider_data.api_version is not None: 

547 provider.api_version = provider_data.api_version 

548 if provider_data.config is not None: 

549 provider.config = protect_provider_config_for_storage( 

550 provider_data.config, 

551 existing_config=provider.config if isinstance(provider.config, dict) else None, 

552 ) 

553 if provider_data.default_model is not None: 

554 provider.default_model = provider_data.default_model 

555 if provider_data.default_temperature is not None: 

556 provider.default_temperature = provider_data.default_temperature 

557 if provider_data.default_max_tokens is not None: 

558 provider.default_max_tokens = provider_data.default_max_tokens 

559 if provider_data.enabled is not None: 

560 provider.enabled = provider_data.enabled 

561 if provider_data.plugin_ids is not None: 

562 provider.plugin_ids = provider_data.plugin_ids 

563 

564 provider.modified_by = modified_by 

565 

566 try: 

567 db.commit() 

568 db.refresh(provider) 

569 logger.info(f"Updated LLM provider: {provider.name} (ID: {provider.id})") 

570 return provider 

571 except IntegrityError as e: 

572 db.rollback() 

573 logger.error(f"Failed to update LLM provider: {e}") 

574 raise 

575 

576 def delete_provider(self, db: Session, provider_id: str) -> bool: 

577 """Delete an LLM provider. 

578 

579 Args: 

580 db: Database session. 

581 provider_id: Provider ID to delete. 

582 

583 Returns: 

584 True if deleted successfully. 

585 

586 Raises: 

587 LLMProviderNotFoundError: If provider not found. 

588 """ 

589 provider = self.get_provider(db, provider_id) 

590 provider_name = provider.name 

591 

592 db.delete(provider) 

593 db.commit() 

594 logger.info(f"Deleted LLM provider: {provider_name} (ID: {provider_id})") 

595 return True 

596 

597 def set_provider_state(self, db: Session, provider_id: str, activate: Optional[bool] = None) -> LLMProvider: 

598 """Set provider enabled state. 

599 

600 Args: 

601 db: Database session. 

602 provider_id: Provider ID to update. 

603 activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior). 

604 

605 Returns: 

606 Updated LLMProvider instance. 

607 """ 

608 provider = self.get_provider(db, provider_id) 

609 if activate is None: 

610 # Legacy toggle behavior for backward compatibility 

611 provider.enabled = not provider.enabled 

612 else: 

613 provider.enabled = activate 

614 db.commit() 

615 db.refresh(provider) 

616 logger.info(f"Set LLM provider state: {provider.name} enabled={provider.enabled}") 

617 return provider 

618 

619 # --------------------------------------------------------------------------- 

620 # Model CRUD Operations 

621 # --------------------------------------------------------------------------- 

622 

623 def create_model( 

624 self, 

625 db: Session, 

626 model_data: LLMModelCreate, 

627 ) -> LLMModel: 

628 """Create a new LLM model. 

629 

630 Args: 

631 db: Database session. 

632 model_data: Model data to create. 

633 

634 Returns: 

635 Created LLMModel instance. 

636 

637 Raises: 

638 LLMProviderNotFoundError: If provider not found. 

639 LLMModelConflictError: If model already exists for provider. 

640 """ 

641 # Verify provider exists 

642 self.get_provider(db, model_data.provider_id) 

643 

644 # Check for conflict 

645 existing = db.execute( 

646 select(LLMModel).where( 

647 and_( 

648 LLMModel.provider_id == model_data.provider_id, 

649 LLMModel.model_id == model_data.model_id, 

650 ) 

651 ) 

652 ).scalar_one_or_none() 

653 

654 if existing: 

655 raise LLMModelConflictError(f"Model {model_data.model_id} already exists for provider {model_data.provider_id}") 

656 

657 model = LLMModel( 

658 provider_id=model_data.provider_id, 

659 model_id=model_data.model_id, 

660 model_name=model_data.model_name, 

661 model_alias=model_data.model_alias, 

662 description=model_data.description, 

663 supports_chat=model_data.supports_chat, 

664 supports_streaming=model_data.supports_streaming, 

665 supports_function_calling=model_data.supports_function_calling, 

666 supports_vision=model_data.supports_vision, 

667 context_window=model_data.context_window, 

668 max_output_tokens=model_data.max_output_tokens, 

669 enabled=model_data.enabled, 

670 deprecated=model_data.deprecated, 

671 ) 

672 

673 try: 

674 db.add(model) 

675 db.commit() 

676 db.refresh(model) 

677 logger.info(f"Created LLM model: {model.model_id} (ID: {model.id})") 

678 return model 

679 except IntegrityError as e: 

680 db.rollback() 

681 logger.error(f"Failed to create LLM model: {e}") 

682 raise LLMModelConflictError(f"Model conflict: {model_data.model_id}") 

683 

684 def get_model(self, db: Session, model_id: str) -> LLMModel: 

685 """Get an LLM model by ID. 

686 

687 Args: 

688 db: Database session. 

689 model_id: Model ID to retrieve. 

690 

691 Returns: 

692 LLMModel instance. 

693 

694 Raises: 

695 LLMModelNotFoundError: If model not found. 

696 """ 

697 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none() 

698 

699 if not model: 

700 raise LLMModelNotFoundError(f"Model not found: {model_id}") 

701 

702 return model 

703 

704 def list_models( 

705 self, 

706 db: Session, 

707 provider_id: Optional[str] = None, 

708 enabled_only: bool = False, 

709 page: int = 1, 

710 page_size: int = 50, 

711 ) -> Tuple[List[LLMModel], int]: 

712 """List LLM models. 

713 

714 Args: 

715 db: Database session. 

716 provider_id: Filter by provider ID. 

717 enabled_only: Only return enabled models. 

718 page: Page number (1-indexed). 

719 page_size: Items per page. 

720 

721 Returns: 

722 Tuple of (models list, total count). 

723 """ 

724 query = select(LLMModel) 

725 

726 if provider_id: 

727 query = query.where(LLMModel.provider_id == provider_id) 

728 if enabled_only: 

729 query = query.where(LLMModel.enabled.is_(True)) 

730 

731 # Get total count efficiently using func.count() 

732 count_query = select(func.count(LLMModel.id)) # pylint: disable=not-callable 

733 if provider_id: 

734 count_query = count_query.where(LLMModel.provider_id == provider_id) 

735 if enabled_only: 

736 count_query = count_query.where(LLMModel.enabled.is_(True)) 

737 total = db.execute(count_query).scalar() or 0 

738 

739 # Apply pagination 

740 offset = (page - 1) * page_size 

741 query = query.offset(offset).limit(page_size).order_by(LLMModel.model_name) 

742 

743 models = list(db.execute(query).scalars().all()) 

744 return models, total 

745 

746 def update_model( 

747 self, 

748 db: Session, 

749 model_id: str, 

750 model_data: LLMModelUpdate, 

751 ) -> LLMModel: 

752 """Update an LLM model. 

753 

754 Args: 

755 db: Database session. 

756 model_id: Model ID to update. 

757 model_data: Updated model data. 

758 

759 Returns: 

760 Updated LLMModel instance. 

761 """ 

762 model = self.get_model(db, model_id) 

763 

764 if model_data.model_id is not None: 

765 model.model_id = model_data.model_id 

766 if model_data.model_name is not None: 

767 model.model_name = model_data.model_name 

768 if model_data.model_alias is not None: 

769 model.model_alias = model_data.model_alias 

770 if model_data.description is not None: 

771 model.description = model_data.description 

772 if model_data.supports_chat is not None: 

773 model.supports_chat = model_data.supports_chat 

774 if model_data.supports_streaming is not None: 

775 model.supports_streaming = model_data.supports_streaming 

776 if model_data.supports_function_calling is not None: 

777 model.supports_function_calling = model_data.supports_function_calling 

778 if model_data.supports_vision is not None: 

779 model.supports_vision = model_data.supports_vision 

780 if model_data.context_window is not None: 

781 model.context_window = model_data.context_window 

782 if model_data.max_output_tokens is not None: 

783 model.max_output_tokens = model_data.max_output_tokens 

784 if model_data.enabled is not None: 

785 model.enabled = model_data.enabled 

786 if model_data.deprecated is not None: 

787 model.deprecated = model_data.deprecated 

788 

789 db.commit() 

790 db.refresh(model) 

791 logger.info(f"Updated LLM model: {model.model_id} (ID: {model.id})") 

792 return model 

793 

794 def delete_model(self, db: Session, model_id: str) -> bool: 

795 """Delete an LLM model. 

796 

797 Args: 

798 db: Database session. 

799 model_id: Model ID to delete. 

800 

801 Returns: 

802 True if deleted successfully. 

803 """ 

804 model = self.get_model(db, model_id) 

805 model_name = model.model_id 

806 

807 db.delete(model) 

808 db.commit() 

809 logger.info(f"Deleted LLM model: {model_name} (ID: {model_id})") 

810 return True 

811 

812 def set_model_state(self, db: Session, model_id: str, activate: Optional[bool] = None) -> LLMModel: 

813 """Set model enabled state. 

814 

815 Args: 

816 db: Database session. 

817 model_id: Model ID to update. 

818 activate: If provided, sets enabled to this value. If None, inverts current state (legacy behavior). 

819 

820 Returns: 

821 Updated LLMModel instance. 

822 """ 

823 model = self.get_model(db, model_id) 

824 if activate is None: 

825 # Legacy toggle behavior for backward compatibility 

826 model.enabled = not model.enabled 

827 else: 

828 model.enabled = activate 

829 db.commit() 

830 db.refresh(model) 

831 logger.info(f"Set LLM model state: {model.model_id} enabled={model.enabled}") 

832 return model 

833 

834 # --------------------------------------------------------------------------- 

835 # Gateway Models (for LLM Chat dropdown) 

836 # --------------------------------------------------------------------------- 

837 

838 def get_gateway_models(self, db: Session) -> List[GatewayModelInfo]: 

839 """Get enabled models for the LLM Chat dropdown. 

840 

841 Args: 

842 db: Database session. 

843 

844 Returns: 

845 List of GatewayModelInfo for enabled models. 

846 """ 

847 # Get enabled models from enabled providers 

848 query = ( 

849 select(LLMModel, LLMProvider) 

850 .join(LLMProvider, LLMModel.provider_id == LLMProvider.id) 

851 .where( 

852 and_( 

853 LLMModel.enabled.is_(True), 

854 LLMProvider.enabled.is_(True), 

855 LLMModel.supports_chat.is_(True), 

856 ) 

857 ) 

858 .order_by(LLMProvider.name, LLMModel.model_name) 

859 ) 

860 

861 results = db.execute(query).all() 

862 

863 models = [] 

864 for model, provider in results: 

865 models.append( 

866 GatewayModelInfo( 

867 id=model.id, 

868 model_id=model.model_id, 

869 model_name=model.model_name, 

870 provider_id=provider.id, 

871 provider_name=provider.name, 

872 provider_type=provider.provider_type, 

873 supports_streaming=model.supports_streaming, 

874 supports_function_calling=model.supports_function_calling, 

875 supports_vision=model.supports_vision, 

876 ) 

877 ) 

878 

879 return models 

880 

881 # --------------------------------------------------------------------------- 

882 # Health Check Operations 

883 # --------------------------------------------------------------------------- 

884 

885 async def check_provider_health( 

886 self, 

887 db: Session, 

888 provider_id: str, 

889 ) -> ProviderHealthCheck: 

890 """Check health of an LLM provider. 

891 

892 Args: 

893 db: Database session. 

894 provider_id: Provider ID to check. 

895 

896 Returns: 

897 ProviderHealthCheck result. 

898 """ 

899 provider = self.get_provider(db, provider_id) 

900 

901 start_time = datetime.now(timezone.utc) 

902 status = HealthStatus.UNKNOWN 

903 error_msg = None 

904 response_time_ms = None 

905 

906 try: 

907 # Get API key 

908 api_key = None 

909 if provider.api_key: 

910 auth_data = decode_auth(provider.api_key) 

911 api_key = auth_data.get("api_key") 

912 

913 # Perform health check based on provider type using shared HTTP client 

914 # First-Party 

915 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel 

916 

917 client = await get_http_client() 

918 if provider.provider_type == LLMProviderType.OPENAI: 

919 # Check OpenAI models endpoint 

920 headers = {"Authorization": f"Bearer {api_key}"} 

921 base_url = provider.api_base or "https://api.openai.com/v1" 

922 self._validate_provider_api_base(base_url) 

923 response = await client.get(f"{base_url}/models", headers=headers, timeout=10.0) 

924 if response.status_code == 200: 

925 status = HealthStatus.HEALTHY 

926 else: 

927 status = HealthStatus.UNHEALTHY 

928 error_msg = f"HTTP {response.status_code}" 

929 

930 elif provider.provider_type == LLMProviderType.OLLAMA: 

931 # Check Ollama health endpoint 

932 base_url = provider.api_base or "http://localhost:11434" 

933 self._validate_provider_api_base(base_url) 

934 # Handle OpenAI-compatible endpoint (/v1) 

935 if base_url.rstrip("/").endswith("/v1"): 

936 # Use OpenAI-compatible models endpoint 

937 response = await client.get(f"{base_url.rstrip('/')}/models", timeout=10.0) 

938 else: 

939 # Use native Ollama API 

940 response = await client.get(f"{base_url.rstrip('/')}/api/tags", timeout=10.0) 

941 if response.status_code == 200: 

942 status = HealthStatus.HEALTHY 

943 else: 

944 status = HealthStatus.UNHEALTHY 

945 error_msg = f"HTTP {response.status_code}" 

946 

947 else: 

948 # Generic check - just verify connectivity 

949 if provider.api_base: 

950 self._validate_provider_api_base(provider.api_base) 

951 response = await client.get(provider.api_base, timeout=5.0) 

952 status = HealthStatus.HEALTHY if response.status_code < 500 else HealthStatus.UNHEALTHY 

953 else: 

954 status = HealthStatus.UNKNOWN 

955 error_msg = "No API base URL configured" 

956 

957 except ValueError as e: 

958 status = HealthStatus.UNHEALTHY 

959 error_msg = str(e) 

960 except httpx.TimeoutException: 

961 status = HealthStatus.UNHEALTHY 

962 error_msg = "Connection timeout" 

963 except httpx.RequestError as e: 

964 status = HealthStatus.UNHEALTHY 

965 error_msg = f"Connection error: {str(e)}" 

966 except Exception as e: 

967 status = HealthStatus.UNHEALTHY 

968 error_msg = f"Error: {str(e)}" 

969 

970 end_time = datetime.now(timezone.utc) 

971 response_time_ms = (end_time - start_time).total_seconds() * 1000 

972 

973 # Update provider health status 

974 provider.health_status = status.value 

975 provider.last_health_check = end_time 

976 db.commit() 

977 

978 return ProviderHealthCheck( 

979 provider_id=provider.id, 

980 provider_name=provider.name, 

981 provider_type=provider.provider_type, 

982 status=status, 

983 response_time_ms=response_time_ms, 

984 error=error_msg, 

985 checked_at=end_time, 

986 ) 

987 

988 def to_provider_response( 

989 self, 

990 provider: LLMProvider, 

991 model_count: int = 0, 

992 ) -> LLMProviderResponse: 

993 """Convert LLMProvider to LLMProviderResponse. 

994 

995 Args: 

996 provider: LLMProvider instance. 

997 model_count: Number of models for this provider. 

998 

999 Returns: 

1000 LLMProviderResponse instance. 

1001 """ 

1002 return LLMProviderResponse( 

1003 id=provider.id, 

1004 name=provider.name, 

1005 slug=provider.slug, 

1006 description=provider.description, 

1007 provider_type=provider.provider_type, 

1008 api_base=provider.api_base, 

1009 api_version=provider.api_version, 

1010 config=sanitize_provider_config_for_response(provider.config), 

1011 default_model=provider.default_model, 

1012 default_temperature=provider.default_temperature, 

1013 default_max_tokens=provider.default_max_tokens, 

1014 enabled=provider.enabled, 

1015 health_status=provider.health_status, 

1016 last_health_check=provider.last_health_check, 

1017 plugin_ids=provider.plugin_ids, 

1018 created_at=provider.created_at, 

1019 updated_at=provider.updated_at, 

1020 created_by=provider.created_by, 

1021 modified_by=provider.modified_by, 

1022 model_count=model_count, 

1023 ) 

1024 

1025 def to_model_response( 

1026 self, 

1027 model: LLMModel, 

1028 provider: Optional[LLMProvider] = None, 

1029 ) -> LLMModelResponse: 

1030 """Convert LLMModel to LLMModelResponse. 

1031 

1032 Args: 

1033 model: LLMModel instance. 

1034 provider: Optional provider for name/type info. 

1035 

1036 Returns: 

1037 LLMModelResponse instance. 

1038 """ 

1039 return LLMModelResponse( 

1040 id=model.id, 

1041 provider_id=model.provider_id, 

1042 model_id=model.model_id, 

1043 model_name=model.model_name, 

1044 model_alias=model.model_alias, 

1045 description=model.description, 

1046 supports_chat=model.supports_chat, 

1047 supports_streaming=model.supports_streaming, 

1048 supports_function_calling=model.supports_function_calling, 

1049 supports_vision=model.supports_vision, 

1050 context_window=model.context_window, 

1051 max_output_tokens=model.max_output_tokens, 

1052 enabled=model.enabled, 

1053 deprecated=model.deprecated, 

1054 created_at=model.created_at, 

1055 updated_at=model.updated_at, 

1056 provider_name=provider.name if provider else None, 

1057 provider_type=provider.provider_type if provider else None, 

1058 )