Coverage for mcpgateway / routers / llmchat_router.py: 97%

337 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/llmchat_router.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7LLM Chat Router Module 

8 

9This module provides FastAPI endpoints for managing LLM-based chat sessions 

10with MCP (Model Context Protocol) server integration. LLM providers are 

11configured via the Admin UI's LLM Settings and accessed through the gateway 

12provider. 

13 

14The module handles user session management, configuration, and real-time 

15streaming responses for conversational AI applications with unified chat 

16history management via ChatHistoryManager from mcp_client_chat_service. 

17 

18""" 

19 

20# Standard 

21import asyncio 

22import os 

23from typing import Any, Dict, Optional 

24 

25# Third-Party 

26from fastapi import APIRouter, Depends, HTTPException, Request 

27from fastapi.responses import StreamingResponse 

28import orjson 

29from pydantic import BaseModel, Field 

30 

31try: 

32 # Third-Party 

33 import redis.asyncio # noqa: F401 - availability check only 

34 

35 REDIS_AVAILABLE = True 

36except ImportError: 

37 REDIS_AVAILABLE = False 

38 

39# First-Party 

40from mcpgateway.config import settings 

41from mcpgateway.middleware.rbac import get_current_user_with_permissions 

42from mcpgateway.services.logging_service import LoggingService 

43from mcpgateway.services.mcp_client_chat_service import ( 

44 GatewayConfig, 

45 LLMConfig, 

46 MCPChatService, 

47 MCPClientConfig, 

48 MCPServerConfig, 

49) 

50from mcpgateway.utils.redis_client import get_redis_client 

51 

52# Initialize router 

53llmchat_router = APIRouter(prefix="/llmchat", tags=["llmchat"]) 

54 

55# Redis client (initialized via init_redis() during app startup) 

56redis_client = None 

57 

58 

59async def init_redis() -> None: 

60 """Initialize Redis client using the shared factory. 

61 

62 Should be called during application startup from main.py lifespan. 

63 """ 

64 global redis_client 

65 if getattr(settings, "cache_type", None) == "redis" and getattr(settings, "redis_url", None): 

66 redis_client = await get_redis_client() 

67 if redis_client: 

68 logger.info("LLMChat router connected to shared Redis client") 

69 

70 

71# Fallback in-memory stores (used when Redis unavailable) 

72# Store active chat sessions per user 

73active_sessions: Dict[str, MCPChatService] = {} 

74 

75# Store configuration per user 

76user_configs: Dict[str, MCPClientConfig] = {} 

77 

78# Logging 

79logging_service = LoggingService() 

80logger = logging_service.get_logger(__name__) 

81 

82# ---------- MODELS ---------- 

83 

84 

85class LLMInput(BaseModel): 

86 """Input configuration for LLM provider selection. 

87 

88 This model specifies which gateway-configured model to use. 

89 Models must be configured via Admin UI -> LLM Settings. 

90 

91 Attributes: 

92 model: Model ID from the gateway's LLM Settings (UUID or model_id). 

93 temperature: Optional sampling temperature (0.0-2.0). 

94 max_tokens: Optional maximum tokens to generate. 

95 

96 Examples: 

97 >>> llm_input = LLMInput(model='gpt-4o') 

98 >>> llm_input.model 

99 'gpt-4o' 

100 

101 >>> llm_input = LLMInput(model='abc123-uuid', temperature=0.5) 

102 >>> llm_input.temperature 

103 0.5 

104 """ 

105 

106 model: str = Field(..., description="Model ID from gateway LLM Settings (UUID or model_id)") 

107 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature") 

108 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate") 

109 

110 

111class ServerInput(BaseModel): 

112 """Input configuration for MCP server connection. 

113 

114 Defines the connection parameters required to establish communication 

115 with an MCP (Model Context Protocol) server. 

116 

117 Attributes: 

118 url: Optional MCP server URL endpoint. Defaults to environment variable 

119 or 'http://localhost:8000/mcp'. 

120 transport: Communication transport protocol. Defaults to 'streamable_http'. 

121 auth_token: Optional authentication token for secure server access. 

122 

123 Examples: 

124 >>> server = ServerInput(url='http://example.com/mcp') 

125 >>> server.transport 

126 'streamable_http' 

127 

128 >>> server = ServerInput() 

129 >>> server.url is None 

130 True 

131 """ 

132 

133 url: Optional[str] = None 

134 transport: Optional[str] = "streamable_http" 

135 auth_token: Optional[str] = None 

136 

137 

138class ConnectInput(BaseModel): 

139 """Request model for establishing a new chat session. 

140 

141 Contains all necessary parameters to initialize a user's chat session 

142 including server connection details, LLM configuration, and streaming preferences. 

143 

144 Attributes: 

145 user_id: Unique identifier for the user session. Required for session management. 

146 server: Optional MCP server configuration. Uses defaults if not provided. 

147 llm: LLM configuration specifying which gateway model to use. Required. 

148 streaming: Whether to enable streaming responses. Defaults to False. 

149 

150 Examples: 

151 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

152 >>> connect.streaming 

153 False 

154 

155 >>> connect = ConnectInput(user_id='user456', llm=LLMInput(model='gpt-4o'), streaming=True) 

156 >>> connect.user_id 

157 'user456' 

158 """ 

159 

160 user_id: str 

161 server: Optional[ServerInput] = None 

162 llm: LLMInput = Field(..., description="LLM configuration with model from gateway LLM Settings") 

163 streaming: bool = False 

164 

165 

166class ChatInput(BaseModel): 

167 """Request model for sending chat messages. 

168 

169 Encapsulates user message data for processing by the chat service. 

170 

171 Attributes: 

172 user_id: Unique identifier for the active user session. 

173 message: The chat message content to be processed. 

174 streaming: Whether to stream the response. Defaults to False. 

175 

176 Examples: 

177 >>> chat = ChatInput(user_id='user123', message='Hello, AI!') 

178 >>> len(chat.message) > 0 

179 True 

180 

181 >>> chat = ChatInput(user_id='user456', message='Tell me a story', streaming=True) 

182 >>> chat.streaming 

183 True 

184 """ 

185 

186 user_id: str 

187 message: str 

188 streaming: bool = False 

189 

190 

191class DisconnectInput(BaseModel): 

192 """Request model for terminating a chat session. 

193 

194 Simple model containing only the user identifier for session cleanup. 

195 

196 Attributes: 

197 user_id: Unique identifier of the session to disconnect. 

198 

199 Examples: 

200 >>> disconnect = DisconnectInput(user_id='user123') 

201 >>> disconnect.user_id 

202 'user123' 

203 """ 

204 

205 user_id: str 

206 

207 

208# ---------- HELPERS ---------- 

209 

210 

211def build_llm_config(llm: LLMInput) -> LLMConfig: 

212 """Construct an LLMConfig object from input parameters. 

213 

214 Creates a gateway provider configuration that routes requests through 

215 the gateway's LLM Settings. Models must be configured via Admin UI. 

216 

217 Args: 

218 llm: LLMInput containing model ID and optional temperature/max_tokens. 

219 

220 Returns: 

221 LLMConfig: Gateway provider configuration. 

222 

223 Examples: 

224 >>> llm_input = LLMInput(model='gpt-4o') 

225 >>> config = build_llm_config(llm_input) 

226 >>> config.provider 

227 'gateway' 

228 

229 Note: 

230 All LLM configuration is done via Admin UI -> Settings -> LLM Settings. 

231 The gateway provider looks up models from the database and creates 

232 the appropriate LLM instance based on provider type. 

233 """ 

234 return LLMConfig( 

235 provider="gateway", 

236 config=GatewayConfig( 

237 model=llm.model, 

238 temperature=llm.temperature if llm.temperature is not None else 0.7, 

239 max_tokens=llm.max_tokens, 

240 ), 

241 ) 

242 

243 

244def build_config(input_data: ConnectInput) -> MCPClientConfig: 

245 """Build complete MCP client configuration from connection input. 

246 

247 Constructs a comprehensive configuration object combining MCP server settings 

248 and LLM configuration. 

249 

250 Args: 

251 input_data: ConnectInput object containing server, LLM, and streaming settings. 

252 

253 Returns: 

254 MCPClientConfig: Complete client configuration ready for service initialization. 

255 

256 Examples: 

257 >>> from mcpgateway.routers.llmchat_router import ConnectInput, LLMInput, build_config 

258 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o')) 

259 >>> config = build_config(connect) 

260 >>> config.mcp_server.transport 

261 'streamable_http' 

262 

263 Note: 

264 MCP server settings use defaults if not provided. 

265 LLM configuration routes through the gateway provider. 

266 """ 

267 server = input_data.server 

268 

269 return MCPClientConfig( 

270 mcp_server=MCPServerConfig( 

271 url=server.url if server and server.url else "http://localhost:8000/mcp", 

272 transport=server.transport if server and server.transport else "streamable_http", 

273 auth_token=server.auth_token if server else None, 

274 ), 

275 llm=build_llm_config(input_data.llm), 

276 enable_streaming=input_data.streaming, 

277 ) 

278 

279 

280def _get_user_id_from_context(user: Dict[str, Any]) -> str: 

281 """Extract a stable user identifier from the authenticated user context. 

282 

283 Args: 

284 user: Authenticated user context from RBAC dependency. 

285 

286 Returns: 

287 User identifier string or "unknown" if missing. 

288 """ 

289 if isinstance(user, dict): 

290 return user.get("id") or user.get("user_id") or user.get("sub") or user.get("email") or "unknown" 

291 return "unknown" if user is None else str(getattr(user, "id", user)) 

292 

293 

294def _resolve_user_id(input_user_id: Optional[str], user: Dict[str, Any]) -> str: 

295 """Resolve the authenticated user ID and reject mismatched requests. 

296 

297 Args: 

298 input_user_id: User ID provided by the client (optional). 

299 user: Authenticated user context from RBAC dependency. 

300 

301 Returns: 

302 Resolved authenticated user identifier. 

303 

304 Raises: 

305 HTTPException: When authentication is missing or user ID mismatches. 

306 """ 

307 user_id = _get_user_id_from_context(user) 

308 if user_id == "unknown": 

309 raise HTTPException(status_code=401, detail="Authentication required.") 

310 if input_user_id and input_user_id != user_id: 

311 raise HTTPException(status_code=403, detail="User ID mismatch.") 

312 return user_id 

313 

314 

315# ---------- SESSION STORAGE HELPERS ---------- 

316 

317# Identify this worker uniquely (used for sticky session ownership) 

318WORKER_ID = str(os.getpid()) 

319 

320# Tunables (can set via environment) 

321SESSION_TTL = settings.llmchat_session_ttl # seconds for active_session key TTL 

322LOCK_TTL = settings.llmchat_session_lock_ttl # seconds for lock expiry 

323LOCK_RETRIES = settings.llmchat_session_lock_retries # how many times to poll while waiting 

324LOCK_WAIT = settings.llmchat_session_lock_wait # seconds between polls 

325 

326 

327# Redis key helpers 

328def _cfg_key(user_id: str) -> str: 

329 """Generate Redis key for user configuration storage. 

330 

331 Args: 

332 user_id: User identifier. 

333 

334 Returns: 

335 str: Redis key for storing user configuration. 

336 """ 

337 return f"user_config:{user_id}" 

338 

339 

340def _active_key(user_id: str) -> str: 

341 """Generate Redis key for active session tracking. 

342 

343 Args: 

344 user_id: User identifier. 

345 

346 Returns: 

347 str: Redis key for tracking active sessions. 

348 """ 

349 return f"active_session:{user_id}" 

350 

351 

352def _lock_key(user_id: str) -> str: 

353 """Generate Redis key for session initialization lock. 

354 

355 Args: 

356 user_id: User identifier. 

357 

358 Returns: 

359 str: Redis key for session locks. 

360 """ 

361 return f"session_lock:{user_id}" 

362 

363 

364# ---------- CONFIG HELPERS ---------- 

365 

366 

367async def set_user_config(user_id: str, config: MCPClientConfig): 

368 """Store user configuration in Redis or memory. 

369 

370 Args: 

371 user_id: User identifier. 

372 config: Complete MCP client configuration. 

373 """ 

374 if redis_client: 

375 await redis_client.set(_cfg_key(user_id), orjson.dumps(config.model_dump())) 

376 else: 

377 user_configs[user_id] = config 

378 

379 

380async def get_user_config(user_id: str) -> Optional[MCPClientConfig]: 

381 """Retrieve user configuration from Redis or memory. 

382 

383 Args: 

384 user_id: User identifier. 

385 

386 Returns: 

387 Optional[MCPClientConfig]: User configuration if found, None otherwise. 

388 """ 

389 if redis_client: 

390 data = await redis_client.get(_cfg_key(user_id)) 

391 if not data: 

392 return None 

393 return MCPClientConfig(**orjson.loads(data)) 

394 return user_configs.get(user_id) 

395 

396 

397async def delete_user_config(user_id: str): 

398 """Delete user configuration from Redis or memory. 

399 

400 Args: 

401 user_id: User identifier. 

402 """ 

403 if redis_client: 

404 await redis_client.delete(_cfg_key(user_id)) 

405 else: 

406 user_configs.pop(user_id, None) 

407 

408 

409# ---------- SESSION (active) HELPERS with locking & recreate ---------- 

410 

411 

412async def set_active_session(user_id: str, session: MCPChatService): 

413 """Register an active session locally and mark ownership in Redis with TTL. 

414 

415 Args: 

416 user_id: User identifier. 

417 session: Initialized MCPChatService instance. 

418 """ 

419 active_sessions[user_id] = session 

420 if redis_client: 

421 # set owner with TTL so dead workers eventually lose ownership 

422 await redis_client.set(_active_key(user_id), WORKER_ID, ex=SESSION_TTL) 

423 

424 

425async def delete_active_session(user_id: str): 

426 """Remove active session locally and from Redis atomically. 

427 

428 Uses a Lua script to ensure we only delete the Redis key if we own it, 

429 preventing race conditions where another worker's session marker could 

430 be deleted if our session expired and was recreated by another worker. 

431 

432 Args: 

433 user_id: User identifier. 

434 """ 

435 active_sessions.pop(user_id, None) 

436 if redis_client: 

437 try: 

438 # Lua script for atomic check-and-delete (only delete if we own the key) 

439 release_script = """ 

440 if redis.call("get", KEYS[1]) == ARGV[1] then 

441 return redis.call("del", KEYS[1]) 

442 else 

443 return 0 

444 end 

445 """ 

446 await redis_client.eval(release_script, 1, _active_key(user_id), WORKER_ID) 

447 except Exception as e: 

448 logger.warning(f"Failed to delete active session for user {user_id}: {e}") 

449 

450 

451async def _try_acquire_lock(user_id: str) -> bool: 

452 """Attempt to acquire the initialization lock for a user session. 

453 

454 Args: 

455 user_id: User identifier. 

456 

457 Returns: 

458 bool: True if lock acquired, False otherwise. 

459 """ 

460 if not redis_client: 

461 return True # no redis -> local only, no lock required 

462 return await redis_client.set(_lock_key(user_id), WORKER_ID, nx=True, ex=LOCK_TTL) 

463 

464 

465async def _release_lock_safe(user_id: str): 

466 """Release the lock atomically only if we own it. 

467 

468 Uses a Lua script to ensure atomic check-and-delete, preventing 

469 the TOCTOU race condition where another worker's lock could be 

470 deleted if the original lock expired between get() and delete(). 

471 

472 Args: 

473 user_id: User identifier. 

474 """ 

475 if not redis_client: 

476 return 

477 try: 

478 # Lua script for atomic check-and-delete (only delete if we own the key) 

479 release_script = """ 

480 if redis.call("get", KEYS[1]) == ARGV[1] then 

481 return redis.call("del", KEYS[1]) 

482 else 

483 return 0 

484 end 

485 """ 

486 await redis_client.eval(release_script, 1, _lock_key(user_id), WORKER_ID) 

487 except Exception as e: 

488 logger.warning(f"Failed to release lock for user {user_id}: {e}") 

489 

490 

491async def _create_local_session_from_config(user_id: str) -> Optional[MCPChatService]: 

492 """Create MCPChatService locally from stored config. 

493 

494 Args: 

495 user_id: User identifier. 

496 

497 Returns: 

498 Optional[MCPChatService]: Initialized service or None if creation fails. 

499 """ 

500 config = await get_user_config(user_id) 

501 if not config: 

502 return None 

503 

504 # create and initialize with unified history manager 

505 try: 

506 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

507 await chat_service.initialize() 

508 await set_active_session(user_id, chat_service) 

509 return chat_service 

510 except Exception as e: 

511 # If initialization fails, ensure nothing partial remains 

512 logger.error(f"Failed to initialize MCPChatService for {user_id}: {e}", exc_info=True) 

513 # cleanup local state and redis ownership (if we set it) 

514 await delete_active_session(user_id) 

515 return None 

516 

517 

518async def get_active_session(user_id: str) -> Optional[MCPChatService]: 

519 """ 

520 Retrieve or (if possible) create the active session for user_id. 

521 

522 Behavior: 

523 - If Redis is disabled: return local session or None. 

524 - If Redis enabled: 

525 * If owner == WORKER_ID and local session exists -> return it (and refresh TTL) 

526 * If owner == WORKER_ID but local missing -> try to acquire lock and recreate 

527 * If no owner -> try to acquire lock and create session here 

528 * If owner != WORKER_ID -> wait a short time for owner to appear or return None 

529 

530 Args: 

531 user_id: User identifier. 

532 

533 Returns: 

534 Optional[MCPChatService]: Active session if available, None otherwise. 

535 """ 

536 # Fast path: no redis => purely local 

537 if not redis_client: 

538 return active_sessions.get(user_id) 

539 

540 active_key = _active_key(user_id) 

541 # _lock_key = _lock_key(user_id) 

542 owner = await redis_client.get(active_key) 

543 

544 # 1) Owned by this worker 

545 if owner == WORKER_ID: 

546 local = active_sessions.get(user_id) 

547 if local: 

548 # refresh TTL so ownership persists while active 

549 try: 

550 await redis_client.expire(active_key, SESSION_TTL) 

551 except Exception as e: # nosec B110 

552 # non-fatal if expire fails, just log the error 

553 logger.debug(f"Failed to refresh session TTL for {user_id}: {e}") 

554 return local 

555 

556 # Owner in Redis points to this worker but local session missing (process restart or lost). 

557 # Try to recreate it (acquire lock). 

558 acquired = await _try_acquire_lock(user_id) 

559 if acquired: 

560 try: 

561 # create new local session 

562 session = await _create_local_session_from_config(user_id) 

563 return session 

564 finally: 

565 await _release_lock_safe(user_id) 

566 else: 

567 # someone else is (re)creating; wait a bit for them to finish 

568 for _ in range(LOCK_RETRIES): 

569 await asyncio.sleep(LOCK_WAIT) 

570 if active_sessions.get(user_id): 570 ↛ 571line 570 didn't jump to line 571 because the condition on line 570 was never true

571 return active_sessions.get(user_id) 

572 return None 

573 

574 # 2) No owner -> try to claim & create session locally 

575 if owner is None: 

576 acquired = await _try_acquire_lock(user_id) 

577 if acquired: 

578 try: 

579 session = await _create_local_session_from_config(user_id) 

580 return session 

581 finally: 

582 await _release_lock_safe(user_id) 

583 

584 # if we couldn't acquire lock, someone else is creating; wait a short time 

585 for _ in range(LOCK_RETRIES): 

586 await asyncio.sleep(LOCK_WAIT) 

587 owner2 = await redis_client.get(active_key) 

588 if owner2 == WORKER_ID and active_sessions.get(user_id): 

589 return active_sessions.get(user_id) 

590 if owner2 is not None and owner2 != WORKER_ID: 

591 # some other worker now owns it 

592 return None 

593 

594 # final attempt to acquire lock (last resort) 

595 acquired = await _try_acquire_lock(user_id) 

596 if acquired: 

597 try: 

598 session = await _create_local_session_from_config(user_id) 

599 return session 

600 finally: 

601 await _release_lock_safe(user_id) 

602 return None 

603 

604 # 3) Owned by another worker -> we don't have it locally 

605 # Optionally we could attempt to "steal" if owner is stale, but TTL expiry handles that. 

606 return None 

607 

608 

609# ---------- ROUTES ---------- 

610 

611 

612@llmchat_router.post("/connect") 

613async def connect(input_data: ConnectInput, request: Request, user=Depends(get_current_user_with_permissions)): 

614 """Create or refresh a chat session for a user. 

615 

616 Initializes a new MCPChatService instance for the specified user, establishing 

617 connections to both the MCP server and the configured LLM provider. If a session 

618 already exists for the user, it is gracefully shutdown before creating a new one. 

619 

620 Authentication is handled via JWT token from cookies if not explicitly provided 

621 in the request body. 

622 

623 Args: 

624 input_data: ConnectInput containing user_id, optional server/LLM config, and streaming preference. 

625 request: FastAPI Request object for accessing cookies and headers. 

626 user: Authenticated user context. 

627 

628 Returns: 

629 dict: Connection status response containing: 

630 - status: 'connected' 

631 - user_id: The connected user's identifier 

632 - provider: The LLM provider being used 

633 - tool_count: Number of available MCP tools 

634 - tools: List of tool names 

635 

636 Raises: 

637 HTTPException: If an error occurs. 

638 400: Invalid user_id, invalid configuration, or LLM config error. 

639 401: Missing authentication token. 

640 503: Failed to connect to MCP server. 

641 500: Service initialization failure or unexpected error. 

642 

643 Examples: 

644 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

645 Example request body: 

646 

647 { 

648 "user_id": "user123", 

649 "server": { 

650 "url": "http://localhost:8000/mcp", 

651 "auth_token": "jwt_token_here" 

652 }, 

653 "llm": { 

654 "provider": "ollama", 

655 "config": {"model": "llama3"} 

656 }, 

657 "streaming": false 

658 } 

659 

660 Example response: 

661 

662 { 

663 "status": "connected", 

664 "user_id": "user123", 

665 "provider": "ollama", 

666 "tool_count": 5, 

667 "tools": ["search", "calculator", "weather", "translate", "summarize"] 

668 } 

669 

670 Note: 

671 Existing sessions are automatically terminated before establishing new ones. 

672 All configuration values support environment variable fallbacks. 

673 """ 

674 user_id = _resolve_user_id(input_data.user_id, user) 

675 

676 try: 

677 # Validate user_id 

678 if not user_id or not isinstance(user_id, str): 

679 raise HTTPException(status_code=400, detail="Invalid user ID provided") 

680 

681 # Handle authentication token 

682 empty_token = "" # nosec B105 

683 if input_data.server and (input_data.server.auth_token is None or input_data.server.auth_token == empty_token): 

684 jwt_token = request.cookies.get("jwt_token") 

685 if not jwt_token: 

686 raise HTTPException(status_code=401, detail="Authentication required. Please ensure you are logged in.") 

687 input_data.server.auth_token = jwt_token 

688 

689 # Close old session if it exists 

690 existing = await get_active_session(user_id) 

691 if existing: 

692 try: 

693 logger.debug(f"Disconnecting existing session for {user_id} before reconnecting") 

694 await existing.shutdown() 

695 except Exception as shutdown_error: 

696 logger.warning(f"Failed to cleanly shutdown existing session for {user_id}: {shutdown_error}") 

697 finally: 

698 # Always remove the session from active sessions, even if shutdown failed 

699 await delete_active_session(user_id) 

700 

701 # Build and validate configuration 

702 try: 

703 config = build_config(input_data) 

704 except ValueError as ve: 

705 raise HTTPException(status_code=400, detail=f"Invalid configuration: {str(ve)}") 

706 except Exception as config_error: 

707 raise HTTPException(status_code=400, detail=f"Configuration error: {str(config_error)}") 

708 

709 # Store user configuration 

710 await set_user_config(user_id, config) 

711 

712 # Initialize chat service 

713 try: 

714 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client) 

715 await chat_service.initialize() 

716 

717 # Clear chat history on new connection 

718 await chat_service.clear_history() 

719 except ConnectionError as ce: 

720 # Clean up partial state 

721 await delete_user_config(user_id) 

722 raise HTTPException(status_code=503, detail=f"Failed to connect to MCP server: {str(ce)}. Please verify the server URL and authentication.") 

723 except ValueError as ve: 

724 # Clean up partial state 

725 await delete_user_config(user_id) 

726 raise HTTPException(status_code=400, detail=f"Invalid LLM configuration: {str(ve)}") 

727 except Exception as init_error: 

728 # Clean up partial state 

729 await delete_user_config(user_id) 

730 raise HTTPException(status_code=500, detail=f"Service initialization failed: {str(init_error)}") 

731 

732 await set_active_session(user_id, chat_service) 

733 

734 # Extract tool names 

735 tool_names = [] 

736 try: 

737 if hasattr(chat_service, "_tools") and chat_service._tools: 737 ↛ 746line 737 didn't jump to line 746 because the condition on line 737 was always true

738 for tool in chat_service._tools: 

739 tool_name = getattr(tool, "name", None) 

740 if tool_name: 740 ↛ 738line 740 didn't jump to line 738 because the condition on line 740 was always true

741 tool_names.append(tool_name) 

742 except Exception as tool_error: 

743 logger.warning(f"Failed to extract tool names: {tool_error}") 

744 # Continue without tools list 

745 

746 return {"status": "connected", "user_id": user_id, "provider": config.llm.provider, "tool_count": len(tool_names), "tools": tool_names} 

747 

748 except HTTPException: 

749 # Re-raise HTTP exceptions as-is 

750 raise 

751 except Exception as e: 

752 logger.error(f"Unexpected error in connect endpoint: {e}", exc_info=True) 

753 raise HTTPException(status_code=500, detail=f"Unexpected connection error: {str(e)}") 

754 

755 

756async def token_streamer(chat_service: MCPChatService, message: str, user_id: str): 

757 """Stream chat response tokens as Server-Sent Events (SSE). 

758 

759 Asynchronous generator that yields SSE-formatted chunks containing tokens, 

760 tool invocation updates, and final response data from the chat service. 

761 Uses the unified ChatHistoryManager for history persistence. 

762 

763 Args: 

764 chat_service: MCPChatService instance configured for the user session. 

765 message: User's chat message to process. 

766 user_id: User identifier for logging. 

767 

768 Yields: 

769 bytes: SSE-formatted event data containing: 

770 - token events: Incremental content chunks 

771 - tool_start: Tool invocation beginning 

772 - tool_end: Tool invocation completion 

773 - tool_error: Tool execution failure 

774 - final: Complete response with metadata 

775 - error: Error information with recovery status 

776 

777 Event Types: 

778 - token: {"content": "text chunk"} 

779 - tool_start: {"type": "tool_start", "tool": "name", ...} 

780 - tool_end: {"type": "tool_end", "tool": "name", ...} 

781 - tool_error: {"type": "tool_error", "tool": "name", "error": "message"} 

782 - final: {"type": "final", "text": "complete response", "metadata": {...}} 

783 - error: {"type": "error", "error": "message", "recoverable": bool} 

784 

785 Examples: 

786 This is an async generator used internally by the chat endpoint. 

787 It cannot be directly tested with standard doctest. 

788 

789 Example event stream: 

790 

791 event: token 

792 data: {"content": "Hello"} 

793 

794 event: token 

795 data: {"content": ", how"} 

796 

797 event: final 

798 data: {"type": "final", "text": "Hello, how can I help?"} 

799 

800 Note: 

801 SSE format requires 'event: <type>\\ndata: <json>\\n\\n' structure. 

802 All exceptions are caught and converted to error events for client handling. 

803 """ 

804 

805 async def sse(event_type: str, data: Dict[str, Any]): 

806 """Format data as Server-Sent Event. 

807 

808 Args: 

809 event_type: SSE event type identifier. 

810 data: Payload dictionary to serialize as JSON. 

811 

812 Yields: 

813 bytes: UTF-8 encoded SSE formatted lines. 

814 """ 

815 yield f"event: {event_type}\n".encode("utf-8") 

816 yield f"data: {orjson.dumps(data).decode()}\n\n".encode("utf-8") 

817 

818 try: 

819 async for ev in chat_service.chat_events(message): 

820 et = ev.get("type") 

821 if et == "token": 

822 content = ev.get("content", "") 

823 async for part in sse("token", {"content": content}): 

824 yield part 

825 elif et in ("tool_start", "tool_end", "tool_error"): 

826 async for part in sse(et, ev): 

827 yield part 

828 elif et == "final": 828 ↛ 819line 828 didn't jump to line 819 because the condition on line 828 was always true

829 async for part in sse("final", ev): 

830 yield part 

831 

832 except ConnectionError as ce: 

833 error_event = {"type": "error", "error": f"Connection lost: {str(ce)}", "recoverable": False} 

834 async for part in sse("error", error_event): 

835 yield part 

836 except TimeoutError: 

837 error_event = {"type": "error", "error": "Request timed out waiting for LLM response", "recoverable": True} 

838 async for part in sse("error", error_event): 

839 yield part 

840 except RuntimeError as re: 

841 error_event = {"type": "error", "error": f"Service error: {str(re)}", "recoverable": False} 

842 async for part in sse("error", error_event): 

843 yield part 

844 except Exception as e: 

845 logger.error(f"Unexpected streaming error: {e}", exc_info=True) 

846 error_event = {"type": "error", "error": f"Unexpected error: {str(e)}", "recoverable": False} 

847 async for part in sse("error", error_event): 

848 yield part 

849 

850 

851@llmchat_router.post("/chat") 

852async def chat(input_data: ChatInput, user=Depends(get_current_user_with_permissions)): 

853 """Send a message to the user's active chat session and receive a response. 

854 

855 Processes user messages through the configured LLM with MCP tool integration. 

856 Supports both streaming (SSE) and non-streaming response modes. Chat history 

857 is managed automatically via the unified ChatHistoryManager. 

858 

859 Args: 

860 input_data: ChatInput containing user_id, message, and streaming preference. 

861 user: Authenticated user context. 

862 

863 Returns: 

864 For streaming=False: 

865 dict: Response containing: 

866 - user_id: Session identifier 

867 - response: Complete LLM response text 

868 - tool_used: Boolean indicating if tools were invoked 

869 - tools: List of tool names used 

870 - tool_invocations: Detailed tool call information 

871 - elapsed_ms: Processing time in milliseconds 

872 For streaming=True: 

873 StreamingResponse: SSE stream of token and event data. 

874 

875 Raises: 

876 HTTPException: Raised when an HTTP-related error occurs. 

877 400: Missing user_id, empty message, or no active session. 

878 503: Session not initialized, chat service error, or connection lost. 

879 504: Request timeout. 

880 500: Unexpected error. 

881 

882 Examples: 

883 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

884 

885 Example non-streaming request: 

886 

887 { 

888 "user_id": "user123", 

889 "message": "What's the weather like?", 

890 "streaming": false 

891 } 

892 

893 Example non-streaming response: 

894 

895 { 

896 "user_id": "user123", 

897 "response": "The weather is sunny and 72°F.", 

898 "tool_used": true, 

899 "tools": ["weather"], 

900 "tool_invocations": 1, 

901 "elapsed_ms": 450 

902 } 

903 

904 Example streaming request: 

905 

906 { 

907 "user_id": "user123", 

908 "message": "Tell me a story", 

909 "streaming": true 

910 } 

911 

912 Note: 

913 Streaming responses use Server-Sent Events (SSE) with 'text/event-stream' MIME type. 

914 Client must maintain persistent connection for streaming. 

915 """ 

916 user_id = _resolve_user_id(input_data.user_id, user) 

917 

918 # Validate input 

919 if not user_id: 919 ↛ 920line 919 didn't jump to line 920 because the condition on line 919 was never true

920 raise HTTPException(status_code=400, detail="User ID is required") 

921 

922 if not input_data.message or not input_data.message.strip(): 

923 raise HTTPException(status_code=400, detail="Message cannot be empty") 

924 

925 # Check for active session 

926 chat_service = await get_active_session(user_id) 

927 if not chat_service: 

928 raise HTTPException(status_code=400, detail="No active session found. Please connect to a server first.") 

929 

930 # Verify session is initialized 

931 if not chat_service.is_initialized: 

932 raise HTTPException(status_code=503, detail="Session is not properly initialized. Please reconnect.") 

933 

934 try: 

935 if input_data.streaming: 

936 return StreamingResponse( 

937 token_streamer(chat_service, input_data.message, user_id), 

938 media_type="text/event-stream", 

939 headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, # Disable proxy buffering 

940 ) 

941 else: 

942 try: 

943 result = await chat_service.chat_with_metadata(input_data.message) 

944 

945 return { 

946 "user_id": user_id, 

947 "response": result["text"], 

948 "tool_used": result["tool_used"], 

949 "tools": result["tools"], 

950 "tool_invocations": result["tool_invocations"], 

951 "elapsed_ms": result["elapsed_ms"], 

952 } 

953 except RuntimeError as re: 

954 raise HTTPException(status_code=503, detail=f"Chat service error: {str(re)}") 

955 

956 except ConnectionError as ce: 

957 raise HTTPException(status_code=503, detail=f"Lost connection to MCP server: {str(ce)}. Please reconnect.") 

958 except TimeoutError: 

959 raise HTTPException(status_code=504, detail="Request timed out. The LLM took too long to respond.") 

960 except HTTPException: 

961 raise 

962 except Exception as e: 

963 logger.error(f"Unexpected error in chat endpoint for user {user_id}: {e}", exc_info=True) 

964 raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") 

965 

966 

967@llmchat_router.post("/disconnect") 

968async def disconnect(input_data: DisconnectInput, user=Depends(get_current_user_with_permissions)): 

969 """End the chat session for a user and clean up resources. 

970 

971 Gracefully shuts down the MCPChatService instance, closes connections, 

972 and removes session data from active storage. Safe to call even if 

973 no active session exists. 

974 

975 Args: 

976 input_data: DisconnectInput containing the user_id to disconnect. 

977 user: Authenticated user context. 

978 

979 Returns: 

980 dict: Disconnection status containing: 

981 - status: One of 'disconnected', 'no_active_session', or 'disconnected_with_errors' 

982 - user_id: The user identifier 

983 - message: Human-readable status description 

984 - warning: (Optional) Error details if cleanup encountered issues 

985 

986 Raises: 

987 HTTPException: Raised when an HTTP-related error occurs. 

988 400: Missing user_id. 

989 

990 Examples: 

991 This endpoint is called via HTTP POST and cannot be directly tested with doctest. 

992 

993 Example request: 

994 

995 { 

996 "user_id": "user123" 

997 } 

998 

999 Example successful response: 

1000 

1001 { 

1002 "status": "disconnected", 

1003 "user_id": "user123", 

1004 "message": "Successfully disconnected" 

1005 } 

1006 

1007 Example response when no session exists: 

1008 

1009 { 

1010 "status": "no_active_session", 

1011 "user_id": "user123", 

1012 "message": "No active session to disconnect" 

1013 } 

1014 

1015 Note: 

1016 This operation is idempotent - calling it multiple times for the same 

1017 user_id is safe and will not raise errors. 

1018 """ 

1019 user_id = _resolve_user_id(input_data.user_id, user) 

1020 

1021 if not user_id: 1021 ↛ 1022line 1021 didn't jump to line 1022 because the condition on line 1021 was never true

1022 raise HTTPException(status_code=400, detail="User ID is required") 

1023 

1024 # Remove and shut down chat service 

1025 chat_service = await get_active_session(user_id) 

1026 await delete_active_session(user_id) 

1027 

1028 # Remove user config 

1029 await delete_user_config(user_id) 

1030 

1031 if not chat_service: 

1032 return {"status": "no_active_session", "user_id": user_id, "message": "No active session to disconnect"} 

1033 

1034 try: 

1035 # Clear chat history on disconnect 

1036 await chat_service.clear_history() 

1037 logger.info(f"Chat session disconnected for {user_id}") 

1038 

1039 await chat_service.shutdown() 

1040 return {"status": "disconnected", "user_id": user_id, "message": "Successfully disconnected"} 

1041 except Exception as e: 

1042 logger.error(f"Error during disconnect for user {user_id}: {e}", exc_info=True) 

1043 # Session already removed, so return success with warning 

1044 return {"status": "disconnected_with_errors", "user_id": user_id, "message": "Disconnected but cleanup encountered errors", "warning": str(e)} 

1045 

1046 

1047@llmchat_router.get("/status/{user_id}") 

1048async def status(user_id: str, user=Depends(get_current_user_with_permissions)): 

1049 """Check if an active chat session exists for the specified user. 

1050 

1051 Lightweight endpoint for verifying session state without modifying data. 

1052 Useful for health checks and UI state management. 

1053 

1054 Args: 

1055 user_id: User identifier to check session status for. 

1056 user: Authenticated user context. 

1057 

1058 Returns: 

1059 dict: Status information containing: 

1060 - user_id: The queried user identifier 

1061 - connected: Boolean indicating if an active session exists 

1062 

1063 Examples: 

1064 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1065 

1066 Example request: 

1067 GET /llmchat/status/user123 

1068 

1069 Example response (connected): 

1070 

1071 { 

1072 "user_id": "user123", 

1073 "connected": true 

1074 } 

1075 

1076 Example response (not connected): 

1077 

1078 { 

1079 "user_id": "user456", 

1080 "connected": false 

1081 } 

1082 

1083 Note: 

1084 This endpoint does not validate that the session is properly initialized, 

1085 only that it exists in the active_sessions dictionary. 

1086 """ 

1087 resolved_user_id = _resolve_user_id(user_id, user) 

1088 connected = bool(await get_active_session(resolved_user_id)) 

1089 return {"user_id": resolved_user_id, "connected": connected} 

1090 

1091 

1092@llmchat_router.get("/config/{user_id}") 

1093async def get_config(user_id: str, user=Depends(get_current_user_with_permissions)): 

1094 """Retrieve the stored configuration for a user's session. 

1095 

1096 Returns sanitized configuration data with sensitive information (API keys, 

1097 auth tokens) removed for security. Useful for debugging and configuration 

1098 verification. 

1099 

1100 Args: 

1101 user_id: User identifier whose configuration to retrieve. 

1102 user: Authenticated user context. 

1103 

1104 Returns: 

1105 dict: Sanitized configuration dictionary containing: 

1106 - mcp_server: Server connection settings (without auth_token) 

1107 - llm: LLM provider configuration (without api_key) 

1108 - enable_streaming: Boolean streaming preference 

1109 

1110 Raises: 

1111 HTTPException: Raised when an HTTP-related error occurs. 

1112 404: No configuration found for the specified user_id. 

1113 

1114 

1115 Examples: 

1116 This endpoint is called via HTTP GET and cannot be directly tested with doctest. 

1117 

1118 Example request: 

1119 GET /llmchat/config/user123 

1120 

1121 Example response: 

1122 

1123 { 

1124 "mcp_server": { 

1125 "url": "http://localhost:8000/mcp", 

1126 "transport": "streamable_http" 

1127 }, 

1128 "llm": { 

1129 "provider": "ollama", 

1130 "config": { 

1131 "model": "llama3", 

1132 "temperature": 0.7 

1133 } 

1134 }, 

1135 "enable_streaming": false 

1136 } 

1137 

1138 Security: 

1139 API keys and authentication tokens are explicitly removed before returning. 

1140 Never log or expose these values in responses. 

1141 """ 

1142 resolved_user_id = _resolve_user_id(user_id, user) 

1143 config = await get_user_config(resolved_user_id) 

1144 

1145 if not config: 

1146 raise HTTPException(status_code=404, detail="No config found for this user.") 

1147 

1148 # Sanitize and return config (remove secrets) 

1149 config_dict = config.model_dump() 

1150 

1151 if "config" in config_dict.get("llm", {}): 1151 ↛ 1155line 1151 didn't jump to line 1155 because the condition on line 1151 was always true

1152 config_dict["llm"]["config"].pop("api_key", None) 

1153 config_dict["llm"]["config"].pop("auth_token", None) 

1154 

1155 return config_dict 

1156 

1157 

1158@llmchat_router.get("/gateway/models") 

1159async def get_gateway_models(_user=Depends(get_current_user_with_permissions)): 

1160 """Get available models from configured LLM providers. 

1161 

1162 Returns a list of enabled models from enabled providers configured 

1163 in the gateway's LLM Settings. These models can be used with the 

1164 "gateway" provider type in /connect requests. 

1165 

1166 Returns: 

1167 dict: Response containing: 

1168 - models: List of available models with provider info 

1169 - count: Total number of available models 

1170 

1171 Examples: 

1172 GET /llmchat/gateway/models 

1173 

1174 Response: 

1175 { 

1176 "models": [ 

1177 { 

1178 "id": "abc123", 

1179 "model_id": "gpt-4o", 

1180 "model_name": "GPT-4o", 

1181 "provider_id": "def456", 

1182 "provider_name": "OpenAI", 

1183 "provider_type": "openai", 

1184 "supports_streaming": true, 

1185 "supports_function_calling": true, 

1186 "supports_vision": true 

1187 } 

1188 ], 

1189 "count": 1 

1190 } 

1191 

1192 Raises: 

1193 HTTPException: If there is an error retrieving gateway models. 

1194 

1195 Args: 

1196 _user: Authenticated user context. 

1197 """ 

1198 # Import here to avoid circular dependency 

1199 # First-Party 

1200 from mcpgateway.db import SessionLocal 

1201 from mcpgateway.services.llm_provider_service import LLMProviderService 

1202 

1203 llm_service = LLMProviderService() 

1204 

1205 try: 

1206 with SessionLocal() as db: 

1207 models = llm_service.get_gateway_models(db) 

1208 return { 

1209 "models": [m.model_dump() for m in models], 

1210 "count": len(models), 

1211 } 

1212 except Exception as e: 

1213 logger.error(f"Failed to get gateway models: {e}") 

1214 raise HTTPException(status_code=500, detail=f"Failed to retrieve gateway models: {str(e)}")