Coverage for mcpgateway / middleware / protocol_version.py: 100%
29 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Middleware to validate MCP-Protocol-Version header per MCP spec 2025-06-18."""
4# Standard
5import logging
6from typing import Callable
8# Third-Party
9from fastapi import Request, Response
10from starlette.middleware.base import BaseHTTPMiddleware
12# First-Party
13from mcpgateway.utils.orjson_response import ORJSONResponse
15logger = logging.getLogger(__name__)
17# MCP Protocol Versions (per MCP specification)
18SUPPORTED_PROTOCOL_VERSIONS = ["2024-11-05", "2025-03-26", "2025-06-18"]
19DEFAULT_PROTOCOL_VERSION = "2025-03-26" # Per spec, default for backwards compatibility
22class MCPProtocolVersionMiddleware(BaseHTTPMiddleware):
23 """
24 Validates MCP-Protocol-Version header per MCP spec 2025-06-18.
26 Per the MCP specification (basic/transports.mdx):
27 - Clients MUST include MCP-Protocol-Version header on all HTTP requests
28 - If not provided, server SHOULD assume 2025-03-26 for backwards compatibility
29 - If unsupported version provided, server MUST respond with 400 Bad Request
30 """
32 async def dispatch(self, request: Request, call_next: Callable) -> Response:
33 """Validate MCP-Protocol-Version header for MCP protocol endpoints.
35 Args:
36 request: The incoming HTTP request
37 call_next: The next middleware or route handler in the chain
39 Returns:
40 Response: Either a 400 error for invalid protocol versions or the result of call_next
42 Examples:
43 Non-MCP endpoints are bypassed:
45 >>> import asyncio
46 >>> from starlette.requests import Request
47 >>> from starlette.responses import Response
48 >>> from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware
49 >>> async def call_next(req): return Response("ok", media_type="text/plain")
50 >>> scope = {
51 ... "type": "http",
52 ... "asgi": {"version": "3.0"},
53 ... "method": "GET",
54 ... "path": "/health",
55 ... "raw_path": b"/health",
56 ... "query_string": b"",
57 ... "headers": [],
58 ... "client": ("testclient", 50000),
59 ... "server": ("testserver", 80),
60 ... "scheme": "http",
61 ... }
62 >>> resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(scope), call_next))
63 >>> resp.status_code
64 200
66 MCP endpoints default the version when the header is missing:
68 >>> from mcpgateway.middleware.protocol_version import DEFAULT_PROTOCOL_VERSION
69 >>> scope_rpc = {
70 ... "type": "http",
71 ... "asgi": {"version": "3.0"},
72 ... "method": "POST",
73 ... "path": "/rpc",
74 ... "raw_path": b"/rpc",
75 ... "query_string": b"",
76 ... "headers": [],
77 ... "client": ("testclient", 50000),
78 ... "server": ("testserver", 80),
79 ... "scheme": "http",
80 ... }
81 >>> req = Request(scope_rpc)
82 >>> _ = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(req, call_next))
83 >>> req.state.mcp_protocol_version == DEFAULT_PROTOCOL_VERSION
84 True
86 Unsupported versions return `400`:
88 >>> bad_scope = {
89 ... "type": "http",
90 ... "asgi": {"version": "3.0"},
91 ... "method": "POST",
92 ... "path": "/rpc",
93 ... "raw_path": b"/rpc",
94 ... "query_string": b"",
95 ... "headers": [(b"mcp-protocol-version", b"bad")],
96 ... "client": ("testclient", 50000),
97 ... "server": ("testserver", 80),
98 ... "scheme": "http",
99 ... }
100 >>> bad_resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(bad_scope), call_next))
101 >>> (bad_resp.status_code, b"Unsupported protocol version: bad" in bad_resp.body)
102 (400, True)
103 """
104 path = request.url.path
106 # Skip validation for non-MCP endpoints (admin UI, health, openapi, etc.)
107 if not self._is_mcp_endpoint(path):
108 return await call_next(request)
110 # Get the protocol version from headers (case-insensitive)
111 protocol_version = request.headers.get("mcp-protocol-version")
113 # If no protocol version provided, assume default version (backwards compatibility)
114 if protocol_version is None:
115 protocol_version = DEFAULT_PROTOCOL_VERSION
116 logger.debug(f"No MCP-Protocol-Version header, assuming {DEFAULT_PROTOCOL_VERSION}")
118 # Validate protocol version
119 if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS:
120 supported = ", ".join(SUPPORTED_PROTOCOL_VERSIONS)
121 logger.warning(f"Unsupported protocol version: {protocol_version}")
122 return ORJSONResponse(
123 status_code=400,
124 content={"error": "Bad Request", "message": f"Unsupported protocol version: {protocol_version}. Supported versions: {supported}"},
125 )
127 # Store validated version in request state for use by handlers
128 request.state.mcp_protocol_version = protocol_version
130 return await call_next(request)
132 def _is_mcp_endpoint(self, path: str) -> bool:
133 """
134 Check if path is an MCP protocol endpoint that requires version validation.
136 MCP protocol endpoints include:
137 - /rpc (main JSON-RPC endpoint)
138 - /servers/*/sse (Server-Sent Events transport)
139 - /servers/*/ws (WebSocket transport)
141 Non-MCP endpoints (admin, health, openapi, etc.) are excluded.
143 Args:
144 path: The request URL path to check
146 Returns:
147 bool: True if path is an MCP protocol endpoint, False otherwise
148 """
149 # Exact match for main RPC endpoint
150 if path in ("/rpc", "/"):
151 return True
153 # Prefix matches for SSE/WebSocket/Server endpoints
154 if path.startswith("/servers/") and (path.endswith("/sse") or path.endswith("/ws")):
155 return True
157 return False