Coverage for mcpgateway / middleware / validation_middleware.py: 100%

100 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/middleware/validation_middleware.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Validation middleware for MCP Gateway input validation and output sanitization. 

8 

9This middleware provides comprehensive input validation and output sanitization 

10for MCP Gateway requests. It validates request parameters, JSON payloads, and 

11resource paths to prevent security vulnerabilities like path traversal, XSS, 

12and injection attacks. 

13 

14Examples: 

15 >>> from mcpgateway.middleware.validation_middleware import ValidationMiddleware # doctest: +SKIP 

16 >>> app.add_middleware(ValidationMiddleware) # doctest: +SKIP 

17""" 

18 

19# Standard 

20import logging 

21from pathlib import Path 

22import re 

23from typing import Any 

24 

25# Third-Party 

26from fastapi import HTTPException, Request, Response 

27import orjson 

28from starlette.middleware.base import BaseHTTPMiddleware 

29 

30# First-Party 

31from mcpgateway.config import settings 

32 

33logger = logging.getLogger(__name__) 

34 

35 

36def is_path_traversal(uri: str) -> bool: 

37 """Check if URI contains path traversal patterns. 

38 

39 Args: 

40 uri (str): URI to check 

41 

42 Returns: 

43 bool: True if path traversal detected 

44 """ 

45 return ".." in uri or uri.startswith("/") or "\\" in uri 

46 

47 

48class ValidationMiddleware(BaseHTTPMiddleware): 

49 """Middleware for validating inputs and sanitizing outputs. 

50 

51 This middleware validates request parameters, JSON data, and resource paths 

52 to prevent security vulnerabilities. It can operate in strict or lenient mode 

53 and optionally sanitizes response content. 

54 """ 

55 

56 def __init__(self, app): 

57 """Initialize validation middleware with configuration settings. 

58 

59 Args: 

60 app: FastAPI application instance 

61 """ 

62 super().__init__(app) 

63 self.enabled = settings.experimental_validate_io 

64 self.strict = settings.validation_strict 

65 self.sanitize = settings.sanitize_output 

66 self.allowed_roots = [Path(root).resolve() for root in settings.allowed_roots] 

67 self.dangerous_patterns = [re.compile(pattern) for pattern in settings.dangerous_patterns] 

68 

69 async def dispatch(self, request: Request, call_next): 

70 """Process request with validation and response sanitization. 

71 

72 Args: 

73 request: Incoming HTTP request 

74 call_next: Next middleware/handler in chain 

75 

76 Returns: 

77 HTTP response, potentially sanitized 

78 

79 Raises: 

80 HTTPException: If validation fails in strict mode 

81 """ 

82 # Phase 0: Feature disabled - skip entirely 

83 if not self.enabled: 

84 response = await call_next(request) 

85 return response 

86 

87 # Phase 1: Log-only mode in dev/staging 

88 warn_only = settings.environment in ("development", "staging") and not self.strict 

89 

90 # Validate input 

91 try: 

92 await self._validate_request(request) 

93 except HTTPException as e: 

94 if warn_only: 

95 logger.warning("[VALIDATION] Input validation failed (log-only mode): %s", e.detail) 

96 else: 

97 logger.error("[VALIDATION] Input validation failed: %s", e.detail) 

98 raise 

99 

100 response = await call_next(request) 

101 

102 # Sanitize output 

103 if self.sanitize: 

104 response = await self._sanitize_response(response) 

105 

106 return response 

107 

108 async def _validate_request(self, request: Request): 

109 """Validate incoming request parameters. 

110 

111 Args: 

112 request (Request): Incoming HTTP request to validate 

113 

114 Raises: 

115 HTTPException: If validation fails in strict mode 

116 """ 

117 # Validate path parameters 

118 if hasattr(request, "path_params"): 

119 for key, value in request.path_params.items(): 

120 self._validate_parameter(key, str(value)) 

121 

122 # Validate query parameters 

123 for key, value in request.query_params.items(): 

124 self._validate_parameter(key, value) 

125 

126 # Validate JSON body for resource/tool requests 

127 if request.headers.get("content-type", "").startswith("application/json"): 

128 try: 

129 body = await request.body() 

130 if body: 

131 data = orjson.loads(body) 

132 self._validate_json_data(data) 

133 except orjson.JSONDecodeError: 

134 pass # Let other middleware handle JSON errors 

135 

136 def _validate_parameter(self, key: str, value: str): 

137 """Validate individual parameter for length and dangerous patterns. 

138 

139 Args: 

140 key (str): Parameter name 

141 value (str): Parameter value 

142 

143 Raises: 

144 HTTPException: If validation fails in strict mode 

145 """ 

146 if len(value) > settings.max_param_length: 

147 if settings.environment in ("development", "staging"): 

148 logger.warning(f"Parameter {key} exceeds maximum length") 

149 return 

150 raise HTTPException(status_code=422, detail=f"Parameter {key} exceeds maximum length") 

151 

152 for pattern in self.dangerous_patterns: 

153 if pattern.search(value): 

154 if settings.environment in ("development", "staging"): 

155 logger.warning(f"Parameter {key} contains dangerous characters") 

156 return 

157 raise HTTPException(status_code=422, detail=f"Parameter {key} contains dangerous characters") 

158 

159 def _validate_json_data(self, data: Any): 

160 """Recursively validate JSON data structure. 

161 

162 Args: 

163 data (Any): JSON data to validate 

164 

165 Raises: 

166 HTTPException: If validation fails in strict mode 

167 """ 

168 if isinstance(data, dict): 

169 for key, value in data.items(): 

170 if isinstance(value, str): 

171 self._validate_parameter(key, value) 

172 elif isinstance(value, (dict, list)): 

173 self._validate_json_data(value) 

174 elif isinstance(data, list): 

175 for item in data: 

176 self._validate_json_data(item) 

177 

178 def validate_resource_path(self, path: str) -> str: 

179 """Validate and normalize resource paths to prevent traversal attacks. 

180 

181 Args: 

182 path (str): Resource path to validate 

183 

184 Returns: 

185 str: Normalized path if valid 

186 

187 Raises: 

188 HTTPException: If path is invalid or contains traversal patterns 

189 """ 

190 # Skip validation for URI schemes (http://, plugin://, etc.) 

191 # 

192 # Note: This must run before the '//' traversal check, otherwise every URI 

193 # would be rejected due to the '://' sequence. 

194 if re.match(r"^[a-zA-Z][a-zA-Z0-9+\-.]*://", path): 

195 return path 

196 

197 # Check explicit path traversal detection 

198 if ".." in path or "//" in path: 

199 raise HTTPException(status_code=400, detail="invalid_path: Path traversal detected") 

200 

201 try: 

202 resolved_path = Path(path).resolve() 

203 

204 # Check path depth 

205 if len(resolved_path.parts) > settings.max_path_depth: 

206 raise HTTPException(status_code=400, detail="invalid_path: Path too deep") 

207 

208 # Check against allowed roots 

209 if self.allowed_roots: 

210 allowed = any(str(resolved_path).startswith(str(root)) for root in self.allowed_roots) 

211 if not allowed: 

212 raise HTTPException(status_code=400, detail="invalid_path: Path outside allowed roots") 

213 

214 return str(resolved_path) 

215 except (OSError, ValueError): 

216 raise HTTPException(status_code=400, detail="invalid_path: Invalid path") 

217 

218 async def _sanitize_response(self, response: Response) -> Response: 

219 """Sanitize response content by removing control characters. 

220 

221 Args: 

222 response: HTTP response to sanitize 

223 

224 Returns: 

225 Response: Sanitized response 

226 """ 

227 if not hasattr(response, "body"): 

228 return response 

229 

230 try: 

231 body = response.body 

232 if isinstance(body, bytes): 

233 body = body.decode("utf-8", errors="replace") 

234 

235 # Remove control characters except newlines and tabs 

236 sanitized = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f-\x9f]", "", body) 

237 

238 response.body = sanitized.encode("utf-8") 

239 response.headers["content-length"] = str(len(response.body)) 

240 

241 except Exception as e: 

242 logger.warning("Failed to sanitize response: %s", e) 

243 

244 return response