Coverage for mcpgateway / middleware / protocol_version.py: 100%

31 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Middleware to validate MCP-Protocol-Version header for MCP HTTP endpoints.""" 

3 

4# Standard 

5import logging 

6from typing import Callable 

7 

8# Third-Party 

9from fastapi import Request, Response 

10from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS as MCP_SUPPORTED_PROTOCOL_VERSIONS 

11from mcp.types import LATEST_PROTOCOL_VERSION 

12from starlette.middleware.base import BaseHTTPMiddleware 

13 

14# First-Party 

15from mcpgateway.utils.orjson_response import ORJSONResponse 

16 

17logger = logging.getLogger(__name__) 

18 

19# MCP protocol versions are sourced from the MCP SDK to stay aligned with schema.ts. 

20SUPPORTED_PROTOCOL_VERSIONS = list(MCP_SUPPORTED_PROTOCOL_VERSIONS) 

21# Default to the latest protocol for this implementation. 

22DEFAULT_PROTOCOL_VERSION = LATEST_PROTOCOL_VERSION 

23 

24 

25class MCPProtocolVersionMiddleware(BaseHTTPMiddleware): 

26 """ 

27 Validates MCP-Protocol-Version header on MCP protocol HTTP endpoints. 

28 """ 

29 

30 async def dispatch(self, request: Request, call_next: Callable) -> Response: 

31 """Validate MCP-Protocol-Version header for MCP protocol endpoints. 

32 

33 Args: 

34 request: The incoming HTTP request 

35 call_next: The next middleware or route handler in the chain 

36 

37 Returns: 

38 Response: Either a 400 error for invalid protocol versions or the result of call_next 

39 

40 Examples: 

41 Non-MCP endpoints are bypassed: 

42 

43 >>> import asyncio 

44 >>> from starlette.requests import Request 

45 >>> from starlette.responses import Response 

46 >>> from mcpgateway.middleware.protocol_version import MCPProtocolVersionMiddleware 

47 >>> async def call_next(req): return Response("ok", media_type="text/plain") 

48 >>> scope = { 

49 ... "type": "http", 

50 ... "asgi": {"version": "3.0"}, 

51 ... "method": "GET", 

52 ... "path": "/health", 

53 ... "raw_path": b"/health", 

54 ... "query_string": b"", 

55 ... "headers": [], 

56 ... "client": ("testclient", 50000), 

57 ... "server": ("testserver", 80), 

58 ... "scheme": "http", 

59 ... } 

60 >>> resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(scope), call_next)) 

61 >>> resp.status_code 

62 200 

63 

64 MCP endpoints default the version when the header is missing: 

65 

66 >>> from mcpgateway.middleware.protocol_version import DEFAULT_PROTOCOL_VERSION 

67 >>> scope_rpc = { 

68 ... "type": "http", 

69 ... "asgi": {"version": "3.0"}, 

70 ... "method": "POST", 

71 ... "path": "/rpc", 

72 ... "raw_path": b"/rpc", 

73 ... "query_string": b"", 

74 ... "headers": [], 

75 ... "client": ("testclient", 50000), 

76 ... "server": ("testserver", 80), 

77 ... "scheme": "http", 

78 ... } 

79 >>> req = Request(scope_rpc) 

80 >>> _ = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(req, call_next)) 

81 >>> req.state.mcp_protocol_version == DEFAULT_PROTOCOL_VERSION 

82 True 

83 

84 Unsupported versions return `400`: 

85 

86 >>> bad_scope = { 

87 ... "type": "http", 

88 ... "asgi": {"version": "3.0"}, 

89 ... "method": "POST", 

90 ... "path": "/rpc", 

91 ... "raw_path": b"/rpc", 

92 ... "query_string": b"", 

93 ... "headers": [(b"mcp-protocol-version", b"bad")], 

94 ... "client": ("testclient", 50000), 

95 ... "server": ("testserver", 80), 

96 ... "scheme": "http", 

97 ... } 

98 >>> bad_resp = asyncio.run(MCPProtocolVersionMiddleware(app=None).dispatch(Request(bad_scope), call_next)) 

99 >>> (bad_resp.status_code, b"Unsupported protocol version: bad" in bad_resp.body) 

100 (400, True) 

101 """ 

102 path = request.url.path 

103 

104 # Skip validation for non-MCP endpoints (admin UI, health, openapi, etc.) 

105 if not self._is_mcp_endpoint(path): 

106 return await call_next(request) 

107 

108 # Get the protocol version from headers (case-insensitive) 

109 protocol_version = request.headers.get("mcp-protocol-version") 

110 

111 # If no protocol version provided, assume default version (backwards compatibility) 

112 if protocol_version is None: 

113 protocol_version = DEFAULT_PROTOCOL_VERSION 

114 logger.debug(f"No MCP-Protocol-Version header, assuming {DEFAULT_PROTOCOL_VERSION}") 

115 

116 # Validate protocol version 

117 if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: 

118 supported = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) 

119 logger.warning(f"Unsupported protocol version: {protocol_version}") 

120 return ORJSONResponse( 

121 status_code=400, 

122 content={"error": "Bad Request", "message": f"Unsupported protocol version: {protocol_version}. Supported versions: {supported}"}, 

123 ) 

124 

125 # Store validated version in request state for use by handlers 

126 request.state.mcp_protocol_version = protocol_version 

127 

128 return await call_next(request) 

129 

130 def _is_mcp_endpoint(self, path: str) -> bool: 

131 """ 

132 Check if path is an MCP protocol endpoint that requires version validation. 

133 

134 MCP protocol endpoints include: 

135 - /mcp and /mcp/ (Streamable HTTP transport) 

136 - /rpc and /rpc/ (gateway JSON-RPC endpoint) 

137 - /servers/*/sse (SSE transport) 

138 - /servers/*/ws (WebSocket transport) 

139 

140 Non-MCP endpoints (admin, health, openapi, etc.) are excluded. 

141 

142 Args: 

143 path: The request URL path to check 

144 

145 Returns: 

146 bool: True if path is an MCP protocol endpoint, False otherwise 

147 """ 

148 # Exact match for main RPC endpoint 

149 if path in ("/mcp", "/mcp/", "/rpc", "/rpc/"): 

150 return True 

151 

152 # Prefix matches for SSE/WebSocket/Server endpoints 

153 if path.startswith("/servers/") and (path.endswith("/sse") or path.endswith("/ws")): 

154 return True 

155 

156 return False