Coverage for mcpgateway / plugins / framework / utils.py: 99%

91 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/utils.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor, Mihai Criveti, Fred Araujo 

6 

7Utility module for plugins layer. 

8This module implements the utility functions associated with 

9plugins. 

10""" 

11 

12# Standard 

13from functools import cache 

14import importlib 

15import logging 

16from types import ModuleType 

17from typing import Any, Optional 

18 

19# Third-Party 

20from fastapi.responses import JSONResponse 

21import orjson 

22from pydantic import BaseModel, ConfigDict 

23 

24# First-Party 

25from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition 

26 

27logger = logging.getLogger(__name__) 

28 

29 

30class StructuredData(BaseModel): 

31 """Dynamic model that provides attribute access on deserialized dicts. 

32 

33 When framework payload fields are typed as ``Any``, Pydantic keeps 

34 nested dicts as plain dicts during ``model_validate``. This class 

35 is used by :func:`coerce_nested` to convert those dicts into objects 

36 with attribute-style access, preserving compatibility with plugin 

37 code that expects ``payload.result.messages[0].content.text``. 

38 

39 Examples: 

40 >>> sd = StructuredData(name="test", value=42) 

41 >>> sd.name 

42 'test' 

43 >>> sd.model_dump() 

44 {'name': 'test', 'value': 42} 

45 """ 

46 

47 model_config = ConfigDict(extra="allow") 

48 

49 

50def coerce_messages(v: Any) -> Any: 

51 """Convert nested dicts in a messages list to objects with attribute access. 

52 

53 Shared validator logic for agent payload ``messages`` fields. 

54 When deserializing from JSON, messages arrive as plain dicts. This 

55 converts each dict to a :class:`StructuredData` so plugin code like 

56 ``payload.messages[0].content.text`` works regardless of the transport. 

57 

58 Args: 

59 v: The raw value for the ``messages`` field. 

60 

61 Returns: 

62 The coerced list with attribute access on each element. 

63 """ 

64 if isinstance(v, list): 

65 return [coerce_nested(item) if isinstance(item, dict) else item for item in v] 

66 return v 

67 

68 

69_COERCE_MAX_DEPTH = 20 

70_COERCE_MAX_BREADTH = 500 

71 

72 

73def coerce_nested(v: Any, *, _depth: int = 0) -> Any: 

74 """Recursively convert dicts to :class:`StructuredData` for attribute access. 

75 

76 Already-constructed Pydantic models (e.g. a real ``PromptResult`` 

77 passed by the gateway) are returned as-is. Depth is capped at 

78 ``_COERCE_MAX_DEPTH`` and breadth (keys per dict / items per list) 

79 at ``_COERCE_MAX_BREADTH`` to guard against resource exhaustion. 

80 

81 Args: 

82 v: Value to coerce — dict, list, or scalar. 

83 _depth: Internal recursion depth counter (do not set manually). 

84 

85 Returns: 

86 A ``StructuredData`` (for dicts), a list of coerced items, or 

87 the original value unchanged. 

88 

89 Examples: 

90 >>> from pydantic import BaseModel 

91 >>> result = coerce_nested({"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]}) 

92 >>> result.messages[0].content.text 

93 'hi' 

94 >>> class Existing(BaseModel): 

95 ... x: int = 1 

96 >>> coerce_nested(Existing()) is not None 

97 True 

98 """ 

99 if _depth >= _COERCE_MAX_DEPTH: 

100 return v 

101 if isinstance(v, BaseModel): 

102 return v 

103 if isinstance(v, dict): 

104 if len(v) > _COERCE_MAX_BREADTH: 

105 logger.warning("coerce_nested: dict has %d keys (limit %d); returning as plain dict", len(v), _COERCE_MAX_BREADTH) 

106 return v 

107 return StructuredData(**{k: coerce_nested(val, _depth=_depth + 1) for k, val in v.items()}) 

108 if isinstance(v, list): 

109 if len(v) > _COERCE_MAX_BREADTH: 

110 logger.warning("coerce_nested: list has %d items (limit %d); skipping coercion", len(v), _COERCE_MAX_BREADTH) 

111 return v 

112 return [coerce_nested(item, _depth=_depth + 1) for item in v] 

113 return v 

114 

115 

116_BLOCKED_MODULE_PREFIXES = ( 

117 "os", 

118 "sys", 

119 "subprocess", 

120 "shutil", 

121 "socket", 

122 "http.server", 

123 "ctypes", 

124 "importlib", 

125 "builtins", 

126 "code", 

127 "codeop", 

128 "compileall", 

129 "runpy", 

130) 

131 

132 

133@cache # noqa 

134def import_module(mod_name: str) -> ModuleType: 

135 """Import a module after validating the name is safe for dynamic loading. 

136 

137 Blocks dangerous stdlib modules that could enable arbitrary code 

138 execution if an attacker controls the plugin ``kind`` field. 

139 

140 Args: 

141 mod_name: fully qualified module name 

142 

143 Returns: 

144 A module. 

145 

146 Raises: 

147 ImportError: If the module name is blocked or contains path traversal. 

148 

149 Examples: 

150 >>> mod = import_module('mcpgateway.plugins.framework.utils') 

151 >>> hasattr(mod, 'import_module') 

152 True 

153 """ 

154 # Block path-traversal-style names and names with dangerous characters 

155 if ".." in mod_name or "/" in mod_name or "\\" in mod_name: 

156 raise ImportError(f"Plugin module name '{mod_name}' contains invalid characters.") 

157 # Block dangerous stdlib modules 

158 for blocked in _BLOCKED_MODULE_PREFIXES: 

159 if mod_name == blocked or mod_name.startswith(blocked + "."): 

160 raise ImportError(f"Plugin module '{mod_name}' is blocked for security reasons.") 

161 return importlib.import_module(mod_name) 

162 

163 

164def parse_class_name(name: str) -> tuple[str, str]: 

165 """Parse a class name into its constituents. 

166 

167 Args: 

168 name: the qualified class name 

169 

170 Returns: 

171 A pair containing the qualified class prefix and the class name 

172 

173 Examples: 

174 >>> parse_class_name('module.submodule.ClassName') 

175 ('module.submodule', 'ClassName') 

176 >>> parse_class_name('SimpleClass') 

177 ('', 'SimpleClass') 

178 >>> parse_class_name('package.Class') 

179 ('package', 'Class') 

180 """ 

181 clslist = name.rsplit(".", 1) 

182 if len(clslist) == 2: 

183 return (clslist[0], clslist[1]) 

184 return ("", name) 

185 

186 

187def matches(condition: PluginCondition, context: GlobalContext) -> bool: 

188 """Check if conditions match the current context. 

189 

190 Args: 

191 condition: the conditions on the plugin that are required for execution. 

192 context: the global context. 

193 

194 Returns: 

195 True if the plugin matches criteria. 

196 

197 Examples: 

198 >>> from mcpgateway.plugins.framework import GlobalContext, PluginCondition 

199 >>> cond = PluginCondition(server_ids={"srv1", "srv2"}) 

200 >>> ctx = GlobalContext(request_id="req1", server_id="srv1") 

201 >>> matches(cond, ctx) 

202 True 

203 >>> ctx2 = GlobalContext(request_id="req2", server_id="srv3") 

204 >>> matches(cond, ctx2) 

205 False 

206 >>> cond2 = PluginCondition(user_patterns=["admin"]) 

207 >>> ctx3 = GlobalContext(request_id="req3", user="admin_user") 

208 >>> matches(cond2, ctx3) 

209 True 

210 """ 

211 # Check server ID 

212 if condition.server_ids and context.server_id not in condition.server_ids: 

213 return False 

214 

215 # Check tenant ID 

216 if condition.tenant_ids and context.tenant_id not in condition.tenant_ids: 

217 return False 

218 

219 # Check user patterns (simple contains check, could be regex) 

220 if condition.user_patterns and context.user: 

221 if not any(pattern in context.user for pattern in condition.user_patterns): 

222 return False 

223 return True 

224 

225 

226def get_attr(obj: Any, attr: str, default: Any = "") -> Any: 

227 """Get attribute from object or dictionary with defensive access. 

228 

229 This utility function provides a consistent way to access attributes 

230 on objects that may be either ORM model instances or plain dictionaries. 

231 

232 Args: 

233 obj: The object or dictionary to get the attribute from. 

234 attr: The attribute name to retrieve. 

235 default: The default value to return if attribute is not found. 

236 

237 Returns: 

238 The attribute value, or the default if not found or obj is None. 

239 

240 Examples: 

241 >>> get_attr({"name": "test"}, "name") 

242 'test' 

243 >>> get_attr({"name": "test"}, "missing", "default") 

244 'default' 

245 >>> get_attr(None, "name", "fallback") 

246 'fallback' 

247 >>> class Obj: 

248 ... name = "obj_name" 

249 >>> get_attr(Obj(), "name") 

250 'obj_name' 

251 """ 

252 if obj is None: 

253 return default 

254 if hasattr(obj, attr): 

255 return getattr(obj, attr, default) or default 

256 if isinstance(obj, dict): 

257 return obj.get(attr, default) or default 

258 return default 

259 

260 

261def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]: 

262 """Extract the matchable value from a payload based on hook type. 

263 

264 This function maps hook types to their corresponding payload attributes 

265 that should be used for conditional matching. 

266 

267 Args: 

268 payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload). 

269 hook_type: The hook type identifier. 

270 

271 Returns: 

272 The matchable value (e.g., tool name, agent ID, resource URI) or None. 

273 

274 Examples: 

275 >>> from mcpgateway.plugins.framework import GlobalContext 

276 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload 

277 >>> payload = ToolPreInvokePayload(name="calculator", args={}) 

278 >>> get_matchable_value(payload, "tool_pre_invoke") 

279 'calculator' 

280 >>> get_matchable_value(payload, "unknown_hook") 

281 """ 

282 # Mapping: hook_type -> payload attribute name 

283 field_map = { 

284 "tool_pre_invoke": "name", 

285 "tool_post_invoke": "name", 

286 "prompt_pre_fetch": "prompt_id", 

287 "prompt_post_fetch": "prompt_id", 

288 "resource_pre_fetch": "uri", 

289 "resource_post_fetch": "uri", 

290 "agent_pre_invoke": "agent_id", 

291 "agent_post_invoke": "agent_id", 

292 } 

293 

294 field_name = field_map.get(hook_type) 

295 if field_name: 

296 return getattr(payload, field_name, None) 

297 return None 

298 

299 

300def payload_matches( 

301 payload: Any, 

302 hook_type: str, 

303 conditions: list[PluginCondition], 

304 context: GlobalContext, 

305) -> bool: 

306 """Check if a payload matches any of the plugin conditions. 

307 

308 This function provides generic conditional matching for all hook types. 

309 It checks both GlobalContext conditions (via matches()) and payload-specific 

310 conditions (tools, prompts, resources, agents). 

311 

312 Args: 

313 payload: The payload object. 

314 hook_type: The hook type identifier. 

315 conditions: List of conditions to check against. 

316 context: The global context. 

317 

318 Returns: 

319 True if the payload matches any condition or if no conditions are specified. 

320 

321 Examples: 

322 >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext 

323 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload 

324 >>> payload = ToolPreInvokePayload(name="calculator", args={}) 

325 >>> cond = PluginCondition(tools={"calculator"}) 

326 >>> ctx = GlobalContext(request_id="req1") 

327 >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx) 

328 True 

329 >>> cond2 = PluginCondition(tools={"other_tool"}) 

330 >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx) 

331 False 

332 >>> payload_matches(payload, "tool_pre_invoke", [], ctx) 

333 True 

334 """ 

335 # Mapping: hook_type -> PluginCondition attribute name 

336 condition_attr_map = { 

337 "tool_pre_invoke": "tools", 

338 "tool_post_invoke": "tools", 

339 "prompt_pre_fetch": "prompts", 

340 "prompt_post_fetch": "prompts", 

341 "resource_pre_fetch": "resources", 

342 "resource_post_fetch": "resources", 

343 "agent_pre_invoke": "agents", 

344 "agent_post_invoke": "agents", 

345 } 

346 

347 # If no conditions, match everything 

348 if not conditions: 

349 return True 

350 

351 # Check each condition (OR logic between conditions) 

352 for condition in conditions: 

353 # First check GlobalContext conditions 

354 if not matches(condition, context): 

355 continue 

356 

357 # Then check payload-specific conditions 

358 condition_attr = condition_attr_map.get(hook_type) 

359 if condition_attr: 

360 condition_set = getattr(condition, condition_attr, None) 

361 if condition_set: 

362 # Extract the matchable value from the payload 

363 payload_value = get_matchable_value(payload, hook_type) 

364 if payload_value and payload_value not in condition_set: 

365 # Payload value doesn't match this condition's set 

366 continue 

367 

368 # If we get here, this condition matched 

369 return True 

370 

371 # No conditions matched 

372 return False 

373 

374 

375class ORJSONResponse(JSONResponse): 

376 """JSON response using orjson for faster serialization. 

377 

378 Drop-in replacement for FastAPI's default JSONResponse. 

379 The framework already depends on both fastapi and orjson. 

380 

381 Example: 

382 >>> response = ORJSONResponse(content={"status": "healthy"}) 

383 >>> response.media_type 

384 'application/json' 

385 """ 

386 

387 media_type = "application/json" 

388 

389 def render(self, content: Any) -> bytes: 

390 """Render content to JSON bytes using orjson. 

391 

392 Args: 

393 content: The content to serialize to JSON. 

394 

395 Returns: 

396 JSON bytes ready for HTTP response. 

397 """ 

398 return orjson.dumps( 

399 content, 

400 option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY, 

401 )