Coverage for mcpgateway / plugins / framework / utils.py: 96%
52 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/utils.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor, Mihai Criveti
7Utility module for plugins layer.
8This module implements the utility functions associated with
9plugins.
10"""
12# Standard
13from functools import cache
14import importlib
15from types import ModuleType
16from typing import Any, Optional
18# First-Party
19from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition
22@cache # noqa
23def import_module(mod_name: str) -> ModuleType:
24 """Import a module.
26 Args:
27 mod_name: fully qualified module name
29 Returns:
30 A module.
32 Examples:
33 >>> import sys
34 >>> mod = import_module('sys')
35 >>> mod is sys
36 True
37 >>> os_mod = import_module('os')
38 >>> hasattr(os_mod, 'path')
39 True
40 """
41 return importlib.import_module(mod_name)
44def parse_class_name(name: str) -> tuple[str, str]:
45 """Parse a class name into its constituents.
47 Args:
48 name: the qualified class name
50 Returns:
51 A pair containing the qualified class prefix and the class name
53 Examples:
54 >>> parse_class_name('module.submodule.ClassName')
55 ('module.submodule', 'ClassName')
56 >>> parse_class_name('SimpleClass')
57 ('', 'SimpleClass')
58 >>> parse_class_name('package.Class')
59 ('package', 'Class')
60 """
61 clslist = name.rsplit(".", 1)
62 if len(clslist) == 2:
63 return (clslist[0], clslist[1])
64 return ("", name)
67def matches(condition: PluginCondition, context: GlobalContext) -> bool:
68 """Check if conditions match the current context.
70 Args:
71 condition: the conditions on the plugin that are required for execution.
72 context: the global context.
74 Returns:
75 True if the plugin matches criteria.
77 Examples:
78 >>> from mcpgateway.plugins.framework import GlobalContext, PluginCondition
79 >>> cond = PluginCondition(server_ids={"srv1", "srv2"})
80 >>> ctx = GlobalContext(request_id="req1", server_id="srv1")
81 >>> matches(cond, ctx)
82 True
83 >>> ctx2 = GlobalContext(request_id="req2", server_id="srv3")
84 >>> matches(cond, ctx2)
85 False
86 >>> cond2 = PluginCondition(user_patterns=["admin"])
87 >>> ctx3 = GlobalContext(request_id="req3", user="admin_user")
88 >>> matches(cond2, ctx3)
89 True
90 """
91 # Check server ID
92 if condition.server_ids and context.server_id not in condition.server_ids:
93 return False
95 # Check tenant ID
96 if condition.tenant_ids and context.tenant_id not in condition.tenant_ids:
97 return False
99 # Check user patterns (simple contains check, could be regex)
100 if condition.user_patterns and context.user:
101 if not any(pattern in context.user for pattern in condition.user_patterns):
102 return False
103 return True
106def get_attr(obj: Any, attr: str, default: Any = "") -> Any:
107 """Get attribute from object or dictionary with defensive access.
109 This utility function provides a consistent way to access attributes
110 on objects that may be either ORM model instances or plain dictionaries.
112 Args:
113 obj: The object or dictionary to get the attribute from.
114 attr: The attribute name to retrieve.
115 default: The default value to return if attribute is not found.
117 Returns:
118 The attribute value, or the default if not found or obj is None.
120 Examples:
121 >>> get_attr({"name": "test"}, "name")
122 'test'
123 >>> get_attr({"name": "test"}, "missing", "default")
124 'default'
125 >>> get_attr(None, "name", "fallback")
126 'fallback'
127 >>> class Obj:
128 ... name = "obj_name"
129 >>> get_attr(Obj(), "name")
130 'obj_name'
131 """
132 if obj is None:
133 return default
134 if hasattr(obj, attr):
135 return getattr(obj, attr, default) or default
136 if isinstance(obj, dict): 136 ↛ 138line 136 didn't jump to line 138 because the condition on line 136 was always true
137 return obj.get(attr, default) or default
138 return default
141def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]:
142 """Extract the matchable value from a payload based on hook type.
144 This function maps hook types to their corresponding payload attributes
145 that should be used for conditional matching.
147 Args:
148 payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload).
149 hook_type: The hook type identifier.
151 Returns:
152 The matchable value (e.g., tool name, agent ID, resource URI) or None.
154 Examples:
155 >>> from mcpgateway.plugins.framework import GlobalContext
156 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload
157 >>> payload = ToolPreInvokePayload(name="calculator", args={})
158 >>> get_matchable_value(payload, "tool_pre_invoke")
159 'calculator'
160 >>> get_matchable_value(payload, "unknown_hook")
161 """
162 # Mapping: hook_type -> payload attribute name
163 field_map = {
164 "tool_pre_invoke": "name",
165 "tool_post_invoke": "name",
166 "prompt_pre_fetch": "prompt_id",
167 "prompt_post_fetch": "prompt_id",
168 "resource_pre_fetch": "uri",
169 "resource_post_fetch": "uri",
170 "agent_pre_invoke": "agent_id",
171 "agent_post_invoke": "agent_id",
172 }
174 field_name = field_map.get(hook_type)
175 if field_name:
176 return getattr(payload, field_name, None)
177 return None
180def payload_matches(
181 payload: Any,
182 hook_type: str,
183 conditions: list[PluginCondition],
184 context: GlobalContext,
185) -> bool:
186 """Check if a payload matches any of the plugin conditions.
188 This function provides generic conditional matching for all hook types.
189 It checks both GlobalContext conditions (via matches()) and payload-specific
190 conditions (tools, prompts, resources, agents).
192 Args:
193 payload: The payload object.
194 hook_type: The hook type identifier.
195 conditions: List of conditions to check against.
196 context: The global context.
198 Returns:
199 True if the payload matches any condition or if no conditions are specified.
201 Examples:
202 >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext
203 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload
204 >>> payload = ToolPreInvokePayload(name="calculator", args={})
205 >>> cond = PluginCondition(tools={"calculator"})
206 >>> ctx = GlobalContext(request_id="req1")
207 >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx)
208 True
209 >>> cond2 = PluginCondition(tools={"other_tool"})
210 >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx)
211 False
212 >>> payload_matches(payload, "tool_pre_invoke", [], ctx)
213 True
214 """
215 # Mapping: hook_type -> PluginCondition attribute name
216 condition_attr_map = {
217 "tool_pre_invoke": "tools",
218 "tool_post_invoke": "tools",
219 "prompt_pre_fetch": "prompts",
220 "prompt_post_fetch": "prompts",
221 "resource_pre_fetch": "resources",
222 "resource_post_fetch": "resources",
223 "agent_pre_invoke": "agents",
224 "agent_post_invoke": "agents",
225 }
227 # If no conditions, match everything
228 if not conditions:
229 return True
231 # Check each condition (OR logic between conditions)
232 for condition in conditions:
233 # First check GlobalContext conditions
234 if not matches(condition, context):
235 continue
237 # Then check payload-specific conditions
238 condition_attr = condition_attr_map.get(hook_type)
239 if condition_attr: 239 ↛ 249line 239 didn't jump to line 249 because the condition on line 239 was always true
240 condition_set = getattr(condition, condition_attr, None)
241 if condition_set:
242 # Extract the matchable value from the payload
243 payload_value = get_matchable_value(payload, hook_type)
244 if payload_value and payload_value not in condition_set:
245 # Payload value doesn't match this condition's set
246 continue
248 # If we get here, this condition matched
249 return True
251 # No conditions matched
252 return False