Coverage for mcpgateway / utils / passthrough_headers.py: 100%
158 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/passthrough_headers.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7HTTP Header Passthrough Utilities.
8This module provides utilities for handling HTTP header passthrough functionality
9in the MCP Gateway. It enables forwarding of specific headers from incoming
10client requests to backing MCP servers while preventing conflicts with
11existing authentication mechanisms.
13Key Features:
14- Global configuration support via environment variables and database
15- Per-gateway header configuration overrides
16- Intelligent conflict detection with existing authentication headers
17- Security-first approach with explicit allowlist handling
18- Comprehensive logging for debugging and monitoring
19- Header validation and sanitization
21The header passthrough system follows a priority hierarchy:
221. Gateway-specific headers (highest priority)
232. Global database configuration
243. Environment variable defaults (lowest priority)
26Example Usage:
27 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
28 for detailed examples of header passthrough functionality.
29"""
31# Standard
32import logging
33import re
34from typing import Dict, List, Optional
36# Third-Party
37from sqlalchemy.orm import Session
39# First-Party
40from mcpgateway.cache.global_config_cache import global_config_cache
41from mcpgateway.config import settings
42from mcpgateway.db import Gateway as DbGateway
43from mcpgateway.db import GlobalConfig
45logger = logging.getLogger(__name__)
47# Header name validation regex - allows letters, numbers, and hyphens
48HEADER_NAME_REGEX = re.compile(r"^[A-Za-z0-9\-]+$")
50# Maximum header value length (4KB)
51MAX_HEADER_VALUE_LENGTH = 4096
54class PassthroughHeadersError(Exception):
55 """Base class for passthrough headers-related errors.
57 Examples:
58 >>> error = PassthroughHeadersError("Test error")
59 >>> str(error)
60 'Test error'
61 >>> isinstance(error, Exception)
62 True
63 """
66def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str:
67 """Sanitize header value for security.
69 Removes dangerous characters and enforces length limits.
71 Args:
72 value: Header value to sanitize
73 max_length: Maximum allowed length
75 Returns:
76 Sanitized header value
78 Examples:
79 Remove CRLF and trim length:
80 >>> s = sanitize_header_value('val' + chr(13) + chr(10) + 'more', max_length=6)
81 >>> s
82 'valmor'
83 >>> len(s) <= 6
84 True
85 >>> sanitize_header_value(' spaced ')
86 'spaced'
87 """
88 # Remove newlines and carriage returns to prevent header injection
89 value = value.replace("\r", "").replace("\n", "")
91 # Trim to max length
92 value = value[:max_length]
94 # Remove control characters except tab (ASCII 9) and space (ASCII 32)
95 value = "".join(c for c in value if ord(c) >= 32 or c == "\t")
97 return value.strip()
100def validate_header_name(name: str) -> bool:
101 """Validate header name against allowed pattern.
103 Args:
104 name: Header name to validate
106 Returns:
107 True if valid, False otherwise
109 Examples:
110 Valid names:
111 >>> validate_header_name('X-Tenant-Id')
112 True
113 >>> validate_header_name('X123-ABC')
114 True
116 Invalid names:
117 >>> validate_header_name('Invalid Header:Name')
118 False
119 >>> validate_header_name('Bad@Name')
120 False
121 """
122 return bool(HEADER_NAME_REGEX.match(name))
125def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[str, str], db: Session, gateway: Optional[DbGateway] = None) -> Dict[str, str]:
126 """Get headers that should be passed through to the target gateway.
128 This function implements the core logic for HTTP header passthrough in the MCP Gateway.
129 It determines which headers from incoming client requests should be forwarded to
130 backing MCP servers based on configuration settings and security policies.
132 Configuration Priority (highest to lowest):
133 1. Gateway-specific passthrough_headers setting
134 2. Global headers from get_passthrough_headers() based on PASSTHROUGH_HEADERS_SOURCE:
135 - "db": Database wins if configured, env var DEFAULT_PASSTHROUGH_HEADERS as fallback
136 - "env": Environment variable always wins, database ignored
137 - "merge": Union of both sources (DB casing wins for duplicates)
139 Security Features:
140 - Feature flag control (disabled by default)
141 - Prevents conflicts with existing base headers (e.g., Content-Type)
142 - Blocks Authorization header conflicts with gateway authentication
143 - Header name validation (regex pattern matching)
144 - Header value sanitization (removes dangerous characters, enforces limits)
145 - Logs all conflicts and skipped headers for debugging
146 - Uses case-insensitive header matching for robustness
147 - Special X-Upstream-Authorization handling: When gateway uses auth, clients can
148 send X-Upstream-Authorization header which gets renamed to Authorization for upstream
150 Args:
151 request_headers (Dict[str, str]): Headers from the incoming HTTP request.
152 Keys should be header names, values should be header values.
153 Example: {"Authorization": "Bearer token123", "X-Tenant-Id": "acme"}
154 base_headers (Dict[str, str]): Base headers that should always be included
155 in the final result. These take precedence over passthrough headers.
156 Example: {"Content-Type": "application/json", "User-Agent": "MCPGateway/1.0"}
157 db (Session): SQLAlchemy database session for querying global configuration.
158 Used to retrieve GlobalConfig.passthrough_headers setting.
159 gateway (Optional[DbGateway]): Target gateway instance. If provided, uses
160 gateway.passthrough_headers to override global settings. Also checks
161 gateway.auth_type to prevent Authorization header conflicts.
163 Returns:
164 Dict[str, str]: Combined dictionary of base headers plus allowed passthrough
165 headers from the request. Base headers are preserved, and passthrough
166 headers are added only if they don't conflict with security policies.
168 Raises:
169 No exceptions are raised. Errors are logged as warnings and processing continues.
170 Database connection issues may propagate from the db.query() call.
172 Examples:
173 Feature disabled by default (secure by default):
174 >>> from unittest.mock import Mock, patch
175 >>> from mcpgateway.cache.global_config_cache import global_config_cache
176 >>> global_config_cache.invalidate() # Clear cache for isolated test
177 >>> with patch(__name__ + ".settings") as mock_settings:
178 ... mock_settings.enable_header_passthrough = False
179 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"]
180 ... mock_db = Mock()
181 ... mock_db.query.return_value.first.return_value = None
182 ... request_headers = {"x-tenant-id": "should-be-ignored"}
183 ... base_headers = {"Content-Type": "application/json"}
184 ... get_passthrough_headers(request_headers, base_headers, mock_db)
185 {'Content-Type': 'application/json'}
187 Enabled with allowlist and conflicts:
188 >>> global_config_cache.invalidate() # Clear cache for isolated test
189 >>> with patch(__name__ + ".settings") as mock_settings:
190 ... mock_settings.enable_header_passthrough = True
191 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "Authorization"]
192 ... # Mock DB returns no global override
193 ... mock_db = Mock()
194 ... mock_db.query.return_value.first.return_value = None
195 ... # Gateway with basic auth should block Authorization passthrough
196 ... gateway = Mock()
197 ... gateway.passthrough_headers = None
198 ... gateway.auth_type = "basic"
199 ... gateway.name = "gw1"
200 ... req_headers = {"X-Tenant-Id": "acme", "Authorization": "Bearer abc"}
201 ... base = {"Content-Type": "application/json", "Authorization": "Bearer base"}
202 ... res = get_passthrough_headers(req_headers, base, mock_db, gateway)
203 ... ("X-Tenant-Id" in res) and (res["Authorization"] == "Bearer base")
204 True
206 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py
207 for detailed examples of enabled functionality, conflict detection, and security features.
209 Note:
210 Header names are matched case-insensitively but preserved in their original
211 case from the allowed_headers configuration. Request header values are
212 matched case-insensitively against the request_headers dictionary.
213 """
214 passthrough_headers = base_headers.copy()
216 # Special handling for X-Upstream-Authorization header (always enabled)
217 # If gateway uses auth and client wants to pass Authorization to upstream,
218 # client can use X-Upstream-Authorization which gets renamed to Authorization
219 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
220 upstream_auth = request_headers_lower.get("x-upstream-authorization")
222 if upstream_auth:
223 try:
224 sanitized_value = sanitize_header_value(upstream_auth)
225 if sanitized_value:
226 # Always rename X-Upstream-Authorization to Authorization for upstream
227 # This works for both auth and no-auth gateways
228 passthrough_headers["Authorization"] = sanitized_value
229 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough")
230 except Exception as e:
231 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}")
232 elif gateway and gateway.auth_type == "none":
233 # When gateway has no auth, pass through client's Authorization if present
234 client_auth = request_headers_lower.get("authorization")
235 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]:
236 try:
237 sanitized_value = sanitize_header_value(client_auth)
238 if sanitized_value:
239 passthrough_headers["Authorization"] = sanitized_value
240 logger.debug("Passing through client Authorization header (auth_type=none)")
241 except Exception as e:
242 logger.warning(f"Failed to sanitize Authorization header: {e}")
244 # Early return if header passthrough feature is disabled
245 if not settings.enable_header_passthrough:
246 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag")
247 return passthrough_headers
249 if settings.enable_overwrite_base_headers:
250 logger.debug("Overwriting base headers is enabled via ENABLE_OVERWRITE_BASE_HEADERS flag")
252 # Get global passthrough headers from in-memory cache (Issue #1715)
253 # This eliminates redundant DB queries for static configuration
254 allowed_headers = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers)
256 # Gateway specific headers override global config
257 if gateway:
258 if gateway.passthrough_headers is not None:
259 allowed_headers = gateway.passthrough_headers
261 # Create case-insensitive lookup for request headers
262 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
264 # Get auth headers to check for conflicts
265 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()}
267 # Copy allowed headers from request
268 if request_headers_lower and allowed_headers:
269 for header_name in allowed_headers:
270 # Validate header name
271 if not validate_header_name(header_name):
272 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
273 continue
275 header_lower = header_name.lower()
276 header_value = request_headers_lower.get(header_lower)
278 if header_value:
279 # Sanitize header value
280 try:
281 sanitized_value = sanitize_header_value(header_value)
282 if not sanitized_value:
283 logger.warning(f"Header {header_name} value became empty after sanitization - skipping")
284 continue
285 except Exception as e:
286 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping")
287 continue
289 # Skip if header would conflict with existing auth headers
290 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers:
291 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers")
292 continue
294 # Skip if header would conflict with gateway auth
295 if gateway:
296 if gateway.auth_type == "basic" and header_lower == "authorization":
297 logger.warning(f"Skipping Authorization header passthrough due to basic auth configuration on gateway {gateway.name}")
298 continue
299 if gateway.auth_type == "bearer" and header_lower == "authorization":
300 logger.warning(f"Skipping Authorization header passthrough due to bearer auth configuration on gateway {gateway.name}")
301 continue
303 # Use original header name casing from configuration, sanitized value from request
304 passthrough_headers[header_name] = sanitized_value
305 logger.debug(f"Added passthrough header: {header_name}")
306 else:
307 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough")
309 logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}")
310 return passthrough_headers
313def compute_passthrough_headers_cached(
314 request_headers: Dict[str, str],
315 base_headers: Dict[str, str],
316 allowed_headers: List[str],
317 gateway_auth_type: Optional[str] = None,
318 gateway_passthrough_headers: Optional[List[str]] = None,
319) -> Dict[str, str]:
320 """Compute passthrough headers without database query.
322 Use this when GlobalConfig has already been fetched and cached, to avoid
323 repeated database queries during high-frequency operations like tool invocation.
325 This function implements the same header passthrough logic as get_passthrough_headers()
326 but accepts pre-fetched configuration values instead of querying the database.
328 Args:
329 request_headers: Headers from the incoming HTTP request.
330 base_headers: Base headers that should always be included (auth, content-type, etc.).
331 allowed_headers: List of header names allowed to pass through (from GlobalConfig).
332 gateway_auth_type: The gateway's auth_type (basic, bearer, oauth, none) if applicable.
333 gateway_passthrough_headers: Gateway-specific passthrough headers override.
335 Returns:
336 Combined dictionary of base headers plus allowed passthrough headers.
338 Examples:
339 >>> from unittest.mock import patch
340 >>> from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached
341 >>> request = {"X-Tenant-Id": "acme", "Authorization": "secret"}
342 >>> base = {"Content-Type": "application/json"}
343 >>> allowed = ["X-Tenant-Id"]
344 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings:
345 ... mock_settings.enable_header_passthrough = True
346 ... mock_settings.enable_overwrite_base_headers = False
347 ... result = compute_passthrough_headers_cached(request, base, allowed, gateway_auth_type=None)
348 >>> "X-Tenant-Id" in result
349 True
350 >>> result.get("Authorization") is None # Not in allowed list
351 True
352 """
353 passthrough_headers = base_headers.copy()
355 # Special handling for X-Upstream-Authorization header (always enabled)
356 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {}
357 upstream_auth = request_headers_lower.get("x-upstream-authorization")
359 if upstream_auth:
360 try:
361 sanitized_value = sanitize_header_value(upstream_auth)
362 if sanitized_value:
363 passthrough_headers["Authorization"] = sanitized_value
364 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough")
365 except Exception as e:
366 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}")
367 elif gateway_auth_type == "none":
368 # When gateway has no auth, pass through client's Authorization if present
369 client_auth = request_headers_lower.get("authorization")
370 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]:
371 try:
372 sanitized_value = sanitize_header_value(client_auth)
373 if sanitized_value:
374 passthrough_headers["Authorization"] = sanitized_value
375 logger.debug("Passing through client Authorization header (auth_type=none)")
376 except Exception as e:
377 logger.warning(f"Failed to sanitize Authorization header: {e}")
379 # Early return if header passthrough feature is disabled
380 if not settings.enable_header_passthrough:
381 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag")
382 return passthrough_headers
384 # Use gateway-specific headers if provided, otherwise use global allowed_headers
385 effective_allowed = gateway_passthrough_headers if gateway_passthrough_headers is not None else allowed_headers
387 # Create case-insensitive lookup for base headers
388 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()}
390 # Copy allowed headers from request
391 if request_headers_lower and effective_allowed:
392 for header_name in effective_allowed:
393 # Validate header name
394 if not validate_header_name(header_name):
395 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
396 continue
398 header_lower = header_name.lower()
399 header_value = request_headers_lower.get(header_lower)
401 if header_value:
402 # Sanitize header value
403 try:
404 sanitized_value = sanitize_header_value(header_value)
405 if not sanitized_value:
406 logger.warning(f"Header {header_name} value became empty after sanitization - skipping")
407 continue
408 except Exception as e:
409 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping")
410 continue
412 # Skip if header would conflict with existing auth headers
413 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers:
414 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers")
415 continue
417 # Skip if header would conflict with gateway auth
418 if gateway_auth_type in ("basic", "bearer") and header_lower == "authorization":
419 logger.warning(f"Skipping Authorization header passthrough due to {gateway_auth_type} auth configuration")
420 continue
422 # Use original header name casing from configuration, sanitized value from request
423 passthrough_headers[header_name] = sanitized_value
424 logger.debug(f"Added passthrough header: {header_name}")
425 else:
426 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough")
428 logger.debug(f"Final passthrough headers (cached): {list(passthrough_headers.keys())}")
429 return passthrough_headers
432async def set_global_passthrough_headers(db: Session) -> None:
433 """Set global passthrough headers in the database if not already configured.
435 This function checks if the global passthrough headers are already set in the
436 GlobalConfig table. If not, it initializes them with the default headers from
437 settings.default_passthrough_headers.
439 When PASSTHROUGH_HEADERS_SOURCE=env, this function skips database writes entirely
440 since the database configuration is ignored in that mode.
442 Args:
443 db (Session): SQLAlchemy database session for querying and updating GlobalConfig.
445 Raises:
446 PassthroughHeadersError: If unable to update passthrough headers in the database.
448 Examples:
449 Successful insert of default headers:
450 >>> import pytest
451 >>> from unittest.mock import Mock, patch
452 >>> @pytest.mark.asyncio
453 ... @patch("mcpgateway.utils.passthrough_headers.settings")
454 ... async def test_default_headers(mock_settings):
455 ... mock_settings.enable_header_passthrough = True
456 ... mock_settings.passthrough_headers_source = "db"
457 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"]
458 ... mock_db = Mock()
459 ... mock_db.query.return_value.first.return_value = None
460 ... await set_global_passthrough_headers(mock_db)
461 ... mock_db.add.assert_called_once()
462 ... mock_db.commit.assert_called_once()
464 Database write failure:
465 >>> import pytest
466 >>> from unittest.mock import Mock, patch
467 >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError
468 >>> @pytest.mark.asyncio
469 ... @patch("mcpgateway.utils.passthrough_headers.settings")
470 ... async def test_db_write_failure(mock_settings):
471 ... mock_settings.enable_header_passthrough = True
472 ... mock_settings.passthrough_headers_source = "db"
473 ... mock_db = Mock()
474 ... mock_db.query.return_value.first.return_value = None
475 ... mock_db.commit.side_effect = Exception("DB write failed")
476 ... with pytest.raises(PassthroughHeadersError):
477 ... await set_global_passthrough_headers(mock_db)
478 ... mock_db.rollback.assert_called_once()
480 Config already exists (no DB write):
481 >>> import pytest
482 >>> from unittest.mock import Mock, patch
483 >>> from mcpgateway.common.models import GlobalConfig
484 >>> @pytest.mark.asyncio
485 ... @patch("mcpgateway.utils.passthrough_headers.settings")
486 ... async def test_existing_config(mock_settings):
487 ... mock_settings.enable_header_passthrough = True
488 ... mock_settings.passthrough_headers_source = "db"
489 ... mock_db = Mock()
490 ... existing = Mock(spec=GlobalConfig)
491 ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"]
492 ... mock_db.query.return_value.first.return_value = existing
493 ... await set_global_passthrough_headers(mock_db)
494 ... mock_db.add.assert_not_called()
495 ... mock_db.commit.assert_not_called()
496 ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"]
498 Env mode skips DB entirely:
499 >>> import pytest
500 >>> from unittest.mock import Mock, patch
501 >>> @pytest.mark.asyncio
502 ... @patch("mcpgateway.utils.passthrough_headers.settings")
503 ... async def test_env_mode_skips_db(mock_settings):
504 ... mock_settings.passthrough_headers_source = "env"
505 ... mock_db = Mock()
506 ... await set_global_passthrough_headers(mock_db)
507 ... mock_db.query.assert_not_called()
508 ... mock_db.add.assert_not_called()
510 Note:
511 This function is typically called during application startup to ensure
512 global configuration is in place before any gateway operations.
513 """
514 # When source is "env", skip DB operations entirely - env vars always win
515 if settings.passthrough_headers_source == "env":
516 logger.debug("Passthrough headers source=env: skipping database initialization (env vars always used)")
517 return
519 # Query DB directly here (not cache) because we need to check if config exists
520 # to decide whether to create it
521 global_config = db.query(GlobalConfig).first()
523 if not global_config:
524 config_headers = settings.default_passthrough_headers
525 allowed_headers = []
526 if config_headers:
527 for header_name in config_headers:
528 # Validate header name
529 if not validate_header_name(header_name):
530 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})")
531 continue
533 allowed_headers.append(header_name)
534 try:
535 db.add(GlobalConfig(passthrough_headers=allowed_headers))
536 db.commit()
537 # Invalidate cache so next read picks up new config (Issue #1715)
538 global_config_cache.invalidate()
539 except Exception as e:
540 db.rollback()
541 raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}")