Coverage for mcpgateway / middleware / protocol_version.py: 100%
31 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Middleware to validate MCP-Protocol-Version header for MCP HTTP endpoints."""
4# Standard
5import logging
6from typing import Callable
8# Third-Party
9from fastapi import Request, Response
10from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS as MCP_SUPPORTED_PROTOCOL_VERSIONS
11from mcp.types import LATEST_PROTOCOL_VERSION
12from starlette.middleware.base import BaseHTTPMiddleware
14# First-Party
15from mcpgateway.utils.orjson_response import ORJSONResponse
17logger = logging.getLogger(__name__)
19# MCP protocol versions are sourced from the MCP SDK to stay aligned with schema.ts.
20SUPPORTED_PROTOCOL_VERSIONS = list(MCP_SUPPORTED_PROTOCOL_VERSIONS)
21# Default to the latest protocol for this implementation.
22DEFAULT_PROTOCOL_VERSION = LATEST_PROTOCOL_VERSION
25class MCPProtocolVersionMiddleware(BaseHTTPMiddleware):
26 """
27 Validates MCP-Protocol-Version header on MCP protocol HTTP endpoints.
28 """
30 async def dispatch(self, request: Request, call_next: Callable) -> Response:
31 """Validate MCP-Protocol-Version header for MCP protocol endpoints.
33 Args:
34 request: The incoming HTTP request
35 call_next: The next middleware or route handler in the chain
37 Returns:
38 Response: Either a 400 error for invalid protocol versions or the result of call_next
40 Examples:
41 Non-MCP endpoints are bypassed:
43 >>> import asyncio
44 >>> from starlette.requests import Request
45 >>> from starlette.responses import Response
46 >>> from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware
47 >>> async def call_next(req): return Response("ok", media_type="text/plain")
48 >>> scope = {
49 ... "type": "http",
50 ... "asgi": {"version": "3.0"},
51 ... "method": "GET",
52 ... "path": "/health",
53 ... "raw_path": b"/health",
54 ... "query_string": b"",
55 ... "headers": [],
56 ... "client": ("testclient", 50000),
57 ... "server": ("testserver", 80),
58 ... "scheme": "http",
59 ... }
60 >>> resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(scope), call_next))
61 >>> resp.status_code
62 200
64 MCP endpoints default the version when the header is missing:
66 >>> from mcpgateway.middleware.protocol_version import DEFAULT_PROTOCOL_VERSION
67 >>> scope_rpc = {
68 ... "type": "http",
69 ... "asgi": {"version": "3.0"},
70 ... "method": "POST",
71 ... "path": "/rpc",
72 ... "raw_path": b"/rpc",
73 ... "query_string": b"",
74 ... "headers": [],
75 ... "client": ("testclient", 50000),
76 ... "server": ("testserver", 80),
77 ... "scheme": "http",
78 ... }
79 >>> req = Request(scope_rpc)
80 >>> _ = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(req, call_next))
81 >>> req.state.mcp_protocol_version == DEFAULT_PROTOCOL_VERSION
82 True
84 Unsupported versions return `400`:
86 >>> bad_scope = {
87 ... "type": "http",
88 ... "asgi": {"version": "3.0"},
89 ... "method": "POST",
90 ... "path": "/rpc",
91 ... "raw_path": b"/rpc",
92 ... "query_string": b"",
93 ... "headers": [(b"mcp-protocol-version", b"bad")],
94 ... "client": ("testclient", 50000),
95 ... "server": ("testserver", 80),
96 ... "scheme": "http",
97 ... }
98 >>> bad_resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(bad_scope), call_next))
99 >>> (bad_resp.status_code, b"Unsupported protocol version: bad" in bad_resp.body)
100 (400, True)
101 """
102 path = request.url.path
104 # Skip validation for non-MCP endpoints (admin UI, health, openapi, etc.)
105 if not self._is_mcp_endpoint(path):
106 return await call_next(request)
108 # Get the protocol version from headers (case-insensitive)
109 protocol_version = request.headers.get("mcp-protocol-version")
111 # If no protocol version provided, assume default version (backwards compatibility)
112 if protocol_version is None:
113 protocol_version = DEFAULT_PROTOCOL_VERSION
114 logger.debug(f"No MCP-Protocol-Version header, assuming {DEFAULT_PROTOCOL_VERSION}")
116 # Validate protocol version
117 if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
118 supported = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
119 logger.warning(f"Unsupported protocol version: {protocol_version}")
120 return ORJSONResponse(
121 status_code=400,
122 content={"error": "Bad Request", "message": f"Unsupported protocol version: {protocol_version}. Supported versions: {supported}"},
123 )
125 # Store validated version in request state for use by handlers
126 request.state.mcp_protocol_version = protocol_version
128 return await call_next(request)
130 def _is_mcp_endpoint(self, path: str) -> bool:
131 """
132 Check if path is an MCP protocol endpoint that requires version validation.
134 MCP protocol endpoints include:
135 - /mcp and /mcp/ (Streamable HTTP transport)
136 - /rpc and /rpc/ (gateway JSON-RPC endpoint)
137 - /servers/*/sse (SSE transport)
138 - /servers/*/ws (WebSocket transport)
140 Non-MCP endpoints (admin, health, openapi, etc.) are excluded.
142 Args:
143 path: The request URL path to check
145 Returns:
146 bool: True if path is an MCP protocol endpoint, False otherwise
147 """
148 # Exact match for main RPC endpoint
149 if path in ("/mcp", "/mcp/", "/rpc", "/rpc/"):
150 return True
152 # Prefix matches for SSE/WebSocket/Server endpoints
153 if path.startswith("/servers/") and (path.endswith("/sse") or path.endswith("/ws")):
154 return True
156 return False