Coverage for mcpgateway / utils / passthrough_headers.py: 99%
258 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/passthrough_headers.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7HTTP Header Passthrough Utilities.
8This module provides utilities for handling HTTP header passthrough functionality
9in ContextForge. It enables forwarding of specific headers from incoming
10client requests to backing MCP servers while preventing conflicts with
11existing authentication mechanisms.
13Key Features:
14- Global configuration support via environment variables and database
15- Per-gateway header configuration overrides
16- Intelligent conflict detection with existing authentication headers
17- Security-first approach with explicit allowlist handling
18- Comprehensive logging for debugging and monitoring
19- Header validation and sanitization
21The header passthrough system follows a priority hierarchy:
221. Gateway-specific headers (highest priority)
232. Global database configuration
243. Environment variable defaults (lowest priority)
26Example Usage:
27 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
28 for detailed examples of header passthrough functionality.
29"""
31# Standard
32import logging
33import re
34import threading
35import time
36from typing import Dict, List, Optional
38# Third-Party
39from sqlalchemy.orm import Session
41# First-Party
42from mcpgateway.cache.global_config_cache import global_config_cache
43from mcpgateway.config import settings
44from mcpgateway.db import Gateway as DbGateway
45from mcpgateway.db import GlobalConfig
47logger = logging.getLogger(__name__)
49# Header name validation regex - allows letters, numbers, and hyphens
50HEADER_NAME_REGEX = re.compile(r"^[A-Za-z0-9\-]+$")
52# Maximum header value length (4KB)
53MAX_HEADER_VALUE_LENGTH = 4096
56class PassthroughHeadersError(Exception):
57 """Base class for passthrough headers-related errors.
59 Examples:
60 >>> error = PassthroughHeadersError("Test error")
61 >>> str(error)
62 'Test error'
63 >>> isinstance(error, Exception)
64 True
65 """
68def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str:
69 """Sanitize header value for security.
71 Removes dangerous characters and enforces length limits.
73 Args:
74 value: Header value to sanitize
75 max_length: Maximum allowed length
77 Returns:
78 Sanitized header value
80 Examples:
81 Remove CRLF and trim length:
82 >>> s = sanitize_header_value('val' + chr(13) + chr(10) + 'more', max_length=6)
83 >>> s
84 'valmor'
85 >>> len(s) <= 6
86 True
87 >>> sanitize_header_value(' spaced ')
88 'spaced'
89 """
90 # Remove newlines and carriage returns to prevent header injection
91 value = value.replace("\r", "").replace("\n", "")
93 # Trim to max length
94 value = value[:max_length]
96 # Remove control characters except tab (ASCII 9) and space (ASCII 32)
97 value = "".join(c for c in value if ord(c) >= 32 or c == "\t")
99 return value.strip()
102def validate_header_name(name: str) -> bool:
103 """Validate header name against allowed pattern.
105 Args:
106 name: Header name to validate
108 Returns:
109 True if valid, False otherwise
111 Examples:
112 Valid names:
113 >>> validate_header_name('X-Tenant-Id')
114 True
115 >>> validate_header_name('X123-ABC')
116 True
118 Invalid names:
119 >>> validate_header_name('Invalid Header:Name')
120 False
121 >>> validate_header_name('Bad@Name')
122 False
123 """
124 return bool(HEADER_NAME_REGEX.match(name))
127def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[str, str], db: Session, gateway: Optional[DbGateway] = None) -> Dict[str, str]:
128 """Get headers that should be passed through to the target gateway.
130 This function implements the core logic for HTTP header passthrough in ContextForge.
131 It determines which headers from incoming client requests should be forwarded to
132 backing MCP servers based on configuration settings and security policies.
134 Configuration Priority (highest to lowest):
135 1. Gateway-specific passthrough_headers setting
136 2. Global headers from get_passthrough_headers() based on PASSTHROUGH_HEADERS_SOURCE:
137 - "db": Database wins if configured, env var DEFAULT_PASSTHROUGH_HEADERS as fallback
138 - "env": Environment variable always wins, database ignored
139 - "merge": Union of both sources (DB casing wins for duplicates)
141 Security Features:
142 - Feature flag control (disabled by default)
143 - Prevents conflicts with existing base headers (e.g., Content-Type)
144 - Blocks Authorization header conflicts with gateway authentication
145 - Header name validation (regex pattern matching)
146 - Header value sanitization (removes dangerous characters, enforces limits)
147 - Logs all conflicts and skipped headers for debugging
148 - Uses case-insensitive header matching for robustness
149 - Special X-Upstream-Authorization handling: When gateway uses auth, clients can
150 send X-Upstream-Authorization header which gets renamed to Authorization for upstream
152 Args:
153 request_headers (Dict[str, str]): Headers from the incoming HTTP request.
154 Keys should be header names, values should be header values.
155 Example: {"Authorization": "Bearer token123", "X-Tenant-Id": "acme"}
156 base_headers (Dict[str, str]): Base headers that should always be included
157 in the final result. These take precedence over passthrough headers.
158 Example: {"Content-Type": "application/json", "User-Agent": "MCPGateway/1.0"}
159 db (Session): SQLAlchemy database session for querying global configuration.
160 Used to retrieve GlobalConfig.passthrough_headers setting.
161 gateway (Optional[DbGateway]): Target gateway instance. If provided, uses
162 gateway.passthrough_headers to override global settings. Also checks
163 gateway.auth_type to prevent Authorization header conflicts.
165 Returns:
166 Dict[str, str]: Combined dictionary of base headers plus allowed passthrough
167 headers from the request. Base headers are preserved, and passthrough
168 headers are added only if they don't conflict with security policies.
170 Raises:
171 No exceptions are raised. Errors are logged as warnings and processing continues.
172 Database connection issues may propagate from the db.query() call.
174 Examples:
175 Feature disabled by default (secure by default):
176 >>> from unittest.mock import Mock, patch
177 >>> from mcpgateway.cache.global_config_cache import global_config_cache
178 >>> global_config_cache.invalidate() # Clear cache for isolated test
179 >>> with patch(__name__ + ".settings") as mock_settings:
180 ... mock_settings.enable_header_passthrough = False
181 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"]
182 ... mock_db = Mock()
183 ... mock_db.query.return_value.first.return_value = None
184 ... request_headers = {"x-tenant-id": "should-be-ignored"}
185 ... base_headers = {"Content-Type": "application/json"}
186 ... get_passthrough_headers(request_headers, base_headers, mock_db)
187 {'Content-Type': 'application/json'}
189 Enabled with allowlist and conflicts:
190 >>> global_config_cache.invalidate() # Clear cache for isolated test
191 >>> with patch(__name__ + ".settings") as mock_settings:
192 ... mock_settings.enable_header_passthrough = True
193 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "Authorization"]
194 ... # Mock DB returns no global override
195 ... mock_db = Mock()
196 ... mock_db.query.return_value.first.return_value = None
197 ... # Gateway with basic auth should block Authorization passthrough
198 ... gateway = Mock()
199 ... gateway.passthrough_headers = None
200 ... gateway.auth_type = "basic"
201 ... gateway.name = "gw1"
202 ... req_headers = {"X-Tenant-Id": "acme", "Authorization": "Bearer abc"}
203 ... base = {"Content-Type": "application/json", "Authorization": "Bearer base"}
204 ... res = get_passthrough_headers(req_headers, base, mock_db, gateway)
205 ... ("X-Tenant-Id" in res) and (res["Authorization"] == "Bearer base")
206 True
208 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
209 for detailed examples of enabled functionality, conflict detection, and security features.
211 Note:
212 Header names are matched case-insensitively but preserved in their original
213 case from the allowed_headers configuration. Request header values are
214 matched case-insensitively against the request_headers dictionary.
215 """
216 passthrough_headers = base_headers.copy()
218 # Special handling for X-Upstream-Authorization header (always enabled)
219 # If gateway uses auth and client wants to pass Authorization to upstream,
220 # client can use X-Upstream-Authorization which gets renamed to Authorization
221 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
222 upstream_auth = request_headers_lower.get("x-upstream-authorization")
224 if upstream_auth:
225 try:
226 sanitized_value = sanitize_header_value(upstream_auth)
227 if sanitized_value:
228 # Always rename X-Upstream-Authorization to Authorization for upstream
229 # This works for both auth and no-auth gateways
230 passthrough_headers["Authorization"] = sanitized_value
231 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough")
232 except Exception as e:
233 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}")
234 elif gateway and gateway.auth_type == "none":
235 # When gateway has no auth, pass through client's Authorization if present
236 client_auth = request_headers_lower.get("authorization")
237 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]:
238 try:
239 sanitized_value = sanitize_header_value(client_auth)
240 if sanitized_value:
241 passthrough_headers["Authorization"] = sanitized_value
242 logger.debug("Passing through client Authorization header (auth_type=none)")
243 except Exception as e:
244 logger.warning(f"Failed to sanitize Authorization header: {e}")
246 # Early return if header passthrough feature is disabled
247 if not settings.enable_header_passthrough:
248 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag")
249 return passthrough_headers
251 if settings.enable_overwrite_base_headers:
252 logger.debug("Overwriting base headers is enabled via ENABLE_OVERWRITE_BASE_HEADERS flag")
254 # Get global passthrough headers from in-memory cache (Issue #1715)
255 # This eliminates redundant DB queries for static configuration
256 allowed_headers = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
258 # Gateway specific headers override global config
259 if gateway:
260 if gateway.passthrough_headers is not None:
261 allowed_headers = gateway.passthrough_headers
263 # Create case-insensitive lookup for request headers
264 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
266 # Get auth headers to check for conflicts
267 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()}
269 # Copy allowed headers from request
270 if request_headers_lower and allowed_headers:
271 for header_name in allowed_headers:
272 # Validate header name
273 if not validate_header_name(header_name):
274 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
275 continue
277 header_lower = header_name.lower()
278 header_value = request_headers_lower.get(header_lower)
280 if header_value:
281 # Sanitize header value
282 try:
283 sanitized_value = sanitize_header_value(header_value)
284 if not sanitized_value:
285 logger.warning(f"Header {header_name} value became empty after sanitization - skipping")
286 continue
287 except Exception as e:
288 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping")
289 continue
291 # Skip if header would conflict with existing auth headers
292 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers:
293 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers")
294 continue
296 # Skip if header would conflict with gateway auth
297 if gateway:
298 if gateway.auth_type == "basic" and header_lower == "authorization":
299 logger.warning(f"Skipping Authorization header passthrough due to basic auth configuration on gateway {gateway.name}")
300 continue
301 if gateway.auth_type == "bearer" and header_lower == "authorization":
302 logger.warning(f"Skipping Authorization header passthrough due to bearer auth configuration on gateway {gateway.name}")
303 continue
305 # Use original header name casing from configuration, sanitized value from request
306 passthrough_headers[header_name] = sanitized_value
307 logger.debug(f"Added passthrough header: {header_name}")
308 else:
309 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough")
311 logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}")
312 return passthrough_headers
315def compute_passthrough_headers_cached(
316 request_headers: Dict[str, str],
317 base_headers: Dict[str, str],
318 allowed_headers: List[str],
319 gateway_auth_type: Optional[str] = None,
320 gateway_passthrough_headers: Optional[List[str]] = None,
321) -> Dict[str, str]:
322 """Compute passthrough headers without database query.
324 Use this when GlobalConfig has already been fetched and cached, to avoid
325 repeated database queries during high-frequency operations like tool invocation.
327 This function implements the same header passthrough logic as get_passthrough_headers()
328 but accepts pre-fetched configuration values instead of querying the database.
330 Args:
331 request_headers: Headers from the incoming HTTP request.
332 base_headers: Base headers that should always be included (auth, content-type, etc.).
333 allowed_headers: List of header names allowed to pass through (from GlobalConfig).
334 gateway_auth_type: The gateway's auth_type (basic, bearer, oauth, none) if applicable.
335 gateway_passthrough_headers: Gateway-specific passthrough headers override.
337 Returns:
338 Combined dictionary of base headers plus allowed passthrough headers.
340 Examples:
341 >>> from unittest.mock import patch
342 >>> from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
343 >>> request = {"X-Tenant-Id": "acme", "Authorization": "secret"}
344 >>> base = {"Content-Type": "application/json"}
345 >>> allowed = ["X-Tenant-Id"]
346 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings:
347 ... mock_settings.enable_header_passthrough = True
348 ... mock_settings.enable_overwrite_base_headers = False
349 ... result = compute_passthrough_headers_cached(request, base, allowed, gateway_auth_type=None)
350 >>> "X-Tenant-Id" in result
351 True
352 >>> result.get("Authorization") is None # Not in allowed list
353 True
354 """
355 passthrough_headers = base_headers.copy()
357 # Special handling for X-Upstream-Authorization header (always enabled)
358 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
359 upstream_auth = request_headers_lower.get("x-upstream-authorization")
361 if upstream_auth:
362 try:
363 sanitized_value = sanitize_header_value(upstream_auth)
364 if sanitized_value:
365 passthrough_headers["Authorization"] = sanitized_value
366 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough")
367 except Exception as e:
368 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}")
369 elif gateway_auth_type == "none":
370 # When gateway has no auth, pass through client's Authorization if present
371 client_auth = request_headers_lower.get("authorization")
372 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]:
373 try:
374 sanitized_value = sanitize_header_value(client_auth)
375 if sanitized_value:
376 passthrough_headers["Authorization"] = sanitized_value
377 logger.debug("Passing through client Authorization header (auth_type=none)")
378 except Exception as e:
379 logger.warning(f"Failed to sanitize Authorization header: {e}")
381 # Early return if header passthrough feature is disabled
382 if not settings.enable_header_passthrough:
383 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag")
384 return passthrough_headers
386 # Use gateway-specific headers if provided, otherwise use global allowed_headers
387 effective_allowed = gateway_passthrough_headers if gateway_passthrough_headers is not None else allowed_headers
389 # Create case-insensitive lookup for base headers
390 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()}
392 # Copy allowed headers from request
393 if request_headers_lower and effective_allowed:
394 for header_name in effective_allowed:
395 # Validate header name
396 if not validate_header_name(header_name):
397 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
398 continue
400 header_lower = header_name.lower()
401 header_value = request_headers_lower.get(header_lower)
403 if header_value:
404 # Sanitize header value
405 try:
406 sanitized_value = sanitize_header_value(header_value)
407 if not sanitized_value:
408 logger.warning(f"Header {header_name} value became empty after sanitization - skipping")
409 continue
410 except Exception as e:
411 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping")
412 continue
414 # Skip if header would conflict with existing auth headers
415 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers:
416 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers")
417 continue
419 # Skip if header would conflict with gateway auth
420 if gateway_auth_type in ("basic", "bearer") and header_lower == "authorization":
421 logger.warning(f"Skipping Authorization header passthrough due to {gateway_auth_type} auth configuration")
422 continue
424 # Use original header name casing from configuration, sanitized value from request
425 passthrough_headers[header_name] = sanitized_value
426 logger.debug(f"Added passthrough header: {header_name}")
427 else:
428 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough")
430 logger.debug(f"Final passthrough headers (cached): {list(passthrough_headers.keys())}")
431 return passthrough_headers
434async def set_global_passthrough_headers(db: Session) -> None:
435 """Set global passthrough headers in the database if not already configured.
437 This function checks if the global passthrough headers are already set in the
438 GlobalConfig table. If not, it initializes them with the default headers from
439 settings.default_passthrough_headers.
441 When PASSTHROUGH_HEADERS_SOURCE=env, this function skips database writes entirely
442 since the database configuration is ignored in that mode.
444 Args:
445 db (Session): SQLAlchemy database session for querying and updating GlobalConfig.
447 Raises:
448 PassthroughHeadersError: If unable to update passthrough headers in the database.
450 Examples:
451 Successful insert of default headers:
452 >>> import pytest
453 >>> from unittest.mock import Mock, patch
454 >>> @pytest.mark.asyncio
455 ... @patch("mcpgateway.utils.passthrough_headers.settings")
456 ... async def test_default_headers(mock_settings):
457 ... mock_settings.enable_header_passthrough = True
458 ... mock_settings.passthrough_headers_source = "db"
459 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
460 ... mock_db = Mock()
461 ... mock_db.query.return_value.first.return_value = None
462 ... await set_global_passthrough_headers(mock_db)
463 ... mock_db.add.assert_called_once()
464 ... mock_db.commit.assert_called_once()
466 Database write failure:
467 >>> import pytest
468 >>> from unittest.mock import Mock, patch
469 >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
470 >>> @pytest.mark.asyncio
471 ... @patch("mcpgateway.utils.passthrough_headers.settings")
472 ... async def test_db_write_failure(mock_settings):
473 ... mock_settings.enable_header_passthrough = True
474 ... mock_settings.passthrough_headers_source = "db"
475 ... mock_db = Mock()
476 ... mock_db.query.return_value.first.return_value = None
477 ... mock_db.commit.side_effect = Exception("DB write failed")
478 ... with pytest.raises(PassthroughHeadersError):
479 ... await set_global_passthrough_headers(mock_db)
480 ... mock_db.rollback.assert_called_once()
482 Config already exists (no DB write):
483 >>> import pytest
484 >>> from unittest.mock import Mock, patch
485 >>> from mcpgateway.common.models import GlobalConfig
486 >>> @pytest.mark.asyncio
487 ... @patch("mcpgateway.utils.passthrough_headers.settings")
488 ... async def test_existing_config(mock_settings):
489 ... mock_settings.enable_header_passthrough = True
490 ... mock_settings.passthrough_headers_source = "db"
491 ... mock_db = Mock()
492 ... existing = Mock(spec=GlobalConfig)
493 ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"]
494 ... mock_db.query.return_value.first.return_value = existing
495 ... await set_global_passthrough_headers(mock_db)
496 ... mock_db.add.assert_not_called()
497 ... mock_db.commit.assert_not_called()
498 ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"]
500 Env mode skips DB entirely:
501 >>> import pytest
502 >>> from unittest.mock import Mock, patch
503 >>> @pytest.mark.asyncio
504 ... @patch("mcpgateway.utils.passthrough_headers.settings")
505 ... async def test_env_mode_skips_db(mock_settings):
506 ... mock_settings.passthrough_headers_source = "env"
507 ... mock_db = Mock()
508 ... await set_global_passthrough_headers(mock_db)
509 ... mock_db.query.assert_not_called()
510 ... mock_db.add.assert_not_called()
512 Note:
513 This function is typically called during application startup to ensure
514 global configuration is in place before any gateway operations.
515 """
516 # When source is "env", skip DB operations entirely - env vars always win
517 if settings.passthrough_headers_source == "env":
518 logger.debug("Passthrough headers source=env: skipping database initialization (env vars always used)")
519 return
521 # Query DB directly here (not cache) because we need to check if config exists
522 # to decide whether to create it
523 global_config = db.query(GlobalConfig).first()
525 if not global_config:
526 config_headers = settings.default_passthrough_headers
527 allowed_headers = []
528 if config_headers:
529 for header_name in config_headers:
530 # Validate header name
531 if not validate_header_name(header_name):
532 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
533 continue
535 allowed_headers.append(header_name)
536 try:
537 db.add(GlobalConfig(passthrough_headers=allowed_headers))
538 db.commit()
539 # Invalidate both global and loopback caches so next read picks up new config (Issue #1715, #3640)
540 invalidate_passthrough_header_caches()
541 except Exception as e:
542 db.rollback()
543 raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}")
546# Headers that must never be forwarded via loopback — they are set by the caller
547# or are gateway-internal routing/loop-prevention headers.
548# IMPORTANT: keep this set in sync with internal headers set at merge sites
549# (session_registry generate_response, WebSocket relay, Streamable HTTP affinity).
550# httpx concatenates case-different duplicate keys rather than picking one, so an
551# omission here could silently corrupt the internal header value.
552_LOOPBACK_SKIP_HEADERS: frozenset[str] = frozenset(
553 {
554 "authorization",
555 "connection",
556 "content-type",
557 "content-length",
558 "host",
559 "keep-alive",
560 "mcp-session-id",
561 "proxy-connection",
562 "te",
563 "trailer",
564 "transfer-encoding",
565 "upgrade",
566 "x-mcp-session-id",
567 "x-forwarded-internally",
568 }
569)
572def _loopback_skip_set() -> frozenset[str]:
573 """Return the full set of headers to skip in loopback forwarding.
575 Extends ``_LOOPBACK_SKIP_HEADERS`` with the configurable
576 ``proxy_user_header`` (default ``X-Authenticated-User``) so that
577 passthrough headers can never overwrite the gateway-internal proxy
578 user identity — even if that header name is added to the passthrough
579 allowlist by mistake.
581 Returns:
582 frozenset[str]: Header names to skip during loopback forwarding.
583 """
584 proxy = settings.proxy_user_header.lower()
585 if proxy in _LOOPBACK_SKIP_HEADERS:
586 return _LOOPBACK_SKIP_HEADERS
587 return _LOOPBACK_SKIP_HEADERS | {proxy}
590class _LoopbackAllowlistCache:
591 """TTL cache for the merged passthrough header allowlist (global + all gateways).
593 Avoids a full Gateway table scan on every loopback call by caching the union
594 of global and gateway-specific passthrough headers with the same 60 s TTL used
595 by global_config_cache.
596 """
598 def __init__(self, ttl_seconds: float = 60.0):
599 self._cache: Optional[frozenset[str]] = None
600 self._populated: bool = False
601 self._expiry: float = 0
602 self._ttl = ttl_seconds
603 self._lock = threading.Lock()
605 def get(self, db: Session) -> frozenset[str]:
606 """Return the cached merged allowlist, refreshing from DB when expired.
608 Falls back to the last known good value during transient DB failures to
609 avoid a thundering-herd of failing queries on every loopback call.
611 Args:
612 db: SQLAlchemy database session for querying gateway configurations.
614 Returns:
615 Frozen set of allowed passthrough header names (union of global and
616 all gateway-specific configurations).
618 Raises:
619 Exception: Re-raised from DB query when no stale cache is available
620 to fall back to (first call after startup with a broken DB).
621 """
622 now = time.time()
623 # CPython GIL ensures atomic attribute reads on the fast path.
624 if now < self._expiry and self._populated:
625 return self._cache # type: ignore[return-value] # _populated guarantees non-None
626 with self._lock:
627 if now < self._expiry and self._populated:
628 return self._cache # type: ignore[return-value] # _populated guarantees non-None
629 try:
630 merged: set[str] = set(global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers or []) or [])
631 gw_rows = db.query(DbGateway.passthrough_headers).filter(DbGateway.passthrough_headers.isnot(None)).all()
632 for (gw_headers,) in gw_rows:
633 if gw_headers:
634 merged.update(gw_headers)
635 self._cache = frozenset(merged)
636 self._populated = True
637 self._expiry = now + self._ttl
638 except Exception:
639 logger.warning("Failed to refresh loopback allowlist cache from DB; using stale value if available", exc_info=True)
640 if self._populated and self._cache is not None:
641 # Extend TTL briefly to avoid hammering DB on every request
642 self._expiry = now + min(self._ttl, 10.0)
643 else:
644 raise
645 return self._cache # type: ignore[return-value] # _populated guarantees non-None
647 def invalidate(self) -> None:
648 """Force a refresh on next access."""
649 with self._lock:
650 self._populated = False
651 self._expiry = 0
654_loopback_allowlist_cache = _LoopbackAllowlistCache()
657def invalidate_passthrough_header_caches() -> None:
658 """Invalidate both the global config cache and the loopback allowlist cache.
660 Call this after any mutation to passthrough header configuration (global or
661 per-gateway) so that loopback transports (SSE, WebSocket, Streamable HTTP)
662 converge immediately with direct /rpc rather than waiting for TTL expiry.
663 """
664 global_config_cache.invalidate()
665 _loopback_allowlist_cache.invalidate()
666 logger.debug("Invalidated global_config_cache and _loopback_allowlist_cache for passthrough headers")
669def filter_loopback_skip_headers(headers: Dict[str, str]) -> Dict[str, str]:
670 """Return a copy of *headers* with gateway-internal loopback headers removed.
672 Defense-in-depth filter applied at loopback merge sites (SSE generate_response,
673 WebSocket relay) to ensure passthrough headers can never override the gateway's
674 internal JWT, content-type, proxy-user, or session/routing headers — even if
675 ``extract_headers_for_loopback`` is bypassed or its skip-list is out of sync.
677 Values are re-sanitized via ``sanitize_header_value()`` so this function is
678 safe to call on input that has not been pre-sanitized.
680 Args:
681 headers: Candidate passthrough headers to filter.
683 Returns:
684 New dictionary containing only headers whose lowercased names are
685 **not** in the skip set (``_LOOPBACK_SKIP_HEADERS`` plus the
686 configurable ``proxy_user_header``), with values sanitized.
687 """
688 skip = _loopback_skip_set()
689 filtered: Dict[str, str] = {}
690 for k, v in headers.items():
691 if k.lower() in skip:
692 continue
693 try:
694 filtered[k] = sanitize_header_value(v)
695 except Exception:
696 logger.warning("Dropped unsafe header %s during loopback filter", k, exc_info=True)
697 return filtered
700def extract_headers_for_loopback(request_headers: Dict[str, str]) -> Dict[str, str]:
701 """Extract passthrough-relevant headers to forward in internal loopback /rpc calls.
703 SSE and WebSocket transports make internal loopback HTTP calls to /rpc. Client
704 passthrough headers (like X-Upstream-Authorization) must be included in those
705 loopback requests so that /rpc can forward them to upstream MCP servers via
706 get_passthrough_headers().
708 Always extracts:
709 - x-upstream-authorization (always enabled per design, renamed to Authorization upstream)
711 When ENABLE_HEADER_PASSTHROUGH is True, also extracts headers matching the
712 cached union of:
713 - The global allowlist resolved via global_config_cache.get_passthrough_headers()
714 (respects PASSTHROUGH_HEADERS_SOURCE priority: env, db, merge)
715 - All gateway-specific passthrough_headers configured on any Gateway
717 The merged allowlist is cached with a 60 s TTL (matching global_config_cache)
718 so the gateway table scan only runs once per TTL window, not per request.
720 All extracted values are sanitized via sanitize_header_value() for defense-in-depth,
721 even though the /rpc endpoint re-sanitizes via get_passthrough_headers().
723 Headers in _LOOPBACK_SKIP_HEADERS (authorization, content-type, and gateway-internal
724 routing/session headers) are never returned, regardless of configuration.
726 Args:
727 request_headers: Headers from the incoming client HTTP request or WebSocket
728 handshake. Keys are header names, values are header values.
730 Returns:
731 Dictionary of headers to merge into the loopback /rpc request.
732 Does not include authorization, content-type, or gateway-internal headers
733 (those are handled separately by the caller).
735 Examples:
736 X-Upstream-Authorization is always extracted:
737 >>> from unittest.mock import patch
738 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as s:
739 ... s.enable_header_passthrough = False
740 ... s.default_passthrough_headers = []
741 ... extract_headers_for_loopback({"X-Upstream-Authorization": "Bearer tok"})
742 {'x-upstream-authorization': 'Bearer tok'}
744 Empty when no relevant headers present:
745 >>> from unittest.mock import patch
746 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as s:
747 ... s.enable_header_passthrough = False
748 ... s.default_passthrough_headers = []
749 ... extract_headers_for_loopback({"Accept": "text/html"})
750 {}
751 """
752 forwarded: Dict[str, str] = {}
753 if not request_headers:
754 return forwarded
756 headers_lower = {k.lower(): v for k, v in request_headers.items()}
758 # Always forward x-upstream-authorization (always-enabled passthrough header).
759 # On sanitization failure, drop the header rather than forwarding an unsanitized value —
760 # sanitization prevents CRLF/control-character injection, so bypassing it is unsafe.
761 upstream_auth = headers_lower.get("x-upstream-authorization")
762 if upstream_auth:
763 try:
764 forwarded["x-upstream-authorization"] = sanitize_header_value(upstream_auth)
765 except Exception:
766 logger.warning("Failed to sanitize x-upstream-authorization; dropping header for safety", exc_info=True)
768 # When passthrough feature is enabled, also forward configured allowlist headers.
769 # The merged allowlist (global + all gateways) is cached with a 60 s TTL.
770 try:
771 if settings.enable_header_passthrough:
772 # First-Party
773 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel
775 with SessionLocal() as db:
776 allowed = _loopback_allowlist_cache.get(db)
777 skip = _loopback_skip_set()
778 for header_name in allowed:
779 header_lower = header_name.lower()
780 if header_lower in skip:
781 continue
782 if header_lower in headers_lower:
783 try:
784 forwarded[header_lower] = sanitize_header_value(headers_lower[header_lower])
785 except Exception:
786 logger.warning("Failed to sanitize passthrough header %s; skipping", header_lower, exc_info=True)
787 except Exception:
788 logger.warning("Failed to read passthrough header allowlist; forwarding only previously extracted headers", exc_info=True)
790 if forwarded:
791 logger.debug("Extracted %d passthrough header(s) for loopback: %s", len(forwarded), list(forwarded.keys()))
793 return forwarded
796def safe_extract_headers_for_loopback(request_headers: Dict[str, str], transport_name: str = "transport") -> Dict[str, str]:
797 """Safely extract passthrough headers, returning ``{}`` on failure.
799 Wraps :func:`extract_headers_for_loopback` so that SSE / WebSocket setup
800 is never blocked by passthrough configuration issues. ``ImportError``
801 propagates (broken deployment should fail loudly).
803 Args:
804 request_headers: Incoming HTTP headers to extract from.
805 transport_name: Label for warning logs on failure.
807 Returns:
808 Dict[str, str]: Extracted passthrough headers, or empty dict on error.
809 """
810 try:
811 return extract_headers_for_loopback(request_headers)
812 except Exception:
813 logger.warning("Failed to extract passthrough headers for %s; upstream auth may fail", transport_name, exc_info=True)
814 return {}
817def safe_extract_and_filter_for_loopback(request_headers: Dict[str, str]) -> Dict[str, str]:
818 """Extract *and* filter passthrough headers, returning ``{}`` on failure.
820 Combines :func:`extract_headers_for_loopback` and
821 :func:`filter_loopback_skip_headers` with error handling so that
822 Streamable HTTP affinity loopback calls degrade gracefully.
824 Args:
825 request_headers: Incoming HTTP headers to extract and filter.
827 Returns:
828 Dict[str, str]: Filtered passthrough headers, or empty dict on error.
829 """
830 try:
831 return filter_loopback_skip_headers(extract_headers_for_loopback(request_headers))
832 except Exception:
833 logger.warning("Failed to extract passthrough headers for loopback; upstream auth may fail", exc_info=True)
834 return {}