Coverage for mcpgateway / transports / streamablehttp_transport.py: 99%
1519 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/streamablehttp_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Keval Mahajan
7Streamable HTTP Transport Implementation.
8This module implements Streamable Http transport for MCP
10Key components include:
11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions
12- Configuration options for:
13 1. stateful/stateless operation
14 2. JSON response mode or SSE streams
15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state
17Examples:
18 >>> # Test module imports
19 >>> from mcpgateway.transports.streamablehttp_transport import (
20 ... EventEntry, StreamBuffer, InMemoryEventStore, SessionManagerWrapper
21 ... )
22 >>>
23 >>> # Verify classes are available
24 >>> EventEntry.__name__
25 'EventEntry'
26 >>> StreamBuffer.__name__
27 'StreamBuffer'
28 >>> InMemoryEventStore.__name__
29 'InMemoryEventStore'
30 >>> SessionManagerWrapper.__name__
31 'SessionManagerWrapper'
32"""
34# Standard
35import asyncio
36from contextlib import asynccontextmanager, AsyncExitStack, ExitStack
37import contextvars
38from dataclasses import dataclass
39import re
40from typing import Any, AsyncGenerator, ContextManager, Dict, List, Optional, Pattern, Tuple, Union
41from urllib.parse import urlsplit, urlunsplit
42from uuid import uuid4
44# Third-Party
45import anyio
46from fastapi import HTTPException
47from fastapi.security.utils import get_authorization_scheme_param
48import httpx
49from mcp import ClientSession, types
50from mcp.client.streamable_http import streamablehttp_client
51from mcp.server.lowlevel import Server
52from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
53from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
54from mcp.types import JSONRPCMessage, PaginatedRequestParams, ReadResourceRequest, ReadResourceRequestParams
55import orjson
56from sqlalchemy.exc import SQLAlchemyError
57from sqlalchemy.orm import Session
58from starlette.datastructures import Headers
59from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN, HTTP_404_NOT_FOUND
60from starlette.types import Receive, Scope, Send
62# First-Party
63from mcpgateway.cache.global_config_cache import global_config_cache
64from mcpgateway.common.models import LogLevel
65from mcpgateway.config import settings
66from mcpgateway.db import SessionLocal
67from mcpgateway.middleware.rbac import _ACCESS_DENIED_MSG
68from mcpgateway.observability import create_span
69from mcpgateway.services.completion_service import CompletionService
70from mcpgateway.services.http_client_service import get_http_client, get_http_limits
71from mcpgateway.services.logging_service import LoggingService
72from mcpgateway.services.metrics import mcp_auth_cache_events_counter
73from mcpgateway.services.oauth_manager import OAuthEnforcementUnavailableError, OAuthRequiredError
74from mcpgateway.services.permission_service import PermissionService
75from mcpgateway.services.prompt_service import PromptService
76from mcpgateway.services.resource_service import ResourceService
77from mcpgateway.services.tool_service import ToolService
78from mcpgateway.transports.redis_event_store import RedisEventStore
79from mcpgateway.utils.gateway_access import build_gateway_auth_headers, check_gateway_access, extract_gateway_id_from_headers, GATEWAY_ID_HEADER
80from mcpgateway.utils.internal_http import internal_loopback_base_url, internal_loopback_verify
81from mcpgateway.utils.orjson_response import ORJSONResponse
82from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
83from mcpgateway.utils.trace_context import set_trace_context_from_teams, set_trace_session_id
84from mcpgateway.utils.verify_credentials import is_proxy_auth_trust_active, require_auth_header_first, verify_credentials
86# Initialize logging service first
87logging_service = LoggingService()
88logger = logging_service.get_logger(__name__)
91def _maybe_open_initialize_span(body: bytes, *, mcp_session_id: Optional[str], server_id: Optional[str]) -> Optional[ContextManager[Any]]:
92 """Return an active span context manager for raw MCP initialize traffic.
94 Args:
95 body: Raw JSON-RPC request body bytes.
96 mcp_session_id: Session identifier from the request headers when present.
97 server_id: Effective virtual server identifier for the request, if any.
99 Returns:
100 Active span context manager for initialize requests, otherwise a no-op context.
101 """
102 try:
103 payload = orjson.loads(body)
104 except orjson.JSONDecodeError:
105 return None
107 if not isinstance(payload, dict) or str(payload.get("method") or "").strip() != "initialize":
108 return None
110 params = payload.get("params")
111 if not isinstance(params, dict):
112 params = {}
114 session_id = params.get("sessionId") or params.get("session_id")
115 if not session_id and mcp_session_id and mcp_session_id != "not-provided":
116 session_id = mcp_session_id
118 span_attributes: Dict[str, Any] = {
119 "mcp.protocol_version": params.get("protocolVersion") or params.get("protocol_version"),
120 "mcp.session_id": session_id,
121 "server.id": server_id,
122 }
123 return create_span("mcp.initialize", span_attributes)
126def _normalize_mcp_prompt_arguments(arguments: Any) -> Optional[List[types.PromptArgument]]:
127 """Convert internal prompt-argument objects to MCP prompt arguments.
129 The prompt service returns internal schema models, while the MCP transport
130 must emit ``mcp.types.PromptArgument`` instances. Pydantic does not treat
131 different model classes as interchangeable, so raw pass-through raises
132 validation errors during prompt listing.
134 Args:
135 arguments: Prompt arguments from internal services. Items may already be
136 ``mcp.types.PromptArgument`` instances, dicts, or other Pydantic
137 models with matching attributes.
139 Returns:
140 Normalized MCP prompt arguments, or ``None`` when the prompt has no
141 argument list.
142 """
143 if arguments is None:
144 return None
146 normalized_arguments: List[types.PromptArgument] = []
147 for argument in arguments:
148 if isinstance(argument, types.PromptArgument):
149 normalized_arguments.append(argument)
150 else:
151 normalized_arguments.append(types.PromptArgument.model_validate(argument, from_attributes=True))
152 return normalized_arguments
155def _safe_str_attr(obj: Any, attr: str) -> Optional[str]:
156 """Extract an attribute as ``str | None``, guarding against non-string values.
158 Args:
159 obj: The object to read the attribute from.
160 attr: The attribute name to extract.
162 Returns:
163 The attribute value if it is a ``str``, otherwise ``None``.
164 """
165 value = getattr(obj, attr, None)
166 return value if isinstance(value, str) else None
169def _to_mcp_prompt(prompt: Any) -> types.Prompt:
170 """Convert an internal prompt object to the MCP transport model.
172 Args:
173 prompt: Internal prompt object returned by prompt_service.
175 Returns:
176 MCP prompt model suitable for protocol responses.
177 """
178 title = _safe_str_attr(prompt, "title")
180 meta = getattr(prompt, "meta", None)
181 if not isinstance(meta, dict):
182 meta = None
184 return types.Prompt(name=prompt.name, title=title, description=prompt.description, arguments=_normalize_mcp_prompt_arguments(getattr(prompt, "arguments", None)), meta=meta)
187def _record_mcp_auth_cache_event(outcome: str) -> None:
188 """Best-effort Prometheus counter update for MCP auth cache flow.
190 Args:
191 outcome: Cache-flow outcome label to emit.
192 """
193 try:
194 mcp_auth_cache_events_counter.labels(outcome=outcome).inc()
195 except Exception:
196 pass # nosec B110 - Metrics must not break auth flow
199# Precompiled regex for server ID extraction from path.
200# SECURITY: Uses [^/]+ (any non-slash characters) instead of a restrictive hex-only
201# class to ensure ALL server-scoped paths are captured. A narrow regex caused non-hex
202# IDs (e.g. "xyz") to silently fall through to unscoped global behaviour (#3891).
203_SERVER_ID_RE: Pattern[str] = re.compile(r"^/servers/(?P<server_id>[^/]+)/mcp")
205# Pattern that detects a server-scoped MCP path even when _SERVER_ID_RE doesn't
206# match (e.g. empty segment: /servers//mcp). Used as a defense-in-depth guard.
207_SERVER_SCOPED_PATH_RE: Pattern[str] = re.compile(r"^/servers/.*/mcp(?:/)?$")
209# Sentinel returned by _validate_server_id to signal that an error response
210# has already been sent and the caller should return immediately.
211_REJECT = object()
214# ASGI scope key for propagating gateway context from middleware to MCP handlers
215_MCPGATEWAY_CONTEXT_KEY = "_mcpgateway_context"
217# Initialize ToolService, PromptService, ResourceService, CompletionService and MCP Server
218tool_service: ToolService = ToolService()
219prompt_service: PromptService = PromptService()
220resource_service: ResourceService = ResourceService()
221completion_service: CompletionService = CompletionService()
223mcp_app: Server[Any] = Server("mcp-streamable-http")
225server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id")
226request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={})
227user_context_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("user_context", default={})
228_oauth_checked_var: contextvars.ContextVar[bool] = contextvars.ContextVar("_oauth_checked", default=False)
229_shared_session_registry: Optional[Any] = None
230_rust_event_store_client: Optional[httpx.AsyncClient] = None
231_rust_event_store_client_lock = asyncio.Lock()
232_RUST_EVENT_STORE_DEFAULT_KEY_PREFIX = "mcpgw:eventstore"
234# ------------------------------ Event store ------------------------------
237@dataclass
238class EventEntry:
239 """
240 Represents an event entry in the event store.
242 Examples:
243 >>> # Create an event entry
244 >>> from mcp.types import JSONRPCMessage
245 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
246 >>> entry = EventEntry(event_id="test-123", stream_id="stream-456", message=message, seq_num=0)
247 >>> entry.event_id
248 'test-123'
249 >>> entry.stream_id
250 'stream-456'
251 >>> entry.seq_num
252 0
253 >>> # Access message attributes through model_dump() for Pydantic v2
254 >>> message_dict = message.model_dump()
255 >>> message_dict['jsonrpc']
256 '2.0'
257 >>> message_dict['method']
258 'test'
259 >>> message_dict['id']
260 1
261 """
263 event_id: EventId
264 stream_id: StreamId
265 message: JSONRPCMessage
266 seq_num: int
269@dataclass
270class StreamBuffer:
271 """
272 Ring buffer for per-stream event storage with O(1) position lookup.
274 Tracks sequence numbers to enable efficient replay without scanning.
275 Events are stored at position (seq_num % capacity) in the entries list.
277 Examples:
278 >>> # Create a stream buffer with capacity 3
279 >>> buffer = StreamBuffer(entries=[None, None, None])
280 >>> buffer.start_seq
281 0
282 >>> buffer.next_seq
283 0
284 >>> buffer.count
285 0
286 >>> len(buffer)
287 0
289 >>> # Simulate adding an entry
290 >>> buffer.next_seq = 1
291 >>> buffer.count = 1
292 >>> len(buffer)
293 1
294 """
296 entries: list[EventEntry | None]
297 start_seq: int = 0 # oldest seq still buffered
298 next_seq: int = 0 # seq assigned to next insert
299 count: int = 0
301 def __len__(self) -> int:
302 """Return the number of events currently in the buffer.
304 Returns:
305 int: The count of events in the buffer.
306 """
307 return self.count
310class InMemoryEventStore(EventStore):
311 """
312 Simple in-memory implementation of the EventStore interface for resumability.
313 This is primarily intended for examples and testing, not for production use
314 where a persistent storage solution would be more appropriate.
316 This implementation keeps only the last N events per stream for memory efficiency.
317 Uses a ring buffer with per-stream sequence numbers for O(1) event lookup and O(k) replay.
319 Examples:
320 >>> # Create event store with default max events
321 >>> store = InMemoryEventStore()
322 >>> store.max_events_per_stream
323 100
324 >>> len(store.streams)
325 0
326 >>> len(store.event_index)
327 0
329 >>> # Create event store with custom max events
330 >>> store = InMemoryEventStore(max_events_per_stream=50)
331 >>> store.max_events_per_stream
332 50
334 >>> # Test event store initialization
335 >>> store = InMemoryEventStore()
336 >>> hasattr(store, 'streams')
337 True
338 >>> hasattr(store, 'event_index')
339 True
340 >>> isinstance(store.streams, dict)
341 True
342 >>> isinstance(store.event_index, dict)
343 True
344 """
346 def __init__(self, max_events_per_stream: int = 100):
347 """Initialize the event store.
349 Args:
350 max_events_per_stream: Maximum number of events to keep per stream
352 Examples:
353 >>> # Test initialization with default value
354 >>> store = InMemoryEventStore()
355 >>> store.max_events_per_stream
356 100
357 >>> store.streams == {}
358 True
359 >>> store.event_index == {}
360 True
362 >>> # Test initialization with custom value
363 >>> store = InMemoryEventStore(max_events_per_stream=25)
364 >>> store.max_events_per_stream
365 25
366 """
367 self.max_events_per_stream = max_events_per_stream
368 # Per-stream ring buffers for O(1) position lookup
369 self.streams: dict[StreamId, StreamBuffer] = {}
370 # event_id -> EventEntry for quick lookup
371 self.event_index: dict[EventId, EventEntry] = {}
373 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId:
374 """
375 Stores an event with a generated event ID.
377 Args:
378 stream_id (StreamId): The ID of the stream.
379 message (JSONRPCMessage): The message to store.
381 Returns:
382 EventId: The ID of the stored event.
384 Examples:
385 >>> # Test storing an event
386 >>> import asyncio
387 >>> from mcp.types import JSONRPCMessage
388 >>> store = InMemoryEventStore(max_events_per_stream=5)
389 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1)
390 >>> event_id = asyncio.run(store.store_event("stream-1", message))
391 >>> isinstance(event_id, str)
392 True
393 >>> len(event_id) > 0
394 True
395 >>> len(store.streams)
396 1
397 >>> len(store.event_index)
398 1
399 >>> "stream-1" in store.streams
400 True
401 >>> event_id in store.event_index
402 True
404 >>> # Test storing multiple events in same stream
405 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
406 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
407 >>> len(store.streams["stream-1"])
408 2
409 >>> len(store.event_index)
410 2
412 >>> # Test ring buffer overflow
413 >>> store2 = InMemoryEventStore(max_events_per_stream=2)
414 >>> msg1 = JSONRPCMessage(jsonrpc="2.0", method="m1", id=1)
415 >>> msg2 = JSONRPCMessage(jsonrpc="2.0", method="m2", id=2)
416 >>> msg3 = JSONRPCMessage(jsonrpc="2.0", method="m3", id=3)
417 >>> id1 = asyncio.run(store2.store_event("stream-2", msg1))
418 >>> id2 = asyncio.run(store2.store_event("stream-2", msg2))
419 >>> # Now buffer is full, adding third will remove first
420 >>> id3 = asyncio.run(store2.store_event("stream-2", msg3))
421 >>> len(store2.streams["stream-2"])
422 2
423 >>> id1 in store2.event_index # First event removed
424 False
425 >>> id2 in store2.event_index and id3 in store2.event_index
426 True
427 """
428 # Get or create ring buffer for this stream
429 buffer = self.streams.get(stream_id)
430 if buffer is None:
431 buffer = StreamBuffer(entries=[None] * self.max_events_per_stream)
432 self.streams[stream_id] = buffer
434 # Assign per-stream sequence number
435 seq_num = buffer.next_seq
436 buffer.next_seq += 1
437 idx = seq_num % self.max_events_per_stream
439 # Handle eviction if buffer is full
440 if buffer.count == self.max_events_per_stream:
441 evicted = buffer.entries[idx]
442 if evicted is not None:
443 self.event_index.pop(evicted.event_id, None)
444 buffer.start_seq += 1
445 else:
446 if buffer.count == 0:
447 buffer.start_seq = seq_num
448 buffer.count += 1
450 # Create and store the new event entry
451 event_id = str(uuid4())
452 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message, seq_num=seq_num)
453 buffer.entries[idx] = event_entry
454 self.event_index[event_id] = event_entry
456 return event_id
458 async def replay_events_after(
459 self,
460 last_event_id: EventId,
461 send_callback: EventCallback,
462 ) -> Union[StreamId, None]:
463 """
464 Replays events that occurred after the specified event ID.
466 Uses O(1) lookup via event_index and O(k) replay where k is the number
467 of events to replay, avoiding the previous O(n) full scan.
469 Args:
470 last_event_id (EventId): The ID of the last received event. Replay starts after this event.
471 send_callback (EventCallback): Async callback to send each replayed event.
473 Returns:
474 StreamId | None: The stream ID if the event is found and replayed, otherwise None.
476 Examples:
477 >>> # Test replaying events
478 >>> import asyncio
479 >>> from mcp.types import JSONRPCMessage
480 >>> store = InMemoryEventStore()
481 >>> message1 = JSONRPCMessage(jsonrpc="2.0", method="test1", id=1)
482 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2)
483 >>> message3 = JSONRPCMessage(jsonrpc="2.0", method="test3", id=3)
484 >>>
485 >>> # Store events
486 >>> event_id1 = asyncio.run(store.store_event("stream-1", message1))
487 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2))
488 >>> event_id3 = asyncio.run(store.store_event("stream-1", message3))
489 >>>
490 >>> # Test replay after first event
491 >>> replayed_events = []
492 >>> async def mock_callback(event_message):
493 ... replayed_events.append(event_message)
494 >>>
495 >>> result = asyncio.run(store.replay_events_after(event_id1, mock_callback))
496 >>> result
497 'stream-1'
498 >>> len(replayed_events)
499 2
501 >>> # Test replay with non-existent event
502 >>> result = asyncio.run(store.replay_events_after("non-existent", mock_callback))
503 >>> result is None
504 True
505 """
506 # O(1) lookup in event_index
507 last_event = self.event_index.get(last_event_id)
508 if last_event is None:
509 logger.warning("Event ID %s not found in store", last_event_id)
510 return None
512 buffer = self.streams.get(last_event.stream_id)
513 if buffer is None:
514 return None
516 # Validate that the event's seq_num is still within the buffer range
517 if last_event.seq_num < buffer.start_seq or last_event.seq_num >= buffer.next_seq:
518 return None
520 # O(k) replay: iterate from last_event.seq_num + 1 to buffer.next_seq - 1
521 for seq in range(last_event.seq_num + 1, buffer.next_seq):
522 entry = buffer.entries[seq % self.max_events_per_stream]
523 # Guard: skip if slot is empty or has been overwritten by a different seq
524 if entry is None or entry.seq_num != seq:
525 continue
526 await send_callback(EventMessage(entry.message, entry.event_id))
528 return last_event.stream_id
531class RustEventStore(EventStore):
532 """Rust-backed event store that delegates resumable stream state to the sidecar."""
534 def __init__(self, max_events_per_stream: int = 100, ttl: int = 3600, key_prefix: str = _RUST_EVENT_STORE_DEFAULT_KEY_PREFIX):
535 """Initialize the Rust-backed event store wrapper.
537 Args:
538 max_events_per_stream: Maximum number of events retained per stream.
539 ttl: Event retention time in seconds.
540 key_prefix: Redis key prefix shared with the Rust sidecar.
541 """
542 self.max_events_per_stream = max_events_per_stream
543 self.ttl = ttl
544 self.key_prefix = key_prefix.rstrip(":")
546 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId:
547 """Store an event in the Rust-backed resumable event store.
549 Args:
550 stream_id: Stream that owns the event.
551 message: JSON-RPC payload to persist for replay.
553 Returns:
554 The generated event identifier returned by the Rust sidecar.
556 Raises:
557 RuntimeError: If the Rust sidecar event store is unavailable or returns invalid data.
558 """
559 client = await _get_rust_event_store_client()
560 message_dict = None if message is None else (message.model_dump() if hasattr(message, "model_dump") else dict(message))
561 response = await client.post(
562 _build_rust_runtime_internal_url("/_internal/event-store/store"),
563 json={
564 "streamId": stream_id,
565 "message": message_dict,
566 "keyPrefix": self.key_prefix,
567 "maxEventsPerStream": self.max_events_per_stream,
568 "ttlSeconds": self.ttl,
569 },
570 timeout=httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds),
571 follow_redirects=False,
572 )
573 response.raise_for_status()
574 payload = response.json()
575 event_id = payload.get("eventId")
576 if not isinstance(event_id, str) or not event_id:
577 raise RuntimeError("Rust event store returned an invalid eventId")
578 return event_id
580 async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> Union[StreamId, None]:
581 """Replay events newer than ``last_event_id`` through the provided callback.
583 Args:
584 last_event_id: Last event acknowledged by the reconnecting client.
585 send_callback: Callback invoked for each replayed event payload.
587 Returns:
588 The associated stream identifier when replay succeeds, else ``None``.
589 """
590 client = await _get_rust_event_store_client()
591 response = await client.post(
592 _build_rust_runtime_internal_url("/_internal/event-store/replay"),
593 json={
594 "lastEventId": last_event_id,
595 "keyPrefix": self.key_prefix,
596 },
597 timeout=httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds),
598 follow_redirects=False,
599 )
600 response.raise_for_status()
601 payload = response.json()
602 stream_id = payload.get("streamId")
603 if not isinstance(stream_id, str) or not stream_id:
604 return None
605 for event in payload.get("events", []):
606 if not isinstance(event, dict):
607 continue
608 await send_callback(event.get("message"))
609 return stream_id
612async def _get_rust_event_store_client() -> httpx.AsyncClient:
613 """Return the HTTP client used for Python -> Rust event-store calls.
615 Returns:
616 An async HTTP client configured for Rust event-store access.
617 """
618 global _rust_event_store_client # pylint: disable=global-statement
620 uds_path = settings.experimental_rust_mcp_runtime_uds
621 if not uds_path:
622 return await get_http_client()
624 if _rust_event_store_client is not None:
625 return _rust_event_store_client
627 async with _rust_event_store_client_lock:
628 if _rust_event_store_client is None:
629 _rust_event_store_client = httpx.AsyncClient(
630 transport=httpx.AsyncHTTPTransport(uds=uds_path),
631 limits=get_http_limits(),
632 timeout=httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds),
633 follow_redirects=False,
634 )
635 return _rust_event_store_client
638def _build_rust_runtime_internal_url(path: str) -> str:
639 """Build a Rust sidecar internal URL for UDS or loopback HTTP transport.
641 Args:
642 path: Internal Rust runtime path to append to the configured base URL.
644 Returns:
645 Absolute URL targeting the Rust sidecar over HTTP or UDS-backed transport.
646 """
647 base = urlsplit(settings.experimental_rust_mcp_runtime_url)
648 base_path = base.path.rstrip("/")
649 target_path = f"{base_path}{path}" if base_path else path
650 return urlunsplit((base.scheme, base.netloc, target_path, "", ""))
653# ------------------------------ Streamable HTTP Transport ------------------------------
656@asynccontextmanager
657async def get_db() -> AsyncGenerator[Session, Any]:
658 """
659 Asynchronous context manager for database sessions.
661 Commits the transaction on successful completion to avoid implicit rollbacks
662 for read-only operations. Rolls back explicitly on exception. Handles
663 asyncio.CancelledError explicitly to prevent transaction leaks when MCP
664 handlers are cancelled (client disconnect, timeout, etc.).
666 Yields:
667 A database session instance from SessionLocal.
668 Ensures the session is closed after use.
670 Raises:
671 asyncio.CancelledError: Re-raised after rollback and close on task cancellation.
672 Exception: Re-raises any exception after rolling back the transaction.
674 Examples:
675 >>> # Test database context manager
676 >>> import asyncio
677 >>> async def test_db():
678 ... async with get_db() as db:
679 ... return db is not None
680 >>> result = asyncio.run(test_db())
681 >>> result
682 True
683 """
684 db = SessionLocal()
685 try:
686 yield db
687 db.commit()
688 except asyncio.CancelledError:
689 # Handle cancellation explicitly to prevent transaction leaks.
690 # When MCP handlers are cancelled (client disconnect, timeout, etc.),
691 # we must rollback and close the session before re-raising.
692 try:
693 db.rollback()
694 except Exception:
695 pass # nosec B110 - Best effort rollback on cancellation
696 try:
697 db.close()
698 except Exception:
699 pass # nosec B110 - Best effort close on cancellation
700 raise
701 except Exception:
702 try:
703 db.rollback()
704 except Exception:
705 try:
706 db.invalidate()
707 except Exception:
708 pass # nosec B110 - Best effort cleanup on connection failure
709 raise
710 finally:
711 db.close()
714def get_user_email_from_context() -> str:
715 """Extract user email from the current user context.
717 Returns:
718 User email address or 'unknown' if not available
719 """
720 user = user_context_var.get()
721 if isinstance(user, dict):
722 # First try 'email', then 'sub' (JWT standard claim)
723 return user.get("email") or user.get("sub") or "unknown"
724 return str(user) if user else "unknown"
727def _should_enforce_streamable_rbac(user_context: Optional[dict[str, Any]]) -> bool:
728 """Return True when request originated from authenticated Streamable HTTP middleware.
730 Direct unit tests may call MCP handlers without middleware context; those
731 invocations should preserve historical behavior and avoid forced RBAC checks.
733 Args:
734 user_context: Request user context propagated by Streamable HTTP auth middleware.
736 Returns:
737 bool: ``True`` when permission checks should be enforced for this request.
738 """
739 return isinstance(user_context, dict) and user_context.get("is_authenticated", False) is True
742def _build_resource_metadata_url(scope: Scope, server_id: str) -> str:
743 """Construct the RFC 9728 OAuth Protected Resource Metadata URL from ASGI scope.
745 Inspects ``x-forwarded-proto`` and ``host`` headers first (reverse-proxy
746 scenario), then falls back to ``scope["scheme"]`` and ``scope["server"]``.
747 Includes ``scope["root_path"]`` so that deployments behind a reverse proxy
748 with a path prefix emit the correct public URL.
750 Args:
751 scope: ASGI connection scope.
752 server_id: Virtual-server identifier.
754 Returns:
755 Fully-qualified URL string, or ``""`` if construction fails.
756 """
757 try:
758 headers = Headers(scope=scope)
759 forwarded_proto = headers.get("x-forwarded-proto")
760 if forwarded_proto:
761 proto = forwarded_proto.split(",")[0].strip().lower()
762 else:
763 proto = scope.get("scheme", "https")
764 if proto not in ("http", "https"):
765 proto = "https"
767 host = headers.get("host")
768 if not host:
769 server_tuple = scope.get("server")
770 if server_tuple:
771 host_addr, port = server_tuple
772 # Wrap IPv6 addresses in brackets per RFC 2732
773 if ":" in str(host_addr):
774 host_addr = f"[{host_addr}]"
775 default_port = 443 if proto == "https" else 80
776 host = f"{host_addr}:{port}" if port != default_port else host_addr
777 else:
778 return ""
780 root_path = scope.get("root_path", "").rstrip("/")
781 return f"{proto}://{host}{root_path}/.well-known/oauth-protected-resource/servers/{server_id}/mcp"
782 except Exception:
783 return ""
786async def _check_server_oauth_enforcement(server_id: str, user_context: Optional[dict[str, Any]]) -> None:
787 """Reject unauthenticated callers when a server requires OAuth.
789 Looks up the server's ``oauth_enabled`` flag and raises
790 ``OAuthRequiredError`` when the flag is set but the caller is not
791 authenticated. This closes the gap where OAuth capability is
792 *advertised* (via RFC 9728 ``experimental.oauth``) but never
793 *enforced* on subsequent MCP requests.
795 The result is cached in ``_oauth_checked_var`` for the lifetime of
796 the request so that handler-level defense-in-depth calls do not
797 repeat the DB query already performed by the middleware.
799 .. note::
800 SSE transport is not covered here because it already requires
801 authentication unconditionally.
803 Args:
804 server_id: Virtual-server identifier extracted from the URL path.
805 user_context: User context set by ``streamable_http_auth`` middleware.
807 Raises:
808 OAuthRequiredError: When the server requires OAuth and the caller has
809 not provided valid authentication credentials.
810 OAuthEnforcementUnavailableError: When the database or session is
811 unavailable and the server's ``oauth_enabled`` flag cannot be
812 verified (fail-closed).
813 """
814 if _oauth_checked_var.get(False):
815 return # Already checked during this request
817 if not server_id or server_id == "default_server_id":
818 return # No server context — nothing to enforce
820 is_authenticated = (user_context or {}).get("is_authenticated", False)
821 if is_authenticated:
822 _oauth_checked_var.set(True)
823 return # Already authenticated — no need to check
825 # Lazy DB lookup to avoid import-time side-effects
826 # Third-Party
827 from sqlalchemy import select # pylint: disable=import-outside-toplevel
829 # First-Party
830 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
832 try:
833 async with get_db() as db:
834 server = db.execute(select(DbServer).where(DbServer.id == server_id)).scalar_one_or_none()
835 if server and server.oauth_enabled:
836 logger.warning("OAuth required for server %s but caller is unauthenticated", server_id)
837 raise OAuthRequiredError(
838 "This server requires OAuth authentication. Please provide a valid access token.",
839 server_id=server_id,
840 )
841 _oauth_checked_var.set(True)
842 except SQLAlchemyError as exc:
843 # DB lookup failure — fail-closed for security.
844 logger.error("OAuth enforcement DB lookup failed for server %s: %s", server_id, exc)
845 raise OAuthEnforcementUnavailableError(
846 f"Unable to verify OAuth requirements for server {server_id}",
847 server_id=server_id,
848 ) from exc
851async def _check_streamable_permission(
852 *,
853 user_context: dict[str, Any],
854 permission: str,
855 allow_admin_bypass: bool = True,
856 check_any_team: bool = False,
857) -> bool:
858 """Evaluate RBAC permission for a Streamable HTTP request context.
860 Args:
861 user_context: Authenticated user context from Streamable HTTP middleware.
862 permission: Permission name to evaluate (for example ``tools.execute``).
863 allow_admin_bypass: Whether unrestricted admin tokens can bypass team checks.
864 check_any_team: Whether any matching team grants permission.
866 Returns:
867 bool: ``True`` when the caller is authorized for ``permission``.
868 """
869 user_email = user_context.get("email")
870 if not user_email:
871 return False
873 try:
874 async with get_db() as db:
875 permission_service = PermissionService(db)
876 granted = await permission_service.check_permission(
877 user_email=user_email,
878 permission=permission,
879 token_teams=user_context.get("teams"),
880 allow_admin_bypass=allow_admin_bypass,
881 check_any_team=check_any_team,
882 )
883 if not granted:
884 logger.warning("Streamable HTTP RBAC denied: user=%s, permission=%s", user_email, permission)
885 return granted
886 except Exception as exc:
887 logger.warning("Streamable HTTP RBAC check failed for %s / %s: %s", user_email, permission, exc)
888 return False
891def _check_scoped_permission(user_context: dict[str, Any], permission: str) -> bool:
892 """Check if token scoped permissions allow this operation.
894 Args:
895 user_context: User context dict (may contain 'scoped_permissions' key).
896 permission: Permission to check.
898 Returns:
899 True if allowed (no scope cap, wildcard, or permission present).
900 """
901 scoped = user_context.get("scoped_permissions")
902 if not scoped: # None or empty list = defer to RBAC
903 return True
904 if "*" in scoped:
905 return True
906 allowed = permission in scoped
907 if not allowed:
908 logger.warning("Streamable HTTP token scope denied: user=%s, required=%s", user_context.get("email"), permission)
909 return allowed
912def _check_any_team_for_server_scoped_rbac(user_context: dict[str, Any] | None, server_id: str | None) -> bool:
913 """Return whether Streamable HTTP RBAC should check across team-scoped roles.
915 Server-scoped MCP routes (``/servers/<id>/mcp``) should authorize team-bound
916 callers against the specific virtual server context. Session tokens already do
917 this via ``check_any_team=True`` because they have no single explicit team_id.
918 Team-scoped API tokens need the same treatment on server-scoped routes; otherwise
919 they are evaluated only in global scope and incorrectly denied.
921 Args:
922 user_context: Current authenticated MCP user context, if any.
923 server_id: Effective virtual server identifier for the MCP request.
925 Returns:
926 ``True`` when RBAC should search across the caller's token teams.
927 """
928 if not user_context:
929 return False
930 if user_context.get("token_use") == "session":
931 return True
932 return bool(server_id) and bool(user_context.get("teams"))
935def set_shared_session_registry(session_registry: Any) -> None:
936 """Set the process-wide session registry used by Streamable HTTP helpers.
938 Args:
939 session_registry: Registry instance created by application bootstrap.
940 """
941 global _shared_session_registry # pylint: disable=global-statement
942 _shared_session_registry = session_registry
945def _get_shared_session_registry() -> Optional[Any]:
946 """Return the process-wide session registry reference.
948 Returns:
949 Optional[Any]: Session registry instance, or ``None`` when unavailable.
950 """
951 return _shared_session_registry
954async def _claim_streamable_session_owner(session_id: str, owner_email: str) -> Optional[str]:
955 """Claim or resolve the logical owner for a Streamable HTTP session.
957 Args:
958 session_id: Logical MCP session identifier to claim.
959 owner_email: Caller email that should own the session.
961 Returns:
962 Optional[str]: Effective owner email after claim, or ``None`` if unavailable.
963 """
964 if not session_id or not owner_email:
965 return None
967 session_registry = _get_shared_session_registry()
968 if session_registry is None:
969 return None
971 try:
972 return await session_registry.claim_session_owner(session_id, owner_email)
973 except Exception as exc:
974 logger.warning("Failed to claim session owner for %s: %s", session_id, exc)
975 return None
978async def _validate_streamable_session_access(
979 *,
980 mcp_session_id: Optional[str],
981 user_context: Optional[dict[str, Any]],
982 rpc_method: Optional[str] = None,
983) -> tuple[bool, int, str]:
984 """Authorize access to a stateful Streamable HTTP session identifier.
986 Args:
987 mcp_session_id: Session identifier from request headers.
988 user_context: Authenticated user context for the current request.
989 rpc_method: JSON-RPC method name when available.
991 Returns:
992 Tuple ``(allowed, deny_status_code, deny_message)``.
993 """
994 if not settings.use_stateful_sessions:
995 return True, 200, ""
997 if not mcp_session_id or mcp_session_id == "not-provided":
998 return True, 200, ""
1000 if not _should_enforce_streamable_rbac(user_context):
1001 return True, 200, ""
1003 if isinstance(user_context, dict) and user_context.get("_rust_session_validated") is True:
1004 return True, 200, ""
1006 # Initialize establishes a new session and is authorized separately.
1007 if (rpc_method or "").strip() == "initialize":
1008 return True, 200, ""
1010 requester_email = user_context.get("email") if isinstance(user_context, dict) else None
1011 requester_is_admin = bool(user_context.get("is_admin", False)) if isinstance(user_context, dict) else False
1013 session_registry = _get_shared_session_registry()
1014 if session_registry is None:
1015 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
1017 try:
1018 session_owner = await session_registry.get_session_owner(mcp_session_id)
1019 except Exception as exc:
1020 logger.warning("Failed to get session owner for %s: %s", mcp_session_id, exc)
1021 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
1023 if session_owner:
1024 if requester_is_admin:
1025 return True, 200, ""
1026 if requester_email and requester_email == session_owner:
1027 return True, 200, ""
1028 return False, HTTP_403_FORBIDDEN, "Session access denied"
1030 try:
1031 session_exists = await session_registry.session_exists(mcp_session_id)
1032 except Exception as exc:
1033 logger.warning("Failed to check session existence for %s: %s", mcp_session_id, exc)
1034 return False, HTTP_403_FORBIDDEN, "Session ownership unavailable"
1036 if session_exists is False:
1037 return False, HTTP_404_NOT_FOUND, "Session not found"
1038 return False, HTTP_403_FORBIDDEN, "Session owner metadata unavailable"
1041async def _proxy_list_tools_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Tool]: # pylint: disable=unused-argument
1042 """Proxy tools/list request directly to remote MCP gateway using MCP SDK.
1044 Args:
1045 gateway: Gateway ORM instance
1046 request_headers: Request headers from client
1047 user_context: User context (not used - _meta comes from MCP SDK)
1048 meta: Request metadata (_meta) from the original request
1050 Returns:
1051 List of Tool objects from remote server
1052 """
1053 try:
1054 # Prepare headers with gateway auth
1055 headers = build_gateway_auth_headers(gateway)
1057 # Forward passthrough headers using shared utility (includes X-Upstream-Authorization rename)
1058 if request_headers:
1059 gw_passthrough = gateway.passthrough_headers if hasattr(gateway, "passthrough_headers") and gateway.passthrough_headers is not None else None
1060 if gw_passthrough is not None:
1061 passthrough_allowed = gw_passthrough
1062 else:
1063 with SessionLocal() as db:
1064 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
1065 headers = compute_passthrough_headers_cached(
1066 request_headers,
1067 headers,
1068 passthrough_allowed,
1069 gateway_auth_type=gateway.auth_type if hasattr(gateway, "auth_type") else None,
1070 gateway_passthrough_headers=gw_passthrough,
1071 )
1073 # Use MCP SDK to connect and list tools
1074 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
1075 async with ClientSession(read_stream, write_stream) as session:
1076 await session.initialize()
1078 # Prepare params with _meta if provided
1079 params = None
1080 if meta:
1081 params = PaginatedRequestParams(_meta=meta)
1082 logger.debug("Forwarding _meta to remote gateway: %s", meta)
1084 # List tools with _meta forwarded
1085 result = await session.list_tools(params=params)
1086 return result.tools
1088 except Exception as e:
1089 logger.exception("Error proxying tools/list to gateway %s: %s", gateway.id, e)
1090 return []
1093async def _proxy_list_resources_to_gateway(gateway: Any, request_headers: dict, user_context: dict, meta: Optional[Any] = None) -> List[types.Resource]: # pylint: disable=unused-argument
1094 """Proxy resources/list request directly to remote MCP gateway using MCP SDK.
1096 Args:
1097 gateway: Gateway ORM instance
1098 request_headers: Request headers from client
1099 user_context: User context (not used - _meta comes from MCP SDK)
1100 meta: Request metadata (_meta) from the original request
1102 Returns:
1103 List of Resource objects from remote server
1104 """
1105 try:
1106 # Prepare headers with gateway auth
1107 headers = build_gateway_auth_headers(gateway)
1109 # Forward passthrough headers using shared utility (includes X-Upstream-Authorization rename)
1110 if request_headers:
1111 gw_passthrough = gateway.passthrough_headers if hasattr(gateway, "passthrough_headers") and gateway.passthrough_headers is not None else None
1112 if gw_passthrough is not None:
1113 passthrough_allowed = gw_passthrough
1114 else:
1115 with SessionLocal() as db:
1116 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
1117 headers = compute_passthrough_headers_cached(
1118 request_headers,
1119 headers,
1120 passthrough_allowed,
1121 gateway_auth_type=gateway.auth_type if hasattr(gateway, "auth_type") else None,
1122 gateway_passthrough_headers=gw_passthrough,
1123 )
1125 logger.info("Proxying resources/list to gateway %s at %s", gateway.id, gateway.url)
1126 if meta:
1127 logger.debug("Forwarding _meta to remote gateway: %s", meta)
1129 # Use MCP SDK to connect and list resources
1130 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
1131 async with ClientSession(read_stream, write_stream) as session:
1132 await session.initialize()
1134 # Prepare params with _meta if provided
1135 params = None
1136 if meta:
1137 params = PaginatedRequestParams(_meta=meta)
1138 logger.debug("Forwarding _meta to remote gateway: %s", meta)
1140 # List resources with _meta forwarded
1141 result = await session.list_resources(params=params)
1143 logger.info("Received %s resources from gateway %s", len(result.resources), gateway.id)
1144 return result.resources
1146 except Exception as e:
1147 logger.exception("Error proxying resources/list to gateway %s: %s", gateway.id, e)
1148 return []
1151async def _proxy_read_resource_to_gateway(gateway: Any, resource_uri: str, user_context: dict, meta: Optional[Any] = None) -> List[Any]: # pylint: disable=unused-argument
1152 """Proxy resources/read request directly to remote MCP gateway using MCP SDK.
1154 Args:
1155 gateway: Gateway ORM instance
1156 resource_uri: URI of the resource to read
1157 user_context: User context (not used - auth comes from gateway config)
1158 meta: Request metadata (_meta) from the original request
1160 Returns:
1161 List of content objects (TextResourceContents or BlobResourceContents) from remote server
1162 """
1163 try:
1164 # Prepare headers with gateway auth
1165 headers = build_gateway_auth_headers(gateway)
1167 # Get request headers
1168 request_headers = request_headers_var.get()
1170 # Forward X-Context-Forge-Gateway-Id header
1171 gw_id = extract_gateway_id_from_headers(request_headers)
1172 if gw_id:
1173 headers[GATEWAY_ID_HEADER] = gw_id
1175 # Forward passthrough headers using shared utility (includes X-Upstream-Authorization rename)
1176 if request_headers:
1177 gw_passthrough = gateway.passthrough_headers if hasattr(gateway, "passthrough_headers") and gateway.passthrough_headers is not None else None
1178 if gw_passthrough is not None:
1179 passthrough_allowed = gw_passthrough
1180 else:
1181 with SessionLocal() as db:
1182 passthrough_allowed = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
1183 headers = compute_passthrough_headers_cached(
1184 request_headers,
1185 headers,
1186 passthrough_allowed,
1187 gateway_auth_type=gateway.auth_type if hasattr(gateway, "auth_type") else None,
1188 gateway_passthrough_headers=gw_passthrough,
1189 )
1191 logger.info("Proxying resources/read for %s to gateway %s at %s", resource_uri, gateway.id, gateway.url)
1192 if meta:
1193 logger.debug("Forwarding _meta to remote gateway: %s", meta)
1195 # Use MCP SDK to connect and read resource
1196 async with streamablehttp_client(url=gateway.url, headers=headers, timeout=settings.mcpgateway_direct_proxy_timeout) as (read_stream, write_stream, _get_session_id):
1197 async with ClientSession(read_stream, write_stream) as session:
1198 await session.initialize()
1200 # Prepare request params with _meta if provided
1201 if meta:
1202 # Create params and inject _meta
1203 request_params = ReadResourceRequestParams(uri=resource_uri)
1204 request_params_dict = request_params.model_dump()
1205 request_params_dict["_meta"] = meta
1207 # Send request with _meta
1208 result = await session.send_request(
1209 types.ClientRequest(ReadResourceRequest(params=ReadResourceRequestParams.model_validate(request_params_dict))),
1210 types.ReadResourceResult,
1211 )
1212 else:
1213 # No _meta, use simple read_resource
1214 result = await session.read_resource(uri=resource_uri)
1216 logger.info("Received %s content items from gateway %s for resource %s", len(result.contents), gateway.id, resource_uri)
1217 return result.contents
1219 except Exception as e:
1220 logger.exception("Error proxying resources/read to gateway %s for resource %s: %s", gateway.id, resource_uri, e)
1221 return []
1224@mcp_app.call_tool(validate_input=False)
1225async def call_tool(name: str, arguments: dict) -> Union[
1226 types.CallToolResult,
1227 List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]],
1228 Tuple[List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]], Dict[str, Any]],
1229]:
1230 """
1231 Handles tool invocation via the MCP Server.
1233 Note: validate_input=False disables the MCP SDK's built-in JSON Schema validation.
1234 This is necessary because the SDK uses jsonschema.validate() which internally calls
1235 check_schema() with the default validator. Schemas using older draft features
1236 (e.g., Draft 4 style exclusiveMinimum: true) fail this validation. The gateway
1237 handles schema validation separately in tool_service.py with multi-draft support.
1239 This function supports the MCP protocol's tool calling with structured content validation.
1240 In direct_proxy mode, returns the raw CallToolResult from the remote server.
1241 In normal mode, converts ToolResult to CallToolResult with content normalization.
1243 Args:
1244 name (str): The name of the tool to invoke.
1245 arguments (dict): A dictionary of arguments to pass to the tool.
1247 Returns:
1248 types.CallToolResult: MCP SDK CallToolResult with content and optional structuredContent.
1250 Raises:
1251 PermissionError: If the caller lacks ``tools.execute`` permission.
1252 Exception: Re-raised after logging to allow MCP SDK to convert to JSON-RPC error response.
1254 Examples:
1255 >>> # Test call_tool function signature
1256 >>> import inspect
1257 >>> sig = inspect.signature(call_tool)
1258 >>> list(sig.parameters.keys())
1259 ['name', 'arguments']
1260 >>> sig.parameters['name'].annotation
1261 <class 'str'>
1262 >>> sig.parameters['arguments'].annotation
1263 <class 'dict'>
1264 """
1265 server_id, request_headers, user_context = await _get_request_context_or_default()
1267 meta_data = None
1268 # Extract _meta from request context if available
1269 try:
1270 ctx = mcp_app.request_context
1271 if ctx and ctx.meta is not None:
1272 meta_data = ctx.meta.model_dump()
1273 except LookupError:
1274 # request_context might not be active in some edge cases (e.g. tests)
1275 logger.debug("No active request context found")
1277 # Extract authorization parameters from user context (same pattern as list_tools)
1278 user_email = user_context.get("email") if user_context else None
1279 token_teams = user_context.get("teams") if user_context else None
1280 is_admin = user_context.get("is_admin", False) if user_context else False
1282 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1283 # If token has explicit team scope (even empty [] for public-only), respect it
1284 if is_admin and token_teams is None:
1285 user_email = None
1286 # token_teams stays None (unrestricted)
1287 elif token_teams is None:
1288 token_teams = [] # Non-admin without teams = public-only (secure default)
1290 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1291 # When mcp_require_auth=True, the middleware already guarantees authentication.
1292 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1293 # the middleware (streamable_http_auth) catches it and returns 503. If the
1294 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1295 # logged by the ASGI server.
1296 if not settings.mcp_require_auth:
1297 await _check_server_oauth_enforcement(server_id, user_context)
1299 if _should_enforce_streamable_rbac(user_context):
1300 # Layer 1: Token scope cap
1301 if not _check_scoped_permission(user_context, "tools.execute"):
1302 raise PermissionError(_ACCESS_DENIED_MSG)
1303 # Layer 2: RBAC check
1304 # Session tokens have no explicit team_id; check across all team-scoped roles.
1305 # Mirrors the @require_permission decorator's check_any_team fallback (rbac.py:562-576).
1306 has_execute_permission = await _check_streamable_permission(
1307 user_context=user_context,
1308 permission="tools.execute",
1309 check_any_team=_check_any_team_for_server_scoped_rbac(user_context, server_id),
1310 )
1311 if not has_execute_permission:
1312 raise PermissionError(_ACCESS_DENIED_MSG)
1314 # Check if we're in direct_proxy mode by looking for X-Context-Forge-Gateway-Id header
1315 gateway_id_from_header = extract_gateway_id_from_headers(request_headers)
1317 # If X-Context-Forge-Gateway-Id header is present, use direct proxy mode
1318 if gateway_id_from_header:
1319 try: # Check if this gateway is in direct_proxy mode
1320 async with get_db() as check_db:
1321 # Third-Party
1322 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1324 # First-Party
1325 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1327 gateway = check_db.execute(select(DbGateway).where(DbGateway.id == gateway_id_from_header)).scalar_one_or_none()
1328 if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1329 # SECURITY: Check gateway access before allowing direct proxy
1330 if not await check_gateway_access(check_db, gateway, user_email, token_teams):
1331 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id_from_header, user_email)
1332 return types.CallToolResult(content=[types.TextContent(type="text", text=f"Tool not found: {name}")], isError=True)
1334 logger.info("Using direct_proxy mode for tool '%s' via gateway %s", name, gateway_id_from_header)
1336 # Use direct proxy method - returns raw CallToolResult from remote server
1337 # Return it directly without any normalization
1338 return await tool_service.invoke_tool_direct(
1339 gateway_id=gateway_id_from_header,
1340 name=name,
1341 arguments=arguments,
1342 request_headers=request_headers,
1343 meta_data=meta_data,
1344 user_email=user_email,
1345 token_teams=token_teams,
1346 )
1347 except Exception as e:
1348 logger.error("Direct proxy mode failed for gateway %s: %s", gateway_id_from_header, e)
1349 return types.CallToolResult(content=[types.TextContent(type="text", text="Direct proxy tool invocation failed")], isError=True)
1351 # Normal mode: use standard tool invocation with normalization
1352 # Use the already-recovered user_context (works for both ContextVar and stateful session paths)
1353 app_user_email = (user_context.get("email") or user_context.get("sub") or "unknown") if user_context else "unknown"
1355 # Multi-worker session affinity: check if we should forward to another worker
1356 # Check both x-mcp-session-id (internal/forwarded) and mcp-session-id (client protocol header)
1357 mcp_session_id = None
1358 if request_headers:
1359 request_headers_lower = {k.lower(): v for k, v in request_headers.items()}
1360 mcp_session_id = request_headers_lower.get("x-mcp-session-id") or request_headers_lower.get("mcp-session-id")
1361 if settings.mcpgateway_session_affinity_enabled and mcp_session_id:
1362 try:
1363 # First-Party
1364 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel
1365 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel
1366 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
1368 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
1369 logger.debug("Invalid MCP session id for Streamable HTTP tool affinity, executing locally")
1370 raise RuntimeError("invalid mcp session id")
1372 pool = get_mcp_session_pool()
1374 # Register session mapping BEFORE checking forwarding (same pattern as SSE)
1375 # This ensures ownership is registered atomically so forward_request_to_owner() works
1376 try:
1377 cached = await tool_lookup_cache.get(name)
1378 if cached and cached.get("status") == "active":
1379 gateway_info = cached.get("gateway")
1380 if gateway_info:
1381 url = gateway_info.get("url")
1382 gateway_id = gateway_info.get("id", "")
1383 transport_type = gateway_info.get("transport", "streamablehttp")
1384 if url:
1385 await pool.register_session_mapping(mcp_session_id, url, gateway_id, transport_type, user_email)
1386 except Exception as e:
1387 logger.error("Failed to pre-register session mapping for Streamable HTTP: %s", e)
1389 forwarded_response = await pool.forward_request_to_owner(
1390 mcp_session_id,
1391 {"method": "tools/call", "params": {"name": name, "arguments": arguments, "_meta": meta_data}, "headers": dict(request_headers) if request_headers else {}},
1392 )
1393 if forwarded_response is not None:
1394 # Request was handled by another worker - convert response to expected format
1395 if "error" in forwarded_response:
1396 raise Exception(forwarded_response["error"].get("message", "Forwarded request failed")) # pylint: disable=broad-exception-raised
1397 result_data = forwarded_response.get("result", {})
1399 def _rehydrate_content_items(items: Any) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource]:
1400 """Convert forwarded tool result items back to MCP content types.
1402 Args:
1403 items: List of content item dicts from forwarded response.
1405 Returns:
1406 List of validated MCP content type instances.
1407 """
1408 if not isinstance(items, list):
1409 return []
1410 converted: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
1411 for item in items:
1412 if not isinstance(item, dict):
1413 continue
1414 item_type = item.get("type")
1415 try:
1416 if item_type == "text":
1417 converted.append(types.TextContent.model_validate(item))
1418 elif item_type == "image":
1419 converted.append(types.ImageContent.model_validate(item))
1420 elif item_type == "audio":
1421 converted.append(types.AudioContent.model_validate(item))
1422 elif item_type == "resource_link":
1423 converted.append(types.ResourceLink.model_validate(item))
1424 elif item_type == "resource":
1425 converted.append(types.EmbeddedResource.model_validate(item))
1426 else:
1427 converted.append(types.TextContent(type="text", text=item if isinstance(item, str) else orjson.dumps(item).decode()))
1428 except Exception:
1429 converted.append(types.TextContent(type="text", text=item if isinstance(item, str) else orjson.dumps(item).decode()))
1430 return converted
1432 unstructured = _rehydrate_content_items(result_data.get("content", []))
1433 structured = result_data.get("structuredContent") or result_data.get("structured_content")
1434 if structured:
1435 return (unstructured, structured)
1436 return unstructured
1437 except RuntimeError:
1438 # Pool not initialized - execute locally
1439 pass
1441 try:
1442 async with get_db() as db:
1443 # Use tool service for all tool invocations (handles direct_proxy internally)
1444 result = await tool_service.invoke_tool(
1445 db=db,
1446 name=name,
1447 arguments=arguments,
1448 request_headers=request_headers,
1449 app_user_email=app_user_email,
1450 user_email=user_email,
1451 token_teams=token_teams,
1452 server_id=server_id,
1453 meta_data=meta_data,
1454 )
1455 if not result or not result.content:
1456 logger.warning("No content returned by tool: %s", name)
1457 return []
1459 # Normalize unstructured content to MCP SDK types, preserving metadata (annotations, _meta, size)
1460 # Helper to convert gateway Annotations to dict for MCP SDK compatibility
1461 # (mcpgateway.common.models.Annotations != mcp.types.Annotations)
1462 def _convert_annotations(ann: Any) -> dict[str, Any] | None:
1463 """Convert gateway Annotations to dict for MCP SDK compatibility.
1465 Args:
1466 ann: Gateway Annotations object, dict, or None.
1468 Returns:
1469 Dict representation of annotations, or None.
1470 """
1471 if ann is None:
1472 return None
1473 if isinstance(ann, dict):
1474 return ann
1475 if hasattr(ann, "model_dump"):
1476 return ann.model_dump(by_alias=True, mode="json")
1477 return None
1479 def _convert_meta(meta: Any) -> dict[str, Any] | None:
1480 """Convert gateway meta to dict for MCP SDK compatibility.
1482 Args:
1483 meta: Gateway meta object, dict, or None.
1485 Returns:
1486 Dict representation of meta, or None.
1487 """
1488 if meta is None:
1489 return None
1490 if isinstance(meta, dict):
1491 return meta
1492 if hasattr(meta, "model_dump"):
1493 return meta.model_dump(by_alias=True, mode="json")
1494 return None
1496 unstructured: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = []
1497 for content in result.content:
1498 if content.type == "text":
1499 unstructured.append(
1500 types.TextContent(
1501 type="text",
1502 text=content.text,
1503 annotations=_convert_annotations(getattr(content, "annotations", None)),
1504 _meta=_convert_meta(getattr(content, "meta", None)),
1505 )
1506 )
1507 elif content.type == "image":
1508 unstructured.append(
1509 types.ImageContent(
1510 type="image",
1511 data=content.data,
1512 mimeType=content.mime_type,
1513 annotations=_convert_annotations(getattr(content, "annotations", None)),
1514 _meta=_convert_meta(getattr(content, "meta", None)),
1515 )
1516 )
1517 elif content.type == "audio":
1518 unstructured.append(
1519 types.AudioContent(
1520 type="audio",
1521 data=content.data,
1522 mimeType=content.mime_type,
1523 annotations=_convert_annotations(getattr(content, "annotations", None)),
1524 _meta=_convert_meta(getattr(content, "meta", None)),
1525 )
1526 )
1527 elif content.type == "resource_link":
1528 unstructured.append(
1529 types.ResourceLink(
1530 type="resource_link",
1531 uri=content.uri,
1532 name=content.name,
1533 description=getattr(content, "description", None),
1534 mimeType=getattr(content, "mime_type", None),
1535 size=getattr(content, "size", None),
1536 _meta=_convert_meta(getattr(content, "meta", None)),
1537 )
1538 )
1539 elif content.type == "resource":
1540 # EmbeddedResource - pass through the model dump as the MCP SDK type requires complex nested structure
1541 unstructured.append(types.EmbeddedResource.model_validate(content.model_dump(by_alias=True, mode="json")))
1542 else:
1543 # Unknown content type - convert to text representation
1544 unstructured.append(types.TextContent(type="text", text=orjson.dumps(content.model_dump(by_alias=True, mode="json")).decode()))
1546 # If the tool produced structured content (ToolResult.structured_content / structuredContent),
1547 # return a combination (unstructured, structured) so the server can validate against outputSchema.
1548 # The ToolService may populate structured_content (snake_case) or the model may expose
1549 # an alias 'structuredContent' when dumped via model_dump(by_alias=True).
1550 structured = None
1551 try:
1552 # Prefer attribute if present
1553 structured = getattr(result, "structured_content", None)
1554 except Exception:
1555 structured = None
1557 # Fallback to by-alias dump (in case the result is a pydantic model with alias fields)
1558 if structured is None:
1559 try:
1560 structured = result.model_dump(by_alias=True).get("structuredContent") if hasattr(result, "model_dump") else None
1561 except Exception:
1562 structured = None
1564 if structured:
1565 return (unstructured, structured)
1567 return unstructured
1568 except Exception as e:
1569 logger.exception("Error calling tool '%s': %s", name, e)
1570 # Re-raise the exception so the MCP SDK can properly convert it to an error response
1571 # This ensures error details are propagated to the client instead of returning empty results
1572 raise
1575async def _get_request_context_or_default() -> Tuple[str, dict[str, Any], dict[str, Any]]:
1576 """Retrieves request context information for the current execution.
1578 This function resolves request context using the following precedence:
1580 1. Context variables (fast path). Used when the handler executes in the
1581 same async context as the middleware (for example, direct ASGI dispatch).
1582 2. ASGI scope. The middleware stores resolved context on
1583 ``scope[_MCPGATEWAY_CONTEXT_KEY]`` before handing off to the MCP SDK.
1584 Because the SDK passes the same ``scope`` dictionary through to
1585 ``mcp_app.request_context.request``, this survives task-group
1586 boundaries where ContextVars may be lost.
1587 3. Re-authentication fallback. Re-extracts identity from the request's
1588 Authorization header or cookies. This is the most expensive path and
1589 may produce a different context shape for anonymous callers (an empty
1590 dictionary instead of the middleware's canonical
1591 ``{"is_authenticated": False, ...}`` structure).
1593 Returns:
1594 Tuple[str, dict[str, Any], dict[str, Any]]: A tuple containing:
1596 - server_id: The resolved server identifier.
1597 - request_headers: The request headers as a dictionary.
1598 - user_context: The resolved user context dictionary.
1599 """
1600 # 1. Try context vars first (fast path)
1601 s_id = server_id_var.get()
1603 # Check if context vars are populated with real data (not defaults)
1604 if s_id != "default_server_id":
1605 return s_id, request_headers_var.get(), user_context_var.get()
1607 # 2. Try ASGI scope context injected by handle_streamable_http()
1608 ctx = None
1609 try:
1610 ctx = mcp_app.request_context
1611 request = ctx.request
1612 if request:
1613 gw_ctx = getattr(request, "scope", {}).get(_MCPGATEWAY_CONTEXT_KEY)
1614 if isinstance(gw_ctx, dict):
1615 return (
1616 gw_ctx.get("server_id") or s_id,
1617 gw_ctx.get("request_headers", {}),
1618 gw_ctx.get("user_context", {}),
1619 )
1620 except LookupError:
1621 # Not in a request context — fall through to ContextVar defaults
1622 return s_id, request_headers_var.get(), user_context_var.get()
1623 except Exception as e:
1624 logger.debug("Failed to read %s from scope: %s", _MCPGATEWAY_CONTEXT_KEY, e)
1626 # 3. Re-authentication fallback (stateful session path)
1627 try:
1628 # Reuse ctx from the scope-reading block above (step 2) to avoid
1629 # a redundant mcp_app.request_context lookup.
1630 if ctx is None:
1631 ctx = mcp_app.request_context
1632 request = ctx.request
1633 if not request:
1634 logger.warning("No request object found in MCP context")
1635 return s_id, request_headers_var.get(), user_context_var.get()
1637 # Extract server_id from URL
1638 path = request.url.path
1639 match = _SERVER_ID_RE.search(path)
1640 if match:
1641 s_id = match.group("server_id")
1643 # Extract headers
1644 req_headers = dict(request.headers)
1646 # Extract and verify user context
1647 # Use require_auth_header_first to match streamable_http_auth token precedence:
1648 # Authorization header > request cookies > jwt_token parameter
1649 auth_header = req_headers.get("authorization")
1650 cookie_token = request.cookies.get("jwt_token")
1652 try:
1653 raw_payload = await require_auth_header_first(auth_header=auth_header, jwt_token=cookie_token, request=request)
1654 if isinstance(raw_payload, str): # "anonymous"
1655 user_ctx = {}
1656 elif isinstance(raw_payload, dict):
1657 # Normalize raw JWT payload to canonical user context shape
1658 # (matches streamable_http_auth normalization at lines 2155-2259)
1659 user_ctx = await _normalize_jwt_payload(raw_payload)
1660 else:
1661 user_ctx = {}
1662 except Exception as e:
1663 logger.warning("Failed to recover user context in stateful session: %s", e)
1664 user_ctx = {}
1666 return s_id, req_headers, user_ctx
1668 except LookupError:
1669 # Not in a request context
1670 return s_id, request_headers_var.get(), user_context_var.get()
1671 except Exception as e:
1672 logger.exception("Error recovering context in stateful session: %s", e)
1673 return s_id, request_headers_var.get(), user_context_var.get()
1676async def _normalize_jwt_payload(payload: dict[str, Any]) -> dict[str, Any]:
1677 """Normalize a raw JWT payload to the canonical user context shape.
1679 Converts raw JWT fields (sub, token_use, nested user.is_admin) into the
1680 canonical ``{email, teams, is_admin, is_authenticated, token_use}`` dict that MCP
1681 handlers expect. This mirrors the normalization performed by
1682 ``streamable_http_auth`` so that the stateful-session fallback path in
1683 ``_get_request_context_or_default`` returns an identical shape.
1685 Args:
1686 payload: Raw JWT payload dict from ``require_auth_header_first``.
1688 Returns:
1689 Canonical user context dict with keys email, teams, is_admin, is_authenticated, token_use.
1690 """
1691 email = payload.get("sub") or payload.get("email")
1692 is_admin = payload.get("is_admin", False)
1693 if not is_admin:
1694 user_info = payload.get("user", {})
1695 is_admin = user_info.get("is_admin", False) if isinstance(user_info, dict) else False
1697 token_use = payload.get("token_use")
1698 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type
1699 # Session token: resolve teams from DB/cache via single policy point
1700 # First-Party
1701 from mcpgateway.auth import resolve_session_teams # pylint: disable=import-outside-toplevel
1703 final_teams = await resolve_session_teams(payload, email, {"is_admin": is_admin})
1704 else:
1705 # API token or legacy: use embedded teams from JWT
1706 # First-Party
1707 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel
1709 final_teams = normalize_token_teams(payload)
1711 user_ctx: dict[str, Any] = {
1712 "email": email,
1713 "teams": final_teams,
1714 "is_admin": is_admin,
1715 "is_authenticated": True,
1716 "token_use": token_use,
1717 }
1718 # Extract scoped permissions from JWT for per-method enforcement
1719 scopes = payload.get("scopes") or {}
1720 scoped_perms = scopes.get("permissions") or [] if isinstance(scopes, dict) else []
1721 if scoped_perms:
1722 user_ctx["scoped_permissions"] = scoped_perms
1723 return user_ctx
1726@mcp_app.list_tools()
1727async def list_tools() -> List[types.Tool]:
1728 """
1729 Lists all tools available to the MCP Server.
1731 Supports two modes based on gateway's gateway_mode:
1732 - 'cache': Returns tools from database (default behavior)
1733 - 'direct_proxy': Proxies the request directly to the remote MCP server
1735 Returns:
1736 A list of Tool objects containing metadata such as name, description, and input schema.
1737 Logs and returns an empty list on failure.
1739 Raises:
1740 PermissionError: If the caller lacks ``tools.read`` permission.
1742 Examples:
1743 >>> # Test list_tools function signature
1744 >>> import inspect
1745 >>> sig = inspect.signature(list_tools)
1746 >>> list(sig.parameters.keys())
1747 []
1748 >>> sig.return_annotation
1749 typing.List[mcp.types.Tool]
1750 """
1751 server_id, request_headers, user_context = await _get_request_context_or_default()
1753 # Token scope cap: deny early if scoped permissions exclude tools.read
1754 if _should_enforce_streamable_rbac(user_context):
1755 if not _check_scoped_permission(user_context, "tools.read"):
1756 raise PermissionError(_ACCESS_DENIED_MSG)
1758 # Extract filtering parameters from user context
1759 user_email = user_context.get("email") if user_context else None
1760 # Use None as default to distinguish "no teams specified" from "empty teams array"
1761 token_teams = user_context.get("teams") if user_context else None
1762 is_admin = user_context.get("is_admin", False) if user_context else False
1764 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1765 # If token has explicit team scope (even empty [] for public-only), respect it
1766 if is_admin and token_teams is None:
1767 user_email = None
1768 # token_teams stays None (unrestricted)
1769 elif token_teams is None:
1770 token_teams = [] # Non-admin without teams = public-only (secure default)
1772 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1773 # When mcp_require_auth=True, the middleware already guarantees authentication.
1774 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1775 # the middleware (streamable_http_auth) catches it and returns 503. If the
1776 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1777 # logged by the ASGI server.
1778 if not settings.mcp_require_auth:
1779 await _check_server_oauth_enforcement(server_id, user_context)
1781 if server_id:
1782 try:
1783 async with get_db() as db:
1784 # Check for X-Context-Forge-Gateway-Id header first - if present, try direct proxy mode
1785 gateway_id = extract_gateway_id_from_headers(request_headers)
1787 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
1788 if gateway_id:
1789 # Third-Party
1790 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1792 # First-Party
1793 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
1795 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
1796 if gateway and getattr(gateway, "gateway_mode", "cache") == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
1797 # SECURITY: Check gateway access before allowing direct proxy
1798 if not await check_gateway_access(db, gateway, user_email, token_teams):
1799 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
1800 return [] # Return empty list for unauthorized access
1802 # Direct proxy mode: forward request to remote MCP server
1803 # Get _meta from request context if available
1804 meta = None
1805 try:
1806 request_ctx = mcp_app.request_context
1807 meta = request_ctx.meta
1808 logger.info(
1809 "[LIST TOOLS] Using direct_proxy mode for server %s, gateway %s (from %s header). Meta Attached: %s",
1810 server_id,
1811 gateway.id,
1812 GATEWAY_ID_HEADER,
1813 meta is not None,
1814 )
1815 except (LookupError, AttributeError) as e:
1816 logger.debug("No request context available for _meta extraction: %s", e)
1818 return await _proxy_list_tools_to_gateway(gateway, request_headers, user_context, meta)
1819 if gateway:
1820 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, getattr(gateway, "gateway_mode", "cache"))
1821 else:
1822 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
1824 # Check if server exists for cache mode
1825 # Third-Party
1826 from sqlalchemy import select # pylint: disable=import-outside-toplevel
1828 # First-Party
1829 from mcpgateway.db import Server as DbServer # pylint: disable=import-outside-toplevel
1831 server = db.execute(select(DbServer).where(DbServer.id == server_id)).scalar_one_or_none()
1832 if not server:
1833 logger.warning("Server %s not found in database", server_id)
1834 return []
1836 # Default cache mode: use database
1837 tools = await tool_service.list_server_tools(db, server_id, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
1838 return [
1839 types.Tool(
1840 name=tool.name,
1841 title=_safe_str_attr(tool, "title"),
1842 description=tool.description or "",
1843 inputSchema=tool.input_schema,
1844 outputSchema=tool.output_schema,
1845 annotations=tool.annotations,
1846 )
1847 for tool in tools
1848 ]
1849 except Exception as e:
1850 logger.error("Error listing tools:%s", e)
1851 return []
1852 else:
1853 try:
1854 async with get_db() as db:
1855 tools, _ = await tool_service.list_tools(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams, _request_headers=request_headers)
1856 return [
1857 types.Tool(
1858 name=tool.name,
1859 title=_safe_str_attr(tool, "title"),
1860 description=tool.description or "",
1861 inputSchema=tool.input_schema,
1862 outputSchema=tool.output_schema,
1863 annotations=tool.annotations,
1864 )
1865 for tool in tools
1866 ]
1867 except Exception as e:
1868 logger.exception("Error listing tools:%s", e)
1869 return []
1872@mcp_app.list_prompts()
1873async def list_prompts() -> List[types.Prompt]:
1874 """
1875 Lists all prompts available to the MCP Server.
1877 Returns:
1878 A list of Prompt objects containing metadata such as name, description, and arguments.
1879 Logs and returns an empty list on failure.
1881 Raises:
1882 PermissionError: If the user context indicates insufficient permissions (e.g., missing "prompts.read" scope).
1884 Examples:
1885 >>> import inspect
1886 >>> sig = inspect.signature(list_prompts)
1887 >>> list(sig.parameters.keys())
1888 []
1889 >>> sig.return_annotation
1890 typing.List[mcp.types.Prompt]
1891 """
1892 server_id, _, user_context = await _get_request_context_or_default()
1894 # Token scope cap: deny early if scoped permissions exclude prompts.read
1895 if _should_enforce_streamable_rbac(user_context):
1896 if not _check_scoped_permission(user_context, "prompts.read"):
1897 raise PermissionError(_ACCESS_DENIED_MSG)
1899 # Extract filtering parameters from user context
1900 user_email = user_context.get("email") if user_context else None
1901 # Use None as default to distinguish "no teams specified" from "empty teams array"
1902 token_teams = user_context.get("teams") if user_context else None
1903 is_admin = user_context.get("is_admin", False) if user_context else False
1905 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1906 # If token has explicit team scope (even empty [] for public-only), respect it
1907 if is_admin and token_teams is None:
1908 user_email = None
1909 # token_teams stays None (unrestricted)
1910 elif token_teams is None:
1911 token_teams = [] # Non-admin without teams = public-only (secure default)
1913 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1914 # When mcp_require_auth=True, the middleware already guarantees authentication.
1915 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1916 # the middleware (streamable_http_auth) catches it and returns 503. If the
1917 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1918 # logged by the ASGI server.
1919 if not settings.mcp_require_auth:
1920 await _check_server_oauth_enforcement(server_id, user_context)
1922 if server_id:
1923 try:
1924 async with get_db() as db:
1925 prompts = await prompt_service.list_server_prompts(db, server_id, user_email=user_email, token_teams=token_teams)
1926 return [_to_mcp_prompt(prompt) for prompt in prompts]
1927 except Exception as e:
1928 logger.exception("Error listing Prompts:%s", e)
1929 return []
1930 else:
1931 try:
1932 async with get_db() as db:
1933 prompts, _ = await prompt_service.list_prompts(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
1934 return [_to_mcp_prompt(prompt) for prompt in prompts]
1935 except Exception as e:
1936 logger.exception("Error listing prompts:%s", e)
1937 return []
1940@mcp_app.get_prompt()
1941async def get_prompt(prompt_id: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
1942 """
1943 Retrieves a prompt by ID, optionally substituting arguments.
1945 Args:
1946 prompt_id (str): The ID of the prompt to retrieve.
1947 arguments (Optional[dict[str, str]]): Optional dictionary of arguments to substitute into the prompt.
1949 Returns:
1950 GetPromptResult: Object containing the prompt messages and description.
1951 Returns an empty list on failure or if no prompt content is found.
1953 Raises:
1954 PermissionError: If the user context indicates insufficient permissions (e.g., missing "prompts.read" scope).
1956 Logs exceptions if any errors occur during retrieval.
1958 Examples:
1959 >>> import inspect
1960 >>> sig = inspect.signature(get_prompt)
1961 >>> list(sig.parameters.keys())
1962 ['prompt_id', 'arguments']
1963 >>> sig.return_annotation.__name__
1964 'GetPromptResult'
1965 """
1966 server_id, _, user_context = await _get_request_context_or_default()
1968 # Token scope cap: deny early if scoped permissions exclude prompts.read
1969 if _should_enforce_streamable_rbac(user_context):
1970 if not _check_scoped_permission(user_context, "prompts.read"):
1971 raise PermissionError(_ACCESS_DENIED_MSG)
1973 # Extract authorization parameters from user context (same pattern as list_prompts)
1974 user_email = user_context.get("email") if user_context else None
1975 token_teams = user_context.get("teams") if user_context else None
1976 is_admin = user_context.get("is_admin", False) if user_context else False
1978 # Admin bypass - only when token has NO team restrictions (token_teams is None)
1979 if is_admin and token_teams is None:
1980 user_email = None
1981 # token_teams stays None (unrestricted)
1982 elif token_teams is None:
1983 token_teams = [] # Non-admin without teams = public-only (secure default)
1985 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
1986 # When mcp_require_auth=True, the middleware already guarantees authentication.
1987 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
1988 # the middleware (streamable_http_auth) catches it and returns 503. If the
1989 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
1990 # logged by the ASGI server.
1991 if not settings.mcp_require_auth:
1992 await _check_server_oauth_enforcement(server_id, user_context)
1994 meta_data = None
1995 # Extract _meta from request context if available
1996 try:
1997 ctx = mcp_app.request_context
1998 if ctx and ctx.meta is not None:
1999 meta_data = ctx.meta.model_dump()
2000 except LookupError:
2001 # request_context might not be active in some edge cases (e.g. tests)
2002 logger.debug("No active request context found")
2004 try:
2005 async with get_db() as db:
2006 try:
2007 result = await prompt_service.get_prompt(
2008 db=db,
2009 prompt_id=prompt_id,
2010 arguments=arguments,
2011 user=user_email,
2012 server_id=server_id,
2013 token_teams=token_teams,
2014 _meta_data=meta_data,
2015 )
2016 except Exception as e:
2017 logger.exception("Error getting prompt '%s': %s", prompt_id, e)
2018 return []
2019 if not result or not result.messages:
2020 logger.warning("No content returned by prompt: %s", prompt_id)
2021 return []
2022 message_dicts = [message.model_dump() for message in result.messages]
2023 return types.GetPromptResult(messages=message_dicts, description=result.description)
2024 except Exception as e:
2025 logger.exception("Error getting prompt '%s': %s", prompt_id, e)
2026 return []
2029@mcp_app.list_resources()
2030async def list_resources() -> List[types.Resource]:
2031 """
2032 Lists all resources available to the MCP Server.
2034 Returns:
2035 A list of Resource objects containing metadata such as uri, name, description, and mimeType.
2036 Logs and returns an empty list on failure.
2038 Raises:
2039 PermissionError: If the user context indicates insufficient permissions (e.g., missing "resources.read" scope).
2041 Examples:
2042 >>> import inspect
2043 >>> sig = inspect.signature(list_resources)
2044 >>> list(sig.parameters.keys())
2045 []
2046 >>> sig.return_annotation
2047 typing.List[mcp.types.Resource]
2048 """
2049 server_id, request_headers, user_context = await _get_request_context_or_default()
2051 # Token scope cap: deny early if scoped permissions exclude resources.read
2052 if _should_enforce_streamable_rbac(user_context):
2053 if not _check_scoped_permission(user_context, "resources.read"):
2054 raise PermissionError(_ACCESS_DENIED_MSG)
2056 # Extract filtering parameters from user context
2057 user_email = user_context.get("email") if user_context else None
2058 # Use None as default to distinguish "no teams specified" from "empty teams array"
2059 token_teams = user_context.get("teams") if user_context else None
2060 is_admin = user_context.get("is_admin", False) if user_context else False
2062 # Admin bypass - only when token has NO team restrictions (token_teams is None)
2063 # If token has explicit team scope (even empty [] for public-only), respect it
2064 if is_admin and token_teams is None:
2065 user_email = None
2066 # token_teams stays None (unrestricted)
2067 elif token_teams is None:
2068 token_teams = [] # Non-admin without teams = public-only (secure default)
2070 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2071 # When mcp_require_auth=True, the middleware already guarantees authentication.
2072 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2073 # the middleware (streamable_http_auth) catches it and returns 503. If the
2074 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2075 # logged by the ASGI server.
2076 if not settings.mcp_require_auth:
2077 await _check_server_oauth_enforcement(server_id, user_context)
2079 if server_id:
2080 try:
2081 async with get_db() as db:
2082 # Check for X-Context-Forge-Gateway-Id header first for direct proxy mode
2083 gateway_id = extract_gateway_id_from_headers(request_headers)
2085 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
2086 if gateway_id:
2087 # Third-Party
2088 from sqlalchemy import select # pylint: disable=import-outside-toplevel
2090 # First-Party
2091 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
2093 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
2094 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
2095 # SECURITY: Check gateway access before allowing direct proxy
2096 if not await check_gateway_access(db, gateway, user_email, token_teams):
2097 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
2098 return [] # Return empty list for unauthorized access
2100 # Direct proxy mode: forward request to remote MCP server
2101 # Get _meta from request context if available
2102 meta = None
2103 try:
2104 request_ctx = mcp_app.request_context
2105 meta = request_ctx.meta
2106 logger.info(
2107 "[LIST RESOURCES] Using direct_proxy mode for server %s, gateway %s (from %s header). Meta Attached: %s",
2108 server_id,
2109 gateway.id,
2110 GATEWAY_ID_HEADER,
2111 meta is not None,
2112 )
2113 except (LookupError, AttributeError) as e:
2114 logger.debug("No request context available for _meta extraction: %s", e)
2116 return await _proxy_list_resources_to_gateway(gateway, request_headers, user_context, meta)
2117 if gateway:
2118 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, gateway.gateway_mode)
2119 else:
2120 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
2122 # Default cache mode: use database
2123 resources = await resource_service.list_server_resources(db, server_id, user_email=user_email, token_teams=token_teams)
2124 return [
2125 types.Resource(uri=resource.uri, name=resource.name, title=_safe_str_attr(resource, "title"), description=resource.description, mimeType=resource.mime_type)
2126 for resource in resources
2127 ]
2128 except Exception as e:
2129 logger.exception("Error listing Resources:%s", e)
2130 return []
2131 else:
2132 try:
2133 async with get_db() as db:
2134 resources, _ = await resource_service.list_resources(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams)
2135 return [
2136 types.Resource(uri=resource.uri, name=resource.name, title=_safe_str_attr(resource, "title"), description=resource.description, mimeType=resource.mime_type)
2137 for resource in resources
2138 ]
2139 except Exception as e:
2140 logger.exception("Error listing resources:%s", e)
2141 return []
2144@mcp_app.read_resource()
2145async def read_resource(resource_uri: str) -> Union[str, bytes]:
2146 """
2147 Reads the content of a resource specified by its URI.
2149 Args:
2150 resource_uri (str): The URI of the resource to read.
2152 Returns:
2153 Union[str, bytes]: The content of the resource as text or binary data.
2154 Returns empty string on failure or if no content is found.
2156 Raises:
2157 PermissionError: If the user does not have the required permissions to read resources.
2159 Logs exceptions if any errors occur during reading.
2161 Examples:
2162 >>> import inspect
2163 >>> sig = inspect.signature(read_resource)
2164 >>> list(sig.parameters.keys())
2165 ['resource_uri']
2166 >>> sig.return_annotation
2167 typing.Union[str, bytes]
2168 """
2169 server_id, request_headers, user_context = await _get_request_context_or_default()
2171 # Token scope cap: deny early if scoped permissions exclude resources.read
2172 if _should_enforce_streamable_rbac(user_context):
2173 if not _check_scoped_permission(user_context, "resources.read"):
2174 raise PermissionError(_ACCESS_DENIED_MSG)
2176 # Extract authorization parameters from user context (same pattern as list_resources)
2177 user_email = user_context.get("email") if user_context else None
2178 token_teams = user_context.get("teams") if user_context else None
2179 is_admin = user_context.get("is_admin", False) if user_context else False
2181 # Admin bypass - only when token has NO team restrictions (token_teams is None)
2182 if is_admin and token_teams is None:
2183 user_email = None
2184 # token_teams stays None (unrestricted)
2185 elif token_teams is None:
2186 token_teams = [] # Non-admin without teams = public-only (secure default)
2188 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2189 # When mcp_require_auth=True, the middleware already guarantees authentication.
2190 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2191 # the middleware (streamable_http_auth) catches it and returns 503. If the
2192 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2193 # logged by the ASGI server.
2194 if not settings.mcp_require_auth:
2195 await _check_server_oauth_enforcement(server_id, user_context)
2197 meta_data = None
2198 # Extract _meta from request context if available
2199 try:
2200 ctx = mcp_app.request_context
2201 if ctx and ctx.meta is not None:
2202 meta_data = ctx.meta.model_dump()
2203 except LookupError:
2204 # request_context might not be active in some edge cases (e.g. tests)
2205 logger.debug("No active request context found")
2207 try:
2208 async with get_db() as db:
2209 # Check for X-Context-Forge-Gateway-Id header first for direct proxy mode
2210 gateway_id = extract_gateway_id_from_headers(request_headers)
2212 # If X-Context-Forge-Gateway-Id is provided, check if that gateway is in direct_proxy mode
2213 if gateway_id:
2214 # Third-Party
2215 from sqlalchemy import select # pylint: disable=import-outside-toplevel
2217 # First-Party
2218 from mcpgateway.db import Gateway as DbGateway # pylint: disable=import-outside-toplevel
2220 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
2221 if gateway and gateway.gateway_mode == "direct_proxy" and settings.mcpgateway_direct_proxy_enabled:
2222 # SECURITY: Check gateway access before allowing direct proxy
2223 if not await check_gateway_access(db, gateway, user_email, token_teams):
2224 logger.warning("Access denied to gateway %s in direct_proxy mode for user %s", gateway_id, user_email)
2225 return ""
2227 # Direct proxy mode: forward request to remote MCP server
2228 # Get _meta from request context if available
2229 meta = None
2230 try:
2231 request_ctx = mcp_app.request_context
2232 meta = request_ctx.meta
2233 logger.info(
2234 "Using direct_proxy mode for resources/read %s, server %s, gateway %s (from %s header), forwarding _meta: %s",
2235 resource_uri,
2236 server_id,
2237 gateway.id,
2238 GATEWAY_ID_HEADER,
2239 meta,
2240 )
2241 except (LookupError, AttributeError) as e:
2242 logger.debug("No request context available for _meta extraction: %s", e)
2244 contents = await _proxy_read_resource_to_gateway(gateway, str(resource_uri), user_context, meta)
2245 if contents:
2246 # Return first content (text or blob)
2247 first_content = contents[0]
2248 if hasattr(first_content, "text"):
2249 return first_content.text
2250 if hasattr(first_content, "blob"):
2251 return first_content.blob
2252 return ""
2253 if gateway:
2254 logger.debug("Gateway %s found but not in direct_proxy mode (mode: %s), using cache mode", gateway_id, gateway.gateway_mode)
2255 else:
2256 logger.warning("Gateway %s specified in %s header not found", gateway_id, GATEWAY_ID_HEADER)
2258 # Default cache mode: use database
2259 try:
2260 result = await resource_service.read_resource(
2261 db=db,
2262 resource_uri=str(resource_uri),
2263 user=user_email,
2264 server_id=server_id,
2265 token_teams=token_teams,
2266 meta_data=meta_data,
2267 )
2268 except Exception as e:
2269 logger.exception("Error reading resource '%s': %s", resource_uri, e)
2270 return ""
2272 # Return blob content if available (binary resources)
2273 if result and result.blob:
2274 return result.blob
2276 # Return text content if available (text resources)
2277 if result and result.text:
2278 return result.text
2280 # No content found
2281 logger.warning("No content returned by resource: %s", resource_uri)
2282 return ""
2283 except Exception as e:
2284 logger.exception("Error reading resource '%s': %s", resource_uri, e)
2285 return ""
2288@mcp_app.list_resource_templates()
2289async def list_resource_templates() -> List[Dict[str, Any]]:
2290 """
2291 Lists all resource templates available to the MCP Server.
2293 Returns:
2294 List[types.ResourceTemplate]: A list of resource templates with their URIs and metadata.
2296 Raises:
2297 PermissionError: If the caller lacks ``resources.read`` permission.
2299 Examples:
2300 >>> import inspect
2301 >>> sig = inspect.signature(list_resource_templates)
2302 >>> list(sig.parameters.keys())
2303 []
2304 >>> sig.return_annotation.__origin__.__name__
2305 'list'
2306 """
2307 # Extract filtering parameters from user context (same pattern as list_resources)
2308 server_id, _, user_context = await _get_request_context_or_default()
2310 # Token scope cap: deny early if scoped permissions exclude resources.read
2311 if _should_enforce_streamable_rbac(user_context):
2312 if not _check_scoped_permission(user_context, "resources.read"):
2313 raise PermissionError(_ACCESS_DENIED_MSG)
2315 user_email = user_context.get("email") if user_context else None
2316 token_teams = user_context.get("teams") if user_context else None
2317 is_admin = user_context.get("is_admin", False) if user_context else False
2319 # Admin bypass - only when token has NO team restrictions (token_teams is None)
2320 # If token has explicit team scope (even empty [] for public-only), respect it
2321 if is_admin and token_teams is None:
2322 user_email = None
2323 # token_teams stays None (unrestricted)
2324 elif token_teams is None:
2325 token_teams = [] # Non-admin without teams = public-only (secure default)
2327 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2328 # When mcp_require_auth=True, the middleware already guarantees authentication.
2329 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2330 # the middleware (streamable_http_auth) catches it and returns 503. If the
2331 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2332 # logged by the ASGI server.
2333 if not settings.mcp_require_auth:
2334 await _check_server_oauth_enforcement(server_id, user_context)
2336 try:
2337 async with get_db() as db:
2338 try:
2339 resource_templates = await resource_service.list_resource_templates(
2340 db,
2341 user_email=user_email,
2342 token_teams=token_teams,
2343 server_id=server_id,
2344 )
2345 return [template.model_dump(by_alias=True) for template in resource_templates]
2346 except Exception as e:
2347 logger.exception("Error listing resource templates: %s", e)
2348 return []
2349 except Exception as e:
2350 logger.exception("Error listing resource templates: %s", e)
2351 return []
2354@mcp_app.set_logging_level()
2355async def set_logging_level(level: types.LoggingLevel) -> types.EmptyResult:
2356 """
2357 Sets the logging level for the MCP Server.
2359 Args:
2360 level (types.LoggingLevel): The desired logging level (debug, info, notice, warning, error, critical, alert, emergency).
2362 Returns:
2363 types.EmptyResult: An empty result indicating success.
2365 Examples:
2366 >>> import inspect
2367 >>> sig = inspect.signature(set_logging_level)
2368 >>> list(sig.parameters.keys())
2369 ['level']
2371 Raises:
2372 PermissionError: If the user does not have permission to set the logging level.
2373 """
2374 server_id, _, user_context = await _get_request_context_or_default()
2376 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2377 # When mcp_require_auth=True, the middleware already guarantees authentication.
2378 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2379 # the middleware (streamable_http_auth) catches it and returns 503. If the
2380 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2381 # logged by the ASGI server.
2382 if not settings.mcp_require_auth:
2383 await _check_server_oauth_enforcement(server_id, user_context)
2385 if _should_enforce_streamable_rbac(user_context):
2386 # Layer 1: Token scope cap
2387 if not _check_scoped_permission(user_context, "admin.system_config"):
2388 raise PermissionError(_ACCESS_DENIED_MSG)
2389 # Layer 2: RBAC check
2390 has_permission = await _check_streamable_permission(
2391 user_context=user_context,
2392 permission="admin.system_config",
2393 check_any_team=_check_any_team_for_server_scoped_rbac(user_context, server_id),
2394 )
2395 if not has_permission:
2396 raise PermissionError(_ACCESS_DENIED_MSG)
2398 try:
2399 # Convert MCP logging level to our LogLevel enum
2400 level_map = {
2401 "debug": LogLevel.DEBUG,
2402 "info": LogLevel.INFO,
2403 "notice": LogLevel.INFO,
2404 "warning": LogLevel.WARNING,
2405 "error": LogLevel.ERROR,
2406 "critical": LogLevel.CRITICAL,
2407 "alert": LogLevel.CRITICAL,
2408 "emergency": LogLevel.CRITICAL,
2409 }
2410 log_level = level_map.get(level.lower(), LogLevel.INFO)
2411 await logging_service.set_level(log_level)
2412 return types.EmptyResult()
2413 except PermissionError:
2414 raise
2415 except Exception as e:
2416 logger.exception("Error setting logging level: %s", e)
2417 return types.EmptyResult()
2420@mcp_app.completion()
2421async def complete(
2422 ref: Union[types.PromptReference, types.ResourceTemplateReference],
2423 argument: types.CompleteRequest,
2424 context: Optional[types.CompletionContext] = None,
2425) -> types.CompleteResult:
2426 """
2427 Provides argument completion suggestions for prompts or resources.
2429 Args:
2430 ref: A reference to a prompt or a resource template. Can be either
2431 `types.PromptReference` or `types.ResourceTemplateReference`.
2432 argument: The completion request specifying the input text and
2433 position for which completion suggestions should be generated.
2434 context: Optional contextual information for the completion request,
2435 such as user, environment, or invocation metadata.
2437 Returns:
2438 types.CompleteResult: A normalized completion result containing
2439 completion values, metadata (total, hasMore), and any additional
2440 MCP-compliant completion fields.
2442 Raises:
2443 PermissionError: If the caller lacks ``tools.read`` permission.
2444 Exception: If completion handling fails internally. The method
2445 logs the exception and returns an empty completion structure.
2446 """
2447 # Derive caller visibility scope from the current request context.
2448 server_id, _, user_context = await _get_request_context_or_default()
2450 # Token scope cap: deny early if scoped permissions exclude tools.read
2451 if _should_enforce_streamable_rbac(user_context):
2452 if not _check_scoped_permission(user_context, "tools.read"):
2453 raise PermissionError(_ACCESS_DENIED_MSG)
2455 # Enforce per-server OAuth requirement in permissive mode (defense-in-depth).
2456 # When mcp_require_auth=True, the middleware already guarantees authentication.
2457 # Note: OAuthEnforcementUnavailableError is intentionally uncaught here —
2458 # the middleware (streamable_http_auth) catches it and returns 503. If the
2459 # middleware is somehow bypassed, an uncaught 500 is acceptable and will be
2460 # logged by the ASGI server.
2461 if not settings.mcp_require_auth:
2462 await _check_server_oauth_enforcement(server_id, user_context)
2464 try:
2465 user_email = user_context.get("email") if user_context else None
2466 token_teams = user_context.get("teams") if user_context else None
2467 is_admin = user_context.get("is_admin", False) if user_context else False
2469 # Admin bypass only for explicit unrestricted context; otherwise secure default.
2470 if is_admin and token_teams is None:
2471 user_email = None
2472 elif token_teams is None:
2473 token_teams = [] # Non-admin without explicit teams -> public-only
2475 async with get_db() as db:
2476 params = {
2477 "ref": ref.model_dump() if hasattr(ref, "model_dump") else ref,
2478 "argument": argument.model_dump() if hasattr(argument, "model_dump") else argument,
2479 "context": context.model_dump() if hasattr(context, "model_dump") else context,
2480 }
2482 result = await completion_service.handle_completion(
2483 db,
2484 params,
2485 user_email=user_email,
2486 token_teams=token_teams,
2487 )
2489 # ✅ Normalize the result for MCP
2490 if isinstance(result, dict):
2491 completion_data = result.get("completion", result)
2492 return types.Completion(**completion_data)
2494 if hasattr(result, "completion"):
2495 completion_obj = result.completion
2497 # If completion itself is a dict
2498 if isinstance(completion_obj, dict):
2499 return types.Completion(**completion_obj)
2501 # If completion is another CompleteResult (nested)
2502 if hasattr(completion_obj, "completion"):
2503 inner_completion = completion_obj.completion.model_dump() if hasattr(completion_obj.completion, "model_dump") else completion_obj.completion
2504 return types.Completion(**inner_completion)
2506 # If completion is already a Completion model
2507 if isinstance(completion_obj, types.Completion):
2508 return completion_obj
2510 # If it's another Pydantic model (e.g., mcpgateway.models.Completion)
2511 if hasattr(completion_obj, "model_dump"):
2512 return types.Completion(**completion_obj.model_dump())
2514 # If result itself is already a types.Completion
2515 if isinstance(result, types.Completion):
2516 return result
2518 # Fallback: return empty completion
2519 return types.Completion(values=[], total=0, hasMore=False)
2521 except Exception as e:
2522 logger.exception("Error handling completion: %s", e)
2523 return types.Completion(values=[], total=0, hasMore=False)
2526class SessionManagerWrapper:
2527 """
2528 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
2529 Provides start, stop, and request handling methods.
2531 Examples:
2532 >>> # Test SessionManagerWrapper initialization
2533 >>> wrapper = SessionManagerWrapper()
2534 >>> wrapper
2535 <mcpgateway.transports.streamablehttp_transport.SessionManagerWrapper object at ...>
2536 >>> hasattr(wrapper, 'session_manager')
2537 True
2538 >>> hasattr(wrapper, 'stack')
2539 True
2540 >>> isinstance(wrapper.stack, AsyncExitStack)
2541 True
2542 """
2544 def __init__(self) -> None:
2545 """
2546 Initializes the session manager and the exit stack used for managing its lifecycle.
2548 Examples:
2549 >>> # Test initialization
2550 >>> wrapper = SessionManagerWrapper()
2551 >>> wrapper.session_manager is not None
2552 True
2553 >>> wrapper.stack is not None
2554 True
2555 """
2557 if settings.use_stateful_sessions:
2558 if settings.experimental_rust_mcp_runtime_enabled and settings.experimental_rust_mcp_session_auth_reuse_enabled and settings.experimental_rust_mcp_event_store_enabled:
2559 event_store = RustEventStore(
2560 max_events_per_stream=settings.streamable_http_max_events_per_stream,
2561 ttl=settings.streamable_http_event_ttl,
2562 )
2563 logger.debug("Using RustEventStore for stateful sessions")
2564 # Use Redis event store for single-worker stateful deployments
2565 elif settings.cache_type == "redis" and settings.redis_url:
2566 event_store = RedisEventStore(max_events_per_stream=settings.streamable_http_max_events_per_stream, ttl=settings.streamable_http_event_ttl)
2567 logger.debug("Using RedisEventStore for stateful sessions (single-worker)")
2568 else:
2569 # Fall back to in-memory for single-worker or when Redis not available
2570 event_store = InMemoryEventStore()
2571 logger.warning("Using InMemoryEventStore - only works with single worker!")
2572 stateless = False
2573 else:
2574 event_store = None
2575 stateless = True
2577 self.session_manager = StreamableHTTPSessionManager(
2578 app=mcp_app,
2579 event_store=event_store,
2580 json_response=settings.json_response_enabled,
2581 stateless=stateless,
2582 )
2583 self.stack = AsyncExitStack()
2585 async def initialize(self) -> None:
2586 """
2587 Starts the Streamable HTTP session manager context.
2589 Examples:
2590 >>> # Test initialize method exists
2591 >>> wrapper = SessionManagerWrapper()
2592 >>> hasattr(wrapper, 'initialize')
2593 True
2594 >>> callable(wrapper.initialize)
2595 True
2596 """
2597 logger.debug("Initializing Streamable HTTP service")
2598 await self.stack.enter_async_context(self.session_manager.run())
2600 async def shutdown(self) -> None:
2601 """
2602 Gracefully shuts down the Streamable HTTP session manager.
2604 Examples:
2605 >>> # Test shutdown method exists
2606 >>> wrapper = SessionManagerWrapper()
2607 >>> hasattr(wrapper, 'shutdown')
2608 True
2609 >>> callable(wrapper.shutdown)
2610 True
2611 """
2612 logger.debug("Stopping Streamable HTTP Session Manager...")
2613 await self.stack.aclose()
2615 @staticmethod
2616 async def _validate_server_id(match: "re.Match[str] | None", path: str, scope: Scope, receive: Receive, send: Send) -> str | None:
2617 """Validate and resolve the server_id from the request path.
2619 Args:
2620 match: Result of ``_SERVER_ID_RE.search(path)``.
2621 path: Original request path (``scope["modified_path"]``).
2622 scope: ASGI scope dict.
2623 receive: ASGI receive callable.
2624 send: ASGI send callable.
2626 Returns:
2627 The validated server_id string, ``None`` when the path is
2628 not server-scoped (legitimate global ``/mcp``), or the
2629 sentinel ``_REJECT`` when an error response has already been
2630 sent and the caller should return immediately.
2631 """
2632 if match:
2633 server_id = match.group("server_id")
2634 # SECURITY: Validate that the server_id exists in the database
2635 # to prevent unauthorized access via invalid server IDs.
2636 # Uses the shared BaseService.entity_exists() for a lightweight
2637 # EXISTS check — no row data is loaded.
2638 try:
2639 # First-Party
2640 from mcpgateway.services.server_service import server_service as _server_svc # pylint: disable=import-outside-toplevel,no-name-in-module
2642 async with get_db() as db:
2643 if not await _server_svc.entity_exists(db, server_id):
2644 logger.warning("Invalid server ID in MCP request path: %s", server_id)
2645 response = ORJSONResponse({"detail": "Server not found"}, status_code=404)
2646 await response(scope, receive, send)
2647 return _REJECT
2648 except Exception as e:
2649 logger.error("Failed to validate server ID %s: %s", server_id, e)
2650 response = ORJSONResponse({"detail": "Service unavailable — unable to verify server"}, status_code=503)
2651 await response(scope, receive, send)
2652 return _REJECT
2653 return server_id
2655 # SECURITY (defense-in-depth): If the path looks server-scoped but
2656 # the primary regex didn't capture a server_id (e.g. empty segment
2657 # /servers//mcp, or an encoding edge case), reject immediately
2658 # rather than falling through to unscoped global behaviour (#3891).
2659 if _SERVER_SCOPED_PATH_RE.search(path):
2660 logger.warning("Server-scoped MCP path with unparseable server ID rejected: %s", path)
2661 response = ORJSONResponse({"detail": "Invalid server identifier"}, status_code=404)
2662 await response(scope, receive, send)
2663 return _REJECT
2665 return None # Legitimate unscoped /mcp path
2667 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
2668 """
2669 Forwards an incoming ASGI request to the streamable HTTP session manager.
2671 Args:
2672 scope (Scope): ASGI scope object containing connection information.
2673 receive (Receive): ASGI receive callable.
2674 send (Send): ASGI send callable.
2676 Raises:
2677 Exception: Any exception raised during request handling is logged.
2679 Logs any exceptions that occur during request handling.
2681 Examples:
2682 >>> # Test handle_streamable_http method exists
2683 >>> wrapper = SessionManagerWrapper()
2684 >>> hasattr(wrapper, 'handle_streamable_http')
2685 True
2686 >>> callable(wrapper.handle_streamable_http)
2687 True
2689 >>> # Test method signature
2690 >>> import inspect
2691 >>> sig = inspect.signature(wrapper.handle_streamable_http)
2692 >>> list(sig.parameters.keys())
2693 ['scope', 'receive', 'send']
2694 """
2696 path = scope["modified_path"]
2697 # Uses precompiled regex for server ID extraction
2698 match = _SERVER_ID_RE.search(path)
2700 # Extract request headers from scope (ASGI provides bytes; normalize to lowercase for lookup).
2701 raw_headers = scope.get("headers") or []
2702 headers: dict[str, str] = {}
2703 for item in raw_headers:
2704 if not isinstance(item, (tuple, list)) or len(item) != 2:
2705 continue
2706 k, v = item
2707 if not isinstance(k, (bytes, bytearray)) or not isinstance(v, (bytes, bytearray)):
2708 continue
2709 # latin-1 is a byte-preserving decode; safe for arbitrary header bytes.
2710 headers[k.decode("latin-1").lower()] = v.decode("latin-1")
2712 # Log session info for debugging stateful sessions
2713 mcp_session_id = headers.get("x-mcp-session-id") or headers.get("mcp-session-id") or "not-provided"
2714 if mcp_session_id != "not-provided":
2715 set_trace_session_id(mcp_session_id)
2716 method = scope.get("method", "UNKNOWN")
2717 query_string = scope.get("query_string", b"").decode("utf-8")
2718 logger.debug("[STATEFUL] Streamable HTTP %s %s | MCP-Session-Id: %s | Query: %s | Stateful: %s", method, path, mcp_session_id, query_string, settings.use_stateful_sessions)
2720 # Note: mcp-session-id from client is used for gateway-internal session affinity
2721 # routing (stored in request_headers_var), but is NOT renamed or forwarded to
2722 # upstream servers - it's a gateway-side concept, not an end-to-end semantic header
2724 # Multi-worker session affinity: check if we should forward to another worker
2725 # This must happen BEFORE the SDK's session manager handles the request
2726 # Only trust x-forwarded-internally from loopback to prevent external spoofing
2727 _client = scope.get("client")
2728 _client_host = _client[0] if _client else None
2729 _from_loopback = _client_host in ("127.0.0.1", "::1") if _client_host else False
2730 is_internally_forwarded = _from_loopback and headers.get("x-forwarded-internally") == "true"
2732 if settings.mcpgateway_session_affinity_enabled and mcp_session_id != "not-provided":
2733 try:
2734 # First-Party
2735 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel
2737 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id):
2738 logger.debug("Invalid MCP session id on Streamable HTTP request, skipping affinity")
2739 mcp_session_id = "not-provided"
2740 except Exception:
2741 mcp_session_id = "not-provided"
2743 # Log session manager ID for debugging
2744 logger.debug("[SESSION_MGR_DEBUG] Manager ID: %s", id(self.session_manager))
2746 # Enforce server access parity for server-scoped Streamable HTTP MCP routes.
2747 # This mirrors /servers/{id}/sse and /servers/{id}/message guards.
2748 user_context = user_context_var.get()
2749 if match and _should_enforce_streamable_rbac(user_context):
2750 _server_id = match.group("server_id")
2751 has_server_access = await _check_streamable_permission(
2752 user_context=user_context,
2753 permission="servers.use",
2754 check_any_team=_check_any_team_for_server_scoped_rbac(user_context, _server_id),
2755 )
2756 if not has_server_access:
2757 response = ORJSONResponse(
2758 {"detail": _ACCESS_DENIED_MSG},
2759 status_code=HTTP_403_FORBIDDEN,
2760 )
2761 await response(scope, receive, send)
2762 return
2764 # SECURITY: Validate server existence early — before affinity routing
2765 # can shortcut to /rpc, which checks token scoping but not server
2766 # existence. Without this, nonexistent server IDs that reach the
2767 # affinity branches would bypass the 404 and get empty-scoped results.
2768 validated = await self._validate_server_id(match, path, scope, receive, send)
2769 if validated is _REJECT:
2770 return
2772 if is_internally_forwarded:
2773 logger.debug("[HTTP_AFFINITY_FORWARDED] Received forwarded request | Method: %s | Session: %s", method, mcp_session_id)
2775 # Only route POST requests with JSON-RPC body to /rpc
2776 # DELETE and other methods should return success (session cleanup is local)
2777 if method != "POST":
2778 logger.debug("[HTTP_AFFINITY_FORWARDED] Non-POST method, returning 200 OK")
2779 await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]})
2780 await send({"type": "http.response.body", "body": b'{"jsonrpc":"2.0","result":{}}'})
2781 return
2783 # For POST requests, bypass SDK session manager and use /rpc directly
2784 # This avoids SDK's session cleanup issues while maintaining stateful upstream connections
2785 try:
2786 # Read request body
2787 body_parts = []
2788 while True:
2789 message = await receive()
2790 if message["type"] == "http.request":
2791 body_parts.append(message.get("body", b""))
2792 if not message.get("more_body", False):
2793 break
2794 elif message["type"] == "http.disconnect":
2795 return
2796 body = b"".join(body_parts)
2798 if not body:
2799 logger.debug("[HTTP_AFFINITY_FORWARDED] Empty body, returning 202 Accepted")
2800 await send({"type": "http.response.start", "status": 202, "headers": []})
2801 await send({"type": "http.response.body", "body": b""})
2802 return
2804 json_body = orjson.loads(body)
2805 rpc_method = json_body.get("method", "")
2806 logger.debug("[HTTP_AFFINITY_FORWARDED] Routing to /rpc | Method: %s", rpc_method)
2808 session_allowed, deny_status, deny_detail = await _validate_streamable_session_access(
2809 mcp_session_id=mcp_session_id,
2810 user_context=user_context,
2811 rpc_method=rpc_method,
2812 )
2813 if not session_allowed:
2814 response = ORJSONResponse({"detail": deny_detail}, status_code=deny_status)
2815 await response(scope, receive, send)
2816 return
2818 # Notifications don't need /rpc routing - just acknowledge
2819 if rpc_method.startswith("notifications/"):
2820 logger.debug("[HTTP_AFFINITY_FORWARDED] Notification, returning 202 Accepted")
2821 await send({"type": "http.response.start", "status": 202, "headers": []})
2822 await send({"type": "http.response.body", "body": b""})
2823 return
2825 # Inject server_id from URL path into params for /rpc routing
2826 if match:
2827 server_id = match.group("server_id")
2828 if not isinstance(json_body.get("params"), dict):
2829 json_body["params"] = {}
2830 json_body["params"]["server_id"] = server_id
2831 # Re-serialize body with injected server_id
2832 body = orjson.dumps(json_body)
2833 logger.debug("[HTTP_AFFINITY_FORWARDED] Injected server_id %s into /rpc params", server_id)
2835 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client:
2836 rpc_headers = {
2837 "content-type": "application/json",
2838 "x-mcp-session-id": mcp_session_id, # Pass session for upstream affinity
2839 "x-forwarded-internally": "true", # Prevent infinite forwarding loops
2840 }
2841 # Copy auth header if present
2842 if "authorization" in headers:
2843 rpc_headers["authorization"] = headers["authorization"]
2844 # Forward passthrough headers for upstream MCP servers (see #3640).
2845 # First-Party
2846 from mcpgateway.utils.passthrough_headers import safe_extract_and_filter_for_loopback # pylint: disable=import-outside-toplevel
2848 rpc_headers.update(safe_extract_and_filter_for_loopback(headers))
2850 response = await client.post(
2851 f"{internal_loopback_base_url()}/rpc",
2852 content=body,
2853 headers=rpc_headers,
2854 timeout=30.0,
2855 )
2857 # Return response to client
2858 response_headers = [
2859 (b"content-type", b"application/json"),
2860 (b"content-length", str(len(response.content)).encode()),
2861 ]
2862 if mcp_session_id != "not-provided":
2863 response_headers.append((b"mcp-session-id", mcp_session_id.encode()))
2865 await send(
2866 {
2867 "type": "http.response.start",
2868 "status": response.status_code,
2869 "headers": response_headers,
2870 }
2871 )
2872 await send(
2873 {
2874 "type": "http.response.body",
2875 "body": response.content,
2876 }
2877 )
2878 logger.debug("[HTTP_AFFINITY_FORWARDED] Response sent | Status: %s", response.status_code)
2879 return
2880 except Exception as e:
2881 logger.error("[HTTP_AFFINITY_FORWARDED] Error routing to /rpc: %s", e)
2882 # Fall through to SDK handling as fallback
2884 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions and mcp_session_id != "not-provided" and not is_internally_forwarded:
2885 try:
2886 # First-Party - lazy import to avoid circular dependencies
2887 # First-Party
2888 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
2890 pool = get_mcp_session_pool()
2891 owner = await pool.get_streamable_http_session_owner(mcp_session_id)
2892 logger.debug("[HTTP_AFFINITY_CHECK] Worker %s | Session %s... | Owner from Redis: %s", WORKER_ID, mcp_session_id[:8], owner)
2894 if owner and owner != WORKER_ID:
2895 # Session owned by another worker - forward the entire HTTP request
2896 logger.info("[HTTP_AFFINITY] Worker %s | Session %s... | Owner: %s | Forwarding HTTP request", WORKER_ID, mcp_session_id[:8], owner)
2898 # Read request body
2899 body_parts = []
2900 while True:
2901 message = await receive()
2902 if message["type"] == "http.request":
2903 body_parts.append(message.get("body", b""))
2904 if not message.get("more_body", False):
2905 break
2906 elif message["type"] == "http.disconnect":
2907 return
2908 body = b"".join(body_parts)
2910 # Forward to owner worker
2911 response = await pool.forward_streamable_http_to_owner(
2912 owner_worker_id=owner,
2913 mcp_session_id=mcp_session_id,
2914 method=method,
2915 path=path,
2916 headers=headers,
2917 body=body,
2918 query_string=query_string,
2919 )
2921 if response:
2922 # Send forwarded response back to client
2923 response_headers = [(k.encode(), v.encode()) for k, v in response["headers"].items() if k.lower() not in ("transfer-encoding", "content-encoding", "content-length")]
2924 response_headers.append((b"content-length", str(len(response["body"])).encode()))
2926 await send(
2927 {
2928 "type": "http.response.start",
2929 "status": response["status"],
2930 "headers": response_headers,
2931 }
2932 )
2933 await send(
2934 {
2935 "type": "http.response.body",
2936 "body": response["body"],
2937 }
2938 )
2939 logger.debug("[HTTP_AFFINITY] Worker %s | Session %s... | Forwarded response sent to client", WORKER_ID, mcp_session_id[:8])
2940 return
2942 # Forwarding failed - fall through to local handling
2943 # This may result in "session not found" but it's better than no response
2944 logger.debug("[HTTP_AFFINITY] Worker %s | Session %s... | Forwarding failed, falling back to local", WORKER_ID, mcp_session_id[:8])
2946 elif owner == WORKER_ID and method == "POST":
2947 # We own this session - route POST requests to /rpc to avoid SDK session issues
2948 # The SDK's _server_instances gets cleared between requests, so we can't rely on it
2949 logger.debug("[HTTP_AFFINITY_LOCAL] Worker %s | Session %s... | Owner is us, routing to /rpc", WORKER_ID, mcp_session_id[:8])
2951 # Read request body
2952 body_parts = []
2953 while True:
2954 message = await receive()
2955 if message["type"] == "http.request":
2956 body_parts.append(message.get("body", b""))
2957 if not message.get("more_body", False):
2958 break
2959 elif message["type"] == "http.disconnect":
2960 return
2961 body = b"".join(body_parts)
2963 if not body:
2964 logger.debug("[HTTP_AFFINITY_LOCAL] Empty body, returning 202 Accepted")
2965 await send({"type": "http.response.start", "status": 202, "headers": []})
2966 await send({"type": "http.response.body", "body": b""})
2967 return
2969 # Parse JSON-RPC and route to /rpc
2970 try:
2971 json_body = orjson.loads(body)
2972 rpc_method = json_body.get("method", "")
2973 logger.debug("[HTTP_AFFINITY_LOCAL] Routing to /rpc | Method: %s", rpc_method)
2975 session_allowed, deny_status, deny_detail = await _validate_streamable_session_access(
2976 mcp_session_id=mcp_session_id,
2977 user_context=user_context,
2978 rpc_method=rpc_method,
2979 )
2980 if not session_allowed:
2981 response = ORJSONResponse({"detail": deny_detail}, status_code=deny_status)
2982 await response(scope, receive, send)
2983 return
2985 # Notifications don't need /rpc routing
2986 if rpc_method.startswith("notifications/"):
2987 logger.debug("[HTTP_AFFINITY_LOCAL] Notification, returning 202 Accepted")
2988 await send({"type": "http.response.start", "status": 202, "headers": []})
2989 await send({"type": "http.response.body", "body": b""})
2990 return
2992 # Inject server_id from URL path into params for /rpc routing
2993 if match:
2994 server_id = match.group("server_id")
2995 if not isinstance(json_body.get("params"), dict):
2996 json_body["params"] = {}
2997 json_body["params"]["server_id"] = server_id
2998 # Re-serialize body with injected server_id
2999 body = orjson.dumps(json_body)
3000 logger.debug("[HTTP_AFFINITY_LOCAL] Injected server_id %s into /rpc params", server_id)
3002 async with httpx.AsyncClient(verify=internal_loopback_verify()) as client:
3003 rpc_headers = {
3004 "content-type": "application/json",
3005 "x-mcp-session-id": mcp_session_id,
3006 "x-forwarded-internally": "true",
3007 }
3008 if "authorization" in headers:
3009 rpc_headers["authorization"] = headers["authorization"]
3010 # Forward passthrough headers for upstream MCP servers (see #3640).
3011 # First-Party
3012 from mcpgateway.utils.passthrough_headers import safe_extract_and_filter_for_loopback # pylint: disable=import-outside-toplevel
3014 rpc_headers.update(safe_extract_and_filter_for_loopback(headers))
3016 response = await client.post(
3017 f"{internal_loopback_base_url()}/rpc",
3018 content=body,
3019 headers=rpc_headers,
3020 timeout=30.0,
3021 )
3023 response_headers = [
3024 (b"content-type", b"application/json"),
3025 (b"content-length", str(len(response.content)).encode()),
3026 (b"mcp-session-id", mcp_session_id.encode()),
3027 ]
3029 await send(
3030 {
3031 "type": "http.response.start",
3032 "status": response.status_code,
3033 "headers": response_headers,
3034 }
3035 )
3036 await send(
3037 {
3038 "type": "http.response.body",
3039 "body": response.content,
3040 }
3041 )
3042 logger.debug("[HTTP_AFFINITY_LOCAL] Response sent | Status: %s", response.status_code)
3043 return
3044 except Exception as e:
3045 logger.error("[HTTP_AFFINITY_LOCAL] Error routing to /rpc: %s", e)
3046 # Fall through to SDK handling as fallback
3048 except RuntimeError:
3049 # Pool not initialized - proceed with local handling
3050 pass
3051 except Exception as e:
3052 logger.debug("Session affinity check failed, proceeding locally: %s", e)
3054 # Store headers in context for tool invocations
3055 request_headers_var.set(headers)
3057 server_id_var.set(validated)
3059 # For session affinity: wrap send to capture session ID from response headers
3060 # This allows us to register ownership for new sessions created by the SDK
3061 captured_session_id: Optional[str] = None
3063 async def send_with_capture(message: Dict[str, Any]) -> None:
3064 """Wrap ASGI send to capture session ID from response headers.
3066 Args:
3067 message: ASGI message dict.
3068 """
3069 nonlocal captured_session_id
3070 if message["type"] == "http.response.start" and settings.mcpgateway_session_affinity_enabled:
3071 # Look for mcp-session-id in response headers
3072 response_headers = message.get("headers", [])
3073 for header_name, header_value in response_headers:
3074 if isinstance(header_name, bytes):
3075 header_name = header_name.decode("latin-1")
3076 if isinstance(header_value, bytes):
3077 header_value = header_value.decode("latin-1")
3078 if header_name.lower() == "mcp-session-id":
3079 captured_session_id = header_value
3080 break
3081 await send(message)
3083 # Propagate middleware-resolved context via ASGI scope so that MCP
3084 # handlers can retrieve it even when ContextVars are lost (the SDK's
3085 # task group was created at startup, so spawned handler tasks inherit
3086 # the startup context rather than the per-request context).
3087 scope[_MCPGATEWAY_CONTEXT_KEY] = {
3088 "server_id": server_id_var.get(),
3089 "request_headers": headers,
3090 "user_context": user_context,
3091 }
3093 buffered_request_body = bytearray()
3094 initialize_span_cm: Optional[ContextManager[Any]] = None
3095 initialize_span_stack: Optional[ExitStack] = None
3096 initialize_span_active = False
3098 async def receive_with_initialize_trace() -> Dict[str, Any]:
3099 """Capture initialize requests so the public MCP handshake is traced.
3101 Returns:
3102 The next ASGI receive message, with initialize payloads recorded so
3103 tracing can wrap the SDK-managed handshake path.
3104 """
3105 nonlocal initialize_span_cm, initialize_span_stack, initialize_span_active
3106 message = await receive()
3107 if method == "POST" and not initialize_span_active and message.get("type") == "http.request":
3108 buffered_request_body.extend(message.get("body", b""))
3109 if not message.get("more_body", False):
3110 initialize_span_cm = _maybe_open_initialize_span(
3111 bytes(buffered_request_body),
3112 mcp_session_id=mcp_session_id,
3113 server_id=validated,
3114 )
3115 if initialize_span_cm is not None:
3116 initialize_span_stack = ExitStack()
3117 initialize_span_stack.enter_context(initialize_span_cm)
3118 initialize_span_active = True
3119 return message
3121 span_exit_exc: tuple[Any, Any, Any] = (None, None, None)
3123 try:
3124 await self.session_manager.handle_request(scope, receive_with_initialize_trace, send_with_capture)
3125 logger.debug("[STATEFUL] Streamable HTTP request completed successfully | Session: %s", mcp_session_id)
3127 # Register ownership for the session we just handled
3128 # This captures both existing sessions (mcp_session_id from request)
3129 # and new sessions (captured_session_id from response)
3130 logger.debug(
3131 "[HTTP_AFFINITY_DEBUG] affinity_enabled=%s stateful=%s captured=%s mcp_session_id=%s",
3132 settings.mcpgateway_session_affinity_enabled,
3133 settings.use_stateful_sessions,
3134 captured_session_id,
3135 mcp_session_id,
3136 )
3137 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions:
3138 session_to_register: Optional[str] = None
3140 # Only server-emitted session IDs (from successful initialize) can
3141 # establish new ownership state for affinity.
3142 if captured_session_id:
3143 session_to_register = captured_session_id
3145 requester_email = user_context.get("email") if isinstance(user_context, dict) else None
3146 if requester_email:
3147 effective_owner = await _claim_streamable_session_owner(captured_session_id, requester_email)
3148 if effective_owner and effective_owner != requester_email and not bool(user_context.get("is_admin", False)):
3149 logger.warning("Session owner mismatch for %s... (requester=%s, owner=%s)", captured_session_id[:8], requester_email, effective_owner)
3150 elif mcp_session_id != "not-provided":
3151 # Existing client-provided IDs may only refresh affinity when they
3152 # are already bound to the caller's principal.
3153 session_allowed, _deny_status, _deny_detail = await _validate_streamable_session_access(
3154 mcp_session_id=mcp_session_id,
3155 user_context=user_context,
3156 rpc_method=None,
3157 )
3158 if session_allowed:
3159 session_to_register = mcp_session_id
3161 logger.debug("[HTTP_AFFINITY_DEBUG] session_to_register=%s", session_to_register)
3162 if session_to_register:
3163 try:
3164 # First-Party - lazy import to avoid circular dependencies
3165 # First-Party
3166 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel
3168 pool = get_mcp_session_pool()
3169 await pool.register_pool_session_owner(session_to_register)
3170 logger.debug("[HTTP_AFFINITY_SDK] Worker %s | Session %s... | Registered ownership after SDK handling", WORKER_ID, session_to_register[:8])
3171 except Exception as e:
3172 logger.debug("[HTTP_AFFINITY_DEBUG] Exception during registration: %s", e)
3173 logger.warning("Failed to register session ownership: %s", e)
3175 except anyio.ClosedResourceError:
3176 # Expected when client closes one side of the stream (normal lifecycle)
3177 logger.debug("Streamable HTTP connection closed by client (ClosedResourceError)")
3178 except Exception as e:
3179 span_exit_exc = (type(e), e, e.__traceback__)
3180 logger.error("[STATEFUL] Streamable HTTP request failed | Session: %s | Error: %s", mcp_session_id, e)
3181 logger.exception("Error handling streamable HTTP request: %s", e)
3182 raise
3183 finally:
3184 if initialize_span_active and initialize_span_stack is not None:
3185 initialize_span_stack.__exit__(*span_exit_exc)
3188# ------------------------- Authentication for /mcp routes ------------------------------
3191def _set_proxy_user_context(proxy_user: str) -> None:
3192 """Set user context for a proxy-authenticated request (no team context, non-admin).
3194 Args:
3195 proxy_user: Email address of the proxy-authenticated user.
3196 """
3197 user_context_var.set(
3198 {
3199 "email": proxy_user,
3200 "teams": [],
3201 "is_authenticated": True,
3202 "is_admin": False,
3203 "permission_is_admin": False,
3204 "auth_method": "proxy",
3205 }
3206 )
3207 set_trace_context_from_teams([], user_email=proxy_user, is_admin=False, auth_method="proxy")
3210def get_streamable_http_auth_context() -> dict[str, Any]:
3211 """Return the current StreamableHTTP auth context for trusted internal forwarding.
3213 The Rust MCP proxy uses this to carry already-authenticated MCP request context
3214 across the Python -> Rust -> Python seam so the internal dispatcher does not
3215 need to repeat JWT verification and team normalization on the hot path.
3217 Returns:
3218 A shallow copy of the trusted auth context fields that may be forwarded
3219 across the internal MCP seam.
3220 """
3221 user_context = user_context_var.get()
3222 if not isinstance(user_context, dict):
3223 return {}
3225 forwarded: dict[str, Any] = {}
3226 for key in (
3227 "email",
3228 "teams",
3229 "team_name",
3230 "is_authenticated",
3231 "is_admin",
3232 "auth_method",
3233 "token_use",
3234 "permission_is_admin",
3235 "scoped_permissions",
3236 "scoped_server_id",
3237 ):
3238 if key not in user_context:
3239 continue
3240 value = user_context[key]
3241 if isinstance(value, list):
3242 forwarded[key] = list(value)
3243 else:
3244 forwarded[key] = value
3245 return forwarded
3248class _StreamableHttpAuthHandler:
3249 """Per-request handler that authenticates MCP StreamableHTTP requests.
3251 Encapsulates the ASGI triple (scope, receive, send) so that helper methods
3252 can send error responses without threading these values through every call.
3253 """
3255 __slots__ = ("scope", "receive", "send")
3257 def __init__(self, scope: Any, receive: Any, send: Any) -> None:
3258 self.scope = scope
3259 self.receive = receive
3260 self.send = send
3262 async def _send_error(self, *, detail: str, status_code: int = HTTP_401_UNAUTHORIZED, headers: dict[str, str] | None = None) -> bool:
3263 """Send an error response and return False (auth rejected).
3265 Args:
3266 detail: Error message for the JSON response body.
3267 status_code: HTTP status code (default 401).
3268 headers: Optional response headers (e.g. WWW-Authenticate).
3270 Returns:
3271 Always ``False`` so callers can ``return await self._send_error(...)``.
3272 """
3273 response = ORJSONResponse({"detail": detail}, status_code=status_code, headers=headers or {})
3274 await response(self.scope, self.receive, self.send)
3275 return False
3277 async def authenticate(self) -> bool:
3278 """Perform authentication check in middleware context (ASGI scope).
3280 Authenticates requests targeting MCP transport paths: ``/mcp``, ``/mcp/``,
3281 ``/mcp/sse``, and ``/mcp/message`` (including ``/servers/{id}/...`` prefixed variants).
3283 Behavior:
3284 - If the path is not an MCP transport path, authentication is skipped.
3285 - If mcp_require_auth=True (strict mode): requests without valid auth are rejected with 401.
3286 - If mcp_require_auth=False (permissive mode):
3287 - Requests without auth are allowed but get public-only access (token_teams=[]).
3288 - EXCEPTION: if the target server has oauth_enabled=True, unauthenticated
3289 requests are rejected with 401 regardless of the global setting.
3290 - Valid tokens get full scoped access based on their teams.
3291 - Malformed/invalid Bearer tokens are rejected with 401 (no silent downgrade).
3292 - If a Bearer token is present, it is verified using ``verify_credentials``.
3294 Returns:
3295 True if authentication passes or is skipped.
3296 False if authentication fails and a 401 response is sent.
3297 """
3298 path = self.scope.get("path", "")
3299 # Normalize trailing slash for consistent matching
3300 normalized = path.rstrip("/")
3301 # Check if this is an MCP-related path that requires authentication.
3302 # path.startswith("/mcp/") catches /mcp/{server_id} paths that the
3303 # Starlette mount at /mcp routes but that don't endswith("/mcp").
3304 is_mcp_path = normalized.endswith("/mcp") or normalized == "/mcp" or normalized.endswith("/mcp/sse") or normalized.endswith("/mcp/message") or path.startswith("/mcp/")
3305 if not is_mcp_path or path.startswith("/.well-known/"):
3306 # No auth for non-MCP paths or RFC 9728 metadata endpoints
3307 return True
3309 # Reject undocumented /mcp/* sub-paths that the Starlette mount would
3310 # otherwise route to the global MCP handler. Only /mcp, /mcp/,
3311 # /mcp/sse, and /mcp/message are valid direct-access endpoints;
3312 # server-scoped access uses /servers/{id}/mcp (rewritten by middleware).
3313 if path.startswith("/mcp/"):
3314 _sub = normalized.removeprefix("/mcp")
3315 if _sub and _sub not in ("/sse", "/message"):
3316 return await self._send_error(detail="Not found", status_code=404)
3318 # Reset per-request OAuth enforcement cache so keep-alive connections
3319 # re-evaluate on every request instead of inheriting a stale True.
3320 _oauth_checked_var.set(False)
3322 headers = Headers(scope=self.scope)
3324 # CORS preflight (OPTIONS + Origin + Access-Control-Request-Method) cannot carry auth headers
3325 method = self.scope.get("method", "")
3326 if method == "OPTIONS":
3327 origin = headers.get("origin")
3328 if origin and headers.get("access-control-request-method"):
3329 return True
3331 authorization = headers.get("authorization")
3332 proxy_trusted = is_proxy_auth_trust_active(settings)
3333 proxy_user = headers.get(settings.proxy_user_header) if proxy_trusted else None
3335 # Determine authentication strategy based on settings
3336 if proxy_trusted and proxy_user:
3337 _set_proxy_user_context(proxy_user)
3338 return True # Trusted proxy supplied user
3340 # --- Standard JWT authentication flow (client auth enabled) ---
3341 token: str | None = None
3342 bearer_header_supplied = False
3343 if authorization:
3344 scheme, credentials = get_authorization_scheme_param(authorization)
3345 if scheme.lower() == "bearer":
3346 bearer_header_supplied = True
3347 if credentials:
3348 token = credentials
3350 if token is None:
3351 return await self._auth_no_token(path=path, bearer_header_supplied=bearer_header_supplied)
3353 return await self._auth_jwt(token=token)
3355 async def _auth_no_token(self, *, path: str, bearer_header_supplied: bool) -> bool:
3356 """Handle unauthenticated MCP requests (no Bearer token present).
3358 Args:
3359 path: Request path (used for per-server OAuth enforcement).
3360 bearer_header_supplied: True when Authorization: Bearer was present but empty.
3362 Returns:
3363 True if the request is allowed with public-only access, False if rejected.
3364 """
3365 # If client supplied a Bearer header but with empty credentials, fail closed
3366 if bearer_header_supplied:
3367 return await self._send_error(detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"})
3369 # Per-server OAuth enforcement MUST run before the global auth check so that
3370 # oauth_enabled servers always return 401 with resource_metadata URL (RFC 9728).
3371 # Without this, strict mode (mcp_require_auth=True) returns a generic
3372 # WWW-Authenticate: Bearer with no resource_metadata, and MCP clients cannot
3373 # discover the OAuth server to authenticate. (Fixes #3752)
3374 match = _SERVER_ID_RE.search(path)
3375 if match:
3376 per_server_id = match.group("server_id")
3377 try:
3378 await _check_server_oauth_enforcement(per_server_id, {"is_authenticated": False})
3379 except OAuthRequiredError:
3380 resource_metadata = _build_resource_metadata_url(self.scope, per_server_id)
3381 www_auth = f'Bearer resource_metadata="{resource_metadata}"' if resource_metadata else "Bearer"
3382 return await self._send_error(detail="This server requires OAuth authentication", headers={"WWW-Authenticate": www_auth})
3383 except OAuthEnforcementUnavailableError:
3384 logger.exception("OAuth enforcement check failed for server %s", per_server_id)
3385 return await self._send_error(detail="Service unavailable — unable to verify server authentication requirements", status_code=503)
3387 # Strict mode: require authentication (non-OAuth servers get generic 401)
3388 if settings.mcp_require_auth:
3389 return await self._send_error(detail="Authentication required for MCP endpoints", headers={"WWW-Authenticate": "Bearer"})
3391 # Permissive mode: allow unauthenticated access with public-only scope
3392 # Set context indicating unauthenticated user with public-only access (teams=[])
3393 user_context_var.set(
3394 {
3395 "email": None,
3396 "teams": [], # Empty list = public-only access
3397 "is_authenticated": False,
3398 "is_admin": False,
3399 "permission_is_admin": False,
3400 "auth_method": "anonymous",
3401 }
3402 )
3403 set_trace_context_from_teams([], auth_method="anonymous")
3404 return True # Allow request to proceed with public-only access
3406 async def _auth_jwt(self, *, token: str) -> bool:
3407 """Verify a JWT Bearer token and populate the user context.
3409 Args:
3410 token: Bearer token value extracted from the Authorization header.
3412 Returns:
3413 True if verification succeeds, False if rejected (401/403/503 sent).
3414 """
3415 try:
3416 user_payload = await verify_credentials(token)
3417 # Store enriched user context with normalized teams
3418 if not isinstance(user_payload, dict):
3419 return True
3421 # First-Party
3422 from mcpgateway.auth import _get_auth_context_batched_sync, resolve_trace_team_name # pylint: disable=import-outside-toplevel
3423 from mcpgateway.cache.auth_cache import CachedAuthContext, get_auth_cache # pylint: disable=import-outside-toplevel
3425 jti = user_payload.get("jti")
3426 user_email = user_payload.get("sub") or user_payload.get("email")
3427 nested_user = user_payload.get("user", {})
3428 nested_is_admin = nested_user.get("is_admin", False) if isinstance(nested_user, dict) else False
3429 is_admin = user_payload.get("is_admin", False) or nested_is_admin
3430 token_use = user_payload.get("token_use")
3431 db_user_is_admin = False
3432 user_record = None
3433 auth_cache = get_auth_cache() if settings.auth_cache_enabled else None
3434 cached_ctx: CachedAuthContext | None = None
3435 batched_auth_ctx: dict[str, Any] | None = None
3436 cached_team_ids: list[str] | None = None
3437 platform_admin_email = getattr(settings, "platform_admin_email", "admin@example.com")
3439 if user_email and auth_cache is not None:
3440 try:
3441 cached_ctx = await auth_cache.get_auth_context(user_email, jti)
3442 if cached_ctx is not None:
3443 _record_mcp_auth_cache_event("auth_context_hit")
3444 if cached_ctx.is_token_revoked:
3445 _record_mcp_auth_cache_event("auth_context_hit_revoked")
3446 return await self._send_error(detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"})
3448 cached_user = cached_ctx.user
3449 if cached_user and not cached_user.get("is_active", True):
3450 _record_mcp_auth_cache_event("auth_context_hit_inactive")
3451 return await self._send_error(detail="Account disabled", headers={"WWW-Authenticate": "Bearer"})
3453 if cached_user:
3454 db_user_is_admin = bool(cached_user.get("is_admin", False))
3455 elif settings.require_user_in_db and user_email != platform_admin_email:
3456 return await self._send_error(detail="User not found in database", headers={"WWW-Authenticate": "Bearer"})
3458 if token_use == "session" and not is_admin: # nosec B105 - token_use is a JWT claim type, not a password
3459 cached_team_ids = await auth_cache.get_user_teams(f"{user_email}:True")
3460 if cached_team_ids is not None:
3461 _record_mcp_auth_cache_event("teams_cache_hit")
3462 else:
3463 _record_mcp_auth_cache_event("auth_context_miss")
3464 except HTTPException:
3465 raise
3466 except Exception as cache_error:
3467 _record_mcp_auth_cache_event("auth_context_cache_error")
3468 logger.debug("MCP auth cache lookup failed for %s: %s", user_email, cache_error)
3469 cached_ctx = None
3471 if user_email and cached_ctx is None and settings.auth_cache_batch_queries:
3472 try:
3473 batched_auth_ctx = await asyncio.to_thread(_get_auth_context_batched_sync, user_email, jti)
3474 _record_mcp_auth_cache_event("auth_context_batch_hit")
3475 if batched_auth_ctx.get("is_token_revoked", False):
3476 _record_mcp_auth_cache_event("auth_context_batch_revoked")
3477 return await self._send_error(detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"})
3479 cached_user = batched_auth_ctx.get("user")
3480 if cached_user and not cached_user.get("is_active", True):
3481 _record_mcp_auth_cache_event("auth_context_batch_inactive")
3482 return await self._send_error(detail="Account disabled", headers={"WWW-Authenticate": "Bearer"})
3484 if cached_user:
3485 db_user_is_admin = bool(cached_user.get("is_admin", False))
3486 elif settings.require_user_in_db and user_email != platform_admin_email:
3487 return await self._send_error(detail="User not found in database", headers={"WWW-Authenticate": "Bearer"})
3489 if auth_cache is not None:
3490 await auth_cache.set_auth_context(
3491 user_email,
3492 jti,
3493 CachedAuthContext(
3494 user=cached_user,
3495 personal_team_id=batched_auth_ctx.get("personal_team_id"),
3496 is_token_revoked=bool(batched_auth_ctx.get("is_token_revoked", False)),
3497 ),
3498 )
3499 if token_use == "session" and not is_admin: # nosec B105 - token_use is a JWT claim type, not a password
3500 cached_team_ids = list(batched_auth_ctx.get("team_ids") or [])
3501 await auth_cache.set_user_teams(f"{user_email}:True", cached_team_ids)
3502 _record_mcp_auth_cache_event("teams_batch_hit")
3503 except HTTPException:
3504 raise
3505 except Exception as batch_error:
3506 _record_mcp_auth_cache_event("auth_context_batch_error")
3507 logger.warning("Batched MCP auth lookup failed for user=%s; falling back to individual checks: %s", user_email, batch_error)
3508 batched_auth_ctx = None
3510 if user_email and cached_ctx is None and batched_auth_ctx is None:
3511 _record_mcp_auth_cache_event("auth_context_fallback")
3512 # First-Party
3513 from mcpgateway.auth import _check_token_revoked_sync, _get_user_by_email_sync # pylint: disable=import-outside-toplevel
3515 is_revoked = False
3516 if jti:
3517 try:
3518 is_revoked = await asyncio.to_thread(_check_token_revoked_sync, jti)
3519 except Exception as exc:
3520 logger.warning("MCP token revocation check failed for jti=%s; allowing request (fail-open): %s", jti, exc)
3521 is_revoked = False
3522 if is_revoked:
3523 return await self._send_error(detail="Token has been revoked", headers={"WWW-Authenticate": "Bearer"})
3525 user_lookup_succeeded = True
3526 try:
3527 user_record = await asyncio.to_thread(_get_user_by_email_sync, user_email)
3528 except Exception as exc:
3529 user_lookup_succeeded = False
3530 user_record = None
3531 logger.warning("MCP user lookup failed for user=%s; allowing request (fail-open): %s", user_email, exc)
3533 if user_lookup_succeeded:
3534 if user_record and not getattr(user_record, "is_active", True):
3535 return await self._send_error(detail="Account disabled", headers={"WWW-Authenticate": "Bearer"})
3536 if user_record:
3537 db_user_is_admin = bool(getattr(user_record, "is_admin", False))
3538 if user_record is None and settings.require_user_in_db and user_email != platform_admin_email:
3539 return await self._send_error(detail="User not found in database", headers={"WWW-Authenticate": "Bearer"})
3541 if auth_cache is not None:
3542 try:
3543 await auth_cache.set_auth_context(
3544 user_email,
3545 jti,
3546 CachedAuthContext(
3547 user=(
3548 {
3549 "email": user_record.email,
3550 "password_hash": user_record.password_hash,
3551 "full_name": user_record.full_name,
3552 "is_admin": bool(user_record.is_admin),
3553 "is_active": bool(user_record.is_active),
3554 "auth_provider": user_record.auth_provider,
3555 "password_change_required": bool(user_record.password_change_required),
3556 "email_verified_at": user_record.email_verified_at,
3557 "created_at": user_record.created_at,
3558 "updated_at": user_record.updated_at,
3559 }
3560 if user_record is not None
3561 else None
3562 ),
3563 personal_team_id=None,
3564 is_token_revoked=is_revoked,
3565 ),
3566 )
3567 except Exception as cache_set_error:
3568 logger.debug("Failed to cache MCP auth context for %s: %s", user_email, cache_set_error)
3570 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type
3571 # Session token: resolve teams via single policy point (DB-first intersection)
3572 # First-Party
3573 from mcpgateway.auth import resolve_session_teams # pylint: disable=import-outside-toplevel
3575 if cached_team_ids is not None:
3576 final_teams = await resolve_session_teams(user_payload, user_email, {"is_admin": is_admin}, preresolved_db_teams=cached_team_ids)
3577 elif batched_auth_ctx is not None:
3578 preresolved = None if is_admin else list(batched_auth_ctx.get("team_ids") or [])
3579 final_teams = await resolve_session_teams(user_payload, user_email, {"is_admin": is_admin}, preresolved_db_teams=preresolved)
3580 else:
3581 _record_mcp_auth_cache_event("teams_db_resolve")
3582 final_teams = await resolve_session_teams(user_payload, user_email, {"is_admin": is_admin})
3583 else:
3584 # API token or legacy: use embedded teams from JWT
3585 # First-Party
3586 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel
3588 final_teams = normalize_token_teams(user_payload)
3590 # ═══════════════════════════════════════════════════════════════════════════
3591 # SECURITY: Validate team membership for team-scoped tokens
3592 # Users removed from a team should lose MCP access immediately, not at token expiry
3593 # ═══════════════════════════════════════════════════════════════════════════
3594 # Validate membership for API/legacy tokens whose teams come from
3595 # the JWT and have never been checked against the DB. Session tokens
3596 # are skipped: resolve_session_teams() already resolved teams from
3597 # DB/cache, so a second membership query would be redundant.
3598 if token_use != "session" and final_teams and len(final_teams) > 0 and user_email: # nosec B105
3599 # Import lazily to avoid circular imports
3600 # Third-Party
3601 from sqlalchemy import select # pylint: disable=import-outside-toplevel
3603 # First-Party
3604 from mcpgateway.cache.auth_cache import get_auth_cache # pylint: disable=import-outside-toplevel
3605 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel
3607 auth_cache = get_auth_cache()
3609 # Check cache first (60s TTL)
3610 cached_result = auth_cache.get_team_membership_valid_sync(user_email, final_teams)
3611 if cached_result is False:
3612 _record_mcp_auth_cache_event("team_membership_cache_reject")
3613 logger.warning("MCP auth rejected: User %s no longer member of teams (cached)", user_email)
3614 return await self._send_error(detail="Token invalid: User is no longer a member of the associated team", status_code=HTTP_403_FORBIDDEN)
3616 if cached_result is None:
3617 _record_mcp_auth_cache_event("team_membership_cache_miss")
3618 # Cache miss - query database
3619 with SessionLocal() as db:
3620 memberships = (
3621 db.execute(
3622 select(EmailTeamMember.team_id).where(
3623 EmailTeamMember.team_id.in_(final_teams),
3624 EmailTeamMember.user_email == user_email,
3625 EmailTeamMember.is_active.is_(True),
3626 )
3627 )
3628 .scalars()
3629 .all()
3630 )
3632 valid_team_ids = set(memberships)
3633 missing_teams = set(final_teams) - valid_team_ids
3635 if missing_teams:
3636 logger.warning("MCP auth rejected: User %s no longer member of teams: %s", user_email, missing_teams)
3637 auth_cache.set_team_membership_valid_sync(user_email, final_teams, False)
3638 return await self._send_error(detail="Token invalid: User is no longer a member of the associated team", status_code=HTTP_403_FORBIDDEN)
3640 # Cache positive result
3641 auth_cache.set_team_membership_valid_sync(user_email, final_teams, True)
3642 else:
3643 _record_mcp_auth_cache_event("team_membership_cache_hit")
3645 auth_user_ctx: dict[str, Any] = {
3646 "email": user_email,
3647 "teams": final_teams,
3648 "is_authenticated": True,
3649 "is_admin": is_admin,
3650 "auth_method": "jwt",
3651 "permission_is_admin": db_user_is_admin or is_admin,
3652 "token_use": token_use, # propagated for downstream RBAC (check_any_team)
3653 }
3654 trace_team_name = await resolve_trace_team_name(user_payload, final_teams, preresolved_team_names=batched_auth_ctx.get("team_names") if batched_auth_ctx else None)
3655 if trace_team_name:
3656 auth_user_ctx["team_name"] = trace_team_name
3657 # Extract scoped permissions from JWT for per-method enforcement
3658 jwt_scopes = user_payload.get("scopes") or {}
3659 jwt_scoped_perms = jwt_scopes.get("permissions") or [] if isinstance(jwt_scopes, dict) else []
3660 if jwt_scoped_perms:
3661 auth_user_ctx["scoped_permissions"] = jwt_scoped_perms
3662 scoped_server_id = jwt_scopes.get("server_id") if isinstance(jwt_scopes, dict) else None
3663 if isinstance(scoped_server_id, str) and scoped_server_id:
3664 auth_user_ctx["scoped_server_id"] = scoped_server_id
3665 user_context_var.set(auth_user_ctx)
3666 set_trace_context_from_teams(
3667 final_teams,
3668 user_email=user_email,
3669 is_admin=bool(db_user_is_admin or is_admin),
3670 auth_method="jwt",
3671 team_name=trace_team_name,
3672 )
3673 except HTTPException:
3674 # JWT verification failed (expired, malformed, bad signature, etc.)
3675 return await self._send_error(detail="Invalid authentication credentials", headers={"WWW-Authenticate": "Bearer"})
3676 except SQLAlchemyError:
3677 # DB failure during team resolution or membership validation
3678 logger.exception("Database error during MCP authentication")
3679 return await self._send_error(detail="Service unavailable — unable to verify authentication", status_code=503)
3680 except Exception:
3681 # Unexpected error during authentication — fail closed with 401.
3682 logger.exception("Unexpected error during MCP JWT authentication")
3683 return await self._send_error(detail="Authentication failed", headers={"WWW-Authenticate": "Bearer"})
3685 return True
3688async def streamable_http_auth(scope: Any, receive: Any, send: Any) -> bool:
3689 """Perform authentication check in middleware context (ASGI scope).
3691 Delegates to :class:`_StreamableHttpAuthHandler` which encapsulates the
3692 ASGI triple so helper methods can send error responses directly.
3694 Args:
3695 scope: The ASGI scope dictionary, which includes request metadata.
3696 receive: ASGI receive callable used to receive events.
3697 send: ASGI send callable used to send events (e.g. a 401 response).
3699 Returns:
3700 bool: True if authentication passes or is skipped.
3701 False if authentication fails and a 401 response is sent.
3703 Examples:
3704 >>> # Test streamable_http_auth function exists
3705 >>> callable(streamable_http_auth)
3706 True
3708 >>> # Test function signature
3709 >>> import inspect
3710 >>> sig = inspect.signature(streamable_http_auth)
3711 >>> list(sig.parameters.keys())
3712 ['scope', 'receive', 'send']
3713 """
3714 return await _StreamableHttpAuthHandler(scope, receive, send).authenticate()