Coverage for mcpgateway / routers / llmchat_router.py: 99%
397 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/llmchat_router.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7LLM Chat Router Module
9This module provides FastAPI endpoints for managing LLM-based chat sessions
10with MCP (Model Context Protocol) server integration. LLM providers are
11configured via the Admin UI's LLM Settings and accessed through the gateway
12provider.
14The module handles user session management, configuration, and real-time
15streaming responses for conversational AI applications with unified chat
16history management via ChatHistoryManager from mcp_client_chat_service.
18"""
20# Standard
21import asyncio
22import os
23import time
24from typing import Any, Dict, Optional
26# Third-Party
27from fastapi import APIRouter, Depends, HTTPException, Request
28from fastapi.responses import StreamingResponse
29import orjson
30from pydantic import BaseModel, Field
32try:
33 # Third-Party
34 import redis.asyncio # noqa: F401 - availability check only
36 REDIS_AVAILABLE = True
37except ImportError:
38 REDIS_AVAILABLE = False
40# First-Party
41from mcpgateway.common.validators import SecurityValidator
42from mcpgateway.config import settings
43from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
44from mcpgateway.services.logging_service import LoggingService
45from mcpgateway.services.mcp_client_chat_service import (
46 GatewayConfig,
47 LLMConfig,
48 MCPChatService,
49 MCPClientConfig,
50 MCPServerConfig,
51)
52from mcpgateway.utils.redis_client import get_redis_client
53from mcpgateway.utils.services_auth import decode_auth, encode_auth
55# Initialize router
56llmchat_router = APIRouter(prefix="/llmchat", tags=["llmchat"])
58# Redis client (initialized via init_redis() during app startup)
59redis_client = None
62async def init_redis() -> None:
63 """Initialize Redis client using the shared factory.
65 Should be called during application startup from main.py lifespan.
66 """
67 global redis_client
68 if getattr(settings, "cache_type", None) == "redis" and getattr(settings, "redis_url", None):
69 redis_client = await get_redis_client()
70 if redis_client:
71 logger.info("LLMChat router connected to shared Redis client")
74# Fallback in-memory stores (used when Redis unavailable)
75# Store active chat sessions per user
76active_sessions: Dict[str, MCPChatService] = {}
78# Store configuration per user
79user_configs: Dict[str, tuple[bytes, float]] = {}
81# Logging
82logging_service = LoggingService()
83logger = logging_service.get_logger(__name__)
85# ---------- MODELS ----------
88class LLMInput(BaseModel):
89 """Input configuration for LLM provider selection.
91 This model specifies which gateway-configured model to use.
92 Models must be configured via Admin UI -> LLM Settings.
94 Attributes:
95 model: Model ID from the gateway's LLM Settings (UUID or model_id).
96 temperature: Optional sampling temperature (0.0-2.0).
97 max_tokens: Optional maximum tokens to generate.
99 Examples:
100 >>> llm_input = LLMInput(model='gpt-4o')
101 >>> llm_input.model
102 'gpt-4o'
104 >>> llm_input = LLMInput(model='abc123-uuid', temperature=0.5)
105 >>> llm_input.temperature
106 0.5
107 """
109 model: str = Field(..., description="Model ID from gateway LLM Settings (UUID or model_id)")
110 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature")
111 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
114class ServerInput(BaseModel):
115 """Input configuration for MCP server connection.
117 Defines the connection parameters required to establish communication
118 with an MCP (Model Context Protocol) server.
120 Attributes:
121 url: Optional MCP server URL endpoint. Defaults to environment variable
122 or 'http://localhost:8000/mcp'.
123 transport: Communication transport protocol. Defaults to 'streamable_http'.
124 auth_token: Optional authentication token for secure server access.
126 Examples:
127 >>> server = ServerInput(url='http://example.com/mcp')
128 >>> server.transport
129 'streamable_http'
131 >>> server = ServerInput()
132 >>> server.url is None
133 True
134 """
136 url: Optional[str] = None
137 transport: Optional[str] = "streamable_http"
138 auth_token: Optional[str] = None
141class ConnectInput(BaseModel):
142 """Request model for establishing a new chat session.
144 Contains all necessary parameters to initialize a user's chat session
145 including server connection details, LLM configuration, and streaming preferences.
147 Attributes:
148 user_id: Unique identifier for the user session. Required for session management.
149 server: Optional MCP server configuration. Uses defaults if not provided.
150 llm: LLM configuration specifying which gateway model to use. Required.
151 streaming: Whether to enable streaming responses. Defaults to False.
153 Examples:
154 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o'))
155 >>> connect.streaming
156 False
158 >>> connect = ConnectInput(user_id='user456', llm=LLMInput(model='gpt-4o'), streaming=True)
159 >>> connect.user_id
160 'user456'
161 """
163 user_id: str
164 server: Optional[ServerInput] = None
165 llm: LLMInput = Field(..., description="LLM configuration with model from gateway LLM Settings")
166 streaming: bool = False
169class ChatInput(BaseModel):
170 """Request model for sending chat messages.
172 Encapsulates user message data for processing by the chat service.
174 Attributes:
175 user_id: Unique identifier for the active user session.
176 message: The chat message content to be processed.
177 streaming: Whether to stream the response. Defaults to False.
179 Examples:
180 >>> chat = ChatInput(user_id='user123', message='Hello, AI!')
181 >>> len(chat.message) > 0
182 True
184 >>> chat = ChatInput(user_id='user456', message='Tell me a story', streaming=True)
185 >>> chat.streaming
186 True
187 """
189 user_id: str
190 message: str
191 streaming: bool = False
194class DisconnectInput(BaseModel):
195 """Request model for terminating a chat session.
197 Simple model containing only the user identifier for session cleanup.
199 Attributes:
200 user_id: Unique identifier of the session to disconnect.
202 Examples:
203 >>> disconnect = DisconnectInput(user_id='user123')
204 >>> disconnect.user_id
205 'user123'
206 """
208 user_id: str
211# ---------- HELPERS ----------
214def build_llm_config(llm: LLMInput) -> LLMConfig:
215 """Construct an LLMConfig object from input parameters.
217 Creates a gateway provider configuration that routes requests through
218 the gateway's LLM Settings. Models must be configured via Admin UI.
220 Args:
221 llm: LLMInput containing model ID and optional temperature/max_tokens.
223 Returns:
224 LLMConfig: Gateway provider configuration.
226 Examples:
227 >>> llm_input = LLMInput(model='gpt-4o')
228 >>> config = build_llm_config(llm_input)
229 >>> config.provider
230 'gateway'
232 Note:
233 All LLM configuration is done via Admin UI -> Settings -> LLM Settings.
234 The gateway provider looks up models from the database and creates
235 the appropriate LLM instance based on provider type.
236 """
237 return LLMConfig(
238 provider="gateway",
239 config=GatewayConfig(
240 model=llm.model,
241 temperature=llm.temperature if llm.temperature is not None else 0.7,
242 max_tokens=llm.max_tokens,
243 ),
244 )
247def build_config(input_data: ConnectInput) -> MCPClientConfig:
248 """Build complete MCP client configuration from connection input.
250 Constructs a comprehensive configuration object combining MCP server settings
251 and LLM configuration.
253 Args:
254 input_data: ConnectInput object containing server, LLM, and streaming settings.
256 Returns:
257 MCPClientConfig: Complete client configuration ready for service initialization.
259 Examples:
260 >>> from mcpgateway.routers.llmchat_router import ConnectInput, LLMInput, build_config
261 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o'))
262 >>> config = build_config(connect)
263 >>> config.mcp_server.transport
264 'streamable_http'
266 Note:
267 MCP server settings use defaults if not provided.
268 LLM configuration routes through the gateway provider.
269 """
270 server = input_data.server
272 return MCPClientConfig(
273 mcp_server=MCPServerConfig(
274 url=server.url if server and server.url else "http://localhost:8000/mcp",
275 transport=server.transport if server and server.transport else "streamable_http",
276 auth_token=server.auth_token if server else None,
277 ),
278 llm=build_llm_config(input_data.llm),
279 enable_streaming=input_data.streaming,
280 )
283def _get_user_id_from_context(user: Dict[str, Any]) -> str:
284 """Extract a stable user identifier from the authenticated user context.
286 Args:
287 user: Authenticated user context from RBAC dependency.
289 Returns:
290 User identifier string or "unknown" if missing.
291 """
292 if isinstance(user, dict):
293 return user.get("id") or user.get("user_id") or user.get("sub") or user.get("email") or "unknown"
294 return "unknown" if user is None else str(getattr(user, "id", user))
297def _resolve_user_id(input_user_id: Optional[str], user: Dict[str, Any]) -> str:
298 """Resolve the authenticated user ID and reject mismatched requests.
300 Args:
301 input_user_id: User ID provided by the client (optional).
302 user: Authenticated user context from RBAC dependency.
304 Returns:
305 Resolved authenticated user identifier.
307 Raises:
308 HTTPException: When authentication is missing or user ID mismatches.
309 """
310 user_id = _get_user_id_from_context(user)
311 if user_id == "unknown":
312 raise HTTPException(status_code=401, detail="Authentication required.")
313 if input_user_id and input_user_id != user_id:
314 raise HTTPException(status_code=403, detail="User ID mismatch.")
315 return user_id
318# ---------- SESSION STORAGE HELPERS ----------
320# Identify this worker uniquely (used for sticky session ownership)
321WORKER_ID = str(os.getpid())
323# Tunables (can set via environment)
324SESSION_TTL = settings.llmchat_session_ttl # seconds for active_session key TTL
325LOCK_TTL = settings.llmchat_session_lock_ttl # seconds for lock expiry
326LOCK_RETRIES = settings.llmchat_session_lock_retries # how many times to poll while waiting
327LOCK_WAIT = settings.llmchat_session_lock_wait # seconds between polls
328USER_CONFIG_TTL = settings.llmchat_session_ttl
330_ENCRYPTED_CONFIG_PAYLOAD_KEY = "_encrypted_payload"
331_ENCRYPTED_CONFIG_VERSION_KEY = "_version"
332_ENCRYPTED_CONFIG_VERSION = "1"
333_CONFIG_SENSITIVE_KEYS = frozenset(
334 {
335 "api_key",
336 "auth_token",
337 "authorization",
338 "access_token",
339 "refresh_token",
340 "client_secret",
341 "secret_access_key",
342 "session_token",
343 "credentials_json",
344 "password",
345 "private_key",
346 }
347)
350# Redis key helpers
351def _cfg_key(user_id: str) -> str:
352 """Generate Redis key for user configuration storage.
354 Args:
355 user_id: User identifier.
357 Returns:
358 str: Redis key for storing user configuration.
359 """
360 return f"user_config:{user_id}"
363def _active_key(user_id: str) -> str:
364 """Generate Redis key for active session tracking.
366 Args:
367 user_id: User identifier.
369 Returns:
370 str: Redis key for tracking active sessions.
371 """
372 return f"active_session:{user_id}"
375def _lock_key(user_id: str) -> str:
376 """Generate Redis key for session initialization lock.
378 Args:
379 user_id: User identifier.
381 Returns:
382 str: Redis key for session locks.
383 """
384 return f"session_lock:{user_id}"
387def _serialize_user_config_for_storage(config: MCPClientConfig) -> bytes:
388 """Serialize and encrypt user config for storage backends.
390 Args:
391 config: User MCP client configuration.
393 Returns:
394 Serialized bytes containing encrypted config envelope.
395 """
396 payload = encode_auth({"config": config.model_dump()})
397 return orjson.dumps(
398 {
399 _ENCRYPTED_CONFIG_VERSION_KEY: _ENCRYPTED_CONFIG_VERSION,
400 _ENCRYPTED_CONFIG_PAYLOAD_KEY: payload,
401 }
402 )
405def _deserialize_user_config_from_storage(data: bytes | str) -> Optional[MCPClientConfig]:
406 """Deserialize user config from encrypted or legacy plaintext payloads.
408 Args:
409 data: Serialized config payload from storage.
411 Returns:
412 Parsed ``MCPClientConfig`` when data is valid, otherwise ``None``.
413 """
414 try:
415 parsed = orjson.loads(data)
416 except Exception:
417 logger.warning("Failed to parse stored LLM chat config payload")
418 return None
420 # New encrypted envelope
421 if isinstance(parsed, dict) and _ENCRYPTED_CONFIG_PAYLOAD_KEY in parsed:
422 encrypted_payload = parsed.get(_ENCRYPTED_CONFIG_PAYLOAD_KEY)
423 if not encrypted_payload:
424 return None
425 decoded = decode_auth(encrypted_payload)
426 config_data = decoded.get("config") if isinstance(decoded, dict) else None
427 if not isinstance(config_data, dict):
428 logger.warning("Decoded encrypted LLM chat config is invalid")
429 return None
430 return MCPClientConfig(**config_data)
432 # Legacy plaintext payload compatibility
433 if isinstance(parsed, dict):
434 return MCPClientConfig(**parsed)
436 return None
439def _is_sensitive_config_key(key: str) -> bool:
440 """Return whether a config key should be masked in responses.
442 Args:
443 key: Config field name.
445 Returns:
446 ``True`` when key is in the sensitive-key allowlist.
447 """
448 return str(key).strip().lower() in _CONFIG_SENSITIVE_KEYS
451def _mask_sensitive_config_values(value: Any) -> Any:
452 """Recursively mask sensitive config values in API responses.
454 Args:
455 value: Arbitrary nested config value.
457 Returns:
458 Value with sensitive fields replaced by configured mask marker.
459 """
460 if isinstance(value, dict):
461 masked: Dict[str, Any] = {}
462 for key, item in value.items():
463 if _is_sensitive_config_key(key):
464 masked[key] = settings.masked_auth_value if item not in (None, "") else item
465 else:
466 masked[key] = _mask_sensitive_config_values(item)
467 return masked
468 if isinstance(value, list):
469 return [_mask_sensitive_config_values(item) for item in value]
470 return value
473# ---------- CONFIG HELPERS ----------
476async def set_user_config(user_id: str, config: MCPClientConfig):
477 """Store user configuration in Redis or memory.
479 Args:
480 user_id: User identifier.
481 config: Complete MCP client configuration.
482 """
483 serialized = _serialize_user_config_for_storage(config)
484 if redis_client:
485 await redis_client.set(_cfg_key(user_id), serialized, ex=USER_CONFIG_TTL)
486 else:
487 user_configs[user_id] = (serialized, time.monotonic())
490async def get_user_config(user_id: str) -> Optional[MCPClientConfig]:
491 """Retrieve user configuration from Redis or memory.
493 Args:
494 user_id: User identifier.
496 Returns:
497 Optional[MCPClientConfig]: User configuration if found, None otherwise.
498 """
499 if redis_client:
500 data = await redis_client.get(_cfg_key(user_id))
501 if not data:
502 return None
503 return _deserialize_user_config_from_storage(data)
505 cached_entry = user_configs.get(user_id)
506 if not cached_entry:
507 return None
508 cached, cached_at = cached_entry
509 if (time.monotonic() - cached_at) > USER_CONFIG_TTL:
510 user_configs.pop(user_id, None)
511 return None
512 return _deserialize_user_config_from_storage(cached)
515async def delete_user_config(user_id: str):
516 """Delete user configuration from Redis or memory.
518 Args:
519 user_id: User identifier.
520 """
521 if redis_client:
522 await redis_client.delete(_cfg_key(user_id))
523 else:
524 user_configs.pop(user_id, None)
527# ---------- SESSION (active) HELPERS with locking & recreate ----------
530async def set_active_session(user_id: str, session: MCPChatService):
531 """Register an active session locally and mark ownership in Redis with TTL.
533 Args:
534 user_id: User identifier.
535 session: Initialized MCPChatService instance.
536 """
537 active_sessions[user_id] = session
538 if redis_client:
539 # set owner with TTL so dead workers eventually lose ownership
540 await redis_client.set(_active_key(user_id), WORKER_ID, ex=SESSION_TTL)
543async def delete_active_session(user_id: str):
544 """Remove active session locally and from Redis atomically.
546 Uses a Lua script to ensure we only delete the Redis key if we own it,
547 preventing race conditions where another worker's session marker could
548 be deleted if our session expired and was recreated by another worker.
550 Args:
551 user_id: User identifier.
552 """
553 active_sessions.pop(user_id, None)
554 if redis_client:
555 try:
556 # Lua script for atomic check-and-delete (only delete if we own the key)
557 release_script = """
558 if redis.call("get", KEYS[1]) == ARGV[1] then
559 return redis.call("del", KEYS[1])
560 else
561 return 0
562 end
563 """
564 await redis_client.eval(release_script, 1, _active_key(user_id), WORKER_ID)
565 except Exception as e:
566 logger.warning(f"Failed to delete active session for user {user_id}: {e}")
569async def _try_acquire_lock(user_id: str) -> bool:
570 """Attempt to acquire the initialization lock for a user session.
572 Args:
573 user_id: User identifier.
575 Returns:
576 bool: True if lock acquired, False otherwise.
577 """
578 if not redis_client:
579 return True # no redis -> local only, no lock required
580 return await redis_client.set(_lock_key(user_id), WORKER_ID, nx=True, ex=LOCK_TTL)
583async def _release_lock_safe(user_id: str):
584 """Release the lock atomically only if we own it.
586 Uses a Lua script to ensure atomic check-and-delete, preventing
587 the TOCTOU race condition where another worker's lock could be
588 deleted if the original lock expired between get() and delete().
590 Args:
591 user_id: User identifier.
592 """
593 if not redis_client:
594 return
595 try:
596 # Lua script for atomic check-and-delete (only delete if we own the key)
597 release_script = """
598 if redis.call("get", KEYS[1]) == ARGV[1] then
599 return redis.call("del", KEYS[1])
600 else
601 return 0
602 end
603 """
604 await redis_client.eval(release_script, 1, _lock_key(user_id), WORKER_ID)
605 except Exception as e:
606 logger.warning(f"Failed to release lock for user {user_id}: {e}")
609async def _create_local_session_from_config(user_id: str) -> Optional[MCPChatService]:
610 """Create MCPChatService locally from stored config.
612 Args:
613 user_id: User identifier.
615 Returns:
616 Optional[MCPChatService]: Initialized service or None if creation fails.
617 """
618 config = await get_user_config(user_id)
619 if not config:
620 return None
622 # create and initialize with unified history manager
623 try:
624 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client)
625 await chat_service.initialize()
626 await set_active_session(user_id, chat_service)
627 return chat_service
628 except Exception as e:
629 # If initialization fails, ensure nothing partial remains
630 logger.error(f"Failed to initialize MCPChatService for {user_id}: {e}", exc_info=True)
631 # cleanup local state and redis ownership (if we set it)
632 await delete_active_session(user_id)
633 return None
636async def get_active_session(user_id: str) -> Optional[MCPChatService]:
637 """
638 Retrieve or (if possible) create the active session for user_id.
640 Behavior:
641 - If Redis is disabled: return local session or None.
642 - If Redis enabled:
643 * If owner == WORKER_ID and local session exists -> return it (and refresh TTL)
644 * If owner == WORKER_ID but local missing -> try to acquire lock and recreate
645 * If no owner -> try to acquire lock and create session here
646 * If owner != WORKER_ID -> wait a short time for owner to appear or return None
648 Args:
649 user_id: User identifier.
651 Returns:
652 Optional[MCPChatService]: Active session if available, None otherwise.
653 """
654 # Fast path: no redis => purely local
655 if not redis_client:
656 return active_sessions.get(user_id)
658 active_key = _active_key(user_id)
659 # _lock_key = _lock_key(user_id)
660 owner = await redis_client.get(active_key)
662 # 1) Owned by this worker
663 if owner == WORKER_ID:
664 local = active_sessions.get(user_id)
665 if local:
666 # refresh TTL so ownership persists while active
667 try:
668 await redis_client.expire(active_key, SESSION_TTL)
669 except Exception as e: # nosec B110
670 # non-fatal if expire fails, just log the error
671 logger.debug(f"Failed to refresh session TTL for {user_id}: {e}")
672 return local
674 # Owner in Redis points to this worker but local session missing (process restart or lost).
675 # Try to recreate it (acquire lock).
676 acquired = await _try_acquire_lock(user_id)
677 if acquired:
678 try:
679 # create new local session
680 session = await _create_local_session_from_config(user_id)
681 return session
682 finally:
683 await _release_lock_safe(user_id)
684 else:
685 # someone else is (re)creating; wait a bit for them to finish
686 for _ in range(LOCK_RETRIES):
687 await asyncio.sleep(LOCK_WAIT)
688 if active_sessions.get(user_id):
689 return active_sessions.get(user_id)
690 return None
692 # 2) No owner -> try to claim & create session locally
693 if owner is None:
694 acquired = await _try_acquire_lock(user_id)
695 if acquired:
696 try:
697 session = await _create_local_session_from_config(user_id)
698 return session
699 finally:
700 await _release_lock_safe(user_id)
702 # if we couldn't acquire lock, someone else is creating; wait a short time
703 for _ in range(LOCK_RETRIES):
704 await asyncio.sleep(LOCK_WAIT)
705 owner2 = await redis_client.get(active_key)
706 if owner2 == WORKER_ID and active_sessions.get(user_id):
707 return active_sessions.get(user_id)
708 if owner2 is not None and owner2 != WORKER_ID:
709 # some other worker now owns it
710 return None
712 # final attempt to acquire lock (last resort)
713 acquired = await _try_acquire_lock(user_id)
714 if acquired:
715 try:
716 session = await _create_local_session_from_config(user_id)
717 return session
718 finally:
719 await _release_lock_safe(user_id)
720 return None
722 # 3) Owned by another worker -> we don't have it locally
723 # Optionally we could attempt to "steal" if owner is stale, but TTL expiry handles that.
724 return None
727# ---------- ROUTES ----------
730@llmchat_router.post("/connect")
731@require_permission("llm.invoke")
732async def connect(input_data: ConnectInput, request: Request, user=Depends(get_current_user_with_permissions)):
733 """Create or refresh a chat session for a user.
735 Initializes a new MCPChatService instance for the specified user, establishing
736 connections to both the MCP server and the configured LLM provider. If a session
737 already exists for the user, it is gracefully shutdown before creating a new one.
739 Authentication is handled via JWT token from cookies if not explicitly provided
740 in the request body.
742 Args:
743 input_data: ConnectInput containing user_id, optional server/LLM config, and streaming preference.
744 request: FastAPI Request object for accessing cookies and headers.
745 user: Authenticated user context.
747 Returns:
748 dict: Connection status response containing:
749 - status: 'connected'
750 - user_id: The connected user's identifier
751 - provider: The LLM provider being used
752 - tool_count: Number of available MCP tools
753 - tools: List of tool names
755 Raises:
756 HTTPException: If an error occurs.
757 400: Invalid user_id, invalid configuration, or LLM config error.
758 401: Missing authentication token.
759 503: Failed to connect to MCP server.
760 500: Service initialization failure or unexpected error.
762 Examples:
763 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
764 Example request body:
766 {
767 "user_id": "user123",
768 "server": {
769 "url": "http://localhost:8000/mcp",
770 "auth_token": "jwt_token_here"
771 },
772 "llm": {
773 "provider": "ollama",
774 "config": {"model": "llama3"}
775 },
776 "streaming": false
777 }
779 Example response:
781 {
782 "status": "connected",
783 "user_id": "user123",
784 "provider": "ollama",
785 "tool_count": 5,
786 "tools": ["search", "calculator", "weather", "translate", "summarize"]
787 }
789 Note:
790 Existing sessions are automatically terminated before establishing new ones.
791 All configuration values support environment variable fallbacks.
792 """
793 user_id = _resolve_user_id(input_data.user_id, user)
795 try:
796 # Validate user_id
797 if not user_id or not isinstance(user_id, str):
798 raise HTTPException(status_code=400, detail="Invalid user ID provided")
800 # Validate user-supplied server URLs with SSRF protections before any outbound connection setup.
801 if input_data.server and input_data.server.url:
802 try:
803 input_data.server.url = SecurityValidator.validate_url(str(input_data.server.url), "MCP server URL")
804 except ValueError as e:
805 logger.warning("LLM chat connect URL validation failed for user %s and URL %s: %s", user_id, input_data.server.url, e)
806 raise HTTPException(status_code=400, detail="Invalid server URL")
808 # Handle authentication token
809 empty_token = "" # nosec B105
810 if input_data.server and (input_data.server.auth_token is None or input_data.server.auth_token == empty_token):
811 jwt_token = request.cookies.get("jwt_token")
812 if not jwt_token:
813 raise HTTPException(status_code=401, detail="Authentication required. Please ensure you are logged in.")
814 input_data.server.auth_token = jwt_token
816 # Close old session if it exists
817 existing = await get_active_session(user_id)
818 if existing:
819 try:
820 logger.debug(f"Disconnecting existing session for {user_id} before reconnecting")
821 await existing.shutdown()
822 except Exception as shutdown_error:
823 logger.warning(f"Failed to cleanly shutdown existing session for {user_id}: {shutdown_error}")
824 finally:
825 # Always remove the session from active sessions, even if shutdown failed
826 await delete_active_session(user_id)
828 # Build and validate configuration
829 try:
830 config = build_config(input_data)
831 except ValueError as ve:
832 raise HTTPException(status_code=400, detail=f"Invalid configuration: {str(ve)}")
833 except Exception as config_error:
834 raise HTTPException(status_code=400, detail=f"Configuration error: {str(config_error)}")
836 # Store user configuration
837 await set_user_config(user_id, config)
839 # Initialize chat service
840 try:
841 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client)
842 await chat_service.initialize()
844 # Clear chat history on new connection
845 await chat_service.clear_history()
846 except ConnectionError as ce:
847 # Clean up partial state
848 await delete_user_config(user_id)
849 raise HTTPException(status_code=503, detail=f"Failed to connect to MCP server: {str(ce)}. Please verify the server URL and authentication.")
850 except ValueError as ve:
851 # Clean up partial state
852 await delete_user_config(user_id)
853 raise HTTPException(status_code=400, detail=f"Invalid LLM configuration: {str(ve)}")
854 except Exception as init_error:
855 # Clean up partial state
856 await delete_user_config(user_id)
857 raise HTTPException(status_code=500, detail=f"Service initialization failed: {str(init_error)}")
859 await set_active_session(user_id, chat_service)
861 # Extract tool names
862 tool_names = []
863 try:
864 if hasattr(chat_service, "_tools") and chat_service._tools:
865 for tool in chat_service._tools:
866 tool_name = getattr(tool, "name", None)
867 if tool_name:
868 tool_names.append(tool_name)
869 except Exception as tool_error:
870 logger.warning(f"Failed to extract tool names: {tool_error}")
871 # Continue without tools list
873 return {"status": "connected", "user_id": user_id, "provider": config.llm.provider, "tool_count": len(tool_names), "tools": tool_names}
875 except HTTPException:
876 # Re-raise HTTP exceptions as-is
877 raise
878 except Exception as e:
879 logger.error(f"Unexpected error in connect endpoint: {e}", exc_info=True)
880 raise HTTPException(status_code=500, detail=f"Unexpected connection error: {str(e)}")
883async def token_streamer(chat_service: MCPChatService, message: str, user_id: str):
884 """Stream chat response tokens as Server-Sent Events (SSE).
886 Asynchronous generator that yields SSE-formatted chunks containing tokens,
887 tool invocation updates, and final response data from the chat service.
888 Uses the unified ChatHistoryManager for history persistence.
890 Args:
891 chat_service: MCPChatService instance configured for the user session.
892 message: User's chat message to process.
893 user_id: User identifier for logging.
895 Yields:
896 bytes: SSE-formatted event data containing:
897 - token events: Incremental content chunks
898 - tool_start: Tool invocation beginning
899 - tool_end: Tool invocation completion
900 - tool_error: Tool execution failure
901 - final: Complete response with metadata
902 - error: Error information with recovery status
904 Event Types:
905 - token: {"content": "text chunk"}
906 - tool_start: {"type": "tool_start", "tool": "name", ...}
907 - tool_end: {"type": "tool_end", "tool": "name", ...}
908 - tool_error: {"type": "tool_error", "tool": "name", "error": "message"}
909 - final: {"type": "final", "text": "complete response", "metadata": {...}}
910 - error: {"type": "error", "error": "message", "recoverable": bool}
912 Examples:
913 This is an async generator used internally by the chat endpoint.
914 It cannot be directly tested with standard doctest.
916 Example event stream:
918 event: token
919 data: {"content": "Hello"}
921 event: token
922 data: {"content": ", how"}
924 event: final
925 data: {"type": "final", "text": "Hello, how can I help?"}
927 Note:
928 SSE format requires 'event: <type>\\ndata: <json>\\n\\n' structure.
929 All exceptions are caught and converted to error events for client handling.
930 """
932 async def sse(event_type: str, data: Dict[str, Any]):
933 """Format data as Server-Sent Event.
935 Args:
936 event_type: SSE event type identifier.
937 data: Payload dictionary to serialize as JSON.
939 Yields:
940 bytes: UTF-8 encoded SSE formatted lines.
941 """
942 yield f"event: {event_type}\n".encode("utf-8")
943 yield f"data: {orjson.dumps(data).decode()}\n\n".encode("utf-8")
945 try:
946 async for ev in chat_service.chat_events(message):
947 et = ev.get("type")
948 if et == "token":
949 content = ev.get("content", "")
950 async for part in sse("token", {"content": content}):
951 yield part
952 elif et in ("tool_start", "tool_end", "tool_error"):
953 async for part in sse(et, ev):
954 yield part
955 elif et == "final":
956 async for part in sse("final", ev):
957 yield part
959 except ConnectionError as ce:
960 error_event = {"type": "error", "error": f"Connection lost: {str(ce)}", "recoverable": False}
961 async for part in sse("error", error_event):
962 yield part
963 except TimeoutError:
964 error_event = {"type": "error", "error": "Request timed out waiting for LLM response", "recoverable": True}
965 async for part in sse("error", error_event):
966 yield part
967 except RuntimeError as re:
968 error_event = {"type": "error", "error": f"Service error: {str(re)}", "recoverable": False}
969 async for part in sse("error", error_event):
970 yield part
971 except Exception as e:
972 logger.error(f"Unexpected streaming error: {e}", exc_info=True)
973 error_event = {"type": "error", "error": f"Unexpected error: {str(e)}", "recoverable": False}
974 async for part in sse("error", error_event):
975 yield part
978@llmchat_router.post("/chat")
979@require_permission("llm.invoke")
980async def chat(input_data: ChatInput, user=Depends(get_current_user_with_permissions)):
981 """Send a message to the user's active chat session and receive a response.
983 Processes user messages through the configured LLM with MCP tool integration.
984 Supports both streaming (SSE) and non-streaming response modes. Chat history
985 is managed automatically via the unified ChatHistoryManager.
987 Args:
988 input_data: ChatInput containing user_id, message, and streaming preference.
989 user: Authenticated user context.
991 Returns:
992 For streaming=False:
993 dict: Response containing:
994 - user_id: Session identifier
995 - response: Complete LLM response text
996 - tool_used: Boolean indicating if tools were invoked
997 - tools: List of tool names used
998 - tool_invocations: Detailed tool call information
999 - elapsed_ms: Processing time in milliseconds
1000 For streaming=True:
1001 StreamingResponse: SSE stream of token and event data.
1003 Raises:
1004 HTTPException: Raised when an HTTP-related error occurs.
1005 400: Missing user_id, empty message, or no active session.
1006 503: Session not initialized, chat service error, or connection lost.
1007 504: Request timeout.
1008 500: Unexpected error.
1010 Examples:
1011 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
1013 Example non-streaming request:
1015 {
1016 "user_id": "user123",
1017 "message": "What's the weather like?",
1018 "streaming": false
1019 }
1021 Example non-streaming response:
1023 {
1024 "user_id": "user123",
1025 "response": "The weather is sunny and 72°F.",
1026 "tool_used": true,
1027 "tools": ["weather"],
1028 "tool_invocations": 1,
1029 "elapsed_ms": 450
1030 }
1032 Example streaming request:
1034 {
1035 "user_id": "user123",
1036 "message": "Tell me a story",
1037 "streaming": true
1038 }
1040 Note:
1041 Streaming responses use Server-Sent Events (SSE) with 'text/event-stream' MIME type.
1042 Client must maintain persistent connection for streaming.
1043 """
1044 user_id = _resolve_user_id(input_data.user_id, user)
1046 # Validate input
1047 if not user_id:
1048 raise HTTPException(status_code=400, detail="User ID is required")
1050 if not input_data.message or not input_data.message.strip():
1051 raise HTTPException(status_code=400, detail="Message cannot be empty")
1053 # Check for active session
1054 chat_service = await get_active_session(user_id)
1055 if not chat_service:
1056 raise HTTPException(status_code=400, detail="No active session found. Please connect to a server first.")
1058 # Verify session is initialized
1059 if not chat_service.is_initialized:
1060 raise HTTPException(status_code=503, detail="Session is not properly initialized. Please reconnect.")
1062 try:
1063 if input_data.streaming:
1064 return StreamingResponse(
1065 token_streamer(chat_service, input_data.message, user_id),
1066 media_type="text/event-stream",
1067 headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, # Disable proxy buffering
1068 )
1069 else:
1070 try:
1071 result = await chat_service.chat_with_metadata(input_data.message)
1073 return {
1074 "user_id": user_id,
1075 "response": result["text"],
1076 "tool_used": result["tool_used"],
1077 "tools": result["tools"],
1078 "tool_invocations": result["tool_invocations"],
1079 "elapsed_ms": result["elapsed_ms"],
1080 }
1081 except RuntimeError as re:
1082 raise HTTPException(status_code=503, detail=f"Chat service error: {str(re)}")
1084 except ConnectionError as ce:
1085 raise HTTPException(status_code=503, detail=f"Lost connection to MCP server: {str(ce)}. Please reconnect.")
1086 except TimeoutError:
1087 raise HTTPException(status_code=504, detail="Request timed out. The LLM took too long to respond.")
1088 except HTTPException:
1089 raise
1090 except Exception as e:
1091 logger.error(f"Unexpected error in chat endpoint for user {user_id}: {e}", exc_info=True)
1092 raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
1095@llmchat_router.post("/disconnect")
1096@require_permission("llm.invoke")
1097async def disconnect(input_data: DisconnectInput, user=Depends(get_current_user_with_permissions)):
1098 """End the chat session for a user and clean up resources.
1100 Gracefully shuts down the MCPChatService instance, closes connections,
1101 and removes session data from active storage. Safe to call even if
1102 no active session exists.
1104 Args:
1105 input_data: DisconnectInput containing the user_id to disconnect.
1106 user: Authenticated user context.
1108 Returns:
1109 dict: Disconnection status containing:
1110 - status: One of 'disconnected', 'no_active_session', or 'disconnected_with_errors'
1111 - user_id: The user identifier
1112 - message: Human-readable status description
1113 - warning: (Optional) Error details if cleanup encountered issues
1115 Raises:
1116 HTTPException: Raised when an HTTP-related error occurs.
1117 400: Missing user_id.
1119 Examples:
1120 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
1122 Example request:
1124 {
1125 "user_id": "user123"
1126 }
1128 Example successful response:
1130 {
1131 "status": "disconnected",
1132 "user_id": "user123",
1133 "message": "Successfully disconnected"
1134 }
1136 Example response when no session exists:
1138 {
1139 "status": "no_active_session",
1140 "user_id": "user123",
1141 "message": "No active session to disconnect"
1142 }
1144 Note:
1145 This operation is idempotent - calling it multiple times for the same
1146 user_id is safe and will not raise errors.
1147 """
1148 user_id = _resolve_user_id(input_data.user_id, user)
1150 if not user_id:
1151 raise HTTPException(status_code=400, detail="User ID is required")
1153 # Remove and shut down chat service
1154 chat_service = await get_active_session(user_id)
1155 await delete_active_session(user_id)
1157 # Remove user config
1158 await delete_user_config(user_id)
1160 if not chat_service:
1161 return {"status": "no_active_session", "user_id": user_id, "message": "No active session to disconnect"}
1163 try:
1164 # Clear chat history on disconnect
1165 await chat_service.clear_history()
1166 logger.info(f"Chat session disconnected for {user_id}")
1168 await chat_service.shutdown()
1169 return {"status": "disconnected", "user_id": user_id, "message": "Successfully disconnected"}
1170 except Exception as e:
1171 logger.error(f"Error during disconnect for user {user_id}: {e}", exc_info=True)
1172 # Session already removed, so return success with warning
1173 return {"status": "disconnected_with_errors", "user_id": user_id, "message": "Disconnected but cleanup encountered errors", "warning": str(e)}
1176@llmchat_router.get("/status/{user_id}")
1177@require_permission("llm.read")
1178async def status(user_id: str, user=Depends(get_current_user_with_permissions)):
1179 """Check if an active chat session exists for the specified user.
1181 Lightweight endpoint for verifying session state without modifying data.
1182 Useful for health checks and UI state management.
1184 Args:
1185 user_id: User identifier to check session status for.
1186 user: Authenticated user context.
1188 Returns:
1189 dict: Status information containing:
1190 - user_id: The queried user identifier
1191 - connected: Boolean indicating if an active session exists
1193 Examples:
1194 This endpoint is called via HTTP GET and cannot be directly tested with doctest.
1196 Example request:
1197 GET /llmchat/status/user123
1199 Example response (connected):
1201 {
1202 "user_id": "user123",
1203 "connected": true
1204 }
1206 Example response (not connected):
1208 {
1209 "user_id": "user456",
1210 "connected": false
1211 }
1213 Note:
1214 This endpoint does not validate that the session is properly initialized,
1215 only that it exists in the active_sessions dictionary.
1216 """
1217 resolved_user_id = _resolve_user_id(user_id, user)
1218 connected = bool(await get_active_session(resolved_user_id))
1219 return {"user_id": resolved_user_id, "connected": connected}
1222@llmchat_router.get("/config/{user_id}")
1223@require_permission("llm.read")
1224async def get_config(user_id: str, user=Depends(get_current_user_with_permissions)):
1225 """Retrieve the stored configuration for a user's session.
1227 Returns sanitized configuration data with sensitive information (API keys,
1228 auth tokens) removed for security. Useful for debugging and configuration
1229 verification.
1231 Args:
1232 user_id: User identifier whose configuration to retrieve.
1233 user: Authenticated user context.
1235 Returns:
1236 dict: Sanitized configuration dictionary containing:
1237 - mcp_server: Server connection settings (without auth_token)
1238 - llm: LLM provider configuration (without api_key)
1239 - enable_streaming: Boolean streaming preference
1241 Raises:
1242 HTTPException: Raised when an HTTP-related error occurs.
1243 404: No configuration found for the specified user_id.
1246 Examples:
1247 This endpoint is called via HTTP GET and cannot be directly tested with doctest.
1249 Example request:
1250 GET /llmchat/config/user123
1252 Example response:
1254 {
1255 "mcp_server": {
1256 "url": "http://localhost:8000/mcp",
1257 "transport": "streamable_http"
1258 },
1259 "llm": {
1260 "provider": "ollama",
1261 "config": {
1262 "model": "llama3",
1263 "temperature": 0.7
1264 }
1265 },
1266 "enable_streaming": false
1267 }
1269 Security:
1270 API keys and authentication tokens are explicitly removed before returning.
1271 Never log or expose these values in responses.
1272 """
1273 resolved_user_id = _resolve_user_id(user_id, user)
1274 config = await get_user_config(resolved_user_id)
1276 if not config:
1277 raise HTTPException(status_code=404, detail="No config found for this user.")
1279 config_dict = config.model_dump()
1280 return _mask_sensitive_config_values(config_dict)
1283@llmchat_router.get("/gateway/models")
1284@require_permission("llm.read")
1285async def get_gateway_models(_user=Depends(get_current_user_with_permissions)):
1286 """Get available models from configured LLM providers.
1288 Returns a list of enabled models from enabled providers configured
1289 in the gateway's LLM Settings. These models can be used with the
1290 "gateway" provider type in /connect requests.
1292 Returns:
1293 dict: Response containing:
1294 - models: List of available models with provider info
1295 - count: Total number of available models
1297 Examples:
1298 GET /llmchat/gateway/models
1300 Response:
1301 {
1302 "models": [
1303 {
1304 "id": "abc123",
1305 "model_id": "gpt-4o",
1306 "model_name": "GPT-4o",
1307 "provider_id": "def456",
1308 "provider_name": "OpenAI",
1309 "provider_type": "openai",
1310 "supports_streaming": true,
1311 "supports_function_calling": true,
1312 "supports_vision": true
1313 }
1314 ],
1315 "count": 1
1316 }
1318 Raises:
1319 HTTPException: If there is an error retrieving gateway models.
1321 Args:
1322 _user: Authenticated user context.
1323 """
1324 # Import here to avoid circular dependency
1325 # First-Party
1326 from mcpgateway.db import SessionLocal
1327 from mcpgateway.services.llm_provider_service import LLMProviderService
1329 llm_service = LLMProviderService()
1331 try:
1332 with SessionLocal() as db:
1333 models = llm_service.get_gateway_models(db)
1334 return {
1335 "models": [m.model_dump() for m in models],
1336 "count": len(models),
1337 }
1338 except Exception as e:
1339 logger.error(f"Failed to get gateway models: {e}")
1340 raise HTTPException(status_code=500, detail=f"Failed to retrieve gateway models: {str(e)}")