Coverage for mcpgateway / utils / passthrough_headers.py: 99%

258 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/passthrough_headers.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7HTTP Header Passthrough Utilities. 

8This module provides utilities for handling HTTP header passthrough functionality 

9in ContextForge. It enables forwarding of specific headers from incoming 

10client requests to backing MCP servers while preventing conflicts with 

11existing authentication mechanisms. 

12 

13Key Features: 

14- Global configuration support via environment variables and database 

15- Per-gateway header configuration overrides 

16- Intelligent conflict detection with existing authentication headers 

17- Security-first approach with explicit allowlist handling 

18- Comprehensive logging for debugging and monitoring 

19- Header validation and sanitization 

20 

21The header passthrough system follows a priority hierarchy: 

221. Gateway-specific headers (highest priority) 

232. Global database configuration 

243. Environment variable defaults (lowest priority) 

25 

26Example Usage: 

27 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py 

28 for detailed examples of header passthrough functionality. 

29""" 

30 

31# Standard 

32import logging 

33import re 

34import threading 

35import time 

36from typing import Dict, List, Optional 

37 

38# Third-Party 

39from sqlalchemy.orm import Session 

40 

41# First-Party 

42from mcpgateway.cache.global_config_cache import global_config_cache 

43from mcpgateway.config import settings 

44from mcpgateway.db import Gateway as DbGateway 

45from mcpgateway.db import GlobalConfig 

46 

47logger = logging.getLogger(__name__) 

48 

49# Header name validation regex - allows letters, numbers, and hyphens 

50HEADER_NAME_REGEX = re.compile(r"^[A-Za-z0-9\-]+$") 

51 

52# Maximum header value length (4KB) 

53MAX_HEADER_VALUE_LENGTH = 4096 

54 

55 

56class PassthroughHeadersError(Exception): 

57 """Base class for passthrough headers-related errors. 

58 

59 Examples: 

60 >>> error = PassthroughHeadersError("Test error") 

61 >>> str(error) 

62 'Test error' 

63 >>> isinstance(error, Exception) 

64 True 

65 """ 

66 

67 

68def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str: 

69 """Sanitize header value for security. 

70 

71 Removes dangerous characters and enforces length limits. 

72 

73 Args: 

74 value: Header value to sanitize 

75 max_length: Maximum allowed length 

76 

77 Returns: 

78 Sanitized header value 

79 

80 Examples: 

81 Remove CRLF and trim length: 

82 >>> s = sanitize_header_value('val' + chr(13) + chr(10) + 'more', max_length=6) 

83 >>> s 

84 'valmor' 

85 >>> len(s) <= 6 

86 True 

87 >>> sanitize_header_value(' spaced ') 

88 'spaced' 

89 """ 

90 # Remove newlines and carriage returns to prevent header injection 

91 value = value.replace("\r", "").replace("\n", "") 

92 

93 # Trim to max length 

94 value = value[:max_length] 

95 

96 # Remove control characters except tab (ASCII 9) and space (ASCII 32) 

97 value = "".join(c for c in value if ord(c) >= 32 or c == "\t") 

98 

99 return value.strip() 

100 

101 

102def validate_header_name(name: str) -> bool: 

103 """Validate header name against allowed pattern. 

104 

105 Args: 

106 name: Header name to validate 

107 

108 Returns: 

109 True if valid, False otherwise 

110 

111 Examples: 

112 Valid names: 

113 >>> validate_header_name('X-Tenant-Id') 

114 True 

115 >>> validate_header_name('X123-ABC') 

116 True 

117 

118 Invalid names: 

119 >>> validate_header_name('Invalid Header:Name') 

120 False 

121 >>> validate_header_name('Bad@Name') 

122 False 

123 """ 

124 return bool(HEADER_NAME_REGEX.match(name)) 

125 

126 

127def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[str, str], db: Session, gateway: Optional[DbGateway] = None) -> Dict[str, str]: 

128 """Get headers that should be passed through to the target gateway. 

129 

130 This function implements the core logic for HTTP header passthrough in ContextForge. 

131 It determines which headers from incoming client requests should be forwarded to 

132 backing MCP servers based on configuration settings and security policies. 

133 

134 Configuration Priority (highest to lowest): 

135 1. Gateway-specific passthrough_headers setting 

136 2. Global headers from get_passthrough_headers() based on PASSTHROUGH_HEADERS_SOURCE: 

137 - "db": Database wins if configured, env var DEFAULT_PASSTHROUGH_HEADERS as fallback 

138 - "env": Environment variable always wins, database ignored 

139 - "merge": Union of both sources (DB casing wins for duplicates) 

140 

141 Security Features: 

142 - Feature flag control (disabled by default) 

143 - Prevents conflicts with existing base headers (e.g., Content-Type) 

144 - Blocks Authorization header conflicts with gateway authentication 

145 - Header name validation (regex pattern matching) 

146 - Header value sanitization (removes dangerous characters, enforces limits) 

147 - Logs all conflicts and skipped headers for debugging 

148 - Uses case-insensitive header matching for robustness 

149 - Special X-Upstream-Authorization handling: When gateway uses auth, clients can 

150 send X-Upstream-Authorization header which gets renamed to Authorization for upstream 

151 

152 Args: 

153 request_headers (Dict[str, str]): Headers from the incoming HTTP request. 

154 Keys should be header names, values should be header values. 

155 Example: {"Authorization": "Bearer token123", "X-Tenant-Id": "acme"} 

156 base_headers (Dict[str, str]): Base headers that should always be included 

157 in the final result. These take precedence over passthrough headers. 

158 Example: {"Content-Type": "application/json", "User-Agent": "MCPGateway/1.0"} 

159 db (Session): SQLAlchemy database session for querying global configuration. 

160 Used to retrieve GlobalConfig.passthrough_headers setting. 

161 gateway (Optional[DbGateway]): Target gateway instance. If provided, uses 

162 gateway.passthrough_headers to override global settings. Also checks 

163 gateway.auth_type to prevent Authorization header conflicts. 

164 

165 Returns: 

166 Dict[str, str]: Combined dictionary of base headers plus allowed passthrough 

167 headers from the request. Base headers are preserved, and passthrough 

168 headers are added only if they don't conflict with security policies. 

169 

170 Raises: 

171 No exceptions are raised. Errors are logged as warnings and processing continues. 

172 Database connection issues may propagate from the db.query() call. 

173 

174 Examples: 

175 Feature disabled by default (secure by default): 

176 >>> from unittest.mock import Mock, patch 

177 >>> from mcpgateway.cache.global_config_cache import global_config_cache 

178 >>> global_config_cache.invalidate() # Clear cache for isolated test 

179 >>> with patch(__name__ + ".settings") as mock_settings: 

180 ... mock_settings.enable_header_passthrough = False 

181 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"] 

182 ... mock_db = Mock() 

183 ... mock_db.query.return_value.first.return_value = None 

184 ... request_headers = {"x-tenant-id": "should-be-ignored"} 

185 ... base_headers = {"Content-Type": "application/json"} 

186 ... get_passthrough_headers(request_headers, base_headers, mock_db) 

187 {'Content-Type': 'application/json'} 

188 

189 Enabled with allowlist and conflicts: 

190 >>> global_config_cache.invalidate() # Clear cache for isolated test 

191 >>> with patch(__name__ + ".settings") as mock_settings: 

192 ... mock_settings.enable_header_passthrough = True 

193 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "Authorization"] 

194 ... # Mock DB returns no global override 

195 ... mock_db = Mock() 

196 ... mock_db.query.return_value.first.return_value = None 

197 ... # Gateway with basic auth should block Authorization passthrough 

198 ... gateway = Mock() 

199 ... gateway.passthrough_headers = None 

200 ... gateway.auth_type = "basic" 

201 ... gateway.name = "gw1" 

202 ... req_headers = {"X-Tenant-Id": "acme", "Authorization": "Bearer abc"} 

203 ... base = {"Content-Type": "application/json", "Authorization": "Bearer base"} 

204 ... res = get_passthrough_headers(req_headers, base, mock_db, gateway) 

205 ... ("X-Tenant-Id" in res) and (res["Authorization"] == "Bearer base") 

206 True 

207 

208 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py 

209 for detailed examples of enabled functionality, conflict detection, and security features. 

210 

211 Note: 

212 Header names are matched case-insensitively but preserved in their original 

213 case from the allowed_headers configuration. Request header values are 

214 matched case-insensitively against the request_headers dictionary. 

215 """ 

216 passthrough_headers = base_headers.copy() 

217 

218 # Special handling for X-Upstream-Authorization header (always enabled) 

219 # If gateway uses auth and client wants to pass Authorization to upstream, 

220 # client can use X-Upstream-Authorization which gets renamed to Authorization 

221 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

222 upstream_auth = request_headers_lower.get("x-upstream-authorization") 

223 

224 if upstream_auth: 

225 try: 

226 sanitized_value = sanitize_header_value(upstream_auth) 

227 if sanitized_value: 

228 # Always rename X-Upstream-Authorization to Authorization for upstream 

229 # This works for both auth and no-auth gateways 

230 passthrough_headers["Authorization"] = sanitized_value 

231 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough") 

232 except Exception as e: 

233 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}") 

234 elif gateway and gateway.auth_type == "none": 

235 # When gateway has no auth, pass through client's Authorization if present 

236 client_auth = request_headers_lower.get("authorization") 

237 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]: 

238 try: 

239 sanitized_value = sanitize_header_value(client_auth) 

240 if sanitized_value: 

241 passthrough_headers["Authorization"] = sanitized_value 

242 logger.debug("Passing through client Authorization header (auth_type=none)") 

243 except Exception as e: 

244 logger.warning(f"Failed to sanitize Authorization header: {e}") 

245 

246 # Early return if header passthrough feature is disabled 

247 if not settings.enable_header_passthrough: 

248 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag") 

249 return passthrough_headers 

250 

251 if settings.enable_overwrite_base_headers: 

252 logger.debug("Overwriting base headers is enabled via ENABLE_OVERWRITE_BASE_HEADERS flag") 

253 

254 # Get global passthrough headers from in-memory cache (Issue #1715) 

255 # This eliminates redundant DB queries for static configuration 

256 allowed_headers = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers) 

257 

258 # Gateway specific headers override global config 

259 if gateway: 

260 if gateway.passthrough_headers is not None: 

261 allowed_headers = gateway.passthrough_headers 

262 

263 # Create case-insensitive lookup for request headers 

264 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

265 

266 # Get auth headers to check for conflicts 

267 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()} 

268 

269 # Copy allowed headers from request 

270 if request_headers_lower and allowed_headers: 

271 for header_name in allowed_headers: 

272 # Validate header name 

273 if not validate_header_name(header_name): 

274 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

275 continue 

276 

277 header_lower = header_name.lower() 

278 header_value = request_headers_lower.get(header_lower) 

279 

280 if header_value: 

281 # Sanitize header value 

282 try: 

283 sanitized_value = sanitize_header_value(header_value) 

284 if not sanitized_value: 

285 logger.warning(f"Header {header_name} value became empty after sanitization - skipping") 

286 continue 

287 except Exception as e: 

288 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping") 

289 continue 

290 

291 # Skip if header would conflict with existing auth headers 

292 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers: 

293 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers") 

294 continue 

295 

296 # Skip if header would conflict with gateway auth 

297 if gateway: 

298 if gateway.auth_type == "basic" and header_lower == "authorization": 

299 logger.warning(f"Skipping Authorization header passthrough due to basic auth configuration on gateway {gateway.name}") 

300 continue 

301 if gateway.auth_type == "bearer" and header_lower == "authorization": 

302 logger.warning(f"Skipping Authorization header passthrough due to bearer auth configuration on gateway {gateway.name}") 

303 continue 

304 

305 # Use original header name casing from configuration, sanitized value from request 

306 passthrough_headers[header_name] = sanitized_value 

307 logger.debug(f"Added passthrough header: {header_name}") 

308 else: 

309 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough") 

310 

311 logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}") 

312 return passthrough_headers 

313 

314 

315def compute_passthrough_headers_cached( 

316 request_headers: Dict[str, str], 

317 base_headers: Dict[str, str], 

318 allowed_headers: List[str], 

319 gateway_auth_type: Optional[str] = None, 

320 gateway_passthrough_headers: Optional[List[str]] = None, 

321) -> Dict[str, str]: 

322 """Compute passthrough headers without database query. 

323 

324 Use this when GlobalConfig has already been fetched and cached, to avoid 

325 repeated database queries during high-frequency operations like tool invocation. 

326 

327 This function implements the same header passthrough logic as get_passthrough_headers() 

328 but accepts pre-fetched configuration values instead of querying the database. 

329 

330 Args: 

331 request_headers: Headers from the incoming HTTP request. 

332 base_headers: Base headers that should always be included (auth, content-type, etc.). 

333 allowed_headers: List of header names allowed to pass through (from GlobalConfig). 

334 gateway_auth_type: The gateway's auth_type (basic, bearer, oauth, none) if applicable. 

335 gateway_passthrough_headers: Gateway-specific passthrough headers override. 

336 

337 Returns: 

338 Combined dictionary of base headers plus allowed passthrough headers. 

339 

340 Examples: 

341 >>> from unittest.mock import patch 

342 >>> from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached 

343 >>> request = {"X-Tenant-Id": "acme", "Authorization": "secret"} 

344 >>> base = {"Content-Type": "application/json"} 

345 >>> allowed = ["X-Tenant-Id"] 

346 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings: 

347 ... mock_settings.enable_header_passthrough = True 

348 ... mock_settings.enable_overwrite_base_headers = False 

349 ... result = compute_passthrough_headers_cached(request, base, allowed, gateway_auth_type=None) 

350 >>> "X-Tenant-Id" in result 

351 True 

352 >>> result.get("Authorization") is None # Not in allowed list 

353 True 

354 """ 

355 passthrough_headers = base_headers.copy() 

356 

357 # Special handling for X-Upstream-Authorization header (always enabled) 

358 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

359 upstream_auth = request_headers_lower.get("x-upstream-authorization") 

360 

361 if upstream_auth: 

362 try: 

363 sanitized_value = sanitize_header_value(upstream_auth) 

364 if sanitized_value: 

365 passthrough_headers["Authorization"] = sanitized_value 

366 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough") 

367 except Exception as e: 

368 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}") 

369 elif gateway_auth_type == "none": 

370 # When gateway has no auth, pass through client's Authorization if present 

371 client_auth = request_headers_lower.get("authorization") 

372 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]: 

373 try: 

374 sanitized_value = sanitize_header_value(client_auth) 

375 if sanitized_value: 

376 passthrough_headers["Authorization"] = sanitized_value 

377 logger.debug("Passing through client Authorization header (auth_type=none)") 

378 except Exception as e: 

379 logger.warning(f"Failed to sanitize Authorization header: {e}") 

380 

381 # Early return if header passthrough feature is disabled 

382 if not settings.enable_header_passthrough: 

383 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag") 

384 return passthrough_headers 

385 

386 # Use gateway-specific headers if provided, otherwise use global allowed_headers 

387 effective_allowed = gateway_passthrough_headers if gateway_passthrough_headers is not None else allowed_headers 

388 

389 # Create case-insensitive lookup for base headers 

390 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()} 

391 

392 # Copy allowed headers from request 

393 if request_headers_lower and effective_allowed: 

394 for header_name in effective_allowed: 

395 # Validate header name 

396 if not validate_header_name(header_name): 

397 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

398 continue 

399 

400 header_lower = header_name.lower() 

401 header_value = request_headers_lower.get(header_lower) 

402 

403 if header_value: 

404 # Sanitize header value 

405 try: 

406 sanitized_value = sanitize_header_value(header_value) 

407 if not sanitized_value: 

408 logger.warning(f"Header {header_name} value became empty after sanitization - skipping") 

409 continue 

410 except Exception as e: 

411 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping") 

412 continue 

413 

414 # Skip if header would conflict with existing auth headers 

415 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers: 

416 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers") 

417 continue 

418 

419 # Skip if header would conflict with gateway auth 

420 if gateway_auth_type in ("basic", "bearer") and header_lower == "authorization": 

421 logger.warning(f"Skipping Authorization header passthrough due to {gateway_auth_type} auth configuration") 

422 continue 

423 

424 # Use original header name casing from configuration, sanitized value from request 

425 passthrough_headers[header_name] = sanitized_value 

426 logger.debug(f"Added passthrough header: {header_name}") 

427 else: 

428 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough") 

429 

430 logger.debug(f"Final passthrough headers (cached): {list(passthrough_headers.keys())}") 

431 return passthrough_headers 

432 

433 

434async def set_global_passthrough_headers(db: Session) -> None: 

435 """Set global passthrough headers in the database if not already configured. 

436 

437 This function checks if the global passthrough headers are already set in the 

438 GlobalConfig table. If not, it initializes them with the default headers from 

439 settings.default_passthrough_headers. 

440 

441 When PASSTHROUGH_HEADERS_SOURCE=env, this function skips database writes entirely 

442 since the database configuration is ignored in that mode. 

443 

444 Args: 

445 db (Session): SQLAlchemy database session for querying and updating GlobalConfig. 

446 

447 Raises: 

448 PassthroughHeadersError: If unable to update passthrough headers in the database. 

449 

450 Examples: 

451 Successful insert of default headers: 

452 >>> import pytest 

453 >>> from unittest.mock import Mock, patch 

454 >>> @pytest.mark.asyncio 

455 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

456 ... async def test_default_headers(mock_settings): 

457 ... mock_settings.enable_header_passthrough = True 

458 ... mock_settings.passthrough_headers_source = "db" 

459 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] 

460 ... mock_db = Mock() 

461 ... mock_db.query.return_value.first.return_value = None 

462 ... await set_global_passthrough_headers(mock_db) 

463 ... mock_db.add.assert_called_once() 

464 ... mock_db.commit.assert_called_once() 

465 

466 Database write failure: 

467 >>> import pytest 

468 >>> from unittest.mock import Mock, patch 

469 >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError 

470 >>> @pytest.mark.asyncio 

471 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

472 ... async def test_db_write_failure(mock_settings): 

473 ... mock_settings.enable_header_passthrough = True 

474 ... mock_settings.passthrough_headers_source = "db" 

475 ... mock_db = Mock() 

476 ... mock_db.query.return_value.first.return_value = None 

477 ... mock_db.commit.side_effect = Exception("DB write failed") 

478 ... with pytest.raises(PassthroughHeadersError): 

479 ... await set_global_passthrough_headers(mock_db) 

480 ... mock_db.rollback.assert_called_once() 

481 

482 Config already exists (no DB write): 

483 >>> import pytest 

484 >>> from unittest.mock import Mock, patch 

485 >>> from mcpgateway.common.models import GlobalConfig 

486 >>> @pytest.mark.asyncio 

487 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

488 ... async def test_existing_config(mock_settings): 

489 ... mock_settings.enable_header_passthrough = True 

490 ... mock_settings.passthrough_headers_source = "db" 

491 ... mock_db = Mock() 

492 ... existing = Mock(spec=GlobalConfig) 

493 ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"] 

494 ... mock_db.query.return_value.first.return_value = existing 

495 ... await set_global_passthrough_headers(mock_db) 

496 ... mock_db.add.assert_not_called() 

497 ... mock_db.commit.assert_not_called() 

498 ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"] 

499 

500 Env mode skips DB entirely: 

501 >>> import pytest 

502 >>> from unittest.mock import Mock, patch 

503 >>> @pytest.mark.asyncio 

504 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

505 ... async def test_env_mode_skips_db(mock_settings): 

506 ... mock_settings.passthrough_headers_source = "env" 

507 ... mock_db = Mock() 

508 ... await set_global_passthrough_headers(mock_db) 

509 ... mock_db.query.assert_not_called() 

510 ... mock_db.add.assert_not_called() 

511 

512 Note: 

513 This function is typically called during application startup to ensure 

514 global configuration is in place before any gateway operations. 

515 """ 

516 # When source is "env", skip DB operations entirely - env vars always win 

517 if settings.passthrough_headers_source == "env": 

518 logger.debug("Passthrough headers source=env: skipping database initialization (env vars always used)") 

519 return 

520 

521 # Query DB directly here (not cache) because we need to check if config exists 

522 # to decide whether to create it 

523 global_config = db.query(GlobalConfig).first() 

524 

525 if not global_config: 

526 config_headers = settings.default_passthrough_headers 

527 allowed_headers = [] 

528 if config_headers: 

529 for header_name in config_headers: 

530 # Validate header name 

531 if not validate_header_name(header_name): 

532 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

533 continue 

534 

535 allowed_headers.append(header_name) 

536 try: 

537 db.add(GlobalConfig(passthrough_headers=allowed_headers)) 

538 db.commit() 

539 # Invalidate both global and loopback caches so next read picks up new config (Issue #1715, #3640) 

540 invalidate_passthrough_header_caches() 

541 except Exception as e: 

542 db.rollback() 

543 raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}") 

544 

545 

546# Headers that must never be forwarded via loopback — they are set by the caller 

547# or are gateway-internal routing/loop-prevention headers. 

548# IMPORTANT: keep this set in sync with internal headers set at merge sites 

549# (session_registry generate_response, WebSocket relay, Streamable HTTP affinity). 

550# httpx concatenates case-different duplicate keys rather than picking one, so an 

551# omission here could silently corrupt the internal header value. 

552_LOOPBACK_SKIP_HEADERS: frozenset[str] = frozenset( 

553 { 

554 "authorization", 

555 "connection", 

556 "content-type", 

557 "content-length", 

558 "host", 

559 "keep-alive", 

560 "mcp-session-id", 

561 "proxy-connection", 

562 "te", 

563 "trailer", 

564 "transfer-encoding", 

565 "upgrade", 

566 "x-mcp-session-id", 

567 "x-forwarded-internally", 

568 } 

569) 

570 

571 

572def _loopback_skip_set() -> frozenset[str]: 

573 """Return the full set of headers to skip in loopback forwarding. 

574 

575 Extends ``_LOOPBACK_SKIP_HEADERS`` with the configurable 

576 ``proxy_user_header`` (default ``X-Authenticated-User``) so that 

577 passthrough headers can never overwrite the gateway-internal proxy 

578 user identity — even if that header name is added to the passthrough 

579 allowlist by mistake. 

580 

581 Returns: 

582 frozenset[str]: Header names to skip during loopback forwarding. 

583 """ 

584 proxy = settings.proxy_user_header.lower() 

585 if proxy in _LOOPBACK_SKIP_HEADERS: 

586 return _LOOPBACK_SKIP_HEADERS 

587 return _LOOPBACK_SKIP_HEADERS | {proxy} 

588 

589 

590class _LoopbackAllowlistCache: 

591 """TTL cache for the merged passthrough header allowlist (global + all gateways). 

592 

593 Avoids a full Gateway table scan on every loopback call by caching the union 

594 of global and gateway-specific passthrough headers with the same 60 s TTL used 

595 by global_config_cache. 

596 """ 

597 

598 def __init__(self, ttl_seconds: float = 60.0): 

599 self._cache: Optional[frozenset[str]] = None 

600 self._populated: bool = False 

601 self._expiry: float = 0 

602 self._ttl = ttl_seconds 

603 self._lock = threading.Lock() 

604 

605 def get(self, db: Session) -> frozenset[str]: 

606 """Return the cached merged allowlist, refreshing from DB when expired. 

607 

608 Falls back to the last known good value during transient DB failures to 

609 avoid a thundering-herd of failing queries on every loopback call. 

610 

611 Args: 

612 db: SQLAlchemy database session for querying gateway configurations. 

613 

614 Returns: 

615 Frozen set of allowed passthrough header names (union of global and 

616 all gateway-specific configurations). 

617 

618 Raises: 

619 Exception: Re-raised from DB query when no stale cache is available 

620 to fall back to (first call after startup with a broken DB). 

621 """ 

622 now = time.time() 

623 # CPython GIL ensures atomic attribute reads on the fast path. 

624 if now < self._expiry and self._populated: 

625 return self._cache # type: ignore[return-value] # _populated guarantees non-None 

626 with self._lock: 

627 if now < self._expiry and self._populated: 

628 return self._cache # type: ignore[return-value] # _populated guarantees non-None 

629 try: 

630 merged: set[str] = set(global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers or []) or []) 

631 gw_rows = db.query(DbGateway.passthrough_headers).filter(DbGateway.passthrough_headers.isnot(None)).all() 

632 for (gw_headers,) in gw_rows: 

633 if gw_headers: 

634 merged.update(gw_headers) 

635 self._cache = frozenset(merged) 

636 self._populated = True 

637 self._expiry = now + self._ttl 

638 except Exception: 

639 logger.warning("Failed to refresh loopback allowlist cache from DB; using stale value if available", exc_info=True) 

640 if self._populated and self._cache is not None: 

641 # Extend TTL briefly to avoid hammering DB on every request 

642 self._expiry = now + min(self._ttl, 10.0) 

643 else: 

644 raise 

645 return self._cache # type: ignore[return-value] # _populated guarantees non-None 

646 

647 def invalidate(self) -> None: 

648 """Force a refresh on next access.""" 

649 with self._lock: 

650 self._populated = False 

651 self._expiry = 0 

652 

653 

654_loopback_allowlist_cache = _LoopbackAllowlistCache() 

655 

656 

657def invalidate_passthrough_header_caches() -> None: 

658 """Invalidate both the global config cache and the loopback allowlist cache. 

659 

660 Call this after any mutation to passthrough header configuration (global or 

661 per-gateway) so that loopback transports (SSE, WebSocket, Streamable HTTP) 

662 converge immediately with direct /rpc rather than waiting for TTL expiry. 

663 """ 

664 global_config_cache.invalidate() 

665 _loopback_allowlist_cache.invalidate() 

666 logger.debug("Invalidated global_config_cache and _loopback_allowlist_cache for passthrough headers") 

667 

668 

669def filter_loopback_skip_headers(headers: Dict[str, str]) -> Dict[str, str]: 

670 """Return a copy of *headers* with gateway-internal loopback headers removed. 

671 

672 Defense-in-depth filter applied at loopback merge sites (SSE generate_response, 

673 WebSocket relay) to ensure passthrough headers can never override the gateway's 

674 internal JWT, content-type, proxy-user, or session/routing headers — even if 

675 ``extract_headers_for_loopback`` is bypassed or its skip-list is out of sync. 

676 

677 Values are re-sanitized via ``sanitize_header_value()`` so this function is 

678 safe to call on input that has not been pre-sanitized. 

679 

680 Args: 

681 headers: Candidate passthrough headers to filter. 

682 

683 Returns: 

684 New dictionary containing only headers whose lowercased names are 

685 **not** in the skip set (``_LOOPBACK_SKIP_HEADERS`` plus the 

686 configurable ``proxy_user_header``), with values sanitized. 

687 """ 

688 skip = _loopback_skip_set() 

689 filtered: Dict[str, str] = {} 

690 for k, v in headers.items(): 

691 if k.lower() in skip: 

692 continue 

693 try: 

694 filtered[k] = sanitize_header_value(v) 

695 except Exception: 

696 logger.warning("Dropped unsafe header %s during loopback filter", k, exc_info=True) 

697 return filtered 

698 

699 

700def extract_headers_for_loopback(request_headers: Dict[str, str]) -> Dict[str, str]: 

701 """Extract passthrough-relevant headers to forward in internal loopback /rpc calls. 

702 

703 SSE and WebSocket transports make internal loopback HTTP calls to /rpc. Client 

704 passthrough headers (like X-Upstream-Authorization) must be included in those 

705 loopback requests so that /rpc can forward them to upstream MCP servers via 

706 get_passthrough_headers(). 

707 

708 Always extracts: 

709 - x-upstream-authorization (always enabled per design, renamed to Authorization upstream) 

710 

711 When ENABLE_HEADER_PASSTHROUGH is True, also extracts headers matching the 

712 cached union of: 

713 - The global allowlist resolved via global_config_cache.get_passthrough_headers() 

714 (respects PASSTHROUGH_HEADERS_SOURCE priority: env, db, merge) 

715 - All gateway-specific passthrough_headers configured on any Gateway 

716 

717 The merged allowlist is cached with a 60 s TTL (matching global_config_cache) 

718 so the gateway table scan only runs once per TTL window, not per request. 

719 

720 All extracted values are sanitized via sanitize_header_value() for defense-in-depth, 

721 even though the /rpc endpoint re-sanitizes via get_passthrough_headers(). 

722 

723 Headers in _LOOPBACK_SKIP_HEADERS (authorization, content-type, and gateway-internal 

724 routing/session headers) are never returned, regardless of configuration. 

725 

726 Args: 

727 request_headers: Headers from the incoming client HTTP request or WebSocket 

728 handshake. Keys are header names, values are header values. 

729 

730 Returns: 

731 Dictionary of headers to merge into the loopback /rpc request. 

732 Does not include authorization, content-type, or gateway-internal headers 

733 (those are handled separately by the caller). 

734 

735 Examples: 

736 X-Upstream-Authorization is always extracted: 

737 >>> from unittest.mock import patch 

738 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as s: 

739 ... s.enable_header_passthrough = False 

740 ... s.default_passthrough_headers = [] 

741 ... extract_headers_for_loopback({"X-Upstream-Authorization": "Bearer tok"}) 

742 {'x-upstream-authorization': 'Bearer tok'} 

743 

744 Empty when no relevant headers present: 

745 >>> from unittest.mock import patch 

746 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as s: 

747 ... s.enable_header_passthrough = False 

748 ... s.default_passthrough_headers = [] 

749 ... extract_headers_for_loopback({"Accept": "text/html"}) 

750 {} 

751 """ 

752 forwarded: Dict[str, str] = {} 

753 if not request_headers: 

754 return forwarded 

755 

756 headers_lower = {k.lower(): v for k, v in request_headers.items()} 

757 

758 # Always forward x-upstream-authorization (always-enabled passthrough header). 

759 # On sanitization failure, drop the header rather than forwarding an unsanitized value — 

760 # sanitization prevents CRLF/control-character injection, so bypassing it is unsafe. 

761 upstream_auth = headers_lower.get("x-upstream-authorization") 

762 if upstream_auth: 

763 try: 

764 forwarded["x-upstream-authorization"] = sanitize_header_value(upstream_auth) 

765 except Exception: 

766 logger.warning("Failed to sanitize x-upstream-authorization; dropping header for safety", exc_info=True) 

767 

768 # When passthrough feature is enabled, also forward configured allowlist headers. 

769 # The merged allowlist (global + all gateways) is cached with a 60 s TTL. 

770 try: 

771 if settings.enable_header_passthrough: 

772 # First-Party 

773 from mcpgateway.db import SessionLocal # pylint: disable=import-outside-toplevel 

774 

775 with SessionLocal() as db: 

776 allowed = _loopback_allowlist_cache.get(db) 

777 skip = _loopback_skip_set() 

778 for header_name in allowed: 

779 header_lower = header_name.lower() 

780 if header_lower in skip: 

781 continue 

782 if header_lower in headers_lower: 

783 try: 

784 forwarded[header_lower] = sanitize_header_value(headers_lower[header_lower]) 

785 except Exception: 

786 logger.warning("Failed to sanitize passthrough header %s; skipping", header_lower, exc_info=True) 

787 except Exception: 

788 logger.warning("Failed to read passthrough header allowlist; forwarding only previously extracted headers", exc_info=True) 

789 

790 if forwarded: 

791 logger.debug("Extracted %d passthrough header(s) for loopback: %s", len(forwarded), list(forwarded.keys())) 

792 

793 return forwarded 

794 

795 

796def safe_extract_headers_for_loopback(request_headers: Dict[str, str], transport_name: str = "transport") -> Dict[str, str]: 

797 """Safely extract passthrough headers, returning ``{}`` on failure. 

798 

799 Wraps :func:`extract_headers_for_loopback` so that SSE / WebSocket setup 

800 is never blocked by passthrough configuration issues. ``ImportError`` 

801 propagates (broken deployment should fail loudly). 

802 

803 Args: 

804 request_headers: Incoming HTTP headers to extract from. 

805 transport_name: Label for warning logs on failure. 

806 

807 Returns: 

808 Dict[str, str]: Extracted passthrough headers, or empty dict on error. 

809 """ 

810 try: 

811 return extract_headers_for_loopback(request_headers) 

812 except Exception: 

813 logger.warning("Failed to extract passthrough headers for %s; upstream auth may fail", transport_name, exc_info=True) 

814 return {} 

815 

816 

817def safe_extract_and_filter_for_loopback(request_headers: Dict[str, str]) -> Dict[str, str]: 

818 """Extract *and* filter passthrough headers, returning ``{}`` on failure. 

819 

820 Combines :func:`extract_headers_for_loopback` and 

821 :func:`filter_loopback_skip_headers` with error handling so that 

822 Streamable HTTP affinity loopback calls degrade gracefully. 

823 

824 Args: 

825 request_headers: Incoming HTTP headers to extract and filter. 

826 

827 Returns: 

828 Dict[str, str]: Filtered passthrough headers, or empty dict on error. 

829 """ 

830 try: 

831 return filter_loopback_skip_headers(extract_headers_for_loopback(request_headers)) 

832 except Exception: 

833 logger.warning("Failed to extract passthrough headers for loopback; upstream auth may fail", exc_info=True) 

834 return {}