Coverage for mcpgateway / transports / sse_transport.py: 100%
255 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/sse_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7SSE Transport Implementation.
8This module implements Server-Sent Events (SSE) transport for MCP,
9providing server-to-client streaming with proper session management.
10"""
12# Standard
13import asyncio
14from collections import deque
15import logging
16import time
17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional
18import uuid
20# Third-Party
21import anyio
22from anyio._backends._asyncio import CancelScope
23from fastapi import Request
24import orjson
25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse
26from starlette.types import Receive, Scope, Send
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.services.logging_service import LoggingService
31from mcpgateway.transports.base import Transport
32from mcpgateway.utils.trace_context import set_trace_session_id
34# Initialize logging service first
35logging_service = LoggingService()
36logger = logging_service.get_logger(__name__)
39# =============================================================================
40# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695)
41# =============================================================================
42# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond
43# to CancelledError. This optional monkey-patch adds a max iteration limit to
44# prevent indefinite spinning. After the limit is reached, we give up delivering
45# cancellation and let the scope exit (tasks will be orphaned but won't spin).
46#
47# This workaround is DISABLED by default. Enable via:
48# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true
49#
50# Trade-offs:
51# - Prevents indefinite CPU spin (good)
52# - May leave some tasks uncancelled after max iterations (usually harmless)
53# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks
54#
55# This workaround may be removed when anyio or MCP SDK fix the underlying issue.
56# See: https://github.com/agronholm/anyio/issues/695
57# =============================================================================
59# Store original for potential restoration and for the patch to call
60_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access
61_patch_applied = False # pylint: disable=invalid-name
64def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901
65 """Create a patched _deliver_cancellation with configurable max iterations.
67 Args:
68 max_iterations: Maximum iterations before giving up cancellation delivery.
70 Returns:
71 Patched function that limits cancellation delivery iterations.
72 """
74 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access
75 """Patched _deliver_cancellation with max iteration limit to prevent spin.
77 This wraps anyio's original _deliver_cancellation to track iteration count
78 and give up after a maximum number of attempts. This prevents the CPU spin
79 loop that occurs when tasks don't respond to CancelledError.
81 Args:
82 self: The cancel scope being processed.
83 origin: The cancel scope that originated the cancellation.
85 Returns:
86 True if delivery should be retried, False if done or max iterations reached.
87 """
88 # Track iteration count on the origin scope (the one that initiated cancel)
89 if not hasattr(origin, "_delivery_iterations"):
90 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access
92 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access
94 # Check if we've exceeded the maximum iterations
95 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access
96 # Log warning and give up - this prevents indefinite spin
97 logger.warning(
98 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. "
99 "Some tasks may not have been properly cancelled. "
100 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.",
101 max_iterations,
102 )
103 # Clear the cancel handle to stop further retries
104 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access
105 self._cancel_handle = None # pylint: disable=protected-access
106 return False # Don't retry
108 # Call the original implementation
109 return _original_deliver_cancellation(self, origin)
111 return _patched_deliver_cancellation
114def apply_anyio_cancel_delivery_patch() -> bool:
115 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config.
117 This function is idempotent - calling it multiple times has no additional effect.
119 Returns:
120 True if patch was applied (or already applied), False if disabled.
121 """
122 global _patch_applied # pylint: disable=global-statement
124 if _patch_applied:
125 return True
127 try:
128 if not settings.anyio_cancel_delivery_patch_enabled:
129 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.")
130 return False
132 max_iterations = settings.anyio_cancel_delivery_max_iterations
133 patched_func = _create_patched_deliver_cancellation(max_iterations)
134 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access
135 _patch_applied = True
137 logger.info(
138 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). "
139 "This is an experimental workaround for anyio#695. "
140 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.",
141 max_iterations,
142 )
143 return True
145 except Exception as e:
146 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e)
147 return False
150def remove_anyio_cancel_delivery_patch() -> bool:
151 """Remove the anyio _deliver_cancellation monkey-patch.
153 Restores the original anyio implementation.
155 Returns:
156 True if patch was removed, False if it wasn't applied.
157 """
158 global _patch_applied # pylint: disable=global-statement
160 if not _patch_applied:
161 return False
163 try:
164 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access
165 _patch_applied = False
166 logger.info("anyio _deliver_cancellation patch removed - restored original implementation")
167 return True
168 except Exception as e:
169 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e)
170 return False
173# Apply patch at module load time if enabled
174apply_anyio_cancel_delivery_patch()
177def _get_sse_cleanup_timeout() -> float:
178 """Get SSE task group cleanup timeout from config.
180 This timeout controls how long to wait for SSE task group tasks to respond
181 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's
182 _deliver_cancellation when tasks don't properly handle CancelledError.
184 Returns:
185 Cleanup timeout in seconds (default: 5.0)
186 """
187 try:
188 return settings.sse_task_group_cleanup_timeout
189 except Exception:
190 return 5.0 # Fallback default
193class EventSourceResponse(BaseEventSourceResponse):
194 """Patched EventSourceResponse with CPU spin detection.
196 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation
197 spins at 100% CPU when tasks in the SSE task group don't respond to
198 cancellation.
200 Instead of trying to timeout the task group (which would affect normal
201 SSE connections), we copy the __call__ method and add a deadline to
202 the cancel scope to ensure cleanup doesn't hang indefinitely.
204 See:
205 - https://github.com/agronholm/anyio/issues/695
206 - https://github.com/anthropics/claude-agent-sdk-python/issues/378
207 """
209 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002
210 """Enable compression (no-op for SSE streams).
212 SSE streams don't support compression as per sse_starlette.
213 This override prevents NotImplementedError from parent class.
215 Args:
216 force: Ignored - compression not supported for SSE.
217 """
219 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
220 """Handle SSE request with cancel scope deadline to prevent spin.
222 This method is copied from sse_starlette with one key modification:
223 the task group's cancel_scope gets a deadline set when cancellation
224 starts, preventing indefinite spinning if tasks don't respond.
226 Args:
227 scope: ASGI scope dictionary.
228 receive: ASGI receive callable.
229 send: ASGI send callable.
230 """
231 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix
232 async with anyio.create_task_group() as task_group:
233 # Add deadline to cancel scope to prevent indefinite spin on cleanup
234 # The deadline is set far in the future initially, and only becomes
235 # relevant if the scope is cancelled and cleanup takes too long
237 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None:
238 """Execute coroutine then cancel task group with bounded deadline.
240 This wrapper runs the given coroutine and, upon completion, cancels
241 the parent task group with a deadline to prevent indefinite spinning
242 if other tasks don't respond to cancellation (anyio#695 mitigation).
244 Args:
245 coro: Async callable to execute before triggering cancellation.
246 """
247 await coro()
248 # When cancelling, set a deadline to prevent indefinite spin
249 # if other tasks don't respond to cancellation
250 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout()
251 task_group.cancel_scope.cancel()
253 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send))
254 task_group.start_soon(cancel_on_finish, lambda: self._ping(send))
255 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal)
257 if self.data_sender_callable:
258 task_group.start_soon(self.data_sender_callable)
260 # Wait for the client to disconnect last
261 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive))
263 if self.background is not None:
264 await self.background()
267# Pre-computed SSE frame components for performance
268_SSE_EVENT_PREFIX = b"event: "
269_SSE_DATA_PREFIX = b"\r\ndata: "
270_SSE_RETRY_PREFIX = b"\r\nretry: "
271_SSE_FRAME_END = b"\r\n\r\n"
274def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes:
275 """Build SSE frame as bytes to avoid encode/decode overhead.
277 Args:
278 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error')
279 data: JSON data as bytes (from orjson.dumps)
280 retry: Retry timeout in milliseconds
282 Returns:
283 Complete SSE frame as bytes
285 Note:
286 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's
287 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever
288 needed, this function would need to accept a sep parameter.
290 Examples:
291 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000)
292 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n'
294 >>> _build_sse_frame(b"keepalive", b"{}", 15000)
295 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n'
296 """
297 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END
300class SSETransport(Transport):
301 """Transport implementation using Server-Sent Events with proper session management.
303 This transport implementation uses Server-Sent Events (SSE) for real-time
304 communication between the MCP gateway and clients. It provides streaming
305 capabilities with automatic session management and keepalive support.
307 Examples:
308 >>> # Create SSE transport with default URL
309 >>> transport = SSETransport()
310 >>> transport
311 <mcpgateway.transports.sse_transport.SSETransport object at ...>
313 >>> # Create SSE transport with custom URL
314 >>> transport = SSETransport("http://localhost:8080")
315 >>> transport._base_url
316 'http://localhost:8080'
318 >>> # Check initial connection state
319 >>> import asyncio
320 >>> asyncio.run(transport.is_connected())
321 False
323 >>> # Verify it's a proper Transport subclass
324 >>> isinstance(transport, Transport)
325 True
326 >>> issubclass(SSETransport, Transport)
327 True
329 >>> # Check session ID generation
330 >>> transport.session_id
331 '...'
332 >>> len(transport.session_id) > 0
333 True
335 >>> # Verify required methods exist
336 >>> hasattr(transport, 'connect')
337 True
338 >>> hasattr(transport, 'disconnect')
339 True
340 >>> hasattr(transport, 'send_message')
341 True
342 >>> hasattr(transport, 'receive_message')
343 True
344 >>> hasattr(transport, 'is_connected')
345 True
346 """
348 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars)
349 _SESSION_ID_PATTERN = None
351 @staticmethod
352 def _is_valid_session_id(session_id: str) -> bool:
353 """Validate session ID format for security.
355 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars.
356 This prevents injection attacks via malformed session IDs.
358 Args:
359 session_id: The session ID to validate.
361 Returns:
362 True if the session ID is valid, False otherwise.
364 Examples:
365 >>> SSETransport._is_valid_session_id("abc123")
366 True
367 >>> SSETransport._is_valid_session_id("abc-123_def")
368 True
369 >>> SSETransport._is_valid_session_id("a" * 128)
370 True
371 >>> SSETransport._is_valid_session_id("a" * 129)
372 False
373 >>> SSETransport._is_valid_session_id("")
374 False
375 >>> SSETransport._is_valid_session_id("abc/def")
376 False
377 >>> SSETransport._is_valid_session_id("abc:def")
378 False
379 """
380 # Standard
381 import re # pylint: disable=import-outside-toplevel
383 if not session_id or len(session_id) > 128:
384 return False
386 # Lazy compile pattern
387 if SSETransport._SESSION_ID_PATTERN is None:
388 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
390 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id))
392 def __init__(self, base_url: str = None):
393 """Initialize SSE transport.
395 Args:
396 base_url: Base URL for client message endpoints
398 Examples:
399 >>> # Test default initialization
400 >>> transport = SSETransport()
401 >>> transport._connected
402 False
403 >>> transport._message_queue is not None
404 True
405 >>> transport._client_gone is not None
406 True
407 >>> len(transport._session_id) > 0
408 True
410 >>> # Test custom base URL
411 >>> transport = SSETransport("https://api.example.com")
412 >>> transport._base_url
413 'https://api.example.com'
415 >>> # Test session ID uniqueness
416 >>> transport1 = SSETransport()
417 >>> transport2 = SSETransport()
418 >>> transport1.session_id != transport2.session_id
419 True
420 """
421 self._base_url = base_url or f"http://{settings.host}:{settings.port}"
422 self._connected = False
423 self._message_queue = asyncio.Queue()
424 self._client_gone = asyncio.Event()
425 # Server generates session_id for SSE - client receives it in endpoint event
426 self._session_id = str(uuid.uuid4())
427 set_trace_session_id(self._session_id)
429 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id)
431 async def connect(self) -> None:
432 """Set up SSE connection.
434 Examples:
435 >>> # Test connection setup
436 >>> transport = SSETransport()
437 >>> import asyncio
438 >>> asyncio.run(transport.connect())
439 >>> transport._connected
440 True
441 >>> asyncio.run(transport.is_connected())
442 True
443 """
444 self._connected = True
445 logger.info("SSE transport connected: %s", self._session_id)
447 async def disconnect(self) -> None:
448 """Clean up SSE connection.
450 Examples:
451 >>> # Test disconnection
452 >>> transport = SSETransport()
453 >>> import asyncio
454 >>> asyncio.run(transport.connect())
455 >>> asyncio.run(transport.disconnect())
456 >>> transport._connected
457 False
458 >>> transport._client_gone.is_set()
459 True
460 >>> asyncio.run(transport.is_connected())
461 False
463 >>> # Test disconnection when already disconnected
464 >>> transport = SSETransport()
465 >>> asyncio.run(transport.disconnect())
466 >>> transport._connected
467 False
468 """
469 if self._connected:
470 self._connected = False
471 self._client_gone.set()
472 logger.info("SSE transport disconnected: %s", self._session_id)
474 async def send_message(self, message: Dict[str, Any]) -> None:
475 """Send a message over SSE.
477 Args:
478 message: Message to send
480 Raises:
481 RuntimeError: If transport is not connected
482 Exception: If unable to put message to queue
484 Examples:
485 >>> # Test sending message when connected
486 >>> transport = SSETransport()
487 >>> import asyncio
488 >>> asyncio.run(transport.connect())
489 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1}
490 >>> asyncio.run(transport.send_message(message))
491 >>> transport._message_queue.qsize()
492 1
494 >>> # Test sending message when not connected
495 >>> transport = SSETransport()
496 >>> try:
497 ... asyncio.run(transport.send_message({"test": "message"}))
498 ... except RuntimeError as e:
499 ... print("Expected error:", str(e))
500 Expected error: Transport not connected
502 >>> # Test message format validation
503 >>> transport = SSETransport()
504 >>> asyncio.run(transport.connect())
505 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}}
506 >>> isinstance(valid_message, dict)
507 True
508 >>> "jsonrpc" in valid_message
509 True
511 >>> # Test exception handling in queue put
512 >>> transport = SSETransport()
513 >>> asyncio.run(transport.connect())
514 >>> # Create a full queue to trigger exception
515 >>> transport._message_queue = asyncio.Queue(maxsize=1)
516 >>> asyncio.run(transport._message_queue.put({"dummy": "message"}))
517 >>> # Now queue is full, next put should fail
518 >>> try:
519 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1))
520 ... except asyncio.TimeoutError:
521 ... print("Queue full as expected")
522 Queue full as expected
523 """
524 if not self._connected:
525 raise RuntimeError("Transport not connected")
527 try:
528 await self._message_queue.put(message)
529 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)"))
530 except Exception as e:
531 logger.error("Failed to queue message: %s", e)
532 raise
534 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
535 """Receive messages from the client over SSE transport.
537 This method implements a continuous message-receiving pattern for SSE transport.
538 Since SSE is primarily a server-to-client communication channel, this method
539 yields an initial initialize placeholder message and then enters a waiting loop.
540 The actual client messages are received via a separate HTTP POST endpoint
541 (not handled in this method).
543 The method will continue running until either:
544 1. The connection is explicitly disconnected (client_gone event is set)
545 2. The receive loop is cancelled from outside
547 Yields:
548 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always
549 an initialize placeholder with the format:
550 {"jsonrpc": "2.0", "method": "initialize", "id": 1}
552 Raises:
553 RuntimeError: If the transport is not connected when this method is called
554 asyncio.CancelledError: When the SSE receive loop is cancelled externally
556 Examples:
557 >>> # Test receive message when connected
558 >>> transport = SSETransport()
559 >>> import asyncio
560 >>> asyncio.run(transport.connect())
561 >>> async def test_receive():
562 ... async for msg in transport.receive_message():
563 ... return msg
564 ... return None
565 >>> result = asyncio.run(test_receive())
566 >>> result
567 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1}
569 >>> # Test receive message when not connected
570 >>> transport = SSETransport()
571 >>> try:
572 ... async def test_receive():
573 ... async for msg in transport.receive_message():
574 ... pass
575 ... asyncio.run(test_receive())
576 ... except RuntimeError as e:
577 ... print("Expected error:", str(e))
578 Expected error: Transport not connected
580 >>> # Verify generator behavior
581 >>> transport = SSETransport()
582 >>> import inspect
583 >>> inspect.isasyncgenfunction(transport.receive_message)
584 True
585 """
586 if not self._connected:
587 raise RuntimeError("Transport not connected")
589 # For SSE, we set up a loop to wait for messages which are delivered via POST
590 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder
591 # to keep the receive loop running
592 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1}
594 # Continue waiting for cancellation
595 try:
596 while not self._client_gone.is_set():
597 await asyncio.sleep(1.0)
598 except asyncio.CancelledError:
599 logger.info("SSE receive loop cancelled for session %s", self._session_id)
600 raise
601 finally:
602 logger.info("SSE receive loop ended for session %s", self._session_id)
604 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]:
605 """Get message from queue with timeout, returns None on timeout.
607 Uses asyncio.wait() to avoid TimeoutError exception overhead.
609 Args:
610 timeout: Timeout in seconds, or None for no timeout
612 Returns:
613 Message dict if received, None if timeout occurred
615 Raises:
616 asyncio.CancelledError: If the operation is cancelled externally
617 """
618 if timeout is None:
619 return await self._message_queue.get()
621 get_task = asyncio.create_task(self._message_queue.get())
622 try:
623 done, _ = await asyncio.wait({get_task}, timeout=timeout)
624 except asyncio.CancelledError:
625 get_task.cancel()
626 try:
627 await get_task
628 except asyncio.CancelledError:
629 pass
630 raise
632 if get_task in done:
633 return get_task.result()
635 # Timeout - cancel pending task, but return the result if it completed in the race window.
636 get_task.cancel()
637 try:
638 return await get_task
639 except asyncio.CancelledError:
640 return None
642 async def is_connected(self) -> bool:
643 """Check if transport is connected.
645 Returns:
646 True if connected
648 Examples:
649 >>> # Test initial state
650 >>> transport = SSETransport()
651 >>> import asyncio
652 >>> asyncio.run(transport.is_connected())
653 False
655 >>> # Test after connection
656 >>> transport = SSETransport()
657 >>> asyncio.run(transport.connect())
658 >>> asyncio.run(transport.is_connected())
659 True
661 >>> # Test after disconnection
662 >>> transport = SSETransport()
663 >>> asyncio.run(transport.connect())
664 >>> asyncio.run(transport.disconnect())
665 >>> asyncio.run(transport.is_connected())
666 False
667 """
668 return self._connected
670 async def create_sse_response(
671 self,
672 request: Request,
673 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None,
674 ) -> EventSourceResponse:
675 """Create SSE response for streaming.
677 Args:
678 request: FastAPI request (used for disconnection detection)
679 on_disconnect_callback: Optional async callback to run when client disconnects.
680 Used for defensive cleanup (e.g., cancelling respond tasks).
682 Returns:
683 SSE response object
685 Examples:
686 >>> # Test SSE response creation
687 >>> transport = SSETransport("http://localhost:8000")
688 >>> # Note: This method requires a FastAPI Request object
689 >>> # and cannot be easily tested in doctest environment
690 >>> callable(transport.create_sse_response)
691 True
692 """
693 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}"
695 async def event_generator():
696 """Generate SSE events.
698 Yields:
699 SSE event as bytes (pre-formatted SSE frame)
701 Raises:
702 asyncio.CancelledError: If the generator is cancelled.
703 """
704 # Send the endpoint event first
705 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout)
707 # Send keepalive immediately to help establish connection (if enabled)
708 if settings.sse_keepalive_enabled:
709 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
711 consecutive_errors = 0
712 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected)
714 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected
715 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window.
716 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None
717 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0
718 last_yield_time = time.monotonic()
719 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection
721 def check_rapid_yield() -> bool:
722 """Check if yields are happening too fast.
724 Returns:
725 True if spin loop detected and should disconnect, False otherwise.
726 """
727 nonlocal last_yield_time, consecutive_rapid_yields
728 now = time.monotonic()
729 time_since_last = now - last_yield_time
730 last_yield_time = now
732 # Track consecutive rapid yields (< 100ms apart)
733 # This catches spin loops even without full deque analysis
734 if time_since_last < 0.1:
735 consecutive_rapid_yields += 1
736 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin
737 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id)
738 return True
739 else:
740 consecutive_rapid_yields = 0 # Reset on normal-speed yield
742 # Also use deque-based detection for more nuanced analysis
743 if yield_timestamps is None:
744 return False
745 yield_timestamps.append(now)
746 if time_since_last < 0.01: # Less than 10ms between yields is very fast
747 if len(yield_timestamps) > settings.sse_rapid_yield_max:
748 oldest = yield_timestamps[0]
749 elapsed = now - oldest
750 if elapsed < rapid_yield_window_sec:
751 logger.error(
752 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id
753 )
754 return True
755 return False
757 try:
758 while not self._client_gone.is_set():
759 # Check if client has disconnected via request state
760 if await request.is_disconnected():
761 logger.info("SSE client disconnected (detected via request): %s", self._session_id)
762 self._client_gone.set()
763 break
765 try:
766 # Use timeout-based polling only when keepalive is enabled
767 if not settings.sse_keepalive_enabled:
768 message = await self._message_queue.get()
769 else:
770 message = await self._get_message_with_timeout(settings.sse_keepalive_interval)
772 if message is not None:
773 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY)
775 if logger.isEnabledFor(logging.DEBUG):
776 logger.debug("Sending SSE message: %s", json_bytes.decode())
778 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout)
779 consecutive_errors = 0 # Reset on successful send
781 # Check for rapid yields after message send
782 if check_rapid_yield():
783 self._client_gone.set()
784 break
785 elif settings.sse_keepalive_enabled:
786 # Timeout - send keepalive
787 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
788 consecutive_errors = 0 # Reset on successful send
789 # Check for rapid yields after keepalive too
790 if check_rapid_yield():
791 self._client_gone.set()
792 break
793 # Note: We don't clear yield_timestamps here. The deque has maxlen
794 # which automatically drops old entries. Clearing would prevent
795 # detection of spin loops where yields happen faster than expected.
797 except GeneratorExit:
798 # Client disconnected - generator is being closed
799 logger.info("SSE generator exit (client disconnected): %s", self._session_id)
800 self._client_gone.set()
801 break
802 except Exception as e:
803 consecutive_errors += 1
804 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e)
805 if consecutive_errors >= max_consecutive_errors:
806 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id)
807 self._client_gone.set()
808 break
809 # Don't yield error frame - it could cause more errors if client is gone
811 except asyncio.CancelledError:
812 logger.info("SSE event generator cancelled: %s", self._session_id)
813 self._client_gone.set()
814 raise
815 except GeneratorExit:
816 logger.info("SSE generator exit: %s", self._session_id)
817 self._client_gone.set()
818 except Exception as e:
819 logger.error("SSE event generator error: %s", e)
820 self._client_gone.set()
821 finally:
822 logger.info("SSE event generator completed: %s", self._session_id)
823 self._client_gone.set() # Always set client_gone on exit to clean up
824 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3)
825 # This covers server-initiated close, errors, and cancellation - not just client close
826 if on_disconnect_callback:
827 try:
828 await on_disconnect_callback()
829 except Exception as e:
830 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e)
832 async def on_client_close(_scope: dict) -> None:
833 """Handle client close event from sse_starlette."""
834 logger.info("SSE client close handler called: %s", self._session_id)
835 self._client_gone.set()
837 # Defensive cleanup via callback (if provided)
838 if on_disconnect_callback:
839 try:
840 await on_disconnect_callback()
841 except Exception as e:
842 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e)
844 return EventSourceResponse(
845 event_generator(),
846 status_code=200,
847 headers={
848 "Cache-Control": "no-cache",
849 "Connection": "keep-alive",
850 "Content-Type": "text/event-stream",
851 "X-MCP-SSE": "true",
852 },
853 # Timeout for ASGI send() calls - protects against sends that hang indefinitely
854 # when client connection is in a bad state (e.g., client stopped reading but TCP
855 # connection not yet closed). Does NOT affect MCP server response times.
856 # Set to 0 to disable. Default matches keepalive interval.
857 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None,
858 # Callback when client closes - helps detect disconnects that ASGI server
859 # may not properly propagate via request.is_disconnected()
860 client_close_handler_callable=on_client_close,
861 )
863 async def _client_disconnected(self, _request: Request) -> bool:
864 """Check if client has disconnected.
866 Args:
867 _request: FastAPI Request object
869 Returns:
870 bool: True if client disconnected
872 Examples:
873 >>> # Test client disconnected check
874 >>> transport = SSETransport()
875 >>> import asyncio
876 >>> asyncio.run(transport._client_disconnected(None))
877 False
879 >>> # Test after setting client gone
880 >>> transport = SSETransport()
881 >>> transport._client_gone.set()
882 >>> asyncio.run(transport._client_disconnected(None))
883 True
884 """
885 # We only check our internal client_gone flag
886 # We intentionally don't check connection_lost on the request
887 # as it can be unreliable and cause premature closures
888 return self._client_gone.is_set()
890 @property
891 def session_id(self) -> str:
892 """
893 Get the session ID for this transport.
895 Returns:
896 str: session_id
898 Examples:
899 >>> # Test session ID property
900 >>> transport = SSETransport()
901 >>> session_id = transport.session_id
902 >>> isinstance(session_id, str)
903 True
904 >>> len(session_id) > 0
905 True
906 >>> session_id == transport._session_id
907 True
909 >>> # Test session ID uniqueness
910 >>> transport1 = SSETransport()
911 >>> transport2 = SSETransport()
912 >>> transport1.session_id != transport2.session_id
913 True
914 """
915 return self._session_id