Coverage for mcpgateway / plugins / framework / utils.py: 96%

52 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/framework/utils.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor, Mihai Criveti 

6 

7Utility module for plugins layer. 

8This module implements the utility functions associated with 

9plugins. 

10""" 

11 

12# Standard 

13from functools import cache 

14import importlib 

15from types import ModuleType 

16from typing import Any, Optional 

17 

18# First-Party 

19from mcpgateway.plugins.framework.models import GlobalContext, PluginCondition 

20 

21 

22@cache # noqa 

23def import_module(mod_name: str) -> ModuleType: 

24 """Import a module. 

25 

26 Args: 

27 mod_name: fully qualified module name 

28 

29 Returns: 

30 A module. 

31 

32 Examples: 

33 >>> import sys 

34 >>> mod = import_module('sys') 

35 >>> mod is sys 

36 True 

37 >>> os_mod = import_module('os') 

38 >>> hasattr(os_mod, 'path') 

39 True 

40 """ 

41 return importlib.import_module(mod_name) 

42 

43 

44def parse_class_name(name: str) -> tuple[str, str]: 

45 """Parse a class name into its constituents. 

46 

47 Args: 

48 name: the qualified class name 

49 

50 Returns: 

51 A pair containing the qualified class prefix and the class name 

52 

53 Examples: 

54 >>> parse_class_name('module.submodule.ClassName') 

55 ('module.submodule', 'ClassName') 

56 >>> parse_class_name('SimpleClass') 

57 ('', 'SimpleClass') 

58 >>> parse_class_name('package.Class') 

59 ('package', 'Class') 

60 """ 

61 clslist = name.rsplit(".", 1) 

62 if len(clslist) == 2: 

63 return (clslist[0], clslist[1]) 

64 return ("", name) 

65 

66 

67def matches(condition: PluginCondition, context: GlobalContext) -> bool: 

68 """Check if conditions match the current context. 

69 

70 Args: 

71 condition: the conditions on the plugin that are required for execution. 

72 context: the global context. 

73 

74 Returns: 

75 True if the plugin matches criteria. 

76 

77 Examples: 

78 >>> from mcpgateway.plugins.framework import GlobalContext, PluginCondition 

79 >>> cond = PluginCondition(server_ids={"srv1", "srv2"}) 

80 >>> ctx = GlobalContext(request_id="req1", server_id="srv1") 

81 >>> matches(cond, ctx) 

82 True 

83 >>> ctx2 = GlobalContext(request_id="req2", server_id="srv3") 

84 >>> matches(cond, ctx2) 

85 False 

86 >>> cond2 = PluginCondition(user_patterns=["admin"]) 

87 >>> ctx3 = GlobalContext(request_id="req3", user="admin_user") 

88 >>> matches(cond2, ctx3) 

89 True 

90 """ 

91 # Check server ID 

92 if condition.server_ids and context.server_id not in condition.server_ids: 

93 return False 

94 

95 # Check tenant ID 

96 if condition.tenant_ids and context.tenant_id not in condition.tenant_ids: 

97 return False 

98 

99 # Check user patterns (simple contains check, could be regex) 

100 if condition.user_patterns and context.user: 

101 if not any(pattern in context.user for pattern in condition.user_patterns): 

102 return False 

103 return True 

104 

105 

106def get_attr(obj: Any, attr: str, default: Any = "") -> Any: 

107 """Get attribute from object or dictionary with defensive access. 

108 

109 This utility function provides a consistent way to access attributes 

110 on objects that may be either ORM model instances or plain dictionaries. 

111 

112 Args: 

113 obj: The object or dictionary to get the attribute from. 

114 attr: The attribute name to retrieve. 

115 default: The default value to return if attribute is not found. 

116 

117 Returns: 

118 The attribute value, or the default if not found or obj is None. 

119 

120 Examples: 

121 >>> get_attr({"name": "test"}, "name") 

122 'test' 

123 >>> get_attr({"name": "test"}, "missing", "default") 

124 'default' 

125 >>> get_attr(None, "name", "fallback") 

126 'fallback' 

127 >>> class Obj: 

128 ... name = "obj_name" 

129 >>> get_attr(Obj(), "name") 

130 'obj_name' 

131 """ 

132 if obj is None: 

133 return default 

134 if hasattr(obj, attr): 

135 return getattr(obj, attr, default) or default 

136 if isinstance(obj, dict): 136 ↛ 138line 136 didn't jump to line 138 because the condition on line 136 was always true

137 return obj.get(attr, default) or default 

138 return default 

139 

140 

141def get_matchable_value(payload: Any, hook_type: str) -> Optional[str]: 

142 """Extract the matchable value from a payload based on hook type. 

143 

144 This function maps hook types to their corresponding payload attributes 

145 that should be used for conditional matching. 

146 

147 Args: 

148 payload: The payload object (e.g., ToolPreInvokePayload, AgentPreInvokePayload). 

149 hook_type: The hook type identifier. 

150 

151 Returns: 

152 The matchable value (e.g., tool name, agent ID, resource URI) or None. 

153 

154 Examples: 

155 >>> from mcpgateway.plugins.framework import GlobalContext 

156 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload 

157 >>> payload = ToolPreInvokePayload(name="calculator", args={}) 

158 >>> get_matchable_value(payload, "tool_pre_invoke") 

159 'calculator' 

160 >>> get_matchable_value(payload, "unknown_hook") 

161 """ 

162 # Mapping: hook_type -> payload attribute name 

163 field_map = { 

164 "tool_pre_invoke": "name", 

165 "tool_post_invoke": "name", 

166 "prompt_pre_fetch": "prompt_id", 

167 "prompt_post_fetch": "prompt_id", 

168 "resource_pre_fetch": "uri", 

169 "resource_post_fetch": "uri", 

170 "agent_pre_invoke": "agent_id", 

171 "agent_post_invoke": "agent_id", 

172 } 

173 

174 field_name = field_map.get(hook_type) 

175 if field_name: 

176 return getattr(payload, field_name, None) 

177 return None 

178 

179 

180def payload_matches( 

181 payload: Any, 

182 hook_type: str, 

183 conditions: list[PluginCondition], 

184 context: GlobalContext, 

185) -> bool: 

186 """Check if a payload matches any of the plugin conditions. 

187 

188 This function provides generic conditional matching for all hook types. 

189 It checks both GlobalContext conditions (via matches()) and payload-specific 

190 conditions (tools, prompts, resources, agents). 

191 

192 Args: 

193 payload: The payload object. 

194 hook_type: The hook type identifier. 

195 conditions: List of conditions to check against. 

196 context: The global context. 

197 

198 Returns: 

199 True if the payload matches any condition or if no conditions are specified. 

200 

201 Examples: 

202 >>> from mcpgateway.plugins.framework import PluginCondition, GlobalContext 

203 >>> from mcpgateway.plugins.framework.hooks.tools import ToolPreInvokePayload 

204 >>> payload = ToolPreInvokePayload(name="calculator", args={}) 

205 >>> cond = PluginCondition(tools={"calculator"}) 

206 >>> ctx = GlobalContext(request_id="req1") 

207 >>> payload_matches(payload, "tool_pre_invoke", [cond], ctx) 

208 True 

209 >>> cond2 = PluginCondition(tools={"other_tool"}) 

210 >>> payload_matches(payload, "tool_pre_invoke", [cond2], ctx) 

211 False 

212 >>> payload_matches(payload, "tool_pre_invoke", [], ctx) 

213 True 

214 """ 

215 # Mapping: hook_type -> PluginCondition attribute name 

216 condition_attr_map = { 

217 "tool_pre_invoke": "tools", 

218 "tool_post_invoke": "tools", 

219 "prompt_pre_fetch": "prompts", 

220 "prompt_post_fetch": "prompts", 

221 "resource_pre_fetch": "resources", 

222 "resource_post_fetch": "resources", 

223 "agent_pre_invoke": "agents", 

224 "agent_post_invoke": "agents", 

225 } 

226 

227 # If no conditions, match everything 

228 if not conditions: 

229 return True 

230 

231 # Check each condition (OR logic between conditions) 

232 for condition in conditions: 

233 # First check GlobalContext conditions 

234 if not matches(condition, context): 

235 continue 

236 

237 # Then check payload-specific conditions 

238 condition_attr = condition_attr_map.get(hook_type) 

239 if condition_attr: 239 ↛ 249line 239 didn't jump to line 249 because the condition on line 239 was always true

240 condition_set = getattr(condition, condition_attr, None) 

241 if condition_set: 

242 # Extract the matchable value from the payload 

243 payload_value = get_matchable_value(payload, hook_type) 

244 if payload_value and payload_value not in condition_set: 

245 # Payload value doesn't match this condition's set 

246 continue 

247 

248 # If we get here, this condition matched 

249 return True 

250 

251 # No conditions matched 

252 return False