Coverage for mcpgateway / routers / reverse_proxy.py: 100%

203 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/reverse_proxy.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7FastAPI router for handling reverse proxy connections. 

8 

9This module provides WebSocket and SSE endpoints for reverse proxy clients 

10to connect and tunnel their local MCP servers through the gateway. 

11""" 

12 

13# Standard 

14import asyncio 

15from datetime import datetime, timezone 

16from typing import Any, Dict, Optional 

17import uuid 

18 

19# Third-Party 

20from fastapi import APIRouter, Depends, HTTPException, Request, status, WebSocket, WebSocketDisconnect 

21from fastapi.responses import StreamingResponse 

22import orjson 

23from sqlalchemy.orm import Session 

24 

25# First-Party 

26from mcpgateway.config import settings 

27from mcpgateway.db import get_db 

28from mcpgateway.services.logging_service import LoggingService 

29from mcpgateway.utils.verify_credentials import require_auth, verify_jwt_token 

30 

31# Initialize logging 

32logging_service = LoggingService() 

33LOGGER = logging_service.get_logger("mcpgateway.routers.reverse_proxy") 

34 

35router = APIRouter(prefix="/reverse-proxy", tags=["reverse-proxy"]) 

36 

37 

38class ReverseProxySession: 

39 """Manages a reverse proxy session.""" 

40 

41 def __init__(self, session_id: str, websocket: WebSocket, user: Optional[str | dict] = None): 

42 """Initialize reverse proxy session. 

43 

44 Args: 

45 session_id: Unique session identifier. 

46 websocket: WebSocket connection. 

47 user: Authenticated user info (if any). 

48 """ 

49 self.session_id = session_id 

50 self.websocket = websocket 

51 self.user = user 

52 self.server_info: Dict[str, Any] = {} 

53 self.connected_at = datetime.now(tz=timezone.utc) 

54 self.last_activity = datetime.now(tz=timezone.utc) 

55 self.message_count = 0 

56 self.bytes_transferred = 0 

57 

58 async def send_message(self, message: Dict[str, Any]) -> None: 

59 """Send message to the client. 

60 

61 Args: 

62 message: Message dictionary to send. 

63 """ 

64 data = orjson.dumps(message).decode() 

65 await self.websocket.send_text(data) 

66 self.bytes_transferred += len(data) 

67 self.last_activity = datetime.now(tz=timezone.utc) 

68 

69 async def receive_message(self) -> Dict[str, Any]: 

70 """Receive message from the client. 

71 

72 Returns: 

73 Parsed message dictionary. 

74 """ 

75 data = await self.websocket.receive_text() 

76 self.bytes_transferred += len(data) 

77 self.message_count += 1 

78 self.last_activity = datetime.now(tz=timezone.utc) 

79 return orjson.loads(data) 

80 

81 

82class ReverseProxyManager: 

83 """Manages all reverse proxy sessions.""" 

84 

85 def __init__(self): 

86 """Initialize the manager.""" 

87 self.sessions: Dict[str, ReverseProxySession] = {} 

88 self._lock = asyncio.Lock() 

89 

90 async def add_session(self, session: ReverseProxySession) -> None: 

91 """Add a new session. 

92 

93 Args: 

94 session: Session to add. 

95 """ 

96 async with self._lock: 

97 self.sessions[session.session_id] = session 

98 LOGGER.info(f"Added reverse proxy session: {session.session_id}") 

99 

100 async def remove_session(self, session_id: str) -> None: 

101 """Remove a session. 

102 

103 Args: 

104 session_id: Session ID to remove. 

105 """ 

106 async with self._lock: 

107 if session_id in self.sessions: 

108 del self.sessions[session_id] 

109 LOGGER.info(f"Removed reverse proxy session: {session_id}") 

110 

111 def get_session(self, session_id: str) -> Optional[ReverseProxySession]: 

112 """Get a session by ID. 

113 

114 Args: 

115 session_id: Session ID to get. 

116 

117 Returns: 

118 Session if found, None otherwise. 

119 """ 

120 return self.sessions.get(session_id) 

121 

122 def list_sessions(self) -> list[Dict[str, Any]]: 

123 """List all active sessions. 

124 

125 Returns: 

126 List of session information dictionaries. 

127 

128 Examples: 

129 >>> from fastapi import WebSocket 

130 >>> manager = ReverseProxyManager() 

131 >>> sessions = manager.list_sessions() 

132 >>> sessions 

133 [] 

134 >>> isinstance(sessions, list) 

135 True 

136 """ 

137 return [ 

138 { 

139 "session_id": session.session_id, 

140 "server_info": session.server_info, 

141 "connected_at": session.connected_at.isoformat(), 

142 "last_activity": session.last_activity.isoformat(), 

143 "message_count": session.message_count, 

144 "bytes_transferred": session.bytes_transferred, 

145 "user": session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None, 

146 } 

147 for session in self.sessions.values() 

148 ] 

149 

150 

151# Global manager instance 

152manager = ReverseProxyManager() 

153 

154 

155@router.websocket("/ws") 

156async def websocket_endpoint( 

157 websocket: WebSocket, 

158 db: Session = Depends(get_db), 

159): 

160 """WebSocket endpoint for reverse proxy connections. 

161 

162 Authentication is REQUIRED when: 

163 - settings.auth_required is True, OR 

164 - settings.mcp_client_auth_enabled is True 

165 

166 Supports: 

167 - Bearer token in Authorization header 

168 - Token in query parameter (?token=...) 

169 - Proxy authentication (when trust_proxy_auth is True and mcp_client_auth_enabled is False) 

170 

171 Args: 

172 websocket: WebSocket connection. 

173 db: Database session. 

174 

175 Raises: 

176 ValueError: If token is missing required subject claim. 

177 """ 

178 # Check authentication BEFORE accepting connection 

179 user = None 

180 auth_header = websocket.headers.get("Authorization", "") 

181 

182 # Determine if auth is required 

183 auth_required = settings.auth_required or settings.mcp_client_auth_enabled 

184 

185 if auth_required: 

186 # Try Bearer token authentication from header 

187 if auth_header.startswith("Bearer "): 

188 try: 

189 token = auth_header.split(" ", 1)[1] 

190 payload = await verify_jwt_token(token) 

191 user = payload.get("sub") or payload.get("email") 

192 if not user: 

193 raise ValueError("Token missing subject claim") 

194 LOGGER.debug(f"WebSocket authenticated via JWT: {user}") 

195 except HTTPException as e: 

196 LOGGER.warning(f"WebSocket JWT authentication failed: {e.detail}") 

197 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed") 

198 return 

199 except Exception as e: 

200 LOGGER.warning(f"WebSocket JWT authentication failed: {e}") 

201 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed") 

202 return 

203 # Try token from query parameter 

204 elif "token" in websocket.query_params: 

205 try: 

206 token = websocket.query_params["token"] 

207 payload = await verify_jwt_token(token) 

208 user = payload.get("sub") or payload.get("email") 

209 if not user: 

210 raise ValueError("Token missing subject claim") 

211 LOGGER.debug(f"WebSocket authenticated via query token: {user}") 

212 except HTTPException as e: 

213 LOGGER.warning(f"WebSocket query token authentication failed: {e.detail}") 

214 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed") 

215 return 

216 except Exception as e: 

217 LOGGER.warning(f"WebSocket query token authentication failed: {e}") 

218 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed") 

219 return 

220 # Try proxy authentication (when mcp_client_auth_enabled is False and trust_proxy_auth is True) 

221 elif settings.trust_proxy_auth and not settings.mcp_client_auth_enabled: 

222 proxy_user = websocket.headers.get(settings.proxy_user_header) 

223 if proxy_user: 

224 user = proxy_user 

225 LOGGER.debug(f"WebSocket authenticated via proxy header: {user}") 

226 else: 

227 LOGGER.warning("WebSocket proxy authentication failed: no proxy header") 

228 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication required") 

229 return 

230 else: 

231 LOGGER.warning("WebSocket authentication required but no credentials provided") 

232 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication required") 

233 return 

234 

235 # Accept connection only after successful authentication (or when auth not required) 

236 await websocket.accept() 

237 

238 # Generate session ID server-side to prevent session hijacking 

239 # Client-supplied X-Session-ID is ignored for security (prevents collision/hijack attacks) 

240 session_id = uuid.uuid4().hex 

241 

242 # Create session with authenticated user 

243 session = ReverseProxySession(session_id, websocket, user) 

244 await manager.add_session(session) 

245 

246 try: 

247 LOGGER.info(f"Reverse proxy connected: {session_id}") 

248 

249 # Main message loop 

250 while True: 

251 try: 

252 message = await session.receive_message() 

253 msg_type = message.get("type") 

254 

255 if msg_type == "register": 

256 # Register the server 

257 session.server_info = message.get("server", {}) 

258 LOGGER.info(f"Registered server for session {session_id}: {session.server_info.get('name')}") 

259 

260 # Send acknowledgment 

261 await session.send_message({"type": "register_ack", "sessionId": session_id, "status": "success"}) 

262 

263 elif msg_type == "unregister": 

264 # Unregister the server 

265 LOGGER.info(f"Unregistering server for session {session_id}") 

266 break 

267 

268 elif msg_type == "heartbeat": 

269 # Respond to heartbeat 

270 await session.send_message({"type": "heartbeat", "sessionId": session_id, "timestamp": datetime.now(tz=timezone.utc).isoformat()}) 

271 

272 elif msg_type in ("response", "notification"): 

273 # Handle MCP response/notification from the proxied server 

274 # TODO: Route to appropriate MCP client 

275 LOGGER.debug(f"Received {msg_type} from session {session_id}") 

276 

277 else: 

278 LOGGER.warning(f"Unknown message type from session {session_id}: {msg_type}") 

279 

280 except WebSocketDisconnect: 

281 LOGGER.info(f"WebSocket disconnected: {session_id}") 

282 break 

283 except orjson.JSONDecodeError as e: 

284 LOGGER.error(f"Invalid JSON from session {session_id}: {e}") 

285 await session.send_message({"type": "error", "message": "Invalid JSON format"}) 

286 except Exception as e: 

287 LOGGER.error(f"Error handling message from session {session_id}: {e}") 

288 await session.send_message({"type": "error", "message": str(e)}) 

289 

290 finally: 

291 await manager.remove_session(session_id) 

292 LOGGER.info(f"Reverse proxy session ended: {session_id}") 

293 

294 

295@router.get("/sessions") 

296async def list_sessions( 

297 request: Request, 

298 credentials: str | dict = Depends(require_auth), 

299): 

300 """List active reverse proxy sessions. 

301 

302 Returns only sessions owned by the authenticated user, unless 

303 the user is an admin (in which case all sessions are returned). 

304 

305 Args: 

306 request: HTTP request. 

307 credentials: Authenticated user credentials. 

308 

309 Returns: 

310 List of session information (filtered by ownership). 

311 """ 

312 requesting_user, is_admin = _get_user_from_credentials(credentials) 

313 

314 # Admins see all sessions 

315 if is_admin: 

316 return {"sessions": manager.list_sessions(), "total": len(manager.sessions)} 

317 

318 # Regular users see only their own sessions 

319 all_sessions = manager.list_sessions() 

320 owned_sessions = [] 

321 for session_info in all_sessions: 

322 session_owner = session_info.get("user") 

323 # Include if: user owns the session, or session has no owner (anonymous) 

324 if not session_owner or session_owner == requesting_user: 

325 owned_sessions.append(session_info) 

326 

327 return {"sessions": owned_sessions, "total": len(owned_sessions)} 

328 

329 

330@router.delete("/sessions/{session_id}") 

331async def disconnect_session( 

332 session_id: str, 

333 request: Request, 

334 credentials: str | dict = Depends(require_auth), 

335): 

336 """Disconnect a reverse proxy session. 

337 

338 Requires authentication and validates session ownership. 

339 Only the session owner or an admin can disconnect a session. 

340 

341 Args: 

342 session_id: Session ID to disconnect. 

343 request: HTTP request. 

344 credentials: Authenticated user credentials. 

345 

346 Returns: 

347 Disconnection status. 

348 

349 Raises: 

350 HTTPException: If session is not found or user is not authorized. 

351 """ 

352 session = manager.get_session(session_id) 

353 if not session: 

354 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

355 

356 # Validate session ownership 

357 _validate_session_ownership(session, credentials, "disconnect") 

358 

359 # Close the WebSocket connection 

360 await session.websocket.close() 

361 await manager.remove_session(session_id) 

362 

363 return {"status": "disconnected", "session_id": session_id} 

364 

365 

366@router.post("/sessions/{session_id}/request") 

367async def send_request_to_session( 

368 session_id: str, 

369 mcp_request: Dict[str, Any], 

370 request: Request, 

371 credentials: str | dict = Depends(require_auth), 

372): 

373 """Send an MCP request to a reverse proxy session. 

374 

375 Requires authentication and validates session ownership. 

376 Only the session owner or an admin can send requests to a session. 

377 

378 Args: 

379 session_id: Session ID to send request to. 

380 mcp_request: MCP request to send. 

381 request: HTTP request. 

382 credentials: Authenticated user credentials. 

383 

384 Returns: 

385 Request acknowledgment. 

386 

387 Raises: 

388 HTTPException: If session is not found, user is not authorized, or request fails. 

389 """ 

390 session = manager.get_session(session_id) 

391 if not session: 

392 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

393 

394 # Validate session ownership 

395 _validate_session_ownership(session, credentials, "send request to") 

396 

397 # Wrap the request in reverse proxy envelope 

398 message = {"type": "request", "sessionId": session_id, "payload": mcp_request} 

399 

400 try: 

401 await session.send_message(message) 

402 return {"status": "sent", "session_id": session_id} 

403 except Exception as e: 

404 raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to send request: {e}") 

405 

406 

407def _get_user_from_credentials(credentials: str | dict) -> tuple[str | None, bool]: 

408 """Extract user and admin status from credentials. 

409 

410 Args: 

411 credentials: Auth credentials (dict from JWT or string) 

412 

413 Returns: 

414 Tuple of (username, is_admin) 

415 """ 

416 if isinstance(credentials, dict): 

417 user = credentials.get("sub") or credentials.get("email") 

418 # Check both top-level is_admin and nested user.is_admin (JWT tokens may nest it) 

419 is_admin = credentials.get("is_admin", False) or credentials.get("user", {}).get("is_admin", False) 

420 return user, is_admin 

421 elif credentials and credentials != "anonymous": 

422 return credentials, False 

423 return None, False 

424 

425 

426def _validate_session_ownership(session: ReverseProxySession, credentials: str | dict, action: str) -> None: 

427 """Validate that the requesting user owns the session or is admin. 

428 

429 Args: 

430 session: The session to validate ownership for 

431 credentials: Auth credentials from require_auth 

432 action: Description of the action for logging 

433 

434 Raises: 

435 HTTPException: 403 if user is not authorized for the session 

436 """ 

437 if not session.user: 

438 # Session was created without auth - allow access 

439 return 

440 

441 requesting_user, is_admin = _get_user_from_credentials(credentials) 

442 

443 # Admins can access any session 

444 if is_admin: 

445 return 

446 

447 # Session owner can access their own session 

448 session_owner = session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None 

449 if requesting_user and session_owner and requesting_user == session_owner: 

450 return 

451 

452 # Not authorized 

453 LOGGER.warning(f"Session access denied: user {requesting_user} attempted to {action} session owned by {session_owner}") 

454 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized for this session") 

455 

456 

457@router.get("/sse/{session_id}") 

458async def sse_endpoint( 

459 session_id: str, 

460 request: Request, 

461 credentials: str | dict = Depends(require_auth), 

462): 

463 """SSE endpoint for receiving messages from a reverse proxy session. 

464 

465 Requires authentication via require_auth dependency. 

466 Additionally validates that the authenticated user owns the session. 

467 

468 Args: 

469 session_id: Session ID to subscribe to. 

470 request: HTTP request. 

471 credentials: Authenticated user credentials. 

472 

473 Returns: 

474 SSE stream. 

475 

476 Raises: 

477 HTTPException: If session is not found or user is not authorized. 

478 """ 

479 session = manager.get_session(session_id) 

480 if not session: 

481 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

482 

483 # Validate session ownership 

484 _validate_session_ownership(session, credentials, "subscribe to SSE for") 

485 

486 async def event_generator(): 

487 """Generate SSE events. 

488 

489 Yields: 

490 dict: SSE event data. 

491 """ 

492 try: 

493 # Send initial connection event 

494 yield {"event": "connected", "data": orjson.dumps({"sessionId": session_id, "serverInfo": session.server_info}).decode()} 

495 

496 # TODO: Implement message queue for SSE delivery 

497 while not await request.is_disconnected(): 

498 await asyncio.sleep(30) # Keepalive 

499 yield {"event": "keepalive", "data": orjson.dumps({"timestamp": datetime.now(tz=timezone.utc).isoformat()}).decode()} 

500 

501 except asyncio.CancelledError: 

502 pass 

503 

504 return StreamingResponse( 

505 event_generator(), 

506 media_type="text/event-stream", 

507 headers={ 

508 "Cache-Control": "no-cache", 

509 "Connection": "keep-alive", 

510 "X-Accel-Buffering": "no", 

511 }, 

512 )