Coverage for mcpgateway / transports / rust_mcp_runtime_proxy.py: 100%

123 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/rust_mcp_runtime_proxy.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6Experimental MCP transport proxy for the Rust runtime edge. 

7 

8This module keeps Python auth/path-rewrite middleware in front of MCP traffic 

9while proxying MCP transport requests to the optional Rust runtime sidecar. 

10""" 

11 

12# Future 

13from __future__ import annotations 

14 

15# Standard 

16import asyncio 

17import base64 

18import logging 

19import re 

20from urllib.parse import urlsplit, urlunsplit 

21 

22# Third-Party 

23import httpx 

24import orjson 

25from starlette.types import Receive, Scope, Send 

26 

27# First-Party 

28from mcpgateway.config import settings 

29from mcpgateway.services.http_client_service import get_http_client, get_http_limits 

30from mcpgateway.transports.streamablehttp_transport import get_streamable_http_auth_context 

31from mcpgateway.utils.orjson_response import ORJSONResponse 

32 

33logger = logging.getLogger(__name__) 

34 

35_SERVER_ID_RE = re.compile(r"/servers/(?P<server_id>[a-fA-F0-9\-]+)/mcp/?$") 

36_CONTEXTFORGE_SERVER_ID_HEADER = "x-contextforge-server-id" 

37_CONTEXTFORGE_AUTH_CONTEXT_HEADER = "x-contextforge-auth-context" 

38_CONTEXTFORGE_AFFINITY_FORWARDED_HEADER = "x-contextforge-affinity-forwarded" 

39_CLIENT_ERROR_DETAIL = "See server logs" 

40_REQUEST_HOP_BY_HOP_HEADERS = frozenset({"host", "content-length", "connection", "transfer-encoding", "keep-alive"}) 

41_FORWARDED_CHAIN_HEADERS = frozenset({"forwarded", "x-forwarded-for", "x-forwarded-host", "x-forwarded-port", "x-forwarded-proto"}) 

42_INTERNAL_ONLY_REQUEST_HEADERS = frozenset( 

43 { 

44 "x-forwarded-internally", 

45 "x-original-worker", 

46 "x-mcp-session-id", 

47 "x-contextforge-mcp-runtime", 

48 _CONTEXTFORGE_SERVER_ID_HEADER, 

49 _CONTEXTFORGE_AUTH_CONTEXT_HEADER, 

50 _CONTEXTFORGE_AFFINITY_FORWARDED_HEADER, 

51 } 

52) 

53_RESPONSE_HOP_BY_HOP_HEADERS = frozenset({"connection", "transfer-encoding", "keep-alive"}) 

54 

55 

56class RustMCPRuntimeProxy: 

57 """Proxy MCP transport traffic to the experimental Rust runtime.""" 

58 

59 def __init__(self, python_fallback_app) -> None: 

60 """Initialize the proxy with the existing Python MCP transport fallback. 

61 

62 Args: 

63 python_fallback_app: Python MCP transport app used when Rust cannot handle 

64 the request. 

65 """ 

66 self.python_fallback_app = python_fallback_app 

67 self._uds_client: httpx.AsyncClient | None = None 

68 self._uds_client_lock = asyncio.Lock() 

69 

70 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None: 

71 """Route MCP transport requests to the Rust runtime and preserve Python fallback for others. 

72 

73 Args: 

74 scope: Incoming ASGI scope. 

75 receive: ASGI receive callable. 

76 send: ASGI send callable. 

77 """ 

78 if scope.get("type") != "http": 

79 logger.debug("Rust MCP runtime deferring to Python fallback: scope type %r is not 'http'", scope.get("type")) 

80 await self.python_fallback_app(scope, receive, send) 

81 return 

82 

83 method = str(scope.get("method", "GET")).upper() 

84 if method not in {"GET", "POST", "DELETE"}: 

85 logger.debug("Rust MCP runtime deferring to Python fallback: HTTP method %r is not supported", method) 

86 await self.python_fallback_app(scope, receive, send) 

87 return 

88 

89 target_url = _build_runtime_mcp_url(scope) 

90 headers = _build_forward_headers(scope) 

91 timeout = httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds) 

92 

93 try: 

94 client = await self._get_runtime_client() 

95 async with client.stream( 

96 method, 

97 target_url, 

98 content=_stream_request_body(receive) if method == "POST" else b"", 

99 headers=headers, 

100 timeout=timeout, 

101 follow_redirects=False, 

102 ) as response: 

103 await send( 

104 { 

105 "type": "http.response.start", 

106 "status": response.status_code, 

107 "headers": [(name, value) for name, value in response.headers.raw if name.decode("latin-1").lower() not in _RESPONSE_HOP_BY_HOP_HEADERS], 

108 } 

109 ) 

110 async for chunk in response.aiter_bytes(): 

111 if chunk: 

112 await send({"type": "http.response.body", "body": chunk, "more_body": True}) 

113 await send({"type": "http.response.body", "body": b"", "more_body": False}) 

114 except httpx.HTTPError as exc: 

115 logger.error("Experimental Rust MCP runtime request failed: %s", exc) 

116 error_response = ORJSONResponse( 

117 status_code=502, 

118 content={ 

119 "jsonrpc": "2.0", 

120 "id": None, 

121 "error": { 

122 "code": -32000, 

123 "message": "Experimental Rust MCP runtime unavailable", 

124 "data": _CLIENT_ERROR_DETAIL, 

125 }, 

126 }, 

127 ) 

128 await error_response(scope, receive, send) 

129 return 

130 

131 async def _get_runtime_client(self) -> httpx.AsyncClient: 

132 """Return the client used for Python -> Rust runtime proxying. 

133 

134 Returns: 

135 An async HTTP client configured for either UDS or loopback HTTP. 

136 """ 

137 uds_path = settings.experimental_rust_mcp_runtime_uds 

138 if not uds_path: 

139 return await get_http_client() 

140 

141 if self._uds_client is not None: 

142 return self._uds_client 

143 

144 async with self._uds_client_lock: 

145 if self._uds_client is None: 

146 self._uds_client = httpx.AsyncClient( 

147 transport=httpx.AsyncHTTPTransport(uds=uds_path), 

148 limits=get_http_limits(), 

149 timeout=httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds), 

150 follow_redirects=False, 

151 ) 

152 return self._uds_client 

153 

154 

155async def _stream_request_body(receive: Receive): 

156 """Yield ASGI request body chunks without buffering the full request. 

157 

158 Args: 

159 receive: ASGI receive callable for the current request. 

160 

161 Yields: 

162 Raw request body chunks as they arrive from the client. 

163 """ 

164 while True: 

165 message = await receive() 

166 if message["type"] == "http.disconnect": 

167 return 

168 if message["type"] != "http.request": 

169 continue 

170 body = message.get("body", b"") 

171 if body: 

172 yield body 

173 if not message.get("more_body", False): 

174 return 

175 

176 

177def _extract_server_id_from_scope(scope: Scope) -> str | None: 

178 """Extract server_id when the mounted MCP path came from /servers/<id>/mcp. 

179 

180 Args: 

181 scope: Incoming ASGI scope. 

182 

183 Returns: 

184 The matched server id, or ``None`` when the request is not server-scoped. 

185 """ 

186 modified_path = str(scope.get("modified_path") or scope.get("path") or "") 

187 match = _SERVER_ID_RE.search(modified_path) 

188 return match.group("server_id") if match else None 

189 

190 

191def _build_runtime_mcp_url(scope: Scope) -> str: 

192 """Build the target Rust runtime /mcp URL, preserving the query string. 

193 

194 Args: 

195 scope: Incoming ASGI scope. 

196 

197 Returns: 

198 Absolute URL for the Rust sidecar MCP endpoint. 

199 """ 

200 base = urlsplit(settings.experimental_rust_mcp_runtime_url) 

201 query_string = scope.get("query_string", b"") 

202 query = query_string.decode("latin-1") if isinstance(query_string, (bytes, bytearray)) else str(query_string or "") 

203 base_path = base.path.rstrip("/") 

204 if not base_path: 

205 target_path = "/mcp/" 

206 elif base_path.endswith("/mcp"): 

207 target_path = f"{base_path}/" 

208 else: 

209 target_path = f"{base_path}/mcp/" 

210 merged_query = "&".join(part for part in (base.query, query) if part) 

211 return urlunsplit((base.scheme, base.netloc, target_path, merged_query, "")) 

212 

213 

214def _build_forward_headers(scope: Scope) -> list[tuple[str, str]]: 

215 """Forward request headers needed by the Rust runtime while stripping internal-only headers. 

216 

217 Args: 

218 scope: Incoming ASGI scope. 

219 

220 Returns: 

221 Header tuples safe to forward to the Rust sidecar. 

222 """ 

223 headers: list[tuple[str, str]] = [] 

224 for item in scope.get("headers") or []: 

225 if not isinstance(item, (tuple, list)) or len(item) != 2: 

226 continue 

227 name, value = item 

228 if not isinstance(name, (bytes, bytearray)) or not isinstance(value, (bytes, bytearray)): 

229 continue 

230 header_name = name.decode("latin-1").lower() 

231 if header_name in _REQUEST_HOP_BY_HOP_HEADERS or header_name in _FORWARDED_CHAIN_HEADERS or header_name in _INTERNAL_ONLY_REQUEST_HEADERS: 

232 continue 

233 headers.append((header_name, value.decode("latin-1"))) 

234 

235 server_id = _extract_server_id_from_scope(scope) 

236 if server_id: 

237 headers.append((_CONTEXTFORGE_SERVER_ID_HEADER, server_id)) 

238 

239 auth_context = _build_forwarded_auth_context_header() 

240 if auth_context is not None: 

241 headers.append((_CONTEXTFORGE_AUTH_CONTEXT_HEADER, auth_context)) 

242 

243 client = scope.get("client") 

244 client_host = client[0] if isinstance(client, (tuple, list)) and client else None 

245 from_loopback = client_host in ("127.0.0.1", "::1") 

246 incoming_headers = { 

247 name.decode("latin-1").lower(): value.decode("latin-1") 

248 for item in scope.get("headers") or [] 

249 if isinstance(item, (tuple, list)) and len(item) == 2 

250 for name, value in [item] 

251 if isinstance(name, (bytes, bytearray)) and isinstance(value, (bytes, bytearray)) 

252 } 

253 if from_loopback and incoming_headers.get("x-forwarded-internally") == "true": 

254 headers.append((_CONTEXTFORGE_AFFINITY_FORWARDED_HEADER, "rust")) 

255 

256 return headers 

257 

258 

259def _build_forwarded_auth_context_header() -> str | None: 

260 """Serialize the authenticated MCP context for the trusted internal Python dispatcher. 

261 

262 Returns: 

263 Base64url-encoded auth context for trusted internal forwarding, or ``None`` 

264 when no MCP auth context is available. 

265 """ 

266 auth_context = get_streamable_http_auth_context() 

267 if not auth_context: 

268 return None 

269 encoded = base64.urlsafe_b64encode(orjson.dumps(auth_context)).decode("ascii") 

270 return encoded.rstrip("=")