Coverage for mcpgateway / utils / passthrough_headers.py: 100%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/passthrough_headers.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7HTTP Header Passthrough Utilities. 

8This module provides utilities for handling HTTP header passthrough functionality 

9in the MCP Gateway. It enables forwarding of specific headers from incoming 

10client requests to backing MCP servers while preventing conflicts with 

11existing authentication mechanisms. 

12 

13Key Features: 

14- Global configuration support via environment variables and database 

15- Per-gateway header configuration overrides 

16- Intelligent conflict detection with existing authentication headers 

17- Security-first approach with explicit allowlist handling 

18- Comprehensive logging for debugging and monitoring 

19- Header validation and sanitization 

20 

21The header passthrough system follows a priority hierarchy: 

221. Gateway-specific headers (highest priority) 

232. Global database configuration 

243. Environment variable defaults (lowest priority) 

25 

26Example Usage: 

27 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py 

28 for detailed examples of header passthrough functionality. 

29""" 

30 

31# Standard 

32import logging 

33import re 

34from typing import Dict, List, Optional 

35 

36# Third-Party 

37from sqlalchemy.orm import Session 

38 

39# First-Party 

40from mcpgateway.cache.global_config_cache import global_config_cache 

41from mcpgateway.config import settings 

42from mcpgateway.db import Gateway as DbGateway 

43from mcpgateway.db import GlobalConfig 

44 

45logger = logging.getLogger(__name__) 

46 

47# Header name validation regex - allows letters, numbers, and hyphens 

48HEADER_NAME_REGEX = re.compile(r"^[A-Za-z0-9\-]+$") 

49 

50# Maximum header value length (4KB) 

51MAX_HEADER_VALUE_LENGTH = 4096 

52 

53 

54class PassthroughHeadersError(Exception): 

55 """Base class for passthrough headers-related errors. 

56 

57 Examples: 

58 >>> error = PassthroughHeadersError("Test error") 

59 >>> str(error) 

60 'Test error' 

61 >>> isinstance(error, Exception) 

62 True 

63 """ 

64 

65 

66def sanitize_header_value(value: str, max_length: int = MAX_HEADER_VALUE_LENGTH) -> str: 

67 """Sanitize header value for security. 

68 

69 Removes dangerous characters and enforces length limits. 

70 

71 Args: 

72 value: Header value to sanitize 

73 max_length: Maximum allowed length 

74 

75 Returns: 

76 Sanitized header value 

77 

78 Examples: 

79 Remove CRLF and trim length: 

80 >>> s = sanitize_header_value('val' + chr(13) + chr(10) + 'more', max_length=6) 

81 >>> s 

82 'valmor' 

83 >>> len(s) <= 6 

84 True 

85 >>> sanitize_header_value(' spaced ') 

86 'spaced' 

87 """ 

88 # Remove newlines and carriage returns to prevent header injection 

89 value = value.replace("\r", "").replace("\n", "") 

90 

91 # Trim to max length 

92 value = value[:max_length] 

93 

94 # Remove control characters except tab (ASCII 9) and space (ASCII 32) 

95 value = "".join(c for c in value if ord(c) >= 32 or c == "\t") 

96 

97 return value.strip() 

98 

99 

100def validate_header_name(name: str) -> bool: 

101 """Validate header name against allowed pattern. 

102 

103 Args: 

104 name: Header name to validate 

105 

106 Returns: 

107 True if valid, False otherwise 

108 

109 Examples: 

110 Valid names: 

111 >>> validate_header_name('X-Tenant-Id') 

112 True 

113 >>> validate_header_name('X123-ABC') 

114 True 

115 

116 Invalid names: 

117 >>> validate_header_name('Invalid Header:Name') 

118 False 

119 >>> validate_header_name('Bad@Name') 

120 False 

121 """ 

122 return bool(HEADER_NAME_REGEX.match(name)) 

123 

124 

125def get_passthrough_headers(request_headers: Dict[str, str], base_headers: Dict[str, str], db: Session, gateway: Optional[DbGateway] = None) -> Dict[str, str]: 

126 """Get headers that should be passed through to the target gateway. 

127 

128 This function implements the core logic for HTTP header passthrough in the MCP Gateway. 

129 It determines which headers from incoming client requests should be forwarded to 

130 backing MCP servers based on configuration settings and security policies. 

131 

132 Configuration Priority (highest to lowest): 

133 1. Gateway-specific passthrough_headers setting 

134 2. Global headers from get_passthrough_headers() based on PASSTHROUGH_HEADERS_SOURCE: 

135 - "db": Database wins if configured, env var DEFAULT_PASSTHROUGH_HEADERS as fallback 

136 - "env": Environment variable always wins, database ignored 

137 - "merge": Union of both sources (DB casing wins for duplicates) 

138 

139 Security Features: 

140 - Feature flag control (disabled by default) 

141 - Prevents conflicts with existing base headers (e.g., Content-Type) 

142 - Blocks Authorization header conflicts with gateway authentication 

143 - Header name validation (regex pattern matching) 

144 - Header value sanitization (removes dangerous characters, enforces limits) 

145 - Logs all conflicts and skipped headers for debugging 

146 - Uses case-insensitive header matching for robustness 

147 - Special X-Upstream-Authorization handling: When gateway uses auth, clients can 

148 send X-Upstream-Authorization header which gets renamed to Authorization for upstream 

149 

150 Args: 

151 request_headers (Dict[str, str]): Headers from the incoming HTTP request. 

152 Keys should be header names, values should be header values. 

153 Example: {"Authorization": "Bearer token123", "X-Tenant-Id": "acme"} 

154 base_headers (Dict[str, str]): Base headers that should always be included 

155 in the final result. These take precedence over passthrough headers. 

156 Example: {"Content-Type": "application/json", "User-Agent": "MCPGateway/1.0"} 

157 db (Session): SQLAlchemy database session for querying global configuration. 

158 Used to retrieve GlobalConfig.passthrough_headers setting. 

159 gateway (Optional[DbGateway]): Target gateway instance. If provided, uses 

160 gateway.passthrough_headers to override global settings. Also checks 

161 gateway.auth_type to prevent Authorization header conflicts. 

162 

163 Returns: 

164 Dict[str, str]: Combined dictionary of base headers plus allowed passthrough 

165 headers from the request. Base headers are preserved, and passthrough 

166 headers are added only if they don't conflict with security policies. 

167 

168 Raises: 

169 No exceptions are raised. Errors are logged as warnings and processing continues. 

170 Database connection issues may propagate from the db.query() call. 

171 

172 Examples: 

173 Feature disabled by default (secure by default): 

174 >>> from unittest.mock import Mock, patch 

175 >>> from mcpgateway.cache.global_config_cache import global_config_cache 

176 >>> global_config_cache.invalidate() # Clear cache for isolated test 

177 >>> with patch(__name__ + ".settings") as mock_settings: 

178 ... mock_settings.enable_header_passthrough = False 

179 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id"] 

180 ... mock_db = Mock() 

181 ... mock_db.query.return_value.first.return_value = None 

182 ... request_headers = {"x-tenant-id": "should-be-ignored"} 

183 ... base_headers = {"Content-Type": "application/json"} 

184 ... get_passthrough_headers(request_headers, base_headers, mock_db) 

185 {'Content-Type': 'application/json'} 

186 

187 Enabled with allowlist and conflicts: 

188 >>> global_config_cache.invalidate() # Clear cache for isolated test 

189 >>> with patch(__name__ + ".settings") as mock_settings: 

190 ... mock_settings.enable_header_passthrough = True 

191 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "Authorization"] 

192 ... # Mock DB returns no global override 

193 ... mock_db = Mock() 

194 ... mock_db.query.return_value.first.return_value = None 

195 ... # Gateway with basic auth should block Authorization passthrough 

196 ... gateway = Mock() 

197 ... gateway.passthrough_headers = None 

198 ... gateway.auth_type = "basic" 

199 ... gateway.name = "gw1" 

200 ... req_headers = {"X-Tenant-Id": "acme", "Authorization": "Bearer abc"} 

201 ... base = {"Content-Type": "application/json", "Authorization": "Bearer base"} 

202 ... res = get_passthrough_headers(req_headers, base, mock_db, gateway) 

203 ... ("X-Tenant-Id" in res) and (res["Authorization"] == "Bearer base") 

204 True 

205 

206 See comprehensive unit tests in tests/unit/mcpgateway/utils/test_passthrough_headers*.py 

207 for detailed examples of enabled functionality, conflict detection, and security features. 

208 

209 Note: 

210 Header names are matched case-insensitively but preserved in their original 

211 case from the allowed_headers configuration. Request header values are 

212 matched case-insensitively against the request_headers dictionary. 

213 """ 

214 passthrough_headers = base_headers.copy() 

215 

216 # Special handling for X-Upstream-Authorization header (always enabled) 

217 # If gateway uses auth and client wants to pass Authorization to upstream, 

218 # client can use X-Upstream-Authorization which gets renamed to Authorization 

219 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

220 upstream_auth = request_headers_lower.get("x-upstream-authorization") 

221 

222 if upstream_auth: 

223 try: 

224 sanitized_value = sanitize_header_value(upstream_auth) 

225 if sanitized_value: 

226 # Always rename X-Upstream-Authorization to Authorization for upstream 

227 # This works for both auth and no-auth gateways 

228 passthrough_headers["Authorization"] = sanitized_value 

229 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough") 

230 except Exception as e: 

231 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}") 

232 elif gateway and gateway.auth_type == "none": 

233 # When gateway has no auth, pass through client's Authorization if present 

234 client_auth = request_headers_lower.get("authorization") 

235 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]: 

236 try: 

237 sanitized_value = sanitize_header_value(client_auth) 

238 if sanitized_value: 

239 passthrough_headers["Authorization"] = sanitized_value 

240 logger.debug("Passing through client Authorization header (auth_type=none)") 

241 except Exception as e: 

242 logger.warning(f"Failed to sanitize Authorization header: {e}") 

243 

244 # Early return if header passthrough feature is disabled 

245 if not settings.enable_header_passthrough: 

246 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag") 

247 return passthrough_headers 

248 

249 if settings.enable_overwrite_base_headers: 

250 logger.debug("Overwriting base headers is enabled via ENABLE_OVERWRITE_BASE_HEADERS flag") 

251 

252 # Get global passthrough headers from in-memory cache (Issue #1715) 

253 # This eliminates redundant DB queries for static configuration 

254 allowed_headers = global_config_cache.get_passthrough_headers(db, settings.default_passthrough_headers) 

255 

256 # Gateway specific headers override global config 

257 if gateway: 

258 if gateway.passthrough_headers is not None: 

259 allowed_headers = gateway.passthrough_headers 

260 

261 # Create case-insensitive lookup for request headers 

262 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

263 

264 # Get auth headers to check for conflicts 

265 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()} 

266 

267 # Copy allowed headers from request 

268 if request_headers_lower and allowed_headers: 

269 for header_name in allowed_headers: 

270 # Validate header name 

271 if not validate_header_name(header_name): 

272 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

273 continue 

274 

275 header_lower = header_name.lower() 

276 header_value = request_headers_lower.get(header_lower) 

277 

278 if header_value: 

279 # Sanitize header value 

280 try: 

281 sanitized_value = sanitize_header_value(header_value) 

282 if not sanitized_value: 

283 logger.warning(f"Header {header_name} value became empty after sanitization - skipping") 

284 continue 

285 except Exception as e: 

286 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping") 

287 continue 

288 

289 # Skip if header would conflict with existing auth headers 

290 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers: 

291 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers") 

292 continue 

293 

294 # Skip if header would conflict with gateway auth 

295 if gateway: 

296 if gateway.auth_type == "basic" and header_lower == "authorization": 

297 logger.warning(f"Skipping Authorization header passthrough due to basic auth configuration on gateway {gateway.name}") 

298 continue 

299 if gateway.auth_type == "bearer" and header_lower == "authorization": 

300 logger.warning(f"Skipping Authorization header passthrough due to bearer auth configuration on gateway {gateway.name}") 

301 continue 

302 

303 # Use original header name casing from configuration, sanitized value from request 

304 passthrough_headers[header_name] = sanitized_value 

305 logger.debug(f"Added passthrough header: {header_name}") 

306 else: 

307 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough") 

308 

309 logger.debug(f"Final passthrough headers: {list(passthrough_headers.keys())}") 

310 return passthrough_headers 

311 

312 

313def compute_passthrough_headers_cached( 

314 request_headers: Dict[str, str], 

315 base_headers: Dict[str, str], 

316 allowed_headers: List[str], 

317 gateway_auth_type: Optional[str] = None, 

318 gateway_passthrough_headers: Optional[List[str]] = None, 

319) -> Dict[str, str]: 

320 """Compute passthrough headers without database query. 

321 

322 Use this when GlobalConfig has already been fetched and cached, to avoid 

323 repeated database queries during high-frequency operations like tool invocation. 

324 

325 This function implements the same header passthrough logic as get_passthrough_headers() 

326 but accepts pre-fetched configuration values instead of querying the database. 

327 

328 Args: 

329 request_headers: Headers from the incoming HTTP request. 

330 base_headers: Base headers that should always be included (auth, content-type, etc.). 

331 allowed_headers: List of header names allowed to pass through (from GlobalConfig). 

332 gateway_auth_type: The gateway's auth_type (basic, bearer, oauth, none) if applicable. 

333 gateway_passthrough_headers: Gateway-specific passthrough headers override. 

334 

335 Returns: 

336 Combined dictionary of base headers plus allowed passthrough headers. 

337 

338 Examples: 

339 >>> from unittest.mock import patch 

340 >>> from mcpgateway.utils.passthrough_headers import compute_passthrough_headers_cached 

341 >>> request = {"X-Tenant-Id": "acme", "Authorization": "secret"} 

342 >>> base = {"Content-Type": "application/json"} 

343 >>> allowed = ["X-Tenant-Id"] 

344 >>> with patch("mcpgateway.utils.passthrough_headers.settings") as mock_settings: 

345 ... mock_settings.enable_header_passthrough = True 

346 ... mock_settings.enable_overwrite_base_headers = False 

347 ... result = compute_passthrough_headers_cached(request, base, allowed, gateway_auth_type=None) 

348 >>> "X-Tenant-Id" in result 

349 True 

350 >>> result.get("Authorization") is None # Not in allowed list 

351 True 

352 """ 

353 passthrough_headers = base_headers.copy() 

354 

355 # Special handling for X-Upstream-Authorization header (always enabled) 

356 request_headers_lower = {k.lower(): v for k, v in request_headers.items()} if request_headers else {} 

357 upstream_auth = request_headers_lower.get("x-upstream-authorization") 

358 

359 if upstream_auth: 

360 try: 

361 sanitized_value = sanitize_header_value(upstream_auth) 

362 if sanitized_value: 

363 passthrough_headers["Authorization"] = sanitized_value 

364 logger.debug("Renamed X-Upstream-Authorization to Authorization for upstream passthrough") 

365 except Exception as e: 

366 logger.warning(f"Failed to sanitize X-Upstream-Authorization header: {e}") 

367 elif gateway_auth_type == "none": 

368 # When gateway has no auth, pass through client's Authorization if present 

369 client_auth = request_headers_lower.get("authorization") 

370 if client_auth and "authorization" not in [h.lower() for h in base_headers.keys()]: 

371 try: 

372 sanitized_value = sanitize_header_value(client_auth) 

373 if sanitized_value: 

374 passthrough_headers["Authorization"] = sanitized_value 

375 logger.debug("Passing through client Authorization header (auth_type=none)") 

376 except Exception as e: 

377 logger.warning(f"Failed to sanitize Authorization header: {e}") 

378 

379 # Early return if header passthrough feature is disabled 

380 if not settings.enable_header_passthrough: 

381 logger.debug("Header passthrough is disabled via ENABLE_HEADER_PASSTHROUGH flag") 

382 return passthrough_headers 

383 

384 # Use gateway-specific headers if provided, otherwise use global allowed_headers 

385 effective_allowed = gateway_passthrough_headers if gateway_passthrough_headers is not None else allowed_headers 

386 

387 # Create case-insensitive lookup for base headers 

388 base_headers_keys = {key.lower(): key for key in passthrough_headers.keys()} 

389 

390 # Copy allowed headers from request 

391 if request_headers_lower and effective_allowed: 

392 for header_name in effective_allowed: 

393 # Validate header name 

394 if not validate_header_name(header_name): 

395 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

396 continue 

397 

398 header_lower = header_name.lower() 

399 header_value = request_headers_lower.get(header_lower) 

400 

401 if header_value: 

402 # Sanitize header value 

403 try: 

404 sanitized_value = sanitize_header_value(header_value) 

405 if not sanitized_value: 

406 logger.warning(f"Header {header_name} value became empty after sanitization - skipping") 

407 continue 

408 except Exception as e: 

409 logger.warning(f"Failed to sanitize header {header_name}: {e} - skipping") 

410 continue 

411 

412 # Skip if header would conflict with existing auth headers 

413 if header_lower in base_headers_keys and not settings.enable_overwrite_base_headers: 

414 logger.warning(f"Skipping {header_name} header passthrough as it conflicts with pre-defined headers") 

415 continue 

416 

417 # Skip if header would conflict with gateway auth 

418 if gateway_auth_type in ("basic", "bearer") and header_lower == "authorization": 

419 logger.warning(f"Skipping Authorization header passthrough due to {gateway_auth_type} auth configuration") 

420 continue 

421 

422 # Use original header name casing from configuration, sanitized value from request 

423 passthrough_headers[header_name] = sanitized_value 

424 logger.debug(f"Added passthrough header: {header_name}") 

425 else: 

426 logger.debug(f"Header {header_name} not found in request headers, skipping passthrough") 

427 

428 logger.debug(f"Final passthrough headers (cached): {list(passthrough_headers.keys())}") 

429 return passthrough_headers 

430 

431 

432async def set_global_passthrough_headers(db: Session) -> None: 

433 """Set global passthrough headers in the database if not already configured. 

434 

435 This function checks if the global passthrough headers are already set in the 

436 GlobalConfig table. If not, it initializes them with the default headers from 

437 settings.default_passthrough_headers. 

438 

439 When PASSTHROUGH_HEADERS_SOURCE=env, this function skips database writes entirely 

440 since the database configuration is ignored in that mode. 

441 

442 Args: 

443 db (Session): SQLAlchemy database session for querying and updating GlobalConfig. 

444 

445 Raises: 

446 PassthroughHeadersError: If unable to update passthrough headers in the database. 

447 

448 Examples: 

449 Successful insert of default headers: 

450 >>> import pytest 

451 >>> from unittest.mock import Mock, patch 

452 >>> @pytest.mark.asyncio 

453 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

454 ... async def test_default_headers(mock_settings): 

455 ... mock_settings.enable_header_passthrough = True 

456 ... mock_settings.passthrough_headers_source = "db" 

457 ... mock_settings.default_passthrough_headers = ["X-Tenant-Id", "X-Trace-Id"] 

458 ... mock_db = Mock() 

459 ... mock_db.query.return_value.first.return_value = None 

460 ... await set_global_passthrough_headers(mock_db) 

461 ... mock_db.add.assert_called_once() 

462 ... mock_db.commit.assert_called_once() 

463 

464 Database write failure: 

465 >>> import pytest 

466 >>> from unittest.mock import Mock, patch 

467 >>> from mcpgateway.utils.passthrough_headers import PassthroughHeadersError 

468 >>> @pytest.mark.asyncio 

469 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

470 ... async def test_db_write_failure(mock_settings): 

471 ... mock_settings.enable_header_passthrough = True 

472 ... mock_settings.passthrough_headers_source = "db" 

473 ... mock_db = Mock() 

474 ... mock_db.query.return_value.first.return_value = None 

475 ... mock_db.commit.side_effect = Exception("DB write failed") 

476 ... with pytest.raises(PassthroughHeadersError): 

477 ... await set_global_passthrough_headers(mock_db) 

478 ... mock_db.rollback.assert_called_once() 

479 

480 Config already exists (no DB write): 

481 >>> import pytest 

482 >>> from unittest.mock import Mock, patch 

483 >>> from mcpgateway.common.models import GlobalConfig 

484 >>> @pytest.mark.asyncio 

485 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

486 ... async def test_existing_config(mock_settings): 

487 ... mock_settings.enable_header_passthrough = True 

488 ... mock_settings.passthrough_headers_source = "db" 

489 ... mock_db = Mock() 

490 ... existing = Mock(spec=GlobalConfig) 

491 ... existing.passthrough_headers = ["X-Tenant-ID", "Authorization"] 

492 ... mock_db.query.return_value.first.return_value = existing 

493 ... await set_global_passthrough_headers(mock_db) 

494 ... mock_db.add.assert_not_called() 

495 ... mock_db.commit.assert_not_called() 

496 ... assert existing.passthrough_headers == ["X-Tenant-ID", "Authorization"] 

497 

498 Env mode skips DB entirely: 

499 >>> import pytest 

500 >>> from unittest.mock import Mock, patch 

501 >>> @pytest.mark.asyncio 

502 ... @patch("mcpgateway.utils.passthrough_headers.settings") 

503 ... async def test_env_mode_skips_db(mock_settings): 

504 ... mock_settings.passthrough_headers_source = "env" 

505 ... mock_db = Mock() 

506 ... await set_global_passthrough_headers(mock_db) 

507 ... mock_db.query.assert_not_called() 

508 ... mock_db.add.assert_not_called() 

509 

510 Note: 

511 This function is typically called during application startup to ensure 

512 global configuration is in place before any gateway operations. 

513 """ 

514 # When source is "env", skip DB operations entirely - env vars always win 

515 if settings.passthrough_headers_source == "env": 

516 logger.debug("Passthrough headers source=env: skipping database initialization (env vars always used)") 

517 return 

518 

519 # Query DB directly here (not cache) because we need to check if config exists 

520 # to decide whether to create it 

521 global_config = db.query(GlobalConfig).first() 

522 

523 if not global_config: 

524 config_headers = settings.default_passthrough_headers 

525 allowed_headers = [] 

526 if config_headers: 

527 for header_name in config_headers: 

528 # Validate header name 

529 if not validate_header_name(header_name): 

530 logger.warning(f"Invalid header name '{header_name}' - skipping (must match pattern: {HEADER_NAME_REGEX.pattern})") 

531 continue 

532 

533 allowed_headers.append(header_name) 

534 try: 

535 db.add(GlobalConfig(passthrough_headers=allowed_headers)) 

536 db.commit() 

537 # Invalidate cache so next read picks up new config (Issue #1715) 

538 global_config_cache.invalidate() 

539 except Exception as e: 

540 db.rollback() 

541 raise PassthroughHeadersError(f"Failed to update passthrough headers: {str(e)}")