Coverage for mcpgateway / transports / sse_transport.py: 100%

253 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/sse_transport.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7SSE Transport Implementation. 

8This module implements Server-Sent Events (SSE) transport for MCP, 

9providing server-to-client streaming with proper session management. 

10""" 

11 

12# Standard 

13import asyncio 

14from collections import deque 

15import logging 

16import time 

17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional 

18import uuid 

19 

20# Third-Party 

21import anyio 

22from anyio._backends._asyncio import CancelScope 

23from fastapi import Request 

24import orjson 

25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse 

26from starlette.types import Receive, Scope, Send 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.services.logging_service import LoggingService 

31from mcpgateway.transports.base import Transport 

32 

33# Initialize logging service first 

34logging_service = LoggingService() 

35logger = logging_service.get_logger(__name__) 

36 

37 

38# ============================================================================= 

39# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695) 

40# ============================================================================= 

41# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond 

42# to CancelledError. This optional monkey-patch adds a max iteration limit to 

43# prevent indefinite spinning. After the limit is reached, we give up delivering 

44# cancellation and let the scope exit (tasks will be orphaned but won't spin). 

45# 

46# This workaround is DISABLED by default. Enable via: 

47# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true 

48# 

49# Trade-offs: 

50# - Prevents indefinite CPU spin (good) 

51# - May leave some tasks uncancelled after max iterations (usually harmless) 

52# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks 

53# 

54# This workaround may be removed when anyio or MCP SDK fix the underlying issue. 

55# See: https://github.com/agronholm/anyio/issues/695 

56# ============================================================================= 

57 

58# Store original for potential restoration and for the patch to call 

59_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access 

60_patch_applied = False # pylint: disable=invalid-name 

61 

62 

63def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901 

64 """Create a patched _deliver_cancellation with configurable max iterations. 

65 

66 Args: 

67 max_iterations: Maximum iterations before giving up cancellation delivery. 

68 

69 Returns: 

70 Patched function that limits cancellation delivery iterations. 

71 """ 

72 

73 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access 

74 """Patched _deliver_cancellation with max iteration limit to prevent spin. 

75 

76 This wraps anyio's original _deliver_cancellation to track iteration count 

77 and give up after a maximum number of attempts. This prevents the CPU spin 

78 loop that occurs when tasks don't respond to CancelledError. 

79 

80 Args: 

81 self: The cancel scope being processed. 

82 origin: The cancel scope that originated the cancellation. 

83 

84 Returns: 

85 True if delivery should be retried, False if done or max iterations reached. 

86 """ 

87 # Track iteration count on the origin scope (the one that initiated cancel) 

88 if not hasattr(origin, "_delivery_iterations"): 

89 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access 

90 

91 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access 

92 

93 # Check if we've exceeded the maximum iterations 

94 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access 

95 # Log warning and give up - this prevents indefinite spin 

96 logger.warning( 

97 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. " 

98 "Some tasks may not have been properly cancelled. " 

99 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.", 

100 max_iterations, 

101 ) 

102 # Clear the cancel handle to stop further retries 

103 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access 

104 self._cancel_handle = None # pylint: disable=protected-access 

105 return False # Don't retry 

106 

107 # Call the original implementation 

108 return _original_deliver_cancellation(self, origin) 

109 

110 return _patched_deliver_cancellation 

111 

112 

113def apply_anyio_cancel_delivery_patch() -> bool: 

114 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config. 

115 

116 This function is idempotent - calling it multiple times has no additional effect. 

117 

118 Returns: 

119 True if patch was applied (or already applied), False if disabled. 

120 """ 

121 global _patch_applied # pylint: disable=global-statement 

122 

123 if _patch_applied: 

124 return True 

125 

126 try: 

127 if not settings.anyio_cancel_delivery_patch_enabled: 

128 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.") 

129 return False 

130 

131 max_iterations = settings.anyio_cancel_delivery_max_iterations 

132 patched_func = _create_patched_deliver_cancellation(max_iterations) 

133 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access 

134 _patch_applied = True 

135 

136 logger.info( 

137 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). " 

138 "This is an experimental workaround for anyio#695. " 

139 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.", 

140 max_iterations, 

141 ) 

142 return True 

143 

144 except Exception as e: 

145 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e) 

146 return False 

147 

148 

149def remove_anyio_cancel_delivery_patch() -> bool: 

150 """Remove the anyio _deliver_cancellation monkey-patch. 

151 

152 Restores the original anyio implementation. 

153 

154 Returns: 

155 True if patch was removed, False if it wasn't applied. 

156 """ 

157 global _patch_applied # pylint: disable=global-statement 

158 

159 if not _patch_applied: 

160 return False 

161 

162 try: 

163 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access 

164 _patch_applied = False 

165 logger.info("anyio _deliver_cancellation patch removed - restored original implementation") 

166 return True 

167 except Exception as e: 

168 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e) 

169 return False 

170 

171 

172# Apply patch at module load time if enabled 

173apply_anyio_cancel_delivery_patch() 

174 

175 

176def _get_sse_cleanup_timeout() -> float: 

177 """Get SSE task group cleanup timeout from config. 

178 

179 This timeout controls how long to wait for SSE task group tasks to respond 

180 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's 

181 _deliver_cancellation when tasks don't properly handle CancelledError. 

182 

183 Returns: 

184 Cleanup timeout in seconds (default: 5.0) 

185 """ 

186 try: 

187 return settings.sse_task_group_cleanup_timeout 

188 except Exception: 

189 return 5.0 # Fallback default 

190 

191 

192class EventSourceResponse(BaseEventSourceResponse): 

193 """Patched EventSourceResponse with CPU spin detection. 

194 

195 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation 

196 spins at 100% CPU when tasks in the SSE task group don't respond to 

197 cancellation. 

198 

199 Instead of trying to timeout the task group (which would affect normal 

200 SSE connections), we copy the __call__ method and add a deadline to 

201 the cancel scope to ensure cleanup doesn't hang indefinitely. 

202 

203 See: 

204 - https://github.com/agronholm/anyio/issues/695 

205 - https://github.com/anthropics/claude-agent-sdk-python/issues/378 

206 """ 

207 

208 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002 

209 """Enable compression (no-op for SSE streams). 

210 

211 SSE streams don't support compression as per sse_starlette. 

212 This override prevents NotImplementedError from parent class. 

213 

214 Args: 

215 force: Ignored - compression not supported for SSE. 

216 """ 

217 

218 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

219 """Handle SSE request with cancel scope deadline to prevent spin. 

220 

221 This method is copied from sse_starlette with one key modification: 

222 the task group's cancel_scope gets a deadline set when cancellation 

223 starts, preventing indefinite spinning if tasks don't respond. 

224 

225 Args: 

226 scope: ASGI scope dictionary. 

227 receive: ASGI receive callable. 

228 send: ASGI send callable. 

229 """ 

230 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix 

231 async with anyio.create_task_group() as task_group: 

232 # Add deadline to cancel scope to prevent indefinite spin on cleanup 

233 # The deadline is set far in the future initially, and only becomes 

234 # relevant if the scope is cancelled and cleanup takes too long 

235 

236 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None: 

237 """Execute coroutine then cancel task group with bounded deadline. 

238 

239 This wrapper runs the given coroutine and, upon completion, cancels 

240 the parent task group with a deadline to prevent indefinite spinning 

241 if other tasks don't respond to cancellation (anyio#695 mitigation). 

242 

243 Args: 

244 coro: Async callable to execute before triggering cancellation. 

245 """ 

246 await coro() 

247 # When cancelling, set a deadline to prevent indefinite spin 

248 # if other tasks don't respond to cancellation 

249 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout() 

250 task_group.cancel_scope.cancel() 

251 

252 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send)) 

253 task_group.start_soon(cancel_on_finish, lambda: self._ping(send)) 

254 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal) 

255 

256 if self.data_sender_callable: 

257 task_group.start_soon(self.data_sender_callable) 

258 

259 # Wait for the client to disconnect last 

260 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive)) 

261 

262 if self.background is not None: 

263 await self.background() 

264 

265 

266# Pre-computed SSE frame components for performance 

267_SSE_EVENT_PREFIX = b"event: " 

268_SSE_DATA_PREFIX = b"\r\ndata: " 

269_SSE_RETRY_PREFIX = b"\r\nretry: " 

270_SSE_FRAME_END = b"\r\n\r\n" 

271 

272 

273def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes: 

274 """Build SSE frame as bytes to avoid encode/decode overhead. 

275 

276 Args: 

277 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error') 

278 data: JSON data as bytes (from orjson.dumps) 

279 retry: Retry timeout in milliseconds 

280 

281 Returns: 

282 Complete SSE frame as bytes 

283 

284 Note: 

285 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's 

286 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever 

287 needed, this function would need to accept a sep parameter. 

288 

289 Examples: 

290 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000) 

291 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n' 

292 

293 >>> _build_sse_frame(b"keepalive", b"{}", 15000) 

294 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n' 

295 """ 

296 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END 

297 

298 

299class SSETransport(Transport): 

300 """Transport implementation using Server-Sent Events with proper session management. 

301 

302 This transport implementation uses Server-Sent Events (SSE) for real-time 

303 communication between the MCP gateway and clients. It provides streaming 

304 capabilities with automatic session management and keepalive support. 

305 

306 Examples: 

307 >>> # Create SSE transport with default URL 

308 >>> transport = SSETransport() 

309 >>> transport 

310 <mcpgateway.transports.sse_transport.SSETransport object at ...> 

311 

312 >>> # Create SSE transport with custom URL 

313 >>> transport = SSETransport("http://localhost:8080") 

314 >>> transport._base_url 

315 'http://localhost:8080' 

316 

317 >>> # Check initial connection state 

318 >>> import asyncio 

319 >>> asyncio.run(transport.is_connected()) 

320 False 

321 

322 >>> # Verify it's a proper Transport subclass 

323 >>> isinstance(transport, Transport) 

324 True 

325 >>> issubclass(SSETransport, Transport) 

326 True 

327 

328 >>> # Check session ID generation 

329 >>> transport.session_id 

330 '...' 

331 >>> len(transport.session_id) > 0 

332 True 

333 

334 >>> # Verify required methods exist 

335 >>> hasattr(transport, 'connect') 

336 True 

337 >>> hasattr(transport, 'disconnect') 

338 True 

339 >>> hasattr(transport, 'send_message') 

340 True 

341 >>> hasattr(transport, 'receive_message') 

342 True 

343 >>> hasattr(transport, 'is_connected') 

344 True 

345 """ 

346 

347 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars) 

348 _SESSION_ID_PATTERN = None 

349 

350 @staticmethod 

351 def _is_valid_session_id(session_id: str) -> bool: 

352 """Validate session ID format for security. 

353 

354 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars. 

355 This prevents injection attacks via malformed session IDs. 

356 

357 Args: 

358 session_id: The session ID to validate. 

359 

360 Returns: 

361 True if the session ID is valid, False otherwise. 

362 

363 Examples: 

364 >>> SSETransport._is_valid_session_id("abc123") 

365 True 

366 >>> SSETransport._is_valid_session_id("abc-123_def") 

367 True 

368 >>> SSETransport._is_valid_session_id("a" * 128) 

369 True 

370 >>> SSETransport._is_valid_session_id("a" * 129) 

371 False 

372 >>> SSETransport._is_valid_session_id("") 

373 False 

374 >>> SSETransport._is_valid_session_id("abc/def") 

375 False 

376 >>> SSETransport._is_valid_session_id("abc:def") 

377 False 

378 """ 

379 # Standard 

380 import re # pylint: disable=import-outside-toplevel 

381 

382 if not session_id or len(session_id) > 128: 

383 return False 

384 

385 # Lazy compile pattern 

386 if SSETransport._SESSION_ID_PATTERN is None: 

387 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") 

388 

389 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id)) 

390 

391 def __init__(self, base_url: str = None): 

392 """Initialize SSE transport. 

393 

394 Args: 

395 base_url: Base URL for client message endpoints 

396 

397 Examples: 

398 >>> # Test default initialization 

399 >>> transport = SSETransport() 

400 >>> transport._connected 

401 False 

402 >>> transport._message_queue is not None 

403 True 

404 >>> transport._client_gone is not None 

405 True 

406 >>> len(transport._session_id) > 0 

407 True 

408 

409 >>> # Test custom base URL 

410 >>> transport = SSETransport("https://api.example.com") 

411 >>> transport._base_url 

412 'https://api.example.com' 

413 

414 >>> # Test session ID uniqueness 

415 >>> transport1 = SSETransport() 

416 >>> transport2 = SSETransport() 

417 >>> transport1.session_id != transport2.session_id 

418 True 

419 """ 

420 self._base_url = base_url or f"http://{settings.host}:{settings.port}" 

421 self._connected = False 

422 self._message_queue = asyncio.Queue() 

423 self._client_gone = asyncio.Event() 

424 # Server generates session_id for SSE - client receives it in endpoint event 

425 self._session_id = str(uuid.uuid4()) 

426 

427 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id) 

428 

429 async def connect(self) -> None: 

430 """Set up SSE connection. 

431 

432 Examples: 

433 >>> # Test connection setup 

434 >>> transport = SSETransport() 

435 >>> import asyncio 

436 >>> asyncio.run(transport.connect()) 

437 >>> transport._connected 

438 True 

439 >>> asyncio.run(transport.is_connected()) 

440 True 

441 """ 

442 self._connected = True 

443 logger.info("SSE transport connected: %s", self._session_id) 

444 

445 async def disconnect(self) -> None: 

446 """Clean up SSE connection. 

447 

448 Examples: 

449 >>> # Test disconnection 

450 >>> transport = SSETransport() 

451 >>> import asyncio 

452 >>> asyncio.run(transport.connect()) 

453 >>> asyncio.run(transport.disconnect()) 

454 >>> transport._connected 

455 False 

456 >>> transport._client_gone.is_set() 

457 True 

458 >>> asyncio.run(transport.is_connected()) 

459 False 

460 

461 >>> # Test disconnection when already disconnected 

462 >>> transport = SSETransport() 

463 >>> asyncio.run(transport.disconnect()) 

464 >>> transport._connected 

465 False 

466 """ 

467 if self._connected: 

468 self._connected = False 

469 self._client_gone.set() 

470 logger.info("SSE transport disconnected: %s", self._session_id) 

471 

472 async def send_message(self, message: Dict[str, Any]) -> None: 

473 """Send a message over SSE. 

474 

475 Args: 

476 message: Message to send 

477 

478 Raises: 

479 RuntimeError: If transport is not connected 

480 Exception: If unable to put message to queue 

481 

482 Examples: 

483 >>> # Test sending message when connected 

484 >>> transport = SSETransport() 

485 >>> import asyncio 

486 >>> asyncio.run(transport.connect()) 

487 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1} 

488 >>> asyncio.run(transport.send_message(message)) 

489 >>> transport._message_queue.qsize() 

490 1 

491 

492 >>> # Test sending message when not connected 

493 >>> transport = SSETransport() 

494 >>> try: 

495 ... asyncio.run(transport.send_message({"test": "message"})) 

496 ... except RuntimeError as e: 

497 ... print("Expected error:", str(e)) 

498 Expected error: Transport not connected 

499 

500 >>> # Test message format validation 

501 >>> transport = SSETransport() 

502 >>> asyncio.run(transport.connect()) 

503 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}} 

504 >>> isinstance(valid_message, dict) 

505 True 

506 >>> "jsonrpc" in valid_message 

507 True 

508 

509 >>> # Test exception handling in queue put 

510 >>> transport = SSETransport() 

511 >>> asyncio.run(transport.connect()) 

512 >>> # Create a full queue to trigger exception 

513 >>> transport._message_queue = asyncio.Queue(maxsize=1) 

514 >>> asyncio.run(transport._message_queue.put({"dummy": "message"})) 

515 >>> # Now queue is full, next put should fail 

516 >>> try: 

517 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1)) 

518 ... except asyncio.TimeoutError: 

519 ... print("Queue full as expected") 

520 Queue full as expected 

521 """ 

522 if not self._connected: 

523 raise RuntimeError("Transport not connected") 

524 

525 try: 

526 await self._message_queue.put(message) 

527 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)")) 

528 except Exception as e: 

529 logger.error("Failed to queue message: %s", e) 

530 raise 

531 

532 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

533 """Receive messages from the client over SSE transport. 

534 

535 This method implements a continuous message-receiving pattern for SSE transport. 

536 Since SSE is primarily a server-to-client communication channel, this method 

537 yields an initial initialize placeholder message and then enters a waiting loop. 

538 The actual client messages are received via a separate HTTP POST endpoint 

539 (not handled in this method). 

540 

541 The method will continue running until either: 

542 1. The connection is explicitly disconnected (client_gone event is set) 

543 2. The receive loop is cancelled from outside 

544 

545 Yields: 

546 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always 

547 an initialize placeholder with the format: 

548 {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

549 

550 Raises: 

551 RuntimeError: If the transport is not connected when this method is called 

552 asyncio.CancelledError: When the SSE receive loop is cancelled externally 

553 

554 Examples: 

555 >>> # Test receive message when connected 

556 >>> transport = SSETransport() 

557 >>> import asyncio 

558 >>> asyncio.run(transport.connect()) 

559 >>> async def test_receive(): 

560 ... async for msg in transport.receive_message(): 

561 ... return msg 

562 ... return None 

563 >>> result = asyncio.run(test_receive()) 

564 >>> result 

565 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1} 

566 

567 >>> # Test receive message when not connected 

568 >>> transport = SSETransport() 

569 >>> try: 

570 ... async def test_receive(): 

571 ... async for msg in transport.receive_message(): 

572 ... pass 

573 ... asyncio.run(test_receive()) 

574 ... except RuntimeError as e: 

575 ... print("Expected error:", str(e)) 

576 Expected error: Transport not connected 

577 

578 >>> # Verify generator behavior 

579 >>> transport = SSETransport() 

580 >>> import inspect 

581 >>> inspect.isasyncgenfunction(transport.receive_message) 

582 True 

583 """ 

584 if not self._connected: 

585 raise RuntimeError("Transport not connected") 

586 

587 # For SSE, we set up a loop to wait for messages which are delivered via POST 

588 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder 

589 # to keep the receive loop running 

590 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

591 

592 # Continue waiting for cancellation 

593 try: 

594 while not self._client_gone.is_set(): 

595 await asyncio.sleep(1.0) 

596 except asyncio.CancelledError: 

597 logger.info("SSE receive loop cancelled for session %s", self._session_id) 

598 raise 

599 finally: 

600 logger.info("SSE receive loop ended for session %s", self._session_id) 

601 

602 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]: 

603 """Get message from queue with timeout, returns None on timeout. 

604 

605 Uses asyncio.wait() to avoid TimeoutError exception overhead. 

606 

607 Args: 

608 timeout: Timeout in seconds, or None for no timeout 

609 

610 Returns: 

611 Message dict if received, None if timeout occurred 

612 

613 Raises: 

614 asyncio.CancelledError: If the operation is cancelled externally 

615 """ 

616 if timeout is None: 

617 return await self._message_queue.get() 

618 

619 get_task = asyncio.create_task(self._message_queue.get()) 

620 try: 

621 done, _ = await asyncio.wait({get_task}, timeout=timeout) 

622 except asyncio.CancelledError: 

623 get_task.cancel() 

624 try: 

625 await get_task 

626 except asyncio.CancelledError: 

627 pass 

628 raise 

629 

630 if get_task in done: 

631 return get_task.result() 

632 

633 # Timeout - cancel pending task, but return the result if it completed in the race window. 

634 get_task.cancel() 

635 try: 

636 return await get_task 

637 except asyncio.CancelledError: 

638 return None 

639 

640 async def is_connected(self) -> bool: 

641 """Check if transport is connected. 

642 

643 Returns: 

644 True if connected 

645 

646 Examples: 

647 >>> # Test initial state 

648 >>> transport = SSETransport() 

649 >>> import asyncio 

650 >>> asyncio.run(transport.is_connected()) 

651 False 

652 

653 >>> # Test after connection 

654 >>> transport = SSETransport() 

655 >>> asyncio.run(transport.connect()) 

656 >>> asyncio.run(transport.is_connected()) 

657 True 

658 

659 >>> # Test after disconnection 

660 >>> transport = SSETransport() 

661 >>> asyncio.run(transport.connect()) 

662 >>> asyncio.run(transport.disconnect()) 

663 >>> asyncio.run(transport.is_connected()) 

664 False 

665 """ 

666 return self._connected 

667 

668 async def create_sse_response( 

669 self, 

670 request: Request, 

671 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None, 

672 ) -> EventSourceResponse: 

673 """Create SSE response for streaming. 

674 

675 Args: 

676 request: FastAPI request (used for disconnection detection) 

677 on_disconnect_callback: Optional async callback to run when client disconnects. 

678 Used for defensive cleanup (e.g., cancelling respond tasks). 

679 

680 Returns: 

681 SSE response object 

682 

683 Examples: 

684 >>> # Test SSE response creation 

685 >>> transport = SSETransport("http://localhost:8000") 

686 >>> # Note: This method requires a FastAPI Request object 

687 >>> # and cannot be easily tested in doctest environment 

688 >>> callable(transport.create_sse_response) 

689 True 

690 """ 

691 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}" 

692 

693 async def event_generator(): 

694 """Generate SSE events. 

695 

696 Yields: 

697 SSE event as bytes (pre-formatted SSE frame) 

698 

699 Raises: 

700 asyncio.CancelledError: If the generator is cancelled. 

701 """ 

702 # Send the endpoint event first 

703 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout) 

704 

705 # Send keepalive immediately to help establish connection (if enabled) 

706 if settings.sse_keepalive_enabled: 

707 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

708 

709 consecutive_errors = 0 

710 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected) 

711 

712 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected 

713 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window. 

714 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None 

715 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0 

716 last_yield_time = time.monotonic() 

717 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection 

718 

719 def check_rapid_yield() -> bool: 

720 """Check if yields are happening too fast. 

721 

722 Returns: 

723 True if spin loop detected and should disconnect, False otherwise. 

724 """ 

725 nonlocal last_yield_time, consecutive_rapid_yields 

726 now = time.monotonic() 

727 time_since_last = now - last_yield_time 

728 last_yield_time = now 

729 

730 # Track consecutive rapid yields (< 100ms apart) 

731 # This catches spin loops even without full deque analysis 

732 if time_since_last < 0.1: 

733 consecutive_rapid_yields += 1 

734 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin 

735 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id) 

736 return True 

737 else: 

738 consecutive_rapid_yields = 0 # Reset on normal-speed yield 

739 

740 # Also use deque-based detection for more nuanced analysis 

741 if yield_timestamps is None: 

742 return False 

743 yield_timestamps.append(now) 

744 if time_since_last < 0.01: # Less than 10ms between yields is very fast 

745 if len(yield_timestamps) > settings.sse_rapid_yield_max: 

746 oldest = yield_timestamps[0] 

747 elapsed = now - oldest 

748 if elapsed < rapid_yield_window_sec: 

749 logger.error( 

750 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id 

751 ) 

752 return True 

753 return False 

754 

755 try: 

756 while not self._client_gone.is_set(): 

757 # Check if client has disconnected via request state 

758 if await request.is_disconnected(): 

759 logger.info("SSE client disconnected (detected via request): %s", self._session_id) 

760 self._client_gone.set() 

761 break 

762 

763 try: 

764 # Use timeout-based polling only when keepalive is enabled 

765 if not settings.sse_keepalive_enabled: 

766 message = await self._message_queue.get() 

767 else: 

768 message = await self._get_message_with_timeout(settings.sse_keepalive_interval) 

769 

770 if message is not None: 

771 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY) 

772 

773 if logger.isEnabledFor(logging.DEBUG): 

774 logger.debug("Sending SSE message: %s", json_bytes.decode()) 

775 

776 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout) 

777 consecutive_errors = 0 # Reset on successful send 

778 

779 # Check for rapid yields after message send 

780 if check_rapid_yield(): 

781 self._client_gone.set() 

782 break 

783 elif settings.sse_keepalive_enabled: 

784 # Timeout - send keepalive 

785 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

786 consecutive_errors = 0 # Reset on successful send 

787 # Check for rapid yields after keepalive too 

788 if check_rapid_yield(): 

789 self._client_gone.set() 

790 break 

791 # Note: We don't clear yield_timestamps here. The deque has maxlen 

792 # which automatically drops old entries. Clearing would prevent 

793 # detection of spin loops where yields happen faster than expected. 

794 

795 except GeneratorExit: 

796 # Client disconnected - generator is being closed 

797 logger.info("SSE generator exit (client disconnected): %s", self._session_id) 

798 self._client_gone.set() 

799 break 

800 except Exception as e: 

801 consecutive_errors += 1 

802 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e) 

803 if consecutive_errors >= max_consecutive_errors: 

804 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id) 

805 self._client_gone.set() 

806 break 

807 # Don't yield error frame - it could cause more errors if client is gone 

808 

809 except asyncio.CancelledError: 

810 logger.info("SSE event generator cancelled: %s", self._session_id) 

811 self._client_gone.set() 

812 raise 

813 except GeneratorExit: 

814 logger.info("SSE generator exit: %s", self._session_id) 

815 self._client_gone.set() 

816 except Exception as e: 

817 logger.error("SSE event generator error: %s", e) 

818 self._client_gone.set() 

819 finally: 

820 logger.info("SSE event generator completed: %s", self._session_id) 

821 self._client_gone.set() # Always set client_gone on exit to clean up 

822 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3) 

823 # This covers server-initiated close, errors, and cancellation - not just client close 

824 if on_disconnect_callback: 

825 try: 

826 await on_disconnect_callback() 

827 except Exception as e: 

828 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e) 

829 

830 async def on_client_close(_scope: dict) -> None: 

831 """Handle client close event from sse_starlette.""" 

832 logger.info("SSE client close handler called: %s", self._session_id) 

833 self._client_gone.set() 

834 

835 # Defensive cleanup via callback (if provided) 

836 if on_disconnect_callback: 

837 try: 

838 await on_disconnect_callback() 

839 except Exception as e: 

840 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e) 

841 

842 return EventSourceResponse( 

843 event_generator(), 

844 status_code=200, 

845 headers={ 

846 "Cache-Control": "no-cache", 

847 "Connection": "keep-alive", 

848 "Content-Type": "text/event-stream", 

849 "X-MCP-SSE": "true", 

850 }, 

851 # Timeout for ASGI send() calls - protects against sends that hang indefinitely 

852 # when client connection is in a bad state (e.g., client stopped reading but TCP 

853 # connection not yet closed). Does NOT affect MCP server response times. 

854 # Set to 0 to disable. Default matches keepalive interval. 

855 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None, 

856 # Callback when client closes - helps detect disconnects that ASGI server 

857 # may not properly propagate via request.is_disconnected() 

858 client_close_handler_callable=on_client_close, 

859 ) 

860 

861 async def _client_disconnected(self, _request: Request) -> bool: 

862 """Check if client has disconnected. 

863 

864 Args: 

865 _request: FastAPI Request object 

866 

867 Returns: 

868 bool: True if client disconnected 

869 

870 Examples: 

871 >>> # Test client disconnected check 

872 >>> transport = SSETransport() 

873 >>> import asyncio 

874 >>> asyncio.run(transport._client_disconnected(None)) 

875 False 

876 

877 >>> # Test after setting client gone 

878 >>> transport = SSETransport() 

879 >>> transport._client_gone.set() 

880 >>> asyncio.run(transport._client_disconnected(None)) 

881 True 

882 """ 

883 # We only check our internal client_gone flag 

884 # We intentionally don't check connection_lost on the request 

885 # as it can be unreliable and cause premature closures 

886 return self._client_gone.is_set() 

887 

888 @property 

889 def session_id(self) -> str: 

890 """ 

891 Get the session ID for this transport. 

892 

893 Returns: 

894 str: session_id 

895 

896 Examples: 

897 >>> # Test session ID property 

898 >>> transport = SSETransport() 

899 >>> session_id = transport.session_id 

900 >>> isinstance(session_id, str) 

901 True 

902 >>> len(session_id) > 0 

903 True 

904 >>> session_id == transport._session_id 

905 True 

906 

907 >>> # Test session ID uniqueness 

908 >>> transport1 = SSETransport() 

909 >>> transport2 = SSETransport() 

910 >>> transport1.session_id != transport2.session_id 

911 True 

912 """ 

913 return self._session_id