Coverage for mcpgateway / routers / llmchat_router.py: 99%

401 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/llmchat_router.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7LLM Chat Router Module 

8 

9This module provides FastAPI endpoints for managing LLM-based chat sessions 

10with MCP (Model Context Protocol) server integration. LLM providers are 

11configured via the Admin UI's LLM Settings and accessed through the gateway 

12provider. 

13 

14The module handles user session management, configuration, and real-time 

15streaming responses for conversational AI applications with unified chat 

16history management via ChatHistoryManager from mcp_client_chat_service. 

17 

18""" 

19 

20# Standard 

21import asyncio 

22import os 

23import time 

24from typing import Any, Dict, Optional 

25 

26# Third-Party 

27from fastapi import APIRouter, Depends, HTTPException, Request 

28from fastapi.responses import StreamingResponse 

29import orjson 

30from pydantic import BaseModel, Field 

31 

32try: 

33 # Third-Party 

34 import redis.asyncio # noqa: F401 - availability check only 

35 

36 REDIS_AVAILABLE = True 

37except ImportError: 

38 REDIS_AVAILABLE = False 

39 

40# First-Party 

41from mcpgateway.common.validators import SecurityValidator 

42from mcpgateway.config import settings 

43from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission 

44from mcpgateway.services.logging_service import LoggingService 

45from mcpgateway.services.mcp_client_chat_service import ( 

46 ChatProcessingError, 

47 GatewayConfig, 

48 LLMConfig, 

49 MCPChatService, 

50 MCPClientConfig, 

51 MCPServerConfig, 

52) 

53from mcpgateway.utils.redis_client import get_redis_client 

54from mcpgateway.utils.services_auth import decode_auth, encode_auth 

55 

56# Initialize router 

57llmchat_router = APIRouter(prefix="/llmchat", tags=["llmchat"]) 

58 

59# Redis client (initialized via init_redis() during app startup) 

60redis_client = None 

61 

62 

63async def init_redis() -> None: 

64 """Initialize Redis client using the shared factory. 

65 

66 Should be called during application startup from main.py lifespan. 

67 """ 

68 global redis_client 

69 if getattr(settings, "cache_type", None) == "redis" and getattr(settings, "redis_url", None): 

70 redis_client = await get_redis_client() 

71 if redis_client: 

72 logger.info("LLMChat router connected to shared Redis client") 

73 

74 

75# Fallback in-memory stores (used when Redis unavailable) 

76# Store active chat sessions per user 

77active_sessions: Dict[str, MCPChatService] = {} 

78 

79# Store configuration per user 

80user_configs: Dict[str, tuple[bytes, float]] = {} 

81 

82# Logging 

83logging_service = LoggingService() 

84logger = logging_service.get_logger(__name__) 

85 

86# ---------- MODELS ---------- 

87 

88 

89class LLMInput(BaseModel): 

90 """Input configuration for LLM provider selection. 

91 

92 This model specifies which gateway-configured model to use. 

93 Models must be configured via Admin UI -> LLM Settings. 

94 

95 Attributes: 

96 model: Model ID from the gateway's LLM Settings (UUID or model_id). 

97 temperature: Optional sampling temperature (0.0-2.0). 

98 max_tokens: Optional maximum tokens to generate. 

99 

100 Examples: 

101 >>> llm_input = LLMInput(model='gpt-4o') 

102 >>> llm_input.model 

103 'gpt-4o' 

104 

105 >>> llm_input = LLMInput(model='abc123-uuid', temperature=0.5) 

106 >>> llm_input.temperature 

107 0.5 

108 """ 

109 

110 model: str = Field(..., description="Model ID from gateway LLM Settings (UUID or model_id)") 

111 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature") 

112 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

113 

114 

115class ServerInput(BaseModel): 

116 """Input configuration for MCP server connection. 

117 

118 Defines the connection parameters required to establish communication 

119 with an MCP (Model Context Protocol) server. 

120 

121 Attributes: 

122 url: Optional MCP server URL endpoint. Defaults to environment variable 

123 or 'http://localhost:8000/mcp'. 

124 transport: Communication transport protocol. Defaults to 'streamable_http'. 

125 auth_token: Optional authentication token for secure server access. 

126 

127 Examples: 

128 >>> server = ServerInput(url='http://example.com/mcp') 

129 >>> server.transport 

130 'streamable_http' 

131 

132 >>> server = ServerInput() 

133 >>> server.url is None 

134 True 

135 """ 

136 

137 url: Optional[str] = None 

138 transport: Optional[str] = "streamable_http" 

139 auth_token: Optional[str] = None 

140 

141 

142class ConnectInput(BaseModel): 

143 """Request model for establishing a new chat session. 

144 

145 Contains all necessary parameters to initialize a user's chat session 

146 including server connection details, LLM configuration, and streaming preferences. 

147 

148 Attributes: 

149 user_id: Unique identifier for the user session. Required for session management. 

150 server: Optional MCP server configuration. Uses defaults if not provided. 

151 llm: LLM configuration specifying which gateway model to use. Required. 

152 streaming: Whether to enable streaming responses. Defaults to False. 

153 

154 Examples: 

155 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

156 >>> connect.streaming 

157 False 

158 

159 >>> connect = ConnectInput(user_id='user456', llm=LLMInput(model='gpt-4o'), streaming=True) 

160 >>> connect.user_id 

161 'user456' 

162 """ 

163 

164 user_id: str 

165 server: Optional[ServerInput] = None 

166 llm: LLMInput = Field(..., description="LLM configuration with model from gateway LLM Settings") 

167 streaming: bool = False 

168 

169 

170class ChatInput(BaseModel): 

171 """Request model for sending chat messages. 

172 

173 Encapsulates user message data for processing by the chat service. 

174 

175 Attributes: 

176 user_id: Unique identifier for the active user session. 

177 message: The chat message content to be processed. 

178 streaming: Whether to stream the response. Defaults to False. 

179 

180 Examples: 

181 >>> chat = ChatInput(user_id='user123', message='Hello, AI!') 

182 >>> len(chat.message) > 0 

183 True 

184 

185 >>> chat = ChatInput(user_id='user456', message='Tell me a story', streaming=True) 

186 >>> chat.streaming 

187 True 

188 """ 

189 

190 user_id: str 

191 message: str 

192 streaming: bool = False 

193 

194 

195class DisconnectInput(BaseModel): 

196 """Request model for terminating a chat session. 

197 

198 Simple model containing only the user identifier for session cleanup. 

199 

200 Attributes: 

201 user_id: Unique identifier of the session to disconnect. 

202 

203 Examples: 

204 >>> disconnect = DisconnectInput(user_id='user123') 

205 >>> disconnect.user_id 

206 'user123' 

207 """ 

208 

209 user_id: str 

210 

211 

212# ---------- HELPERS ---------- 

213 

214 

215def build_llm_config(llm: LLMInput) -> LLMConfig: 

216 """Construct an LLMConfig object from input parameters. 

217 

218 Creates a gateway provider configuration that routes requests through 

219 the gateway's LLM Settings. Models must be configured via Admin UI. 

220 

221 Args: 

222 llm: LLMInput containing model ID and optional temperature/max_tokens. 

223 

224 Returns: 

225 LLMConfig: Gateway provider configuration. 

226 

227 Examples: 

228 >>> llm_input = LLMInput(model='gpt-4o') 

229 >>> config = build_llm_config(llm_input) 

230 >>> config.provider 

231 'gateway' 

232 

233 Note: 

234 All LLM configuration is done via Admin UI -> Settings -> LLM Settings. 

235 The gateway provider looks up models from the database and creates 

236 the appropriate LLM instance based on provider type. 

237 """ 

238 return LLMConfig( 

239 provider="gateway", 

240 config=GatewayConfig( 

241 model=llm.model, 

242 temperature=llm.temperature if llm.temperature is not None else 0.7, 

243 max_tokens=llm.max_tokens, 

244 ), 

245 ) 

246 

247 

248def build_config(input_data: ConnectInput) -> MCPClientConfig: 

249 """Build complete MCP client configuration from connection input. 

250 

251 Constructs a comprehensive configuration object combining MCP server settings 

252 and LLM configuration. 

253 

254 Args: 

255 input_data: ConnectInput object containing server, LLM, and streaming settings. 

256 

257 Returns: 

258 MCPClientConfig: Complete client configuration ready for service initialization. 

259 

260 Examples: 

261 >>> from mcpgateway.routers.llmchat_router import ConnectInput, LLMInput, build_config 

262 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

263 >>> config = build_config(connect) 

264 >>> config.mcp_server.transport 

265 'streamable_http' 

266 

267 Note: 

268 MCP server settings use defaults if not provided. 

269 LLM configuration routes through the gateway provider. 

270 """ 

271 server = input_data.server 

272 

273 return MCPClientConfig( 

274 mcp_server=MCPServerConfig( 

275 url=server.url if server and server.url else "http://localhost:8000/mcp", 

276 transport=server.transport if server and server.transport else "streamable_http", 

277 auth_token=server.auth_token if server else None, 

278 ), 

279 llm=build_llm_config(input_data.llm), 

280 enable_streaming=input_data.streaming, 

281 ) 

282 

283 

284def _get_user_id_from_context(user: Dict[str, Any]) -> str: 

285 """Extract a stable user identifier from the authenticated user context. 

286 

287 Args: 

288 user: Authenticated user context from RBAC dependency. 

289 

290 Returns: 

291 User identifier string or "unknown" if missing. 

292 """ 

293 if isinstance(user, dict): 

294 return user.get("id") or user.get("user_id") or user.get("sub") or user.get("email") or "unknown" 

295 return "unknown" if user is None else str(getattr(user, "id", user)) 

296 

297 

298def _resolve_user_id(input_user_id: Optional[str], user: Dict[str, Any]) -> str: 

299 """Resolve the authenticated user ID and reject mismatched requests. 

300 

301 Args: 

302 input_user_id: User ID provided by the client (optional). 

303 user: Authenticated user context from RBAC dependency. 

304 

305 Returns: 

306 Resolved authenticated user identifier. 

307 

308 Raises: 

309 HTTPException: When authentication is missing or user ID mismatches. 

310 """ 

311 user_id = _get_user_id_from_context(user) 

312 if user_id == "unknown": 

313 raise HTTPException(status_code=401, detail="Authentication required.") 

314 if input_user_id and input_user_id != user_id: 

315 raise HTTPException(status_code=403, detail="User ID mismatch.") 

316 return user_id 

317 

318 

319# ---------- SESSION STORAGE HELPERS ---------- 

320 

321# Identify this worker uniquely (used for sticky session ownership) 

322WORKER_ID = str(os.getpid()) 

323 

324# Tunables (can set via environment) 

325SESSION_TTL = settings.llmchat_session_ttl # seconds for active_session key TTL 

326LOCK_TTL = settings.llmchat_session_lock_ttl # seconds for lock expiry 

327LOCK_RETRIES = settings.llmchat_session_lock_retries # how many times to poll while waiting 

328LOCK_WAIT = settings.llmchat_session_lock_wait # seconds between polls 

329USER_CONFIG_TTL = settings.llmchat_session_ttl 

330 

331_ENCRYPTED_CONFIG_PAYLOAD_KEY = "_encrypted_payload" 

332_ENCRYPTED_CONFIG_VERSION_KEY = "_version" 

333_ENCRYPTED_CONFIG_VERSION = "1" 

334_CONFIG_SENSITIVE_KEYS = frozenset( 

335 { 

336 "api_key", 

337 "auth_token", 

338 "authorization", 

339 "access_token", 

340 "refresh_token", 

341 "client_secret", 

342 "secret_access_key", 

343 "session_token", 

344 "credentials_json", 

345 "password", 

346 "private_key", 

347 } 

348) 

349 

350 

351# Redis key helpers 

352def _cfg_key(user_id: str) -> str: 

353 """Generate Redis key for user configuration storage. 

354 

355 Args: 

356 user_id: User identifier. 

357 

358 Returns: 

359 str: Redis key for storing user configuration. 

360 """ 

361 return f"user_config:{user_id}" 

362 

363 

364def _active_key(user_id: str) -> str: 

365 """Generate Redis key for active session tracking. 

366 

367 Args: 

368 user_id: User identifier. 

369 

370 Returns: 

371 str: Redis key for tracking active sessions. 

372 """ 

373 return f"active_session:{user_id}" 

374 

375 

376def _lock_key(user_id: str) -> str: 

377 """Generate Redis key for session initialization lock. 

378 

379 Args: 

380 user_id: User identifier. 

381 

382 Returns: 

383 str: Redis key for session locks. 

384 """ 

385 return f"session_lock:{user_id}" 

386 

387 

388def _serialize_user_config_for_storage(config: MCPClientConfig) -> bytes: 

389 """Serialize and encrypt user config for storage backends. 

390 

391 Args: 

392 config: User MCP client configuration. 

393 

394 Returns: 

395 Serialized bytes containing encrypted config envelope. 

396 """ 

397 payload = encode_auth({"config": config.model_dump()}) 

398 return orjson.dumps( 

399 { 

400 _ENCRYPTED_CONFIG_VERSION_KEY: _ENCRYPTED_CONFIG_VERSION, 

401 _ENCRYPTED_CONFIG_PAYLOAD_KEY: payload, 

402 } 

403 ) 

404 

405 

406def _deserialize_user_config_from_storage(data: bytes | str) -> Optional[MCPClientConfig]: 

407 """Deserialize user config from encrypted or legacy plaintext payloads. 

408 

409 Args: 

410 data: Serialized config payload from storage. 

411 

412 Returns: 

413 Parsed ``MCPClientConfig`` when data is valid, otherwise ``None``. 

414 """ 

415 try: 

416 parsed = orjson.loads(data) 

417 except Exception: 

418 logger.warning("Failed to parse stored LLM chat config payload") 

419 return None 

420 

421 # New encrypted envelope 

422 if isinstance(parsed, dict) and _ENCRYPTED_CONFIG_PAYLOAD_KEY in parsed: 

423 encrypted_payload = parsed.get(_ENCRYPTED_CONFIG_PAYLOAD_KEY) 

424 if not encrypted_payload: 

425 return None 

426 decoded = decode_auth(encrypted_payload) 

427 config_data = decoded.get("config") if isinstance(decoded, dict) else None 

428 if not isinstance(config_data, dict): 

429 logger.warning("Decoded encrypted LLM chat config is invalid") 

430 return None 

431 return MCPClientConfig(**config_data) 

432 

433 # Legacy plaintext payload compatibility 

434 if isinstance(parsed, dict): 

435 return MCPClientConfig(**parsed) 

436 

437 return None 

438 

439 

440def _is_sensitive_config_key(key: str) -> bool: 

441 """Return whether a config key should be masked in responses. 

442 

443 Args: 

444 key: Config field name. 

445 

446 Returns: 

447 ``True`` when key is in the sensitive-key allowlist. 

448 """ 

449 return str(key).strip().lower() in _CONFIG_SENSITIVE_KEYS 

450 

451 

452def _mask_sensitive_config_values(value: Any) -> Any: 

453 """Recursively mask sensitive config values in API responses. 

454 

455 Args: 

456 value: Arbitrary nested config value. 

457 

458 Returns: 

459 Value with sensitive fields replaced by configured mask marker. 

460 """ 

461 if isinstance(value, dict): 

462 masked: Dict[str, Any] = {} 

463 for key, item in value.items(): 

464 if _is_sensitive_config_key(key): 

465 masked[key] = settings.masked_auth_value if item not in (None, "") else item 

466 else: 

467 masked[key] = _mask_sensitive_config_values(item) 

468 return masked 

469 if isinstance(value, list): 

470 return [_mask_sensitive_config_values(item) for item in value] 

471 return value 

472 

473 

474# ---------- CONFIG HELPERS ---------- 

475 

476 

477async def set_user_config(user_id: str, config: MCPClientConfig): 

478 """Store user configuration in Redis or memory. 

479 

480 Args: 

481 user_id: User identifier. 

482 config: Complete MCP client configuration. 

483 """ 

484 serialized = _serialize_user_config_for_storage(config) 

485 if redis_client: 

486 await redis_client.set(_cfg_key(user_id), serialized, ex=USER_CONFIG_TTL) 

487 else: 

488 user_configs[user_id] = (serialized, time.monotonic()) 

489 

490 

491async def get_user_config(user_id: str) -> Optional[MCPClientConfig]: 

492 """Retrieve user configuration from Redis or memory. 

493 

494 Args: 

495 user_id: User identifier. 

496 

497 Returns: 

498 Optional[MCPClientConfig]: User configuration if found, None otherwise. 

499 """ 

500 if redis_client: 

501 data = await redis_client.get(_cfg_key(user_id)) 

502 if not data: 

503 return None 

504 return _deserialize_user_config_from_storage(data) 

505 

506 cached_entry = user_configs.get(user_id) 

507 if not cached_entry: 

508 return None 

509 cached, cached_at = cached_entry 

510 if (time.monotonic() - cached_at) > USER_CONFIG_TTL: 

511 user_configs.pop(user_id, None) 

512 return None 

513 return _deserialize_user_config_from_storage(cached) 

514 

515 

516async def delete_user_config(user_id: str): 

517 """Delete user configuration from Redis or memory. 

518 

519 Args: 

520 user_id: User identifier. 

521 """ 

522 if redis_client: 

523 await redis_client.delete(_cfg_key(user_id)) 

524 else: 

525 user_configs.pop(user_id, None) 

526 

527 

528# ---------- SESSION (active) HELPERS with locking & recreate ---------- 

529 

530 

531async def set_active_session(user_id: str, session: MCPChatService): 

532 """Register an active session locally and mark ownership in Redis with TTL. 

533 

534 Args: 

535 user_id: User identifier. 

536 session: Initialized MCPChatService instance. 

537 """ 

538 active_sessions[user_id] = session 

539 if redis_client: 

540 # set owner with TTL so dead workers eventually lose ownership 

541 await redis_client.set(_active_key(user_id), WORKER_ID, ex=SESSION_TTL) 

542 

543 

544async def delete_active_session(user_id: str): 

545 """Remove active session locally and from Redis atomically. 

546 

547 Uses a Lua script to ensure we only delete the Redis key if we own it, 

548 preventing race conditions where another worker's session marker could 

549 be deleted if our session expired and was recreated by another worker. 

550 

551 Args: 

552 user_id: User identifier. 

553 """ 

554 active_sessions.pop(user_id, None) 

555 if redis_client: 

556 try: 

557 # Lua script for atomic check-and-delete (only delete if we own the key) 

558 release_script = """ 

559 if redis.call("get", KEYS[1]) == ARGV[1] then 

560 return redis.call("del", KEYS[1]) 

561 else 

562 return 0 

563 end 

564 """ 

565 await redis_client.eval(release_script, 1, _active_key(user_id), WORKER_ID) 

566 except Exception as e: 

567 logger.warning(f"Failed to delete active session for user {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

568 

569 

570async def _try_acquire_lock(user_id: str) -> bool: 

571 """Attempt to acquire the initialization lock for a user session. 

572 

573 Args: 

574 user_id: User identifier. 

575 

576 Returns: 

577 bool: True if lock acquired, False otherwise. 

578 """ 

579 if not redis_client: 

580 return True # no redis -> local only, no lock required 

581 return await redis_client.set(_lock_key(user_id), WORKER_ID, nx=True, ex=LOCK_TTL) 

582 

583 

584async def _release_lock_safe(user_id: str): 

585 """Release the lock atomically only if we own it. 

586 

587 Uses a Lua script to ensure atomic check-and-delete, preventing 

588 the TOCTOU race condition where another worker's lock could be 

589 deleted if the original lock expired between get() and delete(). 

590 

591 Args: 

592 user_id: User identifier. 

593 """ 

594 if not redis_client: 

595 return 

596 try: 

597 # Lua script for atomic check-and-delete (only delete if we own the key) 

598 release_script = """ 

599 if redis.call("get", KEYS[1]) == ARGV[1] then 

600 return redis.call("del", KEYS[1]) 

601 else 

602 return 0 

603 end 

604 """ 

605 await redis_client.eval(release_script, 1, _lock_key(user_id), WORKER_ID) 

606 except Exception as e: 

607 logger.warning(f"Failed to release lock for user {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

608 

609 

610async def _create_local_session_from_config(user_id: str) -> Optional[MCPChatService]: 

611 """Create MCPChatService locally from stored config. 

612 

613 Args: 

614 user_id: User identifier. 

615 

616 Returns: 

617 Optional[MCPChatService]: Initialized service or None if creation fails. 

618 """ 

619 config = await get_user_config(user_id) 

620 if not config: 

621 return None 

622 

623 # create and initialize with unified history manager 

624 try: 

625 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

626 await chat_service.initialize() 

627 await set_active_session(user_id, chat_service) 

628 return chat_service 

629 except Exception as e: 

630 # If initialization fails, ensure nothing partial remains 

631 logger.error(f"Failed to initialize MCPChatService for {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True) 

632 # cleanup local state and redis ownership (if we set it) 

633 await delete_active_session(user_id) 

634 return None 

635 

636 

637async def get_active_session(user_id: str) -> Optional[MCPChatService]: 

638 """ 

639 Retrieve or (if possible) create the active session for user_id. 

640 

641 Behavior: 

642 - If Redis is disabled: return local session or None. 

643 - If Redis enabled: 

644 * If owner == WORKER_ID and local session exists -> return it (and refresh TTL) 

645 * If owner == WORKER_ID but local missing -> try to acquire lock and recreate 

646 * If no owner -> try to acquire lock and create session here 

647 * If owner != WORKER_ID -> wait a short time for owner to appear or return None 

648 

649 Args: 

650 user_id: User identifier. 

651 

652 Returns: 

653 Optional[MCPChatService]: Active session if available, None otherwise. 

654 """ 

655 # Fast path: no redis => purely local 

656 if not redis_client: 

657 return active_sessions.get(user_id) 

658 

659 active_key = _active_key(user_id) 

660 # _lock_key = _lock_key(user_id) 

661 owner = await redis_client.get(active_key) 

662 

663 # 1) Owned by this worker 

664 if owner == WORKER_ID: 

665 local = active_sessions.get(user_id) 

666 if local: 

667 # refresh TTL so ownership persists while active 

668 try: 

669 await redis_client.expire(active_key, SESSION_TTL) 

670 except Exception as e: # nosec B110 

671 # non-fatal if expire fails, just log the error 

672 logger.debug(f"Failed to refresh session TTL for {SecurityValidator.sanitize_log_message(user_id)}: {e}") 

673 return local 

674 

675 # Owner in Redis points to this worker but local session missing (process restart or lost). 

676 # Try to recreate it (acquire lock). 

677 acquired = await _try_acquire_lock(user_id) 

678 if acquired: 

679 try: 

680 # create new local session 

681 session = await _create_local_session_from_config(user_id) 

682 return session 

683 finally: 

684 await _release_lock_safe(user_id) 

685 else: 

686 # someone else is (re)creating; wait a bit for them to finish 

687 for _ in range(LOCK_RETRIES): 

688 await asyncio.sleep(LOCK_WAIT) 

689 if active_sessions.get(user_id): 

690 return active_sessions.get(user_id) 

691 return None 

692 

693 # 2) No owner -> try to claim & create session locally 

694 if owner is None: 

695 acquired = await _try_acquire_lock(user_id) 

696 if acquired: 

697 try: 

698 session = await _create_local_session_from_config(user_id) 

699 return session 

700 finally: 

701 await _release_lock_safe(user_id) 

702 

703 # if we couldn't acquire lock, someone else is creating; wait a short time 

704 for _ in range(LOCK_RETRIES): 

705 await asyncio.sleep(LOCK_WAIT) 

706 owner2 = await redis_client.get(active_key) 

707 if owner2 == WORKER_ID and active_sessions.get(user_id): 

708 return active_sessions.get(user_id) 

709 if owner2 is not None and owner2 != WORKER_ID: 

710 # some other worker now owns it 

711 return None 

712 

713 # final attempt to acquire lock (last resort) 

714 acquired = await _try_acquire_lock(user_id) 

715 if acquired: 

716 try: 

717 session = await _create_local_session_from_config(user_id) 

718 return session 

719 finally: 

720 await _release_lock_safe(user_id) 

721 return None 

722 

723 # 3) Owned by another worker -> we don't have it locally 

724 # Optionally we could attempt to "steal" if owner is stale, but TTL expiry handles that. 

725 return None 

726 

727 

728# ---------- ROUTES ---------- 

729 

730 

731@llmchat_router.post("/connect") 

732@require_permission("llm.invoke") 

733async def connect(input_data: ConnectInput, request: Request, user=Depends(get_current_user_with_permissions)): 

734 """Create or refresh a chat session for a user. 

735 

736 Initializes a new MCPChatService instance for the specified user, establishing 

737 connections to both the MCP server and the configured LLM provider. If a session 

738 already exists for the user, it is gracefully shutdown before creating a new one. 

739 

740 Authentication is handled via JWT token from cookies if not explicitly provided 

741 in the request body. 

742 

743 Args: 

744 input_data: ConnectInput containing user_id, optional server/LLM config, and streaming preference. 

745 request: FastAPI Request object for accessing cookies and headers. 

746 user: Authenticated user context. 

747 

748 Returns: 

749 dict: Connection status response containing: 

750 - status: 'connected' 

751 - user_id: The connected user's identifier 

752 - provider: The LLM provider being used 

753 - tool_count: Number of available MCP tools 

754 - tools: List of tool names 

755 

756 Raises: 

757 HTTPException: If an error occurs. 

758 400: Invalid user_id, invalid configuration, or LLM config error. 

759 401: Missing authentication token. 

760 503: Failed to connect to MCP server. 

761 500: Service initialization failure or unexpected error. 

762 

763 Examples: 

764 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

765 Example request body: 

766 

767 { 

768 "user_id": "user123", 

769 "server": { 

770 "url": "http://localhost:8000/mcp", 

771 "auth_token": "jwt_token_here" 

772 }, 

773 "llm": { 

774 "provider": "ollama", 

775 "config": {"model": "llama3"} 

776 }, 

777 "streaming": false 

778 } 

779 

780 Example response: 

781 

782 { 

783 "status": "connected", 

784 "user_id": "user123", 

785 "provider": "ollama", 

786 "tool_count": 5, 

787 "tools": ["search", "calculator", "weather", "translate", "summarize"] 

788 } 

789 

790 Note: 

791 Existing sessions are automatically terminated before establishing new ones. 

792 All configuration values support environment variable fallbacks. 

793 """ 

794 user_id = _resolve_user_id(input_data.user_id, user) 

795 

796 try: 

797 # Validate user_id 

798 if not user_id or not isinstance(user_id, str): 

799 raise HTTPException(status_code=400, detail="Invalid user ID provided") 

800 

801 # Validate user-supplied server URLs with SSRF protections before any outbound connection setup. 

802 if input_data.server and input_data.server.url: 

803 try: 

804 input_data.server.url = SecurityValidator.validate_url(str(input_data.server.url), "MCP server URL") 

805 except ValueError as e: 

806 logger.warning("LLM chat connect URL validation failed for user %s and URL %s: %s", user_id, input_data.server.url, e) 

807 raise HTTPException(status_code=400, detail="Invalid server URL") 

808 

809 # Handle authentication token 

810 empty_token = "" # nosec B105 

811 if input_data.server and (input_data.server.auth_token is None or input_data.server.auth_token == empty_token): 

812 jwt_token = request.cookies.get("jwt_token") 

813 if not jwt_token: 

814 raise HTTPException(status_code=401, detail="Authentication required. Please ensure you are logged in.") 

815 input_data.server.auth_token = jwt_token 

816 

817 # Close old session if it exists 

818 existing = await get_active_session(user_id) 

819 if existing: 

820 try: 

821 logger.debug(f"Disconnecting existing session for {SecurityValidator.sanitize_log_message(user_id)} before reconnecting") 

822 await existing.shutdown() 

823 except Exception as shutdown_error: 

824 logger.warning(f"Failed to cleanly shutdown existing session for {SecurityValidator.sanitize_log_message(user_id)}: {shutdown_error}") 

825 finally: 

826 # Always remove the session from active sessions, even if shutdown failed 

827 await delete_active_session(user_id) 

828 

829 # Build and validate configuration 

830 try: 

831 config = build_config(input_data) 

832 except ValueError as ve: 

833 raise HTTPException(status_code=400, detail=f"Invalid configuration: {str(ve)}") 

834 except Exception as config_error: 

835 raise HTTPException(status_code=400, detail=f"Configuration error: {str(config_error)}") 

836 

837 # Store user configuration 

838 await set_user_config(user_id, config) 

839 

840 # Initialize chat service 

841 try: 

842 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

843 await chat_service.initialize() 

844 

845 # Clear chat history on new connection 

846 await chat_service.clear_history() 

847 except ConnectionError as ce: 

848 # Clean up partial state 

849 await delete_user_config(user_id) 

850 raise HTTPException(status_code=503, detail=f"Failed to connect to MCP server: {str(ce)}. Please verify the server URL and authentication.") 

851 except ValueError as ve: 

852 # Clean up partial state 

853 await delete_user_config(user_id) 

854 raise HTTPException(status_code=400, detail=f"Invalid LLM configuration: {str(ve)}") 

855 except Exception as init_error: 

856 # Clean up partial state 

857 await delete_user_config(user_id) 

858 raise HTTPException(status_code=500, detail=f"Service initialization failed: {str(init_error)}") 

859 

860 await set_active_session(user_id, chat_service) 

861 

862 # Extract tool names 

863 tool_names = [] 

864 try: 

865 if hasattr(chat_service, "_tools") and chat_service._tools: 

866 for tool in chat_service._tools: 

867 tool_name = getattr(tool, "name", None) 

868 if tool_name: 

869 tool_names.append(tool_name) 

870 except Exception as tool_error: 

871 logger.warning(f"Failed to extract tool names: {tool_error}") 

872 # Continue without tools list 

873 

874 return {"status": "connected", "user_id": user_id, "provider": config.llm.provider, "tool_count": len(tool_names), "tools": tool_names} 

875 

876 except HTTPException: 

877 # Re-raise HTTP exceptions as-is 

878 raise 

879 except Exception as e: 

880 logger.error(f"Unexpected error in connect endpoint: {e}", exc_info=True) 

881 raise HTTPException(status_code=500, detail=f"Unexpected connection error: {str(e)}") 

882 

883 

884async def token_streamer(chat_service: MCPChatService, message: str, user_id: str): 

885 """Stream chat response tokens as Server-Sent Events (SSE). 

886 

887 Asynchronous generator that yields SSE-formatted chunks containing tokens, 

888 tool invocation updates, and final response data from the chat service. 

889 Uses the unified ChatHistoryManager for history persistence. 

890 

891 Args: 

892 chat_service: MCPChatService instance configured for the user session. 

893 message: User's chat message to process. 

894 user_id: User identifier for logging. 

895 

896 Yields: 

897 bytes: SSE-formatted event data containing: 

898 - token events: Incremental content chunks 

899 - tool_start: Tool invocation beginning 

900 - tool_end: Tool invocation completion 

901 - tool_error: Tool execution failure 

902 - final: Complete response with metadata 

903 - error: Error information with recovery status 

904 

905 Event Types: 

906 - token: {"content": "text chunk"} 

907 - tool_start: {"type": "tool_start", "tool": "name", ...} 

908 - tool_end: {"type": "tool_end", "tool": "name", ...} 

909 - tool_error: {"type": "tool_error", "tool": "name", "error": "message"} 

910 - final: {"type": "final", "text": "complete response", "metadata": {...}} 

911 - error: {"type": "error", "error": "message", "recoverable": bool} 

912 

913 Examples: 

914 This is an async generator used internally by the chat endpoint. 

915 It cannot be directly tested with standard doctest. 

916 

917 Example event stream: 

918 

919 event: token 

920 data: {"content": "Hello"} 

921 

922 event: token 

923 data: {"content": ", how"} 

924 

925 event: final 

926 data: {"type": "final", "text": "Hello, how can I help?"} 

927 

928 Note: 

929 SSE format requires 'event: <type>\\ndata: <json>\\n\\n' structure. 

930 All exceptions are caught and converted to error events for client handling. 

931 """ 

932 

933 async def sse(event_type: str, data: Dict[str, Any]): 

934 """Format data as Server-Sent Event. 

935 

936 Args: 

937 event_type: SSE event type identifier. 

938 data: Payload dictionary to serialize as JSON. 

939 

940 Yields: 

941 bytes: UTF-8 encoded SSE formatted lines. 

942 """ 

943 yield f"event: {event_type}\n".encode("utf-8") 

944 yield f"data: {orjson.dumps(data).decode()}\n\n".encode("utf-8") 

945 

946 try: 

947 async for ev in chat_service.chat_events(message): 

948 et = ev.get("type") 

949 if et == "token": 

950 content = ev.get("content", "") 

951 async for part in sse("token", {"content": content}): 

952 yield part 

953 elif et in ("tool_start", "tool_end", "tool_error"): 

954 async for part in sse(et, ev): 

955 yield part 

956 elif et == "final": 

957 async for part in sse("final", ev): 

958 yield part 

959 

960 except ConnectionError as ce: 

961 error_event = {"type": "error", "error": f"Connection lost: {str(ce)}", "recoverable": False} 

962 async for part in sse("error", error_event): 

963 yield part 

964 except TimeoutError: 

965 error_event = {"type": "error", "error": "Request timed out waiting for LLM response", "recoverable": True} 

966 async for part in sse("error", error_event): 

967 yield part 

968 except ChatProcessingError as cpe: 

969 # ChatProcessingError wraps tool/parsing/model errors — 

970 # the session is still valid, so the frontend can retry. 

971 error_event = {"type": "error", "error": f"Service error: {str(cpe)}", "recoverable": True} 

972 async for part in sse("error", error_event): 

973 yield part 

974 except RuntimeError as re: 

975 error_event = {"type": "error", "error": f"Service error: {str(re)}", "recoverable": False} 

976 async for part in sse("error", error_event): 

977 yield part 

978 except Exception as e: 

979 logger.error(f"Unexpected streaming error: {e}", exc_info=True) 

980 error_event = {"type": "error", "error": f"Unexpected error: {str(e)}", "recoverable": False} 

981 async for part in sse("error", error_event): 

982 yield part 

983 

984 

985@llmchat_router.post("/chat") 

986@require_permission("llm.invoke") 

987async def chat(input_data: ChatInput, user=Depends(get_current_user_with_permissions)): 

988 """Send a message to the user's active chat session and receive a response. 

989 

990 Processes user messages through the configured LLM with MCP tool integration. 

991 Supports both streaming (SSE) and non-streaming response modes. Chat history 

992 is managed automatically via the unified ChatHistoryManager. 

993 

994 Args: 

995 input_data: ChatInput containing user_id, message, and streaming preference. 

996 user: Authenticated user context. 

997 

998 Returns: 

999 For streaming=False: 

1000 dict: Response containing: 

1001 - user_id: Session identifier 

1002 - response: Complete LLM response text 

1003 - tool_used: Boolean indicating if tools were invoked 

1004 - tools: List of tool names used 

1005 - tool_invocations: Detailed tool call information 

1006 - elapsed_ms: Processing time in milliseconds 

1007 For streaming=True: 

1008 StreamingResponse: SSE stream of token and event data. 

1009 

1010 Raises: 

1011 HTTPException: Raised when an HTTP-related error occurs. 

1012 400: Missing user_id, empty message, or no active session. 

1013 503: Session not initialized, chat service error, or connection lost. 

1014 504: Request timeout. 

1015 500: Unexpected error. 

1016 

1017 Examples: 

1018 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

1019 

1020 Example non-streaming request: 

1021 

1022 { 

1023 "user_id": "user123", 

1024 "message": "What's the weather like?", 

1025 "streaming": false 

1026 } 

1027 

1028 Example non-streaming response: 

1029 

1030 { 

1031 "user_id": "user123", 

1032 "response": "The weather is sunny and 72°F.", 

1033 "tool_used": true, 

1034 "tools": ["weather"], 

1035 "tool_invocations": 1, 

1036 "elapsed_ms": 450 

1037 } 

1038 

1039 Example streaming request: 

1040 

1041 { 

1042 "user_id": "user123", 

1043 "message": "Tell me a story", 

1044 "streaming": true 

1045 } 

1046 

1047 Note: 

1048 Streaming responses use Server-Sent Events (SSE) with 'text/event-stream' MIME type. 

1049 Client must maintain persistent connection for streaming. 

1050 """ 

1051 user_id = _resolve_user_id(input_data.user_id, user) 

1052 

1053 # Validate input 

1054 if not user_id: 

1055 raise HTTPException(status_code=400, detail="User ID is required") 

1056 

1057 if not input_data.message or not input_data.message.strip(): 

1058 raise HTTPException(status_code=400, detail="Message cannot be empty") 

1059 

1060 # Check for active session 

1061 chat_service = await get_active_session(user_id) 

1062 if not chat_service: 

1063 raise HTTPException(status_code=400, detail="No active session found. Please connect to a server first.") 

1064 

1065 # Verify session is initialized 

1066 if not chat_service.is_initialized: 

1067 raise HTTPException(status_code=503, detail="Session is not properly initialized. Please reconnect.") 

1068 

1069 try: 

1070 if input_data.streaming: 

1071 return StreamingResponse( 

1072 token_streamer(chat_service, input_data.message, user_id), 

1073 media_type="text/event-stream", 

1074 headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, # Disable proxy buffering 

1075 ) 

1076 else: 

1077 try: 

1078 result = await chat_service.chat_with_metadata(input_data.message) 

1079 

1080 return { 

1081 "user_id": user_id, 

1082 "response": result["text"], 

1083 "tool_used": result["tool_used"], 

1084 "tools": result["tools"], 

1085 "tool_invocations": result["tool_invocations"], 

1086 "elapsed_ms": result["elapsed_ms"], 

1087 } 

1088 except RuntimeError as re: 

1089 raise HTTPException(status_code=503, detail=f"Chat service error: {str(re)}") 

1090 

1091 except ConnectionError as ce: 

1092 raise HTTPException(status_code=503, detail=f"Lost connection to MCP server: {str(ce)}. Please reconnect.") 

1093 except TimeoutError: 

1094 raise HTTPException(status_code=504, detail="Request timed out. The LLM took too long to respond.") 

1095 except HTTPException: 

1096 raise 

1097 except Exception as e: 

1098 logger.error(f"Unexpected error in chat endpoint for user {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True) 

1099 raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") 

1100 

1101 

1102@llmchat_router.post("/disconnect") 

1103@require_permission("llm.invoke") 

1104async def disconnect(input_data: DisconnectInput, user=Depends(get_current_user_with_permissions)): 

1105 """End the chat session for a user and clean up resources. 

1106 

1107 Gracefully shuts down the MCPChatService instance, closes connections, 

1108 and removes session data from active storage. Safe to call even if 

1109 no active session exists. 

1110 

1111 Args: 

1112 input_data: DisconnectInput containing the user_id to disconnect. 

1113 user: Authenticated user context. 

1114 

1115 Returns: 

1116 dict: Disconnection status containing: 

1117 - status: One of 'disconnected', 'no_active_session', or 'disconnected_with_errors' 

1118 - user_id: The user identifier 

1119 - message: Human-readable status description 

1120 - warning: (Optional) Error details if cleanup encountered issues 

1121 

1122 Raises: 

1123 HTTPException: Raised when an HTTP-related error occurs. 

1124 400: Missing user_id. 

1125 

1126 Examples: 

1127 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

1128 

1129 Example request: 

1130 

1131 { 

1132 "user_id": "user123" 

1133 } 

1134 

1135 Example successful response: 

1136 

1137 { 

1138 "status": "disconnected", 

1139 "user_id": "user123", 

1140 "message": "Successfully disconnected" 

1141 } 

1142 

1143 Example response when no session exists: 

1144 

1145 { 

1146 "status": "no_active_session", 

1147 "user_id": "user123", 

1148 "message": "No active session to disconnect" 

1149 } 

1150 

1151 Note: 

1152 This operation is idempotent - calling it multiple times for the same 

1153 user_id is safe and will not raise errors. 

1154 """ 

1155 user_id = _resolve_user_id(input_data.user_id, user) 

1156 

1157 if not user_id: 

1158 raise HTTPException(status_code=400, detail="User ID is required") 

1159 

1160 # Remove and shut down chat service 

1161 chat_service = await get_active_session(user_id) 

1162 await delete_active_session(user_id) 

1163 

1164 # Remove user config 

1165 await delete_user_config(user_id) 

1166 

1167 if not chat_service: 

1168 return {"status": "no_active_session", "user_id": user_id, "message": "No active session to disconnect"} 

1169 

1170 try: 

1171 # Clear chat history on disconnect 

1172 await chat_service.clear_history() 

1173 logger.info(f"Chat session disconnected for {SecurityValidator.sanitize_log_message(user_id)}") 

1174 

1175 await chat_service.shutdown() 

1176 return {"status": "disconnected", "user_id": user_id, "message": "Successfully disconnected"} 

1177 except Exception as e: 

1178 logger.error(f"Error during disconnect for user {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True) 

1179 # Session already removed, so return success with warning 

1180 return {"status": "disconnected_with_errors", "user_id": user_id, "message": "Disconnected but cleanup encountered errors", "warning": str(e)} 

1181 

1182 

1183@llmchat_router.get("/status/{user_id}") 

1184@require_permission("llm.read") 

1185async def status(user_id: str, user=Depends(get_current_user_with_permissions)): 

1186 """Check if an active chat session exists for the specified user. 

1187 

1188 Lightweight endpoint for verifying session state without modifying data. 

1189 Useful for health checks and UI state management. 

1190 

1191 Args: 

1192 user_id: User identifier to check session status for. 

1193 user: Authenticated user context. 

1194 

1195 Returns: 

1196 dict: Status information containing: 

1197 - user_id: The queried user identifier 

1198 - connected: Boolean indicating if an active session exists 

1199 

1200 Examples: 

1201 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1202 

1203 Example request: 

1204 GET /llmchat/status/user123 

1205 

1206 Example response (connected): 

1207 

1208 { 

1209 "user_id": "user123", 

1210 "connected": true 

1211 } 

1212 

1213 Example response (not connected): 

1214 

1215 { 

1216 "user_id": "user456", 

1217 "connected": false 

1218 } 

1219 

1220 Note: 

1221 This endpoint does not validate that the session is properly initialized, 

1222 only that it exists in the active_sessions dictionary. 

1223 """ 

1224 resolved_user_id = _resolve_user_id(user_id, user) 

1225 connected = bool(await get_active_session(resolved_user_id)) 

1226 return {"user_id": resolved_user_id, "connected": connected} 

1227 

1228 

1229@llmchat_router.get("/config/{user_id}") 

1230@require_permission("llm.read") 

1231async def get_config(user_id: str, user=Depends(get_current_user_with_permissions)): 

1232 """Retrieve the stored configuration for a user's session. 

1233 

1234 Returns sanitized configuration data with sensitive information (API keys, 

1235 auth tokens) removed for security. Useful for debugging and configuration 

1236 verification. 

1237 

1238 Args: 

1239 user_id: User identifier whose configuration to retrieve. 

1240 user: Authenticated user context. 

1241 

1242 Returns: 

1243 dict: Sanitized configuration dictionary containing: 

1244 - mcp_server: Server connection settings (without auth_token) 

1245 - llm: LLM provider configuration (without api_key) 

1246 - enable_streaming: Boolean streaming preference 

1247 

1248 Raises: 

1249 HTTPException: Raised when an HTTP-related error occurs. 

1250 404: No configuration found for the specified user_id. 

1251 

1252 

1253 Examples: 

1254 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1255 

1256 Example request: 

1257 GET /llmchat/config/user123 

1258 

1259 Example response: 

1260 

1261 { 

1262 "mcp_server": { 

1263 "url": "http://localhost:8000/mcp", 

1264 "transport": "streamable_http" 

1265 }, 

1266 "llm": { 

1267 "provider": "ollama", 

1268 "config": { 

1269 "model": "llama3", 

1270 "temperature": 0.7 

1271 } 

1272 }, 

1273 "enable_streaming": false 

1274 } 

1275 

1276 Security: 

1277 API keys and authentication tokens are explicitly removed before returning. 

1278 Never log or expose these values in responses. 

1279 """ 

1280 resolved_user_id = _resolve_user_id(user_id, user) 

1281 config = await get_user_config(resolved_user_id) 

1282 

1283 if not config: 

1284 raise HTTPException(status_code=404, detail="No config found for this user.") 

1285 

1286 config_dict = config.model_dump() 

1287 return _mask_sensitive_config_values(config_dict) 

1288 

1289 

1290@llmchat_router.get("/gateway/models") 

1291@require_permission("llm.read") 

1292async def get_gateway_models(_user=Depends(get_current_user_with_permissions)): 

1293 """Get available models from configured LLM providers. 

1294 

1295 Returns a list of enabled models from enabled providers configured 

1296 in the gateway's LLM Settings. These models can be used with the 

1297 "gateway" provider type in /connect requests. 

1298 

1299 Returns: 

1300 dict: Response containing: 

1301 - models: List of available models with provider info 

1302 - count: Total number of available models 

1303 

1304 Examples: 

1305 GET /llmchat/gateway/models 

1306 

1307 Response: 

1308 { 

1309 "models": [ 

1310 { 

1311 "id": "abc123", 

1312 "model_id": "gpt-4o", 

1313 "model_name": "GPT-4o", 

1314 "provider_id": "def456", 

1315 "provider_name": "OpenAI", 

1316 "provider_type": "openai", 

1317 "supports_streaming": true, 

1318 "supports_function_calling": true, 

1319 "supports_vision": true 

1320 } 

1321 ], 

1322 "count": 1 

1323 } 

1324 

1325 Raises: 

1326 HTTPException: If there is an error retrieving gateway models. 

1327 

1328 Args: 

1329 _user: Authenticated user context. 

1330 """ 

1331 # Import here to avoid circular dependency 

1332 # First-Party 

1333 from mcpgateway.db import SessionLocal 

1334 from mcpgateway.services.llm_provider_service import LLMProviderService 

1335 

1336 llm_service = LLMProviderService() 

1337 

1338 try: 

1339 with SessionLocal() as db: 

1340 models = llm_service.get_gateway_models(db) 

1341 return { 

1342 "models": [m.model_dump() for m in models], 

1343 "count": len(models), 

1344 } 

1345 except Exception as e: 

1346 logger.error(f"Failed to get gateway models: {e}") 

1347 raise HTTPException(status_code=500, detail=f"Failed to retrieve gateway models: {str(e)}")