Coverage for mcpgateway / transports / sse_transport.py: 100%

255 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/sse_transport.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7SSE Transport Implementation. 

8This module implements Server-Sent Events (SSE) transport for MCP, 

9providing server-to-client streaming with proper session management. 

10""" 

11 

12# Standard 

13import asyncio 

14from collections import deque 

15import logging 

16import time 

17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional 

18import uuid 

19 

20# Third-Party 

21import anyio 

22from anyio._backends._asyncio import CancelScope 

23from fastapi import Request 

24import orjson 

25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse 

26from starlette.types import Receive, Scope, Send 

27 

28# First-Party 

29from mcpgateway.config import settings 

30from mcpgateway.services.logging_service import LoggingService 

31from mcpgateway.transports.base import Transport 

32from mcpgateway.utils.trace_context import set_trace_session_id 

33 

34# Initialize logging service first 

35logging_service = LoggingService() 

36logger = logging_service.get_logger(__name__) 

37 

38 

39# ============================================================================= 

40# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695) 

41# ============================================================================= 

42# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond 

43# to CancelledError. This optional monkey-patch adds a max iteration limit to 

44# prevent indefinite spinning. After the limit is reached, we give up delivering 

45# cancellation and let the scope exit (tasks will be orphaned but won't spin). 

46# 

47# This workaround is DISABLED by default. Enable via: 

48# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true 

49# 

50# Trade-offs: 

51# - Prevents indefinite CPU spin (good) 

52# - May leave some tasks uncancelled after max iterations (usually harmless) 

53# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks 

54# 

55# This workaround may be removed when anyio or MCP SDK fix the underlying issue. 

56# See: https://github.com/agronholm/anyio/issues/695 

57# ============================================================================= 

58 

59# Store original for potential restoration and for the patch to call 

60_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access 

61_patch_applied = False # pylint: disable=invalid-name 

62 

63 

64def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901 

65 """Create a patched _deliver_cancellation with configurable max iterations. 

66 

67 Args: 

68 max_iterations: Maximum iterations before giving up cancellation delivery. 

69 

70 Returns: 

71 Patched function that limits cancellation delivery iterations. 

72 """ 

73 

74 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access 

75 """Patched _deliver_cancellation with max iteration limit to prevent spin. 

76 

77 This wraps anyio's original _deliver_cancellation to track iteration count 

78 and give up after a maximum number of attempts. This prevents the CPU spin 

79 loop that occurs when tasks don't respond to CancelledError. 

80 

81 Args: 

82 self: The cancel scope being processed. 

83 origin: The cancel scope that originated the cancellation. 

84 

85 Returns: 

86 True if delivery should be retried, False if done or max iterations reached. 

87 """ 

88 # Track iteration count on the origin scope (the one that initiated cancel) 

89 if not hasattr(origin, "_delivery_iterations"): 

90 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access 

91 

92 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access 

93 

94 # Check if we've exceeded the maximum iterations 

95 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access 

96 # Log warning and give up - this prevents indefinite spin 

97 logger.warning( 

98 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. " 

99 "Some tasks may not have been properly cancelled. " 

100 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.", 

101 max_iterations, 

102 ) 

103 # Clear the cancel handle to stop further retries 

104 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access 

105 self._cancel_handle = None # pylint: disable=protected-access 

106 return False # Don't retry 

107 

108 # Call the original implementation 

109 return _original_deliver_cancellation(self, origin) 

110 

111 return _patched_deliver_cancellation 

112 

113 

114def apply_anyio_cancel_delivery_patch() -> bool: 

115 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config. 

116 

117 This function is idempotent - calling it multiple times has no additional effect. 

118 

119 Returns: 

120 True if patch was applied (or already applied), False if disabled. 

121 """ 

122 global _patch_applied # pylint: disable=global-statement 

123 

124 if _patch_applied: 

125 return True 

126 

127 try: 

128 if not settings.anyio_cancel_delivery_patch_enabled: 

129 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.") 

130 return False 

131 

132 max_iterations = settings.anyio_cancel_delivery_max_iterations 

133 patched_func = _create_patched_deliver_cancellation(max_iterations) 

134 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access 

135 _patch_applied = True 

136 

137 logger.info( 

138 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). " 

139 "This is an experimental workaround for anyio#695. " 

140 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.", 

141 max_iterations, 

142 ) 

143 return True 

144 

145 except Exception as e: 

146 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e) 

147 return False 

148 

149 

150def remove_anyio_cancel_delivery_patch() -> bool: 

151 """Remove the anyio _deliver_cancellation monkey-patch. 

152 

153 Restores the original anyio implementation. 

154 

155 Returns: 

156 True if patch was removed, False if it wasn't applied. 

157 """ 

158 global _patch_applied # pylint: disable=global-statement 

159 

160 if not _patch_applied: 

161 return False 

162 

163 try: 

164 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access 

165 _patch_applied = False 

166 logger.info("anyio _deliver_cancellation patch removed - restored original implementation") 

167 return True 

168 except Exception as e: 

169 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e) 

170 return False 

171 

172 

173# Apply patch at module load time if enabled 

174apply_anyio_cancel_delivery_patch() 

175 

176 

177def _get_sse_cleanup_timeout() -> float: 

178 """Get SSE task group cleanup timeout from config. 

179 

180 This timeout controls how long to wait for SSE task group tasks to respond 

181 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's 

182 _deliver_cancellation when tasks don't properly handle CancelledError. 

183 

184 Returns: 

185 Cleanup timeout in seconds (default: 5.0) 

186 """ 

187 try: 

188 return settings.sse_task_group_cleanup_timeout 

189 except Exception: 

190 return 5.0 # Fallback default 

191 

192 

193class EventSourceResponse(BaseEventSourceResponse): 

194 """Patched EventSourceResponse with CPU spin detection. 

195 

196 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation 

197 spins at 100% CPU when tasks in the SSE task group don't respond to 

198 cancellation. 

199 

200 Instead of trying to timeout the task group (which would affect normal 

201 SSE connections), we copy the __call__ method and add a deadline to 

202 the cancel scope to ensure cleanup doesn't hang indefinitely. 

203 

204 See: 

205 - https://github.com/agronholm/anyio/issues/695 

206 - https://github.com/anthropics/claude-agent-sdk-python/issues/378 

207 """ 

208 

209 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002 

210 """Enable compression (no-op for SSE streams). 

211 

212 SSE streams don't support compression as per sse_starlette. 

213 This override prevents NotImplementedError from parent class. 

214 

215 Args: 

216 force: Ignored - compression not supported for SSE. 

217 """ 

218 

219 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

220 """Handle SSE request with cancel scope deadline to prevent spin. 

221 

222 This method is copied from sse_starlette with one key modification: 

223 the task group's cancel_scope gets a deadline set when cancellation 

224 starts, preventing indefinite spinning if tasks don't respond. 

225 

226 Args: 

227 scope: ASGI scope dictionary. 

228 receive: ASGI receive callable. 

229 send: ASGI send callable. 

230 """ 

231 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix 

232 async with anyio.create_task_group() as task_group: 

233 # Add deadline to cancel scope to prevent indefinite spin on cleanup 

234 # The deadline is set far in the future initially, and only becomes 

235 # relevant if the scope is cancelled and cleanup takes too long 

236 

237 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None: 

238 """Execute coroutine then cancel task group with bounded deadline. 

239 

240 This wrapper runs the given coroutine and, upon completion, cancels 

241 the parent task group with a deadline to prevent indefinite spinning 

242 if other tasks don't respond to cancellation (anyio#695 mitigation). 

243 

244 Args: 

245 coro: Async callable to execute before triggering cancellation. 

246 """ 

247 await coro() 

248 # When cancelling, set a deadline to prevent indefinite spin 

249 # if other tasks don't respond to cancellation 

250 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout() 

251 task_group.cancel_scope.cancel() 

252 

253 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send)) 

254 task_group.start_soon(cancel_on_finish, lambda: self._ping(send)) 

255 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal) 

256 

257 if self.data_sender_callable: 

258 task_group.start_soon(self.data_sender_callable) 

259 

260 # Wait for the client to disconnect last 

261 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive)) 

262 

263 if self.background is not None: 

264 await self.background() 

265 

266 

267# Pre-computed SSE frame components for performance 

268_SSE_EVENT_PREFIX = b"event: " 

269_SSE_DATA_PREFIX = b"\r\ndata: " 

270_SSE_RETRY_PREFIX = b"\r\nretry: " 

271_SSE_FRAME_END = b"\r\n\r\n" 

272 

273 

274def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes: 

275 """Build SSE frame as bytes to avoid encode/decode overhead. 

276 

277 Args: 

278 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error') 

279 data: JSON data as bytes (from orjson.dumps) 

280 retry: Retry timeout in milliseconds 

281 

282 Returns: 

283 Complete SSE frame as bytes 

284 

285 Note: 

286 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's 

287 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever 

288 needed, this function would need to accept a sep parameter. 

289 

290 Examples: 

291 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000) 

292 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n' 

293 

294 >>> _build_sse_frame(b"keepalive", b"{}", 15000) 

295 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n' 

296 """ 

297 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END 

298 

299 

300class SSETransport(Transport): 

301 """Transport implementation using Server-Sent Events with proper session management. 

302 

303 This transport implementation uses Server-Sent Events (SSE) for real-time 

304 communication between the MCP gateway and clients. It provides streaming 

305 capabilities with automatic session management and keepalive support. 

306 

307 Examples: 

308 >>> # Create SSE transport with default URL 

309 >>> transport = SSETransport() 

310 >>> transport 

311 <mcpgateway.transports.sse_transport.SSETransport object at ...> 

312 

313 >>> # Create SSE transport with custom URL 

314 >>> transport = SSETransport("http://localhost:8080") 

315 >>> transport._base_url 

316 'http://localhost:8080' 

317 

318 >>> # Check initial connection state 

319 >>> import asyncio 

320 >>> asyncio.run(transport.is_connected()) 

321 False 

322 

323 >>> # Verify it's a proper Transport subclass 

324 >>> isinstance(transport, Transport) 

325 True 

326 >>> issubclass(SSETransport, Transport) 

327 True 

328 

329 >>> # Check session ID generation 

330 >>> transport.session_id 

331 '...' 

332 >>> len(transport.session_id) > 0 

333 True 

334 

335 >>> # Verify required methods exist 

336 >>> hasattr(transport, 'connect') 

337 True 

338 >>> hasattr(transport, 'disconnect') 

339 True 

340 >>> hasattr(transport, 'send_message') 

341 True 

342 >>> hasattr(transport, 'receive_message') 

343 True 

344 >>> hasattr(transport, 'is_connected') 

345 True 

346 """ 

347 

348 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars) 

349 _SESSION_ID_PATTERN = None 

350 

351 @staticmethod 

352 def _is_valid_session_id(session_id: str) -> bool: 

353 """Validate session ID format for security. 

354 

355 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars. 

356 This prevents injection attacks via malformed session IDs. 

357 

358 Args: 

359 session_id: The session ID to validate. 

360 

361 Returns: 

362 True if the session ID is valid, False otherwise. 

363 

364 Examples: 

365 >>> SSETransport._is_valid_session_id("abc123") 

366 True 

367 >>> SSETransport._is_valid_session_id("abc-123_def") 

368 True 

369 >>> SSETransport._is_valid_session_id("a" * 128) 

370 True 

371 >>> SSETransport._is_valid_session_id("a" * 129) 

372 False 

373 >>> SSETransport._is_valid_session_id("") 

374 False 

375 >>> SSETransport._is_valid_session_id("abc/def") 

376 False 

377 >>> SSETransport._is_valid_session_id("abc:def") 

378 False 

379 """ 

380 # Standard 

381 import re # pylint: disable=import-outside-toplevel 

382 

383 if not session_id or len(session_id) > 128: 

384 return False 

385 

386 # Lazy compile pattern 

387 if SSETransport._SESSION_ID_PATTERN is None: 

388 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$") 

389 

390 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id)) 

391 

392 def __init__(self, base_url: str = None): 

393 """Initialize SSE transport. 

394 

395 Args: 

396 base_url: Base URL for client message endpoints 

397 

398 Examples: 

399 >>> # Test default initialization 

400 >>> transport = SSETransport() 

401 >>> transport._connected 

402 False 

403 >>> transport._message_queue is not None 

404 True 

405 >>> transport._client_gone is not None 

406 True 

407 >>> len(transport._session_id) > 0 

408 True 

409 

410 >>> # Test custom base URL 

411 >>> transport = SSETransport("https://api.example.com") 

412 >>> transport._base_url 

413 'https://api.example.com' 

414 

415 >>> # Test session ID uniqueness 

416 >>> transport1 = SSETransport() 

417 >>> transport2 = SSETransport() 

418 >>> transport1.session_id != transport2.session_id 

419 True 

420 """ 

421 self._base_url = base_url or f"http://{settings.host}:{settings.port}" 

422 self._connected = False 

423 self._message_queue = asyncio.Queue() 

424 self._client_gone = asyncio.Event() 

425 # Server generates session_id for SSE - client receives it in endpoint event 

426 self._session_id = str(uuid.uuid4()) 

427 set_trace_session_id(self._session_id) 

428 

429 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id) 

430 

431 async def connect(self) -> None: 

432 """Set up SSE connection. 

433 

434 Examples: 

435 >>> # Test connection setup 

436 >>> transport = SSETransport() 

437 >>> import asyncio 

438 >>> asyncio.run(transport.connect()) 

439 >>> transport._connected 

440 True 

441 >>> asyncio.run(transport.is_connected()) 

442 True 

443 """ 

444 self._connected = True 

445 logger.info("SSE transport connected: %s", self._session_id) 

446 

447 async def disconnect(self) -> None: 

448 """Clean up SSE connection. 

449 

450 Examples: 

451 >>> # Test disconnection 

452 >>> transport = SSETransport() 

453 >>> import asyncio 

454 >>> asyncio.run(transport.connect()) 

455 >>> asyncio.run(transport.disconnect()) 

456 >>> transport._connected 

457 False 

458 >>> transport._client_gone.is_set() 

459 True 

460 >>> asyncio.run(transport.is_connected()) 

461 False 

462 

463 >>> # Test disconnection when already disconnected 

464 >>> transport = SSETransport() 

465 >>> asyncio.run(transport.disconnect()) 

466 >>> transport._connected 

467 False 

468 """ 

469 if self._connected: 

470 self._connected = False 

471 self._client_gone.set() 

472 logger.info("SSE transport disconnected: %s", self._session_id) 

473 

474 async def send_message(self, message: Dict[str, Any]) -> None: 

475 """Send a message over SSE. 

476 

477 Args: 

478 message: Message to send 

479 

480 Raises: 

481 RuntimeError: If transport is not connected 

482 Exception: If unable to put message to queue 

483 

484 Examples: 

485 >>> # Test sending message when connected 

486 >>> transport = SSETransport() 

487 >>> import asyncio 

488 >>> asyncio.run(transport.connect()) 

489 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1} 

490 >>> asyncio.run(transport.send_message(message)) 

491 >>> transport._message_queue.qsize() 

492 1 

493 

494 >>> # Test sending message when not connected 

495 >>> transport = SSETransport() 

496 >>> try: 

497 ... asyncio.run(transport.send_message({"test": "message"})) 

498 ... except RuntimeError as e: 

499 ... print("Expected error:", str(e)) 

500 Expected error: Transport not connected 

501 

502 >>> # Test message format validation 

503 >>> transport = SSETransport() 

504 >>> asyncio.run(transport.connect()) 

505 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}} 

506 >>> isinstance(valid_message, dict) 

507 True 

508 >>> "jsonrpc" in valid_message 

509 True 

510 

511 >>> # Test exception handling in queue put 

512 >>> transport = SSETransport() 

513 >>> asyncio.run(transport.connect()) 

514 >>> # Create a full queue to trigger exception 

515 >>> transport._message_queue = asyncio.Queue(maxsize=1) 

516 >>> asyncio.run(transport._message_queue.put({"dummy": "message"})) 

517 >>> # Now queue is full, next put should fail 

518 >>> try: 

519 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1)) 

520 ... except asyncio.TimeoutError: 

521 ... print("Queue full as expected") 

522 Queue full as expected 

523 """ 

524 if not self._connected: 

525 raise RuntimeError("Transport not connected") 

526 

527 try: 

528 await self._message_queue.put(message) 

529 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)")) 

530 except Exception as e: 

531 logger.error("Failed to queue message: %s", e) 

532 raise 

533 

534 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

535 """Receive messages from the client over SSE transport. 

536 

537 This method implements a continuous message-receiving pattern for SSE transport. 

538 Since SSE is primarily a server-to-client communication channel, this method 

539 yields an initial initialize placeholder message and then enters a waiting loop. 

540 The actual client messages are received via a separate HTTP POST endpoint 

541 (not handled in this method). 

542 

543 The method will continue running until either: 

544 1. The connection is explicitly disconnected (client_gone event is set) 

545 2. The receive loop is cancelled from outside 

546 

547 Yields: 

548 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always 

549 an initialize placeholder with the format: 

550 {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

551 

552 Raises: 

553 RuntimeError: If the transport is not connected when this method is called 

554 asyncio.CancelledError: When the SSE receive loop is cancelled externally 

555 

556 Examples: 

557 >>> # Test receive message when connected 

558 >>> transport = SSETransport() 

559 >>> import asyncio 

560 >>> asyncio.run(transport.connect()) 

561 >>> async def test_receive(): 

562 ... async for msg in transport.receive_message(): 

563 ... return msg 

564 ... return None 

565 >>> result = asyncio.run(test_receive()) 

566 >>> result 

567 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1} 

568 

569 >>> # Test receive message when not connected 

570 >>> transport = SSETransport() 

571 >>> try: 

572 ... async def test_receive(): 

573 ... async for msg in transport.receive_message(): 

574 ... pass 

575 ... asyncio.run(test_receive()) 

576 ... except RuntimeError as e: 

577 ... print("Expected error:", str(e)) 

578 Expected error: Transport not connected 

579 

580 >>> # Verify generator behavior 

581 >>> transport = SSETransport() 

582 >>> import inspect 

583 >>> inspect.isasyncgenfunction(transport.receive_message) 

584 True 

585 """ 

586 if not self._connected: 

587 raise RuntimeError("Transport not connected") 

588 

589 # For SSE, we set up a loop to wait for messages which are delivered via POST 

590 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder 

591 # to keep the receive loop running 

592 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1} 

593 

594 # Continue waiting for cancellation 

595 try: 

596 while not self._client_gone.is_set(): 

597 await asyncio.sleep(1.0) 

598 except asyncio.CancelledError: 

599 logger.info("SSE receive loop cancelled for session %s", self._session_id) 

600 raise 

601 finally: 

602 logger.info("SSE receive loop ended for session %s", self._session_id) 

603 

604 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]: 

605 """Get message from queue with timeout, returns None on timeout. 

606 

607 Uses asyncio.wait() to avoid TimeoutError exception overhead. 

608 

609 Args: 

610 timeout: Timeout in seconds, or None for no timeout 

611 

612 Returns: 

613 Message dict if received, None if timeout occurred 

614 

615 Raises: 

616 asyncio.CancelledError: If the operation is cancelled externally 

617 """ 

618 if timeout is None: 

619 return await self._message_queue.get() 

620 

621 get_task = asyncio.create_task(self._message_queue.get()) 

622 try: 

623 done, _ = await asyncio.wait({get_task}, timeout=timeout) 

624 except asyncio.CancelledError: 

625 get_task.cancel() 

626 try: 

627 await get_task 

628 except asyncio.CancelledError: 

629 pass 

630 raise 

631 

632 if get_task in done: 

633 return get_task.result() 

634 

635 # Timeout - cancel pending task, but return the result if it completed in the race window. 

636 get_task.cancel() 

637 try: 

638 return await get_task 

639 except asyncio.CancelledError: 

640 return None 

641 

642 async def is_connected(self) -> bool: 

643 """Check if transport is connected. 

644 

645 Returns: 

646 True if connected 

647 

648 Examples: 

649 >>> # Test initial state 

650 >>> transport = SSETransport() 

651 >>> import asyncio 

652 >>> asyncio.run(transport.is_connected()) 

653 False 

654 

655 >>> # Test after connection 

656 >>> transport = SSETransport() 

657 >>> asyncio.run(transport.connect()) 

658 >>> asyncio.run(transport.is_connected()) 

659 True 

660 

661 >>> # Test after disconnection 

662 >>> transport = SSETransport() 

663 >>> asyncio.run(transport.connect()) 

664 >>> asyncio.run(transport.disconnect()) 

665 >>> asyncio.run(transport.is_connected()) 

666 False 

667 """ 

668 return self._connected 

669 

670 async def create_sse_response( 

671 self, 

672 request: Request, 

673 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None, 

674 ) -> EventSourceResponse: 

675 """Create SSE response for streaming. 

676 

677 Args: 

678 request: FastAPI request (used for disconnection detection) 

679 on_disconnect_callback: Optional async callback to run when client disconnects. 

680 Used for defensive cleanup (e.g., cancelling respond tasks). 

681 

682 Returns: 

683 SSE response object 

684 

685 Examples: 

686 >>> # Test SSE response creation 

687 >>> transport = SSETransport("http://localhost:8000") 

688 >>> # Note: This method requires a FastAPI Request object 

689 >>> # and cannot be easily tested in doctest environment 

690 >>> callable(transport.create_sse_response) 

691 True 

692 """ 

693 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}" 

694 

695 async def event_generator(): 

696 """Generate SSE events. 

697 

698 Yields: 

699 SSE event as bytes (pre-formatted SSE frame) 

700 

701 Raises: 

702 asyncio.CancelledError: If the generator is cancelled. 

703 """ 

704 # Send the endpoint event first 

705 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout) 

706 

707 # Send keepalive immediately to help establish connection (if enabled) 

708 if settings.sse_keepalive_enabled: 

709 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

710 

711 consecutive_errors = 0 

712 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected) 

713 

714 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected 

715 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window. 

716 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None 

717 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0 

718 last_yield_time = time.monotonic() 

719 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection 

720 

721 def check_rapid_yield() -> bool: 

722 """Check if yields are happening too fast. 

723 

724 Returns: 

725 True if spin loop detected and should disconnect, False otherwise. 

726 """ 

727 nonlocal last_yield_time, consecutive_rapid_yields 

728 now = time.monotonic() 

729 time_since_last = now - last_yield_time 

730 last_yield_time = now 

731 

732 # Track consecutive rapid yields (< 100ms apart) 

733 # This catches spin loops even without full deque analysis 

734 if time_since_last < 0.1: 

735 consecutive_rapid_yields += 1 

736 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin 

737 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id) 

738 return True 

739 else: 

740 consecutive_rapid_yields = 0 # Reset on normal-speed yield 

741 

742 # Also use deque-based detection for more nuanced analysis 

743 if yield_timestamps is None: 

744 return False 

745 yield_timestamps.append(now) 

746 if time_since_last < 0.01: # Less than 10ms between yields is very fast 

747 if len(yield_timestamps) > settings.sse_rapid_yield_max: 

748 oldest = yield_timestamps[0] 

749 elapsed = now - oldest 

750 if elapsed < rapid_yield_window_sec: 

751 logger.error( 

752 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id 

753 ) 

754 return True 

755 return False 

756 

757 try: 

758 while not self._client_gone.is_set(): 

759 # Check if client has disconnected via request state 

760 if await request.is_disconnected(): 

761 logger.info("SSE client disconnected (detected via request): %s", self._session_id) 

762 self._client_gone.set() 

763 break 

764 

765 try: 

766 # Use timeout-based polling only when keepalive is enabled 

767 if not settings.sse_keepalive_enabled: 

768 message = await self._message_queue.get() 

769 else: 

770 message = await self._get_message_with_timeout(settings.sse_keepalive_interval) 

771 

772 if message is not None: 

773 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY) 

774 

775 if logger.isEnabledFor(logging.DEBUG): 

776 logger.debug("Sending SSE message: %s", json_bytes.decode()) 

777 

778 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout) 

779 consecutive_errors = 0 # Reset on successful send 

780 

781 # Check for rapid yields after message send 

782 if check_rapid_yield(): 

783 self._client_gone.set() 

784 break 

785 elif settings.sse_keepalive_enabled: 

786 # Timeout - send keepalive 

787 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout) 

788 consecutive_errors = 0 # Reset on successful send 

789 # Check for rapid yields after keepalive too 

790 if check_rapid_yield(): 

791 self._client_gone.set() 

792 break 

793 # Note: We don't clear yield_timestamps here. The deque has maxlen 

794 # which automatically drops old entries. Clearing would prevent 

795 # detection of spin loops where yields happen faster than expected. 

796 

797 except GeneratorExit: 

798 # Client disconnected - generator is being closed 

799 logger.info("SSE generator exit (client disconnected): %s", self._session_id) 

800 self._client_gone.set() 

801 break 

802 except Exception as e: 

803 consecutive_errors += 1 

804 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e) 

805 if consecutive_errors >= max_consecutive_errors: 

806 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id) 

807 self._client_gone.set() 

808 break 

809 # Don't yield error frame - it could cause more errors if client is gone 

810 

811 except asyncio.CancelledError: 

812 logger.info("SSE event generator cancelled: %s", self._session_id) 

813 self._client_gone.set() 

814 raise 

815 except GeneratorExit: 

816 logger.info("SSE generator exit: %s", self._session_id) 

817 self._client_gone.set() 

818 except Exception as e: 

819 logger.error("SSE event generator error: %s", e) 

820 self._client_gone.set() 

821 finally: 

822 logger.info("SSE event generator completed: %s", self._session_id) 

823 self._client_gone.set() # Always set client_gone on exit to clean up 

824 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3) 

825 # This covers server-initiated close, errors, and cancellation - not just client close 

826 if on_disconnect_callback: 

827 try: 

828 await on_disconnect_callback() 

829 except Exception as e: 

830 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e) 

831 

832 async def on_client_close(_scope: dict) -> None: 

833 """Handle client close event from sse_starlette.""" 

834 logger.info("SSE client close handler called: %s", self._session_id) 

835 self._client_gone.set() 

836 

837 # Defensive cleanup via callback (if provided) 

838 if on_disconnect_callback: 

839 try: 

840 await on_disconnect_callback() 

841 except Exception as e: 

842 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e) 

843 

844 return EventSourceResponse( 

845 event_generator(), 

846 status_code=200, 

847 headers={ 

848 "Cache-Control": "no-cache", 

849 "Connection": "keep-alive", 

850 "Content-Type": "text/event-stream", 

851 "X-MCP-SSE": "true", 

852 }, 

853 # Timeout for ASGI send() calls - protects against sends that hang indefinitely 

854 # when client connection is in a bad state (e.g., client stopped reading but TCP 

855 # connection not yet closed). Does NOT affect MCP server response times. 

856 # Set to 0 to disable. Default matches keepalive interval. 

857 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None, 

858 # Callback when client closes - helps detect disconnects that ASGI server 

859 # may not properly propagate via request.is_disconnected() 

860 client_close_handler_callable=on_client_close, 

861 ) 

862 

863 async def _client_disconnected(self, _request: Request) -> bool: 

864 """Check if client has disconnected. 

865 

866 Args: 

867 _request: FastAPI Request object 

868 

869 Returns: 

870 bool: True if client disconnected 

871 

872 Examples: 

873 >>> # Test client disconnected check 

874 >>> transport = SSETransport() 

875 >>> import asyncio 

876 >>> asyncio.run(transport._client_disconnected(None)) 

877 False 

878 

879 >>> # Test after setting client gone 

880 >>> transport = SSETransport() 

881 >>> transport._client_gone.set() 

882 >>> asyncio.run(transport._client_disconnected(None)) 

883 True 

884 """ 

885 # We only check our internal client_gone flag 

886 # We intentionally don't check connection_lost on the request 

887 # as it can be unreliable and cause premature closures 

888 return self._client_gone.is_set() 

889 

890 @property 

891 def session_id(self) -> str: 

892 """ 

893 Get the session ID for this transport. 

894 

895 Returns: 

896 str: session_id 

897 

898 Examples: 

899 >>> # Test session ID property 

900 >>> transport = SSETransport() 

901 >>> session_id = transport.session_id 

902 >>> isinstance(session_id, str) 

903 True 

904 >>> len(session_id) > 0 

905 True 

906 >>> session_id == transport._session_id 

907 True 

908 

909 >>> # Test session ID uniqueness 

910 >>> transport1 = SSETransport() 

911 >>> transport2 = SSETransport() 

912 >>> transport1.session_id != transport2.session_id 

913 True 

914 """ 

915 return self._session_id