Coverage for mcpgateway / middleware / protocol_version.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Middleware to validate MCP-Protocol-Version header per MCP spec 2025-06-18.""" 

3 

4# Standard 

5import logging 

6from typing import Callable 

7 

8# Third-Party 

9from fastapi import Request, Response 

10from starlette.middleware.base import BaseHTTPMiddleware 

11 

12# First-Party 

13from mcpgateway.utils.orjson_response import ORJSONResponse 

14 

15logger = logging.getLogger(__name__) 

16 

17# MCP Protocol Versions (per MCP specification) 

18SUPPORTED_PROTOCOL_VERSIONS = ["2024-11-05", "2025-03-26", "2025-06-18"] 

19DEFAULT_PROTOCOL_VERSION = "2025-03-26" # Per spec, default for backwards compatibility 

20 

21 

22class MCPProtocolVersionMiddleware(BaseHTTPMiddleware): 

23 """ 

24 Validates MCP-Protocol-Version header per MCP spec 2025-06-18. 

25 

26 Per the MCP specification (basic/transports.mdx): 

27 - Clients MUST include MCP-Protocol-Version header on all HTTP requests 

28 - If not provided, server SHOULD assume 2025-03-26 for backwards compatibility 

29 - If unsupported version provided, server MUST respond with 400 Bad Request 

30 """ 

31 

32 async def dispatch(self, request: Request, call_next: Callable) -> Response: 

33 """Validate MCP-Protocol-Version header for MCP protocol endpoints. 

34 

35 Args: 

36 request: The incoming HTTP request 

37 call_next: The next middleware or route handler in the chain 

38 

39 Returns: 

40 Response: Either a 400 error for invalid protocol versions or the result of call_next 

41 

42 Examples: 

43 Non-MCP endpoints are bypassed: 

44 

45 >>> import asyncio 

46 >>> from starlette.requests import Request 

47 >>> from starlette.responses import Response 

48 >>> from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware 

49 >>> async def call_next(req): return Response("ok", media_type="text/plain") 

50 >>> scope = { 

51 ... "type": "http", 

52 ... "asgi": {"version": "3.0"}, 

53 ... "method": "GET", 

54 ... "path": "/health", 

55 ... "raw_path": b"/health", 

56 ... "query_string": b"", 

57 ... "headers": [], 

58 ... "client": ("testclient", 50000), 

59 ... "server": ("testserver", 80), 

60 ... "scheme": "http", 

61 ... } 

62 >>> resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(scope), call_next)) 

63 >>> resp.status_code 

64 200 

65 

66 MCP endpoints default the version when the header is missing: 

67 

68 >>> from mcpgateway.middleware.protocol_version import DEFAULT_PROTOCOL_VERSION 

69 >>> scope_rpc = { 

70 ... "type": "http", 

71 ... "asgi": {"version": "3.0"}, 

72 ... "method": "POST", 

73 ... "path": "/rpc", 

74 ... "raw_path": b"/rpc", 

75 ... "query_string": b"", 

76 ... "headers": [], 

77 ... "client": ("testclient", 50000), 

78 ... "server": ("testserver", 80), 

79 ... "scheme": "http", 

80 ... } 

81 >>> req = Request(scope_rpc) 

82 >>> _ = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(req, call_next)) 

83 >>> req.state.mcp_protocol_version == DEFAULT_PROTOCOL_VERSION 

84 True 

85 

86 Unsupported versions return `400`: 

87 

88 >>> bad_scope = { 

89 ... "type": "http", 

90 ... "asgi": {"version": "3.0"}, 

91 ... "method": "POST", 

92 ... "path": "/rpc", 

93 ... "raw_path": b"/rpc", 

94 ... "query_string": b"", 

95 ... "headers": [(b"mcp-protocol-version", b"bad")], 

96 ... "client": ("testclient", 50000), 

97 ... "server": ("testserver", 80), 

98 ... "scheme": "http", 

99 ... } 

100 >>> bad_resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(bad_scope), call_next)) 

101 >>> (bad_resp.status_code, b"Unsupported protocol version: bad" in bad_resp.body) 

102 (400, True) 

103 """ 

104 path = request.url.path 

105 

106 # Skip validation for non-MCP endpoints (admin UI, health, openapi, etc.) 

107 if not self._is_mcp_endpoint(path): 

108 return await call_next(request) 

109 

110 # Get the protocol version from headers (case-insensitive) 

111 protocol_version = request.headers.get("mcp-protocol-version") 

112 

113 # If no protocol version provided, assume default version (backwards compatibility) 

114 if protocol_version is None: 

115 protocol_version = DEFAULT_PROTOCOL_VERSION 

116 logger.debug(f"No MCP-Protocol-Version header, assuming {DEFAULT_PROTOCOL_VERSION}") 

117 

118 # Validate protocol version 

119 if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: 

120 supported = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) 

121 logger.warning(f"Unsupported protocol version: {protocol_version}") 

122 return ORJSONResponse( 

123 status_code=400, 

124 content={"error": "Bad Request", "message": f"Unsupported protocol version: {protocol_version}. Supported versions: {supported}"}, 

125 ) 

126 

127 # Store validated version in request state for use by handlers 

128 request.state.mcp_protocol_version = protocol_version 

129 

130 return await call_next(request) 

131 

132 def _is_mcp_endpoint(self, path: str) -> bool: 

133 """ 

134 Check if path is an MCP protocol endpoint that requires version validation. 

135 

136 MCP protocol endpoints include: 

137 - /rpc (main JSON-RPC endpoint) 

138 - /servers/*/sse (Server-Sent Events transport) 

139 - /servers/*/ws (WebSocket transport) 

140 

141 Non-MCP endpoints (admin, health, openapi, etc.) are excluded. 

142 

143 Args: 

144 path: The request URL path to check 

145 

146 Returns: 

147 bool: True if path is an MCP protocol endpoint, False otherwise 

148 """ 

149 # Exact match for main RPC endpoint 

150 if path in ("/rpc", "/"): 

151 return True 

152 

153 # Prefix matches for SSE/WebSocket/Server endpoints 

154 if path.startswith("/servers/") and (path.endswith("/sse") or path.endswith("/ws")): 

155 return True 

156 

157 return False