Coverage for mcpgateway / reverse_proxy.py: 99%

333 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/reverse_proxy.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7MCP Reverse Proxy - Bridge local MCP servers to remote gateways. 

8This module implements a reverse proxy that connects local MCP servers 

9(running via stdio) to remote gateways, enabling servers behind firewalls 

10or NATs to be accessible without inbound network access. 

11 

12The reverse proxy establishes an outbound WebSocket or SSE connection to 

13a remote gateway and registers the local server. All MCP protocol messages 

14are then tunneled through this persistent connection. 

15 

16Environment variables: 

17- REVERSE_PROXY_GATEWAY: Remote gateway URL (required) 

18- REVERSE_PROXY_TOKEN: Bearer token for authentication (optional) 

19- REVERSE_PROXY_RECONNECT_DELAY: Initial reconnection delay in seconds (default 1) 

20- REVERSE_PROXY_MAX_RETRIES: Maximum reconnection attempts (default 0 = infinite) 

21- REVERSE_PROXY_LOG_LEVEL: Python log level (default INFO) 

22 

23Example: 

24 $ export REVERSE_PROXY_GATEWAY=https://gateway.example.com 

25 $ export REVERSE_PROXY_TOKEN=$(python3 -m mcpgateway.utils.create_jwt_token --username admin@example.com --exp 10080 --secret key) 

26 $ python3 -m mcpgateway.reverse_proxy --local-stdio "uvx mcp-server-git" 

27""" 

28 

29# Future 

30from __future__ import annotations 

31 

32# Standard 

33import argparse 

34import asyncio 

35from contextlib import suppress 

36from enum import Enum 

37import logging 

38import os 

39import shlex 

40import signal 

41import sys 

42from typing import Any, cast, Dict, List, Optional 

43from urllib.parse import urljoin, urlparse 

44import uuid 

45 

46# Third-Party 

47import orjson 

48 

49try: 

50 # Third-Party 

51 import httpx 

52except ImportError: 

53 httpx = None # type: ignore[assignment] 

54 

55try: 

56 # Third-Party 

57 import websockets 

58except ImportError: 

59 websockets = None # type: ignore[assignment] 

60 

61 

62try: 

63 # Third-Party 

64 import yaml 

65except ImportError: 

66 yaml = None # type: ignore[assignment] 

67 

68# First-Party 

69from mcpgateway.services.logging_service import LoggingService 

70 

71# Type alias for the websocket client protocol to avoid hard dependency at type-check time 

72WSClientProtocol = Any # type: ignore[assignment] 

73 

74# Initialize logging 

75logging_service = LoggingService() 

76LOGGER = logging_service.get_logger("mcpgateway.reverse_proxy") 

77 

78# Environment variable names 

79ENV_GATEWAY = "REVERSE_PROXY_GATEWAY" 

80ENV_TOKEN = "REVERSE_PROXY_TOKEN" # nosec B105 - environment variable name, not a secret 

81ENV_RECONNECT_DELAY = "REVERSE_PROXY_RECONNECT_DELAY" 

82ENV_MAX_RETRIES = "REVERSE_PROXY_MAX_RETRIES" 

83ENV_LOG_LEVEL = "REVERSE_PROXY_LOG_LEVEL" 

84 

85# Default configuration 

86DEFAULT_RECONNECT_DELAY = 1.0 # seconds 

87DEFAULT_MAX_RETRIES = 0 # 0 = infinite 

88DEFAULT_KEEPALIVE_INTERVAL = 30 # seconds 

89DEFAULT_REQUEST_TIMEOUT = 90 # seconds 

90 

91 

92class ConnectionState(Enum): 

93 """Connection state enumeration. 

94 

95 Examples: 

96 >>> ConnectionState.DISCONNECTED.value 

97 'disconnected' 

98 >>> ConnectionState.CONNECTED.value 

99 'connected' 

100 >>> ConnectionState.CONNECTING.value 

101 'connecting' 

102 """ 

103 

104 DISCONNECTED = "disconnected" 

105 CONNECTING = "connecting" 

106 CONNECTED = "connected" 

107 RECONNECTING = "reconnecting" 

108 SHUTTING_DOWN = "shutting_down" 

109 

110 

111class MessageType(Enum): 

112 """Control message types for the reverse proxy protocol. 

113 

114 Examples: 

115 >>> MessageType.REGISTER.value 

116 'register' 

117 >>> MessageType.REQUEST.value 

118 'request' 

119 >>> MessageType.HEARTBEAT.value 

120 'heartbeat' 

121 """ 

122 

123 # Control messages 

124 REGISTER = "register" 

125 UNREGISTER = "unregister" 

126 HEARTBEAT = "heartbeat" 

127 ERROR = "error" 

128 

129 # MCP messages 

130 REQUEST = "request" 

131 RESPONSE = "response" 

132 NOTIFICATION = "notification" 

133 

134 

135class StdioProcess: 

136 """Manages a local MCP server subprocess via stdio.""" 

137 

138 def __init__(self, command: str): 

139 """Initialize stdio process manager. 

140 

141 Args: 

142 command: The command to run as a subprocess. 

143 """ 

144 self.command = command 

145 self.process: Optional[asyncio.subprocess.Process] = None 

146 self._stdout_reader_task: Optional[asyncio.Task] = None 

147 self._message_handlers: List[Any] = [] 

148 

149 async def start(self) -> None: 

150 """Start the stdio subprocess. 

151 

152 Raises: 

153 RuntimeError: If subprocess creation fails with stdio. 

154 """ 

155 LOGGER.info(f"Starting local MCP server: {self.command}") 

156 

157 self.process = await asyncio.create_subprocess_exec( 

158 *shlex.split(self.command), 

159 stdin=asyncio.subprocess.PIPE, 

160 stdout=asyncio.subprocess.PIPE, 

161 stderr=sys.stderr, # Pass through for debugging 

162 ) 

163 

164 if not self.process.stdin or not self.process.stdout: 

165 raise RuntimeError(f"Failed to create subprocess with stdio: {self.command}") 

166 

167 # Start reading stdout 

168 self._stdout_reader_task = asyncio.create_task(self._read_stdout()) 

169 LOGGER.info(f"Local MCP server started (PID: {self.process.pid})") 

170 

171 async def stop(self) -> None: 

172 """Stop the stdio subprocess gracefully.""" 

173 if not self.process: 

174 return 

175 

176 LOGGER.info(f"Stopping local MCP server (PID: {self.process.pid})") 

177 

178 # Cancel stdout reader 

179 if self._stdout_reader_task: 179 ↛ 185line 179 didn't jump to line 185 because the condition on line 179 was always true

180 self._stdout_reader_task.cancel() 

181 with suppress(asyncio.CancelledError): 

182 await self._stdout_reader_task 

183 

184 # Terminate process 

185 self.process.terminate() 

186 with suppress(asyncio.TimeoutError): 

187 await asyncio.wait_for(self.process.wait(), timeout=5) 

188 

189 # Force kill if needed 

190 if self.process.returncode is None: 

191 LOGGER.warning("Force killing subprocess") 

192 self.process.kill() 

193 await self.process.wait() 

194 

195 async def send(self, message: str) -> None: 

196 """Send a message to the subprocess stdin. 

197 

198 Args: 

199 message: JSON-RPC message to send. 

200 

201 Raises: 

202 RuntimeError: If subprocess is not running. 

203 """ 

204 if not self.process or not self.process.stdin: 

205 raise RuntimeError("Subprocess not running") 

206 

207 LOGGER.debug(f"→ stdio: {message[:200]}...") 

208 self.process.stdin.write((message + "\n").encode()) 

209 await self.process.stdin.drain() 

210 

211 def add_message_handler(self, handler) -> None: 

212 """Add a handler for messages from stdout. 

213 

214 Args: 

215 handler: Async function to handle messages. 

216 """ 

217 self._message_handlers.append(handler) 

218 

219 async def _read_stdout(self) -> None: 

220 """Read messages from subprocess stdout. 

221 

222 Raises: 

223 asyncio.CancelledError: When the read task is cancelled. 

224 """ 

225 if not self.process or not self.process.stdout: 

226 return 

227 

228 try: 

229 while True: 

230 line = await self.process.stdout.readline() 

231 if not line: 

232 break 

233 

234 message = line.decode().strip() 

235 if not message: 

236 continue 

237 

238 LOGGER.debug(f"← stdio: {message[:200]}...") 

239 

240 # Notify handlers 

241 for handler in self._message_handlers: 

242 try: 

243 await handler(message) 

244 except Exception as e: 

245 LOGGER.error(f"Handler error: {e}") 

246 

247 except asyncio.CancelledError: # pylint: disable=try-except-raise 

248 raise 

249 except Exception as e: 

250 LOGGER.error(f"Error reading stdout: {e}") 

251 

252 

253class ReverseProxyClient: 

254 """Reverse proxy client that bridges local stdio to remote gateway.""" 

255 

256 def __init__( 

257 self, 

258 gateway_url: str, 

259 local_command: str, 

260 token: Optional[str] = None, 

261 reconnect_delay: float = DEFAULT_RECONNECT_DELAY, 

262 max_retries: int = DEFAULT_MAX_RETRIES, 

263 keepalive_interval: float = DEFAULT_KEEPALIVE_INTERVAL, 

264 ): 

265 """Initialize reverse proxy client. 

266 

267 Args: 

268 gateway_url: Remote gateway URL. 

269 local_command: Local MCP server command. 

270 token: Optional bearer token for authentication. 

271 reconnect_delay: Initial reconnection delay in seconds. 

272 max_retries: Maximum reconnection attempts (0 = infinite). 

273 keepalive_interval: Heartbeat interval in seconds. 

274 """ 

275 self.gateway_url = gateway_url 

276 self.local_command = local_command 

277 self.token = token 

278 self.reconnect_delay = reconnect_delay 

279 self.max_retries = max_retries 

280 self.keepalive_interval = keepalive_interval 

281 

282 # Parse gateway URL 

283 parsed = urlparse(gateway_url) 

284 self.use_websocket = parsed.scheme in ("ws", "wss", "http", "https") 

285 

286 # Connection state 

287 self.state = ConnectionState.DISCONNECTED 

288 self.connection: Optional[WSClientProtocol] = None 

289 self.session_id = uuid.uuid4().hex 

290 self.retry_count = 0 

291 

292 # Components 

293 self.stdio_process = StdioProcess(local_command) 

294 self.stdio_process.add_message_handler(self._handle_stdio_message) 

295 

296 # Tasks 

297 self._keepalive_task: Optional[asyncio.Task] = None 

298 self._receive_task: Optional[asyncio.Task] = None 

299 

300 # Request tracking for correlation 

301 self._pending_requests: Dict[Any, asyncio.Future] = {} 

302 

303 async def connect(self) -> None: 

304 """Establish connection to remote gateway. 

305 

306 Raises: 

307 Exception: If connection fails. 

308 """ 

309 if self.state != ConnectionState.DISCONNECTED: 

310 return 

311 

312 self.state = ConnectionState.CONNECTING 

313 

314 try: 

315 # Start local server first 

316 await self.stdio_process.start() 

317 

318 # Connect to gateway 

319 if self.use_websocket: 

320 await self._connect_websocket() 

321 else: 

322 await self._connect_sse() 

323 

324 self.state = ConnectionState.CONNECTED 

325 self.retry_count = 0 

326 

327 # Register with gateway 

328 await self._register() 

329 

330 # Start keepalive 

331 self._keepalive_task = asyncio.create_task(self._keepalive_loop()) 

332 

333 LOGGER.info(f"Connected to gateway: {self.gateway_url}") 

334 

335 except Exception as e: 

336 LOGGER.error(f"Connection failed: {e}") 

337 self.state = ConnectionState.DISCONNECTED 

338 raise 

339 

340 async def _connect_websocket(self) -> None: 

341 """Connect via WebSocket. 

342 

343 Raises: 

344 ImportError: If websockets package is not installed. 

345 """ 

346 if not websockets: 

347 raise ImportError("websockets package required for WebSocket support") 

348 

349 # Build WebSocket URL 

350 ws_url = self.gateway_url.replace("http://", "ws://").replace("https://", "wss://") 

351 if not ws_url.startswith(("ws://", "wss://")): 

352 ws_url = f"wss://{ws_url}" 

353 

354 # Add reverse proxy endpoint 

355 if "/reverse-proxy" not in ws_url: 

356 ws_url = urljoin(ws_url, "/reverse-proxy/ws") 

357 

358 # Build headers 

359 headers = {} 

360 if self.token: 

361 headers["Authorization"] = f"Bearer {self.token}" 

362 headers["X-Session-ID"] = self.session_id 

363 

364 LOGGER.info(f"Connecting to WebSocket: {ws_url}") 

365 

366 # Connect 

367 self.connection = await websockets.connect( 

368 ws_url, 

369 extra_headers=headers, 

370 ping_interval=20, 

371 ping_timeout=10, 

372 ) 

373 

374 # Start receiving messages 

375 self._receive_task = asyncio.create_task(self._receive_websocket()) 

376 

377 async def _connect_sse(self) -> None: 

378 """Connect via SSE (fallback). 

379 

380 Raises: 

381 ImportError: If httpx package is not installed. 

382 NotImplementedError: SSE transport not yet implemented. 

383 """ 

384 if not httpx: 

385 raise ImportError("httpx package required for SSE support") 

386 

387 # SSE implementation would establish SSE connection 

388 # and use POST endpoint for sending messages 

389 raise NotImplementedError("SSE transport not yet implemented") 

390 

391 async def _register(self) -> None: 

392 """Register local server with gateway.""" 

393 # Get server info by sending initialize request 

394 init_request = { 

395 "jsonrpc": "2.0", 

396 "id": "init-" + uuid.uuid4().hex, 

397 "method": "initialize", 

398 "params": {"protocolVersion": "2024-11-05", "capabilities": {}, "clientInfo": {"name": "reverse-proxy", "version": "1.0.0"}}, 

399 } 

400 

401 # Send to local server 

402 await self.stdio_process.send(orjson.dumps(init_request).decode()) 

403 

404 # Wait for response (simplified - should correlate properly) 

405 await asyncio.sleep(1) 

406 

407 # Send registration to gateway 

408 register_msg = { 

409 "type": MessageType.REGISTER.value, 

410 "sessionId": self.session_id, 

411 "server": { 

412 "name": f"reverse-proxy-{self.session_id[:8]}", 

413 "description": f"Reverse proxied: {self.local_command}", 

414 "protocol": "stdio", 

415 }, 

416 } 

417 

418 await self._send_to_gateway(orjson.dumps(register_msg).decode()) 

419 

420 async def _send_to_gateway(self, message: str) -> None: 

421 """Send message to remote gateway. 

422 

423 Args: 

424 message: Message to send. 

425 

426 Raises: 

427 RuntimeError: If not connected to gateway. 

428 NotImplementedError: If SSE transport is used (not implemented). 

429 """ 

430 conn = self.connection 

431 if not conn: 

432 raise RuntimeError("Not connected to gateway") 

433 

434 if self.use_websocket: 

435 await cast(Any, conn).send(message) 

436 else: 

437 # SSE would POST to message endpoint 

438 raise NotImplementedError("SSE transport not yet implemented") 

439 

440 async def _handle_stdio_message(self, message: str) -> None: 

441 """Handle message from local stdio server. 

442 

443 Args: 

444 message: JSON-RPC message from stdio. 

445 """ 

446 try: 

447 # Parse to check if it's a response or notification 

448 data = orjson.loads(message) 

449 

450 # Wrap in reverse proxy envelope 

451 envelope = {"type": MessageType.RESPONSE.value if "id" in data else MessageType.NOTIFICATION.value, "sessionId": self.session_id, "payload": data} 

452 

453 # Forward to gateway 

454 await self._send_to_gateway(orjson.dumps(envelope).decode()) 

455 

456 except Exception as e: 

457 LOGGER.error(f"Error forwarding stdio message: {e}") 

458 

459 async def _receive_websocket(self) -> None: 

460 """Receive messages from WebSocket connection.""" 

461 if not self.connection: 

462 return 

463 

464 try: 

465 conn = cast(Any, self.connection) 

466 async for message in conn: 

467 await self._handle_gateway_message(message) 

468 except Exception as e: # Catch broad exceptions to avoid dependency-specific attribute errors 

469 closed_exc = None 

470 if websockets is not None: 470 ↛ 474line 470 didn't jump to line 474 because the condition on line 470 was always true

471 ex_mod = getattr(websockets, "exceptions", None) 

472 if ex_mod is not None: 472 ↛ 474line 472 didn't jump to line 474 because the condition on line 472 was always true

473 closed_exc = getattr(ex_mod, "ConnectionClosed", None) 

474 if closed_exc and isinstance(e, closed_exc): 

475 LOGGER.warning("WebSocket connection closed") 

476 else: 

477 LOGGER.error(f"WebSocket receive error: {e}") 

478 finally: 

479 self.state = ConnectionState.DISCONNECTED 

480 

481 async def _handle_gateway_message(self, message: str) -> None: 

482 """Handle message from remote gateway. 

483 

484 Args: 

485 message: Message from gateway. 

486 """ 

487 try: 

488 data = orjson.loads(message) 

489 msg_type = data.get("type") 

490 

491 if msg_type == MessageType.REQUEST.value: 

492 # Forward request to local server 

493 payload = data.get("payload", {}) 

494 await self.stdio_process.send(orjson.dumps(payload).decode()) 

495 

496 elif msg_type == MessageType.HEARTBEAT.value: 

497 # Respond to heartbeat 

498 pong = { 

499 "type": MessageType.HEARTBEAT.value, 

500 "sessionId": self.session_id, 

501 } 

502 await self._send_to_gateway(orjson.dumps(pong).decode()) 

503 

504 elif msg_type == MessageType.ERROR.value: 

505 LOGGER.error(f"Gateway error: {data.get('message', 'Unknown error')}") 

506 

507 else: 

508 LOGGER.warning(f"Unknown message type: {msg_type}") 

509 

510 except Exception as e: 

511 LOGGER.error(f"Error handling gateway message: {e}") 

512 

513 async def _keepalive_loop(self) -> None: 

514 """Send periodic keepalive messages.""" 

515 try: 

516 while self.state == ConnectionState.CONNECTED: 

517 await asyncio.sleep(self.keepalive_interval) 

518 

519 heartbeat = { 

520 "type": MessageType.HEARTBEAT.value, 

521 "sessionId": self.session_id, 

522 } 

523 

524 try: 

525 await self._send_to_gateway(orjson.dumps(heartbeat).decode()) 

526 except Exception as e: 

527 LOGGER.warning(f"Keepalive failed: {e}") 

528 break 

529 

530 except asyncio.CancelledError: 

531 pass 

532 

533 async def disconnect(self) -> None: 

534 """Disconnect from gateway and stop local server.""" 

535 if self.state == ConnectionState.SHUTTING_DOWN: 

536 return 

537 

538 self.state = ConnectionState.SHUTTING_DOWN 

539 LOGGER.info("Disconnecting reverse proxy...") 

540 

541 # Cancel tasks 

542 if self._keepalive_task: 

543 self._keepalive_task.cancel() 

544 if self._receive_task: 

545 self._receive_task.cancel() 

546 

547 # Send unregister message 

548 if self.connection: 

549 try: 

550 unregister = { 

551 "type": MessageType.UNREGISTER.value, 

552 "sessionId": self.session_id, 

553 } 

554 await self._send_to_gateway(orjson.dumps(unregister).decode()) 

555 except Exception: 

556 pass # nosec B110 - Intentionally swallow errors during cleanup 

557 

558 # Close connection 

559 if self.connection: 

560 await cast(Any, self.connection).close() 

561 

562 # Stop local server 

563 await self.stdio_process.stop() 

564 

565 self.state = ConnectionState.DISCONNECTED 

566 LOGGER.info("Reverse proxy disconnected") 

567 

568 async def run_with_reconnect(self) -> None: 

569 """Run the reverse proxy with automatic reconnection.""" 

570 while True: 

571 try: 

572 if self.state == ConnectionState.SHUTTING_DOWN: 

573 break 

574 

575 await self.connect() 

576 

577 # Wait for disconnection 

578 while self.state == ConnectionState.CONNECTED: 

579 await asyncio.sleep(1) 

580 

581 if self.state == ConnectionState.SHUTTING_DOWN: 

582 break 

583 

584 except Exception as e: 

585 LOGGER.error(f"Connection error: {e}") 

586 

587 # Check retry limit 

588 self.retry_count += 1 

589 if self.max_retries > 0 and self.retry_count >= self.max_retries: 

590 LOGGER.error(f"Max retries ({self.max_retries}) exceeded") 

591 break 

592 

593 # Calculate backoff delay 

594 delay = min(self.reconnect_delay * (2**self.retry_count), 60) 

595 LOGGER.info(f"Reconnecting in {delay}s (attempt {self.retry_count})") 

596 

597 self.state = ConnectionState.RECONNECTING 

598 await asyncio.sleep(delay) 

599 self.state = ConnectionState.DISCONNECTED 

600 

601 

602def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: 

603 """Parse command line arguments. 

604 

605 Args: 

606 argv: Command line arguments (default: sys.argv[1:]). 

607 

608 Returns: 

609 Parsed arguments. 

610 

611 Examples: 

612 >>> import os 

613 >>> os.environ['REVERSE_PROXY_GATEWAY'] = 'https://example.com' 

614 >>> args = parse_args(['--local-stdio', 'mcp-server']) 

615 >>> args.local_stdio 

616 'mcp-server' 

617 >>> args.gateway 

618 'https://example.com' 

619 >>> args.log_level 

620 'INFO' 

621 >>> args = parse_args(['--local-stdio', 'mcp-server', '--verbose']) 

622 >>> args.log_level 

623 'DEBUG' 

624 >>> args = parse_args(['--local-stdio', 'mcp-server', '--max-retries', '5']) 

625 >>> args.max_retries 

626 5 

627 """ 

628 parser = argparse.ArgumentParser( 

629 prog="mcpgateway.reverse_proxy", 

630 description="Bridge local MCP servers to remote gateways", 

631 ) 

632 

633 # Required arguments 

634 parser.add_argument( 

635 "--local-stdio", 

636 required=True, 

637 help="Local MCP server command to run via stdio", 

638 ) 

639 

640 parser.add_argument( 

641 "--gateway", 

642 help="Remote gateway URL (can also use REVERSE_PROXY_GATEWAY env var)", 

643 ) 

644 

645 # Authentication 

646 parser.add_argument( 

647 "--token", 

648 help="Bearer token for authentication (can also use REVERSE_PROXY_TOKEN env var)", 

649 ) 

650 

651 # Connection options 

652 parser.add_argument( 

653 "--reconnect-delay", 

654 type=float, 

655 default=DEFAULT_RECONNECT_DELAY, 

656 help=f"Initial reconnection delay in seconds (default: {DEFAULT_RECONNECT_DELAY})", 

657 ) 

658 

659 parser.add_argument( 

660 "--max-retries", 

661 type=int, 

662 default=DEFAULT_MAX_RETRIES, 

663 help=f"Maximum reconnection attempts, 0=infinite (default: {DEFAULT_MAX_RETRIES})", 

664 ) 

665 

666 parser.add_argument( 

667 "--keepalive", 

668 type=int, 

669 default=DEFAULT_KEEPALIVE_INTERVAL, 

670 help=f"Keepalive interval in seconds (default: {DEFAULT_KEEPALIVE_INTERVAL})", 

671 ) 

672 

673 # Configuration file 

674 parser.add_argument( 

675 "--config", 

676 help="Configuration file (YAML or JSON)", 

677 ) 

678 

679 # Logging 

680 parser.add_argument( 

681 "--log-level", 

682 choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], 

683 default="INFO", 

684 help="Log level (default: INFO)", 

685 ) 

686 

687 parser.add_argument( 

688 "--verbose", 

689 action="store_true", 

690 help="Enable verbose logging (same as --log-level DEBUG)", 

691 ) 

692 

693 args = parser.parse_args(argv) 

694 

695 # Handle verbose flag 

696 if args.verbose: 

697 args.log_level = "DEBUG" 

698 

699 # Get gateway from environment if not provided 

700 if not args.gateway: 

701 args.gateway = os.getenv(ENV_GATEWAY) 

702 if not args.gateway: 

703 parser.error("--gateway or REVERSE_PROXY_GATEWAY environment variable required") 

704 

705 # Get token from environment if not provided 

706 if not args.token: 

707 args.token = os.getenv(ENV_TOKEN) 

708 

709 # Load configuration file if provided 

710 if args.config: 

711 if not yaml: 

712 parser.error("PyYAML package required for configuration file support") 

713 yaml_module = cast(Any, yaml) 

714 

715 with open(args.config, "r", encoding="utf-8") as f: 

716 if args.config.endswith((".yaml", ".yml")): 

717 config = yaml_module.safe_load(f) 

718 else: 

719 config = orjson.loads(f.read()) 

720 

721 # Merge configuration (command line takes precedence) 

722 if not isinstance(config, dict): 

723 parser.error("Configuration file must contain a JSON/YAML object at the top level") 

724 else: 

725 for key, value in config.items(): 

726 if not hasattr(args, key) or getattr(args, key) is None: 

727 setattr(args, key, value) 

728 

729 return args 

730 

731 

732async def main(argv: Optional[List[str]] = None) -> None: 

733 """Main entry point for reverse proxy. 

734 

735 Args: 

736 argv: Command line arguments. 

737 """ 

738 args = parse_args(argv) 

739 

740 # Configure logging 

741 logging.basicConfig( 

742 level=getattr(logging, args.log_level), 

743 format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", 

744 datefmt="%Y-%m-%dT%H:%M:%S", 

745 stream=sys.stderr, 

746 ) 

747 

748 # Create reverse proxy client 

749 client = ReverseProxyClient( 

750 gateway_url=args.gateway, 

751 local_command=args.local_stdio, 

752 token=args.token, 

753 reconnect_delay=args.reconnect_delay, 

754 max_retries=args.max_retries, 

755 keepalive_interval=args.keepalive, 

756 ) 

757 

758 # Handle shutdown signals 

759 shutdown_event = asyncio.Event() 

760 

761 def signal_handler(*_args: object) -> None: 

762 """Handle shutdown signals gracefully. 

763 

764 Args: 

765 *_args: Signal handler positional arguments (ignored). 

766 """ 

767 LOGGER.info("Shutdown signal received") 

768 shutdown_event.set() 

769 

770 # Register signal handlers 

771 loop = asyncio.get_running_loop() 

772 for sig in (signal.SIGINT, signal.SIGTERM): 

773 with suppress(NotImplementedError): 

774 loop.add_signal_handler(sig, signal_handler) 

775 

776 # Run client with reconnection 

777 client_task = asyncio.create_task(client.run_with_reconnect()) 

778 

779 try: 

780 # Wait for shutdown 

781 await shutdown_event.wait() 

782 finally: 

783 # Clean shutdown 

784 await client.disconnect() 

785 client_task.cancel() 

786 with suppress(asyncio.CancelledError): 

787 await client_task 

788 

789 

790def run() -> None: 

791 """Console script entry point.""" 

792 try: 

793 asyncio.run(main()) 

794 except KeyboardInterrupt: 

795 print("\nShutdown complete", file=sys.stderr) 

796 sys.exit(0) 

797 except Exception as e: 

798 print(f"Error: {e}", file=sys.stderr) 

799 sys.exit(1) 

800 

801 

802if __name__ == "__main__": 

803 run()