Coverage for mcpgateway / llm_schemas.py: 100%

228 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/llm_schemas.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6LLM Settings Pydantic Schemas. 

7This module provides Pydantic models for LLM provider configuration, model management, 

8and chat completions for the internal LLM Chat feature. 

9 

10The schemas support: 

11- LLM provider CRUD operations 

12- Model configuration and capabilities 

13- Chat completion requests/responses (OpenAI-compatible) 

14- Embedding requests/responses 

15""" 

16 

17# Standard 

18from datetime import datetime 

19from enum import Enum 

20from typing import Any, Dict, List, Literal, Optional, Union 

21 

22# Third-Party 

23from pydantic import BaseModel, ConfigDict, Field 

24 

25# --------------------------------------------------------------------------- 

26# Enums 

27# --------------------------------------------------------------------------- 

28 

29 

30class LLMProviderTypeEnum(str, Enum): 

31 """Enumeration of supported LLM provider types.""" 

32 

33 OPENAI = "openai" 

34 AZURE_OPENAI = "azure_openai" 

35 ANTHROPIC = "anthropic" 

36 BEDROCK = "bedrock" 

37 GOOGLE_VERTEX = "google_vertex" 

38 WATSONX = "watsonx" 

39 OLLAMA = "ollama" 

40 OPENAI_COMPATIBLE = "openai_compatible" 

41 COHERE = "cohere" 

42 MISTRAL = "mistral" 

43 GROQ = "groq" 

44 TOGETHER = "together" 

45 

46 

47class HealthStatus(str, Enum): 

48 """Health status values for LLM providers.""" 

49 

50 HEALTHY = "healthy" 

51 UNHEALTHY = "unhealthy" 

52 UNKNOWN = "unknown" 

53 

54 

55class RequestStatus(str, Enum): 

56 """Request processing status.""" 

57 

58 PENDING = "pending" 

59 PROCESSING = "processing" 

60 COMPLETED = "completed" 

61 FAILED = "failed" 

62 

63 

64class RequestType(str, Enum): 

65 """Types of LLM requests.""" 

66 

67 CHAT = "chat" 

68 COMPLETION = "completion" 

69 EMBEDDING = "embedding" 

70 

71 

72# --------------------------------------------------------------------------- 

73# LLM Provider Schemas 

74# --------------------------------------------------------------------------- 

75 

76 

77class LLMProviderBase(BaseModel): 

78 """Base schema for LLM provider data.""" 

79 

80 name: str = Field(..., min_length=1, max_length=255, description="Display name for the provider") 

81 description: Optional[str] = Field(None, max_length=2000, description="Optional description") 

82 provider_type: LLMProviderTypeEnum = Field(..., description="Type of LLM provider") 

83 api_base: Optional[str] = Field(None, max_length=512, description="Base URL for API requests") 

84 api_version: Optional[str] = Field(None, max_length=50, description="API version (for Azure OpenAI)") 

85 config: Dict[str, Any] = Field(default_factory=dict, description="Provider-specific configuration") 

86 default_model: Optional[str] = Field(None, max_length=255, description="Default model ID") 

87 default_temperature: float = Field(default=0.7, ge=0.0, le=2.0, description="Default temperature") 

88 default_max_tokens: Optional[int] = Field(None, ge=1, description="Default max tokens") 

89 enabled: bool = Field(default=True, description="Whether provider is enabled") 

90 plugin_ids: List[str] = Field(default_factory=list, description="Attached plugin IDs") 

91 

92 def validate_provider_config(self) -> None: 

93 """Validate provider-specific configuration based on provider type. 

94 

95 Raises: 

96 ValueError: If required provider-specific fields are missing. 

97 """ 

98 # Import here to avoid circular dependency 

99 # First-Party 

100 from mcpgateway.llm_provider_configs import get_provider_config # pylint: disable=import-outside-toplevel 

101 

102 provider_def = get_provider_config(self.provider_type.value if isinstance(self.provider_type, LLMProviderTypeEnum) else self.provider_type) 

103 if not provider_def: 

104 return 

105 

106 # Validate required config fields 

107 for field_def in provider_def.config_fields: 

108 if field_def.required and field_def.name not in self.config: 

109 raise ValueError(f"Required configuration field '{field_def.name}' missing for provider type '{self.provider_type}'") 

110 

111 

112class LLMProviderCreate(LLMProviderBase): 

113 """Schema for creating a new LLM provider.""" 

114 

115 api_key: Optional[str] = Field(None, description="API key (will be encrypted)") 

116 

117 

118class LLMProviderUpdate(BaseModel): 

119 """Schema for updating an LLM provider.""" 

120 

121 name: Optional[str] = Field(None, min_length=1, max_length=255) 

122 description: Optional[str] = Field(None, max_length=2000) 

123 provider_type: Optional[LLMProviderTypeEnum] = None 

124 api_key: Optional[str] = Field(None, description="API key (will be encrypted)") 

125 api_base: Optional[str] = Field(None, max_length=512) 

126 api_version: Optional[str] = Field(None, max_length=50) 

127 config: Optional[Dict[str, Any]] = None 

128 default_model: Optional[str] = Field(None, max_length=255) 

129 default_temperature: Optional[float] = Field(None, ge=0.0, le=2.0) 

130 default_max_tokens: Optional[int] = Field(None, ge=1) 

131 enabled: Optional[bool] = None 

132 plugin_ids: Optional[List[str]] = None 

133 

134 

135class LLMProviderResponse(BaseModel): 

136 """Schema for LLM provider response.""" 

137 

138 model_config = ConfigDict(from_attributes=True) 

139 

140 id: str 

141 name: str 

142 slug: str 

143 description: Optional[str] = None 

144 provider_type: str 

145 api_base: Optional[str] = None 

146 api_version: Optional[str] = None 

147 config: Dict[str, Any] = Field(default_factory=dict) 

148 default_model: Optional[str] = None 

149 default_temperature: float = 0.7 

150 default_max_tokens: Optional[int] = None 

151 enabled: bool = True 

152 health_status: str = "unknown" 

153 last_health_check: Optional[datetime] = None 

154 plugin_ids: List[str] = Field(default_factory=list) 

155 created_at: datetime 

156 updated_at: datetime 

157 created_by: Optional[str] = None 

158 modified_by: Optional[str] = None 

159 model_count: int = Field(default=0, description="Number of models for this provider") 

160 

161 

162class LLMProviderListResponse(BaseModel): 

163 """Schema for paginated list of LLM providers.""" 

164 

165 providers: List[LLMProviderResponse] 

166 total: int 

167 page: int = 1 

168 page_size: int = 50 

169 

170 

171# --------------------------------------------------------------------------- 

172# LLM Model Schemas 

173# --------------------------------------------------------------------------- 

174 

175 

176class LLMModelBase(BaseModel): 

177 """Base schema for LLM model data.""" 

178 

179 model_id: str = Field(..., min_length=1, max_length=255, description="Provider's model ID") 

180 model_name: str = Field(..., min_length=1, max_length=255, description="Display name") 

181 model_alias: Optional[str] = Field(None, max_length=255, description="Optional routing alias") 

182 description: Optional[str] = Field(None, max_length=2000, description="Model description") 

183 supports_chat: bool = Field(default=True, description="Supports chat completions") 

184 supports_streaming: bool = Field(default=True, description="Supports streaming") 

185 supports_function_calling: bool = Field(default=False, description="Supports function/tool calling") 

186 supports_vision: bool = Field(default=False, description="Supports vision/images") 

187 context_window: Optional[int] = Field(None, ge=1, description="Max context tokens") 

188 max_output_tokens: Optional[int] = Field(None, ge=1, description="Max output tokens") 

189 enabled: bool = Field(default=True, description="Whether model is enabled") 

190 deprecated: bool = Field(default=False, description="Whether model is deprecated") 

191 

192 

193class LLMModelCreate(LLMModelBase): 

194 """Schema for creating a new LLM model.""" 

195 

196 provider_id: str = Field(..., description="Provider ID this model belongs to") 

197 

198 

199class LLMModelUpdate(BaseModel): 

200 """Schema for updating an LLM model.""" 

201 

202 model_id: Optional[str] = Field(None, min_length=1, max_length=255) 

203 model_name: Optional[str] = Field(None, min_length=1, max_length=255) 

204 model_alias: Optional[str] = Field(None, max_length=255) 

205 description: Optional[str] = Field(None, max_length=2000) 

206 supports_chat: Optional[bool] = None 

207 supports_streaming: Optional[bool] = None 

208 supports_function_calling: Optional[bool] = None 

209 supports_vision: Optional[bool] = None 

210 context_window: Optional[int] = Field(None, ge=1) 

211 max_output_tokens: Optional[int] = Field(None, ge=1) 

212 enabled: Optional[bool] = None 

213 deprecated: Optional[bool] = None 

214 

215 

216class LLMModelResponse(BaseModel): 

217 """Schema for LLM model response.""" 

218 

219 model_config = ConfigDict(from_attributes=True) 

220 

221 id: str 

222 provider_id: str 

223 model_id: str 

224 model_name: str 

225 model_alias: Optional[str] = None 

226 description: Optional[str] = None 

227 supports_chat: bool = True 

228 supports_streaming: bool = True 

229 supports_function_calling: bool = False 

230 supports_vision: bool = False 

231 context_window: Optional[int] = None 

232 max_output_tokens: Optional[int] = None 

233 enabled: bool = True 

234 deprecated: bool = False 

235 created_at: datetime 

236 updated_at: datetime 

237 provider_name: Optional[str] = Field(None, description="Provider name for display") 

238 provider_type: Optional[str] = Field(None, description="Provider type for display") 

239 

240 

241class LLMModelListResponse(BaseModel): 

242 """Schema for paginated list of LLM models.""" 

243 

244 models: List[LLMModelResponse] 

245 total: int 

246 page: int = 1 

247 page_size: int = 50 

248 

249 

250# --------------------------------------------------------------------------- 

251# Chat Completion Schemas (OpenAI-compatible) 

252# --------------------------------------------------------------------------- 

253 

254 

255class FunctionDefinition(BaseModel): 

256 """Function definition for tool calling.""" 

257 

258 name: str = Field(..., description="Function name") 

259 description: Optional[str] = Field(None, description="Function description") 

260 parameters: Dict[str, Any] = Field(default_factory=dict, description="JSON Schema for parameters") 

261 

262 

263class ToolDefinition(BaseModel): 

264 """Tool definition for function calling.""" 

265 

266 type: Literal["function"] = "function" 

267 function: FunctionDefinition 

268 

269 

270class ChatMessage(BaseModel): 

271 """A single chat message.""" 

272 

273 role: Literal["system", "user", "assistant", "tool"] = Field(..., description="Message role") 

274 content: Optional[str] = Field(None, description="Message content") 

275 name: Optional[str] = Field(None, description="Optional name for the participant") 

276 tool_calls: Optional[List[Dict[str, Any]]] = Field(None, description="Tool calls made by assistant") 

277 tool_call_id: Optional[str] = Field(None, description="ID of tool call this message responds to") 

278 

279 

280class ChatCompletionRequest(BaseModel): 

281 """Request for chat completions (OpenAI-compatible).""" 

282 

283 model: str = Field(..., description="Model ID to use") 

284 messages: List[ChatMessage] = Field(..., min_length=1, description="Conversation messages") 

285 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature") 

286 max_tokens: Optional[int] = Field(None, ge=1, description="Maximum tokens to generate") 

287 stream: bool = Field(default=False, description="Enable streaming response") 

288 tools: Optional[List[ToolDefinition]] = Field(None, description="Available tools") 

289 tool_choice: Optional[Union[str, Dict[str, Any]]] = Field(None, description="Tool choice preference") 

290 top_p: Optional[float] = Field(None, ge=0.0, le=1.0, description="Nucleus sampling") 

291 frequency_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Frequency penalty") 

292 presence_penalty: Optional[float] = Field(None, ge=-2.0, le=2.0, description="Presence penalty") 

293 stop: Optional[Union[str, List[str]]] = Field(None, description="Stop sequences") 

294 user: Optional[str] = Field(None, description="User identifier") 

295 

296 

297class UsageStats(BaseModel): 

298 """Token usage statistics.""" 

299 

300 prompt_tokens: int = 0 

301 completion_tokens: int = 0 

302 total_tokens: int = 0 

303 

304 

305class ChatChoice(BaseModel): 

306 """A single chat completion choice.""" 

307 

308 index: int = 0 

309 message: ChatMessage 

310 finish_reason: Optional[str] = None 

311 

312 

313class ChatCompletionResponse(BaseModel): 

314 """Response from chat completions.""" 

315 

316 id: str = Field(..., description="Unique response ID") 

317 object: str = "chat.completion" 

318 created: int = Field(..., description="Unix timestamp") 

319 model: str = Field(..., description="Model used") 

320 choices: List[ChatChoice] 

321 usage: Optional[UsageStats] = None 

322 

323 

324class ChatCompletionChunk(BaseModel): 

325 """Streaming chunk for chat completions.""" 

326 

327 id: str 

328 object: str = "chat.completion.chunk" 

329 created: int 

330 model: str 

331 choices: List[Dict[str, Any]] 

332 

333 

334# --------------------------------------------------------------------------- 

335# Embedding Schemas 

336# --------------------------------------------------------------------------- 

337 

338 

339class EmbeddingRequest(BaseModel): 

340 """Request for embeddings.""" 

341 

342 model: str = Field(..., description="Model ID to use") 

343 input: Union[str, List[str]] = Field(..., description="Text to embed") 

344 encoding_format: Optional[Literal["float", "base64"]] = Field(None, description="Encoding format") 

345 user: Optional[str] = Field(None, description="User identifier") 

346 

347 

348class EmbeddingData(BaseModel): 

349 """A single embedding result.""" 

350 

351 object: str = "embedding" 

352 embedding: List[float] 

353 index: int = 0 

354 

355 

356class EmbeddingResponse(BaseModel): 

357 """Response from embeddings.""" 

358 

359 object: str = "list" 

360 data: List[EmbeddingData] 

361 model: str 

362 usage: UsageStats 

363 

364 

365# --------------------------------------------------------------------------- 

366# Gateway Models Response (for LLM Chat dropdown) 

367# --------------------------------------------------------------------------- 

368 

369 

370class GatewayModelInfo(BaseModel): 

371 """Simplified model info for the LLM Chat dropdown.""" 

372 

373 model_config = ConfigDict(from_attributes=True) 

374 

375 id: str = Field(..., description="Unique model ID") 

376 model_id: str = Field(..., description="Provider's model identifier") 

377 model_name: str = Field(..., description="Display name") 

378 provider_id: str = Field(..., description="Provider ID") 

379 provider_name: str = Field(..., description="Provider display name") 

380 provider_type: str = Field(..., description="Provider type") 

381 supports_streaming: bool = True 

382 supports_function_calling: bool = False 

383 supports_vision: bool = False 

384 

385 

386class GatewayModelsResponse(BaseModel): 

387 """Response for /llmchat/gateway/models endpoint.""" 

388 

389 models: List[GatewayModelInfo] 

390 count: int 

391 

392 

393# --------------------------------------------------------------------------- 

394# Health Check Schemas 

395# --------------------------------------------------------------------------- 

396 

397 

398class ProviderHealthCheck(BaseModel): 

399 """Result of a provider health check.""" 

400 

401 provider_id: str 

402 provider_name: str 

403 provider_type: str 

404 status: HealthStatus 

405 response_time_ms: Optional[float] = None 

406 error: Optional[str] = None 

407 checked_at: datetime