Coverage for mcpgateway / plugins / framework / utils.py: 99%
91 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/utils.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor, Mihai Criveti, Fred Araujo
7Utility module for plugins layer.
8This module implements the utility functions associated with
9plugins.
10"""
12# Standard
13from functools import cache
14import importlib
15import logging
16from types import ModuleType
17from typing import Any, Optional
19# Third-Party
20from fastapi.responses import JSONResponse
21import orjson
22from pydantic import BaseModel, ConfigDict
24# First-Party
25from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition
27logger = logging.getLogger(__name__)
30class StructuredData(BaseModel):
31 """Dynamic model that provides attribute access on deserialized dicts.
33 When framework payload fields are typed as ``Any``, Pydantic keeps
34 nested dicts as plain dicts during ``model_validate``. This class
35 is used by :func:`coerce_nested` to convert those dicts into objects
36 with attribute-style access, preserving compatibility with plugin
37 code that expects ``payload.result.messages[0].content.text``.
39 Examples:
40 >>> sd = StructuredData(name="test", value=42)
41 >>> sd.name
42 'test'
43 >>> sd.model_dump()
44 {'name': 'test', 'value': 42}
45 """
47 model_config = ConfigDict(extra="allow")
50def coerce_messages(v: Any) -> Any:
51 """Convert nested dicts in a messages list to objects with attribute access.
53 Shared validator logic for agent payload ``messages`` fields.
54 When deserializing from JSON, messages arrive as plain dicts. This
55 converts each dict to a :class:`StructuredData` so plugin code like
56 ``payload.messages[0].content.text`` works regardless of the transport.
58 Args:
59 v: The raw value for the ``messages`` field.
61 Returns:
62 The coerced list with attribute access on each element.
63 """
64 if isinstance(v, list):
65 return [coerce_nested(item) if isinstance(item, dict) else item for item in v]
66 return v
69_COERCE_MAX_DEPTH = 20
70_COERCE_MAX_BREADTH = 500
73def coerce_nested(v: Any, *, _depth: int = 0) -> Any:
74 """Recursively convert dicts to :class:`StructuredData` for attribute access.
76 Already-constructed Pydantic models (e.g. a real ``PromptResult``
77 passed by the gateway) are returned as-is. Depth is capped at
78 ``_COERCE_MAX_DEPTH`` and breadth (keys per dict / items per list)
79 at ``_COERCE_MAX_BREADTH`` to guard against resource exhaustion.
81 Args:
82 v: Value to coerce — dict, list, or scalar.
83 _depth: Internal recursion depth counter (do not set manually).
85 Returns:
86 A ``StructuredData`` (for dicts), a list of coerced items, or
87 the original value unchanged.
89 Examples:
90 >>> from pydantic import BaseModel
91 >>> result = coerce_nested({"messages": [{"role": "user", "content": {"type": "text", "text": "hi"}}]})
92 >>> result.messages[0].content.text
93 'hi'
94 >>> class Existing(BaseModel):
95 ... x: int = 1
96 >>> coerce_nested(Existing()) is not None
97 True
98 """
99 if _depth >= _COERCE_MAX_DEPTH:
100 return v
101 if isinstance(v, BaseModel):
102 return v
103 if isinstance(v, dict):
104 if len(v) > _COERCE_MAX_BREADTH:
105 logger.warning("coerce_nested: dict has %d keys (limit %d); returning as plain dict", len(v), _COERCE_MAX_BREADTH)
106 return v
107 return StructuredData(**{k: coerce_nested(val, _depth=_depth + 1) for k, val in v.items()})
108 if isinstance(v, list):
109 if len(v) > _COERCE_MAX_BREADTH:
110 logger.warning("coerce_nested: list has %d items (limit %d); skipping coercion", len(v), _COERCE_MAX_BREADTH)
111 return v
112 return [coerce_nested(item, _depth=_depth + 1) for item in v]
113 return v
116_BLOCKED_MODULE_PREFIXES = (
117 "os",
118 "sys",
119 "subprocess",
120 "shutil",
121 "socket",
122 "http.server",
123 "ctypes",
124 "importlib",
125 "builtins",
126 "code",
127 "codeop",
128 "compileall",
129 "runpy",
130)
133@cache # noqa
134def import_module(mod_name: str) -> ModuleType:
135 """Import a module after validating the name is safe for dynamic loading.
137 Blocks dangerous stdlib modules that could enable arbitrary code
138 execution if an attacker controls the plugin ``kind`` field.
140 Args:
141 mod_name: fully qualified module name
143 Returns:
144 A module.
146 Raises:
147 ImportError: If the module name is blocked or contains path traversal.
149 Examples:
150 >>> mod = import_module('mcpgateway.plugins.framework.utils')
151 >>> hasattr(mod, 'import_module')
152 True
153 """
154 # Block path-traversal-style names and names with dangerous characters
155 if ".." in mod_name or "/" in mod_name or "\\" in mod_name:
156 raise ImportError(f"Plugin module name '{mod_name}' contains invalid characters.")
157 # Block dangerous stdlib modules
158 for blocked in _BLOCKED_MODULE_PREFIXES:
159 if mod_name == blocked or mod_name.startswith(blocked + "."):
160 raise ImportError(f"Plugin module '{mod_name}' is blocked for security reasons.")
161 return importlib.import_module(mod_name)
164def parse_class_name(name: str) -> tuple[str, str]:
165 """Parse a class name into its constituents.
167 Args:
168 name: the qualified class name
170 Returns:
171 A pair containing the qualified class prefix and the class name
173 Examples:
174 >>> parse_class_name('module.submodule.ClassName')
175 ('module.submodule', 'ClassName')
176 >>> parse_class_name('SimpleClass')
177 ('', 'SimpleClass')
178 >>> parse_class_name('package.Class')
179 ('package', 'Class')
180 """
181 clslist = name.rsplit(".", 1)
182 if len(clslist) == 2:
183 return (clslist[0], clslist[1])
184 return ("", name)
187def matches(condition: PluginCondition, context: GlobalContext) -> bool:
188 """Check if conditions match the current context.
190 Args:
191 condition: the conditions on the plugin that are required for execution.
192 context: the global context.
194 Returns:
195 True if the plugin matches criteria.
197 Examples:
198 >>> from mcpgateway.plugins.framework import GlobalContext, PluginCondition
199 >>> cond = PluginCondition(server_ids={"srv1", "srv2"})
200 >>> ctx = GlobalContext(request_id="req1", server_id="srv1")
201 >>> matches(cond, ctx)
202 True
203 >>> ctx2 = GlobalContext(request_id="req2", server_id="srv3")
204 >>> matches(cond, ctx2)
205 False
206 >>> cond2 = PluginCondition(user_patterns=["admin"])
207 >>> ctx3 = GlobalContext(request_id="req3", user="admin_user")
208 >>> matches(cond2, ctx3)
209 True
210 """
211 # Check server ID
212 if condition.server_ids and context.server_id not in condition.server_ids:
213 return False
215 # Check tenant ID
216 if condition.tenant_ids and context.tenant_id not in condition.tenant_ids:
217 return False
219 # Check user patterns (simple contains check, could be regex)
220 if condition.user_patterns and context.user:
221 if not any(pattern in context.user for pattern in condition.user_patterns):
222 return False
223 return True
226def get_attr(obj: Any, attr: str, default: Any = "") -> Any:
227 """Get attribute from object or dictionary with defensive access.
229 This utility function provides a consistent way to access attributes
230 on objects that may be either ORM model instances or plain dictionaries.
232 Args:
233 obj: The object or dictionary to get the attribute from.
234 attr: The attribute name to retrieve.
235 default: The default value to return if attribute is not found.
237 Returns:
238 The attribute value, or the default if not found or obj is None.
240 Examples:
241 >>> get_attr({"name": "test"}, "name")
242 'test'
243 >>> get_attr({"name": "test"}, "missing", "default")
244 'default'
245 >>> get_attr(None, "name", "fallback")
246 'fallback'
247 >>> class Obj:
248 ... name = "obj_name"
249 >>> get_attr(Obj(), "name")
250 'obj_name'
251 """
252 if obj is None:
253 return default
254 if hasattr(obj, attr):
255 return getattr(obj, attr, default) or default
256 if isinstance(obj, dict):
257 return obj.get(attr, default) or default
258 return default
261def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]:
262 """Extract the matchable value from a payload based on hook type.
264 This function maps hook types to their corresponding payload attributes
265 that should be used for conditional matching.
267 Args:
268 payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload).
269 hook_type: The hook type identifier.
271 Returns:
272 The matchable value (e.g., tool name, agent ID, resource URI) or None.
274 Examples:
275 >>> from mcpgateway.plugins.framework import GlobalContext
276 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload
277 >>> payload = ToolPreInvokePayload(name="calculator", args={})
278 >>> get_matchable_value(payload, "tool_pre_invoke")
279 'calculator'
280 >>> get_matchable_value(payload, "unknown_hook")
281 """
282 # Mapping: hook_type -> payload attribute name
283 field_map = {
284 "tool_pre_invoke": "name",
285 "tool_post_invoke": "name",
286 "prompt_pre_fetch": "prompt_id",
287 "prompt_post_fetch": "prompt_id",
288 "resource_pre_fetch": "uri",
289 "resource_post_fetch": "uri",
290 "agent_pre_invoke": "agent_id",
291 "agent_post_invoke": "agent_id",
292 }
294 field_name = field_map.get(hook_type)
295 if field_name:
296 return getattr(payload, field_name, None)
297 return None
300def payload_matches(
301 payload: Any,
302 hook_type: str,
303 conditions: list[PluginCondition],
304 context: GlobalContext,
305) -> bool:
306 """Check if a payload matches any of the plugin conditions.
308 This function provides generic conditional matching for all hook types.
309 It checks both GlobalContext conditions (via matches()) and payload-specific
310 conditions (tools, prompts, resources, agents).
312 Args:
313 payload: The payload object.
314 hook_type: The hook type identifier.
315 conditions: List of conditions to check against.
316 context: The global context.
318 Returns:
319 True if the payload matches any condition or if no conditions are specified.
321 Examples:
322 >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext
323 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload
324 >>> payload = ToolPreInvokePayload(name="calculator", args={})
325 >>> cond = PluginCondition(tools={"calculator"})
326 >>> ctx = GlobalContext(request_id="req1")
327 >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx)
328 True
329 >>> cond2 = PluginCondition(tools={"other_tool"})
330 >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx)
331 False
332 >>> payload_matches(payload, "tool_pre_invoke", [], ctx)
333 True
334 """
335 # Mapping: hook_type -> PluginCondition attribute name
336 condition_attr_map = {
337 "tool_pre_invoke": "tools",
338 "tool_post_invoke": "tools",
339 "prompt_pre_fetch": "prompts",
340 "prompt_post_fetch": "prompts",
341 "resource_pre_fetch": "resources",
342 "resource_post_fetch": "resources",
343 "agent_pre_invoke": "agents",
344 "agent_post_invoke": "agents",
345 }
347 # If no conditions, match everything
348 if not conditions:
349 return True
351 # Check each condition (OR logic between conditions)
352 for condition in conditions:
353 # First check GlobalContext conditions
354 if not matches(condition, context):
355 continue
357 # Then check payload-specific conditions
358 condition_attr = condition_attr_map.get(hook_type)
359 if condition_attr:
360 condition_set = getattr(condition, condition_attr, None)
361 if condition_set:
362 # Extract the matchable value from the payload
363 payload_value = get_matchable_value(payload, hook_type)
364 if payload_value and payload_value not in condition_set:
365 # Payload value doesn't match this condition's set
366 continue
368 # If we get here, this condition matched
369 return True
371 # No conditions matched
372 return False
375class ORJSONResponse(JSONResponse):
376 """JSON response using orjson for faster serialization.
378 Drop-in replacement for FastAPI's default JSONResponse.
379 The framework already depends on both fastapi and orjson.
381 Example:
382 >>> response = ORJSONResponse(content={"status": "healthy"})
383 >>> response.media_type
384 'application/json'
385 """
387 media_type = "application/json"
389 def render(self, content: Any) -> bytes:
390 """Render content to JSON bytes using orjson.
392 Args:
393 content: The content to serialize to JSON.
395 Returns:
396 JSON bytes ready for HTTP response.
397 """
398 return orjson.dumps(
399 content,
400 option=orjson.OPT_NON_STR_KEYS | orjson.OPT_SERIALIZE_NUMPY,
401 )