Coverage for mcpgateway / transports / sse_transport.py: 100%
253 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/sse_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7SSE Transport Implementation.
8This module implements Server-Sent Events (SSE) transport for MCP,
9providing server-to-client streaming with proper session management.
10"""
12# Standard
13import asyncio
14from collections import deque
15import logging
16import time
17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional
18import uuid
20# Third-Party
21import anyio
22from anyio._backends._asyncio import CancelScope
23from fastapi import Request
24import orjson
25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse
26from starlette.types import Receive, Scope, Send
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.services.logging_service import LoggingService
31from mcpgateway.transports.base import Transport
33# Initialize logging service first
34logging_service = LoggingService()
35logger = logging_service.get_logger(__name__)
38# =============================================================================
39# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695)
40# =============================================================================
41# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond
42# to CancelledError. This optional monkey-patch adds a max iteration limit to
43# prevent indefinite spinning. After the limit is reached, we give up delivering
44# cancellation and let the scope exit (tasks will be orphaned but won't spin).
45#
46# This workaround is DISABLED by default. Enable via:
47# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true
48#
49# Trade-offs:
50# - Prevents indefinite CPU spin (good)
51# - May leave some tasks uncancelled after max iterations (usually harmless)
52# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks
53#
54# This workaround may be removed when anyio or MCP SDK fix the underlying issue.
55# See: https://github.com/agronholm/anyio/issues/695
56# =============================================================================
58# Store original for potential restoration and for the patch to call
59_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access
60_patch_applied = False # pylint: disable=invalid-name
63def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901
64 """Create a patched _deliver_cancellation with configurable max iterations.
66 Args:
67 max_iterations: Maximum iterations before giving up cancellation delivery.
69 Returns:
70 Patched function that limits cancellation delivery iterations.
71 """
73 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access
74 """Patched _deliver_cancellation with max iteration limit to prevent spin.
76 This wraps anyio's original _deliver_cancellation to track iteration count
77 and give up after a maximum number of attempts. This prevents the CPU spin
78 loop that occurs when tasks don't respond to CancelledError.
80 Args:
81 self: The cancel scope being processed.
82 origin: The cancel scope that originated the cancellation.
84 Returns:
85 True if delivery should be retried, False if done or max iterations reached.
86 """
87 # Track iteration count on the origin scope (the one that initiated cancel)
88 if not hasattr(origin, "_delivery_iterations"):
89 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access
91 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access
93 # Check if we've exceeded the maximum iterations
94 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access
95 # Log warning and give up - this prevents indefinite spin
96 logger.warning(
97 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. "
98 "Some tasks may not have been properly cancelled. "
99 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.",
100 max_iterations,
101 )
102 # Clear the cancel handle to stop further retries
103 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access
104 self._cancel_handle = None # pylint: disable=protected-access
105 return False # Don't retry
107 # Call the original implementation
108 return _original_deliver_cancellation(self, origin)
110 return _patched_deliver_cancellation
113def apply_anyio_cancel_delivery_patch() -> bool:
114 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config.
116 This function is idempotent - calling it multiple times has no additional effect.
118 Returns:
119 True if patch was applied (or already applied), False if disabled.
120 """
121 global _patch_applied # pylint: disable=global-statement
123 if _patch_applied:
124 return True
126 try:
127 if not settings.anyio_cancel_delivery_patch_enabled:
128 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.")
129 return False
131 max_iterations = settings.anyio_cancel_delivery_max_iterations
132 patched_func = _create_patched_deliver_cancellation(max_iterations)
133 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access
134 _patch_applied = True
136 logger.info(
137 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). "
138 "This is an experimental workaround for anyio#695. "
139 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.",
140 max_iterations,
141 )
142 return True
144 except Exception as e:
145 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e)
146 return False
149def remove_anyio_cancel_delivery_patch() -> bool:
150 """Remove the anyio _deliver_cancellation monkey-patch.
152 Restores the original anyio implementation.
154 Returns:
155 True if patch was removed, False if it wasn't applied.
156 """
157 global _patch_applied # pylint: disable=global-statement
159 if not _patch_applied:
160 return False
162 try:
163 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access
164 _patch_applied = False
165 logger.info("anyio _deliver_cancellation patch removed - restored original implementation")
166 return True
167 except Exception as e:
168 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e)
169 return False
172# Apply patch at module load time if enabled
173apply_anyio_cancel_delivery_patch()
176def _get_sse_cleanup_timeout() -> float:
177 """Get SSE task group cleanup timeout from config.
179 This timeout controls how long to wait for SSE task group tasks to respond
180 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's
181 _deliver_cancellation when tasks don't properly handle CancelledError.
183 Returns:
184 Cleanup timeout in seconds (default: 5.0)
185 """
186 try:
187 return settings.sse_task_group_cleanup_timeout
188 except Exception:
189 return 5.0 # Fallback default
192class EventSourceResponse(BaseEventSourceResponse):
193 """Patched EventSourceResponse with CPU spin detection.
195 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation
196 spins at 100% CPU when tasks in the SSE task group don't respond to
197 cancellation.
199 Instead of trying to timeout the task group (which would affect normal
200 SSE connections), we copy the __call__ method and add a deadline to
201 the cancel scope to ensure cleanup doesn't hang indefinitely.
203 See:
204 - https://github.com/agronholm/anyio/issues/695
205 - https://github.com/anthropics/claude-agent-sdk-python/issues/378
206 """
208 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002
209 """Enable compression (no-op for SSE streams).
211 SSE streams don't support compression as per sse_starlette.
212 This override prevents NotImplementedError from parent class.
214 Args:
215 force: Ignored - compression not supported for SSE.
216 """
218 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
219 """Handle SSE request with cancel scope deadline to prevent spin.
221 This method is copied from sse_starlette with one key modification:
222 the task group's cancel_scope gets a deadline set when cancellation
223 starts, preventing indefinite spinning if tasks don't respond.
225 Args:
226 scope: ASGI scope dictionary.
227 receive: ASGI receive callable.
228 send: ASGI send callable.
229 """
230 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix
231 async with anyio.create_task_group() as task_group:
232 # Add deadline to cancel scope to prevent indefinite spin on cleanup
233 # The deadline is set far in the future initially, and only becomes
234 # relevant if the scope is cancelled and cleanup takes too long
236 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None:
237 """Execute coroutine then cancel task group with bounded deadline.
239 This wrapper runs the given coroutine and, upon completion, cancels
240 the parent task group with a deadline to prevent indefinite spinning
241 if other tasks don't respond to cancellation (anyio#695 mitigation).
243 Args:
244 coro: Async callable to execute before triggering cancellation.
245 """
246 await coro()
247 # When cancelling, set a deadline to prevent indefinite spin
248 # if other tasks don't respond to cancellation
249 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout()
250 task_group.cancel_scope.cancel()
252 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send))
253 task_group.start_soon(cancel_on_finish, lambda: self._ping(send))
254 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal)
256 if self.data_sender_callable:
257 task_group.start_soon(self.data_sender_callable)
259 # Wait for the client to disconnect last
260 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive))
262 if self.background is not None:
263 await self.background()
266# Pre-computed SSE frame components for performance
267_SSE_EVENT_PREFIX = b"event: "
268_SSE_DATA_PREFIX = b"\r\ndata: "
269_SSE_RETRY_PREFIX = b"\r\nretry: "
270_SSE_FRAME_END = b"\r\n\r\n"
273def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes:
274 """Build SSE frame as bytes to avoid encode/decode overhead.
276 Args:
277 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error')
278 data: JSON data as bytes (from orjson.dumps)
279 retry: Retry timeout in milliseconds
281 Returns:
282 Complete SSE frame as bytes
284 Note:
285 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's
286 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever
287 needed, this function would need to accept a sep parameter.
289 Examples:
290 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000)
291 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n'
293 >>> _build_sse_frame(b"keepalive", b"{}", 15000)
294 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n'
295 """
296 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END
299class SSETransport(Transport):
300 """Transport implementation using Server-Sent Events with proper session management.
302 This transport implementation uses Server-Sent Events (SSE) for real-time
303 communication between the MCP gateway and clients. It provides streaming
304 capabilities with automatic session management and keepalive support.
306 Examples:
307 >>> # Create SSE transport with default URL
308 >>> transport = SSETransport()
309 >>> transport
310 <mcpgateway.transports.sse_transport.SSETransport object at ...>
312 >>> # Create SSE transport with custom URL
313 >>> transport = SSETransport("http://localhost:8080")
314 >>> transport._base_url
315 'http://localhost:8080'
317 >>> # Check initial connection state
318 >>> import asyncio
319 >>> asyncio.run(transport.is_connected())
320 False
322 >>> # Verify it's a proper Transport subclass
323 >>> isinstance(transport, Transport)
324 True
325 >>> issubclass(SSETransport, Transport)
326 True
328 >>> # Check session ID generation
329 >>> transport.session_id
330 '...'
331 >>> len(transport.session_id) > 0
332 True
334 >>> # Verify required methods exist
335 >>> hasattr(transport, 'connect')
336 True
337 >>> hasattr(transport, 'disconnect')
338 True
339 >>> hasattr(transport, 'send_message')
340 True
341 >>> hasattr(transport, 'receive_message')
342 True
343 >>> hasattr(transport, 'is_connected')
344 True
345 """
347 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars)
348 _SESSION_ID_PATTERN = None
350 @staticmethod
351 def _is_valid_session_id(session_id: str) -> bool:
352 """Validate session ID format for security.
354 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars.
355 This prevents injection attacks via malformed session IDs.
357 Args:
358 session_id: The session ID to validate.
360 Returns:
361 True if the session ID is valid, False otherwise.
363 Examples:
364 >>> SSETransport._is_valid_session_id("abc123")
365 True
366 >>> SSETransport._is_valid_session_id("abc-123_def")
367 True
368 >>> SSETransport._is_valid_session_id("a" * 128)
369 True
370 >>> SSETransport._is_valid_session_id("a" * 129)
371 False
372 >>> SSETransport._is_valid_session_id("")
373 False
374 >>> SSETransport._is_valid_session_id("abc/def")
375 False
376 >>> SSETransport._is_valid_session_id("abc:def")
377 False
378 """
379 # Standard
380 import re # pylint: disable=import-outside-toplevel
382 if not session_id or len(session_id) > 128:
383 return False
385 # Lazy compile pattern
386 if SSETransport._SESSION_ID_PATTERN is None:
387 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
389 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id))
391 def __init__(self, base_url: str = None):
392 """Initialize SSE transport.
394 Args:
395 base_url: Base URL for client message endpoints
397 Examples:
398 >>> # Test default initialization
399 >>> transport = SSETransport()
400 >>> transport._connected
401 False
402 >>> transport._message_queue is not None
403 True
404 >>> transport._client_gone is not None
405 True
406 >>> len(transport._session_id) > 0
407 True
409 >>> # Test custom base URL
410 >>> transport = SSETransport("https://api.example.com")
411 >>> transport._base_url
412 'https://api.example.com'
414 >>> # Test session ID uniqueness
415 >>> transport1 = SSETransport()
416 >>> transport2 = SSETransport()
417 >>> transport1.session_id != transport2.session_id
418 True
419 """
420 self._base_url = base_url or f"http://{settings.host}:{settings.port}"
421 self._connected = False
422 self._message_queue = asyncio.Queue()
423 self._client_gone = asyncio.Event()
424 # Server generates session_id for SSE - client receives it in endpoint event
425 self._session_id = str(uuid.uuid4())
427 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id)
429 async def connect(self) -> None:
430 """Set up SSE connection.
432 Examples:
433 >>> # Test connection setup
434 >>> transport = SSETransport()
435 >>> import asyncio
436 >>> asyncio.run(transport.connect())
437 >>> transport._connected
438 True
439 >>> asyncio.run(transport.is_connected())
440 True
441 """
442 self._connected = True
443 logger.info("SSE transport connected: %s", self._session_id)
445 async def disconnect(self) -> None:
446 """Clean up SSE connection.
448 Examples:
449 >>> # Test disconnection
450 >>> transport = SSETransport()
451 >>> import asyncio
452 >>> asyncio.run(transport.connect())
453 >>> asyncio.run(transport.disconnect())
454 >>> transport._connected
455 False
456 >>> transport._client_gone.is_set()
457 True
458 >>> asyncio.run(transport.is_connected())
459 False
461 >>> # Test disconnection when already disconnected
462 >>> transport = SSETransport()
463 >>> asyncio.run(transport.disconnect())
464 >>> transport._connected
465 False
466 """
467 if self._connected:
468 self._connected = False
469 self._client_gone.set()
470 logger.info("SSE transport disconnected: %s", self._session_id)
472 async def send_message(self, message: Dict[str, Any]) -> None:
473 """Send a message over SSE.
475 Args:
476 message: Message to send
478 Raises:
479 RuntimeError: If transport is not connected
480 Exception: If unable to put message to queue
482 Examples:
483 >>> # Test sending message when connected
484 >>> transport = SSETransport()
485 >>> import asyncio
486 >>> asyncio.run(transport.connect())
487 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1}
488 >>> asyncio.run(transport.send_message(message))
489 >>> transport._message_queue.qsize()
490 1
492 >>> # Test sending message when not connected
493 >>> transport = SSETransport()
494 >>> try:
495 ... asyncio.run(transport.send_message({"test": "message"}))
496 ... except RuntimeError as e:
497 ... print("Expected error:", str(e))
498 Expected error: Transport not connected
500 >>> # Test message format validation
501 >>> transport = SSETransport()
502 >>> asyncio.run(transport.connect())
503 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}}
504 >>> isinstance(valid_message, dict)
505 True
506 >>> "jsonrpc" in valid_message
507 True
509 >>> # Test exception handling in queue put
510 >>> transport = SSETransport()
511 >>> asyncio.run(transport.connect())
512 >>> # Create a full queue to trigger exception
513 >>> transport._message_queue = asyncio.Queue(maxsize=1)
514 >>> asyncio.run(transport._message_queue.put({"dummy": "message"}))
515 >>> # Now queue is full, next put should fail
516 >>> try:
517 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1))
518 ... except asyncio.TimeoutError:
519 ... print("Queue full as expected")
520 Queue full as expected
521 """
522 if not self._connected:
523 raise RuntimeError("Transport not connected")
525 try:
526 await self._message_queue.put(message)
527 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)"))
528 except Exception as e:
529 logger.error("Failed to queue message: %s", e)
530 raise
532 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
533 """Receive messages from the client over SSE transport.
535 This method implements a continuous message-receiving pattern for SSE transport.
536 Since SSE is primarily a server-to-client communication channel, this method
537 yields an initial initialize placeholder message and then enters a waiting loop.
538 The actual client messages are received via a separate HTTP POST endpoint
539 (not handled in this method).
541 The method will continue running until either:
542 1. The connection is explicitly disconnected (client_gone event is set)
543 2. The receive loop is cancelled from outside
545 Yields:
546 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always
547 an initialize placeholder with the format:
548 {"jsonrpc": "2.0", "method": "initialize", "id": 1}
550 Raises:
551 RuntimeError: If the transport is not connected when this method is called
552 asyncio.CancelledError: When the SSE receive loop is cancelled externally
554 Examples:
555 >>> # Test receive message when connected
556 >>> transport = SSETransport()
557 >>> import asyncio
558 >>> asyncio.run(transport.connect())
559 >>> async def test_receive():
560 ... async for msg in transport.receive_message():
561 ... return msg
562 ... return None
563 >>> result = asyncio.run(test_receive())
564 >>> result
565 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1}
567 >>> # Test receive message when not connected
568 >>> transport = SSETransport()
569 >>> try:
570 ... async def test_receive():
571 ... async for msg in transport.receive_message():
572 ... pass
573 ... asyncio.run(test_receive())
574 ... except RuntimeError as e:
575 ... print("Expected error:", str(e))
576 Expected error: Transport not connected
578 >>> # Verify generator behavior
579 >>> transport = SSETransport()
580 >>> import inspect
581 >>> inspect.isasyncgenfunction(transport.receive_message)
582 True
583 """
584 if not self._connected:
585 raise RuntimeError("Transport not connected")
587 # For SSE, we set up a loop to wait for messages which are delivered via POST
588 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder
589 # to keep the receive loop running
590 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1}
592 # Continue waiting for cancellation
593 try:
594 while not self._client_gone.is_set():
595 await asyncio.sleep(1.0)
596 except asyncio.CancelledError:
597 logger.info("SSE receive loop cancelled for session %s", self._session_id)
598 raise
599 finally:
600 logger.info("SSE receive loop ended for session %s", self._session_id)
602 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]:
603 """Get message from queue with timeout, returns None on timeout.
605 Uses asyncio.wait() to avoid TimeoutError exception overhead.
607 Args:
608 timeout: Timeout in seconds, or None for no timeout
610 Returns:
611 Message dict if received, None if timeout occurred
613 Raises:
614 asyncio.CancelledError: If the operation is cancelled externally
615 """
616 if timeout is None:
617 return await self._message_queue.get()
619 get_task = asyncio.create_task(self._message_queue.get())
620 try:
621 done, _ = await asyncio.wait({get_task}, timeout=timeout)
622 except asyncio.CancelledError:
623 get_task.cancel()
624 try:
625 await get_task
626 except asyncio.CancelledError:
627 pass
628 raise
630 if get_task in done:
631 return get_task.result()
633 # Timeout - cancel pending task, but return the result if it completed in the race window.
634 get_task.cancel()
635 try:
636 return await get_task
637 except asyncio.CancelledError:
638 return None
640 async def is_connected(self) -> bool:
641 """Check if transport is connected.
643 Returns:
644 True if connected
646 Examples:
647 >>> # Test initial state
648 >>> transport = SSETransport()
649 >>> import asyncio
650 >>> asyncio.run(transport.is_connected())
651 False
653 >>> # Test after connection
654 >>> transport = SSETransport()
655 >>> asyncio.run(transport.connect())
656 >>> asyncio.run(transport.is_connected())
657 True
659 >>> # Test after disconnection
660 >>> transport = SSETransport()
661 >>> asyncio.run(transport.connect())
662 >>> asyncio.run(transport.disconnect())
663 >>> asyncio.run(transport.is_connected())
664 False
665 """
666 return self._connected
668 async def create_sse_response(
669 self,
670 request: Request,
671 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None,
672 ) -> EventSourceResponse:
673 """Create SSE response for streaming.
675 Args:
676 request: FastAPI request (used for disconnection detection)
677 on_disconnect_callback: Optional async callback to run when client disconnects.
678 Used for defensive cleanup (e.g., cancelling respond tasks).
680 Returns:
681 SSE response object
683 Examples:
684 >>> # Test SSE response creation
685 >>> transport = SSETransport("http://localhost:8000")
686 >>> # Note: This method requires a FastAPI Request object
687 >>> # and cannot be easily tested in doctest environment
688 >>> callable(transport.create_sse_response)
689 True
690 """
691 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}"
693 async def event_generator():
694 """Generate SSE events.
696 Yields:
697 SSE event as bytes (pre-formatted SSE frame)
699 Raises:
700 asyncio.CancelledError: If the generator is cancelled.
701 """
702 # Send the endpoint event first
703 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout)
705 # Send keepalive immediately to help establish connection (if enabled)
706 if settings.sse_keepalive_enabled:
707 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
709 consecutive_errors = 0
710 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected)
712 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected
713 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window.
714 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None
715 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0
716 last_yield_time = time.monotonic()
717 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection
719 def check_rapid_yield() -> bool:
720 """Check if yields are happening too fast.
722 Returns:
723 True if spin loop detected and should disconnect, False otherwise.
724 """
725 nonlocal last_yield_time, consecutive_rapid_yields
726 now = time.monotonic()
727 time_since_last = now - last_yield_time
728 last_yield_time = now
730 # Track consecutive rapid yields (< 100ms apart)
731 # This catches spin loops even without full deque analysis
732 if time_since_last < 0.1:
733 consecutive_rapid_yields += 1
734 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin
735 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id)
736 return True
737 else:
738 consecutive_rapid_yields = 0 # Reset on normal-speed yield
740 # Also use deque-based detection for more nuanced analysis
741 if yield_timestamps is None:
742 return False
743 yield_timestamps.append(now)
744 if time_since_last < 0.01: # Less than 10ms between yields is very fast
745 if len(yield_timestamps) > settings.sse_rapid_yield_max:
746 oldest = yield_timestamps[0]
747 elapsed = now - oldest
748 if elapsed < rapid_yield_window_sec:
749 logger.error(
750 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id
751 )
752 return True
753 return False
755 try:
756 while not self._client_gone.is_set():
757 # Check if client has disconnected via request state
758 if await request.is_disconnected():
759 logger.info("SSE client disconnected (detected via request): %s", self._session_id)
760 self._client_gone.set()
761 break
763 try:
764 # Use timeout-based polling only when keepalive is enabled
765 if not settings.sse_keepalive_enabled:
766 message = await self._message_queue.get()
767 else:
768 message = await self._get_message_with_timeout(settings.sse_keepalive_interval)
770 if message is not None:
771 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY)
773 if logger.isEnabledFor(logging.DEBUG):
774 logger.debug("Sending SSE message: %s", json_bytes.decode())
776 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout)
777 consecutive_errors = 0 # Reset on successful send
779 # Check for rapid yields after message send
780 if check_rapid_yield():
781 self._client_gone.set()
782 break
783 elif settings.sse_keepalive_enabled:
784 # Timeout - send keepalive
785 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
786 consecutive_errors = 0 # Reset on successful send
787 # Check for rapid yields after keepalive too
788 if check_rapid_yield():
789 self._client_gone.set()
790 break
791 # Note: We don't clear yield_timestamps here. The deque has maxlen
792 # which automatically drops old entries. Clearing would prevent
793 # detection of spin loops where yields happen faster than expected.
795 except GeneratorExit:
796 # Client disconnected - generator is being closed
797 logger.info("SSE generator exit (client disconnected): %s", self._session_id)
798 self._client_gone.set()
799 break
800 except Exception as e:
801 consecutive_errors += 1
802 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e)
803 if consecutive_errors >= max_consecutive_errors:
804 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id)
805 self._client_gone.set()
806 break
807 # Don't yield error frame - it could cause more errors if client is gone
809 except asyncio.CancelledError:
810 logger.info("SSE event generator cancelled: %s", self._session_id)
811 self._client_gone.set()
812 raise
813 except GeneratorExit:
814 logger.info("SSE generator exit: %s", self._session_id)
815 self._client_gone.set()
816 except Exception as e:
817 logger.error("SSE event generator error: %s", e)
818 self._client_gone.set()
819 finally:
820 logger.info("SSE event generator completed: %s", self._session_id)
821 self._client_gone.set() # Always set client_gone on exit to clean up
822 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3)
823 # This covers server-initiated close, errors, and cancellation - not just client close
824 if on_disconnect_callback:
825 try:
826 await on_disconnect_callback()
827 except Exception as e:
828 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e)
830 async def on_client_close(_scope: dict) -> None:
831 """Handle client close event from sse_starlette."""
832 logger.info("SSE client close handler called: %s", self._session_id)
833 self._client_gone.set()
835 # Defensive cleanup via callback (if provided)
836 if on_disconnect_callback:
837 try:
838 await on_disconnect_callback()
839 except Exception as e:
840 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e)
842 return EventSourceResponse(
843 event_generator(),
844 status_code=200,
845 headers={
846 "Cache-Control": "no-cache",
847 "Connection": "keep-alive",
848 "Content-Type": "text/event-stream",
849 "X-MCP-SSE": "true",
850 },
851 # Timeout for ASGI send() calls - protects against sends that hang indefinitely
852 # when client connection is in a bad state (e.g., client stopped reading but TCP
853 # connection not yet closed). Does NOT affect MCP server response times.
854 # Set to 0 to disable. Default matches keepalive interval.
855 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None,
856 # Callback when client closes - helps detect disconnects that ASGI server
857 # may not properly propagate via request.is_disconnected()
858 client_close_handler_callable=on_client_close,
859 )
861 async def _client_disconnected(self, _request: Request) -> bool:
862 """Check if client has disconnected.
864 Args:
865 _request: FastAPI Request object
867 Returns:
868 bool: True if client disconnected
870 Examples:
871 >>> # Test client disconnected check
872 >>> transport = SSETransport()
873 >>> import asyncio
874 >>> asyncio.run(transport._client_disconnected(None))
875 False
877 >>> # Test after setting client gone
878 >>> transport = SSETransport()
879 >>> transport._client_gone.set()
880 >>> asyncio.run(transport._client_disconnected(None))
881 True
882 """
883 # We only check our internal client_gone flag
884 # We intentionally don't check connection_lost on the request
885 # as it can be unreliable and cause premature closures
886 return self._client_gone.is_set()
888 @property
889 def session_id(self) -> str:
890 """
891 Get the session ID for this transport.
893 Returns:
894 str: session_id
896 Examples:
897 >>> # Test session ID property
898 >>> transport = SSETransport()
899 >>> session_id = transport.session_id
900 >>> isinstance(session_id, str)
901 True
902 >>> len(session_id) > 0
903 True
904 >>> session_id == transport._session_id
905 True
907 >>> # Test session ID uniqueness
908 >>> transport1 = SSETransport()
909 >>> transport2 = SSETransport()
910 >>> transport1.session_id != transport2.session_id
911 True
912 """
913 return self._session_id