Coverage for mcpgateway / routers / reverse_proxy.py: 100%
196 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/reverse_proxy.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7FastAPI router for handling reverse proxy connections.
9This module provides WebSocket and SSE endpoints for reverse proxy clients
10to connect and tunnel their local MCP servers through the gateway.
11"""
13# Standard
14import asyncio
15from datetime import datetime, timezone
16from typing import Any, Dict, Optional
17import uuid
19# Third-Party
20from fastapi import APIRouter, Depends, HTTPException, Request, status, WebSocket, WebSocketDisconnect
21from fastapi.responses import StreamingResponse
22from fastapi.security import HTTPAuthorizationCredentials
23import orjson
24from sqlalchemy.orm import Session
26# First-Party
27from mcpgateway.auth import get_current_user
28from mcpgateway.config import settings
29from mcpgateway.db import get_db
30from mcpgateway.middleware.rbac import _ACCESS_DENIED_MSG, PermissionChecker
31from mcpgateway.services.logging_service import LoggingService
32from mcpgateway.utils.verify_credentials import extract_websocket_bearer_token, is_proxy_auth_trust_active, require_auth
34# Initialize logging
35logging_service = LoggingService()
36LOGGER = logging_service.get_logger("mcpgateway.routers.reverse_proxy")
38router = APIRouter(prefix="/reverse-proxy", tags=["reverse-proxy"])
41class ReverseProxySession:
42 """Manages a reverse proxy session."""
44 def __init__(self, session_id: str, websocket: WebSocket, user: Optional[str | dict] = None):
45 """Initialize reverse proxy session.
47 Args:
48 session_id: Unique session identifier.
49 websocket: WebSocket connection.
50 user: Authenticated user info (if any).
51 """
52 self.session_id = session_id
53 self.websocket = websocket
54 self.user = user
55 self.server_info: Dict[str, Any] = {}
56 self.connected_at = datetime.now(tz=timezone.utc)
57 self.last_activity = datetime.now(tz=timezone.utc)
58 self.message_count = 0
59 self.bytes_transferred = 0
61 async def send_message(self, message: Dict[str, Any]) -> None:
62 """Send message to the client.
64 Args:
65 message: Message dictionary to send.
66 """
67 data = orjson.dumps(message).decode()
68 await self.websocket.send_text(data)
69 self.bytes_transferred += len(data)
70 self.last_activity = datetime.now(tz=timezone.utc)
72 async def receive_message(self) -> Dict[str, Any]:
73 """Receive message from the client.
75 Returns:
76 Parsed message dictionary.
77 """
78 data = await self.websocket.receive_text()
79 self.bytes_transferred += len(data)
80 self.message_count += 1
81 self.last_activity = datetime.now(tz=timezone.utc)
82 return orjson.loads(data)
85class ReverseProxyManager:
86 """Manages all reverse proxy sessions."""
88 def __init__(self):
89 """Initialize the manager."""
90 self.sessions: Dict[str, ReverseProxySession] = {}
91 self._lock = asyncio.Lock()
93 async def add_session(self, session: ReverseProxySession) -> None:
94 """Add a new session.
96 Args:
97 session: Session to add.
98 """
99 async with self._lock:
100 self.sessions[session.session_id] = session
101 LOGGER.info(f"Added reverse proxy session: {session.session_id}")
103 async def remove_session(self, session_id: str) -> None:
104 """Remove a session.
106 Args:
107 session_id: Session ID to remove.
108 """
109 async with self._lock:
110 if session_id in self.sessions:
111 del self.sessions[session_id]
112 LOGGER.info(f"Removed reverse proxy session: {session_id}")
114 def get_session(self, session_id: str) -> Optional[ReverseProxySession]:
115 """Get a session by ID.
117 Args:
118 session_id: Session ID to get.
120 Returns:
121 Session if found, None otherwise.
122 """
123 return self.sessions.get(session_id)
125 def list_sessions(self) -> list[Dict[str, Any]]:
126 """List all active sessions.
128 Returns:
129 List of session information dictionaries.
131 Examples:
132 >>> from fastapi import WebSocket
133 >>> manager = ReverseProxyManager()
134 >>> sessions = manager.list_sessions()
135 >>> sessions
136 []
137 >>> isinstance(sessions, list)
138 True
139 """
140 return [
141 {
142 "session_id": session.session_id,
143 "server_info": session.server_info,
144 "connected_at": session.connected_at.isoformat(),
145 "last_activity": session.last_activity.isoformat(),
146 "message_count": session.message_count,
147 "bytes_transferred": session.bytes_transferred,
148 "user": session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None,
149 }
150 for session in self.sessions.values()
151 ]
154# Global manager instance
155manager = ReverseProxyManager()
157_REVERSE_PROXY_CONNECT_PERMISSIONS = [
158 "servers.create",
159 "servers.update",
160 "servers.manage",
161]
164def _get_websocket_bearer_token(websocket: WebSocket) -> Optional[str]:
165 """Extract bearer token from WebSocket Authorization headers.
167 Args:
168 websocket: Incoming WebSocket connection.
170 Returns:
171 Bearer token value when present, otherwise None.
172 """
173 return extract_websocket_bearer_token(
174 getattr(websocket, "query_params", {}),
175 getattr(websocket, "headers", {}),
176 query_param_warning="Reverse proxy WebSocket token passed via query parameter",
177 )
180async def _authenticate_reverse_proxy_websocket(websocket: WebSocket) -> Optional[str]:
181 """Authenticate and authorize a reverse-proxy WebSocket connection.
183 Args:
184 websocket: Incoming WebSocket connection.
186 Returns:
187 Authenticated user email when available, otherwise None.
189 Raises:
190 HTTPException: If authentication fails or required permissions are missing.
191 """
192 auth_required = settings.auth_required or settings.mcp_client_auth_enabled
193 auth_token = _get_websocket_bearer_token(websocket)
194 user_context: Optional[dict[str, Any]] = None
196 if auth_token:
197 credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=auth_token)
198 try:
199 user = await get_current_user(credentials, request=websocket)
200 except HTTPException:
201 raise
202 except Exception as exc:
203 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication failed") from exc
204 user_context = {
205 "email": user.email,
206 "full_name": user.full_name,
207 "is_admin": user.is_admin,
208 "ip_address": websocket.client.host if websocket.client else None,
209 "user_agent": websocket.headers.get("user-agent"),
210 "team_id": getattr(websocket.state, "team_id", None),
211 "token_teams": getattr(websocket.state, "token_teams", None),
212 "token_use": getattr(websocket.state, "token_use", None),
213 }
214 elif is_proxy_auth_trust_active(settings):
215 proxy_user = websocket.headers.get(settings.proxy_user_header)
216 if proxy_user:
217 user_context = {
218 "email": proxy_user,
219 "full_name": proxy_user,
220 "is_admin": False,
221 "ip_address": websocket.client.host if websocket.client else None,
222 "user_agent": websocket.headers.get("user-agent"),
223 }
224 elif auth_required:
225 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
226 elif auth_required:
227 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
229 if user_context:
230 checker = PermissionChecker(user_context)
231 if not await checker.has_any_permission(_REVERSE_PROXY_CONNECT_PERMISSIONS):
232 LOGGER.warning("Reverse proxy permission denied: user=%s", user_context.get("email"))
233 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG)
234 return user_context["email"]
236 return None
239@router.websocket("/ws")
240async def websocket_endpoint(
241 websocket: WebSocket,
242 db: Session = Depends(get_db),
243):
244 """WebSocket endpoint for reverse proxy connections.
246 Authentication is REQUIRED when:
247 - settings.auth_required is True, OR
248 - settings.mcp_client_auth_enabled is True
250 Supports:
251 - Bearer token in Authorization header
252 - Proxy authentication (when trust_proxy_auth is True and mcp_client_auth_enabled is False)
254 Args:
255 websocket: WebSocket connection.
256 db: Database session.
258 Raises:
259 ValueError: If token is missing required subject claim.
260 """
261 try:
262 user = await _authenticate_reverse_proxy_websocket(websocket)
263 except HTTPException as e:
264 LOGGER.warning(f"Reverse proxy WebSocket authentication failed: {e.detail}")
265 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=str(e.detail))
266 return
268 # Accept connection only after successful authentication (or when auth not required)
269 await websocket.accept()
271 # Generate session ID server-side to prevent session hijacking
272 # Client-supplied X-Session-ID is ignored for security (prevents collision/hijack attacks)
273 session_id = uuid.uuid4().hex
275 # Create session with authenticated user
276 session = ReverseProxySession(session_id, websocket, user)
277 await manager.add_session(session)
279 try:
280 LOGGER.info(f"Reverse proxy connected: {session_id}")
282 # Main message loop
283 while True:
284 try:
285 message = await session.receive_message()
286 msg_type = message.get("type")
288 if msg_type == "register":
289 # Register the server
290 session.server_info = message.get("server", {})
291 LOGGER.info(f"Registered server for session {session_id}: {session.server_info.get('name')}")
293 # Send acknowledgment
294 await session.send_message({"type": "register_ack", "sessionId": session_id, "status": "success"})
296 elif msg_type == "unregister":
297 # Unregister the server
298 LOGGER.info(f"Unregistering server for session {session_id}")
299 break
301 elif msg_type == "heartbeat":
302 # Respond to heartbeat
303 await session.send_message({"type": "heartbeat", "sessionId": session_id, "timestamp": datetime.now(tz=timezone.utc).isoformat()})
305 elif msg_type in ("response", "notification"):
306 # Handle MCP response/notification from the proxied server
307 # TODO: Route to appropriate MCP client
308 LOGGER.debug(f"Received {msg_type} from session {session_id}")
310 else:
311 LOGGER.warning(f"Unknown message type from session {session_id}: {msg_type}")
313 except WebSocketDisconnect:
314 LOGGER.info(f"WebSocket disconnected: {session_id}")
315 break
316 except orjson.JSONDecodeError as e:
317 LOGGER.error(f"Invalid JSON from session {session_id}: {e}")
318 await session.send_message({"type": "error", "message": "Invalid JSON format"})
319 except Exception as e:
320 LOGGER.error(f"Error handling message from session {session_id}: {e}")
321 await session.send_message({"type": "error", "message": str(e)})
323 finally:
324 await manager.remove_session(session_id)
325 LOGGER.info(f"Reverse proxy session ended: {session_id}")
328@router.get("/sessions")
329async def list_sessions(
330 request: Request,
331 credentials: str | dict = Depends(require_auth),
332):
333 """List active reverse proxy sessions.
335 Returns only sessions owned by the authenticated user, unless
336 the user is an admin (in which case all sessions are returned).
338 Args:
339 request: HTTP request.
340 credentials: Authenticated user credentials.
342 Returns:
343 List of session information (filtered by ownership).
344 """
345 requesting_user, is_admin = _get_user_from_credentials(credentials)
347 # Admins see all sessions
348 if is_admin:
349 return {"sessions": manager.list_sessions(), "total": len(manager.sessions)}
351 # Regular users see only their own sessions
352 all_sessions = manager.list_sessions()
353 owned_sessions = []
354 for session_info in all_sessions:
355 session_owner = session_info.get("user")
356 # Include if: user owns the session, or session has no owner (anonymous)
357 if not session_owner or session_owner == requesting_user:
358 owned_sessions.append(session_info)
360 return {"sessions": owned_sessions, "total": len(owned_sessions)}
363@router.delete("/sessions/{session_id}")
364async def disconnect_session(
365 session_id: str,
366 request: Request,
367 credentials: str | dict = Depends(require_auth),
368):
369 """Disconnect a reverse proxy session.
371 Requires authentication and validates session ownership.
372 Only the session owner or an admin can disconnect a session.
374 Args:
375 session_id: Session ID to disconnect.
376 request: HTTP request.
377 credentials: Authenticated user credentials.
379 Returns:
380 Disconnection status.
382 Raises:
383 HTTPException: If session is not found or user is not authorized.
384 """
385 session = manager.get_session(session_id)
386 if not session:
387 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
389 # Validate session ownership
390 _validate_session_ownership(session, credentials, "disconnect")
392 # Close the WebSocket connection
393 await session.websocket.close()
394 await manager.remove_session(session_id)
396 return {"status": "disconnected", "session_id": session_id}
399@router.post("/sessions/{session_id}/request")
400async def send_request_to_session(
401 session_id: str,
402 mcp_request: Dict[str, Any],
403 request: Request,
404 credentials: str | dict = Depends(require_auth),
405):
406 """Send an MCP request to a reverse proxy session.
408 Requires authentication and validates session ownership.
409 Only the session owner or an admin can send requests to a session.
411 Args:
412 session_id: Session ID to send request to.
413 mcp_request: MCP request to send.
414 request: HTTP request.
415 credentials: Authenticated user credentials.
417 Returns:
418 Request acknowledgment.
420 Raises:
421 HTTPException: If session is not found, user is not authorized, or request fails.
422 """
423 session = manager.get_session(session_id)
424 if not session:
425 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
427 # Validate session ownership
428 _validate_session_ownership(session, credentials, "send request to")
430 # Wrap the request in reverse proxy envelope
431 message = {"type": "request", "sessionId": session_id, "payload": mcp_request}
433 try:
434 await session.send_message(message)
435 return {"status": "sent", "session_id": session_id}
436 except Exception as e:
437 raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to send request: {e}")
440def _get_user_from_credentials(credentials: str | dict) -> tuple[str | None, bool]:
441 """Extract user and admin status from credentials.
443 Args:
444 credentials: Auth credentials (dict from JWT or string)
446 Returns:
447 Tuple of (username, is_admin)
448 """
449 if isinstance(credentials, dict):
450 user = credentials.get("sub") or credentials.get("email")
451 # Check both top-level is_admin and nested user.is_admin (JWT tokens may nest it)
452 is_admin = credentials.get("is_admin", False) or credentials.get("user", {}).get("is_admin", False)
453 return user, is_admin
454 elif credentials and credentials != "anonymous":
455 return credentials, False
456 return None, False
459def _validate_session_ownership(session: ReverseProxySession, credentials: str | dict, action: str) -> None:
460 """Validate that the requesting user owns the session or is admin.
462 Args:
463 session: The session to validate ownership for
464 credentials: Auth credentials from require_auth
465 action: Description of the action for logging
467 Raises:
468 HTTPException: 403 if user is not authorized for the session
469 """
470 if not session.user:
471 # Session was created without auth - allow access
472 return
474 requesting_user, is_admin = _get_user_from_credentials(credentials)
476 # Admins can access any session
477 if is_admin:
478 return
480 # Session owner can access their own session
481 session_owner = session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None
482 if requesting_user and session_owner and requesting_user == session_owner:
483 return
485 # Not authorized
486 LOGGER.warning(f"Session access denied: user {requesting_user} attempted to {action} session owned by {session_owner}")
487 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized for this session")
490@router.get("/sse/{session_id}")
491async def sse_endpoint(
492 session_id: str,
493 request: Request,
494 credentials: str | dict = Depends(require_auth),
495):
496 """SSE endpoint for receiving messages from a reverse proxy session.
498 Requires authentication via require_auth dependency.
499 Additionally validates that the authenticated user owns the session.
501 Args:
502 session_id: Session ID to subscribe to.
503 request: HTTP request.
504 credentials: Authenticated user credentials.
506 Returns:
507 SSE stream.
509 Raises:
510 HTTPException: If session is not found or user is not authorized.
511 """
512 session = manager.get_session(session_id)
513 if not session:
514 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
516 # Validate session ownership
517 _validate_session_ownership(session, credentials, "subscribe to SSE for")
519 async def event_generator():
520 """Generate SSE events.
522 Yields:
523 dict: SSE event data.
525 Raises:
526 asyncio.CancelledError: If the generator is cancelled.
527 """
528 try:
529 # Send initial connection event
530 yield {"event": "connected", "data": orjson.dumps({"sessionId": session_id, "serverInfo": session.server_info}).decode()}
532 # TODO: Implement message queue for SSE delivery
533 while not await request.is_disconnected():
534 await asyncio.sleep(30) # Keepalive
535 yield {"event": "keepalive", "data": orjson.dumps({"timestamp": datetime.now(tz=timezone.utc).isoformat()}).decode()}
537 except asyncio.CancelledError:
538 raise
540 return StreamingResponse(
541 event_generator(),
542 media_type="text/event-stream",
543 headers={
544 "Cache-Control": "no-cache",
545 "Connection": "keep-alive",
546 "X-Accel-Buffering": "no",
547 },
548 )