Coverage for mcpgateway / services / llm_proxy_service.py: 99%

325 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/llm_proxy_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6LLM Proxy Service 

7 

8This module implements the internal proxy for routing LLM requests 

9to configured providers. It translates requests to provider-specific 

10formats and handles streaming responses. 

11""" 

12 

13# Standard 

14import time 

15from typing import Any, AsyncGenerator, Dict, Optional, Tuple 

16import uuid 

17 

18# Third-Party 

19import httpx 

20import orjson 

21from sqlalchemy import select 

22from sqlalchemy.orm import Session 

23 

24# First-Party 

25from mcpgateway.common.validators import SecurityValidator 

26from mcpgateway.config import settings 

27from mcpgateway.db import LLMModel, LLMProvider, LLMProviderType 

28from mcpgateway.llm_schemas import ( 

29 ChatChoice, 

30 ChatCompletionRequest, 

31 ChatCompletionResponse, 

32 ChatMessage, 

33 UsageStats, 

34) 

35from mcpgateway.observability import create_span, set_span_attribute 

36from mcpgateway.services.llm_provider_service import ( 

37 decrypt_provider_config_for_runtime, 

38 LLMModelNotFoundError, 

39 LLMProviderNotFoundError, 

40) 

41from mcpgateway.services.logging_service import LoggingService 

42from mcpgateway.utils.services_auth import decode_auth 

43from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload 

44 

45# Initialize logging 

46logging_service = LoggingService() 

47logger = logging_service.get_logger(__name__) 

48 

49 

50def _provider_trace_system(provider: LLMProvider) -> str: 

51 """Map provider type to a stable ``gen_ai.system`` label. 

52 

53 Args: 

54 provider: Provider model used to resolve the tracing system label. 

55 

56 Returns: 

57 Lowercase provider type string for trace attributes. 

58 """ 

59 provider_type = str(provider.provider_type.value if hasattr(provider.provider_type, "value") else provider.provider_type) 

60 return provider_type.lower() 

61 

62 

63def _request_trace_input(request: ChatCompletionRequest) -> str: 

64 """Return a redacted serialized prompt payload for tracing. 

65 

66 Args: 

67 request: Chat completion request payload being proxied. 

68 

69 Returns: 

70 Redacted serialized request payload for the trace input field. 

71 """ 

72 return serialize_trace_payload(request.model_dump(mode="json", exclude_none=True)) 

73 

74 

75def _usage_trace_attrs(response: ChatCompletionResponse) -> Dict[str, int]: 

76 """Extract token usage attributes from a chat completion response. 

77 

78 Args: 

79 response: Provider response carrying token usage metadata. 

80 

81 Returns: 

82 Trace attribute mapping for prompt, completion, and total token counts. 

83 """ 

84 return { 

85 "gen_ai.usage.prompt_tokens": response.usage.prompt_tokens, 

86 "gen_ai.usage.completion_tokens": response.usage.completion_tokens, 

87 "gen_ai.usage.total_tokens": response.usage.total_tokens, 

88 } 

89 

90 

91class LLMProxyError(Exception): 

92 """Base class for LLM proxy errors.""" 

93 

94 

95class LLMProxyAuthError(LLMProxyError): 

96 """Raised when authentication fails.""" 

97 

98 

99class LLMProxyRequestError(LLMProxyError): 

100 """Raised when request to provider fails.""" 

101 

102 

103class LLMProxyService: 

104 """Service for proxying LLM requests to configured providers. 

105 

106 Handles request translation, streaming, and response formatting 

107 for the internal /v1/chat/completions endpoint. 

108 """ 

109 

110 def __init__(self) -> None: 

111 """Initialize the LLM proxy service.""" 

112 self._initialized = False 

113 self._client: Optional[httpx.AsyncClient] = None 

114 

115 async def initialize(self) -> None: 

116 """Initialize the proxy service and HTTP client.""" 

117 if not self._initialized: 

118 self._client = httpx.AsyncClient( 

119 timeout=httpx.Timeout(settings.llm_request_timeout, connect=30.0), 

120 limits=httpx.Limits( 

121 max_connections=settings.httpx_max_connections, 

122 max_keepalive_connections=settings.httpx_max_keepalive_connections, 

123 keepalive_expiry=settings.httpx_keepalive_expiry, 

124 ), 

125 verify=not settings.skip_ssl_verify, 

126 ) 

127 logger.info("Initialized LLM Proxy Service") 

128 self._initialized = True 

129 

130 async def shutdown(self) -> None: 

131 """Shutdown the proxy service and close connections.""" 

132 if self._initialized and self._client: 

133 await self._client.aclose() 

134 self._client = None 

135 logger.info("Shutdown LLM Proxy Service") 

136 self._initialized = False 

137 

138 def _resolve_model( 

139 self, 

140 db: Session, 

141 model_id: str, 

142 ) -> Tuple[LLMProvider, LLMModel]: 

143 """Resolve a model ID to provider and model. 

144 

145 Args: 

146 db: Database session. 

147 model_id: Model ID (can be model.id, model.model_id, or model.model_alias). 

148 

149 Returns: 

150 Tuple of (LLMProvider, LLMModel). 

151 

152 Raises: 

153 LLMModelNotFoundError: If model not found. 

154 LLMProviderNotFoundError: If provider not found or disabled. 

155 """ 

156 # Try to find by model.id first 

157 model = db.execute(select(LLMModel).where(LLMModel.id == model_id)).scalar_one_or_none() 

158 

159 # Try by model_id 

160 if not model: 

161 model = db.execute(select(LLMModel).where(LLMModel.model_id == model_id)).scalar_one_or_none() 

162 

163 # Try by model_alias 

164 if not model: 

165 model = db.execute(select(LLMModel).where(LLMModel.model_alias == model_id)).scalar_one_or_none() 

166 

167 if not model: 

168 raise LLMModelNotFoundError(f"Model not found: {model_id}") 

169 

170 if not model.enabled: 

171 raise LLMModelNotFoundError(f"Model is disabled: {model_id}") 

172 

173 # Get provider 

174 provider = db.execute(select(LLMProvider).where(LLMProvider.id == model.provider_id)).scalar_one_or_none() 

175 

176 if not provider: 

177 raise LLMProviderNotFoundError(f"Provider not found for model: {model_id}") 

178 

179 if not provider.enabled: 

180 raise LLMProviderNotFoundError(f"Provider is disabled: {provider.name}") 

181 

182 return provider, model 

183 

184 def _get_api_key(self, provider: LLMProvider) -> Optional[str]: 

185 """Extract API key from provider. 

186 

187 Args: 

188 provider: LLM provider instance. 

189 

190 Returns: 

191 Decrypted API key or None. 

192 """ 

193 if not provider.api_key: 

194 return None 

195 

196 try: 

197 auth_data = decode_auth(provider.api_key) 

198 return auth_data.get("api_key") 

199 except Exception as e: 

200 logger.error(f"Failed to decode API key for provider {provider.name}: {e}") 

201 return None 

202 

203 def _build_openai_request( 

204 self, 

205 request: ChatCompletionRequest, 

206 provider: LLMProvider, 

207 model: LLMModel, 

208 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: 

209 """Build request for OpenAI-compatible providers. 

210 

211 Args: 

212 request: Chat completion request. 

213 provider: LLM provider. 

214 model: LLM model. 

215 

216 Returns: 

217 Tuple of (url, headers, body). 

218 """ 

219 api_key = self._get_api_key(provider) 

220 base_url = provider.api_base or "https://api.openai.com/v1" 

221 

222 url = f"{base_url.rstrip('/')}/chat/completions" 

223 

224 headers = { 

225 "Content-Type": "application/json", 

226 } 

227 if api_key: 

228 headers["Authorization"] = f"Bearer {api_key}" 

229 

230 # Build request body 

231 body: Dict[str, Any] = { 

232 "model": model.model_id, 

233 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages], 

234 } 

235 

236 # Add optional parameters 

237 if request.temperature is not None: 

238 body["temperature"] = request.temperature 

239 elif provider.default_temperature: 

240 body["temperature"] = provider.default_temperature 

241 

242 if request.max_tokens is not None: 

243 body["max_tokens"] = request.max_tokens 

244 elif provider.default_max_tokens: 

245 body["max_tokens"] = provider.default_max_tokens 

246 

247 if request.stream: 

248 body["stream"] = True 

249 

250 if request.tools: 

251 body["tools"] = [t.model_dump() for t in request.tools] 

252 if request.tool_choice: 

253 body["tool_choice"] = request.tool_choice 

254 if request.top_p is not None: 

255 body["top_p"] = request.top_p 

256 if request.frequency_penalty is not None: 

257 body["frequency_penalty"] = request.frequency_penalty 

258 if request.presence_penalty is not None: 

259 body["presence_penalty"] = request.presence_penalty 

260 if request.stop: 

261 body["stop"] = request.stop 

262 

263 return url, headers, body 

264 

265 def _build_azure_request( 

266 self, 

267 request: ChatCompletionRequest, 

268 provider: LLMProvider, 

269 model: LLMModel, 

270 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: 

271 """Build request for Azure OpenAI. 

272 

273 Args: 

274 request: Chat completion request. 

275 provider: LLM provider. 

276 model: LLM model. 

277 

278 Returns: 

279 Tuple of (url, headers, body). 

280 """ 

281 api_key = self._get_api_key(provider) 

282 provider_config = decrypt_provider_config_for_runtime(provider.config) 

283 

284 # Get Azure-specific config 

285 deployment_name = provider_config.get("deployment_name") or provider_config.get("deployment") or model.model_id 

286 resource_name = provider_config.get("resource_name", "") 

287 api_version = provider_config.get("api_version") or provider.api_version or "2024-02-15-preview" 

288 

289 # Build base URL from resource name if not provided 

290 if not provider.api_base and resource_name: 

291 base_url = f"https://{resource_name}.openai.azure.com" 

292 else: 

293 base_url = provider.api_base or "" 

294 

295 url = f"{base_url.rstrip('/')}/openai/deployments/{deployment_name}/chat/completions?api-version={api_version}" 

296 

297 headers = { 

298 "Content-Type": "application/json", 

299 "api-key": api_key or "", 

300 } 

301 

302 # Build request body (similar to OpenAI) 

303 body: Dict[str, Any] = { 

304 "messages": [msg.model_dump(exclude_none=True) for msg in request.messages], 

305 } 

306 

307 if request.temperature is not None: 

308 body["temperature"] = request.temperature 

309 elif provider.default_temperature: 

310 body["temperature"] = provider.default_temperature 

311 

312 if request.max_tokens is not None: 

313 body["max_tokens"] = request.max_tokens 

314 elif provider.default_max_tokens: 

315 body["max_tokens"] = provider.default_max_tokens 

316 

317 if request.stream: 

318 body["stream"] = True 

319 

320 return url, headers, body 

321 

322 def _build_anthropic_request( 

323 self, 

324 request: ChatCompletionRequest, 

325 provider: LLMProvider, 

326 model: LLMModel, 

327 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: 

328 """Build request for Anthropic Claude. 

329 

330 Args: 

331 request: Chat completion request. 

332 provider: LLM provider. 

333 model: LLM model. 

334 

335 Returns: 

336 Tuple of (url, headers, body). 

337 """ 

338 api_key = self._get_api_key(provider) 

339 base_url = provider.api_base or "https://api.anthropic.com" 

340 provider_config = decrypt_provider_config_for_runtime(provider.config) 

341 

342 url = f"{base_url.rstrip('/')}/v1/messages" 

343 

344 # Get Anthropic-specific config 

345 anthropic_version = provider_config.get("anthropic_version") or provider.api_version or "2023-06-01" 

346 

347 headers = { 

348 "Content-Type": "application/json", 

349 "x-api-key": api_key or "", 

350 "anthropic-version": anthropic_version, 

351 } 

352 

353 # Convert messages to Anthropic format 

354 system_message = None 

355 messages = [] 

356 for msg in request.messages: 

357 if msg.role == "system": 

358 system_message = msg.content 

359 else: 

360 messages.append( 

361 { 

362 "role": msg.role, 

363 "content": msg.content or "", 

364 } 

365 ) 

366 

367 body: Dict[str, Any] = { 

368 "model": model.model_id, 

369 "messages": messages, 

370 "max_tokens": request.max_tokens or provider.default_max_tokens or 4096, 

371 } 

372 

373 if system_message: 

374 body["system"] = system_message 

375 

376 if request.temperature is not None: 

377 body["temperature"] = request.temperature 

378 elif provider.default_temperature: 

379 body["temperature"] = provider.default_temperature 

380 

381 if request.stream: 

382 body["stream"] = True 

383 

384 return url, headers, body 

385 

386 def _build_ollama_request( 

387 self, 

388 request: ChatCompletionRequest, 

389 provider: LLMProvider, 

390 model: LLMModel, 

391 ) -> Tuple[str, Dict[str, str], Dict[str, Any]]: 

392 """Build request for Ollama. 

393 

394 Args: 

395 request: Chat completion request. 

396 provider: LLM provider. 

397 model: LLM model. 

398 

399 Returns: 

400 Tuple of (url, headers, body). 

401 """ 

402 base_url = provider.api_base or "http://localhost:11434" 

403 base_url = base_url.rstrip("/") 

404 

405 # Check if using OpenAI-compatible endpoint 

406 if base_url.endswith("/v1"): 

407 # Use OpenAI-compatible API 

408 url = f"{base_url}/chat/completions" 

409 headers = {"Content-Type": "application/json"} 

410 body: Dict[str, Any] = { 

411 "model": model.model_id, 

412 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages], 

413 "stream": request.stream, 

414 } 

415 if request.temperature is not None: 

416 body["temperature"] = request.temperature 

417 elif provider.default_temperature: 

418 body["temperature"] = provider.default_temperature 

419 if request.max_tokens: 

420 body["max_tokens"] = request.max_tokens 

421 elif provider.default_max_tokens: 

422 body["max_tokens"] = provider.default_max_tokens 

423 else: 

424 # Use native Ollama API 

425 url = f"{base_url}/api/chat" 

426 headers = {"Content-Type": "application/json"} 

427 body = { 

428 "model": model.model_id, 

429 "messages": [{"role": msg.role, "content": msg.content or ""} for msg in request.messages], 

430 "stream": request.stream, 

431 } 

432 options = {} 

433 if request.temperature is not None: 

434 options["temperature"] = request.temperature 

435 elif provider.default_temperature: 

436 options["temperature"] = provider.default_temperature 

437 if options: 

438 body["options"] = options 

439 

440 return url, headers, body 

441 

442 async def chat_completion( 

443 self, 

444 db: Session, 

445 request: ChatCompletionRequest, 

446 ) -> ChatCompletionResponse: 

447 """Process a chat completion request (non-streaming). 

448 

449 Args: 

450 db: Database session. 

451 request: Chat completion request. 

452 

453 Returns: 

454 ChatCompletionResponse. 

455 

456 Raises: 

457 LLMProxyRequestError: If request fails. 

458 """ 

459 if not self._client: 

460 await self.initialize() 

461 

462 provider, model = self._resolve_model(db, request.model) 

463 

464 # Build request based on provider type 

465 if provider.provider_type == LLMProviderType.AZURE_OPENAI: 

466 url, headers, body = self._build_azure_request(request, provider, model) 

467 elif provider.provider_type == LLMProviderType.ANTHROPIC: 

468 url, headers, body = self._build_anthropic_request(request, provider, model) 

469 elif provider.provider_type == LLMProviderType.OLLAMA: 

470 url, headers, body = self._build_ollama_request(request, provider, model) 

471 else: 

472 # Default to OpenAI-compatible 

473 url, headers, body = self._build_openai_request(request, provider, model) 

474 

475 # Ensure non-streaming 

476 body["stream"] = False 

477 

478 # Validate the constructed URL to prevent SSRF attacks 

479 try: 

480 SecurityValidator.validate_url(url, "LLM provider URL") 

481 except ValueError as url_err: 

482 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err 

483 

484 span_attributes = { 

485 "langfuse.observation.type": "generation", 

486 "gen_ai.system": _provider_trace_system(provider), 

487 "gen_ai.request.model": model.model_id, 

488 "llm.provider.id": str(provider.id), 

489 "llm.provider.type": _provider_trace_system(provider), 

490 "llm.model.id": str(model.id), 

491 } 

492 if is_input_capture_enabled("llm.proxy"): 

493 span_attributes["langfuse.observation.input"] = _request_trace_input(request) 

494 

495 with create_span("llm.proxy", span_attributes) as span: 

496 try: 

497 response = await self._client.post(url, headers=headers, json=body) 

498 response.raise_for_status() 

499 data = response.json() 

500 

501 # Transform response based on provider 

502 if provider.provider_type == LLMProviderType.ANTHROPIC: 

503 result = self._transform_anthropic_response(data, model.model_id) 

504 elif provider.provider_type == LLMProviderType.OLLAMA: 

505 base_url = (provider.api_base or "").rstrip("/") 

506 if base_url.endswith("/v1"): 

507 result = self._transform_openai_response(data) 

508 else: 

509 result = self._transform_ollama_response(data, model.model_id) 

510 else: 

511 result = self._transform_openai_response(data) 

512 

513 if span: 

514 set_span_attribute(span, "gen_ai.response.model", result.model) 

515 for key, value in _usage_trace_attrs(result).items(): 

516 set_span_attribute(span, key, value) 

517 if is_output_capture_enabled("llm.proxy"): 

518 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload(result)) 

519 

520 return result 

521 

522 except httpx.HTTPStatusError as e: 

523 logger.error(f"LLM request failed: {e.response.status_code} - {e.response.text}") 

524 raise LLMProxyRequestError(f"Request failed: {e.response.status_code}") 

525 except httpx.RequestError as e: 

526 logger.error(f"LLM request error: {e}") 

527 raise LLMProxyRequestError(f"Connection error: {str(e)}") 

528 

529 async def chat_completion_stream( 

530 self, 

531 db: Session, 

532 request: ChatCompletionRequest, 

533 ) -> AsyncGenerator[str, None]: 

534 """Process a streaming chat completion request. 

535 

536 Args: 

537 db: Database session. 

538 request: Chat completion request. 

539 

540 Yields: 

541 SSE-formatted string chunks. 

542 

543 Raises: 

544 LLMProxyRequestError: If request fails. 

545 """ 

546 if not self._client: 

547 await self.initialize() 

548 

549 provider, model = self._resolve_model(db, request.model) 

550 

551 # Build request based on provider type 

552 if provider.provider_type == LLMProviderType.AZURE_OPENAI: 

553 url, headers, body = self._build_azure_request(request, provider, model) 

554 elif provider.provider_type == LLMProviderType.ANTHROPIC: 

555 url, headers, body = self._build_anthropic_request(request, provider, model) 

556 elif provider.provider_type == LLMProviderType.OLLAMA: 

557 url, headers, body = self._build_ollama_request(request, provider, model) 

558 else: 

559 url, headers, body = self._build_openai_request(request, provider, model) 

560 

561 # Ensure streaming 

562 body["stream"] = True 

563 

564 # Validate the constructed URL to prevent SSRF attacks 

565 try: 

566 SecurityValidator.validate_url(url, "LLM provider URL") 

567 except ValueError as url_err: 

568 raise LLMProxyRequestError(f"Invalid LLM provider URL: {url_err}") from url_err 

569 

570 response_id = f"chatcmpl-{uuid.uuid4().hex[:24]}" 

571 created = int(time.time()) 

572 capture_output = is_output_capture_enabled("llm.proxy") 

573 captured_output = "" 

574 

575 span_attributes = { 

576 "langfuse.observation.type": "generation", 

577 "gen_ai.system": _provider_trace_system(provider), 

578 "gen_ai.request.model": model.model_id, 

579 "llm.provider.id": str(provider.id), 

580 "llm.provider.type": _provider_trace_system(provider), 

581 "llm.model.id": str(model.id), 

582 "llm.stream": True, 

583 } 

584 if is_input_capture_enabled("llm.proxy"): 

585 span_attributes["langfuse.observation.input"] = _request_trace_input(request) 

586 

587 with create_span("llm.proxy", span_attributes) as span: 

588 try: 

589 async with self._client.stream("POST", url, headers=headers, json=body) as response: 

590 response.raise_for_status() 

591 

592 async for line in response.aiter_lines(): 

593 if not line: 

594 continue 

595 

596 # Handle SSE format 

597 if line.startswith("data:"): 

598 data_str = line[5:] 

599 if data_str.startswith(" "): 

600 data_str = data_str[1:] 

601 if data_str.strip() == "[DONE]": 

602 yield "data: [DONE]\n\n" 

603 break 

604 

605 try: 

606 data = orjson.loads(data_str) 

607 

608 if provider.provider_type == LLMProviderType.ANTHROPIC: 

609 chunk = self._transform_anthropic_stream_chunk(data, response_id, created, model.model_id) 

610 elif provider.provider_type == LLMProviderType.OLLAMA: 

611 base_url = (provider.api_base or "").rstrip("/") 

612 if base_url.endswith("/v1"): 

613 chunk = data_str 

614 else: 

615 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id) 

616 else: 

617 chunk = data_str 

618 

619 if chunk: 

620 if capture_output and len(captured_output) < 65536: 

621 captured_output += chunk[: 65536 - len(captured_output)] 

622 yield f"data: {chunk}\n\n" 

623 

624 except orjson.JSONDecodeError: 

625 continue 

626 

627 elif provider.provider_type == LLMProviderType.OLLAMA: 

628 base_url = (provider.api_base or "").rstrip("/") 

629 if not base_url.endswith("/v1"): 

630 try: 

631 data = orjson.loads(line) 

632 chunk = self._transform_ollama_stream_chunk(data, response_id, created, model.model_id) 

633 if chunk: 

634 if capture_output and len(captured_output) < 65536: 

635 captured_output += chunk[: 65536 - len(captured_output)] 

636 yield f"data: {chunk}\n\n" 

637 except orjson.JSONDecodeError: 

638 continue 

639 except httpx.HTTPStatusError as e: 

640 error_chunk = { 

641 "error": { 

642 "message": f"Request failed: {e.response.status_code}", 

643 "type": "proxy_error", 

644 } 

645 } 

646 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n" 

647 except httpx.RequestError as e: 

648 error_chunk = { 

649 "error": { 

650 "message": f"Connection error: {str(e)}", 

651 "type": "proxy_error", 

652 } 

653 } 

654 yield f"data: {orjson.dumps(error_chunk).decode()}\n\n" 

655 finally: 

656 if span and capture_output and captured_output: 

657 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload({"stream": captured_output})) 

658 

659 def _transform_openai_response(self, data: Dict[str, Any]) -> ChatCompletionResponse: 

660 """Transform OpenAI response to standard format. 

661 

662 Args: 

663 data: Raw OpenAI API response data. 

664 

665 Returns: 

666 ChatCompletionResponse in standard format. 

667 """ 

668 choices = [] 

669 for choice in data.get("choices", []): 

670 message_data = choice.get("message", {}) 

671 choices.append( 

672 ChatChoice( 

673 index=choice.get("index", 0), 

674 message=ChatMessage( 

675 role=message_data.get("role", "assistant"), 

676 content=message_data.get("content"), 

677 tool_calls=message_data.get("tool_calls"), 

678 ), 

679 finish_reason=choice.get("finish_reason"), 

680 ) 

681 ) 

682 

683 usage_data = data.get("usage", {}) 

684 usage = UsageStats( 

685 prompt_tokens=usage_data.get("prompt_tokens", 0), 

686 completion_tokens=usage_data.get("completion_tokens", 0), 

687 total_tokens=usage_data.get("total_tokens", 0), 

688 ) 

689 

690 return ChatCompletionResponse( 

691 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"), 

692 created=data.get("created", int(time.time())), 

693 model=data.get("model", "unknown"), 

694 choices=choices, 

695 usage=usage, 

696 ) 

697 

698 def _transform_anthropic_response( 

699 self, 

700 data: Dict[str, Any], 

701 model_id: str, 

702 ) -> ChatCompletionResponse: 

703 """Transform Anthropic response to OpenAI format. 

704 

705 Args: 

706 data: Raw Anthropic API response data. 

707 model_id: Model ID to include in response. 

708 

709 Returns: 

710 ChatCompletionResponse in OpenAI format. 

711 """ 

712 content = "" 

713 for block in data.get("content", []): 

714 if block.get("type") == "text": 

715 content += block.get("text", "") 

716 

717 usage_data = data.get("usage", {}) 

718 

719 return ChatCompletionResponse( 

720 id=data.get("id", f"chatcmpl-{uuid.uuid4().hex[:24]}"), 

721 created=int(time.time()), 

722 model=model_id, 

723 choices=[ 

724 ChatChoice( 

725 index=0, 

726 message=ChatMessage(role="assistant", content=content), 

727 finish_reason=data.get("stop_reason", "stop"), 

728 ) 

729 ], 

730 usage=UsageStats( 

731 prompt_tokens=usage_data.get("input_tokens", 0), 

732 completion_tokens=usage_data.get("output_tokens", 0), 

733 total_tokens=usage_data.get("input_tokens", 0) + usage_data.get("output_tokens", 0), 

734 ), 

735 ) 

736 

737 def _transform_ollama_response( 

738 self, 

739 data: Dict[str, Any], 

740 model_id: str, 

741 ) -> ChatCompletionResponse: 

742 """Transform Ollama response to OpenAI format. 

743 

744 Args: 

745 data: Raw Ollama API response data. 

746 model_id: Model ID to include in response. 

747 

748 Returns: 

749 ChatCompletionResponse in OpenAI format. 

750 """ 

751 message = data.get("message", {}) 

752 

753 return ChatCompletionResponse( 

754 id=f"chatcmpl-{uuid.uuid4().hex[:24]}", 

755 created=int(time.time()), 

756 model=model_id, 

757 choices=[ 

758 ChatChoice( 

759 index=0, 

760 message=ChatMessage( 

761 role=message.get("role", "assistant"), 

762 content=message.get("content", ""), 

763 ), 

764 finish_reason="stop" if data.get("done") else None, 

765 ) 

766 ], 

767 usage=UsageStats( 

768 prompt_tokens=data.get("prompt_eval_count", 0), 

769 completion_tokens=data.get("eval_count", 0), 

770 total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0), 

771 ), 

772 ) 

773 

774 def _transform_anthropic_stream_chunk( 

775 self, 

776 data: Dict[str, Any], 

777 response_id: str, 

778 created: int, 

779 model_id: str, 

780 ) -> Optional[str]: 

781 """Transform Anthropic streaming chunk to OpenAI format. 

782 

783 Args: 

784 data: Raw Anthropic streaming event data. 

785 response_id: Response ID for the chunk. 

786 created: Timestamp for the response. 

787 model_id: Model ID to include in response. 

788 

789 Returns: 

790 JSON string chunk in OpenAI format, or None if not applicable. 

791 """ 

792 event_type = data.get("type") 

793 

794 if event_type == "content_block_delta": 

795 delta = data.get("delta", {}) 

796 if delta.get("type") == "text_delta": 

797 chunk = { 

798 "id": response_id, 

799 "object": "chat.completion.chunk", 

800 "created": created, 

801 "model": model_id, 

802 "choices": [ 

803 { 

804 "index": 0, 

805 "delta": {"content": delta.get("text", "")}, 

806 "finish_reason": None, 

807 } 

808 ], 

809 } 

810 return orjson.dumps(chunk).decode() 

811 

812 elif event_type == "message_stop": 

813 chunk = { 

814 "id": response_id, 

815 "object": "chat.completion.chunk", 

816 "created": created, 

817 "model": model_id, 

818 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], 

819 } 

820 return orjson.dumps(chunk).decode() 

821 

822 return None 

823 

824 def _transform_ollama_stream_chunk( 

825 self, 

826 data: Dict[str, Any], 

827 response_id: str, 

828 created: int, 

829 model_id: str, 

830 ) -> Optional[str]: 

831 """Transform Ollama streaming chunk to OpenAI format. 

832 

833 Args: 

834 data: Raw Ollama streaming event data. 

835 response_id: Response ID for the chunk. 

836 created: Timestamp for the response. 

837 model_id: Model ID to include in response. 

838 

839 Returns: 

840 JSON string chunk in OpenAI format, or None if not applicable. 

841 """ 

842 message = data.get("message", {}) 

843 content = message.get("content", "") 

844 

845 if data.get("done"): 

846 chunk = { 

847 "id": response_id, 

848 "object": "chat.completion.chunk", 

849 "created": created, 

850 "model": model_id, 

851 "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], 

852 } 

853 else: 

854 chunk = { 

855 "id": response_id, 

856 "object": "chat.completion.chunk", 

857 "created": created, 

858 "model": model_id, 

859 "choices": [ 

860 { 

861 "index": 0, 

862 "delta": {"content": content} if content else {}, 

863 "finish_reason": None, 

864 } 

865 ], 

866 } 

867 

868 return orjson.dumps(chunk).decode()