Coverage for mcpgateway / transports / streamablehttp_transport.py: 100%

745 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/streamablehttp_transport.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Keval Mahajan 

6 

7Streamable HTTP Transport Implementation. 

8This module implements Streamable Http transport for MCP 

9 

10Key components include: 

11- SessionManagerWrapper: Manages the lifecycle of streamable HTTP sessions 

12- Configuration options for: 

13 1. stateful/stateless operation 

14 2. JSON response mode or SSE streams 

15- InMemoryEventStore: A simple in-memory event storage system for maintaining session state 

16 

17Examples: 

18 >>> # Test module imports 

19 >>> from mcpgateway.transports.streamablehttp_transport import ( 

20 ... EventEntry, StreamBuffer, InMemoryEventStore, SessionManagerWrapper 

21 ... ) 

22 >>> 

23 >>> # Verify classes are available 

24 >>> EventEntry.__name__ 

25 'EventEntry' 

26 >>> StreamBuffer.__name__ 

27 'StreamBuffer' 

28 >>> InMemoryEventStore.__name__ 

29 'InMemoryEventStore' 

30 >>> SessionManagerWrapper.__name__ 

31 'SessionManagerWrapper' 

32""" 

33 

34# Standard 

35import asyncio 

36from contextlib import asynccontextmanager, AsyncExitStack 

37import contextvars 

38from dataclasses import dataclass 

39import re 

40from typing import Any, AsyncGenerator, Dict, List, Optional, Pattern, Union 

41from uuid import uuid4 

42 

43# Third-Party 

44import anyio 

45from fastapi.security.utils import get_authorization_scheme_param 

46import httpx 

47from mcp import types 

48from mcp.server.lowlevel import Server 

49from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId 

50from mcp.server.streamable_http_manager import StreamableHTTPSessionManager 

51from mcp.types import JSONRPCMessage 

52import orjson 

53from sqlalchemy.orm import Session 

54from starlette.datastructures import Headers 

55from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN 

56from starlette.types import Receive, Scope, Send 

57 

58# First-Party 

59from mcpgateway.common.models import LogLevel 

60from mcpgateway.config import settings 

61from mcpgateway.db import SessionLocal 

62from mcpgateway.services.completion_service import CompletionService 

63from mcpgateway.services.logging_service import LoggingService 

64from mcpgateway.services.prompt_service import PromptService 

65from mcpgateway.services.resource_service import ResourceService 

66from mcpgateway.services.tool_service import ToolService 

67from mcpgateway.transports.redis_event_store import RedisEventStore 

68from mcpgateway.utils.orjson_response import ORJSONResponse 

69from mcpgateway.utils.verify_credentials import verify_credentials 

70 

71# Initialize logging service first 

72logging_service = LoggingService() 

73logger = logging_service.get_logger(__name__) 

74 

75# Precompiled regex for server ID extraction from path 

76_SERVER_ID_RE: Pattern[str] = re.compile(r"/servers/(?P<server_id>[a-fA-F0-9\-]+)/mcp") 

77 

78# Initialize ToolService, PromptService, ResourceService, CompletionService and MCP Server 

79tool_service: ToolService = ToolService() 

80prompt_service: PromptService = PromptService() 

81resource_service: ResourceService = ResourceService() 

82completion_service: CompletionService = CompletionService() 

83 

84mcp_app: Server[Any] = Server("mcp-streamable-http") 

85 

86server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id") 

87request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={}) 

88user_context_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("user_context", default={}) 

89 

90# ------------------------------ Event store ------------------------------ 

91 

92 

93@dataclass 

94class EventEntry: 

95 """ 

96 Represents an event entry in the event store. 

97 

98 Examples: 

99 >>> # Create an event entry 

100 >>> from mcp.types import JSONRPCMessage 

101 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1) 

102 >>> entry = EventEntry(event_id="test-123", stream_id="stream-456", message=message, seq_num=0) 

103 >>> entry.event_id 

104 'test-123' 

105 >>> entry.stream_id 

106 'stream-456' 

107 >>> entry.seq_num 

108 0 

109 >>> # Access message attributes through model_dump() for Pydantic v2 

110 >>> message_dict = message.model_dump() 

111 >>> message_dict['jsonrpc'] 

112 '2.0' 

113 >>> message_dict['method'] 

114 'test' 

115 >>> message_dict['id'] 

116 1 

117 """ 

118 

119 event_id: EventId 

120 stream_id: StreamId 

121 message: JSONRPCMessage 

122 seq_num: int 

123 

124 

125@dataclass 

126class StreamBuffer: 

127 """ 

128 Ring buffer for per-stream event storage with O(1) position lookup. 

129 

130 Tracks sequence numbers to enable efficient replay without scanning. 

131 Events are stored at position (seq_num % capacity) in the entries list. 

132 

133 Examples: 

134 >>> # Create a stream buffer with capacity 3 

135 >>> buffer = StreamBuffer(entries=[None, None, None]) 

136 >>> buffer.start_seq 

137 0 

138 >>> buffer.next_seq 

139 0 

140 >>> buffer.count 

141 0 

142 >>> len(buffer) 

143 0 

144 

145 >>> # Simulate adding an entry 

146 >>> buffer.next_seq = 1 

147 >>> buffer.count = 1 

148 >>> len(buffer) 

149 1 

150 """ 

151 

152 entries: list[EventEntry | None] 

153 start_seq: int = 0 # oldest seq still buffered 

154 next_seq: int = 0 # seq assigned to next insert 

155 count: int = 0 

156 

157 def __len__(self) -> int: 

158 """Return the number of events currently in the buffer. 

159 

160 Returns: 

161 int: The count of events in the buffer. 

162 """ 

163 return self.count 

164 

165 

166class InMemoryEventStore(EventStore): 

167 """ 

168 Simple in-memory implementation of the EventStore interface for resumability. 

169 This is primarily intended for examples and testing, not for production use 

170 where a persistent storage solution would be more appropriate. 

171 

172 This implementation keeps only the last N events per stream for memory efficiency. 

173 Uses a ring buffer with per-stream sequence numbers for O(1) event lookup and O(k) replay. 

174 

175 Examples: 

176 >>> # Create event store with default max events 

177 >>> store = InMemoryEventStore() 

178 >>> store.max_events_per_stream 

179 100 

180 >>> len(store.streams) 

181 0 

182 >>> len(store.event_index) 

183 0 

184 

185 >>> # Create event store with custom max events 

186 >>> store = InMemoryEventStore(max_events_per_stream=50) 

187 >>> store.max_events_per_stream 

188 50 

189 

190 >>> # Test event store initialization 

191 >>> store = InMemoryEventStore() 

192 >>> hasattr(store, 'streams') 

193 True 

194 >>> hasattr(store, 'event_index') 

195 True 

196 >>> isinstance(store.streams, dict) 

197 True 

198 >>> isinstance(store.event_index, dict) 

199 True 

200 """ 

201 

202 def __init__(self, max_events_per_stream: int = 100): 

203 """Initialize the event store. 

204 

205 Args: 

206 max_events_per_stream: Maximum number of events to keep per stream 

207 

208 Examples: 

209 >>> # Test initialization with default value 

210 >>> store = InMemoryEventStore() 

211 >>> store.max_events_per_stream 

212 100 

213 >>> store.streams == {} 

214 True 

215 >>> store.event_index == {} 

216 True 

217 

218 >>> # Test initialization with custom value 

219 >>> store = InMemoryEventStore(max_events_per_stream=25) 

220 >>> store.max_events_per_stream 

221 25 

222 """ 

223 self.max_events_per_stream = max_events_per_stream 

224 # Per-stream ring buffers for O(1) position lookup 

225 self.streams: dict[StreamId, StreamBuffer] = {} 

226 # event_id -> EventEntry for quick lookup 

227 self.event_index: dict[EventId, EventEntry] = {} 

228 

229 async def store_event(self, stream_id: StreamId, message: JSONRPCMessage) -> EventId: 

230 """ 

231 Stores an event with a generated event ID. 

232 

233 Args: 

234 stream_id (StreamId): The ID of the stream. 

235 message (JSONRPCMessage): The message to store. 

236 

237 Returns: 

238 EventId: The ID of the stored event. 

239 

240 Examples: 

241 >>> # Test storing an event 

242 >>> import asyncio 

243 >>> from mcp.types import JSONRPCMessage 

244 >>> store = InMemoryEventStore(max_events_per_stream=5) 

245 >>> message = JSONRPCMessage(jsonrpc="2.0", method="test", id=1) 

246 >>> event_id = asyncio.run(store.store_event("stream-1", message)) 

247 >>> isinstance(event_id, str) 

248 True 

249 >>> len(event_id) > 0 

250 True 

251 >>> len(store.streams) 

252 1 

253 >>> len(store.event_index) 

254 1 

255 >>> "stream-1" in store.streams 

256 True 

257 >>> event_id in store.event_index 

258 True 

259 

260 >>> # Test storing multiple events in same stream 

261 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2) 

262 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2)) 

263 >>> len(store.streams["stream-1"]) 

264 2 

265 >>> len(store.event_index) 

266 2 

267 

268 >>> # Test ring buffer overflow 

269 >>> store2 = InMemoryEventStore(max_events_per_stream=2) 

270 >>> msg1 = JSONRPCMessage(jsonrpc="2.0", method="m1", id=1) 

271 >>> msg2 = JSONRPCMessage(jsonrpc="2.0", method="m2", id=2) 

272 >>> msg3 = JSONRPCMessage(jsonrpc="2.0", method="m3", id=3) 

273 >>> id1 = asyncio.run(store2.store_event("stream-2", msg1)) 

274 >>> id2 = asyncio.run(store2.store_event("stream-2", msg2)) 

275 >>> # Now buffer is full, adding third will remove first 

276 >>> id3 = asyncio.run(store2.store_event("stream-2", msg3)) 

277 >>> len(store2.streams["stream-2"]) 

278 2 

279 >>> id1 in store2.event_index # First event removed 

280 False 

281 >>> id2 in store2.event_index and id3 in store2.event_index 

282 True 

283 """ 

284 # Get or create ring buffer for this stream 

285 buffer = self.streams.get(stream_id) 

286 if buffer is None: 

287 buffer = StreamBuffer(entries=[None] * self.max_events_per_stream) 

288 self.streams[stream_id] = buffer 

289 

290 # Assign per-stream sequence number 

291 seq_num = buffer.next_seq 

292 buffer.next_seq += 1 

293 idx = seq_num % self.max_events_per_stream 

294 

295 # Handle eviction if buffer is full 

296 if buffer.count == self.max_events_per_stream: 

297 evicted = buffer.entries[idx] 

298 if evicted is not None: 

299 self.event_index.pop(evicted.event_id, None) 

300 buffer.start_seq += 1 

301 else: 

302 if buffer.count == 0: 

303 buffer.start_seq = seq_num 

304 buffer.count += 1 

305 

306 # Create and store the new event entry 

307 event_id = str(uuid4()) 

308 event_entry = EventEntry(event_id=event_id, stream_id=stream_id, message=message, seq_num=seq_num) 

309 buffer.entries[idx] = event_entry 

310 self.event_index[event_id] = event_entry 

311 

312 return event_id 

313 

314 async def replay_events_after( 

315 self, 

316 last_event_id: EventId, 

317 send_callback: EventCallback, 

318 ) -> Union[StreamId, None]: 

319 """ 

320 Replays events that occurred after the specified event ID. 

321 

322 Uses O(1) lookup via event_index and O(k) replay where k is the number 

323 of events to replay, avoiding the previous O(n) full scan. 

324 

325 Args: 

326 last_event_id (EventId): The ID of the last received event. Replay starts after this event. 

327 send_callback (EventCallback): Async callback to send each replayed event. 

328 

329 Returns: 

330 StreamId | None: The stream ID if the event is found and replayed, otherwise None. 

331 

332 Examples: 

333 >>> # Test replaying events 

334 >>> import asyncio 

335 >>> from mcp.types import JSONRPCMessage 

336 >>> store = InMemoryEventStore() 

337 >>> message1 = JSONRPCMessage(jsonrpc="2.0", method="test1", id=1) 

338 >>> message2 = JSONRPCMessage(jsonrpc="2.0", method="test2", id=2) 

339 >>> message3 = JSONRPCMessage(jsonrpc="2.0", method="test3", id=3) 

340 >>> 

341 >>> # Store events 

342 >>> event_id1 = asyncio.run(store.store_event("stream-1", message1)) 

343 >>> event_id2 = asyncio.run(store.store_event("stream-1", message2)) 

344 >>> event_id3 = asyncio.run(store.store_event("stream-1", message3)) 

345 >>> 

346 >>> # Test replay after first event 

347 >>> replayed_events = [] 

348 >>> async def mock_callback(event_message): 

349 ... replayed_events.append(event_message) 

350 >>> 

351 >>> result = asyncio.run(store.replay_events_after(event_id1, mock_callback)) 

352 >>> result 

353 'stream-1' 

354 >>> len(replayed_events) 

355 2 

356 

357 >>> # Test replay with non-existent event 

358 >>> result = asyncio.run(store.replay_events_after("non-existent", mock_callback)) 

359 >>> result is None 

360 True 

361 """ 

362 # O(1) lookup in event_index 

363 last_event = self.event_index.get(last_event_id) 

364 if last_event is None: 

365 logger.warning(f"Event ID {last_event_id} not found in store") 

366 return None 

367 

368 buffer = self.streams.get(last_event.stream_id) 

369 if buffer is None: 

370 return None 

371 

372 # Validate that the event's seq_num is still within the buffer range 

373 if last_event.seq_num < buffer.start_seq or last_event.seq_num >= buffer.next_seq: 

374 return None 

375 

376 # O(k) replay: iterate from last_event.seq_num + 1 to buffer.next_seq - 1 

377 for seq in range(last_event.seq_num + 1, buffer.next_seq): 

378 entry = buffer.entries[seq % self.max_events_per_stream] 

379 # Guard: skip if slot is empty or has been overwritten by a different seq 

380 if entry is None or entry.seq_num != seq: 

381 continue 

382 await send_callback(EventMessage(entry.message, entry.event_id)) 

383 

384 return last_event.stream_id 

385 

386 

387# ------------------------------ Streamable HTTP Transport ------------------------------ 

388 

389 

390@asynccontextmanager 

391async def get_db() -> AsyncGenerator[Session, Any]: 

392 """ 

393 Asynchronous context manager for database sessions. 

394 

395 Commits the transaction on successful completion to avoid implicit rollbacks 

396 for read-only operations. Rolls back explicitly on exception. Handles 

397 asyncio.CancelledError explicitly to prevent transaction leaks when MCP 

398 handlers are cancelled (client disconnect, timeout, etc.). 

399 

400 Yields: 

401 A database session instance from SessionLocal. 

402 Ensures the session is closed after use. 

403 

404 Raises: 

405 asyncio.CancelledError: Re-raised after rollback and close on task cancellation. 

406 Exception: Re-raises any exception after rolling back the transaction. 

407 

408 Examples: 

409 >>> # Test database context manager 

410 >>> import asyncio 

411 >>> async def test_db(): 

412 ... async with get_db() as db: 

413 ... return db is not None 

414 >>> result = asyncio.run(test_db()) 

415 >>> result 

416 True 

417 """ 

418 db = SessionLocal() 

419 try: 

420 yield db 

421 db.commit() 

422 except asyncio.CancelledError: 

423 # Handle cancellation explicitly to prevent transaction leaks. 

424 # When MCP handlers are cancelled (client disconnect, timeout, etc.), 

425 # we must rollback and close the session before re-raising. 

426 try: 

427 db.rollback() 

428 except Exception: 

429 pass # nosec B110 - Best effort rollback on cancellation 

430 try: 

431 db.close() 

432 except Exception: 

433 pass # nosec B110 - Best effort close on cancellation 

434 raise 

435 except Exception: 

436 try: 

437 db.rollback() 

438 except Exception: 

439 try: 

440 db.invalidate() 

441 except Exception: 

442 pass # nosec B110 - Best effort cleanup on connection failure 

443 raise 

444 finally: 

445 db.close() 

446 

447 

448def get_user_email_from_context() -> str: 

449 """Extract user email from the current user context. 

450 

451 Returns: 

452 User email address or 'unknown' if not available 

453 """ 

454 user = user_context_var.get() 

455 if isinstance(user, dict): 

456 # First try 'email', then 'sub' (JWT standard claim) 

457 return user.get("email") or user.get("sub") or "unknown" 

458 return str(user) if user else "unknown" 

459 

460 

461@mcp_app.call_tool(validate_input=False) 

462async def call_tool(name: str, arguments: dict) -> List[Union[types.TextContent, types.ImageContent, types.AudioContent, types.ResourceLink, types.EmbeddedResource]]: 

463 """ 

464 Handles tool invocation via the MCP Server. 

465 

466 Note: validate_input=False disables the MCP SDK's built-in JSON Schema validation. 

467 This is necessary because the SDK uses jsonschema.validate() which internally calls 

468 check_schema() with the default validator. Schemas using older draft features 

469 (e.g., Draft 4 style exclusiveMinimum: true) fail this validation. The gateway 

470 handles schema validation separately in tool_service.py with multi-draft support. 

471 

472 This function supports the MCP protocol's tool calling with structured content validation. 

473 It can return either unstructured content only, or both unstructured and structured content 

474 when the tool defines an outputSchema. 

475 

476 Args: 

477 name (str): The name of the tool to invoke. 

478 arguments (dict): A dictionary of arguments to pass to the tool. 

479 

480 Returns: 

481 Union[List[ContentBlock], Tuple[List[ContentBlock], Dict[str, Any]]]: 

482 - If structured content is not present: Returns a list of content blocks 

483 (TextContent, ImageContent, or EmbeddedResource) 

484 - If structured content is present: Returns a tuple of (unstructured_content, structured_content) 

485 where structured_content is a dictionary that will be validated against the tool's outputSchema 

486 

487 The MCP SDK's call_tool decorator automatically handles both return types: 

488 - List return → CallToolResult with content only 

489 - Tuple return → CallToolResult with both content and structuredContent fields 

490 

491 Raises: 

492 Exception: Re-raised after logging to allow MCP SDK to convert to JSON-RPC error response. 

493 

494 Raises: 

495 Exception: Re-raises exceptions encountered during tool invocation after logging. 

496 

497 Examples: 

498 >>> # Test call_tool function signature 

499 >>> import inspect 

500 >>> sig = inspect.signature(call_tool) 

501 >>> list(sig.parameters.keys()) 

502 ['name', 'arguments'] 

503 >>> sig.parameters['name'].annotation 

504 <class 'str'> 

505 >>> sig.parameters['arguments'].annotation 

506 <class 'dict'> 

507 >>> sig.return_annotation 

508 typing.List[typing.Union[mcp.types.TextContent, mcp.types.ImageContent, mcp.types.AudioContent, mcp.types.ResourceLink, mcp.types.EmbeddedResource]] 

509 """ 

510 request_headers = request_headers_var.get() 

511 server_id = server_id_var.get() 

512 user_context = user_context_var.get() 

513 

514 meta_data = None 

515 # Extract _meta from request context if available 

516 try: 

517 ctx = mcp_app.request_context 

518 if ctx and ctx.meta is not None: 

519 meta_data = ctx.meta.model_dump() 

520 except LookupError: 

521 # request_context might not be active in some edge cases (e.g. tests) 

522 logger.debug("No active request context found") 

523 

524 # Extract authorization parameters from user context (same pattern as list_tools) 

525 user_email = user_context.get("email") if user_context else None 

526 token_teams = user_context.get("teams") if user_context else None 

527 is_admin = user_context.get("is_admin", False) if user_context else False 

528 

529 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

530 # If token has explicit team scope (even empty [] for public-only), respect it 

531 if is_admin and token_teams is None: 

532 user_email = None 

533 # token_teams stays None (unrestricted) 

534 elif token_teams is None: 

535 token_teams = [] # Non-admin without teams = public-only (secure default) 

536 

537 app_user_email = get_user_email_from_context() # Keep for OAuth token selection 

538 

539 # Multi-worker session affinity: check if we should forward to another worker 

540 # Check both x-mcp-session-id (internal/forwarded) and mcp-session-id (client protocol header) 

541 mcp_session_id = None 

542 if request_headers: 

543 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} 

544 mcp_session_id = request_headers_lower.get("x-mcp-session-id") or request_headers_lower.get("mcp-session-id") 

545 if settings.mcpgateway_session_affinity_enabled and mcp_session_id: 

546 try: 

547 # First-Party 

548 from mcpgateway.cache.tool_lookup_cache import tool_lookup_cache # pylint: disable=import-outside-toplevel 

549 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool # pylint: disable=import-outside-toplevel 

550 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel 

551 

552 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id): 

553 logger.debug("Invalid MCP session id for Streamable HTTP tool affinity, executing locally") 

554 raise RuntimeError("invalid mcp session id") 

555 

556 pool = get_mcp_session_pool() 

557 

558 # Register session mapping BEFORE checking forwarding (same pattern as SSE) 

559 # This ensures ownership is registered atomically so forward_request_to_owner() works 

560 try: 

561 cached = await tool_lookup_cache.get(name) 

562 if cached and cached.get("status") == "active": 

563 gateway_info = cached.get("gateway") 

564 if gateway_info: 

565 url = gateway_info.get("url") 

566 gateway_id = gateway_info.get("id", "") 

567 transport_type = gateway_info.get("transport", "streamablehttp") 

568 if url: 

569 await pool.register_session_mapping(mcp_session_id, url, gateway_id, transport_type, user_email) 

570 except Exception as e: 

571 logger.error(f"Failed to pre-register session mapping for Streamable HTTP: {e}") 

572 

573 forwarded_response = await pool.forward_request_to_owner( 

574 mcp_session_id, 

575 {"method": "tools/call", "params": {"name": name, "arguments": arguments, "_meta": meta_data}, "headers": dict(request_headers) if request_headers else {}}, 

576 ) 

577 if forwarded_response is not None: 

578 # Request was handled by another worker - convert response to expected format 

579 if "error" in forwarded_response: 

580 raise Exception(forwarded_response["error"].get("message", "Forwarded request failed")) # pylint: disable=broad-exception-raised 

581 result_data = forwarded_response.get("result", {}) 

582 

583 def _rehydrate_content_items(items: Any) -> list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource]: 

584 """Convert forwarded tool result items back to MCP content types. 

585 

586 Args: 

587 items: List of content item dicts from forwarded response. 

588 

589 Returns: 

590 List of validated MCP content type instances. 

591 """ 

592 if not isinstance(items, list): 

593 return [] 

594 converted: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = [] 

595 for item in items: 

596 if not isinstance(item, dict): 

597 continue 

598 item_type = item.get("type") 

599 try: 

600 if item_type == "text": 

601 converted.append(types.TextContent.model_validate(item)) 

602 elif item_type == "image": 

603 converted.append(types.ImageContent.model_validate(item)) 

604 elif item_type == "audio": 

605 converted.append(types.AudioContent.model_validate(item)) 

606 elif item_type == "resource_link": 

607 converted.append(types.ResourceLink.model_validate(item)) 

608 elif item_type == "resource": 

609 converted.append(types.EmbeddedResource.model_validate(item)) 

610 else: 

611 converted.append(types.TextContent(type="text", text=str(item))) 

612 except Exception: 

613 converted.append(types.TextContent(type="text", text=str(item))) 

614 return converted 

615 

616 unstructured = _rehydrate_content_items(result_data.get("content", [])) 

617 structured = result_data.get("structuredContent") or result_data.get("structured_content") 

618 if structured: 

619 return (unstructured, structured) 

620 return unstructured 

621 except RuntimeError: 

622 # Pool not initialized - execute locally 

623 pass 

624 

625 try: 

626 async with get_db() as db: 

627 result = await tool_service.invoke_tool( 

628 db=db, 

629 name=name, 

630 arguments=arguments, 

631 request_headers=request_headers, 

632 app_user_email=app_user_email, 

633 user_email=user_email, 

634 token_teams=token_teams, 

635 server_id=server_id, 

636 meta_data=meta_data, 

637 ) 

638 if not result or not result.content: 

639 logger.warning(f"No content returned by tool: {name}") 

640 return [] 

641 

642 # Normalize unstructured content to MCP SDK types, preserving metadata (annotations, _meta, size) 

643 # Helper to convert gateway Annotations to dict for MCP SDK compatibility 

644 # (mcpgateway.common.models.Annotations != mcp.types.Annotations) 

645 def _convert_annotations(ann: Any) -> dict[str, Any] | None: 

646 """Convert gateway Annotations to dict for MCP SDK compatibility. 

647 

648 Args: 

649 ann: Gateway Annotations object, dict, or None. 

650 

651 Returns: 

652 Dict representation of annotations, or None. 

653 """ 

654 if ann is None: 

655 return None 

656 if isinstance(ann, dict): 

657 return ann 

658 if hasattr(ann, "model_dump"): 

659 return ann.model_dump(by_alias=True, mode="json") 

660 return None 

661 

662 def _convert_meta(meta: Any) -> dict[str, Any] | None: 

663 """Convert gateway meta to dict for MCP SDK compatibility. 

664 

665 Args: 

666 meta: Gateway meta object, dict, or None. 

667 

668 Returns: 

669 Dict representation of meta, or None. 

670 """ 

671 if meta is None: 

672 return None 

673 if isinstance(meta, dict): 

674 return meta 

675 if hasattr(meta, "model_dump"): 

676 return meta.model_dump(by_alias=True, mode="json") 

677 return None 

678 

679 unstructured: list[types.TextContent | types.ImageContent | types.AudioContent | types.ResourceLink | types.EmbeddedResource] = [] 

680 for content in result.content: 

681 if content.type == "text": 

682 unstructured.append( 

683 types.TextContent( 

684 type="text", 

685 text=content.text, 

686 annotations=_convert_annotations(getattr(content, "annotations", None)), 

687 _meta=_convert_meta(getattr(content, "meta", None)), 

688 ) 

689 ) 

690 elif content.type == "image": 

691 unstructured.append( 

692 types.ImageContent( 

693 type="image", 

694 data=content.data, 

695 mimeType=content.mime_type, 

696 annotations=_convert_annotations(getattr(content, "annotations", None)), 

697 _meta=_convert_meta(getattr(content, "meta", None)), 

698 ) 

699 ) 

700 elif content.type == "audio": 

701 unstructured.append( 

702 types.AudioContent( 

703 type="audio", 

704 data=content.data, 

705 mimeType=content.mime_type, 

706 annotations=_convert_annotations(getattr(content, "annotations", None)), 

707 _meta=_convert_meta(getattr(content, "meta", None)), 

708 ) 

709 ) 

710 elif content.type == "resource_link": 

711 unstructured.append( 

712 types.ResourceLink( 

713 type="resource_link", 

714 uri=content.uri, 

715 name=content.name, 

716 description=getattr(content, "description", None), 

717 mimeType=getattr(content, "mime_type", None), 

718 size=getattr(content, "size", None), 

719 _meta=_convert_meta(getattr(content, "meta", None)), 

720 ) 

721 ) 

722 elif content.type == "resource": 

723 # EmbeddedResource - pass through the model dump as the MCP SDK type requires complex nested structure 

724 unstructured.append(types.EmbeddedResource.model_validate(content.model_dump(by_alias=True, mode="json"))) 

725 else: 

726 # Unknown content type - convert to text representation 

727 unstructured.append(types.TextContent(type="text", text=str(content.model_dump(by_alias=True, mode="json")))) 

728 

729 # If the tool produced structured content (ToolResult.structured_content / structuredContent), 

730 # return a combination (unstructured, structured) so the server can validate against outputSchema. 

731 # The ToolService may populate structured_content (snake_case) or the model may expose 

732 # an alias 'structuredContent' when dumped via model_dump(by_alias=True). 

733 structured = None 

734 try: 

735 # Prefer attribute if present 

736 structured = getattr(result, "structured_content", None) 

737 except Exception: 

738 structured = None 

739 

740 # Fallback to by-alias dump (in case the result is a pydantic model with alias fields) 

741 if structured is None: 

742 try: 

743 structured = result.model_dump(by_alias=True).get("structuredContent") if hasattr(result, "model_dump") else None 

744 except Exception: 

745 structured = None 

746 

747 if structured: 

748 return (unstructured, structured) 

749 

750 return unstructured 

751 except Exception as e: 

752 logger.exception(f"Error calling tool '{name}': {e}") 

753 # Re-raise the exception so the MCP SDK can properly convert it to an error response 

754 # This ensures error details are propagated to the client instead of returning empty results 

755 raise 

756 

757 

758@mcp_app.list_tools() 

759async def list_tools() -> List[types.Tool]: 

760 """ 

761 Lists all tools available to the MCP Server. 

762 

763 Returns: 

764 A list of Tool objects containing metadata such as name, description, and input schema. 

765 Logs and returns an empty list on failure. 

766 

767 Examples: 

768 >>> # Test list_tools function signature 

769 >>> import inspect 

770 >>> sig = inspect.signature(list_tools) 

771 >>> list(sig.parameters.keys()) 

772 [] 

773 >>> sig.return_annotation 

774 typing.List[mcp.types.Tool] 

775 """ 

776 server_id = server_id_var.get() 

777 request_headers = request_headers_var.get() 

778 user_context = user_context_var.get() 

779 

780 # Extract filtering parameters from user context 

781 user_email = user_context.get("email") if user_context else None 

782 # Use None as default to distinguish "no teams specified" from "empty teams array" 

783 token_teams = user_context.get("teams") if user_context else None 

784 is_admin = user_context.get("is_admin", False) if user_context else False 

785 

786 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

787 # If token has explicit team scope (even empty [] for public-only), respect it 

788 if is_admin and token_teams is None: 

789 user_email = None 

790 # token_teams stays None (unrestricted) 

791 elif token_teams is None: 

792 token_teams = [] # Non-admin without teams = public-only (secure default) 

793 

794 if server_id: 

795 try: 

796 async with get_db() as db: 

797 tools = await tool_service.list_server_tools(db, server_id, user_email=user_email, token_teams=token_teams, _request_headers=request_headers) 

798 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools] 

799 except Exception as e: 

800 logger.exception(f"Error listing tools:{e}") 

801 return [] 

802 else: 

803 try: 

804 async with get_db() as db: 

805 tools, _ = await tool_service.list_tools(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams, _request_headers=request_headers) 

806 return [types.Tool(name=tool.name, description=tool.description, inputSchema=tool.input_schema, outputSchema=tool.output_schema, annotations=tool.annotations) for tool in tools] 

807 except Exception as e: 

808 logger.exception(f"Error listing tools:{e}") 

809 return [] 

810 

811 

812@mcp_app.list_prompts() 

813async def list_prompts() -> List[types.Prompt]: 

814 """ 

815 Lists all prompts available to the MCP Server. 

816 

817 Returns: 

818 A list of Prompt objects containing metadata such as name, description, and arguments. 

819 Logs and returns an empty list on failure. 

820 

821 Examples: 

822 >>> import inspect 

823 >>> sig = inspect.signature(list_prompts) 

824 >>> list(sig.parameters.keys()) 

825 [] 

826 >>> sig.return_annotation 

827 typing.List[mcp.types.Prompt] 

828 """ 

829 server_id = server_id_var.get() 

830 user_context = user_context_var.get() 

831 

832 # Extract filtering parameters from user context 

833 user_email = user_context.get("email") if user_context else None 

834 # Use None as default to distinguish "no teams specified" from "empty teams array" 

835 token_teams = user_context.get("teams") if user_context else None 

836 is_admin = user_context.get("is_admin", False) if user_context else False 

837 

838 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

839 # If token has explicit team scope (even empty [] for public-only), respect it 

840 if is_admin and token_teams is None: 

841 user_email = None 

842 # token_teams stays None (unrestricted) 

843 elif token_teams is None: 

844 token_teams = [] # Non-admin without teams = public-only (secure default) 

845 

846 if server_id: 

847 try: 

848 async with get_db() as db: 

849 prompts = await prompt_service.list_server_prompts(db, server_id, user_email=user_email, token_teams=token_teams) 

850 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts] 

851 except Exception as e: 

852 logger.exception(f"Error listing Prompts:{e}") 

853 return [] 

854 else: 

855 try: 

856 async with get_db() as db: 

857 prompts, _ = await prompt_service.list_prompts(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams) 

858 return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts] 

859 except Exception as e: 

860 logger.exception(f"Error listing prompts:{e}") 

861 return [] 

862 

863 

864@mcp_app.get_prompt() 

865async def get_prompt(prompt_id: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult: 

866 """ 

867 Retrieves a prompt by ID, optionally substituting arguments. 

868 

869 Args: 

870 prompt_id (str): The ID of the prompt to retrieve. 

871 arguments (Optional[dict[str, str]]): Optional dictionary of arguments to substitute into the prompt. 

872 

873 Returns: 

874 GetPromptResult: Object containing the prompt messages and description. 

875 Returns an empty list on failure or if no prompt content is found. 

876 

877 Logs exceptions if any errors occur during retrieval. 

878 

879 Examples: 

880 >>> import inspect 

881 >>> sig = inspect.signature(get_prompt) 

882 >>> list(sig.parameters.keys()) 

883 ['prompt_id', 'arguments'] 

884 >>> sig.return_annotation.__name__ 

885 'GetPromptResult' 

886 """ 

887 server_id = server_id_var.get() 

888 user_context = user_context_var.get() 

889 

890 # Extract authorization parameters from user context (same pattern as list_prompts) 

891 user_email = user_context.get("email") if user_context else None 

892 token_teams = user_context.get("teams") if user_context else None 

893 is_admin = user_context.get("is_admin", False) if user_context else False 

894 

895 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

896 if is_admin and token_teams is None: 

897 user_email = None 

898 # token_teams stays None (unrestricted) 

899 elif token_teams is None: 

900 token_teams = [] # Non-admin without teams = public-only (secure default) 

901 

902 meta_data = None 

903 # Extract _meta from request context if available 

904 try: 

905 ctx = mcp_app.request_context 

906 if ctx and ctx.meta is not None: 

907 meta_data = ctx.meta.model_dump() 

908 except LookupError: 

909 # request_context might not be active in some edge cases (e.g. tests) 

910 logger.debug("No active request context found") 

911 

912 try: 

913 async with get_db() as db: 

914 try: 

915 result = await prompt_service.get_prompt( 

916 db=db, 

917 prompt_id=prompt_id, 

918 arguments=arguments, 

919 user=user_email, 

920 server_id=server_id, 

921 token_teams=token_teams, 

922 _meta_data=meta_data, 

923 ) 

924 except Exception as e: 

925 logger.exception(f"Error getting prompt '{prompt_id}': {e}") 

926 return [] 

927 if not result or not result.messages: 

928 logger.warning(f"No content returned by prompt: {prompt_id}") 

929 return [] 

930 message_dicts = [message.model_dump() for message in result.messages] 

931 return types.GetPromptResult(messages=message_dicts, description=result.description) 

932 except Exception as e: 

933 logger.exception(f"Error getting prompt '{prompt_id}': {e}") 

934 return [] 

935 

936 

937@mcp_app.list_resources() 

938async def list_resources() -> List[types.Resource]: 

939 """ 

940 Lists all resources available to the MCP Server. 

941 

942 Returns: 

943 A list of Resource objects containing metadata such as uri, name, description, and mimeType. 

944 Logs and returns an empty list on failure. 

945 

946 Examples: 

947 >>> import inspect 

948 >>> sig = inspect.signature(list_resources) 

949 >>> list(sig.parameters.keys()) 

950 [] 

951 >>> sig.return_annotation 

952 typing.List[mcp.types.Resource] 

953 """ 

954 server_id = server_id_var.get() 

955 user_context = user_context_var.get() 

956 

957 # Extract filtering parameters from user context 

958 user_email = user_context.get("email") if user_context else None 

959 # Use None as default to distinguish "no teams specified" from "empty teams array" 

960 token_teams = user_context.get("teams") if user_context else None 

961 is_admin = user_context.get("is_admin", False) if user_context else False 

962 

963 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

964 # If token has explicit team scope (even empty [] for public-only), respect it 

965 if is_admin and token_teams is None: 

966 user_email = None 

967 # token_teams stays None (unrestricted) 

968 elif token_teams is None: 

969 token_teams = [] # Non-admin without teams = public-only (secure default) 

970 

971 if server_id: 

972 try: 

973 async with get_db() as db: 

974 resources = await resource_service.list_server_resources(db, server_id, user_email=user_email, token_teams=token_teams) 

975 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources] 

976 except Exception as e: 

977 logger.exception(f"Error listing Resources:{e}") 

978 return [] 

979 else: 

980 try: 

981 async with get_db() as db: 

982 resources, _ = await resource_service.list_resources(db, include_inactive=False, limit=0, user_email=user_email, token_teams=token_teams) 

983 return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources] 

984 except Exception as e: 

985 logger.exception(f"Error listing resources:{e}") 

986 return [] 

987 

988 

989@mcp_app.read_resource() 

990async def read_resource(resource_uri: str) -> Union[str, bytes]: 

991 """ 

992 Reads the content of a resource specified by its URI. 

993 

994 Args: 

995 resource_uri (str): The URI of the resource to read. 

996 

997 Returns: 

998 Union[str, bytes]: The content of the resource as text or binary data. 

999 Returns empty string on failure or if no content is found. 

1000 

1001 Logs exceptions if any errors occur during reading. 

1002 

1003 Examples: 

1004 >>> import inspect 

1005 >>> sig = inspect.signature(read_resource) 

1006 >>> list(sig.parameters.keys()) 

1007 ['resource_uri'] 

1008 >>> sig.return_annotation 

1009 typing.Union[str, bytes] 

1010 """ 

1011 server_id = server_id_var.get() 

1012 user_context = user_context_var.get() 

1013 

1014 # Extract authorization parameters from user context (same pattern as list_resources) 

1015 user_email = user_context.get("email") if user_context else None 

1016 token_teams = user_context.get("teams") if user_context else None 

1017 is_admin = user_context.get("is_admin", False) if user_context else False 

1018 

1019 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

1020 if is_admin and token_teams is None: 

1021 user_email = None 

1022 # token_teams stays None (unrestricted) 

1023 elif token_teams is None: 

1024 token_teams = [] # Non-admin without teams = public-only (secure default) 

1025 

1026 meta_data = None 

1027 # Extract _meta from request context if available 

1028 try: 

1029 ctx = mcp_app.request_context 

1030 if ctx and ctx.meta is not None: 

1031 meta_data = ctx.meta.model_dump() 

1032 except LookupError: 

1033 # request_context might not be active in some edge cases (e.g. tests) 

1034 logger.debug("No active request context found") 

1035 

1036 try: 

1037 async with get_db() as db: 

1038 try: 

1039 result = await resource_service.read_resource( 

1040 db=db, 

1041 resource_uri=str(resource_uri), 

1042 user=user_email, 

1043 server_id=server_id, 

1044 token_teams=token_teams, 

1045 meta_data=meta_data, 

1046 ) 

1047 except Exception as e: 

1048 logger.exception(f"Error reading resource '{resource_uri}': {e}") 

1049 return "" 

1050 

1051 # Return blob content if available (binary resources) 

1052 if result and result.blob: 

1053 return result.blob 

1054 

1055 # Return text content if available (text resources) 

1056 if result and result.text: 

1057 return result.text 

1058 

1059 # No content found 

1060 logger.warning(f"No content returned by resource: {resource_uri}") 

1061 return "" 

1062 except Exception as e: 

1063 logger.exception(f"Error reading resource '{resource_uri}': {e}") 

1064 return "" 

1065 

1066 

1067@mcp_app.list_resource_templates() 

1068async def list_resource_templates() -> List[Dict[str, Any]]: 

1069 """ 

1070 Lists all resource templates available to the MCP Server. 

1071 

1072 Returns: 

1073 List[types.ResourceTemplate]: A list of resource templates with their URIs and metadata. 

1074 

1075 Examples: 

1076 >>> import inspect 

1077 >>> sig = inspect.signature(list_resource_templates) 

1078 >>> list(sig.parameters.keys()) 

1079 [] 

1080 >>> sig.return_annotation.__origin__.__name__ 

1081 'list' 

1082 """ 

1083 # Extract filtering parameters from user context (same pattern as list_resources) 

1084 user_context = user_context_var.get() 

1085 user_email = user_context.get("email") if user_context else None 

1086 token_teams = user_context.get("teams") if user_context else None 

1087 is_admin = user_context.get("is_admin", False) if user_context else False 

1088 

1089 # Admin bypass - only when token has NO team restrictions (token_teams is None) 

1090 # If token has explicit team scope (even empty [] for public-only), respect it 

1091 if is_admin and token_teams is None: 

1092 user_email = None 

1093 # token_teams stays None (unrestricted) 

1094 elif token_teams is None: 

1095 token_teams = [] # Non-admin without teams = public-only (secure default) 

1096 

1097 try: 

1098 async with get_db() as db: 

1099 try: 

1100 resource_templates = await resource_service.list_resource_templates( 

1101 db, 

1102 user_email=user_email, 

1103 token_teams=token_teams, 

1104 ) 

1105 return [template.model_dump(by_alias=True) for template in resource_templates] 

1106 except Exception as e: 

1107 logger.exception(f"Error listing resource templates: {e}") 

1108 return [] 

1109 except Exception as e: 

1110 logger.exception(f"Error listing resource templates: {e}") 

1111 return [] 

1112 

1113 

1114@mcp_app.set_logging_level() 

1115async def set_logging_level(level: types.LoggingLevel) -> types.EmptyResult: 

1116 """ 

1117 Sets the logging level for the MCP Server. 

1118 

1119 Args: 

1120 level (types.LoggingLevel): The desired logging level (debug, info, notice, warning, error, critical, alert, emergency). 

1121 

1122 Returns: 

1123 types.EmptyResult: An empty result indicating success. 

1124 

1125 Examples: 

1126 >>> import inspect 

1127 >>> sig = inspect.signature(set_logging_level) 

1128 >>> list(sig.parameters.keys()) 

1129 ['level'] 

1130 """ 

1131 try: 

1132 # Convert MCP logging level to our LogLevel enum 

1133 level_map = { 

1134 "debug": LogLevel.DEBUG, 

1135 "info": LogLevel.INFO, 

1136 "notice": LogLevel.INFO, 

1137 "warning": LogLevel.WARNING, 

1138 "error": LogLevel.ERROR, 

1139 "critical": LogLevel.CRITICAL, 

1140 "alert": LogLevel.CRITICAL, 

1141 "emergency": LogLevel.CRITICAL, 

1142 } 

1143 log_level = level_map.get(level.lower(), LogLevel.INFO) 

1144 await logging_service.set_level(log_level) 

1145 return types.EmptyResult() 

1146 except Exception as e: 

1147 logger.exception(f"Error setting logging level: {e}") 

1148 return types.EmptyResult() 

1149 

1150 

1151@mcp_app.completion() 

1152async def complete( 

1153 ref: Union[types.PromptReference, types.ResourceTemplateReference], 

1154 argument: types.CompleteRequest, 

1155 context: Optional[types.CompletionContext] = None, 

1156) -> types.CompleteResult: 

1157 """ 

1158 Provides argument completion suggestions for prompts or resources. 

1159 

1160 Args: 

1161 ref: A reference to a prompt or a resource template. Can be either 

1162 `types.PromptReference` or `types.ResourceTemplateReference`. 

1163 argument: The completion request specifying the input text and 

1164 position for which completion suggestions should be generated. 

1165 context: Optional contextual information for the completion request, 

1166 such as user, environment, or invocation metadata. 

1167 

1168 Returns: 

1169 types.CompleteResult: A normalized completion result containing 

1170 completion values, metadata (total, hasMore), and any additional 

1171 MCP-compliant completion fields. 

1172 

1173 Raises: 

1174 Exception: If completion handling fails internally. The method 

1175 logs the exception and returns an empty completion structure. 

1176 """ 

1177 try: 

1178 async with get_db() as db: 

1179 params = { 

1180 "ref": ref.model_dump() if hasattr(ref, "model_dump") else ref, 

1181 "argument": argument.model_dump() if hasattr(argument, "model_dump") else argument, 

1182 "context": context.model_dump() if hasattr(context, "model_dump") else context, 

1183 } 

1184 

1185 result = await completion_service.handle_completion(db, params) 

1186 

1187 # ✅ Normalize the result for MCP 

1188 if isinstance(result, dict): 

1189 completion_data = result.get("completion", result) 

1190 return types.Completion(**completion_data) 

1191 

1192 if hasattr(result, "completion"): 

1193 completion_obj = result.completion 

1194 

1195 # If completion itself is a dict 

1196 if isinstance(completion_obj, dict): 

1197 return types.Completion(**completion_obj) 

1198 

1199 # If completion is another CompleteResult (nested) 

1200 if hasattr(completion_obj, "completion"): 

1201 inner_completion = completion_obj.completion.model_dump() if hasattr(completion_obj.completion, "model_dump") else completion_obj.completion 

1202 return types.Completion(**inner_completion) 

1203 

1204 # If completion is already a Completion model 

1205 if isinstance(completion_obj, types.Completion): 

1206 return completion_obj 

1207 

1208 # If it's another Pydantic model (e.g., mcpgateway.models.Completion) 

1209 if hasattr(completion_obj, "model_dump"): 

1210 return types.Completion(**completion_obj.model_dump()) 

1211 

1212 # If result itself is already a types.Completion 

1213 if isinstance(result, types.Completion): 

1214 return result 

1215 

1216 # Fallback: return empty completion 

1217 return types.Completion(values=[], total=0, hasMore=False) 

1218 

1219 except Exception as e: 

1220 logger.exception(f"Error handling completion: {e}") 

1221 return types.Completion(values=[], total=0, hasMore=False) 

1222 

1223 

1224class SessionManagerWrapper: 

1225 """ 

1226 Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance. 

1227 Provides start, stop, and request handling methods. 

1228 

1229 Examples: 

1230 >>> # Test SessionManagerWrapper initialization 

1231 >>> wrapper = SessionManagerWrapper() 

1232 >>> wrapper 

1233 <mcpgateway.transports.streamablehttp_transport.SessionManagerWrapper object at ...> 

1234 >>> hasattr(wrapper, 'session_manager') 

1235 True 

1236 >>> hasattr(wrapper, 'stack') 

1237 True 

1238 >>> isinstance(wrapper.stack, AsyncExitStack) 

1239 True 

1240 """ 

1241 

1242 def __init__(self) -> None: 

1243 """ 

1244 Initializes the session manager and the exit stack used for managing its lifecycle. 

1245 

1246 Examples: 

1247 >>> # Test initialization 

1248 >>> wrapper = SessionManagerWrapper() 

1249 >>> wrapper.session_manager is not None 

1250 True 

1251 >>> wrapper.stack is not None 

1252 True 

1253 """ 

1254 

1255 if settings.use_stateful_sessions: 

1256 # Use Redis event store for single-worker stateful deployments 

1257 if settings.cache_type == "redis" and settings.redis_url: 

1258 event_store = RedisEventStore(max_events_per_stream=settings.streamable_http_max_events_per_stream, ttl=settings.streamable_http_event_ttl) 

1259 logger.debug("Using RedisEventStore for stateful sessions (single-worker)") 

1260 else: 

1261 # Fall back to in-memory for single-worker or when Redis not available 

1262 event_store = InMemoryEventStore() 

1263 logger.warning("Using InMemoryEventStore - only works with single worker!") 

1264 stateless = False 

1265 else: 

1266 event_store = None 

1267 stateless = True 

1268 

1269 self.session_manager = StreamableHTTPSessionManager( 

1270 app=mcp_app, 

1271 event_store=event_store, 

1272 json_response=settings.json_response_enabled, 

1273 stateless=stateless, 

1274 ) 

1275 self.stack = AsyncExitStack() 

1276 

1277 async def initialize(self) -> None: 

1278 """ 

1279 Starts the Streamable HTTP session manager context. 

1280 

1281 Examples: 

1282 >>> # Test initialize method exists 

1283 >>> wrapper = SessionManagerWrapper() 

1284 >>> hasattr(wrapper, 'initialize') 

1285 True 

1286 >>> callable(wrapper.initialize) 

1287 True 

1288 """ 

1289 logger.debug("Initializing Streamable HTTP service") 

1290 await self.stack.enter_async_context(self.session_manager.run()) 

1291 

1292 async def shutdown(self) -> None: 

1293 """ 

1294 Gracefully shuts down the Streamable HTTP session manager. 

1295 

1296 Examples: 

1297 >>> # Test shutdown method exists 

1298 >>> wrapper = SessionManagerWrapper() 

1299 >>> hasattr(wrapper, 'shutdown') 

1300 True 

1301 >>> callable(wrapper.shutdown) 

1302 True 

1303 """ 

1304 logger.debug("Stopping Streamable HTTP Session Manager...") 

1305 await self.stack.aclose() 

1306 

1307 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None: 

1308 """ 

1309 Forwards an incoming ASGI request to the streamable HTTP session manager. 

1310 

1311 Args: 

1312 scope (Scope): ASGI scope object containing connection information. 

1313 receive (Receive): ASGI receive callable. 

1314 send (Send): ASGI send callable. 

1315 

1316 Raises: 

1317 Exception: Any exception raised during request handling is logged. 

1318 

1319 Logs any exceptions that occur during request handling. 

1320 

1321 Examples: 

1322 >>> # Test handle_streamable_http method exists 

1323 >>> wrapper = SessionManagerWrapper() 

1324 >>> hasattr(wrapper, 'handle_streamable_http') 

1325 True 

1326 >>> callable(wrapper.handle_streamable_http) 

1327 True 

1328 

1329 >>> # Test method signature 

1330 >>> import inspect 

1331 >>> sig = inspect.signature(wrapper.handle_streamable_http) 

1332 >>> list(sig.parameters.keys()) 

1333 ['scope', 'receive', 'send'] 

1334 """ 

1335 

1336 path = scope["modified_path"] 

1337 # Uses precompiled regex for server ID extraction 

1338 match = _SERVER_ID_RE.search(path) 

1339 

1340 # Extract request headers from scope (ASGI provides bytes; normalize to lowercase for lookup). 

1341 raw_headers = scope.get("headers") or [] 

1342 headers: dict[str, str] = {} 

1343 for item in raw_headers: 

1344 if not isinstance(item, (tuple, list)) or len(item) != 2: 

1345 continue 

1346 k, v = item 

1347 if not isinstance(k, (bytes, bytearray)) or not isinstance(v, (bytes, bytearray)): 

1348 continue 

1349 # latin-1 is a byte-preserving decode; safe for arbitrary header bytes. 

1350 headers[k.decode("latin-1").lower()] = v.decode("latin-1") 

1351 

1352 # Log session info for debugging stateful sessions 

1353 mcp_session_id = headers.get("mcp-session-id", "not-provided") 

1354 method = scope.get("method", "UNKNOWN") 

1355 query_string = scope.get("query_string", b"").decode("utf-8") 

1356 logger.debug(f"[STATEFUL] Streamable HTTP {method} {path} | MCP-Session-Id: {mcp_session_id} | Query: {query_string} | Stateful: {settings.use_stateful_sessions}") 

1357 

1358 # Note: mcp-session-id from client is used for gateway-internal session affinity 

1359 # routing (stored in request_headers_var), but is NOT renamed or forwarded to 

1360 # upstream servers - it's a gateway-side concept, not an end-to-end semantic header 

1361 

1362 # Multi-worker session affinity: check if we should forward to another worker 

1363 # This must happen BEFORE the SDK's session manager handles the request 

1364 is_internally_forwarded = headers.get("x-forwarded-internally") == "true" 

1365 

1366 if settings.mcpgateway_session_affinity_enabled and mcp_session_id != "not-provided": 

1367 try: 

1368 # First-Party 

1369 from mcpgateway.services.mcp_session_pool import MCPSessionPool # pylint: disable=import-outside-toplevel 

1370 

1371 if not MCPSessionPool.is_valid_mcp_session_id(mcp_session_id): 

1372 logger.debug("Invalid MCP session id on Streamable HTTP request, skipping affinity") 

1373 mcp_session_id = "not-provided" 

1374 except Exception: 

1375 mcp_session_id = "not-provided" 

1376 

1377 # Log session manager ID for debugging 

1378 logger.debug(f"[SESSION_MGR_DEBUG] Manager ID: {id(self.session_manager)}") 

1379 

1380 if is_internally_forwarded: 

1381 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Received forwarded request | Method: {method} | Session: {mcp_session_id}") 

1382 

1383 # Only route POST requests with JSON-RPC body to /rpc 

1384 # DELETE and other methods should return success (session cleanup is local) 

1385 if method != "POST": 

1386 logger.debug("[HTTP_AFFINITY_FORWARDED] Non-POST method, returning 200 OK") 

1387 await send({"type": "http.response.start", "status": 200, "headers": [(b"content-type", b"application/json")]}) 

1388 await send({"type": "http.response.body", "body": b'{"jsonrpc":"2.0","result":{}}'}) 

1389 return 

1390 

1391 # For POST requests, bypass SDK session manager and use /rpc directly 

1392 # This avoids SDK's session cleanup issues while maintaining stateful upstream connections 

1393 try: 

1394 # Read request body 

1395 body_parts = [] 

1396 while True: 

1397 message = await receive() 

1398 if message["type"] == "http.request": 

1399 body_parts.append(message.get("body", b"")) 

1400 if not message.get("more_body", False): 

1401 break 

1402 elif message["type"] == "http.disconnect": 

1403 return 

1404 body = b"".join(body_parts) 

1405 

1406 if not body: 

1407 logger.debug("[HTTP_AFFINITY_FORWARDED] Empty body, returning 202 Accepted") 

1408 await send({"type": "http.response.start", "status": 202, "headers": []}) 

1409 await send({"type": "http.response.body", "body": b""}) 

1410 return 

1411 

1412 json_body = orjson.loads(body) 

1413 rpc_method = json_body.get("method", "") 

1414 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Routing to /rpc | Method: {rpc_method}") 

1415 

1416 # Notifications don't need /rpc routing - just acknowledge 

1417 if rpc_method.startswith("notifications/"): 

1418 logger.debug("[HTTP_AFFINITY_FORWARDED] Notification, returning 202 Accepted") 

1419 await send({"type": "http.response.start", "status": 202, "headers": []}) 

1420 await send({"type": "http.response.body", "body": b""}) 

1421 return 

1422 

1423 async with httpx.AsyncClient() as client: 

1424 rpc_headers = { 

1425 "content-type": "application/json", 

1426 "x-mcp-session-id": mcp_session_id, # Pass session for upstream affinity 

1427 "x-forwarded-internally": "true", # Prevent infinite forwarding loops 

1428 } 

1429 # Copy auth header if present 

1430 if "authorization" in headers: 

1431 rpc_headers["authorization"] = headers["authorization"] 

1432 

1433 response = await client.post( 

1434 f"http://127.0.0.1:{settings.port}/rpc", 

1435 content=body, 

1436 headers=rpc_headers, 

1437 timeout=30.0, 

1438 ) 

1439 

1440 # Return response to client 

1441 response_headers = [ 

1442 (b"content-type", b"application/json"), 

1443 (b"content-length", str(len(response.content)).encode()), 

1444 ] 

1445 if mcp_session_id != "not-provided": 

1446 response_headers.append((b"mcp-session-id", mcp_session_id.encode())) 

1447 

1448 await send( 

1449 { 

1450 "type": "http.response.start", 

1451 "status": response.status_code, 

1452 "headers": response_headers, 

1453 } 

1454 ) 

1455 await send( 

1456 { 

1457 "type": "http.response.body", 

1458 "body": response.content, 

1459 } 

1460 ) 

1461 logger.debug(f"[HTTP_AFFINITY_FORWARDED] Response sent | Status: {response.status_code}") 

1462 return 

1463 except Exception as e: 

1464 logger.error(f"[HTTP_AFFINITY_FORWARDED] Error routing to /rpc: {e}") 

1465 # Fall through to SDK handling as fallback 

1466 

1467 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions and mcp_session_id != "not-provided" and not is_internally_forwarded: 

1468 try: 

1469 # First-Party - lazy import to avoid circular dependencies 

1470 # First-Party 

1471 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel 

1472 

1473 pool = get_mcp_session_pool() 

1474 owner = await pool.get_streamable_http_session_owner(mcp_session_id) 

1475 logger.debug(f"[HTTP_AFFINITY_CHECK] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner from Redis: {owner}") 

1476 

1477 if owner and owner != WORKER_ID: 

1478 # Session owned by another worker - forward the entire HTTP request 

1479 logger.info(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner: {owner} | Forwarding HTTP request") 

1480 

1481 # Read request body 

1482 body_parts = [] 

1483 while True: 

1484 message = await receive() 

1485 if message["type"] == "http.request": 

1486 body_parts.append(message.get("body", b"")) 

1487 if not message.get("more_body", False): 

1488 break 

1489 elif message["type"] == "http.disconnect": 

1490 return 

1491 body = b"".join(body_parts) 

1492 

1493 # Forward to owner worker 

1494 response = await pool.forward_streamable_http_to_owner( 

1495 owner_worker_id=owner, 

1496 mcp_session_id=mcp_session_id, 

1497 method=method, 

1498 path=path, 

1499 headers=headers, 

1500 body=body, 

1501 query_string=query_string, 

1502 ) 

1503 

1504 if response: 

1505 # Send forwarded response back to client 

1506 response_headers = [(k.encode(), v.encode()) for k, v in response["headers"].items() if k.lower() not in ("transfer-encoding", "content-encoding", "content-length")] 

1507 response_headers.append((b"content-length", str(len(response["body"])).encode())) 

1508 

1509 await send( 

1510 { 

1511 "type": "http.response.start", 

1512 "status": response["status"], 

1513 "headers": response_headers, 

1514 } 

1515 ) 

1516 await send( 

1517 { 

1518 "type": "http.response.body", 

1519 "body": response["body"], 

1520 } 

1521 ) 

1522 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Forwarded response sent to client") 

1523 return 

1524 

1525 # Forwarding failed - fall through to local handling 

1526 # This may result in "session not found" but it's better than no response 

1527 logger.debug(f"[HTTP_AFFINITY] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Forwarding failed, falling back to local") 

1528 

1529 elif owner == WORKER_ID and method == "POST": 

1530 # We own this session - route POST requests to /rpc to avoid SDK session issues 

1531 # The SDK's _server_instances gets cleared between requests, so we can't rely on it 

1532 logger.debug(f"[HTTP_AFFINITY_LOCAL] Worker {WORKER_ID} | Session {mcp_session_id[:8]}... | Owner is us, routing to /rpc") 

1533 

1534 # Read request body 

1535 body_parts = [] 

1536 while True: 

1537 message = await receive() 

1538 if message["type"] == "http.request": 

1539 body_parts.append(message.get("body", b"")) 

1540 if not message.get("more_body", False): 

1541 break 

1542 elif message["type"] == "http.disconnect": 

1543 return 

1544 body = b"".join(body_parts) 

1545 

1546 if not body: 

1547 logger.debug("[HTTP_AFFINITY_LOCAL] Empty body, returning 202 Accepted") 

1548 await send({"type": "http.response.start", "status": 202, "headers": []}) 

1549 await send({"type": "http.response.body", "body": b""}) 

1550 return 

1551 

1552 # Parse JSON-RPC and route to /rpc 

1553 try: 

1554 json_body = orjson.loads(body) 

1555 rpc_method = json_body.get("method", "") 

1556 logger.debug(f"[HTTP_AFFINITY_LOCAL] Routing to /rpc | Method: {rpc_method}") 

1557 

1558 # Notifications don't need /rpc routing 

1559 if rpc_method.startswith("notifications/"): 

1560 logger.debug("[HTTP_AFFINITY_LOCAL] Notification, returning 202 Accepted") 

1561 await send({"type": "http.response.start", "status": 202, "headers": []}) 

1562 await send({"type": "http.response.body", "body": b""}) 

1563 return 

1564 

1565 async with httpx.AsyncClient() as client: 

1566 rpc_headers = { 

1567 "content-type": "application/json", 

1568 "x-mcp-session-id": mcp_session_id, 

1569 "x-forwarded-internally": "true", 

1570 } 

1571 if "authorization" in headers: 

1572 rpc_headers["authorization"] = headers["authorization"] 

1573 

1574 response = await client.post( 

1575 f"http://127.0.0.1:{settings.port}/rpc", 

1576 content=body, 

1577 headers=rpc_headers, 

1578 timeout=30.0, 

1579 ) 

1580 

1581 response_headers = [ 

1582 (b"content-type", b"application/json"), 

1583 (b"content-length", str(len(response.content)).encode()), 

1584 (b"mcp-session-id", mcp_session_id.encode()), 

1585 ] 

1586 

1587 await send( 

1588 { 

1589 "type": "http.response.start", 

1590 "status": response.status_code, 

1591 "headers": response_headers, 

1592 } 

1593 ) 

1594 await send( 

1595 { 

1596 "type": "http.response.body", 

1597 "body": response.content, 

1598 } 

1599 ) 

1600 logger.debug(f"[HTTP_AFFINITY_LOCAL] Response sent | Status: {response.status_code}") 

1601 return 

1602 except Exception as e: 

1603 logger.error(f"[HTTP_AFFINITY_LOCAL] Error routing to /rpc: {e}") 

1604 # Fall through to SDK handling as fallback 

1605 

1606 except RuntimeError: 

1607 # Pool not initialized - proceed with local handling 

1608 pass 

1609 except Exception as e: 

1610 logger.debug(f"Session affinity check failed, proceeding locally: {e}") 

1611 

1612 # Store headers in context for tool invocations 

1613 request_headers_var.set(headers) 

1614 

1615 if match: 

1616 server_id = match.group("server_id") 

1617 server_id_var.set(server_id) 

1618 else: 

1619 server_id_var.set(None) 

1620 

1621 # For session affinity: wrap send to capture session ID from response headers 

1622 # This allows us to register ownership for new sessions created by the SDK 

1623 captured_session_id: Optional[str] = None 

1624 

1625 async def send_with_capture(message: Dict[str, Any]) -> None: 

1626 """Wrap ASGI send to capture session ID from response headers. 

1627 

1628 Args: 

1629 message: ASGI message dict. 

1630 """ 

1631 nonlocal captured_session_id 

1632 if message["type"] == "http.response.start" and settings.mcpgateway_session_affinity_enabled: 

1633 # Look for mcp-session-id in response headers 

1634 response_headers = message.get("headers", []) 

1635 for header_name, header_value in response_headers: 

1636 if isinstance(header_name, bytes): 

1637 header_name = header_name.decode("latin-1") 

1638 if isinstance(header_value, bytes): 

1639 header_value = header_value.decode("latin-1") 

1640 if header_name.lower() == "mcp-session-id": 

1641 captured_session_id = header_value 

1642 break 

1643 await send(message) 

1644 

1645 try: 

1646 await self.session_manager.handle_request(scope, receive, send_with_capture) 

1647 logger.debug(f"[STATEFUL] Streamable HTTP request completed successfully | Session: {mcp_session_id}") 

1648 

1649 # Register ownership for the session we just handled 

1650 # This captures both existing sessions (mcp_session_id from request) 

1651 # and new sessions (captured_session_id from response) 

1652 logger.debug( 

1653 f"[HTTP_AFFINITY_DEBUG] affinity_enabled={settings.mcpgateway_session_affinity_enabled} stateful={settings.use_stateful_sessions} captured={captured_session_id} mcp_session_id={mcp_session_id}" 

1654 ) 

1655 if settings.mcpgateway_session_affinity_enabled and settings.use_stateful_sessions: 

1656 session_to_register = captured_session_id or (mcp_session_id if mcp_session_id != "not-provided" else None) 

1657 logger.debug(f"[HTTP_AFFINITY_DEBUG] session_to_register={session_to_register}") 

1658 if session_to_register: 

1659 try: 

1660 # First-Party - lazy import to avoid circular dependencies 

1661 # First-Party 

1662 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, WORKER_ID # pylint: disable=import-outside-toplevel 

1663 

1664 pool = get_mcp_session_pool() 

1665 await pool.register_pool_session_owner(session_to_register) 

1666 logger.debug(f"[HTTP_AFFINITY_SDK] Worker {WORKER_ID} | Session {session_to_register[:8]}... | Registered ownership after SDK handling") 

1667 except Exception as e: 

1668 logger.debug(f"[HTTP_AFFINITY_DEBUG] Exception during registration: {e}") 

1669 logger.warning(f"Failed to register session ownership: {e}") 

1670 

1671 except anyio.ClosedResourceError: 

1672 # Expected when client closes one side of the stream (normal lifecycle) 

1673 logger.debug("Streamable HTTP connection closed by client (ClosedResourceError)") 

1674 except Exception as e: 

1675 logger.error(f"[STATEFUL] Streamable HTTP request failed | Session: {mcp_session_id} | Error: {e}") 

1676 logger.exception(f"Error handling streamable HTTP request: {e}") 

1677 raise 

1678 

1679 

1680# ------------------------- Authentication for /mcp routes ------------------------------ 

1681 

1682 

1683async def streamable_http_auth(scope: Any, receive: Any, send: Any) -> bool: 

1684 """ 

1685 Perform authentication check in middleware context (ASGI scope). 

1686 

1687 This function is intended to be used in middleware wrapping ASGI apps. 

1688 It authenticates only requests targeting paths ending in "/mcp" or "/mcp/". 

1689 

1690 Behavior: 

1691 - If the path does not end with "/mcp", authentication is skipped. 

1692 - If mcp_require_auth=True (strict mode): 

1693 - Requests without valid auth are rejected with 401. 

1694 - If mcp_require_auth=False (default, permissive mode): 

1695 - Requests without auth are allowed but get public-only access (token_teams=[]). 

1696 - Valid tokens get full scoped access based on their teams. 

1697 - If a Bearer token is present, it is verified using `verify_credentials`. 

1698 - If verification fails and mcp_require_auth=True, a 401 Unauthorized JSON response is sent. 

1699 

1700 Args: 

1701 scope: The ASGI scope dictionary, which includes request metadata. 

1702 receive: ASGI receive callable used to receive events. 

1703 send: ASGI send callable used to send events (e.g. a 401 response). 

1704 

1705 Returns: 

1706 bool: True if authentication passes or is skipped. 

1707 False if authentication fails and a 401 response is sent. 

1708 

1709 Examples: 

1710 >>> # Test streamable_http_auth function exists 

1711 >>> callable(streamable_http_auth) 

1712 True 

1713 

1714 >>> # Test function signature 

1715 >>> import inspect 

1716 >>> sig = inspect.signature(streamable_http_auth) 

1717 >>> list(sig.parameters.keys()) 

1718 ['scope', 'receive', 'send'] 

1719 """ 

1720 path = scope.get("path", "") 

1721 if not path.endswith("/mcp") and not path.endswith("/mcp/"): 

1722 # No auth needed for other paths in this middleware usage 

1723 return True 

1724 

1725 headers = Headers(scope=scope) 

1726 

1727 # CORS preflight (OPTIONS + Origin + Access-Control-Request-Method) cannot carry auth headers 

1728 method = scope.get("method", "") 

1729 if method == "OPTIONS": 

1730 origin = headers.get("origin") 

1731 if origin and headers.get("access-control-request-method"): 

1732 return True 

1733 

1734 authorization = headers.get("authorization") 

1735 proxy_user = headers.get(settings.proxy_user_header) if settings.trust_proxy_auth else None 

1736 

1737 # Determine authentication strategy based on settings 

1738 if not settings.mcp_client_auth_enabled and settings.trust_proxy_auth: 

1739 # Client auth disabled → allow proxy header 

1740 if proxy_user: 

1741 # Set enriched user context for proxy-authenticated sessions 

1742 user_context_var.set( 

1743 { 

1744 "email": proxy_user, 

1745 "teams": [], # Proxy auth has no team context 

1746 "is_authenticated": True, 

1747 "is_admin": False, 

1748 } 

1749 ) 

1750 return True # Trusted proxy supplied user 

1751 

1752 # --- Standard JWT authentication flow (client auth enabled) --- 

1753 token: str | None = None 

1754 if authorization: 

1755 scheme, credentials = get_authorization_scheme_param(authorization) 

1756 if scheme.lower() == "bearer" and credentials: 

1757 token = credentials 

1758 

1759 try: 

1760 if token is None: 

1761 raise Exception("No token provided") 

1762 user_payload = await verify_credentials(token) 

1763 # Store enriched user context with normalized teams 

1764 if isinstance(user_payload, dict): 

1765 # Resolve teams based on token_use claim 

1766 token_use = user_payload.get("token_use") 

1767 if token_use == "session": # nosec B105 - Not a password; token_use is a JWT claim type 

1768 # Session token: resolve teams from DB/cache 

1769 user_email_for_teams = user_payload.get("sub") or user_payload.get("email") 

1770 is_admin_flag = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False) 

1771 if is_admin_flag: 

1772 final_teams = None # Admin bypass 

1773 elif user_email_for_teams: 

1774 # Resolve teams synchronously with L1 cache (StreamableHTTP uses sync context) 

1775 # First-Party 

1776 from mcpgateway.auth import _resolve_teams_from_db_sync # pylint: disable=import-outside-toplevel 

1777 

1778 final_teams = _resolve_teams_from_db_sync(user_email_for_teams, is_admin=False) 

1779 else: 

1780 final_teams = [] # No email — public-only 

1781 else: 

1782 # API token or legacy: use embedded teams from JWT 

1783 # First-Party 

1784 from mcpgateway.auth import normalize_token_teams # pylint: disable=import-outside-toplevel 

1785 

1786 final_teams = normalize_token_teams(user_payload) 

1787 

1788 # ═══════════════════════════════════════════════════════════════════════════ 

1789 # SECURITY: Validate team membership for team-scoped tokens 

1790 # Users removed from a team should lose MCP access immediately, not at token expiry 

1791 # ═══════════════════════════════════════════════════════════════════════════ 

1792 user_email = user_payload.get("sub") or user_payload.get("email") 

1793 is_admin = user_payload.get("is_admin", False) or user_payload.get("user", {}).get("is_admin", False) 

1794 

1795 # Only validate membership for team-scoped tokens (non-empty teams list) 

1796 # Skip for: public-only tokens ([]), admin unrestricted tokens (None) 

1797 if final_teams and len(final_teams) > 0 and user_email: 

1798 # Import lazily to avoid circular imports 

1799 # First-Party 

1800 from mcpgateway.cache.auth_cache import get_auth_cache # pylint: disable=import-outside-toplevel 

1801 from mcpgateway.db import EmailTeamMember # pylint: disable=import-outside-toplevel 

1802 

1803 auth_cache = get_auth_cache() 

1804 

1805 # Check cache first (60s TTL) 

1806 cached_result = auth_cache.get_team_membership_valid_sync(user_email, final_teams) 

1807 if cached_result is False: 

1808 logger.warning(f"MCP auth rejected: User {user_email} no longer member of teams (cached)") 

1809 response = ORJSONResponse( 

1810 {"detail": "Token invalid: User is no longer a member of the associated team"}, 

1811 status_code=HTTP_403_FORBIDDEN, 

1812 ) 

1813 await response(scope, receive, send) 

1814 return False 

1815 

1816 if cached_result is None: 

1817 # Cache miss - query database 

1818 # Third-Party 

1819 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

1820 

1821 db = SessionLocal() 

1822 try: 

1823 memberships = ( 

1824 db.execute( 

1825 select(EmailTeamMember.team_id).where( 

1826 EmailTeamMember.team_id.in_(final_teams), 

1827 EmailTeamMember.user_email == user_email, 

1828 EmailTeamMember.is_active.is_(True), 

1829 ) 

1830 ) 

1831 .scalars() 

1832 .all() 

1833 ) 

1834 

1835 valid_team_ids = set(memberships) 

1836 missing_teams = set(final_teams) - valid_team_ids 

1837 

1838 if missing_teams: 

1839 logger.warning(f"MCP auth rejected: User {user_email} no longer member of teams: {missing_teams}") 

1840 auth_cache.set_team_membership_valid_sync(user_email, final_teams, False) 

1841 response = ORJSONResponse( 

1842 {"detail": "Token invalid: User is no longer a member of the associated team"}, 

1843 status_code=HTTP_403_FORBIDDEN, 

1844 ) 

1845 await response(scope, receive, send) 

1846 return False 

1847 

1848 # Cache positive result 

1849 auth_cache.set_team_membership_valid_sync(user_email, final_teams, True) 

1850 finally: 

1851 # Rollback any implicit transaction before closing to prevent 

1852 # idle-in-transaction state if the connection is returned to pool 

1853 try: 

1854 db.rollback() 

1855 except Exception: 

1856 pass # nosec B110 - Best effort rollback 

1857 db.close() 

1858 

1859 user_context_var.set( 

1860 { 

1861 "email": user_email, 

1862 "teams": final_teams, 

1863 "is_authenticated": True, 

1864 "is_admin": is_admin, 

1865 } 

1866 ) 

1867 elif proxy_user: 

1868 # If using proxy auth, store the proxy user 

1869 user_context_var.set( 

1870 { 

1871 "email": proxy_user, 

1872 "teams": [], 

1873 "is_authenticated": True, 

1874 "is_admin": False, 

1875 } 

1876 ) 

1877 except Exception: 

1878 # If JWT auth fails but we have a trusted proxy user, use that 

1879 if settings.trust_proxy_auth and proxy_user: 

1880 user_context_var.set( 

1881 { 

1882 "email": proxy_user, 

1883 "teams": [], 

1884 "is_authenticated": True, 

1885 "is_admin": False, 

1886 } 

1887 ) 

1888 return True # Fall back to proxy authentication 

1889 

1890 # Check mcp_require_auth setting to determine behavior 

1891 if settings.mcp_require_auth: 

1892 # Strict mode: require authentication, return 401 for unauthenticated requests 

1893 response = ORJSONResponse( 

1894 {"detail": "Authentication required for MCP endpoints"}, 

1895 status_code=HTTP_401_UNAUTHORIZED, 

1896 headers={"WWW-Authenticate": "Bearer"}, 

1897 ) 

1898 await response(scope, receive, send) 

1899 return False 

1900 

1901 # Permissive mode (default): allow unauthenticated access with public-only scope 

1902 # Set context indicating unauthenticated user with public-only access (teams=[]) 

1903 user_context_var.set( 

1904 { 

1905 "email": None, 

1906 "teams": [], # Empty list = public-only access 

1907 "is_authenticated": False, 

1908 "is_admin": False, 

1909 } 

1910 ) 

1911 return True # Allow request to proceed with public-only access 

1912 

1913 return True