Coverage for mcpgateway / routers / reverse_proxy.py: 100%
203 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/routers/reverse_proxy.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7FastAPI router for handling reverse proxy connections.
9This module provides WebSocket and SSE endpoints for reverse proxy clients
10to connect and tunnel their local MCP servers through the gateway.
11"""
13# Standard
14import asyncio
15from datetime import datetime, timezone
16from typing import Any, Dict, Optional
17import uuid
19# Third-Party
20from fastapi import APIRouter, Depends, HTTPException, Request, status, WebSocket, WebSocketDisconnect
21from fastapi.responses import StreamingResponse
22import orjson
23from sqlalchemy.orm import Session
25# First-Party
26from mcpgateway.config import settings
27from mcpgateway.db import get_db
28from mcpgateway.services.logging_service import LoggingService
29from mcpgateway.utils.verify_credentials import require_auth, verify_jwt_token
31# Initialize logging
32logging_service = LoggingService()
33LOGGER = logging_service.get_logger("mcpgateway.routers.reverse_proxy")
35router = APIRouter(prefix="/reverse-proxy", tags=["reverse-proxy"])
38class ReverseProxySession:
39 """Manages a reverse proxy session."""
41 def __init__(self, session_id: str, websocket: WebSocket, user: Optional[str | dict] = None):
42 """Initialize reverse proxy session.
44 Args:
45 session_id: Unique session identifier.
46 websocket: WebSocket connection.
47 user: Authenticated user info (if any).
48 """
49 self.session_id = session_id
50 self.websocket = websocket
51 self.user = user
52 self.server_info: Dict[str, Any] = {}
53 self.connected_at = datetime.now(tz=timezone.utc)
54 self.last_activity = datetime.now(tz=timezone.utc)
55 self.message_count = 0
56 self.bytes_transferred = 0
58 async def send_message(self, message: Dict[str, Any]) -> None:
59 """Send message to the client.
61 Args:
62 message: Message dictionary to send.
63 """
64 data = orjson.dumps(message).decode()
65 await self.websocket.send_text(data)
66 self.bytes_transferred += len(data)
67 self.last_activity = datetime.now(tz=timezone.utc)
69 async def receive_message(self) -> Dict[str, Any]:
70 """Receive message from the client.
72 Returns:
73 Parsed message dictionary.
74 """
75 data = await self.websocket.receive_text()
76 self.bytes_transferred += len(data)
77 self.message_count += 1
78 self.last_activity = datetime.now(tz=timezone.utc)
79 return orjson.loads(data)
82class ReverseProxyManager:
83 """Manages all reverse proxy sessions."""
85 def __init__(self):
86 """Initialize the manager."""
87 self.sessions: Dict[str, ReverseProxySession] = {}
88 self._lock = asyncio.Lock()
90 async def add_session(self, session: ReverseProxySession) -> None:
91 """Add a new session.
93 Args:
94 session: Session to add.
95 """
96 async with self._lock:
97 self.sessions[session.session_id] = session
98 LOGGER.info(f"Added reverse proxy session: {session.session_id}")
100 async def remove_session(self, session_id: str) -> None:
101 """Remove a session.
103 Args:
104 session_id: Session ID to remove.
105 """
106 async with self._lock:
107 if session_id in self.sessions:
108 del self.sessions[session_id]
109 LOGGER.info(f"Removed reverse proxy session: {session_id}")
111 def get_session(self, session_id: str) -> Optional[ReverseProxySession]:
112 """Get a session by ID.
114 Args:
115 session_id: Session ID to get.
117 Returns:
118 Session if found, None otherwise.
119 """
120 return self.sessions.get(session_id)
122 def list_sessions(self) -> list[Dict[str, Any]]:
123 """List all active sessions.
125 Returns:
126 List of session information dictionaries.
128 Examples:
129 >>> from fastapi import WebSocket
130 >>> manager = ReverseProxyManager()
131 >>> sessions = manager.list_sessions()
132 >>> sessions
133 []
134 >>> isinstance(sessions, list)
135 True
136 """
137 return [
138 {
139 "session_id": session.session_id,
140 "server_info": session.server_info,
141 "connected_at": session.connected_at.isoformat(),
142 "last_activity": session.last_activity.isoformat(),
143 "message_count": session.message_count,
144 "bytes_transferred": session.bytes_transferred,
145 "user": session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None,
146 }
147 for session in self.sessions.values()
148 ]
151# Global manager instance
152manager = ReverseProxyManager()
155@router.websocket("/ws")
156async def websocket_endpoint(
157 websocket: WebSocket,
158 db: Session = Depends(get_db),
159):
160 """WebSocket endpoint for reverse proxy connections.
162 Authentication is REQUIRED when:
163 - settings.auth_required is True, OR
164 - settings.mcp_client_auth_enabled is True
166 Supports:
167 - Bearer token in Authorization header
168 - Token in query parameter (?token=...)
169 - Proxy authentication (when trust_proxy_auth is True and mcp_client_auth_enabled is False)
171 Args:
172 websocket: WebSocket connection.
173 db: Database session.
175 Raises:
176 ValueError: If token is missing required subject claim.
177 """
178 # Check authentication BEFORE accepting connection
179 user = None
180 auth_header = websocket.headers.get("Authorization", "")
182 # Determine if auth is required
183 auth_required = settings.auth_required or settings.mcp_client_auth_enabled
185 if auth_required:
186 # Try Bearer token authentication from header
187 if auth_header.startswith("Bearer "):
188 try:
189 token = auth_header.split(" ", 1)[1]
190 payload = await verify_jwt_token(token)
191 user = payload.get("sub") or payload.get("email")
192 if not user:
193 raise ValueError("Token missing subject claim")
194 LOGGER.debug(f"WebSocket authenticated via JWT: {user}")
195 except HTTPException as e:
196 LOGGER.warning(f"WebSocket JWT authentication failed: {e.detail}")
197 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed")
198 return
199 except Exception as e:
200 LOGGER.warning(f"WebSocket JWT authentication failed: {e}")
201 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed")
202 return
203 # Try token from query parameter
204 elif "token" in websocket.query_params:
205 try:
206 token = websocket.query_params["token"]
207 payload = await verify_jwt_token(token)
208 user = payload.get("sub") or payload.get("email")
209 if not user:
210 raise ValueError("Token missing subject claim")
211 LOGGER.debug(f"WebSocket authenticated via query token: {user}")
212 except HTTPException as e:
213 LOGGER.warning(f"WebSocket query token authentication failed: {e.detail}")
214 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed")
215 return
216 except Exception as e:
217 LOGGER.warning(f"WebSocket query token authentication failed: {e}")
218 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed")
219 return
220 # Try proxy authentication (when mcp_client_auth_enabled is False and trust_proxy_auth is True)
221 elif settings.trust_proxy_auth and not settings.mcp_client_auth_enabled:
222 proxy_user = websocket.headers.get(settings.proxy_user_header)
223 if proxy_user:
224 user = proxy_user
225 LOGGER.debug(f"WebSocket authenticated via proxy header: {user}")
226 else:
227 LOGGER.warning("WebSocket proxy authentication failed: no proxy header")
228 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication required")
229 return
230 else:
231 LOGGER.warning("WebSocket authentication required but no credentials provided")
232 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication required")
233 return
235 # Accept connection only after successful authentication (or when auth not required)
236 await websocket.accept()
238 # Generate session ID server-side to prevent session hijacking
239 # Client-supplied X-Session-ID is ignored for security (prevents collision/hijack attacks)
240 session_id = uuid.uuid4().hex
242 # Create session with authenticated user
243 session = ReverseProxySession(session_id, websocket, user)
244 await manager.add_session(session)
246 try:
247 LOGGER.info(f"Reverse proxy connected: {session_id}")
249 # Main message loop
250 while True:
251 try:
252 message = await session.receive_message()
253 msg_type = message.get("type")
255 if msg_type == "register":
256 # Register the server
257 session.server_info = message.get("server", {})
258 LOGGER.info(f"Registered server for session {session_id}: {session.server_info.get('name')}")
260 # Send acknowledgment
261 await session.send_message({"type": "register_ack", "sessionId": session_id, "status": "success"})
263 elif msg_type == "unregister":
264 # Unregister the server
265 LOGGER.info(f"Unregistering server for session {session_id}")
266 break
268 elif msg_type == "heartbeat":
269 # Respond to heartbeat
270 await session.send_message({"type": "heartbeat", "sessionId": session_id, "timestamp": datetime.now(tz=timezone.utc).isoformat()})
272 elif msg_type in ("response", "notification"):
273 # Handle MCP response/notification from the proxied server
274 # TODO: Route to appropriate MCP client
275 LOGGER.debug(f"Received {msg_type} from session {session_id}")
277 else:
278 LOGGER.warning(f"Unknown message type from session {session_id}: {msg_type}")
280 except WebSocketDisconnect:
281 LOGGER.info(f"WebSocket disconnected: {session_id}")
282 break
283 except orjson.JSONDecodeError as e:
284 LOGGER.error(f"Invalid JSON from session {session_id}: {e}")
285 await session.send_message({"type": "error", "message": "Invalid JSON format"})
286 except Exception as e:
287 LOGGER.error(f"Error handling message from session {session_id}: {e}")
288 await session.send_message({"type": "error", "message": str(e)})
290 finally:
291 await manager.remove_session(session_id)
292 LOGGER.info(f"Reverse proxy session ended: {session_id}")
295@router.get("/sessions")
296async def list_sessions(
297 request: Request,
298 credentials: str | dict = Depends(require_auth),
299):
300 """List active reverse proxy sessions.
302 Returns only sessions owned by the authenticated user, unless
303 the user is an admin (in which case all sessions are returned).
305 Args:
306 request: HTTP request.
307 credentials: Authenticated user credentials.
309 Returns:
310 List of session information (filtered by ownership).
311 """
312 requesting_user, is_admin = _get_user_from_credentials(credentials)
314 # Admins see all sessions
315 if is_admin:
316 return {"sessions": manager.list_sessions(), "total": len(manager.sessions)}
318 # Regular users see only their own sessions
319 all_sessions = manager.list_sessions()
320 owned_sessions = []
321 for session_info in all_sessions:
322 session_owner = session_info.get("user")
323 # Include if: user owns the session, or session has no owner (anonymous)
324 if not session_owner or session_owner == requesting_user:
325 owned_sessions.append(session_info)
327 return {"sessions": owned_sessions, "total": len(owned_sessions)}
330@router.delete("/sessions/{session_id}")
331async def disconnect_session(
332 session_id: str,
333 request: Request,
334 credentials: str | dict = Depends(require_auth),
335):
336 """Disconnect a reverse proxy session.
338 Requires authentication and validates session ownership.
339 Only the session owner or an admin can disconnect a session.
341 Args:
342 session_id: Session ID to disconnect.
343 request: HTTP request.
344 credentials: Authenticated user credentials.
346 Returns:
347 Disconnection status.
349 Raises:
350 HTTPException: If session is not found or user is not authorized.
351 """
352 session = manager.get_session(session_id)
353 if not session:
354 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
356 # Validate session ownership
357 _validate_session_ownership(session, credentials, "disconnect")
359 # Close the WebSocket connection
360 await session.websocket.close()
361 await manager.remove_session(session_id)
363 return {"status": "disconnected", "session_id": session_id}
366@router.post("/sessions/{session_id}/request")
367async def send_request_to_session(
368 session_id: str,
369 mcp_request: Dict[str, Any],
370 request: Request,
371 credentials: str | dict = Depends(require_auth),
372):
373 """Send an MCP request to a reverse proxy session.
375 Requires authentication and validates session ownership.
376 Only the session owner or an admin can send requests to a session.
378 Args:
379 session_id: Session ID to send request to.
380 mcp_request: MCP request to send.
381 request: HTTP request.
382 credentials: Authenticated user credentials.
384 Returns:
385 Request acknowledgment.
387 Raises:
388 HTTPException: If session is not found, user is not authorized, or request fails.
389 """
390 session = manager.get_session(session_id)
391 if not session:
392 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
394 # Validate session ownership
395 _validate_session_ownership(session, credentials, "send request to")
397 # Wrap the request in reverse proxy envelope
398 message = {"type": "request", "sessionId": session_id, "payload": mcp_request}
400 try:
401 await session.send_message(message)
402 return {"status": "sent", "session_id": session_id}
403 except Exception as e:
404 raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to send request: {e}")
407def _get_user_from_credentials(credentials: str | dict) -> tuple[str | None, bool]:
408 """Extract user and admin status from credentials.
410 Args:
411 credentials: Auth credentials (dict from JWT or string)
413 Returns:
414 Tuple of (username, is_admin)
415 """
416 if isinstance(credentials, dict):
417 user = credentials.get("sub") or credentials.get("email")
418 # Check both top-level is_admin and nested user.is_admin (JWT tokens may nest it)
419 is_admin = credentials.get("is_admin", False) or credentials.get("user", {}).get("is_admin", False)
420 return user, is_admin
421 elif credentials and credentials != "anonymous":
422 return credentials, False
423 return None, False
426def _validate_session_ownership(session: ReverseProxySession, credentials: str | dict, action: str) -> None:
427 """Validate that the requesting user owns the session or is admin.
429 Args:
430 session: The session to validate ownership for
431 credentials: Auth credentials from require_auth
432 action: Description of the action for logging
434 Raises:
435 HTTPException: 403 if user is not authorized for the session
436 """
437 if not session.user:
438 # Session was created without auth - allow access
439 return
441 requesting_user, is_admin = _get_user_from_credentials(credentials)
443 # Admins can access any session
444 if is_admin:
445 return
447 # Session owner can access their own session
448 session_owner = session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None
449 if requesting_user and session_owner and requesting_user == session_owner:
450 return
452 # Not authorized
453 LOGGER.warning(f"Session access denied: user {requesting_user} attempted to {action} session owned by {session_owner}")
454 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized for this session")
457@router.get("/sse/{session_id}")
458async def sse_endpoint(
459 session_id: str,
460 request: Request,
461 credentials: str | dict = Depends(require_auth),
462):
463 """SSE endpoint for receiving messages from a reverse proxy session.
465 Requires authentication via require_auth dependency.
466 Additionally validates that the authenticated user owns the session.
468 Args:
469 session_id: Session ID to subscribe to.
470 request: HTTP request.
471 credentials: Authenticated user credentials.
473 Returns:
474 SSE stream.
476 Raises:
477 HTTPException: If session is not found or user is not authorized.
478 """
479 session = manager.get_session(session_id)
480 if not session:
481 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found")
483 # Validate session ownership
484 _validate_session_ownership(session, credentials, "subscribe to SSE for")
486 async def event_generator():
487 """Generate SSE events.
489 Yields:
490 dict: SSE event data.
491 """
492 try:
493 # Send initial connection event
494 yield {"event": "connected", "data": orjson.dumps({"sessionId": session_id, "serverInfo": session.server_info}).decode()}
496 # TODO: Implement message queue for SSE delivery
497 while not await request.is_disconnected():
498 await asyncio.sleep(30) # Keepalive
499 yield {"event": "keepalive", "data": orjson.dumps({"timestamp": datetime.now(tz=timezone.utc).isoformat()}).decode()}
501 except asyncio.CancelledError:
502 pass
504 return StreamingResponse(
505 event_generator(),
506 media_type="text/event-stream",
507 headers={
508 "Cache-Control": "no-cache",
509 "Connection": "keep-alive",
510 "X-Accel-Buffering": "no",
511 },
512 )