Coverage for mcpgateway / routers / llmchat_router.py: 99%
401 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/llmchat_router.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7LLM Chat Router Module
9This module provides FastAPI endpoints for managing LLM-based chat sessions
10with MCP (Model Context Protocol) server integration. LLM providers are
11configured via the Admin UI's LLM Settings and accessed through the gateway
12provider.
14The module handles user session management, configuration, and real-time
15streaming responses for conversational AI applications with unified chat
16history management via ChatHistoryManager from mcp_client_chat_service.
18"""
20# Standard
21import asyncio
22import os
23import time
24from typing import Any, Dict, Optional
26# Third-Party
27from fastapi import APIRouter, Depends, HTTPException, Request
28from fastapi.responses import StreamingResponse
29import orjson
30from pydantic import BaseModel, Field
32try:
33 # Third-Party
34 import redis.asyncio # noqa: F401 - availability check only
36 REDIS_AVAILABLE = True
37except ImportError:
38 REDIS_AVAILABLE = False
40# First-Party
41from mcpgateway.common.validators import SecurityValidator
42from mcpgateway.config import settings
43from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission
44from mcpgateway.services.logging_service import LoggingService
45from mcpgateway.services.mcp_client_chat_service import (
46 ChatProcessingError,
47 GatewayConfig,
48 LLMConfig,
49 MCPChatService,
50 MCPClientConfig,
51 MCPServerConfig,
52)
53from mcpgateway.utils.redis_client import get_redis_client
54from mcpgateway.utils.services_auth import decode_auth, encode_auth
56# Initialize router
57llmchat_router = APIRouter(prefix="/llmchat", tags=["llmchat"])
59# Redis client (initialized via init_redis() during app startup)
60redis_client = None
63async def init_redis() -> None:
64 """Initialize Redis client using the shared factory.
66 Should be called during application startup from main.py lifespan.
67 """
68 global redis_client
69 if getattr(settings, "cache_type", None) == "redis" and getattr(settings, "redis_url", None):
70 redis_client = await get_redis_client()
71 if redis_client:
72 logger.info("LLMChat router connected to shared Redis client")
75# Fallback in-memory stores (used when Redis unavailable)
76# Store active chat sessions per user
77active_sessions: Dict[str, MCPChatService] = {}
79# Store configuration per user
80user_configs: Dict[str, tuple[bytes, float]] = {}
82# Logging
83logging_service = LoggingService()
84logger = logging_service.get_logger(__name__)
86# ---------- MODELS ----------
89class LLMInput(BaseModel):
90 """Input configuration for LLM provider selection.
92 This model specifies which gateway-configured model to use.
93 Models must be configured via Admin UI -> LLM Settings.
95 Attributes:
96 model: Model ID from the gateway's LLM Settings (UUID or model_id).
97 temperature: Optional sampling temperature (0.0-2.0).
98 max_tokens: Optional maximum tokens to generate.
100 Examples:
101 >>> llm_input = LLMInput(model='gpt-4o')
102 >>> llm_input.model
103 'gpt-4o'
105 >>> llm_input = LLMInput(model='abc123-uuid', temperature=0.5)
106 >>> llm_input.temperature
107 0.5
108 """
110 model: str = Field(..., description="Model ID from gateway LLM Settings (UUID or model_id)")
111 temperature: Optional[float] = Field(None, ge=0.0, le=2.0, description="Sampling temperature")
112 max_tokens: Optional[int] = Field(None, gt=0, description="Maximum tokens to generate")
115class ServerInput(BaseModel):
116 """Input configuration for MCP server connection.
118 Defines the connection parameters required to establish communication
119 with an MCP (Model Context Protocol) server.
121 Attributes:
122 url: Optional MCP server URL endpoint. Defaults to environment variable
123 or 'http://localhost:8000/mcp'.
124 transport: Communication transport protocol. Defaults to 'streamable_http'.
125 auth_token: Optional authentication token for secure server access.
127 Examples:
128 >>> server = ServerInput(url='http://example.com/mcp')
129 >>> server.transport
130 'streamable_http'
132 >>> server = ServerInput()
133 >>> server.url is None
134 True
135 """
137 url: Optional[str] = None
138 transport: Optional[str] = "streamable_http"
139 auth_token: Optional[str] = None
142class ConnectInput(BaseModel):
143 """Request model for establishing a new chat session.
145 Contains all necessary parameters to initialize a user's chat session
146 including server connection details, LLM configuration, and streaming preferences.
148 Attributes:
149 user_id: Unique identifier for the user session. Required for session management.
150 server: Optional MCP server configuration. Uses defaults if not provided.
151 llm: LLM configuration specifying which gateway model to use. Required.
152 streaming: Whether to enable streaming responses. Defaults to False.
154 Examples:
155 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o'))
156 >>> connect.streaming
157 False
159 >>> connect = ConnectInput(user_id='user456', llm=LLMInput(model='gpt-4o'), streaming=True)
160 >>> connect.user_id
161 'user456'
162 """
164 user_id: str
165 server: Optional[ServerInput] = None
166 llm: LLMInput = Field(..., description="LLM configuration with model from gateway LLM Settings")
167 streaming: bool = False
170class ChatInput(BaseModel):
171 """Request model for sending chat messages.
173 Encapsulates user message data for processing by the chat service.
175 Attributes:
176 user_id: Unique identifier for the active user session.
177 message: The chat message content to be processed.
178 streaming: Whether to stream the response. Defaults to False.
180 Examples:
181 >>> chat = ChatInput(user_id='user123', message='Hello, AI!')
182 >>> len(chat.message) > 0
183 True
185 >>> chat = ChatInput(user_id='user456', message='Tell me a story', streaming=True)
186 >>> chat.streaming
187 True
188 """
190 user_id: str
191 message: str
192 streaming: bool = False
195class DisconnectInput(BaseModel):
196 """Request model for terminating a chat session.
198 Simple model containing only the user identifier for session cleanup.
200 Attributes:
201 user_id: Unique identifier of the session to disconnect.
203 Examples:
204 >>> disconnect = DisconnectInput(user_id='user123')
205 >>> disconnect.user_id
206 'user123'
207 """
209 user_id: str
212# ---------- HELPERS ----------
215def build_llm_config(llm: LLMInput) -> LLMConfig:
216 """Construct an LLMConfig object from input parameters.
218 Creates a gateway provider configuration that routes requests through
219 the gateway's LLM Settings. Models must be configured via Admin UI.
221 Args:
222 llm: LLMInput containing model ID and optional temperature/max_tokens.
224 Returns:
225 LLMConfig: Gateway provider configuration.
227 Examples:
228 >>> llm_input = LLMInput(model='gpt-4o')
229 >>> config = build_llm_config(llm_input)
230 >>> config.provider
231 'gateway'
233 Note:
234 All LLM configuration is done via Admin UI -> Settings -> LLM Settings.
235 The gateway provider looks up models from the database and creates
236 the appropriate LLM instance based on provider type.
237 """
238 return LLMConfig(
239 provider="gateway",
240 config=GatewayConfig(
241 model=llm.model,
242 temperature=llm.temperature if llm.temperature is not None else 0.7,
243 max_tokens=llm.max_tokens,
244 ),
245 )
248def build_config(input_data: ConnectInput) -> MCPClientConfig:
249 """Build complete MCP client configuration from connection input.
251 Constructs a comprehensive configuration object combining MCP server settings
252 and LLM configuration.
254 Args:
255 input_data: ConnectInput object containing server, LLM, and streaming settings.
257 Returns:
258 MCPClientConfig: Complete client configuration ready for service initialization.
260 Examples:
261 >>> from mcpgateway.routers.llmchat_router import ConnectInput, LLMInput, build_config
262 >>> connect = ConnectInput(user_id='user123', llm=LLMInput(model='gpt-4o'))
263 >>> config = build_config(connect)
264 >>> config.mcp_server.transport
265 'streamable_http'
267 Note:
268 MCP server settings use defaults if not provided.
269 LLM configuration routes through the gateway provider.
270 """
271 server = input_data.server
273 return MCPClientConfig(
274 mcp_server=MCPServerConfig(
275 url=server.url if server and server.url else "http://localhost:8000/mcp",
276 transport=server.transport if server and server.transport else "streamable_http",
277 auth_token=server.auth_token if server else None,
278 ),
279 llm=build_llm_config(input_data.llm),
280 enable_streaming=input_data.streaming,
281 )
284def _get_user_id_from_context(user: Dict[str, Any]) -> str:
285 """Extract a stable user identifier from the authenticated user context.
287 Args:
288 user: Authenticated user context from RBAC dependency.
290 Returns:
291 User identifier string or "unknown" if missing.
292 """
293 if isinstance(user, dict):
294 return user.get("id") or user.get("user_id") or user.get("sub") or user.get("email") or "unknown"
295 return "unknown" if user is None else str(getattr(user, "id", user))
298def _resolve_user_id(input_user_id: Optional[str], user: Dict[str, Any]) -> str:
299 """Resolve the authenticated user ID and reject mismatched requests.
301 Args:
302 input_user_id: User ID provided by the client (optional).
303 user: Authenticated user context from RBAC dependency.
305 Returns:
306 Resolved authenticated user identifier.
308 Raises:
309 HTTPException: When authentication is missing or user ID mismatches.
310 """
311 user_id = _get_user_id_from_context(user)
312 if user_id == "unknown":
313 raise HTTPException(status_code=401, detail="Authentication required.")
314 if input_user_id and input_user_id != user_id:
315 raise HTTPException(status_code=403, detail="User ID mismatch.")
316 return user_id
319# ---------- SESSION STORAGE HELPERS ----------
321# Identify this worker uniquely (used for sticky session ownership)
322WORKER_ID = str(os.getpid())
324# Tunables (can set via environment)
325SESSION_TTL = settings.llmchat_session_ttl # seconds for active_session key TTL
326LOCK_TTL = settings.llmchat_session_lock_ttl # seconds for lock expiry
327LOCK_RETRIES = settings.llmchat_session_lock_retries # how many times to poll while waiting
328LOCK_WAIT = settings.llmchat_session_lock_wait # seconds between polls
329USER_CONFIG_TTL = settings.llmchat_session_ttl
331_ENCRYPTED_CONFIG_PAYLOAD_KEY = "_encrypted_payload"
332_ENCRYPTED_CONFIG_VERSION_KEY = "_version"
333_ENCRYPTED_CONFIG_VERSION = "1"
334_CONFIG_SENSITIVE_KEYS = frozenset(
335 {
336 "api_key",
337 "auth_token",
338 "authorization",
339 "access_token",
340 "refresh_token",
341 "client_secret",
342 "secret_access_key",
343 "session_token",
344 "credentials_json",
345 "password",
346 "private_key",
347 }
348)
351# Redis key helpers
352def _cfg_key(user_id: str) -> str:
353 """Generate Redis key for user configuration storage.
355 Args:
356 user_id: User identifier.
358 Returns:
359 str: Redis key for storing user configuration.
360 """
361 return f"user_config:{user_id}"
364def _active_key(user_id: str) -> str:
365 """Generate Redis key for active session tracking.
367 Args:
368 user_id: User identifier.
370 Returns:
371 str: Redis key for tracking active sessions.
372 """
373 return f"active_session:{user_id}"
376def _lock_key(user_id: str) -> str:
377 """Generate Redis key for session initialization lock.
379 Args:
380 user_id: User identifier.
382 Returns:
383 str: Redis key for session locks.
384 """
385 return f"session_lock:{user_id}"
388def _serialize_user_config_for_storage(config: MCPClientConfig) -> bytes:
389 """Serialize and encrypt user config for storage backends.
391 Args:
392 config: User MCP client configuration.
394 Returns:
395 Serialized bytes containing encrypted config envelope.
396 """
397 payload = encode_auth({"config": config.model_dump()})
398 return orjson.dumps(
399 {
400 _ENCRYPTED_CONFIG_VERSION_KEY: _ENCRYPTED_CONFIG_VERSION,
401 _ENCRYPTED_CONFIG_PAYLOAD_KEY: payload,
402 }
403 )
406def _deserialize_user_config_from_storage(data: bytes | str) -> Optional[MCPClientConfig]:
407 """Deserialize user config from encrypted or legacy plaintext payloads.
409 Args:
410 data: Serialized config payload from storage.
412 Returns:
413 Parsed ``MCPClientConfig`` when data is valid, otherwise ``None``.
414 """
415 try:
416 parsed = orjson.loads(data)
417 except Exception:
418 logger.warning("Failed to parse stored LLM chat config payload")
419 return None
421 # New encrypted envelope
422 if isinstance(parsed, dict) and _ENCRYPTED_CONFIG_PAYLOAD_KEY in parsed:
423 encrypted_payload = parsed.get(_ENCRYPTED_CONFIG_PAYLOAD_KEY)
424 if not encrypted_payload:
425 return None
426 decoded = decode_auth(encrypted_payload)
427 config_data = decoded.get("config") if isinstance(decoded, dict) else None
428 if not isinstance(config_data, dict):
429 logger.warning("Decoded encrypted LLM chat config is invalid")
430 return None
431 return MCPClientConfig(**config_data)
433 # Legacy plaintext payload compatibility
434 if isinstance(parsed, dict):
435 return MCPClientConfig(**parsed)
437 return None
440def _is_sensitive_config_key(key: str) -> bool:
441 """Return whether a config key should be masked in responses.
443 Args:
444 key: Config field name.
446 Returns:
447 ``True`` when key is in the sensitive-key allowlist.
448 """
449 return str(key).strip().lower() in _CONFIG_SENSITIVE_KEYS
452def _mask_sensitive_config_values(value: Any) -> Any:
453 """Recursively mask sensitive config values in API responses.
455 Args:
456 value: Arbitrary nested config value.
458 Returns:
459 Value with sensitive fields replaced by configured mask marker.
460 """
461 if isinstance(value, dict):
462 masked: Dict[str, Any] = {}
463 for key, item in value.items():
464 if _is_sensitive_config_key(key):
465 masked[key] = settings.masked_auth_value if item not in (None, "") else item
466 else:
467 masked[key] = _mask_sensitive_config_values(item)
468 return masked
469 if isinstance(value, list):
470 return [_mask_sensitive_config_values(item) for item in value]
471 return value
474# ---------- CONFIG HELPERS ----------
477async def set_user_config(user_id: str, config: MCPClientConfig):
478 """Store user configuration in Redis or memory.
480 Args:
481 user_id: User identifier.
482 config: Complete MCP client configuration.
483 """
484 serialized = _serialize_user_config_for_storage(config)
485 if redis_client:
486 await redis_client.set(_cfg_key(user_id), serialized, ex=USER_CONFIG_TTL)
487 else:
488 user_configs[user_id] = (serialized, time.monotonic())
491async def get_user_config(user_id: str) -> Optional[MCPClientConfig]:
492 """Retrieve user configuration from Redis or memory.
494 Args:
495 user_id: User identifier.
497 Returns:
498 Optional[MCPClientConfig]: User configuration if found, None otherwise.
499 """
500 if redis_client:
501 data = await redis_client.get(_cfg_key(user_id))
502 if not data:
503 return None
504 return _deserialize_user_config_from_storage(data)
506 cached_entry = user_configs.get(user_id)
507 if not cached_entry:
508 return None
509 cached, cached_at = cached_entry
510 if (time.monotonic() - cached_at) > USER_CONFIG_TTL:
511 user_configs.pop(user_id, None)
512 return None
513 return _deserialize_user_config_from_storage(cached)
516async def delete_user_config(user_id: str):
517 """Delete user configuration from Redis or memory.
519 Args:
520 user_id: User identifier.
521 """
522 if redis_client:
523 await redis_client.delete(_cfg_key(user_id))
524 else:
525 user_configs.pop(user_id, None)
528# ---------- SESSION (active) HELPERS with locking & recreate ----------
531async def set_active_session(user_id: str, session: MCPChatService):
532 """Register an active session locally and mark ownership in Redis with TTL.
534 Args:
535 user_id: User identifier.
536 session: Initialized MCPChatService instance.
537 """
538 active_sessions[user_id] = session
539 if redis_client:
540 # set owner with TTL so dead workers eventually lose ownership
541 await redis_client.set(_active_key(user_id), WORKER_ID, ex=SESSION_TTL)
544async def delete_active_session(user_id: str):
545 """Remove active session locally and from Redis atomically.
547 Uses a Lua script to ensure we only delete the Redis key if we own it,
548 preventing race conditions where another worker's session marker could
549 be deleted if our session expired and was recreated by another worker.
551 Args:
552 user_id: User identifier.
553 """
554 active_sessions.pop(user_id, None)
555 if redis_client:
556 try:
557 # Lua script for atomic check-and-delete (only delete if we own the key)
558 release_script = """
559 if redis.call("get", KEYS[1]) == ARGV[1] then
560 return redis.call("del", KEYS[1])
561 else
562 return 0
563 end
564 """
565 await redis_client.eval(release_script, 1, _active_key(user_id), WORKER_ID)
566 except Exception as e:
567 logger.warning(f"Failed to delete active session for user {SecurityValidator.sanitize_log_message(user_id)}: {e}")
570async def _try_acquire_lock(user_id: str) -> bool:
571 """Attempt to acquire the initialization lock for a user session.
573 Args:
574 user_id: User identifier.
576 Returns:
577 bool: True if lock acquired, False otherwise.
578 """
579 if not redis_client:
580 return True # no redis -> local only, no lock required
581 return await redis_client.set(_lock_key(user_id), WORKER_ID, nx=True, ex=LOCK_TTL)
584async def _release_lock_safe(user_id: str):
585 """Release the lock atomically only if we own it.
587 Uses a Lua script to ensure atomic check-and-delete, preventing
588 the TOCTOU race condition where another worker's lock could be
589 deleted if the original lock expired between get() and delete().
591 Args:
592 user_id: User identifier.
593 """
594 if not redis_client:
595 return
596 try:
597 # Lua script for atomic check-and-delete (only delete if we own the key)
598 release_script = """
599 if redis.call("get", KEYS[1]) == ARGV[1] then
600 return redis.call("del", KEYS[1])
601 else
602 return 0
603 end
604 """
605 await redis_client.eval(release_script, 1, _lock_key(user_id), WORKER_ID)
606 except Exception as e:
607 logger.warning(f"Failed to release lock for user {SecurityValidator.sanitize_log_message(user_id)}: {e}")
610async def _create_local_session_from_config(user_id: str) -> Optional[MCPChatService]:
611 """Create MCPChatService locally from stored config.
613 Args:
614 user_id: User identifier.
616 Returns:
617 Optional[MCPChatService]: Initialized service or None if creation fails.
618 """
619 config = await get_user_config(user_id)
620 if not config:
621 return None
623 # create and initialize with unified history manager
624 try:
625 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client)
626 await chat_service.initialize()
627 await set_active_session(user_id, chat_service)
628 return chat_service
629 except Exception as e:
630 # If initialization fails, ensure nothing partial remains
631 logger.error(f"Failed to initialize MCPChatService for {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True)
632 # cleanup local state and redis ownership (if we set it)
633 await delete_active_session(user_id)
634 return None
637async def get_active_session(user_id: str) -> Optional[MCPChatService]:
638 """
639 Retrieve or (if possible) create the active session for user_id.
641 Behavior:
642 - If Redis is disabled: return local session or None.
643 - If Redis enabled:
644 * If owner == WORKER_ID and local session exists -> return it (and refresh TTL)
645 * If owner == WORKER_ID but local missing -> try to acquire lock and recreate
646 * If no owner -> try to acquire lock and create session here
647 * If owner != WORKER_ID -> wait a short time for owner to appear or return None
649 Args:
650 user_id: User identifier.
652 Returns:
653 Optional[MCPChatService]: Active session if available, None otherwise.
654 """
655 # Fast path: no redis => purely local
656 if not redis_client:
657 return active_sessions.get(user_id)
659 active_key = _active_key(user_id)
660 # _lock_key = _lock_key(user_id)
661 owner = await redis_client.get(active_key)
663 # 1) Owned by this worker
664 if owner == WORKER_ID:
665 local = active_sessions.get(user_id)
666 if local:
667 # refresh TTL so ownership persists while active
668 try:
669 await redis_client.expire(active_key, SESSION_TTL)
670 except Exception as e: # nosec B110
671 # non-fatal if expire fails, just log the error
672 logger.debug(f"Failed to refresh session TTL for {SecurityValidator.sanitize_log_message(user_id)}: {e}")
673 return local
675 # Owner in Redis points to this worker but local session missing (process restart or lost).
676 # Try to recreate it (acquire lock).
677 acquired = await _try_acquire_lock(user_id)
678 if acquired:
679 try:
680 # create new local session
681 session = await _create_local_session_from_config(user_id)
682 return session
683 finally:
684 await _release_lock_safe(user_id)
685 else:
686 # someone else is (re)creating; wait a bit for them to finish
687 for _ in range(LOCK_RETRIES):
688 await asyncio.sleep(LOCK_WAIT)
689 if active_sessions.get(user_id):
690 return active_sessions.get(user_id)
691 return None
693 # 2) No owner -> try to claim & create session locally
694 if owner is None:
695 acquired = await _try_acquire_lock(user_id)
696 if acquired:
697 try:
698 session = await _create_local_session_from_config(user_id)
699 return session
700 finally:
701 await _release_lock_safe(user_id)
703 # if we couldn't acquire lock, someone else is creating; wait a short time
704 for _ in range(LOCK_RETRIES):
705 await asyncio.sleep(LOCK_WAIT)
706 owner2 = await redis_client.get(active_key)
707 if owner2 == WORKER_ID and active_sessions.get(user_id):
708 return active_sessions.get(user_id)
709 if owner2 is not None and owner2 != WORKER_ID:
710 # some other worker now owns it
711 return None
713 # final attempt to acquire lock (last resort)
714 acquired = await _try_acquire_lock(user_id)
715 if acquired:
716 try:
717 session = await _create_local_session_from_config(user_id)
718 return session
719 finally:
720 await _release_lock_safe(user_id)
721 return None
723 # 3) Owned by another worker -> we don't have it locally
724 # Optionally we could attempt to "steal" if owner is stale, but TTL expiry handles that.
725 return None
728# ---------- ROUTES ----------
731@llmchat_router.post("/connect")
732@require_permission("llm.invoke")
733async def connect(input_data: ConnectInput, request: Request, user=Depends(get_current_user_with_permissions)):
734 """Create or refresh a chat session for a user.
736 Initializes a new MCPChatService instance for the specified user, establishing
737 connections to both the MCP server and the configured LLM provider. If a session
738 already exists for the user, it is gracefully shutdown before creating a new one.
740 Authentication is handled via JWT token from cookies if not explicitly provided
741 in the request body.
743 Args:
744 input_data: ConnectInput containing user_id, optional server/LLM config, and streaming preference.
745 request: FastAPI Request object for accessing cookies and headers.
746 user: Authenticated user context.
748 Returns:
749 dict: Connection status response containing:
750 - status: 'connected'
751 - user_id: The connected user's identifier
752 - provider: The LLM provider being used
753 - tool_count: Number of available MCP tools
754 - tools: List of tool names
756 Raises:
757 HTTPException: If an error occurs.
758 400: Invalid user_id, invalid configuration, or LLM config error.
759 401: Missing authentication token.
760 503: Failed to connect to MCP server.
761 500: Service initialization failure or unexpected error.
763 Examples:
764 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
765 Example request body:
767 {
768 "user_id": "user123",
769 "server": {
770 "url": "http://localhost:8000/mcp",
771 "auth_token": "jwt_token_here"
772 },
773 "llm": {
774 "provider": "ollama",
775 "config": {"model": "llama3"}
776 },
777 "streaming": false
778 }
780 Example response:
782 {
783 "status": "connected",
784 "user_id": "user123",
785 "provider": "ollama",
786 "tool_count": 5,
787 "tools": ["search", "calculator", "weather", "translate", "summarize"]
788 }
790 Note:
791 Existing sessions are automatically terminated before establishing new ones.
792 All configuration values support environment variable fallbacks.
793 """
794 user_id = _resolve_user_id(input_data.user_id, user)
796 try:
797 # Validate user_id
798 if not user_id or not isinstance(user_id, str):
799 raise HTTPException(status_code=400, detail="Invalid user ID provided")
801 # Validate user-supplied server URLs with SSRF protections before any outbound connection setup.
802 if input_data.server and input_data.server.url:
803 try:
804 input_data.server.url = SecurityValidator.validate_url(str(input_data.server.url), "MCP server URL")
805 except ValueError as e:
806 logger.warning("LLM chat connect URL validation failed for user %s and URL %s: %s", user_id, input_data.server.url, e)
807 raise HTTPException(status_code=400, detail="Invalid server URL")
809 # Handle authentication token
810 empty_token = "" # nosec B105
811 if input_data.server and (input_data.server.auth_token is None or input_data.server.auth_token == empty_token):
812 jwt_token = request.cookies.get("jwt_token")
813 if not jwt_token:
814 raise HTTPException(status_code=401, detail="Authentication required. Please ensure you are logged in.")
815 input_data.server.auth_token = jwt_token
817 # Close old session if it exists
818 existing = await get_active_session(user_id)
819 if existing:
820 try:
821 logger.debug(f"Disconnecting existing session for {SecurityValidator.sanitize_log_message(user_id)} before reconnecting")
822 await existing.shutdown()
823 except Exception as shutdown_error:
824 logger.warning(f"Failed to cleanly shutdown existing session for {SecurityValidator.sanitize_log_message(user_id)}: {shutdown_error}")
825 finally:
826 # Always remove the session from active sessions, even if shutdown failed
827 await delete_active_session(user_id)
829 # Build and validate configuration
830 try:
831 config = build_config(input_data)
832 except ValueError as ve:
833 raise HTTPException(status_code=400, detail=f"Invalid configuration: {str(ve)}")
834 except Exception as config_error:
835 raise HTTPException(status_code=400, detail=f"Configuration error: {str(config_error)}")
837 # Store user configuration
838 await set_user_config(user_id, config)
840 # Initialize chat service
841 try:
842 chat_service = MCPChatService(config, user_id=user_id, redis_client=redis_client)
843 await chat_service.initialize()
845 # Clear chat history on new connection
846 await chat_service.clear_history()
847 except ConnectionError as ce:
848 # Clean up partial state
849 await delete_user_config(user_id)
850 raise HTTPException(status_code=503, detail=f"Failed to connect to MCP server: {str(ce)}. Please verify the server URL and authentication.")
851 except ValueError as ve:
852 # Clean up partial state
853 await delete_user_config(user_id)
854 raise HTTPException(status_code=400, detail=f"Invalid LLM configuration: {str(ve)}")
855 except Exception as init_error:
856 # Clean up partial state
857 await delete_user_config(user_id)
858 raise HTTPException(status_code=500, detail=f"Service initialization failed: {str(init_error)}")
860 await set_active_session(user_id, chat_service)
862 # Extract tool names
863 tool_names = []
864 try:
865 if hasattr(chat_service, "_tools") and chat_service._tools:
866 for tool in chat_service._tools:
867 tool_name = getattr(tool, "name", None)
868 if tool_name:
869 tool_names.append(tool_name)
870 except Exception as tool_error:
871 logger.warning(f"Failed to extract tool names: {tool_error}")
872 # Continue without tools list
874 return {"status": "connected", "user_id": user_id, "provider": config.llm.provider, "tool_count": len(tool_names), "tools": tool_names}
876 except HTTPException:
877 # Re-raise HTTP exceptions as-is
878 raise
879 except Exception as e:
880 logger.error(f"Unexpected error in connect endpoint: {e}", exc_info=True)
881 raise HTTPException(status_code=500, detail=f"Unexpected connection error: {str(e)}")
884async def token_streamer(chat_service: MCPChatService, message: str, user_id: str):
885 """Stream chat response tokens as Server-Sent Events (SSE).
887 Asynchronous generator that yields SSE-formatted chunks containing tokens,
888 tool invocation updates, and final response data from the chat service.
889 Uses the unified ChatHistoryManager for history persistence.
891 Args:
892 chat_service: MCPChatService instance configured for the user session.
893 message: User's chat message to process.
894 user_id: User identifier for logging.
896 Yields:
897 bytes: SSE-formatted event data containing:
898 - token events: Incremental content chunks
899 - tool_start: Tool invocation beginning
900 - tool_end: Tool invocation completion
901 - tool_error: Tool execution failure
902 - final: Complete response with metadata
903 - error: Error information with recovery status
905 Event Types:
906 - token: {"content": "text chunk"}
907 - tool_start: {"type": "tool_start", "tool": "name", ...}
908 - tool_end: {"type": "tool_end", "tool": "name", ...}
909 - tool_error: {"type": "tool_error", "tool": "name", "error": "message"}
910 - final: {"type": "final", "text": "complete response", "metadata": {...}}
911 - error: {"type": "error", "error": "message", "recoverable": bool}
913 Examples:
914 This is an async generator used internally by the chat endpoint.
915 It cannot be directly tested with standard doctest.
917 Example event stream:
919 event: token
920 data: {"content": "Hello"}
922 event: token
923 data: {"content": ", how"}
925 event: final
926 data: {"type": "final", "text": "Hello, how can I help?"}
928 Note:
929 SSE format requires 'event: <type>\\ndata: <json>\\n\\n' structure.
930 All exceptions are caught and converted to error events for client handling.
931 """
933 async def sse(event_type: str, data: Dict[str, Any]):
934 """Format data as Server-Sent Event.
936 Args:
937 event_type: SSE event type identifier.
938 data: Payload dictionary to serialize as JSON.
940 Yields:
941 bytes: UTF-8 encoded SSE formatted lines.
942 """
943 yield f"event: {event_type}\n".encode("utf-8")
944 yield f"data: {orjson.dumps(data).decode()}\n\n".encode("utf-8")
946 try:
947 async for ev in chat_service.chat_events(message):
948 et = ev.get("type")
949 if et == "token":
950 content = ev.get("content", "")
951 async for part in sse("token", {"content": content}):
952 yield part
953 elif et in ("tool_start", "tool_end", "tool_error"):
954 async for part in sse(et, ev):
955 yield part
956 elif et == "final":
957 async for part in sse("final", ev):
958 yield part
960 except ConnectionError as ce:
961 error_event = {"type": "error", "error": f"Connection lost: {str(ce)}", "recoverable": False}
962 async for part in sse("error", error_event):
963 yield part
964 except TimeoutError:
965 error_event = {"type": "error", "error": "Request timed out waiting for LLM response", "recoverable": True}
966 async for part in sse("error", error_event):
967 yield part
968 except ChatProcessingError as cpe:
969 # ChatProcessingError wraps tool/parsing/model errors —
970 # the session is still valid, so the frontend can retry.
971 error_event = {"type": "error", "error": f"Service error: {str(cpe)}", "recoverable": True}
972 async for part in sse("error", error_event):
973 yield part
974 except RuntimeError as re:
975 error_event = {"type": "error", "error": f"Service error: {str(re)}", "recoverable": False}
976 async for part in sse("error", error_event):
977 yield part
978 except Exception as e:
979 logger.error(f"Unexpected streaming error: {e}", exc_info=True)
980 error_event = {"type": "error", "error": f"Unexpected error: {str(e)}", "recoverable": False}
981 async for part in sse("error", error_event):
982 yield part
985@llmchat_router.post("/chat")
986@require_permission("llm.invoke")
987async def chat(input_data: ChatInput, user=Depends(get_current_user_with_permissions)):
988 """Send a message to the user's active chat session and receive a response.
990 Processes user messages through the configured LLM with MCP tool integration.
991 Supports both streaming (SSE) and non-streaming response modes. Chat history
992 is managed automatically via the unified ChatHistoryManager.
994 Args:
995 input_data: ChatInput containing user_id, message, and streaming preference.
996 user: Authenticated user context.
998 Returns:
999 For streaming=False:
1000 dict: Response containing:
1001 - user_id: Session identifier
1002 - response: Complete LLM response text
1003 - tool_used: Boolean indicating if tools were invoked
1004 - tools: List of tool names used
1005 - tool_invocations: Detailed tool call information
1006 - elapsed_ms: Processing time in milliseconds
1007 For streaming=True:
1008 StreamingResponse: SSE stream of token and event data.
1010 Raises:
1011 HTTPException: Raised when an HTTP-related error occurs.
1012 400: Missing user_id, empty message, or no active session.
1013 503: Session not initialized, chat service error, or connection lost.
1014 504: Request timeout.
1015 500: Unexpected error.
1017 Examples:
1018 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
1020 Example non-streaming request:
1022 {
1023 "user_id": "user123",
1024 "message": "What's the weather like?",
1025 "streaming": false
1026 }
1028 Example non-streaming response:
1030 {
1031 "user_id": "user123",
1032 "response": "The weather is sunny and 72°F.",
1033 "tool_used": true,
1034 "tools": ["weather"],
1035 "tool_invocations": 1,
1036 "elapsed_ms": 450
1037 }
1039 Example streaming request:
1041 {
1042 "user_id": "user123",
1043 "message": "Tell me a story",
1044 "streaming": true
1045 }
1047 Note:
1048 Streaming responses use Server-Sent Events (SSE) with 'text/event-stream' MIME type.
1049 Client must maintain persistent connection for streaming.
1050 """
1051 user_id = _resolve_user_id(input_data.user_id, user)
1053 # Validate input
1054 if not user_id:
1055 raise HTTPException(status_code=400, detail="User ID is required")
1057 if not input_data.message or not input_data.message.strip():
1058 raise HTTPException(status_code=400, detail="Message cannot be empty")
1060 # Check for active session
1061 chat_service = await get_active_session(user_id)
1062 if not chat_service:
1063 raise HTTPException(status_code=400, detail="No active session found. Please connect to a server first.")
1065 # Verify session is initialized
1066 if not chat_service.is_initialized:
1067 raise HTTPException(status_code=503, detail="Session is not properly initialized. Please reconnect.")
1069 try:
1070 if input_data.streaming:
1071 return StreamingResponse(
1072 token_streamer(chat_service, input_data.message, user_id),
1073 media_type="text/event-stream",
1074 headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, # Disable proxy buffering
1075 )
1076 else:
1077 try:
1078 result = await chat_service.chat_with_metadata(input_data.message)
1080 return {
1081 "user_id": user_id,
1082 "response": result["text"],
1083 "tool_used": result["tool_used"],
1084 "tools": result["tools"],
1085 "tool_invocations": result["tool_invocations"],
1086 "elapsed_ms": result["elapsed_ms"],
1087 }
1088 except RuntimeError as re:
1089 raise HTTPException(status_code=503, detail=f"Chat service error: {str(re)}")
1091 except ConnectionError as ce:
1092 raise HTTPException(status_code=503, detail=f"Lost connection to MCP server: {str(ce)}. Please reconnect.")
1093 except TimeoutError:
1094 raise HTTPException(status_code=504, detail="Request timed out. The LLM took too long to respond.")
1095 except HTTPException:
1096 raise
1097 except Exception as e:
1098 logger.error(f"Unexpected error in chat endpoint for user {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True)
1099 raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}")
1102@llmchat_router.post("/disconnect")
1103@require_permission("llm.invoke")
1104async def disconnect(input_data: DisconnectInput, user=Depends(get_current_user_with_permissions)):
1105 """End the chat session for a user and clean up resources.
1107 Gracefully shuts down the MCPChatService instance, closes connections,
1108 and removes session data from active storage. Safe to call even if
1109 no active session exists.
1111 Args:
1112 input_data: DisconnectInput containing the user_id to disconnect.
1113 user: Authenticated user context.
1115 Returns:
1116 dict: Disconnection status containing:
1117 - status: One of 'disconnected', 'no_active_session', or 'disconnected_with_errors'
1118 - user_id: The user identifier
1119 - message: Human-readable status description
1120 - warning: (Optional) Error details if cleanup encountered issues
1122 Raises:
1123 HTTPException: Raised when an HTTP-related error occurs.
1124 400: Missing user_id.
1126 Examples:
1127 This endpoint is called via HTTP POST and cannot be directly tested with doctest.
1129 Example request:
1131 {
1132 "user_id": "user123"
1133 }
1135 Example successful response:
1137 {
1138 "status": "disconnected",
1139 "user_id": "user123",
1140 "message": "Successfully disconnected"
1141 }
1143 Example response when no session exists:
1145 {
1146 "status": "no_active_session",
1147 "user_id": "user123",
1148 "message": "No active session to disconnect"
1149 }
1151 Note:
1152 This operation is idempotent - calling it multiple times for the same
1153 user_id is safe and will not raise errors.
1154 """
1155 user_id = _resolve_user_id(input_data.user_id, user)
1157 if not user_id:
1158 raise HTTPException(status_code=400, detail="User ID is required")
1160 # Remove and shut down chat service
1161 chat_service = await get_active_session(user_id)
1162 await delete_active_session(user_id)
1164 # Remove user config
1165 await delete_user_config(user_id)
1167 if not chat_service:
1168 return {"status": "no_active_session", "user_id": user_id, "message": "No active session to disconnect"}
1170 try:
1171 # Clear chat history on disconnect
1172 await chat_service.clear_history()
1173 logger.info(f"Chat session disconnected for {SecurityValidator.sanitize_log_message(user_id)}")
1175 await chat_service.shutdown()
1176 return {"status": "disconnected", "user_id": user_id, "message": "Successfully disconnected"}
1177 except Exception as e:
1178 logger.error(f"Error during disconnect for user {SecurityValidator.sanitize_log_message(user_id)}: {e}", exc_info=True)
1179 # Session already removed, so return success with warning
1180 return {"status": "disconnected_with_errors", "user_id": user_id, "message": "Disconnected but cleanup encountered errors", "warning": str(e)}
1183@llmchat_router.get("/status/{user_id}")
1184@require_permission("llm.read")
1185async def status(user_id: str, user=Depends(get_current_user_with_permissions)):
1186 """Check if an active chat session exists for the specified user.
1188 Lightweight endpoint for verifying session state without modifying data.
1189 Useful for health checks and UI state management.
1191 Args:
1192 user_id: User identifier to check session status for.
1193 user: Authenticated user context.
1195 Returns:
1196 dict: Status information containing:
1197 - user_id: The queried user identifier
1198 - connected: Boolean indicating if an active session exists
1200 Examples:
1201 This endpoint is called via HTTP GET and cannot be directly tested with doctest.
1203 Example request:
1204 GET /llmchat/status/user123
1206 Example response (connected):
1208 {
1209 "user_id": "user123",
1210 "connected": true
1211 }
1213 Example response (not connected):
1215 {
1216 "user_id": "user456",
1217 "connected": false
1218 }
1220 Note:
1221 This endpoint does not validate that the session is properly initialized,
1222 only that it exists in the active_sessions dictionary.
1223 """
1224 resolved_user_id = _resolve_user_id(user_id, user)
1225 connected = bool(await get_active_session(resolved_user_id))
1226 return {"user_id": resolved_user_id, "connected": connected}
1229@llmchat_router.get("/config/{user_id}")
1230@require_permission("llm.read")
1231async def get_config(user_id: str, user=Depends(get_current_user_with_permissions)):
1232 """Retrieve the stored configuration for a user's session.
1234 Returns sanitized configuration data with sensitive information (API keys,
1235 auth tokens) removed for security. Useful for debugging and configuration
1236 verification.
1238 Args:
1239 user_id: User identifier whose configuration to retrieve.
1240 user: Authenticated user context.
1242 Returns:
1243 dict: Sanitized configuration dictionary containing:
1244 - mcp_server: Server connection settings (without auth_token)
1245 - llm: LLM provider configuration (without api_key)
1246 - enable_streaming: Boolean streaming preference
1248 Raises:
1249 HTTPException: Raised when an HTTP-related error occurs.
1250 404: No configuration found for the specified user_id.
1253 Examples:
1254 This endpoint is called via HTTP GET and cannot be directly tested with doctest.
1256 Example request:
1257 GET /llmchat/config/user123
1259 Example response:
1261 {
1262 "mcp_server": {
1263 "url": "http://localhost:8000/mcp",
1264 "transport": "streamable_http"
1265 },
1266 "llm": {
1267 "provider": "ollama",
1268 "config": {
1269 "model": "llama3",
1270 "temperature": 0.7
1271 }
1272 },
1273 "enable_streaming": false
1274 }
1276 Security:
1277 API keys and authentication tokens are explicitly removed before returning.
1278 Never log or expose these values in responses.
1279 """
1280 resolved_user_id = _resolve_user_id(user_id, user)
1281 config = await get_user_config(resolved_user_id)
1283 if not config:
1284 raise HTTPException(status_code=404, detail="No config found for this user.")
1286 config_dict = config.model_dump()
1287 return _mask_sensitive_config_values(config_dict)
1290@llmchat_router.get("/gateway/models")
1291@require_permission("llm.read")
1292async def get_gateway_models(_user=Depends(get_current_user_with_permissions)):
1293 """Get available models from configured LLM providers.
1295 Returns a list of enabled models from enabled providers configured
1296 in the gateway's LLM Settings. These models can be used with the
1297 "gateway" provider type in /connect requests.
1299 Returns:
1300 dict: Response containing:
1301 - models: List of available models with provider info
1302 - count: Total number of available models
1304 Examples:
1305 GET /llmchat/gateway/models
1307 Response:
1308 {
1309 "models": [
1310 {
1311 "id": "abc123",
1312 "model_id": "gpt-4o",
1313 "model_name": "GPT-4o",
1314 "provider_id": "def456",
1315 "provider_name": "OpenAI",
1316 "provider_type": "openai",
1317 "supports_streaming": true,
1318 "supports_function_calling": true,
1319 "supports_vision": true
1320 }
1321 ],
1322 "count": 1
1323 }
1325 Raises:
1326 HTTPException: If there is an error retrieving gateway models.
1328 Args:
1329 _user: Authenticated user context.
1330 """
1331 # Import here to avoid circular dependency
1332 # First-Party
1333 from mcpgateway.db import SessionLocal
1334 from mcpgateway.services.llm_provider_service import LLMProviderService
1336 llm_service = LLMProviderService()
1338 try:
1339 with SessionLocal() as db:
1340 models = llm_service.get_gateway_models(db)
1341 return {
1342 "models": [m.model_dump() for m in models],
1343 "count": len(models),
1344 }
1345 except Exception as e:
1346 logger.error(f"Failed to get gateway models: {e}")
1347 raise HTTPException(status_code=500, detail=f"Failed to retrieve gateway models: {str(e)}")