Coverage for mcpgateway / transports / streamablehttp_transport.py: 100%
745 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/streamablehttp_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7Streamable HTTP Transport Implementation.
8This module implements Streamable Http transport for MCP
10Key components include:
11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12- Configuration options for:
13 1. stateful/stateless operation
14 2. JSON response mode or SSE streams
15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17Examples:
18 >>> # Test module imports
19 >>> from mcpgateway.transports.streamablehttp_transport import (
20 ... EventEntry, StreamBuffer, InMemoryEventStore, SessionManagerWrapper
21 ... )
22 >>>
23 >>> # Verify classes are available
24 >>> EventEntry.__name__
25 'EventEntry'
26 >>> StreamBuffer.__name__
27 'StreamBuffer'
28 >>> InMemoryEventStore.__name__
29 'InMemoryEventStore'
30 >>> SessionManagerWrapper.__name__
31 'SessionManagerWrapper'
32"""
34# Standard
35import asyncio
36from contextlib import asynccontextmanager, AsyncExitStack
37import contextvars
38from dataclasses import dataclass
39import re
40from typing import Any, AsyncGenerator, Dict, List, Optional, Pattern, Union
41from uuid import uuid4
43# Third-Party
44import anyio
45from fastapi.security.utils import get_authorization_scheme_param
46import httpx
47from mcp import types
48from mcp.server.lowlevel import Server
49from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
50from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
51from mcp.types import JSONRPCMessage
52import orjson
53from sqlalchemy.orm import Session
54from starlette.datastructures import Headers
55from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
56from starlette.types import Receive, Scope, Send
58# First-Party
59from mcpgateway.common.models import LogLevel
60from mcpgateway.config import settings
61from mcpgateway.db import SessionLocal
62from mcpgateway.services.completion_service import CompletionService
63from mcpgateway.services.logging_service import LoggingService
64from mcpgateway.services.prompt_service import PromptService
65from mcpgateway.services.resource_service import ResourceService
66from mcpgateway.services.tool_service import ToolService
67from mcpgateway.transports.redis_event_store import RedisEventStore
68from mcpgateway.utils.orjson_response import ORJSONResponse
69from mcpgateway.utils.verify_credentials import verify_credentials
71# Initialize logging service first
72logging_service = LoggingService()
73logger = logging_service.get_logger(__name__)
75# Precompiled regex for server ID extraction from path
76_SERVER_ID_RE: Pattern[str] = re.compile(r"/servers/(?P<server_id>[a-fA-F0-9\-]+)/mcp")
78# Initialize ToolService, PromptService, ResourceService, CompletionService and MCP Server
79tool_service: ToolService = ToolService()
80prompt_service: PromptService = PromptService()
81resource_service: ResourceService = ResourceService()
82completion_service: CompletionService = CompletionService()
84mcp_app: Server[Any] = Server("mcp-streamable-http")
86server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id")
87request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={})
88user_context_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("user_context", default={})
90# ------------------------------ Event store ------------------------------
93@dataclass
94class EventEntry:
95 """
96 Represents an event entry in the event store.
98 Examples:
99 >>> # Create an event entry
100 >>> from mcp.types import JSONRPCMessage
101 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
102 >>> entry = EventEntry(event_id="test-123", stream_id="stream-456", message=message, seq_num=0)
103 >>> entry.event_id
104 'test-123'
105 >>> entry.stream_id
106 'stream-456'
107 >>> entry.seq_num
108 0
109 >>> # Access message attributes through model_dump() for Pydantic v2
110 >>> message_dict = message.model_dump()
111 >>> message_dict['jsonrpc']
112 '2.0'
113 >>> message_dict['method']
114 'test'
115 >>> message_dict['id']
116 1
117 """
119 event_id: EventId
120 stream_id: StreamId
121 message: JSONRPCMessage
122 seq_num: int
125@dataclass
126class StreamBuffer:
127 """
128 Ring buffer for per-stream event storage with O(1) position lookup.
130 Tracks sequence numbers to enable efficient replay without scanning.
131 Events are stored at position (seq_num % capacity) in the entries list.
133 Examples:
134 >>> # Create a stream buffer with capacity 3
135 >>> buffer = StreamBuffer(entries=[None, None, None])
136 >>> buffer.start_seq
137 0
138 >>> buffer.next_seq
139 0
140 >>> buffer.count
141 0
142 >>> len(buffer)
143 0
145 >>> # Simulate adding an entry
146 >>> buffer.next_seq = 1
147 >>> buffer.count = 1
148 >>> len(buffer)
149 1
150 """
152 entries: list[EventEntry | None]
153 start_seq: int = 0 # oldest seq still buffered
154 next_seq: int = 0 # seq assigned to next insert
155 count: int = 0
157 def __len__(self) -> int:
158 """Return the number of events currently in the buffer.
160 Returns:
161 int: The count of events in the buffer.
162 """
163 return self.count
166class InMemoryEventStore(EventStore):
167 """
168 Simple in-memory implementation of the EventStore interface for resumability.
169 This is primarily intended for examples and testing, not for production use
170 where a persistent storage solution would be more appropriate.
172 This implementation keeps only the last N events per stream for memory efficiency.
173 Uses a ring buffer with per-stream sequence numbers for O(1) event lookup and O(k) replay.
175 Examples:
176 >>> # Create event store with default max events
177 >>> store = InMemoryEventStore()
178 >>> store.max_events_per_stream
179 100
180 >>> len(store.streams)
181 0
182 >>> len(store.event_index)
183 0
185 >>> # Create event store with custom max events
186 >>> store = InMemoryEventStore(max_events_per_stream=50)
187 >>> store.max_events_per_stream
188 50
190 >>> # Test event store initialization
191 >>> store = InMemoryEventStore()
192 >>> hasattr(store, 'streams')
193 True
194 >>> hasattr(store, 'event_index')
195 True
196 >>> isinstance(store.streams, dict)
197 True
198 >>> isinstance(store.event_index, dict)
199 True
200 """
202 def __init__(self, max_events_per_stream: int = 100):
203 """Initialize the event store.
205 Args:
206 max_events_per_stream: Maximum number of events to keep per stream
208 Examples:
209 >>> # Test initialization with default value
210 >>> store = InMemoryEventStore()
211 >>> store.max_events_per_stream
212 100
213 >>> store.streams == {}
214 True
215 >>> store.event_index == {}
216 True
218 >>> # Test initialization with custom value
219 >>> store = InMemoryEventStore(max_events_per_stream=25)
220 >>> store.max_events_per_stream
221 25
222 """
223 self.max_events_per_stream = max_events_per_stream
224 # Per-stream ring buffers for O(1) position lookup
225 self.streams: dict[StreamId, StreamBuffer] = {}
226 # event_id -> EventEntry for quick lookup
227 self.event_index: dict[EventId, EventEntry] = {}
229 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
230 """
231 Stores an event with a generated event ID.
233 Args:
234 stream_id (StreamId): The ID of the stream.
235 message (JSONRPCMessage): The message to store.
237 Returns:
238 EventId: The ID of the stored event.
240 Examples:
241 >>> # Test storing an event
242 >>> import asyncio
243 >>> from mcp.types import JSONRPCMessage
244 >>> store = InMemoryEventStore(max_events_per_stream=5)
245 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
246 >>> event_id = asyncio.run(store.store_event("stream-1", message))
247 >>> isinstance(event_id, str)
248 True
249 >>> len(event_id) > 0
250 True
251 >>> len(store.streams)
252 1
253 >>> len(store.event_index)
254 1
255 >>> "stream-1" in store.streams
256 True
257 >>> event_id in store.event_index
258 True
260 >>> # Test storing multiple events in same stream
261 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
262 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
263 >>> len(store.streams["stream-1"])
264 2
265 >>> len(store.event_index)
266 2
268 >>> # Test ring buffer overflow
269 >>> store2 = InMemoryEventStore(max_events_per_stream=2)
270 >>> msg1 = JSONRPCMessage(jsonrpc="2.0", method="m1", id=1)
271 >>> msg2 = JSONRPCMessage(jsonrpc="2.0", method="m2", id=2)
272 >>> msg3 = JSONRPCMessage(jsonrpc="2.0", method="m3", id=3)
273 >>> id1 = asyncio.run(store2.store_event("stream-2", msg1))
274 >>> id2 = asyncio.run(store2.store_event("stream-2", msg2))
275 >>> # Now buffer is full, adding third will remove first
276 >>> id3 = asyncio.run(store2.store_event("stream-2", msg3))
277 >>> len(store2.streams["stream-2"])
278 2
279 >>> id1 in store2.event_index # First event removed
280 False
281 >>> id2 in store2.event_index and id3 in store2.event_index
282 True
283 """
284 # Get or create ring buffer for this stream
285 buffer = self.streams.get(stream_id)
286 if buffer is None:
287 buffer = StreamBuffer(entries=[None] * self.max_events_per_stream)
288 self.streams[stream_id] = buffer
290 # Assign per-stream sequence number
291 seq_num = buffer.next_seq
292 buffer.next_seq += 1
293 idx = seq_num % self.max_events_per_stream
295 # Handle eviction if buffer is full
296 if buffer.count == self.max_events_per_stream:
297 evicted = buffer.entries[idx]
298 if evicted is not None:
299 self.event_index.pop(evicted.event_id, None)
300 buffer.start_seq += 1
301 else:
302 if buffer.count == 0:
303 buffer.start_seq = seq_num
304 buffer.count += 1
306 # Create and store the new event entry
307 event_id = str(uuid4())
308 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message, seq_num=seq_num)
309 buffer.entries[idx] = event_entry
310 self.event_index[event_id] = event_entry
312 return event_id
314 async def replay_events_after(
315 self,
316 last_event_id: EventId,
317 send_callback: EventCallback,
318 ) -> Union[StreamId, None]:
319 """
320 Replays events that occurred after the specified event ID.
322 Uses O(1) lookup via event_index and O(k) replay where k is the number
323 of events to replay, avoiding the previous O(n) full scan.
325 Args:
326 last_event_id (EventId): The ID of the last received event. Replay starts after this event.
327 send_callback (EventCallback): Async callback to send each replayed event.
329 Returns:
330 StreamId | None: The stream ID if the event is found and replayed, otherwise None.
332 Examples:
333 >>> # Test replaying events
334 >>> import asyncio
335 >>> from mcp.types import JSONRPCMessage
336 >>> store = InMemoryEventStore()
337 >>> message1 = JSONRPCMessage(jsonrpc="2.0", method="test1", id=1)
338 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
339 >>> message3 = JSONRPCMessage(jsonrpc="2.0", method="test3", id=3)
340 >>>
341 >>> # Store events
342 >>> event_id1 = asyncio.run(store.store_event("stream-1", message1))
343 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
344 >>> event_id3 = asyncio.run(store.store_event("stream-1", message3))
345 >>>
346 >>> # Test replay after first event
347 >>> replayed_events = []
348 >>> async def mock_callback(event_message):
349 ... replayed_events.append(event_message)
350 >>>
351 >>> result = asyncio.run(store.replay_events_after(event_id1, mock_callback))
352 >>> result
353 'stream-1'
354 >>> len(replayed_events)
355 2
357 >>> # Test replay with non-existent event
358 >>> result = asyncio.run(store.replay_events_after("non-existent", mock_callback))
359 >>> result is None
360 True
361 """
362 # O(1) lookup in event_index
363 last_event = self.event_index.get(last_event_id)
364 if last_event is None:
365 logger.warning(f"Event ID {last_event_id} not found in store")
366 return None
368 buffer = self.streams.get(last_event.stream_id)
369 if buffer is None:
370 return None
372 # Validate that the event's seq_num is still within the buffer range
373 if last_event.seq_num < buffer.start_seq or last_event.seq_num >= buffer.next_seq:
374 return None
376 # O(k) replay: iterate from last_event.seq_num + 1 to buffer.next_seq - 1
377 for seq in range(last_event.seq_num + 1, buffer.next_seq):
378 entry = buffer.entries[seq % self.max_events_per_stream]
379 # Guard: skip if slot is empty or has been overwritten by a different seq
380 if entry is None or entry.seq_num != seq:
381 continue
382 await send_callback(EventMessage(entry.message, entry.event_id))
384 return last_event.stream_id
387# ------------------------------ Streamable HTTP Transport ------------------------------
390@asynccontextmanager
391async def get_db() -> AsyncGenerator[Session, Any]:
392 """
393 Asynchronous context manager for database sessions.
395 Commits the transaction on successful completion to avoid implicit rollbacks
396 for read-only operations. Rolls back explicitly on exception. Handles
397 asyncio.CancelledError explicitly to prevent transaction leaks when MCP
398 handlers are cancelled (client disconnect, timeout, etc.).
400 Yields:
401 A database session instance from SessionLocal.
402 Ensures the session is closed after use.
404 Raises:
405 asyncio.CancelledError: Re-raised after rollback and close on task cancellation.
406 Exception: Re-raises any exception after rolling back the transaction.
408 Examples:
409 >>> # Test database context manager
410 >>> import asyncio
411 >>> async def test_db():
412 ... async with get_db() as db:
413 ... return db is not None
414 >>> result = asyncio.run(test_db())
415 >>> result
416 True
417 """
418 db = SessionLocal()
419 try:
420 yield db
421 db.commit()
422 except asyncio.CancelledError:
423 # Handle cancellation explicitly to prevent transaction leaks.
424 # When MCP handlers are cancelled (client disconnect, timeout, etc.),
425 # we must rollback and close the session before re-raising.
426 try:
427 db.rollback()
428 except Exception:
429 pass # nosec B110 - Best effort rollback on cancellation
430 try:
431 db.close()
432 except Exception:
433 pass # nosec B110 - Best effort close on cancellation
434 raise
435 except Exception:
436 try:
437 db.rollback()
438 except Exception:
439 try:
440 db.invalidate()
441 except Exception:
442 pass # nosec B110 - Best effort cleanup on connection failure
443 raise
444 finally:
445 db.close()
448def get_user_email_from_context() -> str:
449 """Extract user email from the current user context.
451 Returns:
452 User email address or 'unknown' if not available
453 """
454 user = user_context_var.get()
455 if isinstance(user, dict):
456 # First try 'email', then 'sub' (JWT standard claim)
457 return user.get("email") or user.get("sub") or "unknown"
458 return str(user) if user else "unknown"
461@mcp_app.call_tool(validate_input=False)
462async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]]:
463 """
464 Handles tool invocation via the MCP Server.
466 Note: validate_input=False disables the MCP SDK's built-in JSON Schema validation.
467 This is necessary because the SDK uses jsonschema.validate() which internally calls
468 check_schema() with the default validator. Schemas using older draft features
469 (e.g., Draft 4 style exclusiveMinimum: true) fail this validation. The gateway
470 handles schema validation separately in tool_service.py with multi-draft support.
472 This function supports the MCP protocol's tool calling with structured content validation.
473 It can return either unstructured content only, or both unstructured and structured content
474 when the tool defines an outputSchema.
476 Args:
477 name (str): The name of the tool to invoke.
478 arguments (dict): A dictionary of arguments to pass to the tool.
480 Returns:
481 Union[List[ContentBlock], Tuple[List[ContentBlock], Dict[str, Any]]]:
482 - If structured content is not present: Returns a list of content blocks
483 (TextContent, ImageContent, or EmbeddedResource)
484 - If structured content is present: Returns a tuple of (unstructured_content, structured_content)
485 where structured_content is a dictionary that will be validated against the tool's outputSchema
487 The MCP SDK's call_tool decorator automatically handles both return types:
488 - List return → CallToolResult with content only
489 - Tuple return → CallToolResult with both content and structuredContent fields
491 Raises:
492 Exception: Re-raised after logging to allow MCP SDK to convert to JSON-RPC error response.
494 Raises:
495 Exception: Re-raises exceptions encountered during tool invocation after logging.
497 Examples:
498 >>> # Test call_tool function signature
499 >>> import inspect
500 >>> sig = inspect.signature(call_tool)
501 >>> list(sig.parameters.keys())
502 ['name', 'arguments']
503 >>> sig.parameters['name'].annotation
504 <class 'str'>
505 >>> sig.parameters['arguments'].annotation
506 <class 'dict'>
507 >>> sig.return_annotation
508 typing.List[typing.Union[mcp.types.TextContent, mcp.types.ImageContent, mcp.types.AudioContent, mcp.types.ResourceLink, mcp.types.EmbeddedResource]]
509 """
510 request_headers = request_headers_var.get()
511 server_id = server_id_var.get()
512 user_context = user_context_var.get()
514 meta_data = None
515 # Extract _meta from request context if available
516 try:
517 ctx = mcp_app.request_context
518 if ctx and ctx.meta is not None:
519 meta_data = ctx.meta.model_dump()
520 except LookupError:
521 # request_context might not be active in some edge cases (e.g. tests)
522 logger.debug("No active request context found")
524 # Extract authorization parameters from user context (same pattern as list_tools)
525 user_email = user_context.get("email") if user_context else None
526 token_teams = user_context.get("teams") if user_context else None
527 is_admin = user_context.get("is_admin", False) if user_context else False
529 # Admin bypass - only when token has NO team restrictions (token_teams is None)
530 # If token has explicit team scope (even empty [] for public-only), respect it
531 if is_admin and token_teams is None:
532 user_email = None
533 # token_teams stays None (unrestricted)
534 elif token_teams is None:
535 token_teams = [] # Non-admin without teams = public-only (secure default)
537 app_user_email = get_user_email_from_context() # Keep for OAuth token selection
539 # Multi-worker session affinity: check if we should forward to another worker
540 # Check both x-mcp-session-id (internal/forwarded) and mcp-session-id (client protocol header)
541 mcp_session_id = None
542 if request_headers:
543 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
544 mcp_session_id = request_headers_lower.get("x-mcp-session-id") or request_headers_lower.get("mcp-session-id")
545 if settings.mcpgateway_session_affinity_enabled and mcp_session_id:
546 try:
547 # First-Party
548 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
549 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
550 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
552 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
553 logger.debug("Invalid MCP session id for Streamable HTTP tool affinity, executing locally")
554 raise RuntimeError("invalid mcp session id")
556 pool = get_mcp_session_pool()
558 # Register session mapping BEFORE checking forwarding (same pattern as SSE)
559 # This ensures ownership is registered atomically so forward_request_to_owner() works
560 try:
561 cached = await tool_lookup_cache.get(name)
562 if cached and cached.get("status") == "active":
563 gateway_info = cached.get("gateway")
564 if gateway_info:
565 url = gateway_info.get("url")
566 gateway_id = gateway_info.get("id", "")
567 transport_type = gateway_info.get("transport", "streamablehttp")
568 if url:
569 await pool.register_session_mapping(mcp_session_id, url, gateway_id, transport_type, user_email)
570 except Exception as e:
571 logger.error(f"Failed to pre-register session mapping for Streamable HTTP: {e}")
573 forwarded_response = await pool.forward_request_to_owner(
574 mcp_session_id,
575 {"method": "tools/call", "params": {"name": name, "arguments": arguments, "_meta": meta_data}, "headers": dict(request_headers) if request_headers else {}},
576 )
577 if forwarded_response is not None:
578 # Request was handled by another worker - convert response to expected format
579 if "error" in forwarded_response:
580 raise Exception(forwarded_response["error"].get("message", "Forwarded request failed")) # pylint: disable=broad-exception-raised
581 result_data = forwarded_response.get("result", {})
583 def _rehydrate_content_items(items: Any) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource]:
584 """Convert forwarded tool result items back to MCP content types.
586 Args:
587 items: List of content item dicts from forwarded response.
589 Returns:
590 List of validated MCP content type instances.
591 """
592 if not isinstance(items, list):
593 return []
594 converted: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
595 for item in items:
596 if not isinstance(item, dict):
597 continue
598 item_type = item.get("type")
599 try:
600 if item_type == "text":
601 converted.append(types.TextContent.model_validate(item))
602 elif item_type == "image":
603 converted.append(types.ImageContent.model_validate(item))
604 elif item_type == "audio":
605 converted.append(types.AudioContent.model_validate(item))
606 elif item_type == "resource_link":
607 converted.append(types.ResourceLink.model_validate(item))
608 elif item_type == "resource":
609 converted.append(types.EmbeddedResource.model_validate(item))
610 else:
611 converted.append(types.TextContent(type="text", text=str(item)))
612 except Exception:
613 converted.append(types.TextContent(type="text", text=str(item)))
614 return converted
616 unstructured = _rehydrate_content_items(result_data.get("content", []))
617 structured = result_data.get("structuredContent") or result_data.get("structured_content")
618 if structured:
619 return (unstructured, structured)
620 return unstructured
621 except RuntimeError:
622 # Pool not initialized - execute locally
623 pass
625 try:
626 async with get_db() as db:
627 result = await tool_service.invoke_tool(
628 db=db,
629 name=name,
630 arguments=arguments,
631 request_headers=request_headers,
632 app_user_email=app_user_email,
633 user_email=user_email,
634 token_teams=token_teams,
635 server_id=server_id,
636 meta_data=meta_data,
637 )
638 if not result or not result.content:
639 logger.warning(f"No content returned by tool: {name}")
640 return []
642 # Normalize unstructured content to MCP SDK types, preserving metadata (annotations, _meta, size)
643 # Helper to convert gateway Annotations to dict for MCP SDK compatibility
644 # (mcpgateway.common.models.Annotations != mcp.types.Annotations)
645 def _convert_annotations(ann: Any) -> dict[str, Any] | None:
646 """Convert gateway Annotations to dict for MCP SDK compatibility.
648 Args:
649 ann: Gateway Annotations object, dict, or None.
651 Returns:
652 Dict representation of annotations, or None.
653 """
654 if ann is None:
655 return None
656 if isinstance(ann, dict):
657 return ann
658 if hasattr(ann, "model_dump"):
659 return ann.model_dump(by_alias=True, mode="json")
660 return None
662 def _convert_meta(meta: Any) -> dict[str, Any] | None:
663 """Convert gateway meta to dict for MCP SDK compatibility.
665 Args:
666 meta: Gateway meta object, dict, or None.
668 Returns:
669 Dict representation of meta, or None.
670 """
671 if meta is None:
672 return None
673 if isinstance(meta, dict):
674 return meta
675 if hasattr(meta, "model_dump"):
676 return meta.model_dump(by_alias=True, mode="json")
677 return None
679 unstructured: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
680 for content in result.content:
681 if content.type == "text":
682 unstructured.append(
683 types.TextContent(
684 type="text",
685 text=content.text,
686 annotations=_convert_annotations(getattr(content, "annotations", None)),
687 _meta=_convert_meta(getattr(content, "meta", None)),
688 )
689 )
690 elif content.type == "image":
691 unstructured.append(
692 types.ImageContent(
693 type="image",
694 data=content.data,
695 mimeType=content.mime_type,
696 annotations=_convert_annotations(getattr(content, "annotations", None)),
697 _meta=_convert_meta(getattr(content, "meta", None)),
698 )
699 )
700 elif content.type == "audio":
701 unstructured.append(
702 types.AudioContent(
703 type="audio",
704 data=content.data,
705 mimeType=content.mime_type,
706 annotations=_convert_annotations(getattr(content, "annotations", None)),
707 _meta=_convert_meta(getattr(content, "meta", None)),
708 )
709 )
710 elif content.type == "resource_link":
711 unstructured.append(
712 types.ResourceLink(
713 type="resource_link",
714 uri=content.uri,
715 name=content.name,
716 description=getattr(content, "description", None),
717 mimeType=getattr(content, "mime_type", None),
718 size=getattr(content, "size", None),
719 _meta=_convert_meta(getattr(content, "meta", None)),
720 )
721 )
722 elif content.type == "resource":
723 # EmbeddedResource - pass through the model dump as the MCP SDK type requires complex nested structure
724 unstructured.append(types.EmbeddedResource.model_validate(content.model_dump(by_alias=True, mode="json")))
725 else:
726 # Unknown content type - convert to text representation
727 unstructured.append(types.TextContent(type="text", text=str(content.model_dump(by_alias=True, mode="json"))))
729 # If the tool produced structured content (ToolResult.structured_content / structuredContent),
730 # return a combination (unstructured, structured) so the server can validate against outputSchema.
731 # The ToolService may populate structured_content (snake_case) or the model may expose
732 # an alias 'structuredContent' when dumped via model_dump(by_alias=True).
733 structured = None
734 try:
735 # Prefer attribute if present
736 structured = getattr(result, "structured_content", None)
737 except Exception:
738 structured = None
740 # Fallback to by-alias dump (in case the result is a pydantic model with alias fields)
741 if structured is None:
742 try:
743 structured = result.model_dump(by_alias=True).get("structuredContent") if hasattr(result, "model_dump") else None
744 except Exception:
745 structured = None
747 if structured:
748 return (unstructured, structured)
750 return unstructured
751 except Exception as e:
752 logger.exception(f"Error calling tool '{name}': {e}")
753 # Re-raise the exception so the MCP SDK can properly convert it to an error response
754 # This ensures error details are propagated to the client instead of returning empty results
755 raise
758@mcp_app.list_tools()
759async def list_tools() -> List[types.Tool]:
760 """
761 Lists all tools available to the MCP Server.
763 Returns:
764 A list of Tool objects containing metadata such as name, description, and input schema.
765 Logs and returns an empty list on failure.
767 Examples:
768 >>> # Test list_tools function signature
769 >>> import inspect
770 >>> sig = inspect.signature(list_tools)
771 >>> list(sig.parameters.keys())
772 []
773 >>> sig.return_annotation
774 typing.List[mcp.types.Tool]
775 """
776 server_id = server_id_var.get()
777 request_headers = request_headers_var.get()
778 user_context = user_context_var.get()
780 # Extract filtering parameters from user context
781 user_email = user_context.get("email") if user_context else None
782 # Use None as default to distinguish "no teams specified" from "empty teams array"
783 token_teams = user_context.get("teams") if user_context else None
784 is_admin = user_context.get("is_admin", False) if user_context else False
786 # Admin bypass - only when token has NO team restrictions (token_teams is None)
787 # If token has explicit team scope (even empty [] for public-only), respect it
788 if is_admin and token_teams is None:
789 user_email = None
790 # token_teams stays None (unrestricted)
791 elif token_teams is None:
792 token_teams = [] # Non-admin without teams = public-only (secure default)
794 if server_id:
795 try:
796 async with get_db() as db:
797 tools = await tool_service.list_server_tools(db, server_id, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
798 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools]
799 except Exception as e:
800 logger.exception(f"Error listing tools:{e}")
801 return []
802 else:
803 try:
804 async with get_db() as db:
805 tools, _ = await tool_service.list_tools(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
806 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools]
807 except Exception as e:
808 logger.exception(f"Error listing tools:{e}")
809 return []
812@mcp_app.list_prompts()
813async def list_prompts() -> List[types.Prompt]:
814 """
815 Lists all prompts available to the MCP Server.
817 Returns:
818 A list of Prompt objects containing metadata such as name, description, and arguments.
819 Logs and returns an empty list on failure.
821 Examples:
822 >>> import inspect
823 >>> sig = inspect.signature(list_prompts)
824 >>> list(sig.parameters.keys())
825 []
826 >>> sig.return_annotation
827 typing.List[mcp.types.Prompt]
828 """
829 server_id = server_id_var.get()
830 user_context = user_context_var.get()
832 # Extract filtering parameters from user context
833 user_email = user_context.get("email") if user_context else None
834 # Use None as default to distinguish "no teams specified" from "empty teams array"
835 token_teams = user_context.get("teams") if user_context else None
836 is_admin = user_context.get("is_admin", False) if user_context else False
838 # Admin bypass - only when token has NO team restrictions (token_teams is None)
839 # If token has explicit team scope (even empty [] for public-only), respect it
840 if is_admin and token_teams is None:
841 user_email = None
842 # token_teams stays None (unrestricted)
843 elif token_teams is None:
844 token_teams = [] # Non-admin without teams = public-only (secure default)
846 if server_id:
847 try:
848 async with get_db() as db:
849 prompts = await prompt_service.list_server_prompts(db, server_id, user_email=user_email, token_teams=token_teams)
850 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
851 except Exception as e:
852 logger.exception(f"Error listing Prompts:{e}")
853 return []
854 else:
855 try:
856 async with get_db() as db:
857 prompts, _ = await prompt_service.list_prompts(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
858 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
859 except Exception as e:
860 logger.exception(f"Error listing prompts:{e}")
861 return []
864@mcp_app.get_prompt()
865async def get_prompt(prompt_id: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
866 """
867 Retrieves a prompt by ID, optionally substituting arguments.
869 Args:
870 prompt_id (str): The ID of the prompt to retrieve.
871 arguments (Optional[dict[str, str]]): Optional dictionary of arguments to substitute into the prompt.
873 Returns:
874 GetPromptResult: Object containing the prompt messages and description.
875 Returns an empty list on failure or if no prompt content is found.
877 Logs exceptions if any errors occur during retrieval.
879 Examples:
880 >>> import inspect
881 >>> sig = inspect.signature(get_prompt)
882 >>> list(sig.parameters.keys())
883 ['prompt_id', 'arguments']
884 >>> sig.return_annotation.__name__
885 'GetPromptResult'
886 """
887 server_id = server_id_var.get()
888 user_context = user_context_var.get()
890 # Extract authorization parameters from user context (same pattern as list_prompts)
891 user_email = user_context.get("email") if user_context else None
892 token_teams = user_context.get("teams") if user_context else None
893 is_admin = user_context.get("is_admin", False) if user_context else False
895 # Admin bypass - only when token has NO team restrictions (token_teams is None)
896 if is_admin and token_teams is None:
897 user_email = None
898 # token_teams stays None (unrestricted)
899 elif token_teams is None:
900 token_teams = [] # Non-admin without teams = public-only (secure default)
902 meta_data = None
903 # Extract _meta from request context if available
904 try:
905 ctx = mcp_app.request_context
906 if ctx and ctx.meta is not None:
907 meta_data = ctx.meta.model_dump()
908 except LookupError:
909 # request_context might not be active in some edge cases (e.g. tests)
910 logger.debug("No active request context found")
912 try:
913 async with get_db() as db:
914 try:
915 result = await prompt_service.get_prompt(
916 db=db,
917 prompt_id=prompt_id,
918 arguments=arguments,
919 user=user_email,
920 server_id=server_id,
921 token_teams=token_teams,
922 _meta_data=meta_data,
923 )
924 except Exception as e:
925 logger.exception(f"Error getting prompt '{prompt_id}': {e}")
926 return []
927 if not result or not result.messages:
928 logger.warning(f"No content returned by prompt: {prompt_id}")
929 return []
930 message_dicts = [message.model_dump() for message in result.messages]
931 return types.GetPromptResult(messages=message_dicts, description=result.description)
932 except Exception as e:
933 logger.exception(f"Error getting prompt '{prompt_id}': {e}")
934 return []
937@mcp_app.list_resources()
938async def list_resources() -> List[types.Resource]:
939 """
940 Lists all resources available to the MCP Server.
942 Returns:
943 A list of Resource objects containing metadata such as uri, name, description, and mimeType.
944 Logs and returns an empty list on failure.
946 Examples:
947 >>> import inspect
948 >>> sig = inspect.signature(list_resources)
949 >>> list(sig.parameters.keys())
950 []
951 >>> sig.return_annotation
952 typing.List[mcp.types.Resource]
953 """
954 server_id = server_id_var.get()
955 user_context = user_context_var.get()
957 # Extract filtering parameters from user context
958 user_email = user_context.get("email") if user_context else None
959 # Use None as default to distinguish "no teams specified" from "empty teams array"
960 token_teams = user_context.get("teams") if user_context else None
961 is_admin = user_context.get("is_admin", False) if user_context else False
963 # Admin bypass - only when token has NO team restrictions (token_teams is None)
964 # If token has explicit team scope (even empty [] for public-only), respect it
965 if is_admin and token_teams is None:
966 user_email = None
967 # token_teams stays None (unrestricted)
968 elif token_teams is None:
969 token_teams = [] # Non-admin without teams = public-only (secure default)
971 if server_id:
972 try:
973 async with get_db() as db:
974 resources = await resource_service.list_server_resources(db, server_id, user_email=user_email, token_teams=token_teams)
975 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
976 except Exception as e:
977 logger.exception(f"Error listing Resources:{e}")
978 return []
979 else:
980 try:
981 async with get_db() as db:
982 resources, _ = await resource_service.list_resources(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
983 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
984 except Exception as e:
985 logger.exception(f"Error listing resources:{e}")
986 return []
989@mcp_app.read_resource()
990async def read_resource(resource_uri: str) -> Union[str, bytes]:
991 """
992 Reads the content of a resource specified by its URI.
994 Args:
995 resource_uri (str): The URI of the resource to read.
997 Returns:
998 Union[str, bytes]: The content of the resource as text or binary data.
999 Returns empty string on failure or if no content is found.
1001 Logs exceptions if any errors occur during reading.
1003 Examples:
1004 >>> import inspect
1005 >>> sig = inspect.signature(read_resource)
1006 >>> list(sig.parameters.keys())
1007 ['resource_uri']
1008 >>> sig.return_annotation
1009 typing.Union[str, bytes]
1010 """
1011 server_id = server_id_var.get()
1012 user_context = user_context_var.get()
1014 # Extract authorization parameters from user context (same pattern as list_resources)
1015 user_email = user_context.get("email") if user_context else None
1016 token_teams = user_context.get("teams") if user_context else None
1017 is_admin = user_context.get("is_admin", False) if user_context else False
1019 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1020 if is_admin and token_teams is None:
1021 user_email = None
1022 # token_teams stays None (unrestricted)
1023 elif token_teams is None:
1024 token_teams = [] # Non-admin without teams = public-only (secure default)
1026 meta_data = None
1027 # Extract _meta from request context if available
1028 try:
1029 ctx = mcp_app.request_context
1030 if ctx and ctx.meta is not None:
1031 meta_data = ctx.meta.model_dump()
1032 except LookupError:
1033 # request_context might not be active in some edge cases (e.g. tests)
1034 logger.debug("No active request context found")
1036 try:
1037 async with get_db() as db:
1038 try:
1039 result = await resource_service.read_resource(
1040 db=db,
1041 resource_uri=str(resource_uri),
1042 user=user_email,
1043 server_id=server_id,
1044 token_teams=token_teams,
1045 meta_data=meta_data,
1046 )
1047 except Exception as e:
1048 logger.exception(f"Error reading resource '{resource_uri}': {e}")
1049 return ""
1051 # Return blob content if available (binary resources)
1052 if result and result.blob:
1053 return result.blob
1055 # Return text content if available (text resources)
1056 if result and result.text:
1057 return result.text
1059 # No content found
1060 logger.warning(f"No content returned by resource: {resource_uri}")
1061 return ""
1062 except Exception as e:
1063 logger.exception(f"Error reading resource '{resource_uri}': {e}")
1064 return ""
1067@mcp_app.list_resource_templates()
1068async def list_resource_templates() -> List[Dict[str, Any]]:
1069 """
1070 Lists all resource templates available to the MCP Server.
1072 Returns:
1073 List[types.ResourceTemplate]: A list of resource templates with their URIs and metadata.
1075 Examples:
1076 >>> import inspect
1077 >>> sig = inspect.signature(list_resource_templates)
1078 >>> list(sig.parameters.keys())
1079 []
1080 >>> sig.return_annotation.__origin__.__name__
1081 'list'
1082 """
1083 # Extract filtering parameters from user context (same pattern as list_resources)
1084 user_context = user_context_var.get()
1085 user_email = user_context.get("email") if user_context else None
1086 token_teams = user_context.get("teams") if user_context else None
1087 is_admin = user_context.get("is_admin", False) if user_context else False
1089 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1090 # If token has explicit team scope (even empty [] for public-only), respect it
1091 if is_admin and token_teams is None:
1092 user_email = None
1093 # token_teams stays None (unrestricted)
1094 elif token_teams is None:
1095 token_teams = [] # Non-admin without teams = public-only (secure default)
1097 try:
1098 async with get_db() as db:
1099 try:
1100 resource_templates = await resource_service.list_resource_templates(
1101 db,
1102 user_email=user_email,
1103 token_teams=token_teams,
1104 )
1105 return [template.model_dump(by_alias=True) for template in resource_templates]
1106 except Exception as e:
1107 logger.exception(f"Error listing resource templates: {e}")
1108 return []
1109 except Exception as e:
1110 logger.exception(f"Error listing resource templates: {e}")
1111 return []
1114@mcp_app.set_logging_level()
1115async def set_logging_level(level: types.LoggingLevel) -> types.EmptyResult:
1116 """
1117 Sets the logging level for the MCP Server.
1119 Args:
1120 level (types.LoggingLevel): The desired logging level (debug, info, notice, warning, error, critical, alert, emergency).
1122 Returns:
1123 types.EmptyResult: An empty result indicating success.
1125 Examples:
1126 >>> import inspect
1127 >>> sig = inspect.signature(set_logging_level)
1128 >>> list(sig.parameters.keys())
1129 ['level']
1130 """
1131 try:
1132 # Convert MCP logging level to our LogLevel enum
1133 level_map = {
1134 "debug": LogLevel.DEBUG,
1135 "info": LogLevel.INFO,
1136 "notice": LogLevel.INFO,
1137 "warning": LogLevel.WARNING,
1138 "error": LogLevel.ERROR,
1139 "critical": LogLevel.CRITICAL,
1140 "alert": LogLevel.CRITICAL,
1141 "emergency": LogLevel.CRITICAL,
1142 }
1143 log_level = level_map.get(level.lower(), LogLevel.INFO)
1144 await logging_service.set_level(log_level)
1145 return types.EmptyResult()
1146 except Exception as e:
1147 logger.exception(f"Error setting logging level: {e}")
1148 return types.EmptyResult()
1151@mcp_app.completion()
1152async def complete(
1153 ref: Union[types.PromptReference, types.ResourceTemplateReference],
1154 argument: types.CompleteRequest,
1155 context: Optional[types.CompletionContext] = None,
1156) -> types.CompleteResult:
1157 """
1158 Provides argument completion suggestions for prompts or resources.
1160 Args:
1161 ref: A reference to a prompt or a resource template. Can be either
1162 `types.PromptReference` or `types.ResourceTemplateReference`.
1163 argument: The completion request specifying the input text and
1164 position for which completion suggestions should be generated.
1165 context: Optional contextual information for the completion request,
1166 such as user, environment, or invocation metadata.
1168 Returns:
1169 types.CompleteResult: A normalized completion result containing
1170 completion values, metadata (total, hasMore), and any additional
1171 MCP-compliant completion fields.
1173 Raises:
1174 Exception: If completion handling fails internally. The method
1175 logs the exception and returns an empty completion structure.
1176 """
1177 try:
1178 async with get_db() as db:
1179 params = {
1180 "ref": ref.model_dump() if hasattr(ref, "model_dump") else ref,
1181 "argument": argument.model_dump() if hasattr(argument, "model_dump") else argument,
1182 "context": context.model_dump() if hasattr(context, "model_dump") else context,
1183 }
1185 result = await completion_service.handle_completion(db, params)
1187 # ✅ Normalize the result for MCP
1188 if isinstance(result, dict):
1189 completion_data = result.get("completion", result)
1190 return types.Completion(**completion_data)
1192 if hasattr(result, "completion"):
1193 completion_obj = result.completion
1195 # If completion itself is a dict
1196 if isinstance(completion_obj, dict):
1197 return types.Completion(**completion_obj)
1199 # If completion is another CompleteResult (nested)
1200 if hasattr(completion_obj, "completion"):
1201 inner_completion = completion_obj.completion.model_dump() if hasattr(completion_obj.completion, "model_dump") else completion_obj.completion
1202 return types.Completion(**inner_completion)
1204 # If completion is already a Completion model
1205 if isinstance(completion_obj, types.Completion):
1206 return completion_obj
1208 # If it's another Pydantic model (e.g., mcpgateway.models.Completion)
1209 if hasattr(completion_obj, "model_dump"):
1210 return types.Completion(**completion_obj.model_dump())
1212 # If result itself is already a types.Completion
1213 if isinstance(result, types.Completion):
1214 return result
1216 # Fallback: return empty completion
1217 return types.Completion(values=[], total=0, hasMore=False)
1219 except Exception as e:
1220 logger.exception(f"Error handling completion: {e}")
1221 return types.Completion(values=[], total=0, hasMore=False)
1224class SessionManagerWrapper:
1225 """
1226 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
1227 Provides start, stop, and request handling methods.
1229 Examples:
1230 >>> # Test SessionManagerWrapper initialization
1231 >>> wrapper = SessionManagerWrapper()
1232 >>> wrapper
1233 <mcpgateway.transports.streamablehttp_transport.SessionManagerWrapper object at ...>
1234 >>> hasattr(wrapper, 'session_manager')
1235 True
1236 >>> hasattr(wrapper, 'stack')
1237 True
1238 >>> isinstance(wrapper.stack, AsyncExitStack)
1239 True
1240 """
1242 def __init__(self) -> None:
1243 """
1244 Initializes the session manager and the exit stack used for managing its lifecycle.
1246 Examples:
1247 >>> # Test initialization
1248 >>> wrapper = SessionManagerWrapper()
1249 >>> wrapper.session_manager is not None
1250 True
1251 >>> wrapper.stack is not None
1252 True
1253 """
1255 if settings.use_stateful_sessions:
1256 # Use Redis event store for single-worker stateful deployments
1257 if settings.cache_type == "redis" and settings.redis_url:
1258 event_store = RedisEventStore(max_events_per_stream=settings.streamable_http_max_events_per_stream, ttl=settings.streamable_http_event_ttl)
1259 logger.debug("Using RedisEventStore for stateful sessions (single-worker)")
1260 else:
1261 # Fall back to in-memory for single-worker or when Redis not available
1262 event_store = InMemoryEventStore()
1263 logger.warning("Using InMemoryEventStore - only works with single worker!")
1264 stateless = False
1265 else:
1266 event_store = None
1267 stateless = True
1269 self.session_manager = StreamableHTTPSessionManager(
1270 app=mcp_app,
1271 event_store=event_store,
1272 json_response=settings.json_response_enabled,
1273 stateless=stateless,
1274 )
1275 self.stack = AsyncExitStack()
1277 async def initialize(self) -> None:
1278 """
1279 Starts the Streamable HTTP session manager context.
1281 Examples:
1282 >>> # Test initialize method exists
1283 >>> wrapper = SessionManagerWrapper()
1284 >>> hasattr(wrapper, 'initialize')
1285 True
1286 >>> callable(wrapper.initialize)
1287 True
1288 """
1289 logger.debug("Initializing Streamable HTTP service")
1290 await self.stack.enter_async_context(self.session_manager.run())
1292 async def shutdown(self) -> None:
1293 """
1294 Gracefully shuts down the Streamable HTTP session manager.
1296 Examples:
1297 >>> # Test shutdown method exists
1298 >>> wrapper = SessionManagerWrapper()
1299 >>> hasattr(wrapper, 'shutdown')
1300 True
1301 >>> callable(wrapper.shutdown)
1302 True
1303 """
1304 logger.debug("Stopping Streamable HTTP Session Manager...")
1305 await self.stack.aclose()
1307 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
1308 """
1309 Forwards an incoming ASGI request to the streamable HTTP session manager.
1311 Args:
1312 scope (Scope): ASGI scope object containing connection information.
1313 receive (Receive): ASGI receive callable.
1314 send (Send): ASGI send callable.
1316 Raises:
1317 Exception: Any exception raised during request handling is logged.
1319 Logs any exceptions that occur during request handling.
1321 Examples:
1322 >>> # Test handle_streamable_http method exists
1323 >>> wrapper = SessionManagerWrapper()
1324 >>> hasattr(wrapper, 'handle_streamable_http')
1325 True
1326 >>> callable(wrapper.handle_streamable_http)
1327 True
1329 >>> # Test method signature
1330 >>> import inspect
1331 >>> sig = inspect.signature(wrapper.handle_streamable_http)
1332 >>> list(sig.parameters.keys())
1333 ['scope', 'receive', 'send']
1334 """
1336 path = scope["modified_path"]
1337 # Uses precompiled regex for server ID extraction
1338 match = _SERVER_ID_RE.search(path)
1340 # Extract request headers from scope (ASGI provides bytes; normalize to lowercase for lookup).
1341 raw_headers = scope.get("headers") or []
1342 headers: dict[str, str] = {}
1343 for item in raw_headers:
1344 if not isinstance(item, (tuple, list)) or len(item) != 2:
1345 continue
1346 k, v = item
1347 if not isinstance(k, (bytes, bytearray)) or not isinstance(v, (bytes, bytearray)):
1348 continue
1349 # latin-1 is a byte-preserving decode; safe for arbitrary header bytes.
1350 headers[k.decode("latin-1").lower()] = v.decode("latin-1")
1352 # Log session info for debugging stateful sessions
1353 mcp_session_id = headers.get("mcp-session-id", "not-provided")
1354 method = scope.get("method", "UNKNOWN")
1355 query_string = scope.get("query_string", b"").decode("utf-8")
1356 logger.debug(f"[STATEFUL] Streamable HTTP {method} {path} | MCP-Session-Id: {mcp_session_id} | Query: {query_string} | Stateful: {settings.use_stateful_sessions}")
1358 # Note: mcp-session-id from client is used for gateway-internal session affinity
1359 # routing (stored in request_headers_var), but is NOT renamed or forwarded to
1360 # upstream servers - it's a gateway-side concept, not an end-to-end semantic header
1362 # Multi-worker session affinity: check if we should forward to another worker
1363 # This must happen BEFORE the SDK's session manager handles the request
1364 is_internally_forwarded = headers.get("x-forwarded-internally") == "true"
1366 if settings.mcpgateway_session_affinity_enabled and mcp_session_id != "not-provided":
1367 try:
1368 # First-Party
1369 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
1371 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
1372 logger.debug("Invalid MCP session id on Streamable HTTP request, skipping affinity")
1373 mcp_session_id = "not-provided"
1374 except Exception:
1375 mcp_session_id = "not-provided"
1377 # Log session manager ID for debugging
1378 logger.debug(f"[SESSION_MGR_DEBUG] Manager ID: {id(self.session_manager)}")
1380 if is_internally_forwarded:
1381 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Received forwarded request | Method: {method} | Session: {mcp_session_id}")
1383 # Only route POST requests with JSON-RPC body to /rpc
1384 # DELETE and other methods should return success (session cleanup is local)
1385 if method != "POST":
1386 logger.debug("[HTTP_AFFINITY_FORWARDED] Non-POST method, returning 200 OK")
1387 await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]})
1388 await send({"type": "http.response.body", "body": b'{"jsonrpc":"2.0","result":{}}'})
1389 return
1391 # For POST requests, bypass SDK session manager and use /rpc directly
1392 # This avoids SDK's session cleanup issues while maintaining stateful upstream connections
1393 try:
1394 # Read request body
1395 body_parts = []
1396 while True:
1397 message = await receive()
1398 if message["type"] == "http.request":
1399 body_parts.append(message.get("body", b""))
1400 if not message.get("more_body", False):
1401 break
1402 elif message["type"] == "http.disconnect":
1403 return
1404 body = b"".join(body_parts)
1406 if not body:
1407 logger.debug("[HTTP_AFFINITY_FORWARDED] Empty body, returning 202 Accepted")
1408 await send({"type": "http.response.start", "status": 202, "headers": []})
1409 await send({"type": "http.response.body", "body": b""})
1410 return
1412 json_body = orjson.loads(body)
1413 rpc_method = json_body.get("method", "")
1414 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Routing to /rpc | Method: {rpc_method}")
1416 # Notifications don't need /rpc routing - just acknowledge
1417 if rpc_method.startswith("notifications/"):
1418 logger.debug("[HTTP_AFFINITY_FORWARDED] Notification, returning 202 Accepted")
1419 await send({"type": "http.response.start", "status": 202, "headers": []})
1420 await send({"type": "http.response.body", "body": b""})
1421 return
1423 async with httpx.AsyncClient() as client:
1424 rpc_headers = {
1425 "content-type": "application/json",
1426 "x-mcp-session-id": mcp_session_id, # Pass session for upstream affinity
1427 "x-forwarded-internally": "true", # Prevent infinite forwarding loops
1428 }
1429 # Copy auth header if present
1430 if "authorization" in headers:
1431 rpc_headers["authorization"] = headers["authorization"]
1433 response = await client.post(
1434 f"http://127.0.0.1:{settings.port}/rpc",
1435 content=body,
1436 headers=rpc_headers,
1437 timeout=30.0,
1438 )
1440 # Return response to client
1441 response_headers = [
1442 (b"content-type", b"application/json"),
1443 (b"content-length", str(len(response.content)).encode()),
1444 ]
1445 if mcp_session_id != "not-provided":
1446 response_headers.append((b"mcp-session-id", mcp_session_id.encode()))
1448 await send(
1449 {
1450 "type": "http.response.start",
1451 "status": response.status_code,
1452 "headers": response_headers,
1453 }
1454 )
1455 await send(
1456 {
1457 "type": "http.response.body",
1458 "body": response.content,
1459 }
1460 )
1461 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Response sent | Status: {response.status_code}")
1462 return
1463 except Exception as e:
1464 logger.error(f"[HTTP_AFFINITY_FORWARDED] Error routing to /rpc: {e}")
1465 # Fall through to SDK handling as fallback
1467 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions and mcp_session_id != "not-provided" and not is_internally_forwarded:
1468 try:
1469 # First-Party - lazy import to avoid circular dependencies
1470 # First-Party
1471 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
1473 pool = get_mcp_session_pool()
1474 owner = await pool.get_streamable_http_session_owner(mcp_session_id)
1475 logger.debug(f"[HTTP_AFFINITY_CHECK] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner from Redis: {owner}")
1477 if owner and owner != WORKER_ID:
1478 # Session owned by another worker - forward the entire HTTP request
1479 logger.info(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner: {owner} | Forwarding HTTP request")
1481 # Read request body
1482 body_parts = []
1483 while True:
1484 message = await receive()
1485 if message["type"] == "http.request":
1486 body_parts.append(message.get("body", b""))
1487 if not message.get("more_body", False):
1488 break
1489 elif message["type"] == "http.disconnect":
1490 return
1491 body = b"".join(body_parts)
1493 # Forward to owner worker
1494 response = await pool.forward_streamable_http_to_owner(
1495 owner_worker_id=owner,
1496 mcp_session_id=mcp_session_id,
1497 method=method,
1498 path=path,
1499 headers=headers,
1500 body=body,
1501 query_string=query_string,
1502 )
1504 if response:
1505 # Send forwarded response back to client
1506 response_headers = [(k.encode(), v.encode()) for k, v in response["headers"].items() if k.lower() not in ("transfer-encoding", "content-encoding", "content-length")]
1507 response_headers.append((b"content-length", str(len(response["body"])).encode()))
1509 await send(
1510 {
1511 "type": "http.response.start",
1512 "status": response["status"],
1513 "headers": response_headers,
1514 }
1515 )
1516 await send(
1517 {
1518 "type": "http.response.body",
1519 "body": response["body"],
1520 }
1521 )
1522 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Forwarded response sent to client")
1523 return
1525 # Forwarding failed - fall through to local handling
1526 # This may result in "session not found" but it's better than no response
1527 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Forwarding failed, falling back to local")
1529 elif owner == WORKER_ID and method == "POST":
1530 # We own this session - route POST requests to /rpc to avoid SDK session issues
1531 # The SDK's _server_instances gets cleared between requests, so we can't rely on it
1532 logger.debug(f"[HTTP_AFFINITY_LOCAL] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner is us, routing to /rpc")
1534 # Read request body
1535 body_parts = []
1536 while True:
1537 message = await receive()
1538 if message["type"] == "http.request":
1539 body_parts.append(message.get("body", b""))
1540 if not message.get("more_body", False):
1541 break
1542 elif message["type"] == "http.disconnect":
1543 return
1544 body = b"".join(body_parts)
1546 if not body:
1547 logger.debug("[HTTP_AFFINITY_LOCAL] Empty body, returning 202 Accepted")
1548 await send({"type": "http.response.start", "status": 202, "headers": []})
1549 await send({"type": "http.response.body", "body": b""})
1550 return
1552 # Parse JSON-RPC and route to /rpc
1553 try:
1554 json_body = orjson.loads(body)
1555 rpc_method = json_body.get("method", "")
1556 logger.debug(f"[HTTP_AFFINITY_LOCAL] Routing to /rpc | Method: {rpc_method}")
1558 # Notifications don't need /rpc routing
1559 if rpc_method.startswith("notifications/"):
1560 logger.debug("[HTTP_AFFINITY_LOCAL] Notification, returning 202 Accepted")
1561 await send({"type": "http.response.start", "status": 202, "headers": []})
1562 await send({"type": "http.response.body", "body": b""})
1563 return
1565 async with httpx.AsyncClient() as client:
1566 rpc_headers = {
1567 "content-type": "application/json",
1568 "x-mcp-session-id": mcp_session_id,
1569 "x-forwarded-internally": "true",
1570 }
1571 if "authorization" in headers:
1572 rpc_headers["authorization"] = headers["authorization"]
1574 response = await client.post(
1575 f"http://127.0.0.1:{settings.port}/rpc",
1576 content=body,
1577 headers=rpc_headers,
1578 timeout=30.0,
1579 )
1581 response_headers = [
1582 (b"content-type", b"application/json"),
1583 (b"content-length", str(len(response.content)).encode()),
1584 (b"mcp-session-id", mcp_session_id.encode()),
1585 ]
1587 await send(
1588 {
1589 "type": "http.response.start",
1590 "status": response.status_code,
1591 "headers": response_headers,
1592 }
1593 )
1594 await send(
1595 {
1596 "type": "http.response.body",
1597 "body": response.content,
1598 }
1599 )
1600 logger.debug(f"[HTTP_AFFINITY_LOCAL] Response sent | Status: {response.status_code}")
1601 return
1602 except Exception as e:
1603 logger.error(f"[HTTP_AFFINITY_LOCAL] Error routing to /rpc: {e}")
1604 # Fall through to SDK handling as fallback
1606 except RuntimeError:
1607 # Pool not initialized - proceed with local handling
1608 pass
1609 except Exception as e:
1610 logger.debug(f"Session affinity check failed, proceeding locally: {e}")
1612 # Store headers in context for tool invocations
1613 request_headers_var.set(headers)
1615 if match:
1616 server_id = match.group("server_id")
1617 server_id_var.set(server_id)
1618 else:
1619 server_id_var.set(None)
1621 # For session affinity: wrap send to capture session ID from response headers
1622 # This allows us to register ownership for new sessions created by the SDK
1623 captured_session_id: Optional[str] = None
1625 async def send_with_capture(message: Dict[str, Any]) -> None:
1626 """Wrap ASGI send to capture session ID from response headers.
1628 Args:
1629 message: ASGI message dict.
1630 """
1631 nonlocal captured_session_id
1632 if message["type"] == "http.response.start" and settings.mcpgateway_session_affinity_enabled:
1633 # Look for mcp-session-id in response headers
1634 response_headers = message.get("headers", [])
1635 for header_name, header_value in response_headers:
1636 if isinstance(header_name, bytes):
1637 header_name = header_name.decode("latin-1")
1638 if isinstance(header_value, bytes):
1639 header_value = header_value.decode("latin-1")
1640 if header_name.lower() == "mcp-session-id":
1641 captured_session_id = header_value
1642 break
1643 await send(message)
1645 try:
1646 await self.session_manager.handle_request(scope, receive, send_with_capture)
1647 logger.debug(f"[STATEFUL] Streamable HTTP request completed successfully | Session: {mcp_session_id}")
1649 # Register ownership for the session we just handled
1650 # This captures both existing sessions (mcp_session_id from request)
1651 # and new sessions (captured_session_id from response)
1652 logger.debug(
1653 f"[HTTP_AFFINITY_DEBUG] affinity_enabled={settings.mcpgateway_session_affinity_enabled} stateful={settings.use_stateful_sessions} captured={captured_session_id} mcp_session_id={mcp_session_id}"
1654 )
1655 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions:
1656 session_to_register = captured_session_id or (mcp_session_id if mcp_session_id != "not-provided" else None)
1657 logger.debug(f"[HTTP_AFFINITY_DEBUG] session_to_register={session_to_register}")
1658 if session_to_register:
1659 try:
1660 # First-Party - lazy import to avoid circular dependencies
1661 # First-Party
1662 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
1664 pool = get_mcp_session_pool()
1665 await pool.register_pool_session_owner(session_to_register)
1666 logger.debug(f"[HTTP_AFFINITY_SDK] Worker {WORKER_ID} | Session {session_to_register[:8]}... | Registered ownership after SDK handling")
1667 except Exception as e:
1668 logger.debug(f"[HTTP_AFFINITY_DEBUG] Exception during registration: {e}")
1669 logger.warning(f"Failed to register session ownership: {e}")
1671 except anyio.ClosedResourceError:
1672 # Expected when client closes one side of the stream (normal lifecycle)
1673 logger.debug("Streamable HTTP connection closed by client (ClosedResourceError)")
1674 except Exception as e:
1675 logger.error(f"[STATEFUL] Streamable HTTP request failed | Session: {mcp_session_id} | Error: {e}")
1676 logger.exception(f"Error handling streamable HTTP request: {e}")
1677 raise
1680# ------------------------- Authentication for /mcp routes ------------------------------
1683async def streamable_http_auth(scope: Any, receive: Any, send: Any) -> bool:
1684 """
1685 Perform authentication check in middleware context (ASGI scope).
1687 This function is intended to be used in middleware wrapping ASGI apps.
1688 It authenticates only requests targeting paths ending in "/mcp" or "/mcp/".
1690 Behavior:
1691 - If the path does not end with "/mcp", authentication is skipped.
1692 - If mcp_require_auth=True (strict mode):
1693 - Requests without valid auth are rejected with 401.
1694 - If mcp_require_auth=False (default, permissive mode):
1695 - Requests without auth are allowed but get public-only access (token_teams=[]).
1696 - Valid tokens get full scoped access based on their teams.
1697 - If a Bearer token is present, it is verified using `verify_credentials`.
1698 - If verification fails and mcp_require_auth=True, a 401 Unauthorized JSON response is sent.
1700 Args:
1701 scope: The ASGI scope dictionary, which includes request metadata.
1702 receive: ASGI receive callable used to receive events.
1703 send: ASGI send callable used to send events (e.g. a 401 response).
1705 Returns:
1706 bool: True if authentication passes or is skipped.
1707 False if authentication fails and a 401 response is sent.
1709 Examples:
1710 >>> # Test streamable_http_auth function exists
1711 >>> callable(streamable_http_auth)
1712 True
1714 >>> # Test function signature
1715 >>> import inspect
1716 >>> sig = inspect.signature(streamable_http_auth)
1717 >>> list(sig.parameters.keys())
1718 ['scope', 'receive', 'send']
1719 """
1720 path = scope.get("path", "")
1721 if not path.endswith("/mcp") and not path.endswith("/mcp/"):
1722 # No auth needed for other paths in this middleware usage
1723 return True
1725 headers = Headers(scope=scope)
1727 # CORS preflight (OPTIONS + Origin + Access-Control-Request-Method) cannot carry auth headers
1728 method = scope.get("method", "")
1729 if method == "OPTIONS":
1730 origin = headers.get("origin")
1731 if origin and headers.get("access-control-request-method"):
1732 return True
1734 authorization = headers.get("authorization")
1735 proxy_user = headers.get(settings.proxy_user_header) if settings.trust_proxy_auth else None
1737 # Determine authentication strategy based on settings
1738 if not settings.mcp_client_auth_enabled and settings.trust_proxy_auth:
1739 # Client auth disabled → allow proxy header
1740 if proxy_user:
1741 # Set enriched user context for proxy-authenticated sessions
1742 user_context_var.set(
1743 {
1744 "email": proxy_user,
1745 "teams": [], # Proxy auth has no team context
1746 "is_authenticated": True,
1747 "is_admin": False,
1748 }
1749 )
1750 return True # Trusted proxy supplied user
1752 # --- Standard JWT authentication flow (client auth enabled) ---
1753 token: str | None = None
1754 if authorization:
1755 scheme, credentials = get_authorization_scheme_param(authorization)
1756 if scheme.lower() == "bearer" and credentials:
1757 token = credentials
1759 try:
1760 if token is None:
1761 raise Exception("No token provided")
1762 user_payload = await verify_credentials(token)
1763 # Store enriched user context with normalized teams
1764 if isinstance(user_payload, dict):
1765 # Resolve teams based on token_use claim
1766 token_use = user_payload.get("token_use")
1767 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type
1768 # Session token: resolve teams from DB/cache
1769 user_email_for_teams = user_payload.get("sub") or user_payload.get("email")
1770 is_admin_flag = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False)
1771 if is_admin_flag:
1772 final_teams = None # Admin bypass
1773 elif user_email_for_teams:
1774 # Resolve teams synchronously with L1 cache (StreamableHTTP uses sync context)
1775 # First-Party
1776 from mcpgateway.auth import _resolve_teams_from_db_sync # pylint: disable=import-outside-toplevel
1778 final_teams = _resolve_teams_from_db_sync(user_email_for_teams, is_admin=False)
1779 else:
1780 final_teams = [] # No email — public-only
1781 else:
1782 # API token or legacy: use embedded teams from JWT
1783 # First-Party
1784 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel
1786 final_teams = normalize_token_teams(user_payload)
1788 # ═══════════════════════════════════════════════════════════════════════════
1789 # SECURITY: Validate team membership for team-scoped tokens
1790 # Users removed from a team should lose MCP access immediately, not at token expiry
1791 # ═══════════════════════════════════════════════════════════════════════════
1792 user_email = user_payload.get("sub") or user_payload.get("email")
1793 is_admin = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False)
1795 # Only validate membership for team-scoped tokens (non-empty teams list)
1796 # Skip for: public-only tokens ([]), admin unrestricted tokens (None)
1797 if final_teams and len(final_teams) > 0 and user_email:
1798 # Import lazily to avoid circular imports
1799 # First-Party
1800 from mcpgateway.cache.auth_cache import get_auth_cache # pylint: disable=import-outside-toplevel
1801 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel
1803 auth_cache = get_auth_cache()
1805 # Check cache first (60s TTL)
1806 cached_result = auth_cache.get_team_membership_valid_sync(user_email, final_teams)
1807 if cached_result is False:
1808 logger.warning(f"MCP auth rejected: User {user_email} no longer member of teams (cached)")
1809 response = ORJSONResponse(
1810 {"detail": "Token invalid: User is no longer a member of the associated team"},
1811 status_code=HTTP_403_FORBIDDEN,
1812 )
1813 await response(scope, receive, send)
1814 return False
1816 if cached_result is None:
1817 # Cache miss - query database
1818 # Third-Party
1819 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1821 db = SessionLocal()
1822 try:
1823 memberships = (
1824 db.execute(
1825 select(EmailTeamMember.team_id).where(
1826 EmailTeamMember.team_id.in_(final_teams),
1827 EmailTeamMember.user_email == user_email,
1828 EmailTeamMember.is_active.is_(True),
1829 )
1830 )
1831 .scalars()
1832 .all()
1833 )
1835 valid_team_ids = set(memberships)
1836 missing_teams = set(final_teams) - valid_team_ids
1838 if missing_teams:
1839 logger.warning(f"MCP auth rejected: User {user_email} no longer member of teams: {missing_teams}")
1840 auth_cache.set_team_membership_valid_sync(user_email, final_teams, False)
1841 response = ORJSONResponse(
1842 {"detail": "Token invalid: User is no longer a member of the associated team"},
1843 status_code=HTTP_403_FORBIDDEN,
1844 )
1845 await response(scope, receive, send)
1846 return False
1848 # Cache positive result
1849 auth_cache.set_team_membership_valid_sync(user_email, final_teams, True)
1850 finally:
1851 # Rollback any implicit transaction before closing to prevent
1852 # idle-in-transaction state if the connection is returned to pool
1853 try:
1854 db.rollback()
1855 except Exception:
1856 pass # nosec B110 - Best effort rollback
1857 db.close()
1859 user_context_var.set(
1860 {
1861 "email": user_email,
1862 "teams": final_teams,
1863 "is_authenticated": True,
1864 "is_admin": is_admin,
1865 }
1866 )
1867 elif proxy_user:
1868 # If using proxy auth, store the proxy user
1869 user_context_var.set(
1870 {
1871 "email": proxy_user,
1872 "teams": [],
1873 "is_authenticated": True,
1874 "is_admin": False,
1875 }
1876 )
1877 except Exception:
1878 # If JWT auth fails but we have a trusted proxy user, use that
1879 if settings.trust_proxy_auth and proxy_user:
1880 user_context_var.set(
1881 {
1882 "email": proxy_user,
1883 "teams": [],
1884 "is_authenticated": True,
1885 "is_admin": False,
1886 }
1887 )
1888 return True # Fall back to proxy authentication
1890 # Check mcp_require_auth setting to determine behavior
1891 if settings.mcp_require_auth:
1892 # Strict mode: require authentication, return 401 for unauthenticated requests
1893 response = ORJSONResponse(
1894 {"detail": "Authentication required for MCP endpoints"},
1895 status_code=HTTP_401_UNAUTHORIZED,
1896 headers={"WWW-Authenticate": "Bearer"},
1897 )
1898 await response(scope, receive, send)
1899 return False
1901 # Permissive mode (default): allow unauthenticated access with public-only scope
1902 # Set context indicating unauthenticated user with public-only access (teams=[])
1903 user_context_var.set(
1904 {
1905 "email": None,
1906 "teams": [], # Empty list = public-only access
1907 "is_authenticated": False,
1908 "is_admin": False,
1909 }
1910 )
1911 return True # Allow request to proceed with public-only access
1913 return True