Coverage for mcpgateway / routers / reverse_proxy.py: 100%

196 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/routers/reverse_proxy.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7FastAPI router for handling reverse proxy connections. 

8 

9This module provides WebSocket and SSE endpoints for reverse proxy clients 

10to connect and tunnel their local MCP servers through the gateway. 

11""" 

12 

13# Standard 

14import asyncio 

15from datetime import datetime, timezone 

16from typing import Any, Dict, Optional 

17import uuid 

18 

19# Third-Party 

20from fastapi import APIRouter, Depends, HTTPException, Request, status, WebSocket, WebSocketDisconnect 

21from fastapi.responses import StreamingResponse 

22from fastapi.security import HTTPAuthorizationCredentials 

23import orjson 

24from sqlalchemy.orm import Session 

25 

26# First-Party 

27from mcpgateway.auth import get_current_user 

28from mcpgateway.config import settings 

29from mcpgateway.db import get_db 

30from mcpgateway.middleware.rbac import _ACCESS_DENIED_MSG, PermissionChecker 

31from mcpgateway.services.logging_service import LoggingService 

32from mcpgateway.utils.verify_credentials import extract_websocket_bearer_token, is_proxy_auth_trust_active, require_auth 

33 

34# Initialize logging 

35logging_service = LoggingService() 

36LOGGER = logging_service.get_logger("mcpgateway.routers.reverse_proxy") 

37 

38router = APIRouter(prefix="/reverse-proxy", tags=["reverse-proxy"]) 

39 

40 

41class ReverseProxySession: 

42 """Manages a reverse proxy session.""" 

43 

44 def __init__(self, session_id: str, websocket: WebSocket, user: Optional[str | dict] = None): 

45 """Initialize reverse proxy session. 

46 

47 Args: 

48 session_id: Unique session identifier. 

49 websocket: WebSocket connection. 

50 user: Authenticated user info (if any). 

51 """ 

52 self.session_id = session_id 

53 self.websocket = websocket 

54 self.user = user 

55 self.server_info: Dict[str, Any] = {} 

56 self.connected_at = datetime.now(tz=timezone.utc) 

57 self.last_activity = datetime.now(tz=timezone.utc) 

58 self.message_count = 0 

59 self.bytes_transferred = 0 

60 

61 async def send_message(self, message: Dict[str, Any]) -> None: 

62 """Send message to the client. 

63 

64 Args: 

65 message: Message dictionary to send. 

66 """ 

67 data = orjson.dumps(message).decode() 

68 await self.websocket.send_text(data) 

69 self.bytes_transferred += len(data) 

70 self.last_activity = datetime.now(tz=timezone.utc) 

71 

72 async def receive_message(self) -> Dict[str, Any]: 

73 """Receive message from the client. 

74 

75 Returns: 

76 Parsed message dictionary. 

77 """ 

78 data = await self.websocket.receive_text() 

79 self.bytes_transferred += len(data) 

80 self.message_count += 1 

81 self.last_activity = datetime.now(tz=timezone.utc) 

82 return orjson.loads(data) 

83 

84 

85class ReverseProxyManager: 

86 """Manages all reverse proxy sessions.""" 

87 

88 def __init__(self): 

89 """Initialize the manager.""" 

90 self.sessions: Dict[str, ReverseProxySession] = {} 

91 self._lock = asyncio.Lock() 

92 

93 async def add_session(self, session: ReverseProxySession) -> None: 

94 """Add a new session. 

95 

96 Args: 

97 session: Session to add. 

98 """ 

99 async with self._lock: 

100 self.sessions[session.session_id] = session 

101 LOGGER.info(f"Added reverse proxy session: {session.session_id}") 

102 

103 async def remove_session(self, session_id: str) -> None: 

104 """Remove a session. 

105 

106 Args: 

107 session_id: Session ID to remove. 

108 """ 

109 async with self._lock: 

110 if session_id in self.sessions: 

111 del self.sessions[session_id] 

112 LOGGER.info(f"Removed reverse proxy session: {session_id}") 

113 

114 def get_session(self, session_id: str) -> Optional[ReverseProxySession]: 

115 """Get a session by ID. 

116 

117 Args: 

118 session_id: Session ID to get. 

119 

120 Returns: 

121 Session if found, None otherwise. 

122 """ 

123 return self.sessions.get(session_id) 

124 

125 def list_sessions(self) -> list[Dict[str, Any]]: 

126 """List all active sessions. 

127 

128 Returns: 

129 List of session information dictionaries. 

130 

131 Examples: 

132 >>> from fastapi import WebSocket 

133 >>> manager = ReverseProxyManager() 

134 >>> sessions = manager.list_sessions() 

135 >>> sessions 

136 [] 

137 >>> isinstance(sessions, list) 

138 True 

139 """ 

140 return [ 

141 { 

142 "session_id": session.session_id, 

143 "server_info": session.server_info, 

144 "connected_at": session.connected_at.isoformat(), 

145 "last_activity": session.last_activity.isoformat(), 

146 "message_count": session.message_count, 

147 "bytes_transferred": session.bytes_transferred, 

148 "user": session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None, 

149 } 

150 for session in self.sessions.values() 

151 ] 

152 

153 

154# Global manager instance 

155manager = ReverseProxyManager() 

156 

157_REVERSE_PROXY_CONNECT_PERMISSIONS = [ 

158 "servers.create", 

159 "servers.update", 

160 "servers.manage", 

161] 

162 

163 

164def _get_websocket_bearer_token(websocket: WebSocket) -> Optional[str]: 

165 """Extract bearer token from WebSocket Authorization headers. 

166 

167 Args: 

168 websocket: Incoming WebSocket connection. 

169 

170 Returns: 

171 Bearer token value when present, otherwise None. 

172 """ 

173 return extract_websocket_bearer_token( 

174 getattr(websocket, "query_params", {}), 

175 getattr(websocket, "headers", {}), 

176 query_param_warning="Reverse proxy WebSocket token passed via query parameter", 

177 ) 

178 

179 

180async def _authenticate_reverse_proxy_websocket(websocket: WebSocket) -> Optional[str]: 

181 """Authenticate and authorize a reverse-proxy WebSocket connection. 

182 

183 Args: 

184 websocket: Incoming WebSocket connection. 

185 

186 Returns: 

187 Authenticated user email when available, otherwise None. 

188 

189 Raises: 

190 HTTPException: If authentication fails or required permissions are missing. 

191 """ 

192 auth_required = settings.auth_required or settings.mcp_client_auth_enabled 

193 auth_token = _get_websocket_bearer_token(websocket) 

194 user_context: Optional[dict[str, Any]] = None 

195 

196 if auth_token: 

197 credentials = HTTPAuthorizationCredentials(scheme="Bearer", credentials=auth_token) 

198 try: 

199 user = await get_current_user(credentials, request=websocket) 

200 except HTTPException: 

201 raise 

202 except Exception as exc: 

203 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication failed") from exc 

204 user_context = { 

205 "email": user.email, 

206 "full_name": user.full_name, 

207 "is_admin": user.is_admin, 

208 "ip_address": websocket.client.host if websocket.client else None, 

209 "user_agent": websocket.headers.get("user-agent"), 

210 "team_id": getattr(websocket.state, "team_id", None), 

211 "token_teams": getattr(websocket.state, "token_teams", None), 

212 "token_use": getattr(websocket.state, "token_use", None), 

213 } 

214 elif is_proxy_auth_trust_active(settings): 

215 proxy_user = websocket.headers.get(settings.proxy_user_header) 

216 if proxy_user: 

217 user_context = { 

218 "email": proxy_user, 

219 "full_name": proxy_user, 

220 "is_admin": False, 

221 "ip_address": websocket.client.host if websocket.client else None, 

222 "user_agent": websocket.headers.get("user-agent"), 

223 } 

224 elif auth_required: 

225 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required") 

226 elif auth_required: 

227 raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required") 

228 

229 if user_context: 

230 checker = PermissionChecker(user_context) 

231 if not await checker.has_any_permission(_REVERSE_PROXY_CONNECT_PERMISSIONS): 

232 LOGGER.warning("Reverse proxy permission denied: user=%s", user_context.get("email")) 

233 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=_ACCESS_DENIED_MSG) 

234 return user_context["email"] 

235 

236 return None 

237 

238 

239@router.websocket("/ws") 

240async def websocket_endpoint( 

241 websocket: WebSocket, 

242 db: Session = Depends(get_db), 

243): 

244 """WebSocket endpoint for reverse proxy connections. 

245 

246 Authentication is REQUIRED when: 

247 - settings.auth_required is True, OR 

248 - settings.mcp_client_auth_enabled is True 

249 

250 Supports: 

251 - Bearer token in Authorization header 

252 - Proxy authentication (when trust_proxy_auth is True and mcp_client_auth_enabled is False) 

253 

254 Args: 

255 websocket: WebSocket connection. 

256 db: Database session. 

257 

258 Raises: 

259 ValueError: If token is missing required subject claim. 

260 """ 

261 try: 

262 user = await _authenticate_reverse_proxy_websocket(websocket) 

263 except HTTPException as e: 

264 LOGGER.warning(f"Reverse proxy WebSocket authentication failed: {e.detail}") 

265 await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=str(e.detail)) 

266 return 

267 

268 # Accept connection only after successful authentication (or when auth not required) 

269 await websocket.accept() 

270 

271 # Generate session ID server-side to prevent session hijacking 

272 # Client-supplied X-Session-ID is ignored for security (prevents collision/hijack attacks) 

273 session_id = uuid.uuid4().hex 

274 

275 # Create session with authenticated user 

276 session = ReverseProxySession(session_id, websocket, user) 

277 await manager.add_session(session) 

278 

279 try: 

280 LOGGER.info(f"Reverse proxy connected: {session_id}") 

281 

282 # Main message loop 

283 while True: 

284 try: 

285 message = await session.receive_message() 

286 msg_type = message.get("type") 

287 

288 if msg_type == "register": 

289 # Register the server 

290 session.server_info = message.get("server", {}) 

291 LOGGER.info(f"Registered server for session {session_id}: {session.server_info.get('name')}") 

292 

293 # Send acknowledgment 

294 await session.send_message({"type": "register_ack", "sessionId": session_id, "status": "success"}) 

295 

296 elif msg_type == "unregister": 

297 # Unregister the server 

298 LOGGER.info(f"Unregistering server for session {session_id}") 

299 break 

300 

301 elif msg_type == "heartbeat": 

302 # Respond to heartbeat 

303 await session.send_message({"type": "heartbeat", "sessionId": session_id, "timestamp": datetime.now(tz=timezone.utc).isoformat()}) 

304 

305 elif msg_type in ("response", "notification"): 

306 # Handle MCP response/notification from the proxied server 

307 # TODO: Route to appropriate MCP client 

308 LOGGER.debug(f"Received {msg_type} from session {session_id}") 

309 

310 else: 

311 LOGGER.warning(f"Unknown message type from session {session_id}: {msg_type}") 

312 

313 except WebSocketDisconnect: 

314 LOGGER.info(f"WebSocket disconnected: {session_id}") 

315 break 

316 except orjson.JSONDecodeError as e: 

317 LOGGER.error(f"Invalid JSON from session {session_id}: {e}") 

318 await session.send_message({"type": "error", "message": "Invalid JSON format"}) 

319 except Exception as e: 

320 LOGGER.error(f"Error handling message from session {session_id}: {e}") 

321 await session.send_message({"type": "error", "message": str(e)}) 

322 

323 finally: 

324 await manager.remove_session(session_id) 

325 LOGGER.info(f"Reverse proxy session ended: {session_id}") 

326 

327 

328@router.get("/sessions") 

329async def list_sessions( 

330 request: Request, 

331 credentials: str | dict = Depends(require_auth), 

332): 

333 """List active reverse proxy sessions. 

334 

335 Returns only sessions owned by the authenticated user, unless 

336 the user is an admin (in which case all sessions are returned). 

337 

338 Args: 

339 request: HTTP request. 

340 credentials: Authenticated user credentials. 

341 

342 Returns: 

343 List of session information (filtered by ownership). 

344 """ 

345 requesting_user, is_admin = _get_user_from_credentials(credentials) 

346 

347 # Admins see all sessions 

348 if is_admin: 

349 return {"sessions": manager.list_sessions(), "total": len(manager.sessions)} 

350 

351 # Regular users see only their own sessions 

352 all_sessions = manager.list_sessions() 

353 owned_sessions = [] 

354 for session_info in all_sessions: 

355 session_owner = session_info.get("user") 

356 # Include if: user owns the session, or session has no owner (anonymous) 

357 if not session_owner or session_owner == requesting_user: 

358 owned_sessions.append(session_info) 

359 

360 return {"sessions": owned_sessions, "total": len(owned_sessions)} 

361 

362 

363@router.delete("/sessions/{session_id}") 

364async def disconnect_session( 

365 session_id: str, 

366 request: Request, 

367 credentials: str | dict = Depends(require_auth), 

368): 

369 """Disconnect a reverse proxy session. 

370 

371 Requires authentication and validates session ownership. 

372 Only the session owner or an admin can disconnect a session. 

373 

374 Args: 

375 session_id: Session ID to disconnect. 

376 request: HTTP request. 

377 credentials: Authenticated user credentials. 

378 

379 Returns: 

380 Disconnection status. 

381 

382 Raises: 

383 HTTPException: If session is not found or user is not authorized. 

384 """ 

385 session = manager.get_session(session_id) 

386 if not session: 

387 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

388 

389 # Validate session ownership 

390 _validate_session_ownership(session, credentials, "disconnect") 

391 

392 # Close the WebSocket connection 

393 await session.websocket.close() 

394 await manager.remove_session(session_id) 

395 

396 return {"status": "disconnected", "session_id": session_id} 

397 

398 

399@router.post("/sessions/{session_id}/request") 

400async def send_request_to_session( 

401 session_id: str, 

402 mcp_request: Dict[str, Any], 

403 request: Request, 

404 credentials: str | dict = Depends(require_auth), 

405): 

406 """Send an MCP request to a reverse proxy session. 

407 

408 Requires authentication and validates session ownership. 

409 Only the session owner or an admin can send requests to a session. 

410 

411 Args: 

412 session_id: Session ID to send request to. 

413 mcp_request: MCP request to send. 

414 request: HTTP request. 

415 credentials: Authenticated user credentials. 

416 

417 Returns: 

418 Request acknowledgment. 

419 

420 Raises: 

421 HTTPException: If session is not found, user is not authorized, or request fails. 

422 """ 

423 session = manager.get_session(session_id) 

424 if not session: 

425 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

426 

427 # Validate session ownership 

428 _validate_session_ownership(session, credentials, "send request to") 

429 

430 # Wrap the request in reverse proxy envelope 

431 message = {"type": "request", "sessionId": session_id, "payload": mcp_request} 

432 

433 try: 

434 await session.send_message(message) 

435 return {"status": "sent", "session_id": session_id} 

436 except Exception as e: 

437 raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to send request: {e}") 

438 

439 

440def _get_user_from_credentials(credentials: str | dict) -> tuple[str | None, bool]: 

441 """Extract user and admin status from credentials. 

442 

443 Args: 

444 credentials: Auth credentials (dict from JWT or string) 

445 

446 Returns: 

447 Tuple of (username, is_admin) 

448 """ 

449 if isinstance(credentials, dict): 

450 user = credentials.get("sub") or credentials.get("email") 

451 # Check both top-level is_admin and nested user.is_admin (JWT tokens may nest it) 

452 is_admin = credentials.get("is_admin", False) or credentials.get("user", {}).get("is_admin", False) 

453 return user, is_admin 

454 elif credentials and credentials != "anonymous": 

455 return credentials, False 

456 return None, False 

457 

458 

459def _validate_session_ownership(session: ReverseProxySession, credentials: str | dict, action: str) -> None: 

460 """Validate that the requesting user owns the session or is admin. 

461 

462 Args: 

463 session: The session to validate ownership for 

464 credentials: Auth credentials from require_auth 

465 action: Description of the action for logging 

466 

467 Raises: 

468 HTTPException: 403 if user is not authorized for the session 

469 """ 

470 if not session.user: 

471 # Session was created without auth - allow access 

472 return 

473 

474 requesting_user, is_admin = _get_user_from_credentials(credentials) 

475 

476 # Admins can access any session 

477 if is_admin: 

478 return 

479 

480 # Session owner can access their own session 

481 session_owner = session.user if isinstance(session.user, str) else session.user.get("sub") if isinstance(session.user, dict) else None 

482 if requesting_user and session_owner and requesting_user == session_owner: 

483 return 

484 

485 # Not authorized 

486 LOGGER.warning(f"Session access denied: user {requesting_user} attempted to {action} session owned by {session_owner}") 

487 raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized for this session") 

488 

489 

490@router.get("/sse/{session_id}") 

491async def sse_endpoint( 

492 session_id: str, 

493 request: Request, 

494 credentials: str | dict = Depends(require_auth), 

495): 

496 """SSE endpoint for receiving messages from a reverse proxy session. 

497 

498 Requires authentication via require_auth dependency. 

499 Additionally validates that the authenticated user owns the session. 

500 

501 Args: 

502 session_id: Session ID to subscribe to. 

503 request: HTTP request. 

504 credentials: Authenticated user credentials. 

505 

506 Returns: 

507 SSE stream. 

508 

509 Raises: 

510 HTTPException: If session is not found or user is not authorized. 

511 """ 

512 session = manager.get_session(session_id) 

513 if not session: 

514 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"Session {session_id} not found") 

515 

516 # Validate session ownership 

517 _validate_session_ownership(session, credentials, "subscribe to SSE for") 

518 

519 async def event_generator(): 

520 """Generate SSE events. 

521 

522 Yields: 

523 dict: SSE event data. 

524 

525 Raises: 

526 asyncio.CancelledError: If the generator is cancelled. 

527 """ 

528 try: 

529 # Send initial connection event 

530 yield {"event": "connected", "data": orjson.dumps({"sessionId": session_id, "serverInfo": session.server_info}).decode()} 

531 

532 # TODO: Implement message queue for SSE delivery 

533 while not await request.is_disconnected(): 

534 await asyncio.sleep(30) # Keepalive 

535 yield {"event": "keepalive", "data": orjson.dumps({"timestamp": datetime.now(tz=timezone.utc).isoformat()}).decode()} 

536 

537 except asyncio.CancelledError: 

538 raise 

539 

540 return StreamingResponse( 

541 event_generator(), 

542 media_type="text/event-stream", 

543 headers={ 

544 "Cache-Control": "no-cache", 

545 "Connection": "keep-alive", 

546 "X-Accel-Buffering": "no", 

547 }, 

548 )