Coverage for mcpgateway / llm_schemas.py: 100%
319 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/llm_schemas.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Settings Pydantic Schemas.
7This module provides Pydantic models for LLM provider configuration, model management,
8and chat completions for the internal LLM Chat feature.
10The schemas support:
11- LLM provider CRUD operations
12- Model configuration and capabilities
13- Chat completion requests/responses (OpenAI-compatible)
14- Embedding requests/responses
15"""
17# Standard
18from datetime import datetime
19from enum import Enum
20from typing import Any, Dict, List, Literal, Optional, Union
22# Third-Party
23from pydantic import BaseModel, ConfigDict, Field, field_validator
25# First-Party
26from mcpgateway.common.validators import SecurityValidator, validate_core_url
28# ---------------------------------------------------------------------------
29# Enums
30# ---------------------------------------------------------------------------
33class LLMProviderTypeEnum(str, Enum):
34 """Enumeration of supported LLM provider types."""
36 OPENAI = "openai"
37 AZURE_OPENAI = "azure_openai"
38 ANTHROPIC = "anthropic"
39 BEDROCK = "bedrock"
40 GOOGLE_VERTEX = "google_vertex"
41 WATSONX = "watsonx"
42 OLLAMA = "ollama"
43 OPENAI_COMPATIBLE = "openai_compatible"
44 COHERE = "cohere"
45 MISTRAL = "mistral"
46 GROQ = "groq"
47 TOGETHER = "together"
50class HealthStatus(str, Enum):
51 """Health status values for LLM providers."""
53 HEALTHY = "healthy"
54 UNHEALTHY = "unhealthy"
55 UNKNOWN = "unknown"
58class RequestStatus(str, Enum):
59 """Request processing status."""
61 PENDING = "pending"
62 PROCESSING = "processing"
63 COMPLETED = "completed"
64 FAILED = "failed"
67class RequestType(str, Enum):
68 """Types of LLM requests."""
70 CHAT = "chat"
71 COMPLETION = "completion"
72 EMBEDDING = "embedding"
75# ---------------------------------------------------------------------------
76# LLM Provider Schemas
77# ---------------------------------------------------------------------------
80class LLMProviderBase(BaseModel):
81 """Base schema for LLM provider data."""
83 name: str = Field(..., min_length=1, max_length=255, description="Display name for the provider")
84 description: Optional[str] = Field(None, max_length=2000, description="Optional description")
85 provider_type: LLMProviderTypeEnum = Field(..., description="Type of LLM provider")
86 api_base: Optional[str] = Field(None, max_length=512, description="Base URL for API requests")
87 api_version: Optional[str] = Field(None, max_length=50, description="API version (for Azure OpenAI)")
88 config: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific configuration")
89 default_model: Optional[str] = Field(None, max_length=255, description="Default model ID")
90 default_temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Default temperature")
91 default_max_tokens: Optional[int] = Field(None, ge=1, description="Default max tokens")
92 enabled: bool = Field(default=True, description="Whether provider is enabled")
93 plugin_ids: List[str] = Field(default_factory=list, description="Attached plugin IDs")
95 @field_validator("name")
96 @classmethod
97 def _validate_name(cls, v: str) -> str:
98 """Sanitize provider name against XSS and injection.
100 Args:
101 v: Raw name value.
103 Returns:
104 str: Validated name.
105 """
106 return SecurityValidator.validate_name(v, "Provider name")
108 @field_validator("description")
109 @classmethod
110 def _validate_description(cls, v: Optional[str]) -> Optional[str]:
111 """Sanitize provider description for safe display.
113 Args:
114 v: Raw description value.
116 Returns:
117 Optional[str]: Sanitized description.
118 """
119 if v is None:
120 return v
121 return SecurityValidator.sanitize_display_text(v, "Description")
123 @field_validator("api_base")
124 @classmethod
125 def _validate_api_base(cls, v: Optional[str]) -> Optional[str]:
126 """Validate provider API base URL.
128 Args:
129 v: Raw URL value.
131 Returns:
132 Optional[str]: Validated URL.
133 """
134 if v is None:
135 return v
136 return validate_core_url(v, "Provider API base URL")
138 @field_validator("config")
139 @classmethod
140 def _validate_config(cls, v: Dict[str, Any]) -> Dict[str, Any]:
141 """Reject excessively nested provider config.
143 Args:
144 v: Config dictionary.
146 Returns:
147 Dict[str, Any]: Validated config.
148 """
149 SecurityValidator.validate_json_depth(v)
150 return v
152 def validate_provider_config(self) -> None:
153 """Validate provider-specific configuration based on provider type.
155 Raises:
156 ValueError: If required provider-specific fields are missing.
157 """
158 # Import here to avoid circular dependency
159 # First-Party
160 from mcpgateway.llm_provider_configs import get_provider_config # pylint: disable=import-outside-toplevel
162 provider_def = get_provider_config(self.provider_type.value if isinstance(self.provider_type, LLMProviderTypeEnum) else self.provider_type)
163 if not provider_def:
164 return
166 # Validate required config fields
167 for field_def in provider_def.config_fields:
168 if field_def.required and field_def.name not in self.config:
169 raise ValueError(f"Required configuration field '{field_def.name}' missing for provider type '{self.provider_type}'")
172class LLMProviderCreate(LLMProviderBase):
173 """Schema for creating a new LLM provider."""
175 api_key: Optional[str] = Field(None, description="API key (will be encrypted)")
178class LLMProviderUpdate(BaseModel):
179 """Schema for updating an LLM provider."""
181 name: Optional[str] = Field(None, min_length=1, max_length=255)
182 description: Optional[str] = Field(None, max_length=2000)
183 provider_type: Optional[LLMProviderTypeEnum] = None
184 api_key: Optional[str] = Field(None, description="API key (will be encrypted)")
185 api_base: Optional[str] = Field(None, max_length=512)
186 api_version: Optional[str] = Field(None, max_length=50)
187 config: Optional[Dict[str, Any]] = None
188 default_model: Optional[str] = Field(None, max_length=255)
189 default_temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
190 default_max_tokens: Optional[int] = Field(None, ge=1)
191 enabled: Optional[bool] = None
192 plugin_ids: Optional[List[str]] = None
194 @field_validator("name")
195 @classmethod
196 def _validate_name(cls, v: Optional[str]) -> Optional[str]:
197 """Sanitize provider name against XSS and injection.
199 Args:
200 v: Raw name value.
202 Returns:
203 Optional[str]: Validated name.
204 """
205 if v is None:
206 return v
207 return SecurityValidator.validate_name(v, "Provider name")
209 @field_validator("description")
210 @classmethod
211 def _validate_description(cls, v: Optional[str]) -> Optional[str]:
212 """Sanitize provider description for safe display.
214 Args:
215 v: Raw description value.
217 Returns:
218 Optional[str]: Sanitized description.
219 """
220 if v is None:
221 return v
222 return SecurityValidator.sanitize_display_text(v, "Description")
224 @field_validator("api_base")
225 @classmethod
226 def _validate_api_base(cls, v: Optional[str]) -> Optional[str]:
227 """Validate provider API base URL.
229 Args:
230 v: Raw URL value.
232 Returns:
233 Optional[str]: Validated URL.
234 """
235 if v is None:
236 return v
237 return validate_core_url(v, "Provider API base URL")
239 @field_validator("config")
240 @classmethod
241 def _validate_config(cls, v: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
242 """Reject excessively nested provider config.
244 Args:
245 v: Config dictionary.
247 Returns:
248 Optional[Dict[str, Any]]: Validated config.
249 """
250 if v is None:
251 return v
252 SecurityValidator.validate_json_depth(v)
253 return v
256class LLMProviderResponse(BaseModel):
257 """Schema for LLM provider response."""
259 model_config = ConfigDict(from_attributes=True)
261 id: str
262 name: str
263 slug: str
264 description: Optional[str] = None
265 provider_type: str
266 api_base: Optional[str] = None
267 api_version: Optional[str] = None
268 config: Dict[str, Any] = Field(default_factory=dict)
269 default_model: Optional[str] = None
270 default_temperature: float = 0.7
271 default_max_tokens: Optional[int] = None
272 enabled: bool = True
273 health_status: str = "unknown"
274 last_health_check: Optional[datetime] = None
275 plugin_ids: List[str] = Field(default_factory=list)
276 created_at: datetime
277 updated_at: datetime
278 created_by: Optional[str] = None
279 modified_by: Optional[str] = None
280 model_count: int = Field(default=0, description="Number of models for this provider")
283class LLMProviderListResponse(BaseModel):
284 """Schema for paginated list of LLM providers."""
286 providers: List[LLMProviderResponse]
287 total: int
288 page: int = 1
289 page_size: int = 50
292# ---------------------------------------------------------------------------
293# LLM Model Schemas
294# ---------------------------------------------------------------------------
297class LLMModelBase(BaseModel):
298 """Base schema for LLM model data."""
300 model_id: str = Field(..., min_length=1, max_length=255, description="Provider's model ID")
301 model_name: str = Field(..., min_length=1, max_length=255, description="Display name")
302 model_alias: Optional[str] = Field(None, max_length=255, description="Optional routing alias")
303 description: Optional[str] = Field(None, max_length=2000, description="Model description")
304 supports_chat: bool = Field(default=True, description="Supports chat completions")
305 supports_streaming: bool = Field(default=True, description="Supports streaming")
306 supports_function_calling: bool = Field(default=False, description="Supports function/tool calling")
307 supports_vision: bool = Field(default=False, description="Supports vision/images")
308 context_window: Optional[int] = Field(None, ge=1, description="Max context tokens")
309 max_output_tokens: Optional[int] = Field(None, ge=1, description="Max output tokens")
310 enabled: bool = Field(default=True, description="Whether model is enabled")
311 deprecated: bool = Field(default=False, description="Whether model is deprecated")
313 @field_validator("model_id")
314 @classmethod
315 def _validate_model_id(cls, v: str) -> str:
316 """Sanitize model ID against XSS while allowing provider punctuation.
318 Provider model IDs commonly contain colons (``llama3.2:latest``),
319 dots, slashes, and other punctuation. Uses display-text sanitization
320 (rejects HTML tags / script injection) rather than the strict
321 tool-name pattern.
323 Args:
324 v: Raw model ID.
326 Returns:
327 str: Sanitized model ID.
328 """
329 return SecurityValidator.sanitize_display_text(v, "Model ID")
331 @field_validator("model_name")
332 @classmethod
333 def _validate_model_name(cls, v: str) -> str:
334 """Sanitize model display name against XSS while allowing display punctuation.
336 Model names from providers may contain parentheses, colons, and
337 other punctuation (e.g. ``GPT-4o (Latest)``). Uses display-text
338 sanitization rather than strict name pattern.
340 Args:
341 v: Raw model name.
343 Returns:
344 str: Sanitized name.
345 """
346 return SecurityValidator.sanitize_display_text(v, "Model name")
348 @field_validator("model_alias")
349 @classmethod
350 def _validate_model_alias(cls, v: Optional[str]) -> Optional[str]:
351 """Sanitize model alias against XSS and injection.
353 Args:
354 v: Raw alias value.
356 Returns:
357 Optional[str]: Validated alias.
358 """
359 if v is None:
360 return v
361 return SecurityValidator.validate_name(v, "Model alias")
363 @field_validator("description")
364 @classmethod
365 def _validate_description(cls, v: Optional[str]) -> Optional[str]:
366 """Sanitize model description for safe display.
368 Args:
369 v: Raw description value.
371 Returns:
372 Optional[str]: Sanitized description.
373 """
374 if v is None:
375 return v
376 return SecurityValidator.sanitize_display_text(v, "Description")
379class LLMModelCreate(LLMModelBase):
380 """Schema for creating a new LLM model."""
382 provider_id: str = Field(..., description="Provider ID this model belongs to")
385class LLMModelUpdate(BaseModel):
386 """Schema for updating an LLM model."""
388 model_id: Optional[str] = Field(None, min_length=1, max_length=255)
389 model_name: Optional[str] = Field(None, min_length=1, max_length=255)
390 model_alias: Optional[str] = Field(None, max_length=255)
391 description: Optional[str] = Field(None, max_length=2000)
392 supports_chat: Optional[bool] = None
393 supports_streaming: Optional[bool] = None
394 supports_function_calling: Optional[bool] = None
395 supports_vision: Optional[bool] = None
396 context_window: Optional[int] = Field(None, ge=1)
397 max_output_tokens: Optional[int] = Field(None, ge=1)
398 enabled: Optional[bool] = None
399 deprecated: Optional[bool] = None
401 @field_validator("model_id")
402 @classmethod
403 def _validate_model_id(cls, v: Optional[str]) -> Optional[str]:
404 """Sanitize model ID against XSS while allowing provider punctuation.
406 Args:
407 v: Raw model ID.
409 Returns:
410 Optional[str]: Sanitized model ID.
411 """
412 if v is None:
413 return v
414 return SecurityValidator.sanitize_display_text(v, "Model ID")
416 @field_validator("model_name")
417 @classmethod
418 def _validate_model_name(cls, v: Optional[str]) -> Optional[str]:
419 """Sanitize model display name against XSS while allowing display punctuation.
421 Args:
422 v: Raw model name.
424 Returns:
425 Optional[str]: Sanitized name.
426 """
427 if v is None:
428 return v
429 return SecurityValidator.sanitize_display_text(v, "Model name")
431 @field_validator("model_alias")
432 @classmethod
433 def _validate_model_alias(cls, v: Optional[str]) -> Optional[str]:
434 """Sanitize model alias against XSS and injection.
436 Args:
437 v: Raw alias value.
439 Returns:
440 Optional[str]: Validated alias.
441 """
442 if v is None:
443 return v
444 return SecurityValidator.validate_name(v, "Model alias")
446 @field_validator("description")
447 @classmethod
448 def _validate_description(cls, v: Optional[str]) -> Optional[str]:
449 """Sanitize model description for safe display.
451 Args:
452 v: Raw description value.
454 Returns:
455 Optional[str]: Sanitized description.
456 """
457 if v is None:
458 return v
459 return SecurityValidator.sanitize_display_text(v, "Description")
462class LLMModelResponse(BaseModel):
463 """Schema for LLM model response."""
465 model_config = ConfigDict(from_attributes=True)
467 id: str
468 provider_id: str
469 model_id: str
470 model_name: str
471 model_alias: Optional[str] = None
472 description: Optional[str] = None
473 supports_chat: bool = True
474 supports_streaming: bool = True
475 supports_function_calling: bool = False
476 supports_vision: bool = False
477 context_window: Optional[int] = None
478 max_output_tokens: Optional[int] = None
479 enabled: bool = True
480 deprecated: bool = False
481 created_at: datetime
482 updated_at: datetime
483 provider_name: Optional[str] = Field(None, description="Provider name for display")
484 provider_type: Optional[str] = Field(None, description="Provider type for display")
487class LLMModelListResponse(BaseModel):
488 """Schema for paginated list of LLM models."""
490 models: List[LLMModelResponse]
491 total: int
492 page: int = 1
493 page_size: int = 50
496# ---------------------------------------------------------------------------
497# Chat Completion Schemas (OpenAI-compatible)
498# ---------------------------------------------------------------------------
501class FunctionDefinition(BaseModel):
502 """Function definition for tool calling."""
504 name: str = Field(..., description="Function name")
505 description: Optional[str] = Field(None, description="Function description")
506 parameters: Dict[str, Any] = Field(default_factory=dict, description="JSON Schema for parameters")
509class ToolDefinition(BaseModel):
510 """Tool definition for function calling."""
512 type: Literal["function"] = "function"
513 function: FunctionDefinition
516class ChatMessage(BaseModel):
517 """A single chat message."""
519 role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Message role")
520 content: Optional[str] = Field(None, description="Message content")
521 name: Optional[str] = Field(None, description="Optional name for the participant")
522 tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls made by assistant")
523 tool_call_id: Optional[str] = Field(None, description="ID of tool call this message responds to")
526class ChatCompletionRequest(BaseModel):
527 """Request for chat completions (OpenAI-compatible)."""
529 model: str = Field(..., description="Model ID to use")
530 messages: List[ChatMessage] = Field(..., min_length=1, description="Conversation messages")
531 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature")
532 max_tokens: Optional[int] = Field(None, ge=1, description="Maximum tokens to generate")
533 stream: bool = Field(default=False, description="Enable streaming response")
534 tools: Optional[List[ToolDefinition]] = Field(None, description="Available tools")
535 tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice preference")
536 top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling")
537 frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Frequency penalty")
538 presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Presence penalty")
539 stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
540 user: Optional[str] = Field(None, description="User identifier")
543class UsageStats(BaseModel):
544 """Token usage statistics."""
546 prompt_tokens: int = 0
547 completion_tokens: int = 0
548 total_tokens: int = 0
551class ChatChoice(BaseModel):
552 """A single chat completion choice."""
554 index: int = 0
555 message: ChatMessage
556 finish_reason: Optional[str] = None
559class ChatCompletionResponse(BaseModel):
560 """Response from chat completions."""
562 id: str = Field(..., description="Unique response ID")
563 object: str = "chat.completion"
564 created: int = Field(..., description="Unix timestamp")
565 model: str = Field(..., description="Model used")
566 choices: List[ChatChoice]
567 usage: Optional[UsageStats] = None
570class ChatCompletionChunk(BaseModel):
571 """Streaming chunk for chat completions."""
573 id: str
574 object: str = "chat.completion.chunk"
575 created: int
576 model: str
577 choices: List[Dict[str, Any]]
580# ---------------------------------------------------------------------------
581# Embedding Schemas
582# ---------------------------------------------------------------------------
585class EmbeddingRequest(BaseModel):
586 """Request for embeddings."""
588 model: str = Field(..., description="Model ID to use")
589 input: Union[str, List[str]] = Field(..., description="Text to embed")
590 encoding_format: Optional[Literal["float", "base64"]] = Field(None, description="Encoding format")
591 user: Optional[str] = Field(None, description="User identifier")
594class EmbeddingData(BaseModel):
595 """A single embedding result."""
597 object: str = "embedding"
598 embedding: List[float]
599 index: int = 0
602class EmbeddingResponse(BaseModel):
603 """Response from embeddings."""
605 object: str = "list"
606 data: List[EmbeddingData]
607 model: str
608 usage: UsageStats
611# ---------------------------------------------------------------------------
612# Gateway Models Response (for LLM Chat dropdown)
613# ---------------------------------------------------------------------------
616class GatewayModelInfo(BaseModel):
617 """Simplified model info for the LLM Chat dropdown."""
619 model_config = ConfigDict(from_attributes=True)
621 id: str = Field(..., description="Unique model ID")
622 model_id: str = Field(..., description="Provider's model identifier")
623 model_name: str = Field(..., description="Display name")
624 provider_id: str = Field(..., description="Provider ID")
625 provider_name: str = Field(..., description="Provider display name")
626 provider_type: str = Field(..., description="Provider type")
627 supports_streaming: bool = True
628 supports_function_calling: bool = False
629 supports_vision: bool = False
632class GatewayModelsResponse(BaseModel):
633 """Response for /llmchat/gateway/models endpoint."""
635 models: List[GatewayModelInfo]
636 count: int
639# ---------------------------------------------------------------------------
640# Health Check Schemas
641# ---------------------------------------------------------------------------
644class ProviderHealthCheck(BaseModel):
645 """Result of a provider health check."""
647 provider_id: str
648 provider_name: str
649 provider_type: str
650 status: HealthStatus
651 response_time_ms: Optional[float] = None
652 error: Optional[str] = None
653 checked_at: datetime