Coverage for mcpgateway / transports / sse_transport.py: 100%
252 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/sse_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7SSE Transport Implementation.
8This module implements Server-Sent Events (SSE) transport for MCP,
9providing server-to-client streaming with proper session management.
10"""
12# Standard
13import asyncio
14from collections import deque
15import logging
16import time
17from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, Optional
18import uuid
20# Third-Party
21import anyio
22from anyio._backends._asyncio import CancelScope
23from fastapi import Request
24import orjson
25from sse_starlette.sse import EventSourceResponse as BaseEventSourceResponse
26from starlette.types import Receive, Scope, Send
28# First-Party
29from mcpgateway.config import settings
30from mcpgateway.services.logging_service import LoggingService
31from mcpgateway.transports.base import Transport
33# Initialize logging service first
34logging_service = LoggingService()
35logger = logging_service.get_logger(__name__)
38# =============================================================================
39# EXPERIMENTAL WORKAROUND: anyio _deliver_cancellation spin loop (anyio#695)
40# =============================================================================
41# anyio's _deliver_cancellation can spin at 100% CPU when tasks don't respond
42# to CancelledError. This optional monkey-patch adds a max iteration limit to
43# prevent indefinite spinning. After the limit is reached, we give up delivering
44# cancellation and let the scope exit (tasks will be orphaned but won't spin).
45#
46# This workaround is DISABLED by default. Enable via:
47# ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true
48#
49# Trade-offs:
50# - Prevents indefinite CPU spin (good)
51# - May leave some tasks uncancelled after max iterations (usually harmless)
52# - Worker recycling (GUNICORN_MAX_REQUESTS) cleans up orphaned tasks
53#
54# This workaround may be removed when anyio or MCP SDK fix the underlying issue.
55# See: https://github.com/agronholm/anyio/issues/695
56# =============================================================================
58# Store original for potential restoration and for the patch to call
59_original_deliver_cancellation = CancelScope._deliver_cancellation # type: ignore[attr-defined] # pylint: disable=protected-access
60_patch_applied = False # pylint: disable=invalid-name
63def _create_patched_deliver_cancellation(max_iterations: int): # noqa: C901
64 """Create a patched _deliver_cancellation with configurable max iterations.
66 Args:
67 max_iterations: Maximum iterations before giving up cancellation delivery.
69 Returns:
70 Patched function that limits cancellation delivery iterations.
71 """
73 def _patched_deliver_cancellation(self: CancelScope, origin: CancelScope) -> bool: # pylint: disable=protected-access
74 """Patched _deliver_cancellation with max iteration limit to prevent spin.
76 This wraps anyio's original _deliver_cancellation to track iteration count
77 and give up after a maximum number of attempts. This prevents the CPU spin
78 loop that occurs when tasks don't respond to CancelledError.
80 Args:
81 self: The cancel scope being processed.
82 origin: The cancel scope that originated the cancellation.
84 Returns:
85 True if delivery should be retried, False if done or max iterations reached.
86 """
87 # Track iteration count on the origin scope (the one that initiated cancel)
88 if not hasattr(origin, "_delivery_iterations"):
89 origin._delivery_iterations = 0 # type: ignore[attr-defined] # pylint: disable=protected-access
91 origin._delivery_iterations += 1 # type: ignore[attr-defined] # pylint: disable=protected-access
93 # Check if we've exceeded the maximum iterations
94 if origin._delivery_iterations > max_iterations: # type: ignore[attr-defined] # pylint: disable=protected-access
95 # Log warning and give up - this prevents indefinite spin
96 logger.warning(
97 "anyio cancel delivery exceeded %d iterations - giving up to prevent CPU spin. "
98 "Some tasks may not have been properly cancelled. "
99 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if this causes issues.",
100 max_iterations,
101 )
102 # Clear the cancel handle to stop further retries
103 if hasattr(self, "_cancel_handle") and self._cancel_handle is not None: # pylint: disable=protected-access
104 self._cancel_handle = None # pylint: disable=protected-access
105 return False # Don't retry
107 # Call the original implementation
108 return _original_deliver_cancellation(self, origin)
110 return _patched_deliver_cancellation
113def apply_anyio_cancel_delivery_patch() -> bool:
114 """Apply the anyio _deliver_cancellation monkey-patch if enabled in config.
116 This function is idempotent - calling it multiple times has no additional effect.
118 Returns:
119 True if patch was applied (or already applied), False if disabled.
120 """
121 global _patch_applied # pylint: disable=global-statement
123 if _patch_applied:
124 return True
126 try:
127 if not settings.anyio_cancel_delivery_patch_enabled:
128 logger.debug("anyio _deliver_cancellation patch DISABLED. Enable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=true if you experience CPU spin loops.")
129 return False
131 max_iterations = settings.anyio_cancel_delivery_max_iterations
132 patched_func = _create_patched_deliver_cancellation(max_iterations)
133 CancelScope._deliver_cancellation = patched_func # type: ignore[method-assign] # pylint: disable=protected-access
134 _patch_applied = True
136 logger.info(
137 "anyio _deliver_cancellation patch ENABLED (max_iterations=%d). "
138 "This is an experimental workaround for anyio#695. "
139 "Disable with ANYIO_CANCEL_DELIVERY_PATCH_ENABLED=false if it causes issues.",
140 max_iterations,
141 )
142 return True
144 except Exception as e:
145 logger.warning("Failed to apply anyio _deliver_cancellation patch: %s", e)
146 return False
149def remove_anyio_cancel_delivery_patch() -> bool:
150 """Remove the anyio _deliver_cancellation monkey-patch.
152 Restores the original anyio implementation.
154 Returns:
155 True if patch was removed, False if it wasn't applied.
156 """
157 global _patch_applied # pylint: disable=global-statement
159 if not _patch_applied:
160 return False
162 try:
163 CancelScope._deliver_cancellation = _original_deliver_cancellation # type: ignore[method-assign] # pylint: disable=protected-access
164 _patch_applied = False
165 logger.info("anyio _deliver_cancellation patch removed - restored original implementation")
166 return True
167 except Exception as e:
168 logger.warning("Failed to remove anyio _deliver_cancellation patch: %s", e)
169 return False
172# Apply patch at module load time if enabled
173apply_anyio_cancel_delivery_patch()
176def _get_sse_cleanup_timeout() -> float:
177 """Get SSE task group cleanup timeout from config.
179 This timeout controls how long to wait for SSE task group tasks to respond
180 to cancellation before forcing cleanup. Prevents CPU spin loops in anyio's
181 _deliver_cancellation when tasks don't properly handle CancelledError.
183 Returns:
184 Cleanup timeout in seconds (default: 5.0)
185 """
186 try:
187 return settings.sse_task_group_cleanup_timeout
188 except Exception:
189 return 5.0 # Fallback default
192class EventSourceResponse(BaseEventSourceResponse):
193 """Patched EventSourceResponse with CPU spin detection.
195 This mitigates a CPU spin loop issue (anyio#695) where _deliver_cancellation
196 spins at 100% CPU when tasks in the SSE task group don't respond to
197 cancellation.
199 Instead of trying to timeout the task group (which would affect normal
200 SSE connections), we copy the __call__ method and add a deadline to
201 the cancel scope to ensure cleanup doesn't hang indefinitely.
203 See:
204 - https://github.com/agronholm/anyio/issues/695
205 - https://github.com/anthropics/claude-agent-sdk-python/issues/378
206 """
208 def enable_compression(self, force: bool = False) -> None: # noqa: ARG002
209 """Enable compression (no-op for SSE streams).
211 SSE streams don't support compression as per sse_starlette.
212 This override prevents NotImplementedError from parent class.
214 Args:
215 force: Ignored - compression not supported for SSE.
216 """
218 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
219 """Handle SSE request with cancel scope deadline to prevent spin.
221 This method is copied from sse_starlette with one key modification:
222 the task group's cancel_scope gets a deadline set when cancellation
223 starts, preventing indefinite spinning if tasks don't respond.
225 Args:
226 scope: ASGI scope dictionary.
227 receive: ASGI receive callable.
228 send: ASGI send callable.
229 """
230 # Copy of sse_starlette.sse.EventSourceResponse.__call__ with deadline fix
231 async with anyio.create_task_group() as task_group:
232 # Add deadline to cancel scope to prevent indefinite spin on cleanup
233 # The deadline is set far in the future initially, and only becomes
234 # relevant if the scope is cancelled and cleanup takes too long
236 async def cancel_on_finish(coro: Callable[[], Awaitable[None]]) -> None:
237 """Execute coroutine then cancel task group with bounded deadline.
239 This wrapper runs the given coroutine and, upon completion, cancels
240 the parent task group with a deadline to prevent indefinite spinning
241 if other tasks don't respond to cancellation (anyio#695 mitigation).
243 Args:
244 coro: Async callable to execute before triggering cancellation.
245 """
246 await coro()
247 # When cancelling, set a deadline to prevent indefinite spin
248 # if other tasks don't respond to cancellation
249 task_group.cancel_scope.deadline = anyio.current_time() + _get_sse_cleanup_timeout()
250 task_group.cancel_scope.cancel()
252 task_group.start_soon(cancel_on_finish, lambda: self._stream_response(send))
253 task_group.start_soon(cancel_on_finish, lambda: self._ping(send))
254 task_group.start_soon(cancel_on_finish, self._listen_for_exit_signal)
256 if self.data_sender_callable:
257 task_group.start_soon(self.data_sender_callable)
259 # Wait for the client to disconnect last
260 task_group.start_soon(cancel_on_finish, lambda: self._listen_for_disconnect(receive))
262 if self.background is not None:
263 await self.background()
266# Pre-computed SSE frame components for performance
267_SSE_EVENT_PREFIX = b"event: "
268_SSE_DATA_PREFIX = b"\r\ndata: "
269_SSE_RETRY_PREFIX = b"\r\nretry: "
270_SSE_FRAME_END = b"\r\n\r\n"
273def _build_sse_frame(event: bytes, data: bytes, retry: int) -> bytes:
274 """Build SSE frame as bytes to avoid encode/decode overhead.
276 Args:
277 event: SSE event type as bytes (e.g., b'message', b'keepalive', b'error')
278 data: JSON data as bytes (from orjson.dumps)
279 retry: Retry timeout in milliseconds
281 Returns:
282 Complete SSE frame as bytes
284 Note:
285 Uses hardcoded CRLF (\\r\\n) separators matching sse_starlette's
286 ServerSentEvent.DEFAULT_SEPARATOR. If custom separators are ever
287 needed, this function would need to accept a sep parameter.
289 Examples:
290 >>> _build_sse_frame(b"message", b'{"test": 1}', 15000)
291 b'event: message\\r\\ndata: {"test": 1}\\r\\nretry: 15000\\r\\n\\r\\n'
293 >>> _build_sse_frame(b"keepalive", b"{}", 15000)
294 b'event: keepalive\\r\\ndata: {}\\r\\nretry: 15000\\r\\n\\r\\n'
295 """
296 return _SSE_EVENT_PREFIX + event + _SSE_DATA_PREFIX + data + _SSE_RETRY_PREFIX + str(retry).encode() + _SSE_FRAME_END
299class SSETransport(Transport):
300 """Transport implementation using Server-Sent Events with proper session management.
302 This transport implementation uses Server-Sent Events (SSE) for real-time
303 communication between the MCP gateway and clients. It provides streaming
304 capabilities with automatic session management and keepalive support.
306 Examples:
307 >>> # Create SSE transport with default URL
308 >>> transport = SSETransport()
309 >>> transport
310 <mcpgateway.transports.sse_transport.SSETransport object at ...>
312 >>> # Create SSE transport with custom URL
313 >>> transport = SSETransport("http://localhost:8080")
314 >>> transport._base_url
315 'http://localhost:8080'
317 >>> # Check initial connection state
318 >>> import asyncio
319 >>> asyncio.run(transport.is_connected())
320 False
322 >>> # Verify it's a proper Transport subclass
323 >>> isinstance(transport, Transport)
324 True
325 >>> issubclass(SSETransport, Transport)
326 True
328 >>> # Check session ID generation
329 >>> transport.session_id
330 '...'
331 >>> len(transport.session_id) > 0
332 True
334 >>> # Verify required methods exist
335 >>> hasattr(transport, 'connect')
336 True
337 >>> hasattr(transport, 'disconnect')
338 True
339 >>> hasattr(transport, 'send_message')
340 True
341 >>> hasattr(transport, 'receive_message')
342 True
343 >>> hasattr(transport, 'is_connected')
344 True
345 """
347 # Regex for validating session IDs (alphanumeric, hyphens, underscores, max 128 chars)
348 _SESSION_ID_PATTERN = None
350 @staticmethod
351 def _is_valid_session_id(session_id: str) -> bool:
352 """Validate session ID format for security.
354 Valid session IDs are alphanumeric with hyphens and underscores, max 128 chars.
355 This prevents injection attacks via malformed session IDs.
357 Args:
358 session_id: The session ID to validate.
360 Returns:
361 True if the session ID is valid, False otherwise.
363 Examples:
364 >>> SSETransport._is_valid_session_id("abc123")
365 True
366 >>> SSETransport._is_valid_session_id("abc-123_def")
367 True
368 >>> SSETransport._is_valid_session_id("a" * 128)
369 True
370 >>> SSETransport._is_valid_session_id("a" * 129)
371 False
372 >>> SSETransport._is_valid_session_id("")
373 False
374 >>> SSETransport._is_valid_session_id("abc/def")
375 False
376 >>> SSETransport._is_valid_session_id("abc:def")
377 False
378 """
379 # Standard
380 import re # pylint: disable=import-outside-toplevel
382 if not session_id or len(session_id) > 128:
383 return False
385 # Lazy compile pattern
386 if SSETransport._SESSION_ID_PATTERN is None:
387 SSETransport._SESSION_ID_PATTERN = re.compile(r"^[a-zA-Z0-9_-]+$")
389 return bool(SSETransport._SESSION_ID_PATTERN.match(session_id))
391 def __init__(self, base_url: str = None):
392 """Initialize SSE transport.
394 Args:
395 base_url: Base URL for client message endpoints
397 Examples:
398 >>> # Test default initialization
399 >>> transport = SSETransport()
400 >>> transport._connected
401 False
402 >>> transport._message_queue is not None
403 True
404 >>> transport._client_gone is not None
405 True
406 >>> len(transport._session_id) > 0
407 True
409 >>> # Test custom base URL
410 >>> transport = SSETransport("https://api.example.com")
411 >>> transport._base_url
412 'https://api.example.com'
414 >>> # Test session ID uniqueness
415 >>> transport1 = SSETransport()
416 >>> transport2 = SSETransport()
417 >>> transport1.session_id != transport2.session_id
418 True
419 """
420 self._base_url = base_url or f"http://{settings.host}:{settings.port}"
421 self._connected = False
422 self._message_queue = asyncio.Queue()
423 self._client_gone = asyncio.Event()
424 # Server generates session_id for SSE - client receives it in endpoint event
425 self._session_id = str(uuid.uuid4())
427 logger.info("Creating SSE transport with base_url=%s, session_id=%s", self._base_url, self._session_id)
429 async def connect(self) -> None:
430 """Set up SSE connection.
432 Examples:
433 >>> # Test connection setup
434 >>> transport = SSETransport()
435 >>> import asyncio
436 >>> asyncio.run(transport.connect())
437 >>> transport._connected
438 True
439 >>> asyncio.run(transport.is_connected())
440 True
441 """
442 self._connected = True
443 logger.info("SSE transport connected: %s", self._session_id)
445 async def disconnect(self) -> None:
446 """Clean up SSE connection.
448 Examples:
449 >>> # Test disconnection
450 >>> transport = SSETransport()
451 >>> import asyncio
452 >>> asyncio.run(transport.connect())
453 >>> asyncio.run(transport.disconnect())
454 >>> transport._connected
455 False
456 >>> transport._client_gone.is_set()
457 True
458 >>> asyncio.run(transport.is_connected())
459 False
461 >>> # Test disconnection when already disconnected
462 >>> transport = SSETransport()
463 >>> asyncio.run(transport.disconnect())
464 >>> transport._connected
465 False
466 """
467 if self._connected:
468 self._connected = False
469 self._client_gone.set()
470 logger.info("SSE transport disconnected: %s", self._session_id)
472 async def send_message(self, message: Dict[str, Any]) -> None:
473 """Send a message over SSE.
475 Args:
476 message: Message to send
478 Raises:
479 RuntimeError: If transport is not connected
480 Exception: If unable to put message to queue
482 Examples:
483 >>> # Test sending message when connected
484 >>> transport = SSETransport()
485 >>> import asyncio
486 >>> asyncio.run(transport.connect())
487 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1}
488 >>> asyncio.run(transport.send_message(message))
489 >>> transport._message_queue.qsize()
490 1
492 >>> # Test sending message when not connected
493 >>> transport = SSETransport()
494 >>> try:
495 ... asyncio.run(transport.send_message({"test": "message"}))
496 ... except RuntimeError as e:
497 ... print("Expected error:", str(e))
498 Expected error: Transport not connected
500 >>> # Test message format validation
501 >>> transport = SSETransport()
502 >>> asyncio.run(transport.connect())
503 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}}
504 >>> isinstance(valid_message, dict)
505 True
506 >>> "jsonrpc" in valid_message
507 True
509 >>> # Test exception handling in queue put
510 >>> transport = SSETransport()
511 >>> asyncio.run(transport.connect())
512 >>> # Create a full queue to trigger exception
513 >>> transport._message_queue = asyncio.Queue(maxsize=1)
514 >>> asyncio.run(transport._message_queue.put({"dummy": "message"}))
515 >>> # Now queue is full, next put should fail
516 >>> try:
517 ... asyncio.run(asyncio.wait_for(transport.send_message({"test": "message"}), timeout=0.1))
518 ... except asyncio.TimeoutError:
519 ... print("Queue full as expected")
520 Queue full as expected
521 """
522 if not self._connected:
523 raise RuntimeError("Transport not connected")
525 try:
526 await self._message_queue.put(message)
527 logger.debug("Message queued for SSE: %s, method=%s", self._session_id, message.get("method", "(response)"))
528 except Exception as e:
529 logger.error("Failed to queue message: %s", e)
530 raise
532 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
533 """Receive messages from the client over SSE transport.
535 This method implements a continuous message-receiving pattern for SSE transport.
536 Since SSE is primarily a server-to-client communication channel, this method
537 yields an initial initialize placeholder message and then enters a waiting loop.
538 The actual client messages are received via a separate HTTP POST endpoint
539 (not handled in this method).
541 The method will continue running until either:
542 1. The connection is explicitly disconnected (client_gone event is set)
543 2. The receive loop is cancelled from outside
545 Yields:
546 Dict[str, Any]: JSON-RPC formatted messages. The first yielded message is always
547 an initialize placeholder with the format:
548 {"jsonrpc": "2.0", "method": "initialize", "id": 1}
550 Raises:
551 RuntimeError: If the transport is not connected when this method is called
552 asyncio.CancelledError: When the SSE receive loop is cancelled externally
554 Examples:
555 >>> # Test receive message when connected
556 >>> transport = SSETransport()
557 >>> import asyncio
558 >>> asyncio.run(transport.connect())
559 >>> async def test_receive():
560 ... async for msg in transport.receive_message():
561 ... return msg
562 ... return None
563 >>> result = asyncio.run(test_receive())
564 >>> result
565 {'jsonrpc': '2.0', 'method': 'initialize', 'id': 1}
567 >>> # Test receive message when not connected
568 >>> transport = SSETransport()
569 >>> try:
570 ... async def test_receive():
571 ... async for msg in transport.receive_message():
572 ... pass
573 ... asyncio.run(test_receive())
574 ... except RuntimeError as e:
575 ... print("Expected error:", str(e))
576 Expected error: Transport not connected
578 >>> # Verify generator behavior
579 >>> transport = SSETransport()
580 >>> import inspect
581 >>> inspect.isasyncgenfunction(transport.receive_message)
582 True
583 """
584 if not self._connected:
585 raise RuntimeError("Transport not connected")
587 # For SSE, we set up a loop to wait for messages which are delivered via POST
588 # Most messages come via the POST endpoint, but we yield an initial initialize placeholder
589 # to keep the receive loop running
590 yield {"jsonrpc": "2.0", "method": "initialize", "id": 1}
592 # Continue waiting for cancellation
593 try:
594 while not self._client_gone.is_set():
595 await asyncio.sleep(1.0)
596 except asyncio.CancelledError:
597 logger.info("SSE receive loop cancelled for session %s", self._session_id)
598 raise
599 finally:
600 logger.info("SSE receive loop ended for session %s", self._session_id)
602 async def _get_message_with_timeout(self, timeout: Optional[float]) -> Optional[Dict[str, Any]]:
603 """Get message from queue with timeout, returns None on timeout.
605 Uses asyncio.wait() to avoid TimeoutError exception overhead.
607 Args:
608 timeout: Timeout in seconds, or None for no timeout
610 Returns:
611 Message dict if received, None if timeout occurred
613 Raises:
614 asyncio.CancelledError: If the operation is cancelled externally
615 """
616 if timeout is None:
617 return await self._message_queue.get()
619 get_task = asyncio.create_task(self._message_queue.get())
620 try:
621 done, _ = await asyncio.wait({get_task}, timeout=timeout)
622 except asyncio.CancelledError:
623 get_task.cancel()
624 try:
625 await get_task
626 except asyncio.CancelledError:
627 pass
628 raise
630 if get_task in done:
631 return get_task.result()
633 # Timeout - cancel pending task, but return the result if it completed in the race window.
634 get_task.cancel()
635 try:
636 return await get_task
637 except asyncio.CancelledError:
638 return None
640 async def is_connected(self) -> bool:
641 """Check if transport is connected.
643 Returns:
644 True if connected
646 Examples:
647 >>> # Test initial state
648 >>> transport = SSETransport()
649 >>> import asyncio
650 >>> asyncio.run(transport.is_connected())
651 False
653 >>> # Test after connection
654 >>> transport = SSETransport()
655 >>> asyncio.run(transport.connect())
656 >>> asyncio.run(transport.is_connected())
657 True
659 >>> # Test after disconnection
660 >>> transport = SSETransport()
661 >>> asyncio.run(transport.connect())
662 >>> asyncio.run(transport.disconnect())
663 >>> asyncio.run(transport.is_connected())
664 False
665 """
666 return self._connected
668 async def create_sse_response(
669 self,
670 request: Request,
671 on_disconnect_callback: Callable[[], Awaitable[None]] | None = None,
672 ) -> EventSourceResponse:
673 """Create SSE response for streaming.
675 Args:
676 request: FastAPI request (used for disconnection detection)
677 on_disconnect_callback: Optional async callback to run when client disconnects.
678 Used for defensive cleanup (e.g., cancelling respond tasks).
680 Returns:
681 SSE response object
683 Examples:
684 >>> # Test SSE response creation
685 >>> transport = SSETransport("http://localhost:8000")
686 >>> # Note: This method requires a FastAPI Request object
687 >>> # and cannot be easily tested in doctest environment
688 >>> callable(transport.create_sse_response)
689 True
690 """
691 endpoint_url = f"{self._base_url}/message?session_id={self._session_id}"
693 async def event_generator():
694 """Generate SSE events.
696 Yields:
697 SSE event as bytes (pre-formatted SSE frame)
698 """
699 # Send the endpoint event first
700 yield _build_sse_frame(b"endpoint", endpoint_url.encode(), settings.sse_retry_timeout)
702 # Send keepalive immediately to help establish connection (if enabled)
703 if settings.sse_keepalive_enabled:
704 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
706 consecutive_errors = 0
707 max_consecutive_errors = 3 # Exit after 3 consecutive errors (likely client disconnected)
709 # Rapid yield detection: If we're yielding faster than expected, client is likely disconnected
710 # but ASGI server isn't properly signaling it. Track yield timestamps in a sliding window.
711 yield_timestamps: deque = deque(maxlen=settings.sse_rapid_yield_max + 1) if settings.sse_rapid_yield_max > 0 else None
712 rapid_yield_window_sec = settings.sse_rapid_yield_window_ms / 1000.0
713 last_yield_time = time.monotonic()
714 consecutive_rapid_yields = 0 # Track consecutive fast yields for simpler detection
716 def check_rapid_yield() -> bool:
717 """Check if yields are happening too fast.
719 Returns:
720 True if spin loop detected and should disconnect, False otherwise.
721 """
722 nonlocal last_yield_time, consecutive_rapid_yields
723 now = time.monotonic()
724 time_since_last = now - last_yield_time
725 last_yield_time = now
727 # Track consecutive rapid yields (< 100ms apart)
728 # This catches spin loops even without full deque analysis
729 if time_since_last < 0.1:
730 consecutive_rapid_yields += 1
731 if consecutive_rapid_yields >= 10: # 10 consecutive fast yields = definite spin
732 logger.error("SSE spin loop detected (%d consecutive rapid yields, last interval %.3fs), client disconnected: %s", consecutive_rapid_yields, time_since_last, self._session_id)
733 return True
734 else:
735 consecutive_rapid_yields = 0 # Reset on normal-speed yield
737 # Also use deque-based detection for more nuanced analysis
738 if yield_timestamps is None:
739 return False
740 yield_timestamps.append(now)
741 if time_since_last < 0.01: # Less than 10ms between yields is very fast
742 if len(yield_timestamps) > settings.sse_rapid_yield_max:
743 oldest = yield_timestamps[0]
744 elapsed = now - oldest
745 if elapsed < rapid_yield_window_sec:
746 logger.error(
747 "SSE rapid yield detected (%d yields in %.3fs, last interval %.3fs), client disconnected: %s", len(yield_timestamps), elapsed, time_since_last, self._session_id
748 )
749 return True
750 return False
752 try:
753 while not self._client_gone.is_set():
754 # Check if client has disconnected via request state
755 if await request.is_disconnected():
756 logger.info("SSE client disconnected (detected via request): %s", self._session_id)
757 self._client_gone.set()
758 break
760 try:
761 # Use timeout-based polling only when keepalive is enabled
762 if not settings.sse_keepalive_enabled:
763 message = await self._message_queue.get()
764 else:
765 message = await self._get_message_with_timeout(settings.sse_keepalive_interval)
767 if message is not None:
768 json_bytes = orjson.dumps(message, option=orjson.OPT_SERIALIZE_NUMPY)
770 if logger.isEnabledFor(logging.DEBUG):
771 logger.debug("Sending SSE message: %s", json_bytes.decode())
773 yield _build_sse_frame(b"message", json_bytes, settings.sse_retry_timeout)
774 consecutive_errors = 0 # Reset on successful send
776 # Check for rapid yields after message send
777 if check_rapid_yield():
778 self._client_gone.set()
779 break
780 elif settings.sse_keepalive_enabled:
781 # Timeout - send keepalive
782 yield _build_sse_frame(b"keepalive", b"{}", settings.sse_retry_timeout)
783 consecutive_errors = 0 # Reset on successful send
784 # Check for rapid yields after keepalive too
785 if check_rapid_yield():
786 self._client_gone.set()
787 break
788 # Note: We don't clear yield_timestamps here. The deque has maxlen
789 # which automatically drops old entries. Clearing would prevent
790 # detection of spin loops where yields happen faster than expected.
792 except GeneratorExit:
793 # Client disconnected - generator is being closed
794 logger.info("SSE generator exit (client disconnected): %s", self._session_id)
795 self._client_gone.set()
796 break
797 except Exception as e:
798 consecutive_errors += 1
799 logger.warning("Error processing SSE message (attempt %d/%d): %s", consecutive_errors, max_consecutive_errors, e)
800 if consecutive_errors >= max_consecutive_errors:
801 logger.info("SSE too many consecutive errors, assuming client disconnected: %s", self._session_id)
802 self._client_gone.set()
803 break
804 # Don't yield error frame - it could cause more errors if client is gone
806 except asyncio.CancelledError:
807 logger.info("SSE event generator cancelled: %s", self._session_id)
808 self._client_gone.set()
809 except GeneratorExit:
810 logger.info("SSE generator exit: %s", self._session_id)
811 self._client_gone.set()
812 except Exception as e:
813 logger.error("SSE event generator error: %s", e)
814 self._client_gone.set()
815 finally:
816 logger.info("SSE event generator completed: %s", self._session_id)
817 self._client_gone.set() # Always set client_gone on exit to clean up
818 # CRITICAL: Also invoke disconnect callback on generator end (Finding 3)
819 # This covers server-initiated close, errors, and cancellation - not just client close
820 if on_disconnect_callback:
821 try:
822 await on_disconnect_callback()
823 except Exception as e:
824 logger.warning("Disconnect callback in finally failed for %s: %s", self._session_id, e)
826 async def on_client_close(_scope: dict) -> None:
827 """Handle client close event from sse_starlette."""
828 logger.info("SSE client close handler called: %s", self._session_id)
829 self._client_gone.set()
831 # Defensive cleanup via callback (if provided)
832 if on_disconnect_callback:
833 try:
834 await on_disconnect_callback()
835 except Exception as e:
836 logger.warning("Disconnect callback failed for %s: %s", self._session_id, e)
838 return EventSourceResponse(
839 event_generator(),
840 status_code=200,
841 headers={
842 "Cache-Control": "no-cache",
843 "Connection": "keep-alive",
844 "Content-Type": "text/event-stream",
845 "X-MCP-SSE": "true",
846 },
847 # Timeout for ASGI send() calls - protects against sends that hang indefinitely
848 # when client connection is in a bad state (e.g., client stopped reading but TCP
849 # connection not yet closed). Does NOT affect MCP server response times.
850 # Set to 0 to disable. Default matches keepalive interval.
851 send_timeout=settings.sse_send_timeout if settings.sse_send_timeout > 0 else None,
852 # Callback when client closes - helps detect disconnects that ASGI server
853 # may not properly propagate via request.is_disconnected()
854 client_close_handler_callable=on_client_close,
855 )
857 async def _client_disconnected(self, _request: Request) -> bool:
858 """Check if client has disconnected.
860 Args:
861 _request: FastAPI Request object
863 Returns:
864 bool: True if client disconnected
866 Examples:
867 >>> # Test client disconnected check
868 >>> transport = SSETransport()
869 >>> import asyncio
870 >>> asyncio.run(transport._client_disconnected(None))
871 False
873 >>> # Test after setting client gone
874 >>> transport = SSETransport()
875 >>> transport._client_gone.set()
876 >>> asyncio.run(transport._client_disconnected(None))
877 True
878 """
879 # We only check our internal client_gone flag
880 # We intentionally don't check connection_lost on the request
881 # as it can be unreliable and cause premature closures
882 return self._client_gone.is_set()
884 @property
885 def session_id(self) -> str:
886 """
887 Get the session ID for this transport.
889 Returns:
890 str: session_id
892 Examples:
893 >>> # Test session ID property
894 >>> transport = SSETransport()
895 >>> session_id = transport.session_id
896 >>> isinstance(session_id, str)
897 True
898 >>> len(session_id) > 0
899 True
900 >>> session_id == transport._session_id
901 True
903 >>> # Test session ID uniqueness
904 >>> transport1 = SSETransport()
905 >>> transport2 = SSETransport()
906 >>> transport1.session_id != transport2.session_id
907 True
908 """
909 return self._session_id