Coverage for mcpgateway / routers / llmchat_router.py: 99%

397 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/llmchat_router.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7LLM Chat Router Module 

8 

9This module provides FastAPI endpoints for managing LLM-based chat sessions 

10with MCP (Model Context Protocol) server integration. LLM providers are 

11configured via the Admin UI's LLM Settings and accessed through the gateway 

12provider. 

13 

14The module handles user session management, configuration, and real-time 

15streaming responses for conversational AI applications with unified chat 

16history management via ChatHistoryManager from mcp_client_chat_service. 

17 

18""" 

19 

20# Standard 

21import asyncio 

22import os 

23import time 

24from typing import Any, Dict, Optional 

25 

26# Third-Party 

27from fastapi import APIRouter, Depends, HTTPException, Request 

28from fastapi.responses import StreamingResponse 

29import orjson 

30from pydantic import BaseModel, Field 

31 

32try: 

33 # Third-Party 

34 import redis.asyncio # noqa: F401 - availability check only 

35 

36 REDIS_AVAILABLE = True 

37except ImportError: 

38 REDIS_AVAILABLE = False 

39 

40# First-Party 

41from mcpgateway.common.validators import SecurityValidator 

42from mcpgateway.config import settings 

43from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission 

44from mcpgateway.services.logging_service import LoggingService 

45from mcpgateway.services.mcp_client_chat_service import ( 

46 GatewayConfig, 

47 LLMConfig, 

48 MCPChatService, 

49 MCPClientConfig, 

50 MCPServerConfig, 

51) 

52from mcpgateway.utils.redis_client import get_redis_client 

53from mcpgateway.utils.services_auth import decode_auth, encode_auth 

54 

55# Initialize router 

56llmchat_router = APIRouter(prefix="/llmchat", tags=["llmchat"]) 

57 

58# Redis client (initialized via init_redis() during app startup) 

59redis_client = None 

60 

61 

62async def init_redis() -> None: 

63 """Initialize Redis client using the shared factory. 

64 

65 Should be called during application startup from main.py lifespan. 

66 """ 

67 global redis_client 

68 if getattr(settings, "cache_type", None) == "redis" and getattr(settings, "redis_url", None): 

69 redis_client = await get_redis_client() 

70 if redis_client: 

71 logger.info("LLMChat router connected to shared Redis client") 

72 

73 

74# Fallback in-memory stores (used when Redis unavailable) 

75# Store active chat sessions per user 

76active_sessions: Dict[str, MCPChatService] = {} 

77 

78# Store configuration per user 

79user_configs: Dict[str, tuple[bytes, float]] = {} 

80 

81# Logging 

82logging_service = LoggingService() 

83logger = logging_service.get_logger(__name__) 

84 

85# ---------- MODELS ---------- 

86 

87 

88class LLMInput(BaseModel): 

89 """Input configuration for LLM provider selection. 

90 

91 This model specifies which gateway-configured model to use. 

92 Models must be configured via Admin UI -> LLM Settings. 

93 

94 Attributes: 

95 model: Model ID from the gateway's LLM Settings (UUID or model_id). 

96 temperature: Optional sampling temperature (0.0-2.0). 

97 max_tokens: Optional maximum tokens to generate. 

98 

99 Examples: 

100 >>> llm_input = LLMInput(model='gpt-4o') 

101 >>> llm_input.model 

102 'gpt-4o' 

103 

104 >>> llm_input = LLMInput(model='abc123-uuid', temperature=0.5) 

105 >>> llm_input.temperature 

106 0.5 

107 """ 

108 

109 model: str = Field(..., description="Model ID from gateway LLM Settings (UUID or model_id)") 

110 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature") 

111 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

112 

113 

114class ServerInput(BaseModel): 

115 """Input configuration for MCP server connection. 

116 

117 Defines the connection parameters required to establish communication 

118 with an MCP (Model Context Protocol) server. 

119 

120 Attributes: 

121 url: Optional MCP server URL endpoint. Defaults to environment variable 

122 or 'http://localhost:8000/mcp'. 

123 transport: Communication transport protocol. Defaults to 'streamable_http'. 

124 auth_token: Optional authentication token for secure server access. 

125 

126 Examples: 

127 >>> server = ServerInput(url='http://example.com/mcp') 

128 >>> server.transport 

129 'streamable_http' 

130 

131 >>> server = ServerInput() 

132 >>> server.url is None 

133 True 

134 """ 

135 

136 url: Optional[str] = None 

137 transport: Optional[str] = "streamable_http" 

138 auth_token: Optional[str] = None 

139 

140 

141class ConnectInput(BaseModel): 

142 """Request model for establishing a new chat session. 

143 

144 Contains all necessary parameters to initialize a user's chat session 

145 including server connection details, LLM configuration, and streaming preferences. 

146 

147 Attributes: 

148 user_id: Unique identifier for the user session. Required for session management. 

149 server: Optional MCP server configuration. Uses defaults if not provided. 

150 llm: LLM configuration specifying which gateway model to use. Required. 

151 streaming: Whether to enable streaming responses. Defaults to False. 

152 

153 Examples: 

154 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

155 >>> connect.streaming 

156 False 

157 

158 >>> connect = ConnectInput(user_id='user456', llm=LLMInput(model='gpt-4o'), streaming=True) 

159 >>> connect.user_id 

160 'user456' 

161 """ 

162 

163 user_id: str 

164 server: Optional[ServerInput] = None 

165 llm: LLMInput = Field(..., description="LLM configuration with model from gateway LLM Settings") 

166 streaming: bool = False 

167 

168 

169class ChatInput(BaseModel): 

170 """Request model for sending chat messages. 

171 

172 Encapsulates user message data for processing by the chat service. 

173 

174 Attributes: 

175 user_id: Unique identifier for the active user session. 

176 message: The chat message content to be processed. 

177 streaming: Whether to stream the response. Defaults to False. 

178 

179 Examples: 

180 >>> chat = ChatInput(user_id='user123', message='Hello, AI!') 

181 >>> len(chat.message) > 0 

182 True 

183 

184 >>> chat = ChatInput(user_id='user456', message='Tell me a story', streaming=True) 

185 >>> chat.streaming 

186 True 

187 """ 

188 

189 user_id: str 

190 message: str 

191 streaming: bool = False 

192 

193 

194class DisconnectInput(BaseModel): 

195 """Request model for terminating a chat session. 

196 

197 Simple model containing only the user identifier for session cleanup. 

198 

199 Attributes: 

200 user_id: Unique identifier of the session to disconnect. 

201 

202 Examples: 

203 >>> disconnect = DisconnectInput(user_id='user123') 

204 >>> disconnect.user_id 

205 'user123' 

206 """ 

207 

208 user_id: str 

209 

210 

211# ---------- HELPERS ---------- 

212 

213 

214def build_llm_config(llm: LLMInput) -> LLMConfig: 

215 """Construct an LLMConfig object from input parameters. 

216 

217 Creates a gateway provider configuration that routes requests through 

218 the gateway's LLM Settings. Models must be configured via Admin UI. 

219 

220 Args: 

221 llm: LLMInput containing model ID and optional temperature/max_tokens. 

222 

223 Returns: 

224 LLMConfig: Gateway provider configuration. 

225 

226 Examples: 

227 >>> llm_input = LLMInput(model='gpt-4o') 

228 >>> config = build_llm_config(llm_input) 

229 >>> config.provider 

230 'gateway' 

231 

232 Note: 

233 All LLM configuration is done via Admin UI -> Settings -> LLM Settings. 

234 The gateway provider looks up models from the database and creates 

235 the appropriate LLM instance based on provider type. 

236 """ 

237 return LLMConfig( 

238 provider="gateway", 

239 config=GatewayConfig( 

240 model=llm.model, 

241 temperature=llm.temperature if llm.temperature is not None else 0.7, 

242 max_tokens=llm.max_tokens, 

243 ), 

244 ) 

245 

246 

247def build_config(input_data: ConnectInput) -> MCPClientConfig: 

248 """Build complete MCP client configuration from connection input. 

249 

250 Constructs a comprehensive configuration object combining MCP server settings 

251 and LLM configuration. 

252 

253 Args: 

254 input_data: ConnectInput object containing server, LLM, and streaming settings. 

255 

256 Returns: 

257 MCPClientConfig: Complete client configuration ready for service initialization. 

258 

259 Examples: 

260 >>> from mcpgateway.routers.llmchat_router import ConnectInput, LLMInput, build_config 

261 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

262 >>> config = build_config(connect) 

263 >>> config.mcp_server.transport 

264 'streamable_http' 

265 

266 Note: 

267 MCP server settings use defaults if not provided. 

268 LLM configuration routes through the gateway provider. 

269 """ 

270 server = input_data.server 

271 

272 return MCPClientConfig( 

273 mcp_server=MCPServerConfig( 

274 url=server.url if server and server.url else "http://localhost:8000/mcp", 

275 transport=server.transport if server and server.transport else "streamable_http", 

276 auth_token=server.auth_token if server else None, 

277 ), 

278 llm=build_llm_config(input_data.llm), 

279 enable_streaming=input_data.streaming, 

280 ) 

281 

282 

283def _get_user_id_from_context(user: Dict[str, Any]) -> str: 

284 """Extract a stable user identifier from the authenticated user context. 

285 

286 Args: 

287 user: Authenticated user context from RBAC dependency. 

288 

289 Returns: 

290 User identifier string or "unknown" if missing. 

291 """ 

292 if isinstance(user, dict): 

293 return user.get("id") or user.get("user_id") or user.get("sub") or user.get("email") or "unknown" 

294 return "unknown" if user is None else str(getattr(user, "id", user)) 

295 

296 

297def _resolve_user_id(input_user_id: Optional[str], user: Dict[str, Any]) -> str: 

298 """Resolve the authenticated user ID and reject mismatched requests. 

299 

300 Args: 

301 input_user_id: User ID provided by the client (optional). 

302 user: Authenticated user context from RBAC dependency. 

303 

304 Returns: 

305 Resolved authenticated user identifier. 

306 

307 Raises: 

308 HTTPException: When authentication is missing or user ID mismatches. 

309 """ 

310 user_id = _get_user_id_from_context(user) 

311 if user_id == "unknown": 

312 raise HTTPException(status_code=401, detail="Authentication required.") 

313 if input_user_id and input_user_id != user_id: 

314 raise HTTPException(status_code=403, detail="User ID mismatch.") 

315 return user_id 

316 

317 

318# ---------- SESSION STORAGE HELPERS ---------- 

319 

320# Identify this worker uniquely (used for sticky session ownership) 

321WORKER_ID = str(os.getpid()) 

322 

323# Tunables (can set via environment) 

324SESSION_TTL = settings.llmchat_session_ttl # seconds for active_session key TTL 

325LOCK_TTL = settings.llmchat_session_lock_ttl # seconds for lock expiry 

326LOCK_RETRIES = settings.llmchat_session_lock_retries # how many times to poll while waiting 

327LOCK_WAIT = settings.llmchat_session_lock_wait # seconds between polls 

328USER_CONFIG_TTL = settings.llmchat_session_ttl 

329 

330_ENCRYPTED_CONFIG_PAYLOAD_KEY = "_encrypted_payload" 

331_ENCRYPTED_CONFIG_VERSION_KEY = "_version" 

332_ENCRYPTED_CONFIG_VERSION = "1" 

333_CONFIG_SENSITIVE_KEYS = frozenset( 

334 { 

335 "api_key", 

336 "auth_token", 

337 "authorization", 

338 "access_token", 

339 "refresh_token", 

340 "client_secret", 

341 "secret_access_key", 

342 "session_token", 

343 "credentials_json", 

344 "password", 

345 "private_key", 

346 } 

347) 

348 

349 

350# Redis key helpers 

351def _cfg_key(user_id: str) -> str: 

352 """Generate Redis key for user configuration storage. 

353 

354 Args: 

355 user_id: User identifier. 

356 

357 Returns: 

358 str: Redis key for storing user configuration. 

359 """ 

360 return f"user_config:{user_id}" 

361 

362 

363def _active_key(user_id: str) -> str: 

364 """Generate Redis key for active session tracking. 

365 

366 Args: 

367 user_id: User identifier. 

368 

369 Returns: 

370 str: Redis key for tracking active sessions. 

371 """ 

372 return f"active_session:{user_id}" 

373 

374 

375def _lock_key(user_id: str) -> str: 

376 """Generate Redis key for session initialization lock. 

377 

378 Args: 

379 user_id: User identifier. 

380 

381 Returns: 

382 str: Redis key for session locks. 

383 """ 

384 return f"session_lock:{user_id}" 

385 

386 

387def _serialize_user_config_for_storage(config: MCPClientConfig) -> bytes: 

388 """Serialize and encrypt user config for storage backends. 

389 

390 Args: 

391 config: User MCP client configuration. 

392 

393 Returns: 

394 Serialized bytes containing encrypted config envelope. 

395 """ 

396 payload = encode_auth({"config": config.model_dump()}) 

397 return orjson.dumps( 

398 { 

399 _ENCRYPTED_CONFIG_VERSION_KEY: _ENCRYPTED_CONFIG_VERSION, 

400 _ENCRYPTED_CONFIG_PAYLOAD_KEY: payload, 

401 } 

402 ) 

403 

404 

405def _deserialize_user_config_from_storage(data: bytes | str) -> Optional[MCPClientConfig]: 

406 """Deserialize user config from encrypted or legacy plaintext payloads. 

407 

408 Args: 

409 data: Serialized config payload from storage. 

410 

411 Returns: 

412 Parsed ``MCPClientConfig`` when data is valid, otherwise ``None``. 

413 """ 

414 try: 

415 parsed = orjson.loads(data) 

416 except Exception: 

417 logger.warning("Failed to parse stored LLM chat config payload") 

418 return None 

419 

420 # New encrypted envelope 

421 if isinstance(parsed, dict) and _ENCRYPTED_CONFIG_PAYLOAD_KEY in parsed: 

422 encrypted_payload = parsed.get(_ENCRYPTED_CONFIG_PAYLOAD_KEY) 

423 if not encrypted_payload: 

424 return None 

425 decoded = decode_auth(encrypted_payload) 

426 config_data = decoded.get("config") if isinstance(decoded, dict) else None 

427 if not isinstance(config_data, dict): 

428 logger.warning("Decoded encrypted LLM chat config is invalid") 

429 return None 

430 return MCPClientConfig(**config_data) 

431 

432 # Legacy plaintext payload compatibility 

433 if isinstance(parsed, dict): 

434 return MCPClientConfig(**parsed) 

435 

436 return None 

437 

438 

439def _is_sensitive_config_key(key: str) -> bool: 

440 """Return whether a config key should be masked in responses. 

441 

442 Args: 

443 key: Config field name. 

444 

445 Returns: 

446 ``True`` when key is in the sensitive-key allowlist. 

447 """ 

448 return str(key).strip().lower() in _CONFIG_SENSITIVE_KEYS 

449 

450 

451def _mask_sensitive_config_values(value: Any) -> Any: 

452 """Recursively mask sensitive config values in API responses. 

453 

454 Args: 

455 value: Arbitrary nested config value. 

456 

457 Returns: 

458 Value with sensitive fields replaced by configured mask marker. 

459 """ 

460 if isinstance(value, dict): 

461 masked: Dict[str, Any] = {} 

462 for key, item in value.items(): 

463 if _is_sensitive_config_key(key): 

464 masked[key] = settings.masked_auth_value if item not in (None, "") else item 

465 else: 

466 masked[key] = _mask_sensitive_config_values(item) 

467 return masked 

468 if isinstance(value, list): 

469 return [_mask_sensitive_config_values(item) for item in value] 

470 return value 

471 

472 

473# ---------- CONFIG HELPERS ---------- 

474 

475 

476async def set_user_config(user_id: str, config: MCPClientConfig): 

477 """Store user configuration in Redis or memory. 

478 

479 Args: 

480 user_id: User identifier. 

481 config: Complete MCP client configuration. 

482 """ 

483 serialized = _serialize_user_config_for_storage(config) 

484 if redis_client: 

485 await redis_client.set(_cfg_key(user_id), serialized, ex=USER_CONFIG_TTL) 

486 else: 

487 user_configs[user_id] = (serialized, time.monotonic()) 

488 

489 

490async def get_user_config(user_id: str) -> Optional[MCPClientConfig]: 

491 """Retrieve user configuration from Redis or memory. 

492 

493 Args: 

494 user_id: User identifier. 

495 

496 Returns: 

497 Optional[MCPClientConfig]: User configuration if found, None otherwise. 

498 """ 

499 if redis_client: 

500 data = await redis_client.get(_cfg_key(user_id)) 

501 if not data: 

502 return None 

503 return _deserialize_user_config_from_storage(data) 

504 

505 cached_entry = user_configs.get(user_id) 

506 if not cached_entry: 

507 return None 

508 cached, cached_at = cached_entry 

509 if (time.monotonic() - cached_at) > USER_CONFIG_TTL: 

510 user_configs.pop(user_id, None) 

511 return None 

512 return _deserialize_user_config_from_storage(cached) 

513 

514 

515async def delete_user_config(user_id: str): 

516 """Delete user configuration from Redis or memory. 

517 

518 Args: 

519 user_id: User identifier. 

520 """ 

521 if redis_client: 

522 await redis_client.delete(_cfg_key(user_id)) 

523 else: 

524 user_configs.pop(user_id, None) 

525 

526 

527# ---------- SESSION (active) HELPERS with locking & recreate ---------- 

528 

529 

530async def set_active_session(user_id: str, session: MCPChatService): 

531 """Register an active session locally and mark ownership in Redis with TTL. 

532 

533 Args: 

534 user_id: User identifier. 

535 session: Initialized MCPChatService instance. 

536 """ 

537 active_sessions[user_id] = session 

538 if redis_client: 

539 # set owner with TTL so dead workers eventually lose ownership 

540 await redis_client.set(_active_key(user_id), WORKER_ID, ex=SESSION_TTL) 

541 

542 

543async def delete_active_session(user_id: str): 

544 """Remove active session locally and from Redis atomically. 

545 

546 Uses a Lua script to ensure we only delete the Redis key if we own it, 

547 preventing race conditions where another worker's session marker could 

548 be deleted if our session expired and was recreated by another worker. 

549 

550 Args: 

551 user_id: User identifier. 

552 """ 

553 active_sessions.pop(user_id, None) 

554 if redis_client: 

555 try: 

556 # Lua script for atomic check-and-delete (only delete if we own the key) 

557 release_script = """ 

558 if redis.call("get", KEYS[1]) == ARGV[1] then 

559 return redis.call("del", KEYS[1]) 

560 else 

561 return 0 

562 end 

563 """ 

564 await redis_client.eval(release_script, 1, _active_key(user_id), WORKER_ID) 

565 except Exception as e: 

566 logger.warning(f"Failed to delete active session for user {user_id}: {e}") 

567 

568 

569async def _try_acquire_lock(user_id: str) -> bool: 

570 """Attempt to acquire the initialization lock for a user session. 

571 

572 Args: 

573 user_id: User identifier. 

574 

575 Returns: 

576 bool: True if lock acquired, False otherwise. 

577 """ 

578 if not redis_client: 

579 return True # no redis -> local only, no lock required 

580 return await redis_client.set(_lock_key(user_id), WORKER_ID, nx=True, ex=LOCK_TTL) 

581 

582 

583async def _release_lock_safe(user_id: str): 

584 """Release the lock atomically only if we own it. 

585 

586 Uses a Lua script to ensure atomic check-and-delete, preventing 

587 the TOCTOU race condition where another worker's lock could be 

588 deleted if the original lock expired between get() and delete(). 

589 

590 Args: 

591 user_id: User identifier. 

592 """ 

593 if not redis_client: 

594 return 

595 try: 

596 # Lua script for atomic check-and-delete (only delete if we own the key) 

597 release_script = """ 

598 if redis.call("get", KEYS[1]) == ARGV[1] then 

599 return redis.call("del", KEYS[1]) 

600 else 

601 return 0 

602 end 

603 """ 

604 await redis_client.eval(release_script, 1, _lock_key(user_id), WORKER_ID) 

605 except Exception as e: 

606 logger.warning(f"Failed to release lock for user {user_id}: {e}") 

607 

608 

609async def _create_local_session_from_config(user_id: str) -> Optional[MCPChatService]: 

610 """Create MCPChatService locally from stored config. 

611 

612 Args: 

613 user_id: User identifier. 

614 

615 Returns: 

616 Optional[MCPChatService]: Initialized service or None if creation fails. 

617 """ 

618 config = await get_user_config(user_id) 

619 if not config: 

620 return None 

621 

622 # create and initialize with unified history manager 

623 try: 

624 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

625 await chat_service.initialize() 

626 await set_active_session(user_id, chat_service) 

627 return chat_service 

628 except Exception as e: 

629 # If initialization fails, ensure nothing partial remains 

630 logger.error(f"Failed to initialize MCPChatService for {user_id}: {e}", exc_info=True) 

631 # cleanup local state and redis ownership (if we set it) 

632 await delete_active_session(user_id) 

633 return None 

634 

635 

636async def get_active_session(user_id: str) -> Optional[MCPChatService]: 

637 """ 

638 Retrieve or (if possible) create the active session for user_id. 

639 

640 Behavior: 

641 - If Redis is disabled: return local session or None. 

642 - If Redis enabled: 

643 * If owner == WORKER_ID and local session exists -> return it (and refresh TTL) 

644 * If owner == WORKER_ID but local missing -> try to acquire lock and recreate 

645 * If no owner -> try to acquire lock and create session here 

646 * If owner != WORKER_ID -> wait a short time for owner to appear or return None 

647 

648 Args: 

649 user_id: User identifier. 

650 

651 Returns: 

652 Optional[MCPChatService]: Active session if available, None otherwise. 

653 """ 

654 # Fast path: no redis => purely local 

655 if not redis_client: 

656 return active_sessions.get(user_id) 

657 

658 active_key = _active_key(user_id) 

659 # _lock_key = _lock_key(user_id) 

660 owner = await redis_client.get(active_key) 

661 

662 # 1) Owned by this worker 

663 if owner == WORKER_ID: 

664 local = active_sessions.get(user_id) 

665 if local: 

666 # refresh TTL so ownership persists while active 

667 try: 

668 await redis_client.expire(active_key, SESSION_TTL) 

669 except Exception as e: # nosec B110 

670 # non-fatal if expire fails, just log the error 

671 logger.debug(f"Failed to refresh session TTL for {user_id}: {e}") 

672 return local 

673 

674 # Owner in Redis points to this worker but local session missing (process restart or lost). 

675 # Try to recreate it (acquire lock). 

676 acquired = await _try_acquire_lock(user_id) 

677 if acquired: 

678 try: 

679 # create new local session 

680 session = await _create_local_session_from_config(user_id) 

681 return session 

682 finally: 

683 await _release_lock_safe(user_id) 

684 else: 

685 # someone else is (re)creating; wait a bit for them to finish 

686 for _ in range(LOCK_RETRIES): 

687 await asyncio.sleep(LOCK_WAIT) 

688 if active_sessions.get(user_id): 

689 return active_sessions.get(user_id) 

690 return None 

691 

692 # 2) No owner -> try to claim & create session locally 

693 if owner is None: 

694 acquired = await _try_acquire_lock(user_id) 

695 if acquired: 

696 try: 

697 session = await _create_local_session_from_config(user_id) 

698 return session 

699 finally: 

700 await _release_lock_safe(user_id) 

701 

702 # if we couldn't acquire lock, someone else is creating; wait a short time 

703 for _ in range(LOCK_RETRIES): 

704 await asyncio.sleep(LOCK_WAIT) 

705 owner2 = await redis_client.get(active_key) 

706 if owner2 == WORKER_ID and active_sessions.get(user_id): 

707 return active_sessions.get(user_id) 

708 if owner2 is not None and owner2 != WORKER_ID: 

709 # some other worker now owns it 

710 return None 

711 

712 # final attempt to acquire lock (last resort) 

713 acquired = await _try_acquire_lock(user_id) 

714 if acquired: 

715 try: 

716 session = await _create_local_session_from_config(user_id) 

717 return session 

718 finally: 

719 await _release_lock_safe(user_id) 

720 return None 

721 

722 # 3) Owned by another worker -> we don't have it locally 

723 # Optionally we could attempt to "steal" if owner is stale, but TTL expiry handles that. 

724 return None 

725 

726 

727# ---------- ROUTES ---------- 

728 

729 

730@llmchat_router.post("/connect") 

731@require_permission("llm.invoke") 

732async def connect(input_data: ConnectInput, request: Request, user=Depends(get_current_user_with_permissions)): 

733 """Create or refresh a chat session for a user. 

734 

735 Initializes a new MCPChatService instance for the specified user, establishing 

736 connections to both the MCP server and the configured LLM provider. If a session 

737 already exists for the user, it is gracefully shutdown before creating a new one. 

738 

739 Authentication is handled via JWT token from cookies if not explicitly provided 

740 in the request body. 

741 

742 Args: 

743 input_data: ConnectInput containing user_id, optional server/LLM config, and streaming preference. 

744 request: FastAPI Request object for accessing cookies and headers. 

745 user: Authenticated user context. 

746 

747 Returns: 

748 dict: Connection status response containing: 

749 - status: 'connected' 

750 - user_id: The connected user's identifier 

751 - provider: The LLM provider being used 

752 - tool_count: Number of available MCP tools 

753 - tools: List of tool names 

754 

755 Raises: 

756 HTTPException: If an error occurs. 

757 400: Invalid user_id, invalid configuration, or LLM config error. 

758 401: Missing authentication token. 

759 503: Failed to connect to MCP server. 

760 500: Service initialization failure or unexpected error. 

761 

762 Examples: 

763 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

764 Example request body: 

765 

766 { 

767 "user_id": "user123", 

768 "server": { 

769 "url": "http://localhost:8000/mcp", 

770 "auth_token": "jwt_token_here" 

771 }, 

772 "llm": { 

773 "provider": "ollama", 

774 "config": {"model": "llama3"} 

775 }, 

776 "streaming": false 

777 } 

778 

779 Example response: 

780 

781 { 

782 "status": "connected", 

783 "user_id": "user123", 

784 "provider": "ollama", 

785 "tool_count": 5, 

786 "tools": ["search", "calculator", "weather", "translate", "summarize"] 

787 } 

788 

789 Note: 

790 Existing sessions are automatically terminated before establishing new ones. 

791 All configuration values support environment variable fallbacks. 

792 """ 

793 user_id = _resolve_user_id(input_data.user_id, user) 

794 

795 try: 

796 # Validate user_id 

797 if not user_id or not isinstance(user_id, str): 

798 raise HTTPException(status_code=400, detail="Invalid user ID provided") 

799 

800 # Validate user-supplied server URLs with SSRF protections before any outbound connection setup. 

801 if input_data.server and input_data.server.url: 

802 try: 

803 input_data.server.url = SecurityValidator.validate_url(str(input_data.server.url), "MCP server URL") 

804 except ValueError as e: 

805 logger.warning("LLM chat connect URL validation failed for user %s and URL %s: %s", user_id, input_data.server.url, e) 

806 raise HTTPException(status_code=400, detail="Invalid server URL") 

807 

808 # Handle authentication token 

809 empty_token = "" # nosec B105 

810 if input_data.server and (input_data.server.auth_token is None or input_data.server.auth_token == empty_token): 

811 jwt_token = request.cookies.get("jwt_token") 

812 if not jwt_token: 

813 raise HTTPException(status_code=401, detail="Authentication required. Please ensure you are logged in.") 

814 input_data.server.auth_token = jwt_token 

815 

816 # Close old session if it exists 

817 existing = await get_active_session(user_id) 

818 if existing: 

819 try: 

820 logger.debug(f"Disconnecting existing session for {user_id} before reconnecting") 

821 await existing.shutdown() 

822 except Exception as shutdown_error: 

823 logger.warning(f"Failed to cleanly shutdown existing session for {user_id}: {shutdown_error}") 

824 finally: 

825 # Always remove the session from active sessions, even if shutdown failed 

826 await delete_active_session(user_id) 

827 

828 # Build and validate configuration 

829 try: 

830 config = build_config(input_data) 

831 except ValueError as ve: 

832 raise HTTPException(status_code=400, detail=f"Invalid configuration: {str(ve)}") 

833 except Exception as config_error: 

834 raise HTTPException(status_code=400, detail=f"Configuration error: {str(config_error)}") 

835 

836 # Store user configuration 

837 await set_user_config(user_id, config) 

838 

839 # Initialize chat service 

840 try: 

841 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

842 await chat_service.initialize() 

843 

844 # Clear chat history on new connection 

845 await chat_service.clear_history() 

846 except ConnectionError as ce: 

847 # Clean up partial state 

848 await delete_user_config(user_id) 

849 raise HTTPException(status_code=503, detail=f"Failed to connect to MCP server: {str(ce)}. Please verify the server URL and authentication.") 

850 except ValueError as ve: 

851 # Clean up partial state 

852 await delete_user_config(user_id) 

853 raise HTTPException(status_code=400, detail=f"Invalid LLM configuration: {str(ve)}") 

854 except Exception as init_error: 

855 # Clean up partial state 

856 await delete_user_config(user_id) 

857 raise HTTPException(status_code=500, detail=f"Service initialization failed: {str(init_error)}") 

858 

859 await set_active_session(user_id, chat_service) 

860 

861 # Extract tool names 

862 tool_names = [] 

863 try: 

864 if hasattr(chat_service, "_tools") and chat_service._tools: 

865 for tool in chat_service._tools: 

866 tool_name = getattr(tool, "name", None) 

867 if tool_name: 

868 tool_names.append(tool_name) 

869 except Exception as tool_error: 

870 logger.warning(f"Failed to extract tool names: {tool_error}") 

871 # Continue without tools list 

872 

873 return {"status": "connected", "user_id": user_id, "provider": config.llm.provider, "tool_count": len(tool_names), "tools": tool_names} 

874 

875 except HTTPException: 

876 # Re-raise HTTP exceptions as-is 

877 raise 

878 except Exception as e: 

879 logger.error(f"Unexpected error in connect endpoint: {e}", exc_info=True) 

880 raise HTTPException(status_code=500, detail=f"Unexpected connection error: {str(e)}") 

881 

882 

883async def token_streamer(chat_service: MCPChatService, message: str, user_id: str): 

884 """Stream chat response tokens as Server-Sent Events (SSE). 

885 

886 Asynchronous generator that yields SSE-formatted chunks containing tokens, 

887 tool invocation updates, and final response data from the chat service. 

888 Uses the unified ChatHistoryManager for history persistence. 

889 

890 Args: 

891 chat_service: MCPChatService instance configured for the user session. 

892 message: User's chat message to process. 

893 user_id: User identifier for logging. 

894 

895 Yields: 

896 bytes: SSE-formatted event data containing: 

897 - token events: Incremental content chunks 

898 - tool_start: Tool invocation beginning 

899 - tool_end: Tool invocation completion 

900 - tool_error: Tool execution failure 

901 - final: Complete response with metadata 

902 - error: Error information with recovery status 

903 

904 Event Types: 

905 - token: {"content": "text chunk"} 

906 - tool_start: {"type": "tool_start", "tool": "name", ...} 

907 - tool_end: {"type": "tool_end", "tool": "name", ...} 

908 - tool_error: {"type": "tool_error", "tool": "name", "error": "message"} 

909 - final: {"type": "final", "text": "complete response", "metadata": {...}} 

910 - error: {"type": "error", "error": "message", "recoverable": bool} 

911 

912 Examples: 

913 This is an async generator used internally by the chat endpoint. 

914 It cannot be directly tested with standard doctest. 

915 

916 Example event stream: 

917 

918 event: token 

919 data: {"content": "Hello"} 

920 

921 event: token 

922 data: {"content": ", how"} 

923 

924 event: final 

925 data: {"type": "final", "text": "Hello, how can I help?"} 

926 

927 Note: 

928 SSE format requires 'event: <type>\\ndata: <json>\\n\\n' structure. 

929 All exceptions are caught and converted to error events for client handling. 

930 """ 

931 

932 async def sse(event_type: str, data: Dict[str, Any]): 

933 """Format data as Server-Sent Event. 

934 

935 Args: 

936 event_type: SSE event type identifier. 

937 data: Payload dictionary to serialize as JSON. 

938 

939 Yields: 

940 bytes: UTF-8 encoded SSE formatted lines. 

941 """ 

942 yield f"event: {event_type}\n".encode("utf-8") 

943 yield f"data: {orjson.dumps(data).decode()}\n\n".encode("utf-8") 

944 

945 try: 

946 async for ev in chat_service.chat_events(message): 

947 et = ev.get("type") 

948 if et == "token": 

949 content = ev.get("content", "") 

950 async for part in sse("token", {"content": content}): 

951 yield part 

952 elif et in ("tool_start", "tool_end", "tool_error"): 

953 async for part in sse(et, ev): 

954 yield part 

955 elif et == "final": 

956 async for part in sse("final", ev): 

957 yield part 

958 

959 except ConnectionError as ce: 

960 error_event = {"type": "error", "error": f"Connection lost: {str(ce)}", "recoverable": False} 

961 async for part in sse("error", error_event): 

962 yield part 

963 except TimeoutError: 

964 error_event = {"type": "error", "error": "Request timed out waiting for LLM response", "recoverable": True} 

965 async for part in sse("error", error_event): 

966 yield part 

967 except RuntimeError as re: 

968 error_event = {"type": "error", "error": f"Service error: {str(re)}", "recoverable": False} 

969 async for part in sse("error", error_event): 

970 yield part 

971 except Exception as e: 

972 logger.error(f"Unexpected streaming error: {e}", exc_info=True) 

973 error_event = {"type": "error", "error": f"Unexpected error: {str(e)}", "recoverable": False} 

974 async for part in sse("error", error_event): 

975 yield part 

976 

977 

978@llmchat_router.post("/chat") 

979@require_permission("llm.invoke") 

980async def chat(input_data: ChatInput, user=Depends(get_current_user_with_permissions)): 

981 """Send a message to the user's active chat session and receive a response. 

982 

983 Processes user messages through the configured LLM with MCP tool integration. 

984 Supports both streaming (SSE) and non-streaming response modes. Chat history 

985 is managed automatically via the unified ChatHistoryManager. 

986 

987 Args: 

988 input_data: ChatInput containing user_id, message, and streaming preference. 

989 user: Authenticated user context. 

990 

991 Returns: 

992 For streaming=False: 

993 dict: Response containing: 

994 - user_id: Session identifier 

995 - response: Complete LLM response text 

996 - tool_used: Boolean indicating if tools were invoked 

997 - tools: List of tool names used 

998 - tool_invocations: Detailed tool call information 

999 - elapsed_ms: Processing time in milliseconds 

1000 For streaming=True: 

1001 StreamingResponse: SSE stream of token and event data. 

1002 

1003 Raises: 

1004 HTTPException: Raised when an HTTP-related error occurs. 

1005 400: Missing user_id, empty message, or no active session. 

1006 503: Session not initialized, chat service error, or connection lost. 

1007 504: Request timeout. 

1008 500: Unexpected error. 

1009 

1010 Examples: 

1011 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

1012 

1013 Example non-streaming request: 

1014 

1015 { 

1016 "user_id": "user123", 

1017 "message": "What's the weather like?", 

1018 "streaming": false 

1019 } 

1020 

1021 Example non-streaming response: 

1022 

1023 { 

1024 "user_id": "user123", 

1025 "response": "The weather is sunny and 72°F.", 

1026 "tool_used": true, 

1027 "tools": ["weather"], 

1028 "tool_invocations": 1, 

1029 "elapsed_ms": 450 

1030 } 

1031 

1032 Example streaming request: 

1033 

1034 { 

1035 "user_id": "user123", 

1036 "message": "Tell me a story", 

1037 "streaming": true 

1038 } 

1039 

1040 Note: 

1041 Streaming responses use Server-Sent Events (SSE) with 'text/event-stream' MIME type. 

1042 Client must maintain persistent connection for streaming. 

1043 """ 

1044 user_id = _resolve_user_id(input_data.user_id, user) 

1045 

1046 # Validate input 

1047 if not user_id: 

1048 raise HTTPException(status_code=400, detail="User ID is required") 

1049 

1050 if not input_data.message or not input_data.message.strip(): 

1051 raise HTTPException(status_code=400, detail="Message cannot be empty") 

1052 

1053 # Check for active session 

1054 chat_service = await get_active_session(user_id) 

1055 if not chat_service: 

1056 raise HTTPException(status_code=400, detail="No active session found. Please connect to a server first.") 

1057 

1058 # Verify session is initialized 

1059 if not chat_service.is_initialized: 

1060 raise HTTPException(status_code=503, detail="Session is not properly initialized. Please reconnect.") 

1061 

1062 try: 

1063 if input_data.streaming: 

1064 return StreamingResponse( 

1065 token_streamer(chat_service, input_data.message, user_id), 

1066 media_type="text/event-stream", 

1067 headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, # Disable proxy buffering 

1068 ) 

1069 else: 

1070 try: 

1071 result = await chat_service.chat_with_metadata(input_data.message) 

1072 

1073 return { 

1074 "user_id": user_id, 

1075 "response": result["text"], 

1076 "tool_used": result["tool_used"], 

1077 "tools": result["tools"], 

1078 "tool_invocations": result["tool_invocations"], 

1079 "elapsed_ms": result["elapsed_ms"], 

1080 } 

1081 except RuntimeError as re: 

1082 raise HTTPException(status_code=503, detail=f"Chat service error: {str(re)}") 

1083 

1084 except ConnectionError as ce: 

1085 raise HTTPException(status_code=503, detail=f"Lost connection to MCP server: {str(ce)}. Please reconnect.") 

1086 except TimeoutError: 

1087 raise HTTPException(status_code=504, detail="Request timed out. The LLM took too long to respond.") 

1088 except HTTPException: 

1089 raise 

1090 except Exception as e: 

1091 logger.error(f"Unexpected error in chat endpoint for user {user_id}: {e}", exc_info=True) 

1092 raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") 

1093 

1094 

1095@llmchat_router.post("/disconnect") 

1096@require_permission("llm.invoke") 

1097async def disconnect(input_data: DisconnectInput, user=Depends(get_current_user_with_permissions)): 

1098 """End the chat session for a user and clean up resources. 

1099 

1100 Gracefully shuts down the MCPChatService instance, closes connections, 

1101 and removes session data from active storage. Safe to call even if 

1102 no active session exists. 

1103 

1104 Args: 

1105 input_data: DisconnectInput containing the user_id to disconnect. 

1106 user: Authenticated user context. 

1107 

1108 Returns: 

1109 dict: Disconnection status containing: 

1110 - status: One of 'disconnected', 'no_active_session', or 'disconnected_with_errors' 

1111 - user_id: The user identifier 

1112 - message: Human-readable status description 

1113 - warning: (Optional) Error details if cleanup encountered issues 

1114 

1115 Raises: 

1116 HTTPException: Raised when an HTTP-related error occurs. 

1117 400: Missing user_id. 

1118 

1119 Examples: 

1120 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

1121 

1122 Example request: 

1123 

1124 { 

1125 "user_id": "user123" 

1126 } 

1127 

1128 Example successful response: 

1129 

1130 { 

1131 "status": "disconnected", 

1132 "user_id": "user123", 

1133 "message": "Successfully disconnected" 

1134 } 

1135 

1136 Example response when no session exists: 

1137 

1138 { 

1139 "status": "no_active_session", 

1140 "user_id": "user123", 

1141 "message": "No active session to disconnect" 

1142 } 

1143 

1144 Note: 

1145 This operation is idempotent - calling it multiple times for the same 

1146 user_id is safe and will not raise errors. 

1147 """ 

1148 user_id = _resolve_user_id(input_data.user_id, user) 

1149 

1150 if not user_id: 

1151 raise HTTPException(status_code=400, detail="User ID is required") 

1152 

1153 # Remove and shut down chat service 

1154 chat_service = await get_active_session(user_id) 

1155 await delete_active_session(user_id) 

1156 

1157 # Remove user config 

1158 await delete_user_config(user_id) 

1159 

1160 if not chat_service: 

1161 return {"status": "no_active_session", "user_id": user_id, "message": "No active session to disconnect"} 

1162 

1163 try: 

1164 # Clear chat history on disconnect 

1165 await chat_service.clear_history() 

1166 logger.info(f"Chat session disconnected for {user_id}") 

1167 

1168 await chat_service.shutdown() 

1169 return {"status": "disconnected", "user_id": user_id, "message": "Successfully disconnected"} 

1170 except Exception as e: 

1171 logger.error(f"Error during disconnect for user {user_id}: {e}", exc_info=True) 

1172 # Session already removed, so return success with warning 

1173 return {"status": "disconnected_with_errors", "user_id": user_id, "message": "Disconnected but cleanup encountered errors", "warning": str(e)} 

1174 

1175 

1176@llmchat_router.get("/status/{user_id}") 

1177@require_permission("llm.read") 

1178async def status(user_id: str, user=Depends(get_current_user_with_permissions)): 

1179 """Check if an active chat session exists for the specified user. 

1180 

1181 Lightweight endpoint for verifying session state without modifying data. 

1182 Useful for health checks and UI state management. 

1183 

1184 Args: 

1185 user_id: User identifier to check session status for. 

1186 user: Authenticated user context. 

1187 

1188 Returns: 

1189 dict: Status information containing: 

1190 - user_id: The queried user identifier 

1191 - connected: Boolean indicating if an active session exists 

1192 

1193 Examples: 

1194 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1195 

1196 Example request: 

1197 GET /llmchat/status/user123 

1198 

1199 Example response (connected): 

1200 

1201 { 

1202 "user_id": "user123", 

1203 "connected": true 

1204 } 

1205 

1206 Example response (not connected): 

1207 

1208 { 

1209 "user_id": "user456", 

1210 "connected": false 

1211 } 

1212 

1213 Note: 

1214 This endpoint does not validate that the session is properly initialized, 

1215 only that it exists in the active_sessions dictionary. 

1216 """ 

1217 resolved_user_id = _resolve_user_id(user_id, user) 

1218 connected = bool(await get_active_session(resolved_user_id)) 

1219 return {"user_id": resolved_user_id, "connected": connected} 

1220 

1221 

1222@llmchat_router.get("/config/{user_id}") 

1223@require_permission("llm.read") 

1224async def get_config(user_id: str, user=Depends(get_current_user_with_permissions)): 

1225 """Retrieve the stored configuration for a user's session. 

1226 

1227 Returns sanitized configuration data with sensitive information (API keys, 

1228 auth tokens) removed for security. Useful for debugging and configuration 

1229 verification. 

1230 

1231 Args: 

1232 user_id: User identifier whose configuration to retrieve. 

1233 user: Authenticated user context. 

1234 

1235 Returns: 

1236 dict: Sanitized configuration dictionary containing: 

1237 - mcp_server: Server connection settings (without auth_token) 

1238 - llm: LLM provider configuration (without api_key) 

1239 - enable_streaming: Boolean streaming preference 

1240 

1241 Raises: 

1242 HTTPException: Raised when an HTTP-related error occurs. 

1243 404: No configuration found for the specified user_id. 

1244 

1245 

1246 Examples: 

1247 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1248 

1249 Example request: 

1250 GET /llmchat/config/user123 

1251 

1252 Example response: 

1253 

1254 { 

1255 "mcp_server": { 

1256 "url": "http://localhost:8000/mcp", 

1257 "transport": "streamable_http" 

1258 }, 

1259 "llm": { 

1260 "provider": "ollama", 

1261 "config": { 

1262 "model": "llama3", 

1263 "temperature": 0.7 

1264 } 

1265 }, 

1266 "enable_streaming": false 

1267 } 

1268 

1269 Security: 

1270 API keys and authentication tokens are explicitly removed before returning. 

1271 Never log or expose these values in responses. 

1272 """ 

1273 resolved_user_id = _resolve_user_id(user_id, user) 

1274 config = await get_user_config(resolved_user_id) 

1275 

1276 if not config: 

1277 raise HTTPException(status_code=404, detail="No config found for this user.") 

1278 

1279 config_dict = config.model_dump() 

1280 return _mask_sensitive_config_values(config_dict) 

1281 

1282 

1283@llmchat_router.get("/gateway/models") 

1284@require_permission("llm.read") 

1285async def get_gateway_models(_user=Depends(get_current_user_with_permissions)): 

1286 """Get available models from configured LLM providers. 

1287 

1288 Returns a list of enabled models from enabled providers configured 

1289 in the gateway's LLM Settings. These models can be used with the 

1290 "gateway" provider type in /connect requests. 

1291 

1292 Returns: 

1293 dict: Response containing: 

1294 - models: List of available models with provider info 

1295 - count: Total number of available models 

1296 

1297 Examples: 

1298 GET /llmchat/gateway/models 

1299 

1300 Response: 

1301 { 

1302 "models": [ 

1303 { 

1304 "id": "abc123", 

1305 "model_id": "gpt-4o", 

1306 "model_name": "GPT-4o", 

1307 "provider_id": "def456", 

1308 "provider_name": "OpenAI", 

1309 "provider_type": "openai", 

1310 "supports_streaming": true, 

1311 "supports_function_calling": true, 

1312 "supports_vision": true 

1313 } 

1314 ], 

1315 "count": 1 

1316 } 

1317 

1318 Raises: 

1319 HTTPException: If there is an error retrieving gateway models. 

1320 

1321 Args: 

1322 _user: Authenticated user context. 

1323 """ 

1324 # Import here to avoid circular dependency 

1325 # First-Party 

1326 from mcpgateway.db import SessionLocal 

1327 from mcpgateway.services.llm_provider_service import LLMProviderService 

1328 

1329 llm_service = LLMProviderService() 

1330 

1331 try: 

1332 with SessionLocal() as db: 

1333 models = llm_service.get_gateway_models(db) 

1334 return { 

1335 "models": [m.model_dump() for m in models], 

1336 "count": len(models), 

1337 } 

1338 except Exception as e: 

1339 logger.error(f"Failed to get gateway models: {e}") 

1340 raise HTTPException(status_code=500, detail=f"Failed to retrieve gateway models: {str(e)}")