Coverage for mcpgateway / llm_schemas.py: 100%

319 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/llm_schemas.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6LLM Settings Pydantic Schemas. 

7This module provides Pydantic models for LLM provider configuration, model management, 

8and chat completions for the internal LLM Chat feature. 

9 

10The schemas support: 

11- LLM provider CRUD operations 

12- Model configuration and capabilities 

13- Chat completion requests/responses (OpenAI-compatible) 

14- Embedding requests/responses 

15""" 

16 

17# Standard 

18from datetime import datetime 

19from enum import Enum 

20from typing import Any, Dict, List, Literal, Optional, Union 

21 

22# Third-Party 

23from pydantic import BaseModel, ConfigDict, Field, field_validator 

24 

25# First-Party 

26from mcpgateway.common.validators import SecurityValidator, validate_core_url 

27 

28# --------------------------------------------------------------------------- 

29# Enums 

30# --------------------------------------------------------------------------- 

31 

32 

33class LLMProviderTypeEnum(str, Enum): 

34 """Enumeration of supported LLM provider types.""" 

35 

36 OPENAI = "openai" 

37 AZURE_OPENAI = "azure_openai" 

38 ANTHROPIC = "anthropic" 

39 BEDROCK = "bedrock" 

40 GOOGLE_VERTEX = "google_vertex" 

41 WATSONX = "watsonx" 

42 OLLAMA = "ollama" 

43 OPENAI_COMPATIBLE = "openai_compatible" 

44 COHERE = "cohere" 

45 MISTRAL = "mistral" 

46 GROQ = "groq" 

47 TOGETHER = "together" 

48 

49 

50class HealthStatus(str, Enum): 

51 """Health status values for LLM providers.""" 

52 

53 HEALTHY = "healthy" 

54 UNHEALTHY = "unhealthy" 

55 UNKNOWN = "unknown" 

56 

57 

58class RequestStatus(str, Enum): 

59 """Request processing status.""" 

60 

61 PENDING = "pending" 

62 PROCESSING = "processing" 

63 COMPLETED = "completed" 

64 FAILED = "failed" 

65 

66 

67class RequestType(str, Enum): 

68 """Types of LLM requests.""" 

69 

70 CHAT = "chat" 

71 COMPLETION = "completion" 

72 EMBEDDING = "embedding" 

73 

74 

75# --------------------------------------------------------------------------- 

76# LLM Provider Schemas 

77# --------------------------------------------------------------------------- 

78 

79 

80class LLMProviderBase(BaseModel): 

81 """Base schema for LLM provider data.""" 

82 

83 name: str = Field(..., min_length=1, max_length=255, description="Display name for the provider") 

84 description: Optional[str] = Field(None, max_length=2000, description="Optional description") 

85 provider_type: LLMProviderTypeEnum = Field(..., description="Type of LLM provider") 

86 api_base: Optional[str] = Field(None, max_length=512, description="Base URL for API requests") 

87 api_version: Optional[str] = Field(None, max_length=50, description="API version (for Azure OpenAI)") 

88 config: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific configuration") 

89 default_model: Optional[str] = Field(None, max_length=255, description="Default model ID") 

90 default_temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Default temperature") 

91 default_max_tokens: Optional[int] = Field(None, ge=1, description="Default max tokens") 

92 enabled: bool = Field(default=True, description="Whether provider is enabled") 

93 plugin_ids: List[str] = Field(default_factory=list, description="Attached plugin IDs") 

94 

95 @field_validator("name") 

96 @classmethod 

97 def _validate_name(cls, v: str) -> str: 

98 """Sanitize provider name against XSS and injection. 

99 

100 Args: 

101 v: Raw name value. 

102 

103 Returns: 

104 str: Validated name. 

105 """ 

106 return SecurityValidator.validate_name(v, "Provider name") 

107 

108 @field_validator("description") 

109 @classmethod 

110 def _validate_description(cls, v: Optional[str]) -> Optional[str]: 

111 """Sanitize provider description for safe display. 

112 

113 Args: 

114 v: Raw description value. 

115 

116 Returns: 

117 Optional[str]: Sanitized description. 

118 """ 

119 if v is None: 

120 return v 

121 return SecurityValidator.sanitize_display_text(v, "Description") 

122 

123 @field_validator("api_base") 

124 @classmethod 

125 def _validate_api_base(cls, v: Optional[str]) -> Optional[str]: 

126 """Validate provider API base URL. 

127 

128 Args: 

129 v: Raw URL value. 

130 

131 Returns: 

132 Optional[str]: Validated URL. 

133 """ 

134 if v is None: 

135 return v 

136 return validate_core_url(v, "Provider API base URL") 

137 

138 @field_validator("config") 

139 @classmethod 

140 def _validate_config(cls, v: Dict[str, Any]) -> Dict[str, Any]: 

141 """Reject excessively nested provider config. 

142 

143 Args: 

144 v: Config dictionary. 

145 

146 Returns: 

147 Dict[str, Any]: Validated config. 

148 """ 

149 SecurityValidator.validate_json_depth(v) 

150 return v 

151 

152 def validate_provider_config(self) -> None: 

153 """Validate provider-specific configuration based on provider type. 

154 

155 Raises: 

156 ValueError: If required provider-specific fields are missing. 

157 """ 

158 # Import here to avoid circular dependency 

159 # First-Party 

160 from mcpgateway.llm_provider_configs import get_provider_config # pylint: disable=import-outside-toplevel 

161 

162 provider_def = get_provider_config(self.provider_type.value if isinstance(self.provider_type, LLMProviderTypeEnum) else self.provider_type) 

163 if not provider_def: 

164 return 

165 

166 # Validate required config fields 

167 for field_def in provider_def.config_fields: 

168 if field_def.required and field_def.name not in self.config: 

169 raise ValueError(f"Required configuration field '{field_def.name}' missing for provider type '{self.provider_type}'") 

170 

171 

172class LLMProviderCreate(LLMProviderBase): 

173 """Schema for creating a new LLM provider.""" 

174 

175 api_key: Optional[str] = Field(None, description="API key (will be encrypted)") 

176 

177 

178class LLMProviderUpdate(BaseModel): 

179 """Schema for updating an LLM provider.""" 

180 

181 name: Optional[str] = Field(None, min_length=1, max_length=255) 

182 description: Optional[str] = Field(None, max_length=2000) 

183 provider_type: Optional[LLMProviderTypeEnum] = None 

184 api_key: Optional[str] = Field(None, description="API key (will be encrypted)") 

185 api_base: Optional[str] = Field(None, max_length=512) 

186 api_version: Optional[str] = Field(None, max_length=50) 

187 config: Optional[Dict[str, Any]] = None 

188 default_model: Optional[str] = Field(None, max_length=255) 

189 default_temperature: Optional[float] = Field(None, ge=0.0, le=2.0) 

190 default_max_tokens: Optional[int] = Field(None, ge=1) 

191 enabled: Optional[bool] = None 

192 plugin_ids: Optional[List[str]] = None 

193 

194 @field_validator("name") 

195 @classmethod 

196 def _validate_name(cls, v: Optional[str]) -> Optional[str]: 

197 """Sanitize provider name against XSS and injection. 

198 

199 Args: 

200 v: Raw name value. 

201 

202 Returns: 

203 Optional[str]: Validated name. 

204 """ 

205 if v is None: 

206 return v 

207 return SecurityValidator.validate_name(v, "Provider name") 

208 

209 @field_validator("description") 

210 @classmethod 

211 def _validate_description(cls, v: Optional[str]) -> Optional[str]: 

212 """Sanitize provider description for safe display. 

213 

214 Args: 

215 v: Raw description value. 

216 

217 Returns: 

218 Optional[str]: Sanitized description. 

219 """ 

220 if v is None: 

221 return v 

222 return SecurityValidator.sanitize_display_text(v, "Description") 

223 

224 @field_validator("api_base") 

225 @classmethod 

226 def _validate_api_base(cls, v: Optional[str]) -> Optional[str]: 

227 """Validate provider API base URL. 

228 

229 Args: 

230 v: Raw URL value. 

231 

232 Returns: 

233 Optional[str]: Validated URL. 

234 """ 

235 if v is None: 

236 return v 

237 return validate_core_url(v, "Provider API base URL") 

238 

239 @field_validator("config") 

240 @classmethod 

241 def _validate_config(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: 

242 """Reject excessively nested provider config. 

243 

244 Args: 

245 v: Config dictionary. 

246 

247 Returns: 

248 Optional[Dict[str, Any]]: Validated config. 

249 """ 

250 if v is None: 

251 return v 

252 SecurityValidator.validate_json_depth(v) 

253 return v 

254 

255 

256class LLMProviderResponse(BaseModel): 

257 """Schema for LLM provider response.""" 

258 

259 model_config = ConfigDict(from_attributes=True) 

260 

261 id: str 

262 name: str 

263 slug: str 

264 description: Optional[str] = None 

265 provider_type: str 

266 api_base: Optional[str] = None 

267 api_version: Optional[str] = None 

268 config: Dict[str, Any] = Field(default_factory=dict) 

269 default_model: Optional[str] = None 

270 default_temperature: float = 0.7 

271 default_max_tokens: Optional[int] = None 

272 enabled: bool = True 

273 health_status: str = "unknown" 

274 last_health_check: Optional[datetime] = None 

275 plugin_ids: List[str] = Field(default_factory=list) 

276 created_at: datetime 

277 updated_at: datetime 

278 created_by: Optional[str] = None 

279 modified_by: Optional[str] = None 

280 model_count: int = Field(default=0, description="Number of models for this provider") 

281 

282 

283class LLMProviderListResponse(BaseModel): 

284 """Schema for paginated list of LLM providers.""" 

285 

286 providers: List[LLMProviderResponse] 

287 total: int 

288 page: int = 1 

289 page_size: int = 50 

290 

291 

292# --------------------------------------------------------------------------- 

293# LLM Model Schemas 

294# --------------------------------------------------------------------------- 

295 

296 

297class LLMModelBase(BaseModel): 

298 """Base schema for LLM model data.""" 

299 

300 model_id: str = Field(..., min_length=1, max_length=255, description="Provider's model ID") 

301 model_name: str = Field(..., min_length=1, max_length=255, description="Display name") 

302 model_alias: Optional[str] = Field(None, max_length=255, description="Optional routing alias") 

303 description: Optional[str] = Field(None, max_length=2000, description="Model description") 

304 supports_chat: bool = Field(default=True, description="Supports chat completions") 

305 supports_streaming: bool = Field(default=True, description="Supports streaming") 

306 supports_function_calling: bool = Field(default=False, description="Supports function/tool calling") 

307 supports_vision: bool = Field(default=False, description="Supports vision/images") 

308 context_window: Optional[int] = Field(None, ge=1, description="Max context tokens") 

309 max_output_tokens: Optional[int] = Field(None, ge=1, description="Max output tokens") 

310 enabled: bool = Field(default=True, description="Whether model is enabled") 

311 deprecated: bool = Field(default=False, description="Whether model is deprecated") 

312 

313 @field_validator("model_id") 

314 @classmethod 

315 def _validate_model_id(cls, v: str) -> str: 

316 """Sanitize model ID against XSS while allowing provider punctuation. 

317 

318 Provider model IDs commonly contain colons (``llama3.2:latest``), 

319 dots, slashes, and other punctuation. Uses display-text sanitization 

320 (rejects HTML tags / script injection) rather than the strict 

321 tool-name pattern. 

322 

323 Args: 

324 v: Raw model ID. 

325 

326 Returns: 

327 str: Sanitized model ID. 

328 """ 

329 return SecurityValidator.sanitize_display_text(v, "Model ID") 

330 

331 @field_validator("model_name") 

332 @classmethod 

333 def _validate_model_name(cls, v: str) -> str: 

334 """Sanitize model display name against XSS while allowing display punctuation. 

335 

336 Model names from providers may contain parentheses, colons, and 

337 other punctuation (e.g. ``GPT-4o (Latest)``). Uses display-text 

338 sanitization rather than strict name pattern. 

339 

340 Args: 

341 v: Raw model name. 

342 

343 Returns: 

344 str: Sanitized name. 

345 """ 

346 return SecurityValidator.sanitize_display_text(v, "Model name") 

347 

348 @field_validator("model_alias") 

349 @classmethod 

350 def _validate_model_alias(cls, v: Optional[str]) -> Optional[str]: 

351 """Sanitize model alias against XSS and injection. 

352 

353 Args: 

354 v: Raw alias value. 

355 

356 Returns: 

357 Optional[str]: Validated alias. 

358 """ 

359 if v is None: 

360 return v 

361 return SecurityValidator.validate_name(v, "Model alias") 

362 

363 @field_validator("description") 

364 @classmethod 

365 def _validate_description(cls, v: Optional[str]) -> Optional[str]: 

366 """Sanitize model description for safe display. 

367 

368 Args: 

369 v: Raw description value. 

370 

371 Returns: 

372 Optional[str]: Sanitized description. 

373 """ 

374 if v is None: 

375 return v 

376 return SecurityValidator.sanitize_display_text(v, "Description") 

377 

378 

379class LLMModelCreate(LLMModelBase): 

380 """Schema for creating a new LLM model.""" 

381 

382 provider_id: str = Field(..., description="Provider ID this model belongs to") 

383 

384 

385class LLMModelUpdate(BaseModel): 

386 """Schema for updating an LLM model.""" 

387 

388 model_id: Optional[str] = Field(None, min_length=1, max_length=255) 

389 model_name: Optional[str] = Field(None, min_length=1, max_length=255) 

390 model_alias: Optional[str] = Field(None, max_length=255) 

391 description: Optional[str] = Field(None, max_length=2000) 

392 supports_chat: Optional[bool] = None 

393 supports_streaming: Optional[bool] = None 

394 supports_function_calling: Optional[bool] = None 

395 supports_vision: Optional[bool] = None 

396 context_window: Optional[int] = Field(None, ge=1) 

397 max_output_tokens: Optional[int] = Field(None, ge=1) 

398 enabled: Optional[bool] = None 

399 deprecated: Optional[bool] = None 

400 

401 @field_validator("model_id") 

402 @classmethod 

403 def _validate_model_id(cls, v: Optional[str]) -> Optional[str]: 

404 """Sanitize model ID against XSS while allowing provider punctuation. 

405 

406 Args: 

407 v: Raw model ID. 

408 

409 Returns: 

410 Optional[str]: Sanitized model ID. 

411 """ 

412 if v is None: 

413 return v 

414 return SecurityValidator.sanitize_display_text(v, "Model ID") 

415 

416 @field_validator("model_name") 

417 @classmethod 

418 def _validate_model_name(cls, v: Optional[str]) -> Optional[str]: 

419 """Sanitize model display name against XSS while allowing display punctuation. 

420 

421 Args: 

422 v: Raw model name. 

423 

424 Returns: 

425 Optional[str]: Sanitized name. 

426 """ 

427 if v is None: 

428 return v 

429 return SecurityValidator.sanitize_display_text(v, "Model name") 

430 

431 @field_validator("model_alias") 

432 @classmethod 

433 def _validate_model_alias(cls, v: Optional[str]) -> Optional[str]: 

434 """Sanitize model alias against XSS and injection. 

435 

436 Args: 

437 v: Raw alias value. 

438 

439 Returns: 

440 Optional[str]: Validated alias. 

441 """ 

442 if v is None: 

443 return v 

444 return SecurityValidator.validate_name(v, "Model alias") 

445 

446 @field_validator("description") 

447 @classmethod 

448 def _validate_description(cls, v: Optional[str]) -> Optional[str]: 

449 """Sanitize model description for safe display. 

450 

451 Args: 

452 v: Raw description value. 

453 

454 Returns: 

455 Optional[str]: Sanitized description. 

456 """ 

457 if v is None: 

458 return v 

459 return SecurityValidator.sanitize_display_text(v, "Description") 

460 

461 

462class LLMModelResponse(BaseModel): 

463 """Schema for LLM model response.""" 

464 

465 model_config = ConfigDict(from_attributes=True) 

466 

467 id: str 

468 provider_id: str 

469 model_id: str 

470 model_name: str 

471 model_alias: Optional[str] = None 

472 description: Optional[str] = None 

473 supports_chat: bool = True 

474 supports_streaming: bool = True 

475 supports_function_calling: bool = False 

476 supports_vision: bool = False 

477 context_window: Optional[int] = None 

478 max_output_tokens: Optional[int] = None 

479 enabled: bool = True 

480 deprecated: bool = False 

481 created_at: datetime 

482 updated_at: datetime 

483 provider_name: Optional[str] = Field(None, description="Provider name for display") 

484 provider_type: Optional[str] = Field(None, description="Provider type for display") 

485 

486 

487class LLMModelListResponse(BaseModel): 

488 """Schema for paginated list of LLM models.""" 

489 

490 models: List[LLMModelResponse] 

491 total: int 

492 page: int = 1 

493 page_size: int = 50 

494 

495 

496# --------------------------------------------------------------------------- 

497# Chat Completion Schemas (OpenAI-compatible) 

498# --------------------------------------------------------------------------- 

499 

500 

501class FunctionDefinition(BaseModel): 

502 """Function definition for tool calling.""" 

503 

504 name: str = Field(..., description="Function name") 

505 description: Optional[str] = Field(None, description="Function description") 

506 parameters: Dict[str, Any] = Field(default_factory=dict, description="JSON Schema for parameters") 

507 

508 

509class ToolDefinition(BaseModel): 

510 """Tool definition for function calling.""" 

511 

512 type: Literal["function"] = "function" 

513 function: FunctionDefinition 

514 

515 

516class ChatMessage(BaseModel): 

517 """A single chat message.""" 

518 

519 role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Message role") 

520 content: Optional[str] = Field(None, description="Message content") 

521 name: Optional[str] = Field(None, description="Optional name for the participant") 

522 tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls made by assistant") 

523 tool_call_id: Optional[str] = Field(None, description="ID of tool call this message responds to") 

524 

525 

526class ChatCompletionRequest(BaseModel): 

527 """Request for chat completions (OpenAI-compatible).""" 

528 

529 model: str = Field(..., description="Model ID to use") 

530 messages: List[ChatMessage] = Field(..., min_length=1, description="Conversation messages") 

531 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature") 

532 max_tokens: Optional[int] = Field(None, ge=1, description="Maximum tokens to generate") 

533 stream: bool = Field(default=False, description="Enable streaming response") 

534 tools: Optional[List[ToolDefinition]] = Field(None, description="Available tools") 

535 tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice preference") 

536 top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling") 

537 frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Frequency penalty") 

538 presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Presence penalty") 

539 stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences") 

540 user: Optional[str] = Field(None, description="User identifier") 

541 

542 

543class UsageStats(BaseModel): 

544 """Token usage statistics.""" 

545 

546 prompt_tokens: int = 0 

547 completion_tokens: int = 0 

548 total_tokens: int = 0 

549 

550 

551class ChatChoice(BaseModel): 

552 """A single chat completion choice.""" 

553 

554 index: int = 0 

555 message: ChatMessage 

556 finish_reason: Optional[str] = None 

557 

558 

559class ChatCompletionResponse(BaseModel): 

560 """Response from chat completions.""" 

561 

562 id: str = Field(..., description="Unique response ID") 

563 object: str = "chat.completion" 

564 created: int = Field(..., description="Unix timestamp") 

565 model: str = Field(..., description="Model used") 

566 choices: List[ChatChoice] 

567 usage: Optional[UsageStats] = None 

568 

569 

570class ChatCompletionChunk(BaseModel): 

571 """Streaming chunk for chat completions.""" 

572 

573 id: str 

574 object: str = "chat.completion.chunk" 

575 created: int 

576 model: str 

577 choices: List[Dict[str, Any]] 

578 

579 

580# --------------------------------------------------------------------------- 

581# Embedding Schemas 

582# --------------------------------------------------------------------------- 

583 

584 

585class EmbeddingRequest(BaseModel): 

586 """Request for embeddings.""" 

587 

588 model: str = Field(..., description="Model ID to use") 

589 input: Union[str, List[str]] = Field(..., description="Text to embed") 

590 encoding_format: Optional[Literal["float", "base64"]] = Field(None, description="Encoding format") 

591 user: Optional[str] = Field(None, description="User identifier") 

592 

593 

594class EmbeddingData(BaseModel): 

595 """A single embedding result.""" 

596 

597 object: str = "embedding" 

598 embedding: List[float] 

599 index: int = 0 

600 

601 

602class EmbeddingResponse(BaseModel): 

603 """Response from embeddings.""" 

604 

605 object: str = "list" 

606 data: List[EmbeddingData] 

607 model: str 

608 usage: UsageStats 

609 

610 

611# --------------------------------------------------------------------------- 

612# Gateway Models Response (for LLM Chat dropdown) 

613# --------------------------------------------------------------------------- 

614 

615 

616class GatewayModelInfo(BaseModel): 

617 """Simplified model info for the LLM Chat dropdown.""" 

618 

619 model_config = ConfigDict(from_attributes=True) 

620 

621 id: str = Field(..., description="Unique model ID") 

622 model_id: str = Field(..., description="Provider's model identifier") 

623 model_name: str = Field(..., description="Display name") 

624 provider_id: str = Field(..., description="Provider ID") 

625 provider_name: str = Field(..., description="Provider display name") 

626 provider_type: str = Field(..., description="Provider type") 

627 supports_streaming: bool = True 

628 supports_function_calling: bool = False 

629 supports_vision: bool = False 

630 

631 

632class GatewayModelsResponse(BaseModel): 

633 """Response for /llmchat/gateway/models endpoint.""" 

634 

635 models: List[GatewayModelInfo] 

636 count: int 

637 

638 

639# --------------------------------------------------------------------------- 

640# Health Check Schemas 

641# --------------------------------------------------------------------------- 

642 

643 

644class ProviderHealthCheck(BaseModel): 

645 """Result of a provider health check.""" 

646 

647 provider_id: str 

648 provider_name: str 

649 provider_type: str 

650 status: HealthStatus 

651 response_time_ms: Optional[float] = None 

652 error: Optional[str] = None 

653 checked_at: datetime