Coverage for mcpgateway / transports / streamablehttp_transport.py: 99%
1260 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/streamablehttp_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7Streamable HTTP Transport Implementation.
8This module implements Streamable Http transport for MCP
10Key components include:
11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12- Configuration options for:
13 1. stateful/stateless operation
14 2. JSON response mode or SSE streams
15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17Examples:
18 >>> # Test module imports
19 >>> from mcpgateway.transports.streamablehttp_transport import (
20 ... EventEntry, StreamBuffer, InMemoryEventStore, SessionManagerWrapper
21 ... )
22 >>>
23 >>> # Verify classes are available
24 >>> EventEntry.__name__
25 'EventEntry'
26 >>> StreamBuffer.__name__
27 'StreamBuffer'
28 >>> InMemoryEventStore.__name__
29 'InMemoryEventStore'
30 >>> SessionManagerWrapper.__name__
31 'SessionManagerWrapper'
32"""
34# Standard
35import asyncio
36from contextlib import asynccontextmanager, AsyncExitStack
37import contextvars
38from dataclasses import dataclass
39import re
40from typing import Any, AsyncGenerator, Dict, List, Optional, Pattern, Tuple, Union
41from uuid import uuid4
43# Third-Party
44import anyio
45from fastapi import HTTPException
46from fastapi.security.utils import get_authorization_scheme_param
47import httpx
48from mcp import ClientSession, types
49from mcp.client.streamable_http import streamablehttp_client
50from mcp.server.lowlevel import Server
51from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
52from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
53from mcp.types import JSONRPCMessage, PaginatedRequestParams, ReadResourceRequest, ReadResourceRequestParams
54import orjson
55from sqlalchemy.exc import SQLAlchemyError
56from sqlalchemy.orm import Session
57from starlette.datastructures import Headers
58from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
59from starlette.types import Receive, Scope, Send
61# First-Party
62from mcpgateway.common.models import LogLevel
63from mcpgateway.config import settings
64from mcpgateway.db import SessionLocal
65from mcpgateway.middleware.rbac import _ACCESS_DENIED_MSG
66from mcpgateway.services.completion_service import CompletionService
67from mcpgateway.services.logging_service import LoggingService
68from mcpgateway.services.oauth_manager import OAuthEnforcementUnavailableError, OAuthRequiredError
69from mcpgateway.services.permission_service import PermissionService
70from mcpgateway.services.prompt_service import PromptService
71from mcpgateway.services.resource_service import ResourceService
72from mcpgateway.services.tool_service import ToolService
73from mcpgateway.transports.redis_event_store import RedisEventStore
74from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, GATEWAY_ID_HEADER
75from mcpgateway.utils.orjson_response import ORJSONResponse
76from mcpgateway.utils.verify_credentials import is_proxy_auth_trust_active, require_auth_header_first, verify_credentials
78# Initialize logging service first
79logging_service = LoggingService()
80logger = logging_service.get_logger(__name__)
82# Precompiled regex for server ID extraction from path
83_SERVER_ID_RE: Pattern[str] = re.compile(r"/servers/(?P<server_id>[a-fA-F0-9\-]+)/mcp")
85# ASGI scope key for propagating gateway context from middleware to MCP handlers
86_MCPGATEWAY_CONTEXT_KEY = "_mcpgateway_context"
88# Initialize ToolService, PromptService, ResourceService, CompletionService and MCP Server
89tool_service: ToolService = ToolService()
90prompt_service: PromptService = PromptService()
91resource_service: ResourceService = ResourceService()
92completion_service: CompletionService = CompletionService()
94mcp_app: Server[Any] = Server("mcp-streamable-http")
96server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id")
97request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={})
98user_context_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("user_context", default={})
99_oauth_checked_var: contextvars.ContextVar[bool] = contextvars.ContextVar("_oauth_checked", default=False)
100_shared_session_registry: Optional[Any] = None
102# ------------------------------ Event store ------------------------------
105@dataclass
106class EventEntry:
107 """
108 Represents an event entry in the event store.
110 Examples:
111 >>> # Create an event entry
112 >>> from mcp.types import JSONRPCMessage
113 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
114 >>> entry = EventEntry(event_id="test-123", stream_id="stream-456", message=message, seq_num=0)
115 >>> entry.event_id
116 'test-123'
117 >>> entry.stream_id
118 'stream-456'
119 >>> entry.seq_num
120 0
121 >>> # Access message attributes through model_dump() for Pydantic v2
122 >>> message_dict = message.model_dump()
123 >>> message_dict['jsonrpc']
124 '2.0'
125 >>> message_dict['method']
126 'test'
127 >>> message_dict['id']
128 1
129 """
131 event_id: EventId
132 stream_id: StreamId
133 message: JSONRPCMessage
134 seq_num: int
137@dataclass
138class StreamBuffer:
139 """
140 Ring buffer for per-stream event storage with O(1) position lookup.
142 Tracks sequence numbers to enable efficient replay without scanning.
143 Events are stored at position (seq_num % capacity) in the entries list.
145 Examples:
146 >>> # Create a stream buffer with capacity 3
147 >>> buffer = StreamBuffer(entries=[None, None, None])
148 >>> buffer.start_seq
149 0
150 >>> buffer.next_seq
151 0
152 >>> buffer.count
153 0
154 >>> len(buffer)
155 0
157 >>> # Simulate adding an entry
158 >>> buffer.next_seq = 1
159 >>> buffer.count = 1
160 >>> len(buffer)
161 1
162 """
164 entries: list[EventEntry | None]
165 start_seq: int = 0 # oldest seq still buffered
166 next_seq: int = 0 # seq assigned to next insert
167 count: int = 0
169 def __len__(self) -> int:
170 """Return the number of events currently in the buffer.
172 Returns:
173 int: The count of events in the buffer.
174 """
175 return self.count
178class InMemoryEventStore(EventStore):
179 """
180 Simple in-memory implementation of the EventStore interface for resumability.
181 This is primarily intended for examples and testing, not for production use
182 where a persistent storage solution would be more appropriate.
184 This implementation keeps only the last N events per stream for memory efficiency.
185 Uses a ring buffer with per-stream sequence numbers for O(1) event lookup and O(k) replay.
187 Examples:
188 >>> # Create event store with default max events
189 >>> store = InMemoryEventStore()
190 >>> store.max_events_per_stream
191 100
192 >>> len(store.streams)
193 0
194 >>> len(store.event_index)
195 0
197 >>> # Create event store with custom max events
198 >>> store = InMemoryEventStore(max_events_per_stream=50)
199 >>> store.max_events_per_stream
200 50
202 >>> # Test event store initialization
203 >>> store = InMemoryEventStore()
204 >>> hasattr(store, 'streams')
205 True
206 >>> hasattr(store, 'event_index')
207 True
208 >>> isinstance(store.streams, dict)
209 True
210 >>> isinstance(store.event_index, dict)
211 True
212 """
214 def __init__(self, max_events_per_stream: int = 100):
215 """Initialize the event store.
217 Args:
218 max_events_per_stream: Maximum number of events to keep per stream
220 Examples:
221 >>> # Test initialization with default value
222 >>> store = InMemoryEventStore()
223 >>> store.max_events_per_stream
224 100
225 >>> store.streams == {}
226 True
227 >>> store.event_index == {}
228 True
230 >>> # Test initialization with custom value
231 >>> store = InMemoryEventStore(max_events_per_stream=25)
232 >>> store.max_events_per_stream
233 25
234 """
235 self.max_events_per_stream = max_events_per_stream
236 # Per-stream ring buffers for O(1) position lookup
237 self.streams: dict[StreamId, StreamBuffer] = {}
238 # event_id -> EventEntry for quick lookup
239 self.event_index: dict[EventId, EventEntry] = {}
241 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
242 """
243 Stores an event with a generated event ID.
245 Args:
246 stream_id (StreamId): The ID of the stream.
247 message (JSONRPCMessage): The message to store.
249 Returns:
250 EventId: The ID of the stored event.
252 Examples:
253 >>> # Test storing an event
254 >>> import asyncio
255 >>> from mcp.types import JSONRPCMessage
256 >>> store = InMemoryEventStore(max_events_per_stream=5)
257 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
258 >>> event_id = asyncio.run(store.store_event("stream-1", message))
259 >>> isinstance(event_id, str)
260 True
261 >>> len(event_id) > 0
262 True
263 >>> len(store.streams)
264 1
265 >>> len(store.event_index)
266 1
267 >>> "stream-1" in store.streams
268 True
269 >>> event_id in store.event_index
270 True
272 >>> # Test storing multiple events in same stream
273 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
274 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
275 >>> len(store.streams["stream-1"])
276 2
277 >>> len(store.event_index)
278 2
280 >>> # Test ring buffer overflow
281 >>> store2 = InMemoryEventStore(max_events_per_stream=2)
282 >>> msg1 = JSONRPCMessage(jsonrpc="2.0", method="m1", id=1)
283 >>> msg2 = JSONRPCMessage(jsonrpc="2.0", method="m2", id=2)
284 >>> msg3 = JSONRPCMessage(jsonrpc="2.0", method="m3", id=3)
285 >>> id1 = asyncio.run(store2.store_event("stream-2", msg1))
286 >>> id2 = asyncio.run(store2.store_event("stream-2", msg2))
287 >>> # Now buffer is full, adding third will remove first
288 >>> id3 = asyncio.run(store2.store_event("stream-2", msg3))
289 >>> len(store2.streams["stream-2"])
290 2
291 >>> id1 in store2.event_index # First event removed
292 False
293 >>> id2 in store2.event_index and id3 in store2.event_index
294 True
295 """
296 # Get or create ring buffer for this stream
297 buffer = self.streams.get(stream_id)
298 if buffer is None:
299 buffer = StreamBuffer(entries=[None] * self.max_events_per_stream)
300 self.streams[stream_id] = buffer
302 # Assign per-stream sequence number
303 seq_num = buffer.next_seq
304 buffer.next_seq += 1
305 idx = seq_num % self.max_events_per_stream
307 # Handle eviction if buffer is full
308 if buffer.count == self.max_events_per_stream:
309 evicted = buffer.entries[idx]
310 if evicted is not None:
311 self.event_index.pop(evicted.event_id, None)
312 buffer.start_seq += 1
313 else:
314 if buffer.count == 0:
315 buffer.start_seq = seq_num
316 buffer.count += 1
318 # Create and store the new event entry
319 event_id = str(uuid4())
320 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message, seq_num=seq_num)
321 buffer.entries[idx] = event_entry
322 self.event_index[event_id] = event_entry
324 return event_id
326 async def replay_events_after(
327 self,
328 last_event_id: EventId,
329 send_callback: EventCallback,
330 ) -> Union[StreamId, None]:
331 """
332 Replays events that occurred after the specified event ID.
334 Uses O(1) lookup via event_index and O(k) replay where k is the number
335 of events to replay, avoiding the previous O(n) full scan.
337 Args:
338 last_event_id (EventId): The ID of the last received event. Replay starts after this event.
339 send_callback (EventCallback): Async callback to send each replayed event.
341 Returns:
342 StreamId | None: The stream ID if the event is found and replayed, otherwise None.
344 Examples:
345 >>> # Test replaying events
346 >>> import asyncio
347 >>> from mcp.types import JSONRPCMessage
348 >>> store = InMemoryEventStore()
349 >>> message1 = JSONRPCMessage(jsonrpc="2.0", method="test1", id=1)
350 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
351 >>> message3 = JSONRPCMessage(jsonrpc="2.0", method="test3", id=3)
352 >>>
353 >>> # Store events
354 >>> event_id1 = asyncio.run(store.store_event("stream-1", message1))
355 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
356 >>> event_id3 = asyncio.run(store.store_event("stream-1", message3))
357 >>>
358 >>> # Test replay after first event
359 >>> replayed_events = []
360 >>> async def mock_callback(event_message):
361 ... replayed_events.append(event_message)
362 >>>
363 >>> result = asyncio.run(store.replay_events_after(event_id1, mock_callback))
364 >>> result
365 'stream-1'
366 >>> len(replayed_events)
367 2
369 >>> # Test replay with non-existent event
370 >>> result = asyncio.run(store.replay_events_after("non-existent", mock_callback))
371 >>> result is None
372 True
373 """
374 # O(1) lookup in event_index
375 last_event = self.event_index.get(last_event_id)
376 if last_event is None:
377 logger.warning("Event ID %s not found in store", last_event_id)
378 return None
380 buffer = self.streams.get(last_event.stream_id)
381 if buffer is None:
382 return None
384 # Validate that the event's seq_num is still within the buffer range
385 if last_event.seq_num < buffer.start_seq or last_event.seq_num >= buffer.next_seq:
386 return None
388 # O(k) replay: iterate from last_event.seq_num + 1 to buffer.next_seq - 1
389 for seq in range(last_event.seq_num + 1, buffer.next_seq):
390 entry = buffer.entries[seq % self.max_events_per_stream]
391 # Guard: skip if slot is empty or has been overwritten by a different seq
392 if entry is None or entry.seq_num != seq:
393 continue
394 await send_callback(EventMessage(entry.message, entry.event_id))
396 return last_event.stream_id
399# ------------------------------ Streamable HTTP Transport ------------------------------
402@asynccontextmanager
403async def get_db() -> AsyncGenerator[Session, Any]:
404 """
405 Asynchronous context manager for database sessions.
407 Commits the transaction on successful completion to avoid implicit rollbacks
408 for read-only operations. Rolls back explicitly on exception. Handles
409 asyncio.CancelledError explicitly to prevent transaction leaks when MCP
410 handlers are cancelled (client disconnect, timeout, etc.).
412 Yields:
413 A database session instance from SessionLocal.
414 Ensures the session is closed after use.
416 Raises:
417 asyncio.CancelledError: Re-raised after rollback and close on task cancellation.
418 Exception: Re-raises any exception after rolling back the transaction.
420 Examples:
421 >>> # Test database context manager
422 >>> import asyncio
423 >>> async def test_db():
424 ... async with get_db() as db:
425 ... return db is not None
426 >>> result = asyncio.run(test_db())
427 >>> result
428 True
429 """
430 db = SessionLocal()
431 try:
432 yield db
433 db.commit()
434 except asyncio.CancelledError:
435 # Handle cancellation explicitly to prevent transaction leaks.
436 # When MCP handlers are cancelled (client disconnect, timeout, etc.),
437 # we must rollback and close the session before re-raising.
438 try:
439 db.rollback()
440 except Exception:
441 pass # nosec B110 - Best effort rollback on cancellation
442 try:
443 db.close()
444 except Exception:
445 pass # nosec B110 - Best effort close on cancellation
446 raise
447 except Exception:
448 try:
449 db.rollback()
450 except Exception:
451 try:
452 db.invalidate()
453 except Exception:
454 pass # nosec B110 - Best effort cleanup on connection failure
455 raise
456 finally:
457 db.close()
460def get_user_email_from_context() -> str:
461 """Extract user email from the current user context.
463 Returns:
464 User email address or 'unknown' if not available
465 """
466 user = user_context_var.get()
467 if isinstance(user, dict):
468 # First try 'email', then 'sub' (JWT standard claim)
469 return user.get("email") or user.get("sub") or "unknown"
470 return str(user) if user else "unknown"
473def _should_enforce_streamable_rbac(user_context: Optional[dict[str, Any]]) -> bool:
474 """Return True when request originated from authenticated Streamable HTTP middleware.
476 Direct unit tests may call MCP handlers without middleware context; those
477 invocations should preserve historical behavior and avoid forced RBAC checks.
479 Args:
480 user_context: Request user context propagated by Streamable HTTP auth middleware.
482 Returns:
483 bool: ``True`` when permission checks should be enforced for this request.
484 """
485 return isinstance(user_context, dict) and user_context.get("is_authenticated", False) is True
488def _build_resource_metadata_url(scope: Scope, server_id: str) -> str:
489 """Construct the RFC 9728 OAuth Protected Resource Metadata URL from ASGI scope.
491 Inspects ``x-forwarded-proto`` and ``host`` headers first (reverse-proxy
492 scenario), then falls back to ``scope["scheme"]`` and ``scope["server"]``.
493 Includes ``scope["root_path"]`` so that deployments behind a reverse proxy
494 with a path prefix emit the correct public URL.
496 Args:
497 scope: ASGI connection scope.
498 server_id: Virtual-server identifier.
500 Returns:
501 Fully-qualified URL string, or ``""`` if construction fails.
502 """
503 try:
504 headers = Headers(scope=scope)
505 forwarded_proto = headers.get("x-forwarded-proto")
506 if forwarded_proto:
507 proto = forwarded_proto.split(",")[0].strip().lower()
508 else:
509 proto = scope.get("scheme", "https")
510 if proto not in ("http", "https"):
511 proto = "https"
513 host = headers.get("host")
514 if not host:
515 server_tuple = scope.get("server")
516 if server_tuple:
517 host_addr, port = server_tuple
518 # Wrap IPv6 addresses in brackets per RFC 2732
519 if ":" in str(host_addr):
520 host_addr = f"[{host_addr}]"
521 default_port = 443 if proto == "https" else 80
522 host = f"{host_addr}:{port}" if port != default_port else host_addr
523 else:
524 return ""
526 root_path = scope.get("root_path", "").rstrip("/")
527 return f"{proto}://{host}{root_path}/.well-known/oauth-protected-resource/servers/{server_id}/mcp"
528 except Exception:
529 return ""
532async def _check_server_oauth_enforcement(server_id: str, user_context: Optional[dict[str, Any]]) -> None:
533 """Reject unauthenticated callers when a server requires OAuth.
535 Looks up the server's ``oauth_enabled`` flag and raises
536 ``OAuthRequiredError`` when the flag is set but the caller is not
537 authenticated. This closes the gap where OAuth capability is
538 *advertised* (via RFC 9728 ``experimental.oauth``) but never
539 *enforced* on subsequent MCP requests.
541 The result is cached in ``_oauth_checked_var`` for the lifetime of
542 the request so that handler-level defense-in-depth calls do not
543 repeat the DB query already performed by the middleware.
545 .. note::
546 SSE transport is not covered here because it already requires
547 authentication unconditionally.
549 Args:
550 server_id: Virtual-server identifier extracted from the URL path.
551 user_context: User context set by ``streamable_http_auth`` middleware.
553 Raises:
554 OAuthRequiredError: When the server requires OAuth and the caller has
555 not provided valid authentication credentials.
556 OAuthEnforcementUnavailableError: When the database or session is
557 unavailable and the server's ``oauth_enabled`` flag cannot be
558 verified (fail-closed).
559 """
560 if _oauth_checked_var.get(False):
561 return # Already checked during this request
563 if not server_id or server_id == "default_server_id":
564 return # No server context — nothing to enforce
566 is_authenticated = (user_context or {}).get("is_authenticated", False)
567 if is_authenticated:
568 _oauth_checked_var.set(True)
569 return # Already authenticated — no need to check
571 # Lazy DB lookup to avoid import-time side-effects
572 # Third-Party
573 from sqlalchemy import select # pylint: disable=import-outside-toplevel
575 # First-Party
576 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
578 try:
579 async with get_db() as db:
580 server = db.execute(select(DbServer).where(DbServer.id == server_id)).scalar_one_or_none()
581 if server and server.oauth_enabled:
582 logger.warning("OAuth required for server %s but caller is unauthenticated", server_id)
583 raise OAuthRequiredError(
584 "This server requires OAuth authentication. Please provide a valid access token.",
585 server_id=server_id,
586 )
587 _oauth_checked_var.set(True)
588 except SQLAlchemyError as exc:
589 # DB lookup failure — fail-closed for security.
590 logger.error("OAuth enforcement DB lookup failed for server %s: %s", server_id, exc)
591 raise OAuthEnforcementUnavailableError(
592 f"Unable to verify OAuth requirements for server {server_id}",
593 server_id=server_id,
594 ) from exc
597async def _check_streamable_permission(
598 *,
599 user_context: dict[str, Any],
600 permission: str,
601 allow_admin_bypass: bool = True,
602 check_any_team: bool = False,
603) -> bool:
604 """Evaluate RBAC permission for a Streamable HTTP request context.
606 Args:
607 user_context: Authenticated user context from Streamable HTTP middleware.
608 permission: Permission name to evaluate (for example ``tools.execute``).
609 allow_admin_bypass: Whether unrestricted admin tokens can bypass team checks.
610 check_any_team: Whether any matching team grants permission.
612 Returns:
613 bool: ``True`` when the caller is authorized for ``permission``.
614 """
615 user_email = user_context.get("email")
616 if not user_email:
617 return False
619 try:
620 async with get_db() as db:
621 permission_service = PermissionService(db)
622 granted = await permission_service.check_permission(
623 user_email=user_email,
624 permission=permission,
625 token_teams=user_context.get("teams"),
626 allow_admin_bypass=allow_admin_bypass,
627 check_any_team=check_any_team,
628 )
629 if not granted:
630 logger.warning("Streamable HTTP RBAC denied: user=%s, permission=%s", user_email, permission)
631 return granted
632 except Exception as exc:
633 logger.warning("Streamable HTTP RBAC check failed for %s / %s: %s", user_email, permission, exc)
634 return False
637def _check_scoped_permission(user_context: dict[str, Any], permission: str) -> bool:
638 """Check if token scoped permissions allow this operation.
640 Args:
641 user_context: User context dict (may contain 'scoped_permissions' key).
642 permission: Permission to check.
644 Returns:
645 True if allowed (no scope cap, wildcard, or permission present).
646 """
647 scoped = user_context.get("scoped_permissions")
648 if not scoped: # None or empty list = defer to RBAC
649 return True
650 if "*" in scoped:
651 return True
652 allowed = permission in scoped
653 if not allowed:
654 logger.warning("Streamable HTTP token scope denied: user=%s, required=%s", user_context.get("email"), permission)
655 return allowed
658def set_shared_session_registry(session_registry: Any) -> None:
659 """Set the process-wide session registry used by Streamable HTTP helpers.
661 Args:
662 session_registry: Registry instance created by application bootstrap.
663 """
664 global _shared_session_registry # pylint: disable=global-statement
665 _shared_session_registry = session_registry
668def _get_shared_session_registry() -> Optional[Any]:
669 """Return the process-wide session registry reference.
671 Returns:
672 Optional[Any]: Session registry instance, or ``None`` when unavailable.
673 """
674 return _shared_session_registry
677async def _claim_streamable_session_owner(session_id: str, owner_email: str) -> Optional[str]:
678 """Claim or resolve the logical owner for a Streamable HTTP session.
680 Args:
681 session_id: Logical MCP session identifier to claim.
682 owner_email: Caller email that should own the session.
684 Returns:
685 Optional[str]: Effective owner email after claim, or ``None`` if unavailable.
686 """
687 if not session_id or not owner_email:
688 return None
690 session_registry = _get_shared_session_registry()
691 if session_registry is None:
692 return None
694 try:
695 return await session_registry.claim_session_owner(session_id, owner_email)
696 except Exception as exc:
697 logger.warning("Failed to claim session owner for %s: %s", session_id, exc)
698 return None
701async def _validate_streamable_session_access(
702 *,
703 mcp_session_id: Optional[str],
704 user_context: Optional[dict[str, Any]],
705 rpc_method: Optional[str] = None,
706) -> tuple[bool, int, str]:
707 """Authorize access to a stateful Streamable HTTP session identifier.
709 Args:
710 mcp_session_id: Session identifier from request headers.
711 user_context: Authenticated user context for the current request.
712 rpc_method: JSON-RPC method name when available.
714 Returns:
715 Tuple ``(allowed, deny_status_code, deny_message)``.
716 """
717 if not settings.use_stateful_sessions:
718 return True, 200, ""
720 if not mcp_session_id or mcp_session_id == "not-provided":
721 return True, 200, ""
723 if not _should_enforce_streamable_rbac(user_context):
724 return True, 200, ""
726 # Initialize establishes a new session and is authorized separately.
727 if (rpc_method or "").strip() == "initialize":
728 return True, 200, ""
730 requester_email = user_context.get("email") if isinstance(user_context, dict) else None
731 requester_is_admin = bool(user_context.get("is_admin", False)) if isinstance(user_context, dict) else False
733 session_registry = _get_shared_session_registry()
734 if session_registry is None:
735 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
737 try:
738 session_owner = await session_registry.get_session_owner(mcp_session_id)
739 except Exception as exc:
740 logger.warning("Failed to get session owner for %s: %s", mcp_session_id, exc)
741 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
743 if session_owner:
744 if requester_is_admin:
745 return True, 200, ""
746 if requester_email and requester_email == session_owner:
747 return True, 200, ""
748 return False, HTTP_403_FORBIDDEN, "Session access denied"
750 try:
751 session_exists = await session_registry.session_exists(mcp_session_id)
752 except Exception as exc:
753 logger.warning("Failed to check session existence for %s: %s", mcp_session_id, exc)
754 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
756 if session_exists is False:
757 return False, HTTP_404_NOT_FOUND, "Session not found"
758 return False, HTTP_403_FORBIDDEN, "Session owner metadata unavailable"
761async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Tool]: # pylint: disable=unused-argument
762 """Proxy tools/list request directly to remote MCP gateway using MCP SDK.
764 Args:
765 gateway: Gateway ORM instance
766 request_headers: Request headers from client
767 user_context: User context (not used - _meta comes from MCP SDK)
768 meta: Request metadata (_meta) from the original request
770 Returns:
771 List of Tool objects from remote server
772 """
773 try:
774 # Prepare headers with gateway auth
775 headers = build_gateway_auth_headers(gateway)
777 # Forward passthrough headers if configured
778 if gateway.passthrough_headers and request_headers:
779 for header_name in gateway.passthrough_headers:
780 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name)
781 if header_value:
782 headers[header_name] = header_value
784 # Use MCP SDK to connect and list tools
785 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
786 async with ClientSession(read_stream, write_stream) as session:
787 await session.initialize()
789 # Prepare params with _meta if provided
790 params = None
791 if meta:
792 params = PaginatedRequestParams(_meta=meta)
793 logger.debug("Forwarding _meta to remote gateway: %s", meta)
795 # List tools with _meta forwarded
796 result = await session.list_tools(params=params)
797 return result.tools
799 except Exception as e:
800 logger.exception("Error proxying tools/list to gateway %s: %s", gateway.id, e)
801 return []
804async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Resource]: # pylint: disable=unused-argument
805 """Proxy resources/list request directly to remote MCP gateway using MCP SDK.
807 Args:
808 gateway: Gateway ORM instance
809 request_headers: Request headers from client
810 user_context: User context (not used - _meta comes from MCP SDK)
811 meta: Request metadata (_meta) from the original request
813 Returns:
814 List of Resource objects from remote server
815 """
816 try:
817 # Prepare headers with gateway auth
818 headers = build_gateway_auth_headers(gateway)
820 # Forward passthrough headers if configured
821 if gateway.passthrough_headers and request_headers:
822 for header_name in gateway.passthrough_headers:
823 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name)
824 if header_value:
825 headers[header_name] = header_value
827 logger.info("Proxying resources/list to gateway %s at %s", gateway.id, gateway.url)
828 if meta:
829 logger.debug("Forwarding _meta to remote gateway: %s", meta)
831 # Use MCP SDK to connect and list resources
832 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
833 async with ClientSession(read_stream, write_stream) as session:
834 await session.initialize()
836 # Prepare params with _meta if provided
837 params = None
838 if meta:
839 params = PaginatedRequestParams(_meta=meta)
840 logger.debug("Forwarding _meta to remote gateway: %s", meta)
842 # List resources with _meta forwarded
843 result = await session.list_resources(params=params)
845 logger.info("Received %s resources from gateway %s", len(result.resources), gateway.id)
846 return result.resources
848 except Exception as e:
849 logger.exception("Error proxying resources/list to gateway %s: %s", gateway.id, e)
850 return []
853async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_context: dict, meta: Optional[Any] = None) -> List[Any]: # pylint: disable=unused-argument
854 """Proxy resources/read request directly to remote MCP gateway using MCP SDK.
856 Args:
857 gateway: Gateway ORM instance
858 resource_uri: URI of the resource to read
859 user_context: User context (not used - auth comes from gateway config)
860 meta: Request metadata (_meta) from the original request
862 Returns:
863 List of content objects (TextResourceContents or BlobResourceContents) from remote server
864 """
865 try:
866 # Prepare headers with gateway auth
867 headers = build_gateway_auth_headers(gateway)
869 # Get request headers
870 request_headers = request_headers_var.get()
872 # Forward X-Context-Forge-Gateway-Id header
873 gw_id = extract_gateway_id_from_headers(request_headers)
874 if gw_id:
875 headers[GATEWAY_ID_HEADER] = gw_id
877 # Forward passthrough headers if configured
878 if gateway.passthrough_headers and request_headers:
879 for header_name in gateway.passthrough_headers:
880 header_value = request_headers.get(header_name.lower()) or request_headers.get(header_name)
881 if header_value:
882 headers[header_name] = header_value
884 logger.info("Proxying resources/read for %s to gateway %s at %s", resource_uri, gateway.id, gateway.url)
885 if meta:
886 logger.debug("Forwarding _meta to remote gateway: %s", meta)
888 # Use MCP SDK to connect and read resource
889 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
890 async with ClientSession(read_stream, write_stream) as session:
891 await session.initialize()
893 # Prepare request params with _meta if provided
894 if meta:
895 # Create params and inject _meta
896 request_params = ReadResourceRequestParams(uri=resource_uri)
897 request_params_dict = request_params.model_dump()
898 request_params_dict["_meta"] = meta
900 # Send request with _meta
901 result = await session.send_request(
902 types.ClientRequest(ReadResourceRequest(params=ReadResourceRequestParams.model_validate(request_params_dict))),
903 types.ReadResourceResult,
904 )
905 else:
906 # No _meta, use simple read_resource
907 result = await session.read_resource(uri=resource_uri)
909 logger.info("Received %s content items from gateway %s for resource %s", len(result.contents), gateway.id, resource_uri)
910 return result.contents
912 except Exception as e:
913 logger.exception("Error proxying resources/read to gateway %s for resource %s: %s", gateway.id, resource_uri, e)
914 return []
917@mcp_app.call_tool(validate_input=False)
918async def call_tool(name: str, arguments: dict) -> Union[
919 types.CallToolResult,
920 List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]],
921 Tuple[List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]], Dict[str, Any]],
922]:
923 """
924 Handles tool invocation via the MCP Server.
926 Note: validate_input=False disables the MCP SDK's built-in JSON Schema validation.
927 This is necessary because the SDK uses jsonschema.validate() which internally calls
928 check_schema() with the default validator. Schemas using older draft features
929 (e.g., Draft 4 style exclusiveMinimum: true) fail this validation. The gateway
930 handles schema validation separately in tool_service.py with multi-draft support.
932 This function supports the MCP protocol's tool calling with structured content validation.
933 In direct_proxy mode, returns the raw CallToolResult from the remote server.
934 In normal mode, converts ToolResult to CallToolResult with content normalization.
936 Args:
937 name (str): The name of the tool to invoke.
938 arguments (dict): A dictionary of arguments to pass to the tool.
940 Returns:
941 types.CallToolResult: MCP SDK CallToolResult with content and optional structuredContent.
943 Raises:
944 PermissionError: If the caller lacks ``tools.execute`` permission.
945 Exception: Re-raised after logging to allow MCP SDK to convert to JSON-RPC error response.
947 Examples:
948 >>> # Test call_tool function signature
949 >>> import inspect
950 >>> sig = inspect.signature(call_tool)
951 >>> list(sig.parameters.keys())
952 ['name', 'arguments']
953 >>> sig.parameters['name'].annotation
954 <class 'str'>
955 >>> sig.parameters['arguments'].annotation
956 <class 'dict'>
957 """
958 server_id, request_headers, user_context = await _get_request_context_or_default()
960 meta_data = None
961 # Extract _meta from request context if available
962 try:
963 ctx = mcp_app.request_context
964 if ctx and ctx.meta is not None:
965 meta_data = ctx.meta.model_dump()
966 except LookupError:
967 # request_context might not be active in some edge cases (e.g. tests)
968 logger.debug("No active request context found")
970 # Extract authorization parameters from user context (same pattern as list_tools)
971 user_email = user_context.get("email") if user_context else None
972 token_teams = user_context.get("teams") if user_context else None
973 is_admin = user_context.get("is_admin", False) if user_context else False
975 # Admin bypass - only when token has NO team restrictions (token_teams is None)
976 # If token has explicit team scope (even empty [] for public-only), respect it
977 if is_admin and token_teams is None:
978 user_email = None
979 # token_teams stays None (unrestricted)
980 elif token_teams is None:
981 token_teams = [] # Non-admin without teams = public-only (secure default)
983 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
984 # When mcp_require_auth=True, the middleware already guarantees authentication.
985 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
986 # the middleware (streamable_http_auth) catches it and returns 503. If the
987 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
988 # logged by the ASGI server.
989 if not settings.mcp_require_auth:
990 await _check_server_oauth_enforcement(server_id, user_context)
992 if _should_enforce_streamable_rbac(user_context):
993 # Layer 1: Token scope cap
994 if not _check_scoped_permission(user_context, "tools.execute"):
995 raise PermissionError(_ACCESS_DENIED_MSG)
996 # Layer 2: RBAC check
997 # Session tokens have no explicit team_id; check across all team-scoped roles.
998 # Mirrors the @require_permission decorator's check_any_team fallback (rbac.py:562-576).
999 _is_session_token = user_context.get("token_use") == "session"
1000 has_execute_permission = await _check_streamable_permission(
1001 user_context=user_context,
1002 permission="tools.execute",
1003 check_any_team=_is_session_token,
1004 )
1005 if not has_execute_permission:
1006 raise PermissionError(_ACCESS_DENIED_MSG)
1008 # Check if we're in direct_proxy mode by looking for X-Context-Forge-Gateway-Id header
1009 gateway_id_from_header = extract_gateway_id_from_headers(request_headers)
1011 # If X-Context-Forge-Gateway-Id header is present, use direct proxy mode
1012 if gateway_id_from_header:
1013 try: # Check if this gateway is in direct_proxy mode
1014 async with get_db() as check_db:
1015 # Third-Party
1016 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1018 # First-Party
1019 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1021 gateway = check_db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
1022 if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1023 # SECURITY: Check gateway access before allowing direct proxy
1024 if not await check_gateway_access(check_db, gateway, user_email, token_teams):
1025 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id_from_header, user_email)
1026 return types.CallToolResult(content=[types.TextContent(type="text", text=f"Tool not found: {name}")], isError=True)
1028 logger.info("Using direct_proxy mode for tool '%s' via gateway %s", name, gateway_id_from_header)
1030 # Use direct proxy method - returns raw CallToolResult from remote server
1031 # Return it directly without any normalization
1032 return await tool_service.invoke_tool_direct(
1033 gateway_id=gateway_id_from_header,
1034 name=name,
1035 arguments=arguments,
1036 request_headers=request_headers,
1037 meta_data=meta_data,
1038 user_email=user_email,
1039 token_teams=token_teams,
1040 )
1041 except Exception as e:
1042 logger.error("Direct proxy mode failed for gateway %s: %s", gateway_id_from_header, e)
1043 return types.CallToolResult(content=[types.TextContent(type="text", text="Direct proxy tool invocation failed")], isError=True)
1045 # Normal mode: use standard tool invocation with normalization
1046 # Use the already-recovered user_context (works for both ContextVar and stateful session paths)
1047 app_user_email = (user_context.get("email") or user_context.get("sub") or "unknown") if user_context else "unknown"
1049 # Multi-worker session affinity: check if we should forward to another worker
1050 # Check both x-mcp-session-id (internal/forwarded) and mcp-session-id (client protocol header)
1051 mcp_session_id = None
1052 if request_headers:
1053 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
1054 mcp_session_id = request_headers_lower.get("x-mcp-session-id") or request_headers_lower.get("mcp-session-id")
1055 if settings.mcpgateway_session_affinity_enabled and mcp_session_id:
1056 try:
1057 # First-Party
1058 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
1059 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
1060 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
1062 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
1063 logger.debug("Invalid MCP session id for Streamable HTTP tool affinity, executing locally")
1064 raise RuntimeError("invalid mcp session id")
1066 pool = get_mcp_session_pool()
1068 # Register session mapping BEFORE checking forwarding (same pattern as SSE)
1069 # This ensures ownership is registered atomically so forward_request_to_owner() works
1070 try:
1071 cached = await tool_lookup_cache.get(name)
1072 if cached and cached.get("status") == "active":
1073 gateway_info = cached.get("gateway")
1074 if gateway_info:
1075 url = gateway_info.get("url")
1076 gateway_id = gateway_info.get("id", "")
1077 transport_type = gateway_info.get("transport", "streamablehttp")
1078 if url:
1079 await pool.register_session_mapping(mcp_session_id, url, gateway_id, transport_type, user_email)
1080 except Exception as e:
1081 logger.error("Failed to pre-register session mapping for Streamable HTTP: %s", e)
1083 forwarded_response = await pool.forward_request_to_owner(
1084 mcp_session_id,
1085 {"method": "tools/call", "params": {"name": name, "arguments": arguments, "_meta": meta_data}, "headers": dict(request_headers) if request_headers else {}},
1086 )
1087 if forwarded_response is not None:
1088 # Request was handled by another worker - convert response to expected format
1089 if "error" in forwarded_response:
1090 raise Exception(forwarded_response["error"].get("message", "Forwarded request failed")) # pylint: disable=broad-exception-raised
1091 result_data = forwarded_response.get("result", {})
1093 def _rehydrate_content_items(items: Any) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource]:
1094 """Convert forwarded tool result items back to MCP content types.
1096 Args:
1097 items: List of content item dicts from forwarded response.
1099 Returns:
1100 List of validated MCP content type instances.
1101 """
1102 if not isinstance(items, list):
1103 return []
1104 converted: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
1105 for item in items:
1106 if not isinstance(item, dict):
1107 continue
1108 item_type = item.get("type")
1109 try:
1110 if item_type == "text":
1111 converted.append(types.TextContent.model_validate(item))
1112 elif item_type == "image":
1113 converted.append(types.ImageContent.model_validate(item))
1114 elif item_type == "audio":
1115 converted.append(types.AudioContent.model_validate(item))
1116 elif item_type == "resource_link":
1117 converted.append(types.ResourceLink.model_validate(item))
1118 elif item_type == "resource":
1119 converted.append(types.EmbeddedResource.model_validate(item))
1120 else:
1121 converted.append(types.TextContent(type="text", text=item if isinstance(item, str) else orjson.dumps(item).decode()))
1122 except Exception:
1123 converted.append(types.TextContent(type="text", text=item if isinstance(item, str) else orjson.dumps(item).decode()))
1124 return converted
1126 unstructured = _rehydrate_content_items(result_data.get("content", []))
1127 structured = result_data.get("structuredContent") or result_data.get("structured_content")
1128 if structured:
1129 return (unstructured, structured)
1130 return unstructured
1131 except RuntimeError:
1132 # Pool not initialized - execute locally
1133 pass
1135 try:
1136 async with get_db() as db:
1137 # Use tool service for all tool invocations (handles direct_proxy internally)
1138 result = await tool_service.invoke_tool(
1139 db=db,
1140 name=name,
1141 arguments=arguments,
1142 request_headers=request_headers,
1143 app_user_email=app_user_email,
1144 user_email=user_email,
1145 token_teams=token_teams,
1146 server_id=server_id,
1147 meta_data=meta_data,
1148 )
1149 if not result or not result.content:
1150 logger.warning("No content returned by tool: %s", name)
1151 return []
1153 # Normalize unstructured content to MCP SDK types, preserving metadata (annotations, _meta, size)
1154 # Helper to convert gateway Annotations to dict for MCP SDK compatibility
1155 # (mcpgateway.common.models.Annotations != mcp.types.Annotations)
1156 def _convert_annotations(ann: Any) -> dict[str, Any] | None:
1157 """Convert gateway Annotations to dict for MCP SDK compatibility.
1159 Args:
1160 ann: Gateway Annotations object, dict, or None.
1162 Returns:
1163 Dict representation of annotations, or None.
1164 """
1165 if ann is None:
1166 return None
1167 if isinstance(ann, dict):
1168 return ann
1169 if hasattr(ann, "model_dump"):
1170 return ann.model_dump(by_alias=True, mode="json")
1171 return None
1173 def _convert_meta(meta: Any) -> dict[str, Any] | None:
1174 """Convert gateway meta to dict for MCP SDK compatibility.
1176 Args:
1177 meta: Gateway meta object, dict, or None.
1179 Returns:
1180 Dict representation of meta, or None.
1181 """
1182 if meta is None:
1183 return None
1184 if isinstance(meta, dict):
1185 return meta
1186 if hasattr(meta, "model_dump"):
1187 return meta.model_dump(by_alias=True, mode="json")
1188 return None
1190 unstructured: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
1191 for content in result.content:
1192 if content.type == "text":
1193 unstructured.append(
1194 types.TextContent(
1195 type="text",
1196 text=content.text,
1197 annotations=_convert_annotations(getattr(content, "annotations", None)),
1198 _meta=_convert_meta(getattr(content, "meta", None)),
1199 )
1200 )
1201 elif content.type == "image":
1202 unstructured.append(
1203 types.ImageContent(
1204 type="image",
1205 data=content.data,
1206 mimeType=content.mime_type,
1207 annotations=_convert_annotations(getattr(content, "annotations", None)),
1208 _meta=_convert_meta(getattr(content, "meta", None)),
1209 )
1210 )
1211 elif content.type == "audio":
1212 unstructured.append(
1213 types.AudioContent(
1214 type="audio",
1215 data=content.data,
1216 mimeType=content.mime_type,
1217 annotations=_convert_annotations(getattr(content, "annotations", None)),
1218 _meta=_convert_meta(getattr(content, "meta", None)),
1219 )
1220 )
1221 elif content.type == "resource_link":
1222 unstructured.append(
1223 types.ResourceLink(
1224 type="resource_link",
1225 uri=content.uri,
1226 name=content.name,
1227 description=getattr(content, "description", None),
1228 mimeType=getattr(content, "mime_type", None),
1229 size=getattr(content, "size", None),
1230 _meta=_convert_meta(getattr(content, "meta", None)),
1231 )
1232 )
1233 elif content.type == "resource":
1234 # EmbeddedResource - pass through the model dump as the MCP SDK type requires complex nested structure
1235 unstructured.append(types.EmbeddedResource.model_validate(content.model_dump(by_alias=True, mode="json")))
1236 else:
1237 # Unknown content type - convert to text representation
1238 unstructured.append(types.TextContent(type="text", text=orjson.dumps(content.model_dump(by_alias=True, mode="json")).decode()))
1240 # If the tool produced structured content (ToolResult.structured_content / structuredContent),
1241 # return a combination (unstructured, structured) so the server can validate against outputSchema.
1242 # The ToolService may populate structured_content (snake_case) or the model may expose
1243 # an alias 'structuredContent' when dumped via model_dump(by_alias=True).
1244 structured = None
1245 try:
1246 # Prefer attribute if present
1247 structured = getattr(result, "structured_content", None)
1248 except Exception:
1249 structured = None
1251 # Fallback to by-alias dump (in case the result is a pydantic model with alias fields)
1252 if structured is None:
1253 try:
1254 structured = result.model_dump(by_alias=True).get("structuredContent") if hasattr(result, "model_dump") else None
1255 except Exception:
1256 structured = None
1258 if structured:
1259 return (unstructured, structured)
1261 return unstructured
1262 except Exception as e:
1263 logger.exception("Error calling tool '%s': %s", name, e)
1264 # Re-raise the exception so the MCP SDK can properly convert it to an error response
1265 # This ensures error details are propagated to the client instead of returning empty results
1266 raise
1269async def _get_request_context_or_default() -> Tuple[str, dict[str, Any], dict[str, Any]]:
1270 """Retrieves request context information for the current execution.
1272 This function resolves request context using the following precedence:
1274 1. Context variables (fast path). Used when the handler executes in the
1275 same async context as the middleware (for example, direct ASGI dispatch).
1276 2. ASGI scope. The middleware stores resolved context on
1277 ``scope[_MCPGATEWAY_CONTEXT_KEY]`` before handing off to the MCP SDK.
1278 Because the SDK passes the same ``scope`` dictionary through to
1279 ``mcp_app.request_context.request``, this survives task-group
1280 boundaries where ContextVars may be lost.
1281 3. Re-authentication fallback. Re-extracts identity from the request's
1282 Authorization header or cookies. This is the most expensive path and
1283 may produce a different context shape for anonymous callers (an empty
1284 dictionary instead of the middleware's canonical
1285 ``{"is_authenticated": False, ...}`` structure).
1287 Returns:
1288 Tuple[str, dict[str, Any], dict[str, Any]]: A tuple containing:
1290 - server_id: The resolved server identifier.
1291 - request_headers: The request headers as a dictionary.
1292 - user_context: The resolved user context dictionary.
1293 """
1294 # 1. Try context vars first (fast path)
1295 s_id = server_id_var.get()
1297 # Check if context vars are populated with real data (not defaults)
1298 if s_id != "default_server_id":
1299 return s_id, request_headers_var.get(), user_context_var.get()
1301 # 2. Try ASGI scope context injected by handle_streamable_http()
1302 ctx = None
1303 try:
1304 ctx = mcp_app.request_context
1305 request = ctx.request
1306 if request:
1307 gw_ctx = getattr(request, "scope", {}).get(_MCPGATEWAY_CONTEXT_KEY)
1308 if isinstance(gw_ctx, dict):
1309 return (
1310 gw_ctx.get("server_id") or s_id,
1311 gw_ctx.get("request_headers", {}),
1312 gw_ctx.get("user_context", {}),
1313 )
1314 except LookupError:
1315 # Not in a request context — fall through to ContextVar defaults
1316 return s_id, request_headers_var.get(), user_context_var.get()
1317 except Exception as e:
1318 logger.debug("Failed to read %s from scope: %s", _MCPGATEWAY_CONTEXT_KEY, e)
1320 # 3. Re-authentication fallback (stateful session path)
1321 try:
1322 # Reuse ctx from the scope-reading block above (step 2) to avoid
1323 # a redundant mcp_app.request_context lookup.
1324 if ctx is None:
1325 ctx = mcp_app.request_context
1326 request = ctx.request
1327 if not request:
1328 logger.warning("No request object found in MCP context")
1329 return s_id, request_headers_var.get(), user_context_var.get()
1331 # Extract server_id from URL
1332 path = request.url.path
1333 match = _SERVER_ID_RE.search(path)
1334 if match:
1335 s_id = match.group("server_id")
1337 # Extract headers
1338 req_headers = dict(request.headers)
1340 # Extract and verify user context
1341 # Use require_auth_header_first to match streamable_http_auth token precedence:
1342 # Authorization header > request cookies > jwt_token parameter
1343 auth_header = req_headers.get("authorization")
1344 cookie_token = request.cookies.get("jwt_token")
1346 try:
1347 raw_payload = await require_auth_header_first(auth_header=auth_header, jwt_token=cookie_token, request=request)
1348 if isinstance(raw_payload, str): # "anonymous"
1349 user_ctx = {}
1350 elif isinstance(raw_payload, dict):
1351 # Normalize raw JWT payload to canonical user context shape
1352 # (matches streamable_http_auth normalization at lines 2155-2259)
1353 user_ctx = _normalize_jwt_payload(raw_payload)
1354 else:
1355 user_ctx = {}
1356 except Exception as e:
1357 logger.warning("Failed to recover user context in stateful session: %s", e)
1358 user_ctx = {}
1360 return s_id, req_headers, user_ctx
1362 except LookupError:
1363 # Not in a request context
1364 return s_id, request_headers_var.get(), user_context_var.get()
1365 except Exception as e:
1366 logger.exception("Error recovering context in stateful session: %s", e)
1367 return s_id, request_headers_var.get(), user_context_var.get()
1370def _normalize_jwt_payload(payload: dict[str, Any]) -> dict[str, Any]:
1371 """Normalize a raw JWT payload to the canonical user context shape.
1373 Converts raw JWT fields (sub, token_use, nested user.is_admin) into the
1374 canonical ``{email, teams, is_admin, is_authenticated, token_use}`` dict that MCP
1375 handlers expect. This mirrors the normalization performed by
1376 ``streamable_http_auth`` so that the stateful-session fallback path in
1377 ``_get_request_context_or_default`` returns an identical shape.
1379 Args:
1380 payload: Raw JWT payload dict from ``require_auth_header_first``.
1382 Returns:
1383 Canonical user context dict with keys email, teams, is_admin, is_authenticated, token_use.
1384 """
1385 email = payload.get("sub") or payload.get("email")
1386 is_admin = payload.get("is_admin", False)
1387 if not is_admin:
1388 user_info = payload.get("user", {})
1389 is_admin = user_info.get("is_admin", False) if isinstance(user_info, dict) else False
1391 token_use = payload.get("token_use")
1392 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type
1393 # Session token: resolve teams from DB/cache
1394 if is_admin:
1395 final_teams = None # Admin bypass
1396 elif email:
1397 # First-Party
1398 from mcpgateway.auth import _resolve_teams_from_db_sync # pylint: disable=import-outside-toplevel
1400 final_teams = _resolve_teams_from_db_sync(email, is_admin=False)
1401 else:
1402 final_teams = [] # No email — public-only
1403 else:
1404 # API token or legacy: use embedded teams from JWT
1405 # First-Party
1406 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel
1408 final_teams = normalize_token_teams(payload)
1410 user_ctx: dict[str, Any] = {
1411 "email": email,
1412 "teams": final_teams,
1413 "is_admin": is_admin,
1414 "is_authenticated": True,
1415 "token_use": token_use,
1416 }
1417 # Extract scoped permissions from JWT for per-method enforcement
1418 scopes = payload.get("scopes") or {}
1419 scoped_perms = scopes.get("permissions") or [] if isinstance(scopes, dict) else []
1420 if scoped_perms:
1421 user_ctx["scoped_permissions"] = scoped_perms
1422 return user_ctx
1425@mcp_app.list_tools()
1426async def list_tools() -> List[types.Tool]:
1427 """
1428 Lists all tools available to the MCP Server.
1430 Supports two modes based on gateway's gateway_mode:
1431 - 'cache': Returns tools from database (default behavior)
1432 - 'direct_proxy': Proxies the request directly to the remote MCP server
1434 Returns:
1435 A list of Tool objects containing metadata such as name, description, and input schema.
1436 Logs and returns an empty list on failure.
1438 Raises:
1439 PermissionError: If the caller lacks ``tools.read`` permission.
1441 Examples:
1442 >>> # Test list_tools function signature
1443 >>> import inspect
1444 >>> sig = inspect.signature(list_tools)
1445 >>> list(sig.parameters.keys())
1446 []
1447 >>> sig.return_annotation
1448 typing.List[mcp.types.Tool]
1449 """
1450 server_id, request_headers, user_context = await _get_request_context_or_default()
1452 # Token scope cap: deny early if scoped permissions exclude tools.read
1453 if _should_enforce_streamable_rbac(user_context):
1454 if not _check_scoped_permission(user_context, "tools.read"):
1455 raise PermissionError(_ACCESS_DENIED_MSG)
1457 # Extract filtering parameters from user context
1458 user_email = user_context.get("email") if user_context else None
1459 # Use None as default to distinguish "no teams specified" from "empty teams array"
1460 token_teams = user_context.get("teams") if user_context else None
1461 is_admin = user_context.get("is_admin", False) if user_context else False
1463 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1464 # If token has explicit team scope (even empty [] for public-only), respect it
1465 if is_admin and token_teams is None:
1466 user_email = None
1467 # token_teams stays None (unrestricted)
1468 elif token_teams is None:
1469 token_teams = [] # Non-admin without teams = public-only (secure default)
1471 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1472 # When mcp_require_auth=True, the middleware already guarantees authentication.
1473 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1474 # the middleware (streamable_http_auth) catches it and returns 503. If the
1475 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1476 # logged by the ASGI server.
1477 if not settings.mcp_require_auth:
1478 await _check_server_oauth_enforcement(server_id, user_context)
1480 if server_id:
1481 try:
1482 async with get_db() as db:
1483 # Check for X-Context-Forge-Gateway-Id header first - if present, try direct proxy mode
1484 gateway_id = extract_gateway_id_from_headers(request_headers)
1486 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
1487 if gateway_id:
1488 # Third-Party
1489 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1491 # First-Party
1492 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1494 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
1495 if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1496 # SECURITY: Check gateway access before allowing direct proxy
1497 if not await check_gateway_access(db, gateway, user_email, token_teams):
1498 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
1499 return [] # Return empty list for unauthorized access
1501 # Direct proxy mode: forward request to remote MCP server
1502 # Get _meta from request context if available
1503 meta = None
1504 try:
1505 request_ctx = mcp_app.request_context
1506 meta = request_ctx.meta
1507 logger.info(
1508 "[LIST TOOLS] Using direct_proxy mode for server %s, gateway %s (from %s header). Meta Attached: %s",
1509 server_id,
1510 gateway.id,
1511 GATEWAY_ID_HEADER,
1512 meta is not None,
1513 )
1514 except (LookupError, AttributeError) as e:
1515 logger.debug("No request context available for _meta extraction: %s", e)
1517 return await _proxy_list_tools_to_gateway(gateway, request_headers, user_context, meta)
1518 if gateway:
1519 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, getattr(gateway, "gateway_mode", "cache"))
1520 else:
1521 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
1523 # Check if server exists for cache mode
1524 # Third-Party
1525 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1527 # First-Party
1528 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
1530 server = db.execute(select(DbServer).where(DbServer.id == server_id)).scalar_one_or_none()
1531 if not server:
1532 logger.warning("Server %s not found in database", server_id)
1533 return []
1535 # Default cache mode: use database
1536 tools = await tool_service.list_server_tools(db, server_id, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
1537 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools]
1538 except Exception as e:
1539 logger.error("Error listing tools:%s", e)
1540 return []
1541 else:
1542 try:
1543 async with get_db() as db:
1544 tools, _ = await tool_service.list_tools(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
1545 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools]
1546 except Exception as e:
1547 logger.exception("Error listing tools:%s", e)
1548 return []
1551@mcp_app.list_prompts()
1552async def list_prompts() -> List[types.Prompt]:
1553 """
1554 Lists all prompts available to the MCP Server.
1556 Returns:
1557 A list of Prompt objects containing metadata such as name, description, and arguments.
1558 Logs and returns an empty list on failure.
1560 Raises:
1561 PermissionError: If the user context indicates insufficient permissions (e.g., missing "prompts.read" scope).
1563 Examples:
1564 >>> import inspect
1565 >>> sig = inspect.signature(list_prompts)
1566 >>> list(sig.parameters.keys())
1567 []
1568 >>> sig.return_annotation
1569 typing.List[mcp.types.Prompt]
1570 """
1571 server_id, _, user_context = await _get_request_context_or_default()
1573 # Token scope cap: deny early if scoped permissions exclude prompts.read
1574 if _should_enforce_streamable_rbac(user_context):
1575 if not _check_scoped_permission(user_context, "prompts.read"):
1576 raise PermissionError(_ACCESS_DENIED_MSG)
1578 # Extract filtering parameters from user context
1579 user_email = user_context.get("email") if user_context else None
1580 # Use None as default to distinguish "no teams specified" from "empty teams array"
1581 token_teams = user_context.get("teams") if user_context else None
1582 is_admin = user_context.get("is_admin", False) if user_context else False
1584 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1585 # If token has explicit team scope (even empty [] for public-only), respect it
1586 if is_admin and token_teams is None:
1587 user_email = None
1588 # token_teams stays None (unrestricted)
1589 elif token_teams is None:
1590 token_teams = [] # Non-admin without teams = public-only (secure default)
1592 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1593 # When mcp_require_auth=True, the middleware already guarantees authentication.
1594 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1595 # the middleware (streamable_http_auth) catches it and returns 503. If the
1596 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1597 # logged by the ASGI server.
1598 if not settings.mcp_require_auth:
1599 await _check_server_oauth_enforcement(server_id, user_context)
1601 if server_id:
1602 try:
1603 async with get_db() as db:
1604 prompts = await prompt_service.list_server_prompts(db, server_id, user_email=user_email, token_teams=token_teams)
1605 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
1606 except Exception as e:
1607 logger.exception("Error listing Prompts:%s", e)
1608 return []
1609 else:
1610 try:
1611 async with get_db() as db:
1612 prompts, _ = await prompt_service.list_prompts(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
1613 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
1614 except Exception as e:
1615 logger.exception("Error listing prompts:%s", e)
1616 return []
1619@mcp_app.get_prompt()
1620async def get_prompt(prompt_id: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
1621 """
1622 Retrieves a prompt by ID, optionally substituting arguments.
1624 Args:
1625 prompt_id (str): The ID of the prompt to retrieve.
1626 arguments (Optional[dict[str, str]]): Optional dictionary of arguments to substitute into the prompt.
1628 Returns:
1629 GetPromptResult: Object containing the prompt messages and description.
1630 Returns an empty list on failure or if no prompt content is found.
1632 Raises:
1633 PermissionError: If the user context indicates insufficient permissions (e.g., missing "prompts.read" scope).
1635 Logs exceptions if any errors occur during retrieval.
1637 Examples:
1638 >>> import inspect
1639 >>> sig = inspect.signature(get_prompt)
1640 >>> list(sig.parameters.keys())
1641 ['prompt_id', 'arguments']
1642 >>> sig.return_annotation.__name__
1643 'GetPromptResult'
1644 """
1645 server_id, _, user_context = await _get_request_context_or_default()
1647 # Token scope cap: deny early if scoped permissions exclude prompts.read
1648 if _should_enforce_streamable_rbac(user_context):
1649 if not _check_scoped_permission(user_context, "prompts.read"):
1650 raise PermissionError(_ACCESS_DENIED_MSG)
1652 # Extract authorization parameters from user context (same pattern as list_prompts)
1653 user_email = user_context.get("email") if user_context else None
1654 token_teams = user_context.get("teams") if user_context else None
1655 is_admin = user_context.get("is_admin", False) if user_context else False
1657 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1658 if is_admin and token_teams is None:
1659 user_email = None
1660 # token_teams stays None (unrestricted)
1661 elif token_teams is None:
1662 token_teams = [] # Non-admin without teams = public-only (secure default)
1664 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1665 # When mcp_require_auth=True, the middleware already guarantees authentication.
1666 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1667 # the middleware (streamable_http_auth) catches it and returns 503. If the
1668 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1669 # logged by the ASGI server.
1670 if not settings.mcp_require_auth:
1671 await _check_server_oauth_enforcement(server_id, user_context)
1673 meta_data = None
1674 # Extract _meta from request context if available
1675 try:
1676 ctx = mcp_app.request_context
1677 if ctx and ctx.meta is not None:
1678 meta_data = ctx.meta.model_dump()
1679 except LookupError:
1680 # request_context might not be active in some edge cases (e.g. tests)
1681 logger.debug("No active request context found")
1683 try:
1684 async with get_db() as db:
1685 try:
1686 result = await prompt_service.get_prompt(
1687 db=db,
1688 prompt_id=prompt_id,
1689 arguments=arguments,
1690 user=user_email,
1691 server_id=server_id,
1692 token_teams=token_teams,
1693 _meta_data=meta_data,
1694 )
1695 except Exception as e:
1696 logger.exception("Error getting prompt '%s': %s", prompt_id, e)
1697 return []
1698 if not result or not result.messages:
1699 logger.warning("No content returned by prompt: %s", prompt_id)
1700 return []
1701 message_dicts = [message.model_dump() for message in result.messages]
1702 return types.GetPromptResult(messages=message_dicts, description=result.description)
1703 except Exception as e:
1704 logger.exception("Error getting prompt '%s': %s", prompt_id, e)
1705 return []
1708@mcp_app.list_resources()
1709async def list_resources() -> List[types.Resource]:
1710 """
1711 Lists all resources available to the MCP Server.
1713 Returns:
1714 A list of Resource objects containing metadata such as uri, name, description, and mimeType.
1715 Logs and returns an empty list on failure.
1717 Raises:
1718 PermissionError: If the user context indicates insufficient permissions (e.g., missing "resources.read" scope).
1720 Examples:
1721 >>> import inspect
1722 >>> sig = inspect.signature(list_resources)
1723 >>> list(sig.parameters.keys())
1724 []
1725 >>> sig.return_annotation
1726 typing.List[mcp.types.Resource]
1727 """
1728 server_id, request_headers, user_context = await _get_request_context_or_default()
1730 # Token scope cap: deny early if scoped permissions exclude resources.read
1731 if _should_enforce_streamable_rbac(user_context):
1732 if not _check_scoped_permission(user_context, "resources.read"):
1733 raise PermissionError(_ACCESS_DENIED_MSG)
1735 # Extract filtering parameters from user context
1736 user_email = user_context.get("email") if user_context else None
1737 # Use None as default to distinguish "no teams specified" from "empty teams array"
1738 token_teams = user_context.get("teams") if user_context else None
1739 is_admin = user_context.get("is_admin", False) if user_context else False
1741 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1742 # If token has explicit team scope (even empty [] for public-only), respect it
1743 if is_admin and token_teams is None:
1744 user_email = None
1745 # token_teams stays None (unrestricted)
1746 elif token_teams is None:
1747 token_teams = [] # Non-admin without teams = public-only (secure default)
1749 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1750 # When mcp_require_auth=True, the middleware already guarantees authentication.
1751 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1752 # the middleware (streamable_http_auth) catches it and returns 503. If the
1753 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1754 # logged by the ASGI server.
1755 if not settings.mcp_require_auth:
1756 await _check_server_oauth_enforcement(server_id, user_context)
1758 if server_id:
1759 try:
1760 async with get_db() as db:
1761 # Check for X-Context-Forge-Gateway-Id header first for direct proxy mode
1762 gateway_id = extract_gateway_id_from_headers(request_headers)
1764 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
1765 if gateway_id:
1766 # Third-Party
1767 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1769 # First-Party
1770 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1772 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
1773 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1774 # SECURITY: Check gateway access before allowing direct proxy
1775 if not await check_gateway_access(db, gateway, user_email, token_teams):
1776 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
1777 return [] # Return empty list for unauthorized access
1779 # Direct proxy mode: forward request to remote MCP server
1780 # Get _meta from request context if available
1781 meta = None
1782 try:
1783 request_ctx = mcp_app.request_context
1784 meta = request_ctx.meta
1785 logger.info(
1786 "[LIST RESOURCES] Using direct_proxy mode for server %s, gateway %s (from %s header). Meta Attached: %s",
1787 server_id,
1788 gateway.id,
1789 GATEWAY_ID_HEADER,
1790 meta is not None,
1791 )
1792 except (LookupError, AttributeError) as e:
1793 logger.debug("No request context available for _meta extraction: %s", e)
1795 return await _proxy_list_resources_to_gateway(gateway, request_headers, user_context, meta)
1796 if gateway:
1797 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, gateway.gateway_mode)
1798 else:
1799 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
1801 # Default cache mode: use database
1802 resources = await resource_service.list_server_resources(db, server_id, user_email=user_email, token_teams=token_teams)
1803 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
1804 except Exception as e:
1805 logger.exception("Error listing Resources:%s", e)
1806 return []
1807 else:
1808 try:
1809 async with get_db() as db:
1810 resources, _ = await resource_service.list_resources(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
1811 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
1812 except Exception as e:
1813 logger.exception("Error listing resources:%s", e)
1814 return []
1817@mcp_app.read_resource()
1818async def read_resource(resource_uri: str) -> Union[str, bytes]:
1819 """
1820 Reads the content of a resource specified by its URI.
1822 Args:
1823 resource_uri (str): The URI of the resource to read.
1825 Returns:
1826 Union[str, bytes]: The content of the resource as text or binary data.
1827 Returns empty string on failure or if no content is found.
1829 Raises:
1830 PermissionError: If the user does not have the required permissions to read resources.
1832 Logs exceptions if any errors occur during reading.
1834 Examples:
1835 >>> import inspect
1836 >>> sig = inspect.signature(read_resource)
1837 >>> list(sig.parameters.keys())
1838 ['resource_uri']
1839 >>> sig.return_annotation
1840 typing.Union[str, bytes]
1841 """
1842 server_id, request_headers, user_context = await _get_request_context_or_default()
1844 # Token scope cap: deny early if scoped permissions exclude resources.read
1845 if _should_enforce_streamable_rbac(user_context):
1846 if not _check_scoped_permission(user_context, "resources.read"):
1847 raise PermissionError(_ACCESS_DENIED_MSG)
1849 # Extract authorization parameters from user context (same pattern as list_resources)
1850 user_email = user_context.get("email") if user_context else None
1851 token_teams = user_context.get("teams") if user_context else None
1852 is_admin = user_context.get("is_admin", False) if user_context else False
1854 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1855 if is_admin and token_teams is None:
1856 user_email = None
1857 # token_teams stays None (unrestricted)
1858 elif token_teams is None:
1859 token_teams = [] # Non-admin without teams = public-only (secure default)
1861 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1862 # When mcp_require_auth=True, the middleware already guarantees authentication.
1863 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1864 # the middleware (streamable_http_auth) catches it and returns 503. If the
1865 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1866 # logged by the ASGI server.
1867 if not settings.mcp_require_auth:
1868 await _check_server_oauth_enforcement(server_id, user_context)
1870 meta_data = None
1871 # Extract _meta from request context if available
1872 try:
1873 ctx = mcp_app.request_context
1874 if ctx and ctx.meta is not None:
1875 meta_data = ctx.meta.model_dump()
1876 except LookupError:
1877 # request_context might not be active in some edge cases (e.g. tests)
1878 logger.debug("No active request context found")
1880 try:
1881 async with get_db() as db:
1882 # Check for X-Context-Forge-Gateway-Id header first for direct proxy mode
1883 gateway_id = extract_gateway_id_from_headers(request_headers)
1885 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
1886 if gateway_id:
1887 # Third-Party
1888 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1890 # First-Party
1891 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1893 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
1894 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1895 # SECURITY: Check gateway access before allowing direct proxy
1896 if not await check_gateway_access(db, gateway, user_email, token_teams):
1897 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
1898 return ""
1900 # Direct proxy mode: forward request to remote MCP server
1901 # Get _meta from request context if available
1902 meta = None
1903 try:
1904 request_ctx = mcp_app.request_context
1905 meta = request_ctx.meta
1906 logger.info(
1907 "Using direct_proxy mode for resources/read %s, server %s, gateway %s (from %s header), forwarding _meta: %s",
1908 resource_uri,
1909 server_id,
1910 gateway.id,
1911 GATEWAY_ID_HEADER,
1912 meta,
1913 )
1914 except (LookupError, AttributeError) as e:
1915 logger.debug("No request context available for _meta extraction: %s", e)
1917 contents = await _proxy_read_resource_to_gateway(gateway, str(resource_uri), user_context, meta)
1918 if contents:
1919 # Return first content (text or blob)
1920 first_content = contents[0]
1921 if hasattr(first_content, "text"):
1922 return first_content.text
1923 if hasattr(first_content, "blob"):
1924 return first_content.blob
1925 return ""
1926 if gateway:
1927 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, gateway.gateway_mode)
1928 else:
1929 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
1931 # Default cache mode: use database
1932 try:
1933 result = await resource_service.read_resource(
1934 db=db,
1935 resource_uri=str(resource_uri),
1936 user=user_email,
1937 server_id=server_id,
1938 token_teams=token_teams,
1939 meta_data=meta_data,
1940 )
1941 except Exception as e:
1942 logger.exception("Error reading resource '%s': %s", resource_uri, e)
1943 return ""
1945 # Return blob content if available (binary resources)
1946 if result and result.blob:
1947 return result.blob
1949 # Return text content if available (text resources)
1950 if result and result.text:
1951 return result.text
1953 # No content found
1954 logger.warning("No content returned by resource: %s", resource_uri)
1955 return ""
1956 except Exception as e:
1957 logger.exception("Error reading resource '%s': %s", resource_uri, e)
1958 return ""
1961@mcp_app.list_resource_templates()
1962async def list_resource_templates() -> List[Dict[str, Any]]:
1963 """
1964 Lists all resource templates available to the MCP Server.
1966 Returns:
1967 List[types.ResourceTemplate]: A list of resource templates with their URIs and metadata.
1969 Raises:
1970 PermissionError: If the caller lacks ``resources.read`` permission.
1972 Examples:
1973 >>> import inspect
1974 >>> sig = inspect.signature(list_resource_templates)
1975 >>> list(sig.parameters.keys())
1976 []
1977 >>> sig.return_annotation.__origin__.__name__
1978 'list'
1979 """
1980 # Extract filtering parameters from user context (same pattern as list_resources)
1981 server_id, _, user_context = await _get_request_context_or_default()
1983 # Token scope cap: deny early if scoped permissions exclude resources.read
1984 if _should_enforce_streamable_rbac(user_context):
1985 if not _check_scoped_permission(user_context, "resources.read"):
1986 raise PermissionError(_ACCESS_DENIED_MSG)
1988 user_email = user_context.get("email") if user_context else None
1989 token_teams = user_context.get("teams") if user_context else None
1990 is_admin = user_context.get("is_admin", False) if user_context else False
1992 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1993 # If token has explicit team scope (even empty [] for public-only), respect it
1994 if is_admin and token_teams is None:
1995 user_email = None
1996 # token_teams stays None (unrestricted)
1997 elif token_teams is None:
1998 token_teams = [] # Non-admin without teams = public-only (secure default)
2000 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2001 # When mcp_require_auth=True, the middleware already guarantees authentication.
2002 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2003 # the middleware (streamable_http_auth) catches it and returns 503. If the
2004 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2005 # logged by the ASGI server.
2006 if not settings.mcp_require_auth:
2007 await _check_server_oauth_enforcement(server_id, user_context)
2009 try:
2010 async with get_db() as db:
2011 try:
2012 resource_templates = await resource_service.list_resource_templates(
2013 db,
2014 user_email=user_email,
2015 token_teams=token_teams,
2016 server_id=server_id,
2017 )
2018 return [template.model_dump(by_alias=True) for template in resource_templates]
2019 except Exception as e:
2020 logger.exception("Error listing resource templates: %s", e)
2021 return []
2022 except Exception as e:
2023 logger.exception("Error listing resource templates: %s", e)
2024 return []
2027@mcp_app.set_logging_level()
2028async def set_logging_level(level: types.LoggingLevel) -> types.EmptyResult:
2029 """
2030 Sets the logging level for the MCP Server.
2032 Args:
2033 level (types.LoggingLevel): The desired logging level (debug, info, notice, warning, error, critical, alert, emergency).
2035 Returns:
2036 types.EmptyResult: An empty result indicating success.
2038 Examples:
2039 >>> import inspect
2040 >>> sig = inspect.signature(set_logging_level)
2041 >>> list(sig.parameters.keys())
2042 ['level']
2044 Raises:
2045 PermissionError: If the user does not have permission to set the logging level.
2046 """
2047 server_id, _, user_context = await _get_request_context_or_default()
2049 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2050 # When mcp_require_auth=True, the middleware already guarantees authentication.
2051 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2052 # the middleware (streamable_http_auth) catches it and returns 503. If the
2053 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2054 # logged by the ASGI server.
2055 if not settings.mcp_require_auth:
2056 await _check_server_oauth_enforcement(server_id, user_context)
2058 if _should_enforce_streamable_rbac(user_context):
2059 # Layer 1: Token scope cap
2060 if not _check_scoped_permission(user_context, "admin.system_config"):
2061 raise PermissionError(_ACCESS_DENIED_MSG)
2062 # Layer 2: RBAC check
2063 has_admin_permission = await _check_streamable_permission(
2064 user_context=user_context,
2065 permission="admin.system_config",
2066 )
2067 if not has_admin_permission:
2068 raise PermissionError(_ACCESS_DENIED_MSG)
2070 try:
2071 # Convert MCP logging level to our LogLevel enum
2072 level_map = {
2073 "debug": LogLevel.DEBUG,
2074 "info": LogLevel.INFO,
2075 "notice": LogLevel.INFO,
2076 "warning": LogLevel.WARNING,
2077 "error": LogLevel.ERROR,
2078 "critical": LogLevel.CRITICAL,
2079 "alert": LogLevel.CRITICAL,
2080 "emergency": LogLevel.CRITICAL,
2081 }
2082 log_level = level_map.get(level.lower(), LogLevel.INFO)
2083 await logging_service.set_level(log_level)
2084 return types.EmptyResult()
2085 except PermissionError:
2086 raise
2087 except Exception as e:
2088 logger.exception("Error setting logging level: %s", e)
2089 return types.EmptyResult()
2092@mcp_app.completion()
2093async def complete(
2094 ref: Union[types.PromptReference, types.ResourceTemplateReference],
2095 argument: types.CompleteRequest,
2096 context: Optional[types.CompletionContext] = None,
2097) -> types.CompleteResult:
2098 """
2099 Provides argument completion suggestions for prompts or resources.
2101 Args:
2102 ref: A reference to a prompt or a resource template. Can be either
2103 `types.PromptReference` or `types.ResourceTemplateReference`.
2104 argument: The completion request specifying the input text and
2105 position for which completion suggestions should be generated.
2106 context: Optional contextual information for the completion request,
2107 such as user, environment, or invocation metadata.
2109 Returns:
2110 types.CompleteResult: A normalized completion result containing
2111 completion values, metadata (total, hasMore), and any additional
2112 MCP-compliant completion fields.
2114 Raises:
2115 PermissionError: If the caller lacks ``tools.read`` permission.
2116 Exception: If completion handling fails internally. The method
2117 logs the exception and returns an empty completion structure.
2118 """
2119 # Derive caller visibility scope from the current request context.
2120 server_id, _, user_context = await _get_request_context_or_default()
2122 # Token scope cap: deny early if scoped permissions exclude tools.read
2123 if _should_enforce_streamable_rbac(user_context):
2124 if not _check_scoped_permission(user_context, "tools.read"):
2125 raise PermissionError(_ACCESS_DENIED_MSG)
2127 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2128 # When mcp_require_auth=True, the middleware already guarantees authentication.
2129 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2130 # the middleware (streamable_http_auth) catches it and returns 503. If the
2131 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2132 # logged by the ASGI server.
2133 if not settings.mcp_require_auth:
2134 await _check_server_oauth_enforcement(server_id, user_context)
2136 try:
2137 user_email = user_context.get("email") if user_context else None
2138 token_teams = user_context.get("teams") if user_context else None
2139 is_admin = user_context.get("is_admin", False) if user_context else False
2141 # Admin bypass only for explicit unrestricted context; otherwise secure default.
2142 if is_admin and token_teams is None:
2143 user_email = None
2144 elif token_teams is None:
2145 token_teams = [] # Non-admin without explicit teams -> public-only
2147 async with get_db() as db:
2148 params = {
2149 "ref": ref.model_dump() if hasattr(ref, "model_dump") else ref,
2150 "argument": argument.model_dump() if hasattr(argument, "model_dump") else argument,
2151 "context": context.model_dump() if hasattr(context, "model_dump") else context,
2152 }
2154 result = await completion_service.handle_completion(
2155 db,
2156 params,
2157 user_email=user_email,
2158 token_teams=token_teams,
2159 )
2161 # ✅ Normalize the result for MCP
2162 if isinstance(result, dict):
2163 completion_data = result.get("completion", result)
2164 return types.Completion(**completion_data)
2166 if hasattr(result, "completion"):
2167 completion_obj = result.completion
2169 # If completion itself is a dict
2170 if isinstance(completion_obj, dict):
2171 return types.Completion(**completion_obj)
2173 # If completion is another CompleteResult (nested)
2174 if hasattr(completion_obj, "completion"):
2175 inner_completion = completion_obj.completion.model_dump() if hasattr(completion_obj.completion, "model_dump") else completion_obj.completion
2176 return types.Completion(**inner_completion)
2178 # If completion is already a Completion model
2179 if isinstance(completion_obj, types.Completion):
2180 return completion_obj
2182 # If it's another Pydantic model (e.g., mcpgateway.models.Completion)
2183 if hasattr(completion_obj, "model_dump"):
2184 return types.Completion(**completion_obj.model_dump())
2186 # If result itself is already a types.Completion
2187 if isinstance(result, types.Completion):
2188 return result
2190 # Fallback: return empty completion
2191 return types.Completion(values=[], total=0, hasMore=False)
2193 except Exception as e:
2194 logger.exception("Error handling completion: %s", e)
2195 return types.Completion(values=[], total=0, hasMore=False)
2198class SessionManagerWrapper:
2199 """
2200 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
2201 Provides start, stop, and request handling methods.
2203 Examples:
2204 >>> # Test SessionManagerWrapper initialization
2205 >>> wrapper = SessionManagerWrapper()
2206 >>> wrapper
2207 <mcpgateway.transports.streamablehttp_transport.SessionManagerWrapper object at ...>
2208 >>> hasattr(wrapper, 'session_manager')
2209 True
2210 >>> hasattr(wrapper, 'stack')
2211 True
2212 >>> isinstance(wrapper.stack, AsyncExitStack)
2213 True
2214 """
2216 def __init__(self) -> None:
2217 """
2218 Initializes the session manager and the exit stack used for managing its lifecycle.
2220 Examples:
2221 >>> # Test initialization
2222 >>> wrapper = SessionManagerWrapper()
2223 >>> wrapper.session_manager is not None
2224 True
2225 >>> wrapper.stack is not None
2226 True
2227 """
2229 if settings.use_stateful_sessions:
2230 # Use Redis event store for single-worker stateful deployments
2231 if settings.cache_type == "redis" and settings.redis_url:
2232 event_store = RedisEventStore(max_events_per_stream=settings.streamable_http_max_events_per_stream, ttl=settings.streamable_http_event_ttl)
2233 logger.debug("Using RedisEventStore for stateful sessions (single-worker)")
2234 else:
2235 # Fall back to in-memory for single-worker or when Redis not available
2236 event_store = InMemoryEventStore()
2237 logger.warning("Using InMemoryEventStore - only works with single worker!")
2238 stateless = False
2239 else:
2240 event_store = None
2241 stateless = True
2243 self.session_manager = StreamableHTTPSessionManager(
2244 app=mcp_app,
2245 event_store=event_store,
2246 json_response=settings.json_response_enabled,
2247 stateless=stateless,
2248 )
2249 self.stack = AsyncExitStack()
2251 async def initialize(self) -> None:
2252 """
2253 Starts the Streamable HTTP session manager context.
2255 Examples:
2256 >>> # Test initialize method exists
2257 >>> wrapper = SessionManagerWrapper()
2258 >>> hasattr(wrapper, 'initialize')
2259 True
2260 >>> callable(wrapper.initialize)
2261 True
2262 """
2263 logger.debug("Initializing Streamable HTTP service")
2264 await self.stack.enter_async_context(self.session_manager.run())
2266 async def shutdown(self) -> None:
2267 """
2268 Gracefully shuts down the Streamable HTTP session manager.
2270 Examples:
2271 >>> # Test shutdown method exists
2272 >>> wrapper = SessionManagerWrapper()
2273 >>> hasattr(wrapper, 'shutdown')
2274 True
2275 >>> callable(wrapper.shutdown)
2276 True
2277 """
2278 logger.debug("Stopping Streamable HTTP Session Manager...")
2279 await self.stack.aclose()
2281 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
2282 """
2283 Forwards an incoming ASGI request to the streamable HTTP session manager.
2285 Args:
2286 scope (Scope): ASGI scope object containing connection information.
2287 receive (Receive): ASGI receive callable.
2288 send (Send): ASGI send callable.
2290 Raises:
2291 Exception: Any exception raised during request handling is logged.
2293 Logs any exceptions that occur during request handling.
2295 Examples:
2296 >>> # Test handle_streamable_http method exists
2297 >>> wrapper = SessionManagerWrapper()
2298 >>> hasattr(wrapper, 'handle_streamable_http')
2299 True
2300 >>> callable(wrapper.handle_streamable_http)
2301 True
2303 >>> # Test method signature
2304 >>> import inspect
2305 >>> sig = inspect.signature(wrapper.handle_streamable_http)
2306 >>> list(sig.parameters.keys())
2307 ['scope', 'receive', 'send']
2308 """
2310 path = scope["modified_path"]
2311 # Uses precompiled regex for server ID extraction
2312 match = _SERVER_ID_RE.search(path)
2314 # Extract request headers from scope (ASGI provides bytes; normalize to lowercase for lookup).
2315 raw_headers = scope.get("headers") or []
2316 headers: dict[str, str] = {}
2317 for item in raw_headers:
2318 if not isinstance(item, (tuple, list)) or len(item) != 2:
2319 continue
2320 k, v = item
2321 if not isinstance(k, (bytes, bytearray)) or not isinstance(v, (bytes, bytearray)):
2322 continue
2323 # latin-1 is a byte-preserving decode; safe for arbitrary header bytes.
2324 headers[k.decode("latin-1").lower()] = v.decode("latin-1")
2326 # Log session info for debugging stateful sessions
2327 mcp_session_id = headers.get("mcp-session-id", "not-provided")
2328 method = scope.get("method", "UNKNOWN")
2329 query_string = scope.get("query_string", b"").decode("utf-8")
2330 logger.debug("[STATEFUL] Streamable HTTP %s %s | MCP-Session-Id: %s | Query: %s | Stateful: %s", method, path, mcp_session_id, query_string, settings.use_stateful_sessions)
2332 # Note: mcp-session-id from client is used for gateway-internal session affinity
2333 # routing (stored in request_headers_var), but is NOT renamed or forwarded to
2334 # upstream servers - it's a gateway-side concept, not an end-to-end semantic header
2336 # Multi-worker session affinity: check if we should forward to another worker
2337 # This must happen BEFORE the SDK's session manager handles the request
2338 # Only trust x-forwarded-internally from loopback to prevent external spoofing
2339 _client = scope.get("client")
2340 _client_host = _client[0] if _client else None
2341 _from_loopback = _client_host in ("127.0.0.1", "::1") if _client_host else False
2342 is_internally_forwarded = _from_loopback and headers.get("x-forwarded-internally") == "true"
2344 if settings.mcpgateway_session_affinity_enabled and mcp_session_id != "not-provided":
2345 try:
2346 # First-Party
2347 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
2349 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
2350 logger.debug("Invalid MCP session id on Streamable HTTP request, skipping affinity")
2351 mcp_session_id = "not-provided"
2352 except Exception:
2353 mcp_session_id = "not-provided"
2355 # Log session manager ID for debugging
2356 logger.debug("[SESSION_MGR_DEBUG] Manager ID: %s", id(self.session_manager))
2358 # Enforce server access parity for server-scoped Streamable HTTP MCP routes.
2359 # This mirrors /servers/{id}/sse and /servers/{id}/message guards.
2360 user_context = user_context_var.get()
2361 if match and _should_enforce_streamable_rbac(user_context):
2362 _is_session = user_context.get("token_use") == "session" if user_context else False
2363 has_server_access = await _check_streamable_permission(
2364 user_context=user_context,
2365 permission="servers.use",
2366 check_any_team=_is_session,
2367 )
2368 if not has_server_access:
2369 response = ORJSONResponse(
2370 {"detail": _ACCESS_DENIED_MSG},
2371 status_code=HTTP_403_FORBIDDEN,
2372 )
2373 await response(scope, receive, send)
2374 return
2376 if is_internally_forwarded:
2377 logger.debug("[HTTP_AFFINITY_FORWARDED] Received forwarded request | Method: %s | Session: %s", method, mcp_session_id)
2379 # Only route POST requests with JSON-RPC body to /rpc
2380 # DELETE and other methods should return success (session cleanup is local)
2381 if method != "POST":
2382 logger.debug("[HTTP_AFFINITY_FORWARDED] Non-POST method, returning 200 OK")
2383 await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]})
2384 await send({"type": "http.response.body", "body": b'{"jsonrpc":"2.0","result":{}}'})
2385 return
2387 # For POST requests, bypass SDK session manager and use /rpc directly
2388 # This avoids SDK's session cleanup issues while maintaining stateful upstream connections
2389 try:
2390 # Read request body
2391 body_parts = []
2392 while True:
2393 message = await receive()
2394 if message["type"] == "http.request":
2395 body_parts.append(message.get("body", b""))
2396 if not message.get("more_body", False):
2397 break
2398 elif message["type"] == "http.disconnect":
2399 return
2400 body = b"".join(body_parts)
2402 if not body:
2403 logger.debug("[HTTP_AFFINITY_FORWARDED] Empty body, returning 202 Accepted")
2404 await send({"type": "http.response.start", "status": 202, "headers": []})
2405 await send({"type": "http.response.body", "body": b""})
2406 return
2408 json_body = orjson.loads(body)
2409 rpc_method = json_body.get("method", "")
2410 logger.debug("[HTTP_AFFINITY_FORWARDED] Routing to /rpc | Method: %s", rpc_method)
2412 session_allowed, deny_status, deny_detail = await _validate_streamable_session_access(
2413 mcp_session_id=mcp_session_id,
2414 user_context=user_context,
2415 rpc_method=rpc_method,
2416 )
2417 if not session_allowed:
2418 response = ORJSONResponse({"detail": deny_detail}, status_code=deny_status)
2419 await response(scope, receive, send)
2420 return
2422 # Notifications don't need /rpc routing - just acknowledge
2423 if rpc_method.startswith("notifications/"):
2424 logger.debug("[HTTP_AFFINITY_FORWARDED] Notification, returning 202 Accepted")
2425 await send({"type": "http.response.start", "status": 202, "headers": []})
2426 await send({"type": "http.response.body", "body": b""})
2427 return
2429 # Inject server_id from URL path into params for /rpc routing
2430 if match:
2431 server_id = match.group("server_id")
2432 if not isinstance(json_body.get("params"), dict):
2433 json_body["params"] = {}
2434 json_body["params"]["server_id"] = server_id
2435 # Re-serialize body with injected server_id
2436 body = orjson.dumps(json_body)
2437 logger.debug("[HTTP_AFFINITY_FORWARDED] Injected server_id %s into /rpc params", server_id)
2439 async with httpx.AsyncClient() as client:
2440 rpc_headers = {
2441 "content-type": "application/json",
2442 "x-mcp-session-id": mcp_session_id, # Pass session for upstream affinity
2443 "x-forwarded-internally": "true", # Prevent infinite forwarding loops
2444 }
2445 # Copy auth header if present
2446 if "authorization" in headers:
2447 rpc_headers["authorization"] = headers["authorization"]
2449 response = await client.post(
2450 f"http://127.0.0.1:{settings.port}/rpc",
2451 content=body,
2452 headers=rpc_headers,
2453 timeout=30.0,
2454 )
2456 # Return response to client
2457 response_headers = [
2458 (b"content-type", b"application/json"),
2459 (b"content-length", str(len(response.content)).encode()),
2460 ]
2461 if mcp_session_id != "not-provided":
2462 response_headers.append((b"mcp-session-id", mcp_session_id.encode()))
2464 await send(
2465 {
2466 "type": "http.response.start",
2467 "status": response.status_code,
2468 "headers": response_headers,
2469 }
2470 )
2471 await send(
2472 {
2473 "type": "http.response.body",
2474 "body": response.content,
2475 }
2476 )
2477 logger.debug("[HTTP_AFFINITY_FORWARDED] Response sent | Status: %s", response.status_code)
2478 return
2479 except Exception as e:
2480 logger.error("[HTTP_AFFINITY_FORWARDED] Error routing to /rpc: %s", e)
2481 # Fall through to SDK handling as fallback
2483 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions and mcp_session_id != "not-provided" and not is_internally_forwarded:
2484 try:
2485 # First-Party - lazy import to avoid circular dependencies
2486 # First-Party
2487 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
2489 pool = get_mcp_session_pool()
2490 owner = await pool.get_streamable_http_session_owner(mcp_session_id)
2491 logger.debug("[HTTP_AFFINITY_CHECK] Worker %s | Session %s... | Owner from Redis: %s", WORKER_ID, mcp_session_id[:8], owner)
2493 if owner and owner != WORKER_ID:
2494 # Session owned by another worker - forward the entire HTTP request
2495 logger.info("[HTTP_AFFINITY] Worker %s | Session %s... | Owner: %s | Forwarding HTTP request", WORKER_ID, mcp_session_id[:8], owner)
2497 # Read request body
2498 body_parts = []
2499 while True:
2500 message = await receive()
2501 if message["type"] == "http.request":
2502 body_parts.append(message.get("body", b""))
2503 if not message.get("more_body", False):
2504 break
2505 elif message["type"] == "http.disconnect":
2506 return
2507 body = b"".join(body_parts)
2509 # Forward to owner worker
2510 response = await pool.forward_streamable_http_to_owner(
2511 owner_worker_id=owner,
2512 mcp_session_id=mcp_session_id,
2513 method=method,
2514 path=path,
2515 headers=headers,
2516 body=body,
2517 query_string=query_string,
2518 )
2520 if response:
2521 # Send forwarded response back to client
2522 response_headers = [(k.encode(), v.encode()) for k, v in response["headers"].items() if k.lower() not in ("transfer-encoding", "content-encoding", "content-length")]
2523 response_headers.append((b"content-length", str(len(response["body"])).encode()))
2525 await send(
2526 {
2527 "type": "http.response.start",
2528 "status": response["status"],
2529 "headers": response_headers,
2530 }
2531 )
2532 await send(
2533 {
2534 "type": "http.response.body",
2535 "body": response["body"],
2536 }
2537 )
2538 logger.debug("[HTTP_AFFINITY] Worker %s | Session %s... | Forwarded response sent to client", WORKER_ID, mcp_session_id[:8])
2539 return
2541 # Forwarding failed - fall through to local handling
2542 # This may result in "session not found" but it's better than no response
2543 logger.debug("[HTTP_AFFINITY] Worker %s | Session %s... | Forwarding failed, falling back to local", WORKER_ID, mcp_session_id[:8])
2545 elif owner == WORKER_ID and method == "POST":
2546 # We own this session - route POST requests to /rpc to avoid SDK session issues
2547 # The SDK's _server_instances gets cleared between requests, so we can't rely on it
2548 logger.debug("[HTTP_AFFINITY_LOCAL] Worker %s | Session %s... | Owner is us, routing to /rpc", WORKER_ID, mcp_session_id[:8])
2550 # Read request body
2551 body_parts = []
2552 while True:
2553 message = await receive()
2554 if message["type"] == "http.request":
2555 body_parts.append(message.get("body", b""))
2556 if not message.get("more_body", False):
2557 break
2558 elif message["type"] == "http.disconnect":
2559 return
2560 body = b"".join(body_parts)
2562 if not body:
2563 logger.debug("[HTTP_AFFINITY_LOCAL] Empty body, returning 202 Accepted")
2564 await send({"type": "http.response.start", "status": 202, "headers": []})
2565 await send({"type": "http.response.body", "body": b""})
2566 return
2568 # Parse JSON-RPC and route to /rpc
2569 try:
2570 json_body = orjson.loads(body)
2571 rpc_method = json_body.get("method", "")
2572 logger.debug("[HTTP_AFFINITY_LOCAL] Routing to /rpc | Method: %s", rpc_method)
2574 session_allowed, deny_status, deny_detail = await _validate_streamable_session_access(
2575 mcp_session_id=mcp_session_id,
2576 user_context=user_context,
2577 rpc_method=rpc_method,
2578 )
2579 if not session_allowed:
2580 response = ORJSONResponse({"detail": deny_detail}, status_code=deny_status)
2581 await response(scope, receive, send)
2582 return
2584 # Notifications don't need /rpc routing
2585 if rpc_method.startswith("notifications/"):
2586 logger.debug("[HTTP_AFFINITY_LOCAL] Notification, returning 202 Accepted")
2587 await send({"type": "http.response.start", "status": 202, "headers": []})
2588 await send({"type": "http.response.body", "body": b""})
2589 return
2591 # Inject server_id from URL path into params for /rpc routing
2592 if match:
2593 server_id = match.group("server_id")
2594 if not isinstance(json_body.get("params"), dict):
2595 json_body["params"] = {}
2596 json_body["params"]["server_id"] = server_id
2597 # Re-serialize body with injected server_id
2598 body = orjson.dumps(json_body)
2599 logger.debug("[HTTP_AFFINITY_LOCAL] Injected server_id %s into /rpc params", server_id)
2601 async with httpx.AsyncClient() as client:
2602 rpc_headers = {
2603 "content-type": "application/json",
2604 "x-mcp-session-id": mcp_session_id,
2605 "x-forwarded-internally": "true",
2606 }
2607 if "authorization" in headers:
2608 rpc_headers["authorization"] = headers["authorization"]
2610 response = await client.post(
2611 f"http://127.0.0.1:{settings.port}/rpc",
2612 content=body,
2613 headers=rpc_headers,
2614 timeout=30.0,
2615 )
2617 response_headers = [
2618 (b"content-type", b"application/json"),
2619 (b"content-length", str(len(response.content)).encode()),
2620 (b"mcp-session-id", mcp_session_id.encode()),
2621 ]
2623 await send(
2624 {
2625 "type": "http.response.start",
2626 "status": response.status_code,
2627 "headers": response_headers,
2628 }
2629 )
2630 await send(
2631 {
2632 "type": "http.response.body",
2633 "body": response.content,
2634 }
2635 )
2636 logger.debug("[HTTP_AFFINITY_LOCAL] Response sent | Status: %s", response.status_code)
2637 return
2638 except Exception as e:
2639 logger.error("[HTTP_AFFINITY_LOCAL] Error routing to /rpc: %s", e)
2640 # Fall through to SDK handling as fallback
2642 except RuntimeError:
2643 # Pool not initialized - proceed with local handling
2644 pass
2645 except Exception as e:
2646 logger.debug("Session affinity check failed, proceeding locally: %s", e)
2648 # Store headers in context for tool invocations
2649 request_headers_var.set(headers)
2651 if match:
2652 server_id = match.group("server_id")
2653 server_id_var.set(server_id)
2654 else:
2655 server_id_var.set(None)
2657 # For session affinity: wrap send to capture session ID from response headers
2658 # This allows us to register ownership for new sessions created by the SDK
2659 captured_session_id: Optional[str] = None
2661 async def send_with_capture(message: Dict[str, Any]) -> None:
2662 """Wrap ASGI send to capture session ID from response headers.
2664 Args:
2665 message: ASGI message dict.
2666 """
2667 nonlocal captured_session_id
2668 if message["type"] == "http.response.start" and settings.mcpgateway_session_affinity_enabled:
2669 # Look for mcp-session-id in response headers
2670 response_headers = message.get("headers", [])
2671 for header_name, header_value in response_headers:
2672 if isinstance(header_name, bytes):
2673 header_name = header_name.decode("latin-1")
2674 if isinstance(header_value, bytes):
2675 header_value = header_value.decode("latin-1")
2676 if header_name.lower() == "mcp-session-id":
2677 captured_session_id = header_value
2678 break
2679 await send(message)
2681 # Propagate middleware-resolved context via ASGI scope so that MCP
2682 # handlers can retrieve it even when ContextVars are lost (the SDK's
2683 # task group was created at startup, so spawned handler tasks inherit
2684 # the startup context rather than the per-request context).
2685 scope[_MCPGATEWAY_CONTEXT_KEY] = {
2686 "server_id": server_id_var.get(),
2687 "request_headers": headers,
2688 "user_context": user_context,
2689 }
2691 try:
2692 await self.session_manager.handle_request(scope, receive, send_with_capture)
2693 logger.debug("[STATEFUL] Streamable HTTP request completed successfully | Session: %s", mcp_session_id)
2695 # Register ownership for the session we just handled
2696 # This captures both existing sessions (mcp_session_id from request)
2697 # and new sessions (captured_session_id from response)
2698 logger.debug(
2699 "[HTTP_AFFINITY_DEBUG] affinity_enabled=%s stateful=%s captured=%s mcp_session_id=%s",
2700 settings.mcpgateway_session_affinity_enabled,
2701 settings.use_stateful_sessions,
2702 captured_session_id,
2703 mcp_session_id,
2704 )
2705 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions:
2706 session_to_register: Optional[str] = None
2708 # Only server-emitted session IDs (from successful initialize) can
2709 # establish new ownership state for affinity.
2710 if captured_session_id:
2711 session_to_register = captured_session_id
2713 requester_email = user_context.get("email") if isinstance(user_context, dict) else None
2714 if requester_email:
2715 effective_owner = await _claim_streamable_session_owner(captured_session_id, requester_email)
2716 if effective_owner and effective_owner != requester_email and not bool(user_context.get("is_admin", False)):
2717 logger.warning("Session owner mismatch for %s... (requester=%s, owner=%s)", captured_session_id[:8], requester_email, effective_owner)
2718 elif mcp_session_id != "not-provided":
2719 # Existing client-provided IDs may only refresh affinity when they
2720 # are already bound to the caller's principal.
2721 session_allowed, _deny_status, _deny_detail = await _validate_streamable_session_access(
2722 mcp_session_id=mcp_session_id,
2723 user_context=user_context,
2724 rpc_method=None,
2725 )
2726 if session_allowed:
2727 session_to_register = mcp_session_id
2729 logger.debug("[HTTP_AFFINITY_DEBUG] session_to_register=%s", session_to_register)
2730 if session_to_register:
2731 try:
2732 # First-Party - lazy import to avoid circular dependencies
2733 # First-Party
2734 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
2736 pool = get_mcp_session_pool()
2737 await pool.register_pool_session_owner(session_to_register)
2738 logger.debug("[HTTP_AFFINITY_SDK] Worker %s | Session %s... | Registered ownership after SDK handling", WORKER_ID, session_to_register[:8])
2739 except Exception as e:
2740 logger.debug("[HTTP_AFFINITY_DEBUG] Exception during registration: %s", e)
2741 logger.warning("Failed to register session ownership: %s", e)
2743 except anyio.ClosedResourceError:
2744 # Expected when client closes one side of the stream (normal lifecycle)
2745 logger.debug("Streamable HTTP connection closed by client (ClosedResourceError)")
2746 except Exception as e:
2747 logger.error("[STATEFUL] Streamable HTTP request failed | Session: %s | Error: %s", mcp_session_id, e)
2748 logger.exception("Error handling streamable HTTP request: %s", e)
2749 raise
2752# ------------------------- Authentication for /mcp routes ------------------------------
2755def _set_proxy_user_context(proxy_user: str) -> None:
2756 """Set user context for a proxy-authenticated request (no team context, non-admin).
2758 Args:
2759 proxy_user: Email address of the proxy-authenticated user.
2760 """
2761 user_context_var.set(
2762 {
2763 "email": proxy_user,
2764 "teams": [],
2765 "is_authenticated": True,
2766 "is_admin": False,
2767 }
2768 )
2771class _StreamableHttpAuthHandler:
2772 """Per-request handler that authenticates MCP StreamableHTTP requests.
2774 Encapsulates the ASGI triple (scope, receive, send) so that helper methods
2775 can send error responses without threading these values through every call.
2776 """
2778 __slots__ = ("scope", "receive", "send")
2780 def __init__(self, scope: Any, receive: Any, send: Any) -> None:
2781 self.scope = scope
2782 self.receive = receive
2783 self.send = send
2785 async def _send_error(self, *, detail: str, status_code: int = HTTP_401_UNAUTHORIZED, headers: dict[str, str] | None = None) -> bool:
2786 """Send an error response and return False (auth rejected).
2788 Args:
2789 detail: Error message for the JSON response body.
2790 status_code: HTTP status code (default 401).
2791 headers: Optional response headers (e.g. WWW-Authenticate).
2793 Returns:
2794 Always ``False`` so callers can ``return await self._send_error(...)``.
2795 """
2796 response = ORJSONResponse({"detail": detail}, status_code=status_code, headers=headers or {})
2797 await response(self.scope, self.receive, self.send)
2798 return False
2800 async def authenticate(self) -> bool:
2801 """Perform authentication check in middleware context (ASGI scope).
2803 Authenticates only requests targeting paths ending in "/mcp" or "/mcp/".
2805 Behavior:
2806 - If the path does not end with "/mcp", authentication is skipped.
2807 - If mcp_require_auth=True (strict mode): requests without valid auth are rejected with 401.
2808 - If mcp_require_auth=False (permissive mode):
2809 - Requests without auth are allowed but get public-only access (token_teams=[]).
2810 - EXCEPTION: if the target server has oauth_enabled=True, unauthenticated
2811 requests are rejected with 401 regardless of the global setting.
2812 - Valid tokens get full scoped access based on their teams.
2813 - Malformed/invalid Bearer tokens are rejected with 401 (no silent downgrade).
2814 - If a Bearer token is present, it is verified using ``verify_credentials``.
2816 Returns:
2817 True if authentication passes or is skipped.
2818 False if authentication fails and a 401 response is sent.
2819 """
2820 path = self.scope.get("path", "")
2821 if (not path.endswith("/mcp") and not path.endswith("/mcp/")) or path.startswith("/.well-known/"):
2822 # No auth for non-MCP paths or RFC 9728 metadata endpoints
2823 return True
2825 # Reset per-request OAuth enforcement cache so keep-alive connections
2826 # re-evaluate on every request instead of inheriting a stale True.
2827 _oauth_checked_var.set(False)
2829 headers = Headers(scope=self.scope)
2831 # CORS preflight (OPTIONS + Origin + Access-Control-Request-Method) cannot carry auth headers
2832 method = self.scope.get("method", "")
2833 if method == "OPTIONS":
2834 origin = headers.get("origin")
2835 if origin and headers.get("access-control-request-method"):
2836 return True
2838 authorization = headers.get("authorization")
2839 proxy_trusted = is_proxy_auth_trust_active(settings)
2840 proxy_user = headers.get(settings.proxy_user_header) if proxy_trusted else None
2842 # Determine authentication strategy based on settings
2843 if proxy_trusted and proxy_user:
2844 _set_proxy_user_context(proxy_user)
2845 return True # Trusted proxy supplied user
2847 # --- Standard JWT authentication flow (client auth enabled) ---
2848 token: str | None = None
2849 bearer_header_supplied = False
2850 if authorization:
2851 scheme, credentials = get_authorization_scheme_param(authorization)
2852 if scheme.lower() == "bearer":
2853 bearer_header_supplied = True
2854 if credentials:
2855 token = credentials
2857 if token is None:
2858 return await self._auth_no_token(path=path, bearer_header_supplied=bearer_header_supplied)
2860 return await self._auth_jwt(token=token)
2862 async def _auth_no_token(self, *, path: str, bearer_header_supplied: bool) -> bool:
2863 """Handle unauthenticated MCP requests (no Bearer token present).
2865 Args:
2866 path: Request path (used for per-server OAuth enforcement).
2867 bearer_header_supplied: True when Authorization: Bearer was present but empty.
2869 Returns:
2870 True if the request is allowed with public-only access, False if rejected.
2871 """
2872 # If client supplied a Bearer header but with empty credentials, fail closed
2873 if bearer_header_supplied:
2874 return await self._send_error(detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"})
2876 # Strict mode: require authentication
2877 if settings.mcp_require_auth:
2878 return await self._send_error(detail="Authentication required for MCP endpoints", headers={"WWW-Authenticate": "Bearer"})
2880 # Permissive mode: allow unauthenticated access with public-only scope
2881 # BUT first check if this specific server requires OAuth (per-server enforcement)
2882 match = _SERVER_ID_RE.search(path)
2883 if match:
2884 per_server_id = match.group("server_id")
2885 try:
2886 await _check_server_oauth_enforcement(per_server_id, {"is_authenticated": False})
2887 except OAuthRequiredError:
2888 resource_metadata = _build_resource_metadata_url(self.scope, per_server_id)
2889 www_auth = f'Bearer resource_metadata="{resource_metadata}"' if resource_metadata else "Bearer"
2890 return await self._send_error(detail="This server requires OAuth authentication", headers={"WWW-Authenticate": www_auth})
2891 except OAuthEnforcementUnavailableError:
2892 logger.exception("OAuth enforcement check failed for server %s", per_server_id)
2893 return await self._send_error(detail="Service unavailable — unable to verify server authentication requirements", status_code=503)
2895 # Set context indicating unauthenticated user with public-only access (teams=[])
2896 user_context_var.set(
2897 {
2898 "email": None,
2899 "teams": [], # Empty list = public-only access
2900 "is_authenticated": False,
2901 "is_admin": False,
2902 }
2903 )
2904 return True # Allow request to proceed with public-only access
2906 async def _auth_jwt(self, *, token: str) -> bool:
2907 """Verify a JWT Bearer token and populate the user context.
2909 Args:
2910 token: Bearer token value extracted from the Authorization header.
2912 Returns:
2913 True if verification succeeds, False if rejected (401/403/503 sent).
2914 """
2915 try:
2916 user_payload = await verify_credentials(token)
2917 # Store enriched user context with normalized teams
2918 if not isinstance(user_payload, dict):
2919 return True
2921 jti = user_payload.get("jti")
2922 if jti:
2923 # First-Party
2924 from mcpgateway.auth import _check_token_revoked_sync # pylint: disable=import-outside-toplevel
2926 try:
2927 is_revoked = await asyncio.to_thread(_check_token_revoked_sync, jti)
2928 except Exception as exc:
2929 logger.warning("MCP token revocation check failed for jti=%s; allowing request (fail-open): %s", jti, exc)
2930 is_revoked = False
2931 if is_revoked:
2932 return await self._send_error(detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"})
2934 user_email = user_payload.get("sub") or user_payload.get("email")
2935 if user_email:
2936 # First-Party
2937 from mcpgateway.auth import _get_user_by_email_sync # pylint: disable=import-outside-toplevel
2939 user_lookup_succeeded = True
2940 try:
2941 user_record = await asyncio.to_thread(_get_user_by_email_sync, user_email)
2942 except Exception as exc:
2943 user_lookup_succeeded = False
2944 user_record = None
2945 logger.warning("MCP user lookup failed for user=%s; allowing request (fail-open): %s", user_email, exc)
2947 if user_lookup_succeeded:
2948 if user_record and not getattr(user_record, "is_active", True):
2949 return await self._send_error(detail="Account disabled", headers={"WWW-Authenticate": "Bearer"})
2950 if user_record is None and settings.require_user_in_db and user_email != getattr(settings, "platform_admin_email", "admin@example.com"):
2951 return await self._send_error(detail="User not found in database", headers={"WWW-Authenticate": "Bearer"})
2953 # Resolve teams based on token_use claim
2954 token_use = user_payload.get("token_use")
2955 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type
2956 # Session token: resolve teams from DB/cache
2957 user_email_for_teams = user_payload.get("sub") or user_payload.get("email")
2958 is_admin_flag = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False)
2959 if is_admin_flag:
2960 final_teams = None # Admin bypass
2961 elif user_email_for_teams:
2962 # Resolve teams synchronously with L1 cache (StreamableHTTP uses sync context)
2963 # First-Party
2964 from mcpgateway.auth import _resolve_teams_from_db_sync # pylint: disable=import-outside-toplevel
2966 final_teams = _resolve_teams_from_db_sync(user_email_for_teams, is_admin=False)
2967 else:
2968 final_teams = [] # No email — public-only
2969 else:
2970 # API token or legacy: use embedded teams from JWT
2971 # First-Party
2972 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel
2974 final_teams = normalize_token_teams(user_payload)
2976 # ═══════════════════════════════════════════════════════════════════════════
2977 # SECURITY: Validate team membership for team-scoped tokens
2978 # Users removed from a team should lose MCP access immediately, not at token expiry
2979 # ═══════════════════════════════════════════════════════════════════════════
2980 is_admin = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False)
2982 # Only validate membership for team-scoped tokens (non-empty teams list)
2983 # Skip for: public-only tokens ([]), admin unrestricted tokens (None)
2984 if final_teams and len(final_teams) > 0 and user_email:
2985 # Import lazily to avoid circular imports
2986 # Third-Party
2987 from sqlalchemy import select # pylint: disable=import-outside-toplevel
2989 # First-Party
2990 from mcpgateway.cache.auth_cache import get_auth_cache # pylint: disable=import-outside-toplevel
2991 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel
2993 auth_cache = get_auth_cache()
2995 # Check cache first (60s TTL)
2996 cached_result = auth_cache.get_team_membership_valid_sync(user_email, final_teams)
2997 if cached_result is False:
2998 logger.warning("MCP auth rejected: User %s no longer member of teams (cached)", user_email)
2999 return await self._send_error(detail="Token invalid: User is no longer a member of the associated team", status_code=HTTP_403_FORBIDDEN)
3001 if cached_result is None:
3002 # Cache miss - query database
3003 with SessionLocal() as db:
3004 memberships = (
3005 db.execute(
3006 select(EmailTeamMember.team_id).where(
3007 EmailTeamMember.team_id.in_(final_teams),
3008 EmailTeamMember.user_email == user_email,
3009 EmailTeamMember.is_active.is_(True),
3010 )
3011 )
3012 .scalars()
3013 .all()
3014 )
3016 valid_team_ids = set(memberships)
3017 missing_teams = set(final_teams) - valid_team_ids
3019 if missing_teams:
3020 logger.warning("MCP auth rejected: User %s no longer member of teams: %s", user_email, missing_teams)
3021 auth_cache.set_team_membership_valid_sync(user_email, final_teams, False)
3022 return await self._send_error(detail="Token invalid: User is no longer a member of the associated team", status_code=HTTP_403_FORBIDDEN)
3024 # Cache positive result
3025 auth_cache.set_team_membership_valid_sync(user_email, final_teams, True)
3027 auth_user_ctx: dict[str, Any] = {
3028 "email": user_email,
3029 "teams": final_teams,
3030 "is_authenticated": True,
3031 "is_admin": is_admin,
3032 "token_use": token_use, # propagated for downstream RBAC (check_any_team)
3033 }
3034 # Extract scoped permissions from JWT for per-method enforcement
3035 jwt_scopes = user_payload.get("scopes") or {}
3036 jwt_scoped_perms = jwt_scopes.get("permissions") or [] if isinstance(jwt_scopes, dict) else []
3037 if jwt_scoped_perms:
3038 auth_user_ctx["scoped_permissions"] = jwt_scoped_perms
3039 user_context_var.set(auth_user_ctx)
3040 except HTTPException:
3041 # JWT verification failed (expired, malformed, bad signature, etc.)
3042 return await self._send_error(detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"})
3043 except SQLAlchemyError:
3044 # DB failure during team resolution or membership validation
3045 logger.exception("Database error during MCP authentication")
3046 return await self._send_error(detail="Service unavailable — unable to verify authentication", status_code=503)
3047 except Exception:
3048 # Unexpected error during authentication — fail closed with 401.
3049 logger.exception("Unexpected error during MCP JWT authentication")
3050 return await self._send_error(detail="Authentication failed", headers={"WWW-Authenticate": "Bearer"})
3052 return True
3055async def streamable_http_auth(scope: Any, receive: Any, send: Any) -> bool:
3056 """Perform authentication check in middleware context (ASGI scope).
3058 Delegates to :class:`_StreamableHttpAuthHandler` which encapsulates the
3059 ASGI triple so helper methods can send error responses directly.
3061 Args:
3062 scope: The ASGI scope dictionary, which includes request metadata.
3063 receive: ASGI receive callable used to receive events.
3064 send: ASGI send callable used to send events (e.g. a 401 response).
3066 Returns:
3067 bool: True if authentication passes or is skipped.
3068 False if authentication fails and a 401 response is sent.
3070 Examples:
3071 >>> # Test streamable_http_auth function exists
3072 >>> callable(streamable_http_auth)
3073 True
3075 >>> # Test function signature
3076 >>> import inspect
3077 >>> sig = inspect.signature(streamable_http_auth)
3078 >>> list(sig.parameters.keys())
3079 ['scope', 'receive', 'send']
3080 """
3081 return await _StreamableHttpAuthHandler(scope, receive, send).authenticate()