Coverage for mcpgateway / llm_schemas.py: 100%
228 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/llm_schemas.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6LLM Settings Pydantic Schemas.
7This module provides Pydantic models for LLM provider configuration, model management,
8and chat completions for the internal LLM Chat feature.
10The schemas support:
11- LLM provider CRUD operations
12- Model configuration and capabilities
13- Chat completion requests/responses (OpenAI-compatible)
14- Embedding requests/responses
15"""
17# Standard
18from datetime import datetime
19from enum import Enum
20from typing import Any, Dict, List, Literal, Optional, Union
22# Third-Party
23from pydantic import BaseModel, ConfigDict, Field
25# ---------------------------------------------------------------------------
26# Enums
27# ---------------------------------------------------------------------------
30class LLMProviderTypeEnum(str, Enum):
31 """Enumeration of supported LLM provider types."""
33 OPENAI = "openai"
34 AZURE_OPENAI = "azure_openai"
35 ANTHROPIC = "anthropic"
36 BEDROCK = "bedrock"
37 GOOGLE_VERTEX = "google_vertex"
38 WATSONX = "watsonx"
39 OLLAMA = "ollama"
40 OPENAI_COMPATIBLE = "openai_compatible"
41 COHERE = "cohere"
42 MISTRAL = "mistral"
43 GROQ = "groq"
44 TOGETHER = "together"
47class HealthStatus(str, Enum):
48 """Health status values for LLM providers."""
50 HEALTHY = "healthy"
51 UNHEALTHY = "unhealthy"
52 UNKNOWN = "unknown"
55class RequestStatus(str, Enum):
56 """Request processing status."""
58 PENDING = "pending"
59 PROCESSING = "processing"
60 COMPLETED = "completed"
61 FAILED = "failed"
64class RequestType(str, Enum):
65 """Types of LLM requests."""
67 CHAT = "chat"
68 COMPLETION = "completion"
69 EMBEDDING = "embedding"
72# ---------------------------------------------------------------------------
73# LLM Provider Schemas
74# ---------------------------------------------------------------------------
77class LLMProviderBase(BaseModel):
78 """Base schema for LLM provider data."""
80 name: str = Field(..., min_length=1, max_length=255, description="Display name for the provider")
81 description: Optional[str] = Field(None, max_length=2000, description="Optional description")
82 provider_type: LLMProviderTypeEnum = Field(..., description="Type of LLM provider")
83 api_base: Optional[str] = Field(None, max_length=512, description="Base URL for API requests")
84 api_version: Optional[str] = Field(None, max_length=50, description="API version (for Azure OpenAI)")
85 config: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific configuration")
86 default_model: Optional[str] = Field(None, max_length=255, description="Default model ID")
87 default_temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Default temperature")
88 default_max_tokens: Optional[int] = Field(None, ge=1, description="Default max tokens")
89 enabled: bool = Field(default=True, description="Whether provider is enabled")
90 plugin_ids: List[str] = Field(default_factory=list, description="Attached plugin IDs")
92 def validate_provider_config(self) -> None:
93 """Validate provider-specific configuration based on provider type.
95 Raises:
96 ValueError: If required provider-specific fields are missing.
97 """
98 # Import here to avoid circular dependency
99 # First-Party
100 from mcpgateway.llm_provider_configs import get_provider_config # pylint: disable=import-outside-toplevel
102 provider_def = get_provider_config(self.provider_type.value if isinstance(self.provider_type, LLMProviderTypeEnum) else self.provider_type)
103 if not provider_def:
104 return
106 # Validate required config fields
107 for field_def in provider_def.config_fields:
108 if field_def.required and field_def.name not in self.config:
109 raise ValueError(f"Required configuration field '{field_def.name}' missing for provider type '{self.provider_type}'")
112class LLMProviderCreate(LLMProviderBase):
113 """Schema for creating a new LLM provider."""
115 api_key: Optional[str] = Field(None, description="API key (will be encrypted)")
118class LLMProviderUpdate(BaseModel):
119 """Schema for updating an LLM provider."""
121 name: Optional[str] = Field(None, min_length=1, max_length=255)
122 description: Optional[str] = Field(None, max_length=2000)
123 provider_type: Optional[LLMProviderTypeEnum] = None
124 api_key: Optional[str] = Field(None, description="API key (will be encrypted)")
125 api_base: Optional[str] = Field(None, max_length=512)
126 api_version: Optional[str] = Field(None, max_length=50)
127 config: Optional[Dict[str, Any]] = None
128 default_model: Optional[str] = Field(None, max_length=255)
129 default_temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
130 default_max_tokens: Optional[int] = Field(None, ge=1)
131 enabled: Optional[bool] = None
132 plugin_ids: Optional[List[str]] = None
135class LLMProviderResponse(BaseModel):
136 """Schema for LLM provider response."""
138 model_config = ConfigDict(from_attributes=True)
140 id: str
141 name: str
142 slug: str
143 description: Optional[str] = None
144 provider_type: str
145 api_base: Optional[str] = None
146 api_version: Optional[str] = None
147 config: Dict[str, Any] = Field(default_factory=dict)
148 default_model: Optional[str] = None
149 default_temperature: float = 0.7
150 default_max_tokens: Optional[int] = None
151 enabled: bool = True
152 health_status: str = "unknown"
153 last_health_check: Optional[datetime] = None
154 plugin_ids: List[str] = Field(default_factory=list)
155 created_at: datetime
156 updated_at: datetime
157 created_by: Optional[str] = None
158 modified_by: Optional[str] = None
159 model_count: int = Field(default=0, description="Number of models for this provider")
162class LLMProviderListResponse(BaseModel):
163 """Schema for paginated list of LLM providers."""
165 providers: List[LLMProviderResponse]
166 total: int
167 page: int = 1
168 page_size: int = 50
171# ---------------------------------------------------------------------------
172# LLM Model Schemas
173# ---------------------------------------------------------------------------
176class LLMModelBase(BaseModel):
177 """Base schema for LLM model data."""
179 model_id: str = Field(..., min_length=1, max_length=255, description="Provider's model ID")
180 model_name: str = Field(..., min_length=1, max_length=255, description="Display name")
181 model_alias: Optional[str] = Field(None, max_length=255, description="Optional routing alias")
182 description: Optional[str] = Field(None, max_length=2000, description="Model description")
183 supports_chat: bool = Field(default=True, description="Supports chat completions")
184 supports_streaming: bool = Field(default=True, description="Supports streaming")
185 supports_function_calling: bool = Field(default=False, description="Supports function/tool calling")
186 supports_vision: bool = Field(default=False, description="Supports vision/images")
187 context_window: Optional[int] = Field(None, ge=1, description="Max context tokens")
188 max_output_tokens: Optional[int] = Field(None, ge=1, description="Max output tokens")
189 enabled: bool = Field(default=True, description="Whether model is enabled")
190 deprecated: bool = Field(default=False, description="Whether model is deprecated")
193class LLMModelCreate(LLMModelBase):
194 """Schema for creating a new LLM model."""
196 provider_id: str = Field(..., description="Provider ID this model belongs to")
199class LLMModelUpdate(BaseModel):
200 """Schema for updating an LLM model."""
202 model_id: Optional[str] = Field(None, min_length=1, max_length=255)
203 model_name: Optional[str] = Field(None, min_length=1, max_length=255)
204 model_alias: Optional[str] = Field(None, max_length=255)
205 description: Optional[str] = Field(None, max_length=2000)
206 supports_chat: Optional[bool] = None
207 supports_streaming: Optional[bool] = None
208 supports_function_calling: Optional[bool] = None
209 supports_vision: Optional[bool] = None
210 context_window: Optional[int] = Field(None, ge=1)
211 max_output_tokens: Optional[int] = Field(None, ge=1)
212 enabled: Optional[bool] = None
213 deprecated: Optional[bool] = None
216class LLMModelResponse(BaseModel):
217 """Schema for LLM model response."""
219 model_config = ConfigDict(from_attributes=True)
221 id: str
222 provider_id: str
223 model_id: str
224 model_name: str
225 model_alias: Optional[str] = None
226 description: Optional[str] = None
227 supports_chat: bool = True
228 supports_streaming: bool = True
229 supports_function_calling: bool = False
230 supports_vision: bool = False
231 context_window: Optional[int] = None
232 max_output_tokens: Optional[int] = None
233 enabled: bool = True
234 deprecated: bool = False
235 created_at: datetime
236 updated_at: datetime
237 provider_name: Optional[str] = Field(None, description="Provider name for display")
238 provider_type: Optional[str] = Field(None, description="Provider type for display")
241class LLMModelListResponse(BaseModel):
242 """Schema for paginated list of LLM models."""
244 models: List[LLMModelResponse]
245 total: int
246 page: int = 1
247 page_size: int = 50
250# ---------------------------------------------------------------------------
251# Chat Completion Schemas (OpenAI-compatible)
252# ---------------------------------------------------------------------------
255class FunctionDefinition(BaseModel):
256 """Function definition for tool calling."""
258 name: str = Field(..., description="Function name")
259 description: Optional[str] = Field(None, description="Function description")
260 parameters: Dict[str, Any] = Field(default_factory=dict, description="JSON Schema for parameters")
263class ToolDefinition(BaseModel):
264 """Tool definition for function calling."""
266 type: Literal["function"] = "function"
267 function: FunctionDefinition
270class ChatMessage(BaseModel):
271 """A single chat message."""
273 role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Message role")
274 content: Optional[str] = Field(None, description="Message content")
275 name: Optional[str] = Field(None, description="Optional name for the participant")
276 tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls made by assistant")
277 tool_call_id: Optional[str] = Field(None, description="ID of tool call this message responds to")
280class ChatCompletionRequest(BaseModel):
281 """Request for chat completions (OpenAI-compatible)."""
283 model: str = Field(..., description="Model ID to use")
284 messages: List[ChatMessage] = Field(..., min_length=1, description="Conversation messages")
285 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature")
286 max_tokens: Optional[int] = Field(None, ge=1, description="Maximum tokens to generate")
287 stream: bool = Field(default=False, description="Enable streaming response")
288 tools: Optional[List[ToolDefinition]] = Field(None, description="Available tools")
289 tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice preference")
290 top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling")
291 frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Frequency penalty")
292 presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Presence penalty")
293 stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences")
294 user: Optional[str] = Field(None, description="User identifier")
297class UsageStats(BaseModel):
298 """Token usage statistics."""
300 prompt_tokens: int = 0
301 completion_tokens: int = 0
302 total_tokens: int = 0
305class ChatChoice(BaseModel):
306 """A single chat completion choice."""
308 index: int = 0
309 message: ChatMessage
310 finish_reason: Optional[str] = None
313class ChatCompletionResponse(BaseModel):
314 """Response from chat completions."""
316 id: str = Field(..., description="Unique response ID")
317 object: str = "chat.completion"
318 created: int = Field(..., description="Unix timestamp")
319 model: str = Field(..., description="Model used")
320 choices: List[ChatChoice]
321 usage: Optional[UsageStats] = None
324class ChatCompletionChunk(BaseModel):
325 """Streaming chunk for chat completions."""
327 id: str
328 object: str = "chat.completion.chunk"
329 created: int
330 model: str
331 choices: List[Dict[str, Any]]
334# ---------------------------------------------------------------------------
335# Embedding Schemas
336# ---------------------------------------------------------------------------
339class EmbeddingRequest(BaseModel):
340 """Request for embeddings."""
342 model: str = Field(..., description="Model ID to use")
343 input: Union[str, List[str]] = Field(..., description="Text to embed")
344 encoding_format: Optional[Literal["float", "base64"]] = Field(None, description="Encoding format")
345 user: Optional[str] = Field(None, description="User identifier")
348class EmbeddingData(BaseModel):
349 """A single embedding result."""
351 object: str = "embedding"
352 embedding: List[float]
353 index: int = 0
356class EmbeddingResponse(BaseModel):
357 """Response from embeddings."""
359 object: str = "list"
360 data: List[EmbeddingData]
361 model: str
362 usage: UsageStats
365# ---------------------------------------------------------------------------
366# Gateway Models Response (for LLM Chat dropdown)
367# ---------------------------------------------------------------------------
370class GatewayModelInfo(BaseModel):
371 """Simplified model info for the LLM Chat dropdown."""
373 model_config = ConfigDict(from_attributes=True)
375 id: str = Field(..., description="Unique model ID")
376 model_id: str = Field(..., description="Provider's model identifier")
377 model_name: str = Field(..., description="Display name")
378 provider_id: str = Field(..., description="Provider ID")
379 provider_name: str = Field(..., description="Provider display name")
380 provider_type: str = Field(..., description="Provider type")
381 supports_streaming: bool = True
382 supports_function_calling: bool = False
383 supports_vision: bool = False
386class GatewayModelsResponse(BaseModel):
387 """Response for /llmchat/gateway/models endpoint."""
389 models: List[GatewayModelInfo]
390 count: int
393# ---------------------------------------------------------------------------
394# Health Check Schemas
395# ---------------------------------------------------------------------------
398class ProviderHealthCheck(BaseModel):
399 """Result of a provider health check."""
401 provider_id: str
402 provider_name: str
403 provider_type: str
404 status: HealthStatus
405 response_time_ms: Optional[float] = None
406 error: Optional[str] = None
407 checked_at: datetime