Coverage for mcpgateway / middleware / validation_middleware.py: 100%
100 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/validation_middleware.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Validation middleware for MCP Gateway input validation and output sanitization.
9This middleware provides comprehensive input validation and output sanitization
10for MCP Gateway requests. It validates request parameters, JSON payloads, and
11resource paths to prevent security vulnerabilities like path traversal, XSS,
12and injection attacks.
14Examples:
15 >>> from mcpgateway.middleware.validation_middleware import ValidationMiddleware # doctest: +SKIP
16 >>> app.add_middleware(ValidationMiddleware) # doctest: +SKIP
17"""
19# Standard
20import logging
21from pathlib import Path
22import re
23from typing import Any
25# Third-Party
26from fastapi import HTTPException, Request, Response
27import orjson
28from starlette.middleware.base import BaseHTTPMiddleware
30# First-Party
31from mcpgateway.config import settings
33logger = logging.getLogger(__name__)
36def is_path_traversal(uri: str) -> bool:
37 """Check if URI contains path traversal patterns.
39 Args:
40 uri (str): URI to check
42 Returns:
43 bool: True if path traversal detected
44 """
45 return ".." in uri or uri.startswith("/") or "\\" in uri
48class ValidationMiddleware(BaseHTTPMiddleware):
49 """Middleware for validating inputs and sanitizing outputs.
51 This middleware validates request parameters, JSON data, and resource paths
52 to prevent security vulnerabilities. It can operate in strict or lenient mode
53 and optionally sanitizes response content.
54 """
56 def __init__(self, app):
57 """Initialize validation middleware with configuration settings.
59 Args:
60 app: FastAPI application instance
61 """
62 super().__init__(app)
63 self.enabled = settings.experimental_validate_io
64 self.strict = settings.validation_strict
65 self.sanitize = settings.sanitize_output
66 self.allowed_roots = [Path(root).resolve() for root in settings.allowed_roots]
67 self.dangerous_patterns = [re.compile(pattern) for pattern in settings.dangerous_patterns]
69 async def dispatch(self, request: Request, call_next):
70 """Process request with validation and response sanitization.
72 Args:
73 request: Incoming HTTP request
74 call_next: Next middleware/handler in chain
76 Returns:
77 HTTP response, potentially sanitized
79 Raises:
80 HTTPException: If validation fails in strict mode
81 """
82 # Phase 0: Feature disabled - skip entirely
83 if not self.enabled:
84 response = await call_next(request)
85 return response
87 # Phase 1: Log-only mode in dev/staging
88 warn_only = settings.environment in ("development", "staging") and not self.strict
90 # Validate input
91 try:
92 await self._validate_request(request)
93 except HTTPException as e:
94 if warn_only:
95 logger.warning("[VALIDATION] Input validation failed (log-only mode): %s", e.detail)
96 else:
97 logger.error("[VALIDATION] Input validation failed: %s", e.detail)
98 raise
100 response = await call_next(request)
102 # Sanitize output
103 if self.sanitize:
104 response = await self._sanitize_response(response)
106 return response
108 async def _validate_request(self, request: Request):
109 """Validate incoming request parameters.
111 Args:
112 request (Request): Incoming HTTP request to validate
114 Raises:
115 HTTPException: If validation fails in strict mode
116 """
117 # Validate path parameters
118 if hasattr(request, "path_params"):
119 for key, value in request.path_params.items():
120 self._validate_parameter(key, str(value))
122 # Validate query parameters
123 for key, value in request.query_params.items():
124 self._validate_parameter(key, value)
126 # Validate JSON body for resource/tool requests
127 if request.headers.get("content-type", "").startswith("application/json"):
128 try:
129 body = await request.body()
130 if body:
131 data = orjson.loads(body)
132 self._validate_json_data(data)
133 except orjson.JSONDecodeError:
134 pass # Let other middleware handle JSON errors
136 def _validate_parameter(self, key: str, value: str):
137 """Validate individual parameter for length and dangerous patterns.
139 Args:
140 key (str): Parameter name
141 value (str): Parameter value
143 Raises:
144 HTTPException: If validation fails in strict mode
145 """
146 if len(value) > settings.max_param_length:
147 if settings.environment in ("development", "staging"):
148 logger.warning(f"Parameter {key} exceeds maximum length")
149 return
150 raise HTTPException(status_code=422, detail=f"Parameter {key} exceeds maximum length")
152 for pattern in self.dangerous_patterns:
153 if pattern.search(value):
154 if settings.environment in ("development", "staging"):
155 logger.warning(f"Parameter {key} contains dangerous characters")
156 return
157 raise HTTPException(status_code=422, detail=f"Parameter {key} contains dangerous characters")
159 def _validate_json_data(self, data: Any):
160 """Recursively validate JSON data structure.
162 Args:
163 data (Any): JSON data to validate
165 Raises:
166 HTTPException: If validation fails in strict mode
167 """
168 if isinstance(data, dict):
169 for key, value in data.items():
170 if isinstance(value, str):
171 self._validate_parameter(key, value)
172 elif isinstance(value, (dict, list)):
173 self._validate_json_data(value)
174 elif isinstance(data, list):
175 for item in data:
176 self._validate_json_data(item)
178 def validate_resource_path(self, path: str) -> str:
179 """Validate and normalize resource paths to prevent traversal attacks.
181 Args:
182 path (str): Resource path to validate
184 Returns:
185 str: Normalized path if valid
187 Raises:
188 HTTPException: If path is invalid or contains traversal patterns
189 """
190 # Skip validation for URI schemes (http://, plugin://, etc.)
191 #
192 # Note: This must run before the '//' traversal check, otherwise every URI
193 # would be rejected due to the '://' sequence.
194 if re.match(r"^[a-zA-Z][a-zA-Z0-9+\-.]*://", path):
195 return path
197 # Check explicit path traversal detection
198 if ".." in path or "//" in path:
199 raise HTTPException(status_code=400, detail="invalid_path: Path traversal detected")
201 try:
202 resolved_path = Path(path).resolve()
204 # Check path depth
205 if len(resolved_path.parts) > settings.max_path_depth:
206 raise HTTPException(status_code=400, detail="invalid_path: Path too deep")
208 # Check against allowed roots
209 if self.allowed_roots:
210 allowed = any(str(resolved_path).startswith(str(root)) for root in self.allowed_roots)
211 if not allowed:
212 raise HTTPException(status_code=400, detail="invalid_path: Path outside allowed roots")
214 return str(resolved_path)
215 except (OSError, ValueError):
216 raise HTTPException(status_code=400, detail="invalid_path: Invalid path")
218 async def _sanitize_response(self, response: Response) -> Response:
219 """Sanitize response content by removing control characters.
221 Args:
222 response: HTTP response to sanitize
224 Returns:
225 Response: Sanitized response
226 """
227 if not hasattr(response, "body"):
228 return response
230 try:
231 body = response.body
232 if isinstance(body, bytes):
233 body = body.decode("utf-8", errors="replace")
235 # Remove control characters except newlines and tabs
236 sanitized = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]", "", body)
238 response.body = sanitized.encode("utf-8")
239 response.headers["content-length"] = str(len(response.body))
241 except Exception as e:
242 logger.warning("Failed to sanitize response: %s", e)
244 return response