Coverage for mcpgateway / transports / sse_transport.py: 100%

252 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/sse_transport.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7SSE Transport Implementation. 

8This module implements Server-Sent Events (SSE) transport for MCP, 

9providing server-to-client streaming with proper session management. 

10""" 

11 

12# Standard 

13import asyncio 

14from collections import deque 

15import logging 

16import time 

17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional 

18import uuid 

19 

20# Third-Party 

21import anyio 

22from anyio._backends._asyncio import CancelScope 

23from fastapi import Request 

24import orjson 

25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse 

26from starlette.types import Receive, Scope, Send 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.services.logging_service import LoggingService 

31from mcpgateway.transports.base import Transport 

32 

33# Initialize logging service first 

34logging_service = LoggingService() 

35logger = logging_service.get_logger(__name__) 

36 

37 

38# ============================================================================= 

39# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695) 

40# ============================================================================= 

41# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond 

42# to CancelledError. This optional monkey-patch adds a max iteration limit to 

43# prevent indefinite spinning. After the limit is reached, we give up delivering 

44# cancellation and let the scope exit (tasks will be orphaned but won't spin). 

45# 

46# This workaround is DISABLED by default. Enable via: 

47# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true 

48# 

49# Trade-offs: 

50# - Prevents indefinite CPU spin (good) 

51# - May leave some tasks uncancelled after max iterations (usually harmless) 

52# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks 

53# 

54# This workaround may be removed when anyio or MCP SDK fix the underlying issue. 

55# See: https://github.com/agronholm/anyio/issues/695 

56# ============================================================================= 

57 

58# Store original for potential restoration and for the patch to call 

59_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access 

60_patch_applied = False # pylint: disable=invalid-name 

61 

62 

63def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901 

64 """Create a patched _deliver_cancellation with configurable max iterations. 

65 

66 Args: 

67 max_iterations: Maximum iterations before giving up cancellation delivery. 

68 

69 Returns: 

70 Patched function that limits cancellation delivery iterations. 

71 """ 

72 

73 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access 

74 """Patched _deliver_cancellation with max iteration limit to prevent spin. 

75 

76 This wraps anyio's original _deliver_cancellation to track iteration count 

77 and give up after a maximum number of attempts. This prevents the CPU spin 

78 loop that occurs when tasks don't respond to CancelledError. 

79 

80 Args: 

81 self: The cancel scope being processed. 

82 origin: The cancel scope that originated the cancellation. 

83 

84 Returns: 

85 True if delivery should be retried, False if done or max iterations reached. 

86 """ 

87 # Track iteration count on the origin scope (the one that initiated cancel) 

88 if not hasattr(origin, "_delivery_iterations"): 

89 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access 

90 

91 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access 

92 

93 # Check if we've exceeded the maximum iterations 

94 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access 

95 # Log warning and give up - this prevents indefinite spin 

96 logger.warning( 

97 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. " 

98 "Some tasks may not have been properly cancelled. " 

99 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.", 

100 max_iterations, 

101 ) 

102 # Clear the cancel handle to stop further retries 

103 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access 

104 self._cancel_handle = None # pylint: disable=protected-access 

105 return False # Don't retry 

106 

107 # Call the original implementation 

108 return _original_deliver_cancellation(self, origin) 

109 

110 return _patched_deliver_cancellation 

111 

112 

113def apply_anyio_cancel_delivery_patch() -> bool: 

114 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config. 

115 

116 This function is idempotent - calling it multiple times has no additional effect. 

117 

118 Returns: 

119 True if patch was applied (or already applied), False if disabled. 

120 """ 

121 global _patch_applied # pylint: disable=global-statement 

122 

123 if _patch_applied: 

124 return True 

125 

126 try: 

127 if not settings.anyio_cancel_delivery_patch_enabled: 

128 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.") 

129 return False 

130 

131 max_iterations = settings.anyio_cancel_delivery_max_iterations 

132 patched_func = _create_patched_deliver_cancellation(max_iterations) 

133 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access 

134 _patch_applied = True 

135 

136 logger.info( 

137 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). " 

138 "This is an experimental workaround for anyio#695. " 

139 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.", 

140 max_iterations, 

141 ) 

142 return True 

143 

144 except Exception as e: 

145 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e) 

146 return False 

147 

148 

149def remove_anyio_cancel_delivery_patch() -> bool: 

150 """Remove the anyio _deliver_cancellation monkey-patch. 

151 

152 Restores the original anyio implementation. 

153 

154 Returns: 

155 True if patch was removed, False if it wasn't applied. 

156 """ 

157 global _patch_applied # pylint: disable=global-statement 

158 

159 if not _patch_applied: 

160 return False 

161 

162 try: 

163 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access 

164 _patch_applied = False 

165 logger.info("anyio _deliver_cancellation patch removed - restored original implementation") 

166 return True 

167 except Exception as e: 

168 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e) 

169 return False 

170 

171 

172# Apply patch at module load time if enabled 

173apply_anyio_cancel_delivery_patch() 

174 

175 

176def _get_sse_cleanup_timeout() -> float: 

177 """Get SSE task group cleanup timeout from config. 

178 

179 This timeout controls how long to wait for SSE task group tasks to respond 

180 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's 

181 _deliver_cancellation when tasks don't properly handle CancelledError. 

182 

183 Returns: 

184 Cleanup timeout in seconds (default: 5.0) 

185 """ 

186 try: 

187 return settings.sse_task_group_cleanup_timeout 

188 except Exception: 

189 return 5.0 # Fallback default 

190 

191 

192class EventSourceResponse(BaseEventSourceResponse): 

193 """Patched EventSourceResponse with CPU spin detection. 

194 

195 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation 

196 spins at 100% CPU when tasks in the SSE task group don't respond to 

197 cancellation. 

198 

199 Instead of trying to timeout the task group (which would affect normal 

200 SSE connections), we copy the __call__ method and add a deadline to 

201 the cancel scope to ensure cleanup doesn't hang indefinitely. 

202 

203 See: 

204 - https://github.com/agronholm/anyio/issues/695 

205 - https://github.com/anthropics/claude-agent-sdk-python/issues/378 

206 """ 

207 

208 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002 

209 """Enable compression (no-op for SSE streams). 

210 

211 SSE streams don't support compression as per sse_starlette. 

212 This override prevents NotImplementedError from parent class. 

213 

214 Args: 

215 force: Ignored - compression not supported for SSE. 

216 """ 

217 

218 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

219 """Handle SSE request with cancel scope deadline to prevent spin. 

220 

221 This method is copied from sse_starlette with one key modification: 

222 the task group's cancel_scope gets a deadline set when cancellation 

223 starts, preventing indefinite spinning if tasks don't respond. 

224 

225 Args: 

226 scope: ASGI scope dictionary. 

227 receive: ASGI receive callable. 

228 send: ASGI send callable. 

229 """ 

230 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix 

231 async with anyio.create_task_group() as task_group: 

232 # Add deadline to cancel scope to prevent indefinite spin on cleanup 

233 # The deadline is set far in the future initially, and only becomes 

234 # relevant if the scope is cancelled and cleanup takes too long 

235 

236 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None: 

237 """Execute coroutine then cancel task group with bounded deadline. 

238 

239 This wrapper runs the given coroutine and, upon completion, cancels 

240 the parent task group with a deadline to prevent indefinite spinning 

241 if other tasks don't respond to cancellation (anyio#695 mitigation). 

242 

243 Args: 

244 coro: Async callable to execute before triggering cancellation. 

245 """ 

246 await coro() 

247 # When cancelling, set a deadline to prevent indefinite spin 

248 # if other tasks don't respond to cancellation 

249 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout() 

250 task_group.cancel_scope.cancel() 

251 

252 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send)) 

253 task_group.start_soon(cancel_on_finish, lambda: self._ping(send)) 

254 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal) 

255 

256 if self.data_sender_callable: 

257 task_group.start_soon(self.data_sender_callable) 

258 

259 # Wait for the client to disconnect last 

260 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive)) 

261 

262 if self.background is not None: 

263 await self.background() 

264 

265 

266# Pre-computed SSE frame components for performance 

267_SSE_EVENT_PREFIX = b"event: " 

268_SSE_DATA_PREFIX = b"\r\ndata: " 

269_SSE_RETRY_PREFIX = b"\r\nretry: " 

270_SSE_FRAME_END = b"\r\n\r\n" 

271 

272 

273def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes: 

274 """Build SSE frame as bytes to avoid encode/decode overhead. 

275 

276 Args: 

277 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error') 

278 data: JSON data as bytes (from orjson.dumps) 

279 retry: Retry timeout in milliseconds 

280 

281 Returns: 

282 Complete SSE frame as bytes 

283 

284 Note: 

285 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's 

286 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever 

287 needed, this function would need to accept a sep parameter. 

288 

289 Examples: 

290 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000) 

291 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n' 

292 

293 >>> _build_sse_frame(b"keepalive", b"{}", 15000) 

294 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n' 

295 """ 

296 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END 

297 

298 

299class SSETransport(Transport): 

300 """Transport implementation using Server-Sent Events with proper session management. 

301 

302 This transport implementation uses Server-Sent Events (SSE) for real-time 

303 communication between the MCP gateway and clients. It provides streaming 

304 capabilities with automatic session management and keepalive support. 

305 

306 Examples: 

307 >>> # Create SSE transport with default URL 

308 >>> transport = SSETransport() 

309 >>> transport 

310 <mcpgateway.transports.sse_transport.SSETransport object at ...> 

311 

312 >>> # Create SSE transport with custom URL 

313 >>> transport = SSETransport("http://localhost:8080") 

314 >>> transport._base_url 

315 'http://localhost:8080' 

316 

317 >>> # Check initial connection state 

318 >>> import asyncio 

319 >>> asyncio.run(transport.is_connected()) 

320 False 

321 

322 >>> # Verify it's a proper Transport subclass 

323 >>> isinstance(transport, Transport) 

324 True 

325 >>> issubclass(SSETransport, Transport) 

326 True 

327 

328 >>> # Check session ID generation 

329 >>> transport.session_id 

330 '...' 

331 >>> len(transport.session_id) > 0 

332 True 

333 

334 >>> # Verify required methods exist 

335 >>> hasattr(transport, 'connect') 

336 True 

337 >>> hasattr(transport, 'disconnect') 

338 True 

339 >>> hasattr(transport, 'send_message') 

340 True 

341 >>> hasattr(transport, 'receive_message') 

342 True 

343 >>> hasattr(transport, 'is_connected') 

344 True 

345 """ 

346 

347 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars) 

348 _SESSION_ID_PATTERN = None 

349 

350 @staticmethod 

351 def _is_valid_session_id(session_id: str) -> bool: 

352 """Validate session ID format for security. 

353 

354 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars. 

355 This prevents injection attacks via malformed session IDs. 

356 

357 Args: 

358 session_id: The session ID to validate. 

359 

360 Returns: 

361 True if the session ID is valid, False otherwise. 

362 

363 Examples: 

364 >>> SSETransport._is_valid_session_id("abc123") 

365 True 

366 >>> SSETransport._is_valid_session_id("abc-123_def") 

367 True 

368 >>> SSETransport._is_valid_session_id("a" * 128) 

369 True 

370 >>> SSETransport._is_valid_session_id("a" * 129) 

371 False 

372 >>> SSETransport._is_valid_session_id("") 

373 False 

374 >>> SSETransport._is_valid_session_id("abc/def") 

375 False 

376 >>> SSETransport._is_valid_session_id("abc:def") 

377 False 

378 """ 

379 # Standard 

380 import re # pylint: disable=import-outside-toplevel 

381 

382 if not session_id or len(session_id) > 128: 

383 return False 

384 

385 # Lazy compile pattern 

386 if SSETransport._SESSION_ID_PATTERN is None: 

387 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") 

388 

389 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id)) 

390 

391 def __init__(self, base_url: str = None): 

392 """Initialize SSE transport. 

393 

394 Args: 

395 base_url: Base URL for client message endpoints 

396 

397 Examples: 

398 >>> # Test default initialization 

399 >>> transport = SSETransport() 

400 >>> transport._connected 

401 False 

402 >>> transport._message_queue is not None 

403 True 

404 >>> transport._client_gone is not None 

405 True 

406 >>> len(transport._session_id) > 0 

407 True 

408 

409 >>> # Test custom base URL 

410 >>> transport = SSETransport("https://api.example.com") 

411 >>> transport._base_url 

412 'https://api.example.com' 

413 

414 >>> # Test session ID uniqueness 

415 >>> transport1 = SSETransport() 

416 >>> transport2 = SSETransport() 

417 >>> transport1.session_id != transport2.session_id 

418 True 

419 """ 

420 self._base_url = base_url or f"http://{settings.host}:{settings.port}" 

421 self._connected = False 

422 self._message_queue = asyncio.Queue() 

423 self._client_gone = asyncio.Event() 

424 # Server generates session_id for SSE - client receives it in endpoint event 

425 self._session_id = str(uuid.uuid4()) 

426 

427 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id) 

428 

429 async def connect(self) -> None: 

430 """Set up SSE connection. 

431 

432 Examples: 

433 >>> # Test connection setup 

434 >>> transport = SSETransport() 

435 >>> import asyncio 

436 >>> asyncio.run(transport.connect()) 

437 >>> transport._connected 

438 True 

439 >>> asyncio.run(transport.is_connected()) 

440 True 

441 """ 

442 self._connected = True 

443 logger.info("SSE transport connected: %s", self._session_id) 

444 

445 async def disconnect(self) -> None: 

446 """Clean up SSE connection. 

447 

448 Examples: 

449 >>> # Test disconnection 

450 >>> transport = SSETransport() 

451 >>> import asyncio 

452 >>> asyncio.run(transport.connect()) 

453 >>> asyncio.run(transport.disconnect()) 

454 >>> transport._connected 

455 False 

456 >>> transport._client_gone.is_set() 

457 True 

458 >>> asyncio.run(transport.is_connected()) 

459 False 

460 

461 >>> # Test disconnection when already disconnected 

462 >>> transport = SSETransport() 

463 >>> asyncio.run(transport.disconnect()) 

464 >>> transport._connected 

465 False 

466 """ 

467 if self._connected: 

468 self._connected = False 

469 self._client_gone.set() 

470 logger.info("SSE transport disconnected: %s", self._session_id) 

471 

472 async def send_message(self, message: Dict[str, Any]) -> None: 

473 """Send a message over SSE. 

474 

475 Args: 

476 message: Message to send 

477 

478 Raises: 

479 RuntimeError: If transport is not connected 

480 Exception: If unable to put message to queue 

481 

482 Examples: 

483 >>> # Test sending message when connected 

484 >>> transport = SSETransport() 

485 >>> import asyncio 

486 >>> asyncio.run(transport.connect()) 

487 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1} 

488 >>> asyncio.run(transport.send_message(message)) 

489 >>> transport._message_queue.qsize() 

490 1 

491 

492 >>> # Test sending message when not connected 

493 >>> transport = SSETransport() 

494 >>> try: 

495 ... asyncio.run(transport.send_message({"test": "message"})) 

496 ... except RuntimeError as e: 

497 ... print("Expected error:", str(e)) 

498 Expected error: Transport not connected 

499 

500 >>> # Test message format validation 

501 >>> transport = SSETransport() 

502 >>> asyncio.run(transport.connect()) 

503 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}} 

504 >>> isinstance(valid_message, dict) 

505 True 

506 >>> "jsonrpc" in valid_message 

507 True 

508 

509 >>> # Test exception handling in queue put 

510 >>> transport = SSETransport() 

511 >>> asyncio.run(transport.connect()) 

512 >>> # Create a full queue to trigger exception 

513 >>> transport._message_queue = asyncio.Queue(maxsize=1) 

514 >>> asyncio.run(transport._message_queue.put({"dummy": "message"})) 

515 >>> # Now queue is full, next put should fail 

516 >>> try: 

517 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1)) 

518 ... except asyncio.TimeoutError: 

519 ... print("Queue full as expected") 

520 Queue full as expected 

521 """ 

522 if not self._connected: 

523 raise RuntimeError("Transport not connected") 

524 

525 try: 

526 await self._message_queue.put(message) 

527 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)")) 

528 except Exception as e: 

529 logger.error("Failed to queue message: %s", e) 

530 raise 

531 

532 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

533 """Receive messages from the client over SSE transport. 

534 

535 This method implements a continuous message-receiving pattern for SSE transport. 

536 Since SSE is primarily a server-to-client communication channel, this method 

537 yields an initial initialize placeholder message and then enters a waiting loop. 

538 The actual client messages are received via a separate HTTP POST endpoint 

539 (not handled in this method). 

540 

541 The method will continue running until either: 

542 1. The connection is explicitly disconnected (client_gone event is set) 

543 2. The receive loop is cancelled from outside 

544 

545 Yields: 

546 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always 

547 an initialize placeholder with the format: 

548 {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

549 

550 Raises: 

551 RuntimeError: If the transport is not connected when this method is called 

552 asyncio.CancelledError: When the SSE receive loop is cancelled externally 

553 

554 Examples: 

555 >>> # Test receive message when connected 

556 >>> transport = SSETransport() 

557 >>> import asyncio 

558 >>> asyncio.run(transport.connect()) 

559 >>> async def test_receive(): 

560 ... async for msg in transport.receive_message(): 

561 ... return msg 

562 ... return None 

563 >>> result = asyncio.run(test_receive()) 

564 >>> result 

565 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1} 

566 

567 >>> # Test receive message when not connected 

568 >>> transport = SSETransport() 

569 >>> try: 

570 ... async def test_receive(): 

571 ... async for msg in transport.receive_message(): 

572 ... pass 

573 ... asyncio.run(test_receive()) 

574 ... except RuntimeError as e: 

575 ... print("Expected error:", str(e)) 

576 Expected error: Transport not connected 

577 

578 >>> # Verify generator behavior 

579 >>> transport = SSETransport() 

580 >>> import inspect 

581 >>> inspect.isasyncgenfunction(transport.receive_message) 

582 True 

583 """ 

584 if not self._connected: 

585 raise RuntimeError("Transport not connected") 

586 

587 # For SSE, we set up a loop to wait for messages which are delivered via POST 

588 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder 

589 # to keep the receive loop running 

590 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

591 

592 # Continue waiting for cancellation 

593 try: 

594 while not self._client_gone.is_set(): 

595 await asyncio.sleep(1.0) 

596 except asyncio.CancelledError: 

597 logger.info("SSE receive loop cancelled for session %s", self._session_id) 

598 raise 

599 finally: 

600 logger.info("SSE receive loop ended for session %s", self._session_id) 

601 

602 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]: 

603 """Get message from queue with timeout, returns None on timeout. 

604 

605 Uses asyncio.wait() to avoid TimeoutError exception overhead. 

606 

607 Args: 

608 timeout: Timeout in seconds, or None for no timeout 

609 

610 Returns: 

611 Message dict if received, None if timeout occurred 

612 

613 Raises: 

614 asyncio.CancelledError: If the operation is cancelled externally 

615 """ 

616 if timeout is None: 

617 return await self._message_queue.get() 

618 

619 get_task = asyncio.create_task(self._message_queue.get()) 

620 try: 

621 done, _ = await asyncio.wait({get_task}, timeout=timeout) 

622 except asyncio.CancelledError: 

623 get_task.cancel() 

624 try: 

625 await get_task 

626 except asyncio.CancelledError: 

627 pass 

628 raise 

629 

630 if get_task in done: 

631 return get_task.result() 

632 

633 # Timeout - cancel pending task, but return the result if it completed in the race window. 

634 get_task.cancel() 

635 try: 

636 return await get_task 

637 except asyncio.CancelledError: 

638 return None 

639 

640 async def is_connected(self) -> bool: 

641 """Check if transport is connected. 

642 

643 Returns: 

644 True if connected 

645 

646 Examples: 

647 >>> # Test initial state 

648 >>> transport = SSETransport() 

649 >>> import asyncio 

650 >>> asyncio.run(transport.is_connected()) 

651 False 

652 

653 >>> # Test after connection 

654 >>> transport = SSETransport() 

655 >>> asyncio.run(transport.connect()) 

656 >>> asyncio.run(transport.is_connected()) 

657 True 

658 

659 >>> # Test after disconnection 

660 >>> transport = SSETransport() 

661 >>> asyncio.run(transport.connect()) 

662 >>> asyncio.run(transport.disconnect()) 

663 >>> asyncio.run(transport.is_connected()) 

664 False 

665 """ 

666 return self._connected 

667 

668 async def create_sse_response( 

669 self, 

670 request: Request, 

671 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None, 

672 ) -> EventSourceResponse: 

673 """Create SSE response for streaming. 

674 

675 Args: 

676 request: FastAPI request (used for disconnection detection) 

677 on_disconnect_callback: Optional async callback to run when client disconnects. 

678 Used for defensive cleanup (e.g., cancelling respond tasks). 

679 

680 Returns: 

681 SSE response object 

682 

683 Examples: 

684 >>> # Test SSE response creation 

685 >>> transport = SSETransport("http://localhost:8000") 

686 >>> # Note: This method requires a FastAPI Request object 

687 >>> # and cannot be easily tested in doctest environment 

688 >>> callable(transport.create_sse_response) 

689 True 

690 """ 

691 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}" 

692 

693 async def event_generator(): 

694 """Generate SSE events. 

695 

696 Yields: 

697 SSE event as bytes (pre-formatted SSE frame) 

698 """ 

699 # Send the endpoint event first 

700 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout) 

701 

702 # Send keepalive immediately to help establish connection (if enabled) 

703 if settings.sse_keepalive_enabled: 

704 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

705 

706 consecutive_errors = 0 

707 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected) 

708 

709 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected 

710 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window. 

711 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None 

712 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0 

713 last_yield_time = time.monotonic() 

714 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection 

715 

716 def check_rapid_yield() -> bool: 

717 """Check if yields are happening too fast. 

718 

719 Returns: 

720 True if spin loop detected and should disconnect, False otherwise. 

721 """ 

722 nonlocal last_yield_time, consecutive_rapid_yields 

723 now = time.monotonic() 

724 time_since_last = now - last_yield_time 

725 last_yield_time = now 

726 

727 # Track consecutive rapid yields (< 100ms apart) 

728 # This catches spin loops even without full deque analysis 

729 if time_since_last < 0.1: 

730 consecutive_rapid_yields += 1 

731 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin 

732 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id) 

733 return True 

734 else: 

735 consecutive_rapid_yields = 0 # Reset on normal-speed yield 

736 

737 # Also use deque-based detection for more nuanced analysis 

738 if yield_timestamps is None: 

739 return False 

740 yield_timestamps.append(now) 

741 if time_since_last < 0.01: # Less than 10ms between yields is very fast 

742 if len(yield_timestamps) > settings.sse_rapid_yield_max: 

743 oldest = yield_timestamps[0] 

744 elapsed = now - oldest 

745 if elapsed < rapid_yield_window_sec: 

746 logger.error( 

747 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id 

748 ) 

749 return True 

750 return False 

751 

752 try: 

753 while not self._client_gone.is_set(): 

754 # Check if client has disconnected via request state 

755 if await request.is_disconnected(): 

756 logger.info("SSE client disconnected (detected via request): %s", self._session_id) 

757 self._client_gone.set() 

758 break 

759 

760 try: 

761 # Use timeout-based polling only when keepalive is enabled 

762 if not settings.sse_keepalive_enabled: 

763 message = await self._message_queue.get() 

764 else: 

765 message = await self._get_message_with_timeout(settings.sse_keepalive_interval) 

766 

767 if message is not None: 

768 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY) 

769 

770 if logger.isEnabledFor(logging.DEBUG): 

771 logger.debug("Sending SSE message: %s", json_bytes.decode()) 

772 

773 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout) 

774 consecutive_errors = 0 # Reset on successful send 

775 

776 # Check for rapid yields after message send 

777 if check_rapid_yield(): 

778 self._client_gone.set() 

779 break 

780 elif settings.sse_keepalive_enabled: 

781 # Timeout - send keepalive 

782 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

783 consecutive_errors = 0 # Reset on successful send 

784 # Check for rapid yields after keepalive too 

785 if check_rapid_yield(): 

786 self._client_gone.set() 

787 break 

788 # Note: We don't clear yield_timestamps here. The deque has maxlen 

789 # which automatically drops old entries. Clearing would prevent 

790 # detection of spin loops where yields happen faster than expected. 

791 

792 except GeneratorExit: 

793 # Client disconnected - generator is being closed 

794 logger.info("SSE generator exit (client disconnected): %s", self._session_id) 

795 self._client_gone.set() 

796 break 

797 except Exception as e: 

798 consecutive_errors += 1 

799 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e) 

800 if consecutive_errors >= max_consecutive_errors: 

801 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id) 

802 self._client_gone.set() 

803 break 

804 # Don't yield error frame - it could cause more errors if client is gone 

805 

806 except asyncio.CancelledError: 

807 logger.info("SSE event generator cancelled: %s", self._session_id) 

808 self._client_gone.set() 

809 except GeneratorExit: 

810 logger.info("SSE generator exit: %s", self._session_id) 

811 self._client_gone.set() 

812 except Exception as e: 

813 logger.error("SSE event generator error: %s", e) 

814 self._client_gone.set() 

815 finally: 

816 logger.info("SSE event generator completed: %s", self._session_id) 

817 self._client_gone.set() # Always set client_gone on exit to clean up 

818 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3) 

819 # This covers server-initiated close, errors, and cancellation - not just client close 

820 if on_disconnect_callback: 

821 try: 

822 await on_disconnect_callback() 

823 except Exception as e: 

824 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e) 

825 

826 async def on_client_close(_scope: dict) -> None: 

827 """Handle client close event from sse_starlette.""" 

828 logger.info("SSE client close handler called: %s", self._session_id) 

829 self._client_gone.set() 

830 

831 # Defensive cleanup via callback (if provided) 

832 if on_disconnect_callback: 

833 try: 

834 await on_disconnect_callback() 

835 except Exception as e: 

836 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e) 

837 

838 return EventSourceResponse( 

839 event_generator(), 

840 status_code=200, 

841 headers={ 

842 "Cache-Control": "no-cache", 

843 "Connection": "keep-alive", 

844 "Content-Type": "text/event-stream", 

845 "X-MCP-SSE": "true", 

846 }, 

847 # Timeout for ASGI send() calls - protects against sends that hang indefinitely 

848 # when client connection is in a bad state (e.g., client stopped reading but TCP 

849 # connection not yet closed). Does NOT affect MCP server response times. 

850 # Set to 0 to disable. Default matches keepalive interval. 

851 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None, 

852 # Callback when client closes - helps detect disconnects that ASGI server 

853 # may not properly propagate via request.is_disconnected() 

854 client_close_handler_callable=on_client_close, 

855 ) 

856 

857 async def _client_disconnected(self, _request: Request) -> bool: 

858 """Check if client has disconnected. 

859 

860 Args: 

861 _request: FastAPI Request object 

862 

863 Returns: 

864 bool: True if client disconnected 

865 

866 Examples: 

867 >>> # Test client disconnected check 

868 >>> transport = SSETransport() 

869 >>> import asyncio 

870 >>> asyncio.run(transport._client_disconnected(None)) 

871 False 

872 

873 >>> # Test after setting client gone 

874 >>> transport = SSETransport() 

875 >>> transport._client_gone.set() 

876 >>> asyncio.run(transport._client_disconnected(None)) 

877 True 

878 """ 

879 # We only check our internal client_gone flag 

880 # We intentionally don't check connection_lost on the request 

881 # as it can be unreliable and cause premature closures 

882 return self._client_gone.is_set() 

883 

884 @property 

885 def session_id(self) -> str: 

886 """ 

887 Get the session ID for this transport. 

888 

889 Returns: 

890 str: session_id 

891 

892 Examples: 

893 >>> # Test session ID property 

894 >>> transport = SSETransport() 

895 >>> session_id = transport.session_id 

896 >>> isinstance(session_id, str) 

897 True 

898 >>> len(session_id) > 0 

899 True 

900 >>> session_id == transport._session_id 

901 True 

902 

903 >>> # Test session ID uniqueness 

904 >>> transport1 = SSETransport() 

905 >>> transport2 = SSETransport() 

906 >>> transport1.session_id != transport2.session_id 

907 True 

908 """ 

909 return self._session_id