Coverage for mcpgateway / transports / rust_mcp_runtime_proxy.py: 100%
123 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/rust_mcp_runtime_proxy.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6Experimental MCP transport proxy for the Rust runtime edge.
8This module keeps Python auth/path-rewrite middleware in front of MCP traffic
9while proxying MCP transport requests to the optional Rust runtime sidecar.
10"""
12# Future
13from __future__ import annotations
15# Standard
16import asyncio
17import base64
18import logging
19import re
20from urllib.parse import urlsplit, urlunsplit
22# Third-Party
23import httpx
24import orjson
25from starlette.types import Receive, Scope, Send
27# First-Party
28from mcpgateway.config import settings
29from mcpgateway.services.http_client_service import get_http_client, get_http_limits
30from mcpgateway.transports.streamablehttp_transport import get_streamable_http_auth_context
31from mcpgateway.utils.orjson_response import ORJSONResponse
33logger = logging.getLogger(__name__)
35_SERVER_ID_RE = re.compile(r"/servers/(?P<server_id>[a-fA-F0-9\-]+)/mcp/?$")
36_CONTEXTFORGE_SERVER_ID_HEADER = "x-contextforge-server-id"
37_CONTEXTFORGE_AUTH_CONTEXT_HEADER = "x-contextforge-auth-context"
38_CONTEXTFORGE_AFFINITY_FORWARDED_HEADER = "x-contextforge-affinity-forwarded"
39_CLIENT_ERROR_DETAIL = "See server logs"
40_REQUEST_HOP_BY_HOP_HEADERS = frozenset({"host", "content-length", "connection", "transfer-encoding", "keep-alive"})
41_FORWARDED_CHAIN_HEADERS = frozenset({"forwarded", "x-forwarded-for", "x-forwarded-host", "x-forwarded-port", "x-forwarded-proto"})
42_INTERNAL_ONLY_REQUEST_HEADERS = frozenset(
43 {
44 "x-forwarded-internally",
45 "x-original-worker",
46 "x-mcp-session-id",
47 "x-contextforge-mcp-runtime",
48 _CONTEXTFORGE_SERVER_ID_HEADER,
49 _CONTEXTFORGE_AUTH_CONTEXT_HEADER,
50 _CONTEXTFORGE_AFFINITY_FORWARDED_HEADER,
51 }
52)
53_RESPONSE_HOP_BY_HOP_HEADERS = frozenset({"connection", "transfer-encoding", "keep-alive"})
56class RustMCPRuntimeProxy:
57 """Proxy MCP transport traffic to the experimental Rust runtime."""
59 def __init__(self, python_fallback_app) -> None:
60 """Initialize the proxy with the existing Python MCP transport fallback.
62 Args:
63 python_fallback_app: Python MCP transport app used when Rust cannot handle
64 the request.
65 """
66 self.python_fallback_app = python_fallback_app
67 self._uds_client: httpx.AsyncClient | None = None
68 self._uds_client_lock = asyncio.Lock()
70 async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Send) -> None:
71 """Route MCP transport requests to the Rust runtime and preserve Python fallback for others.
73 Args:
74 scope: Incoming ASGI scope.
75 receive: ASGI receive callable.
76 send: ASGI send callable.
77 """
78 if scope.get("type") != "http":
79 logger.debug("Rust MCP runtime deferring to Python fallback: scope type %r is not 'http'", scope.get("type"))
80 await self.python_fallback_app(scope, receive, send)
81 return
83 method = str(scope.get("method", "GET")).upper()
84 if method not in {"GET", "POST", "DELETE"}:
85 logger.debug("Rust MCP runtime deferring to Python fallback: HTTP method %r is not supported", method)
86 await self.python_fallback_app(scope, receive, send)
87 return
89 target_url = _build_runtime_mcp_url(scope)
90 headers = _build_forward_headers(scope)
91 timeout = httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds)
93 try:
94 client = await self._get_runtime_client()
95 async with client.stream(
96 method,
97 target_url,
98 content=_stream_request_body(receive) if method == "POST" else b"",
99 headers=headers,
100 timeout=timeout,
101 follow_redirects=False,
102 ) as response:
103 await send(
104 {
105 "type": "http.response.start",
106 "status": response.status_code,
107 "headers": [(name, value) for name, value in response.headers.raw if name.decode("latin-1").lower() not in _RESPONSE_HOP_BY_HOP_HEADERS],
108 }
109 )
110 async for chunk in response.aiter_bytes():
111 if chunk:
112 await send({"type": "http.response.body", "body": chunk, "more_body": True})
113 await send({"type": "http.response.body", "body": b"", "more_body": False})
114 except httpx.HTTPError as exc:
115 logger.error("Experimental Rust MCP runtime request failed: %s", exc)
116 error_response = ORJSONResponse(
117 status_code=502,
118 content={
119 "jsonrpc": "2.0",
120 "id": None,
121 "error": {
122 "code": -32000,
123 "message": "Experimental Rust MCP runtime unavailable",
124 "data": _CLIENT_ERROR_DETAIL,
125 },
126 },
127 )
128 await error_response(scope, receive, send)
129 return
131 async def _get_runtime_client(self) -> httpx.AsyncClient:
132 """Return the client used for Python -> Rust runtime proxying.
134 Returns:
135 An async HTTP client configured for either UDS or loopback HTTP.
136 """
137 uds_path = settings.experimental_rust_mcp_runtime_uds
138 if not uds_path:
139 return await get_http_client()
141 if self._uds_client is not None:
142 return self._uds_client
144 async with self._uds_client_lock:
145 if self._uds_client is None:
146 self._uds_client = httpx.AsyncClient(
147 transport=httpx.AsyncHTTPTransport(uds=uds_path),
148 limits=get_http_limits(),
149 timeout=httpx.Timeout(settings.experimental_rust_mcp_runtime_timeout_seconds),
150 follow_redirects=False,
151 )
152 return self._uds_client
155async def _stream_request_body(receive: Receive):
156 """Yield ASGI request body chunks without buffering the full request.
158 Args:
159 receive: ASGI receive callable for the current request.
161 Yields:
162 Raw request body chunks as they arrive from the client.
163 """
164 while True:
165 message = await receive()
166 if message["type"] == "http.disconnect":
167 return
168 if message["type"] != "http.request":
169 continue
170 body = message.get("body", b"")
171 if body:
172 yield body
173 if not message.get("more_body", False):
174 return
177def _extract_server_id_from_scope(scope: Scope) -> str | None:
178 """Extract server_id when the mounted MCP path came from /servers/<id>/mcp.
180 Args:
181 scope: Incoming ASGI scope.
183 Returns:
184 The matched server id, or ``None`` when the request is not server-scoped.
185 """
186 modified_path = str(scope.get("modified_path") or scope.get("path") or "")
187 match = _SERVER_ID_RE.search(modified_path)
188 return match.group("server_id") if match else None
191def _build_runtime_mcp_url(scope: Scope) -> str:
192 """Build the target Rust runtime /mcp URL, preserving the query string.
194 Args:
195 scope: Incoming ASGI scope.
197 Returns:
198 Absolute URL for the Rust sidecar MCP endpoint.
199 """
200 base = urlsplit(settings.experimental_rust_mcp_runtime_url)
201 query_string = scope.get("query_string", b"")
202 query = query_string.decode("latin-1") if isinstance(query_string, (bytes, bytearray)) else str(query_string or "")
203 base_path = base.path.rstrip("/")
204 if not base_path:
205 target_path = "/mcp/"
206 elif base_path.endswith("/mcp"):
207 target_path = f"{base_path}/"
208 else:
209 target_path = f"{base_path}/mcp/"
210 merged_query = "&".join(part for part in (base.query, query) if part)
211 return urlunsplit((base.scheme, base.netloc, target_path, merged_query, ""))
214def _build_forward_headers(scope: Scope) -> list[tuple[str, str]]:
215 """Forward request headers needed by the Rust runtime while stripping internal-only headers.
217 Args:
218 scope: Incoming ASGI scope.
220 Returns:
221 Header tuples safe to forward to the Rust sidecar.
222 """
223 headers: list[tuple[str, str]] = []
224 for item in scope.get("headers") or []:
225 if not isinstance(item, (tuple, list)) or len(item) != 2:
226 continue
227 name, value = item
228 if not isinstance(name, (bytes, bytearray)) or not isinstance(value, (bytes, bytearray)):
229 continue
230 header_name = name.decode("latin-1").lower()
231 if header_name in _REQUEST_HOP_BY_HOP_HEADERS or header_name in _FORWARDED_CHAIN_HEADERS or header_name in _INTERNAL_ONLY_REQUEST_HEADERS:
232 continue
233 headers.append((header_name, value.decode("latin-1")))
235 server_id = _extract_server_id_from_scope(scope)
236 if server_id:
237 headers.append((_CONTEXTFORGE_SERVER_ID_HEADER, server_id))
239 auth_context = _build_forwarded_auth_context_header()
240 if auth_context is not None:
241 headers.append((_CONTEXTFORGE_AUTH_CONTEXT_HEADER, auth_context))
243 client = scope.get("client")
244 client_host = client[0] if isinstance(client, (tuple, list)) and client else None
245 from_loopback = client_host in ("127.0.0.1", "::1")
246 incoming_headers = {
247 name.decode("latin-1").lower(): value.decode("latin-1")
248 for item in scope.get("headers") or []
249 if isinstance(item, (tuple, list)) and len(item) == 2
250 for name, value in [item]
251 if isinstance(name, (bytes, bytearray)) and isinstance(value, (bytes, bytearray))
252 }
253 if from_loopback and incoming_headers.get("x-forwarded-internally") == "true":
254 headers.append((_CONTEXTFORGE_AFFINITY_FORWARDED_HEADER, "rust"))
256 return headers
259def _build_forwarded_auth_context_header() -> str | None:
260 """Serialize the authenticated MCP context for the trusted internal Python dispatcher.
262 Returns:
263 Base64url-encoded auth context for trusted internal forwarding, or ``None``
264 when no MCP auth context is available.
265 """
266 auth_context = get_streamable_http_auth_context()
267 if not auth_context:
268 return None
269 encoded = base64.urlsafe_b64encode(orjson.dumps(auth_context)).decode("ascii")
270 return encoded.rstrip("=")