Coverage for mcpgateway / plugins / framework / manager.py: 98%
321 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/plugins/framework/manager.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Teryl Taylor, Mihai Criveti, Fred Araujo
7Plugin manager.
8Module that manages and calls plugins at hookpoints throughout the gateway.
10This module provides the core plugin management functionality including:
11- Plugin lifecycle management (initialization, execution, shutdown)
12- Timeout protection for plugin execution
13- Context management with automatic cleanup
14- Priority-based plugin ordering
15- Conditional plugin execution based on prompts/servers/tenants
17Examples:
18 >>> # Initialize plugin manager with configuration
19 >>> manager = PluginManager("plugins/config.yaml")
20 >>> # await manager.initialize() # Called in async context
22 >>> # Create test payload and context
23 >>> from mcpgateway.plugins.framework.models import GlobalContext
24 >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload
25 >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={"user": "input"})
26 >>> context = GlobalContext(request_id="123")
27 >>> # result, contexts = await manager.prompt_pre_fetch(payload, context) # Called in async context
28"""
30# Standard
31import asyncio
32import copy
33import logging
34import threading
35from typing import Any, Optional, Union
37# Third-Party
38from pydantic import BaseModel, RootModel
40# First-Party
41from mcpgateway.observability import create_span
42from mcpgateway.plugins.framework.base import HookRef, Plugin
43from mcpgateway.plugins.framework.errors import convert_exception_to_error, PluginError, PluginViolationError
44from mcpgateway.plugins.framework.hooks.policies import apply_policy, DefaultHookPolicy, HookPayloadPolicy
45from mcpgateway.plugins.framework.loader.config import ConfigLoader
46from mcpgateway.plugins.framework.loader.plugin import PluginLoader
47from mcpgateway.plugins.framework.memory import copyonwrite
48from mcpgateway.plugins.framework.models import Config, GlobalContext, PluginContext, PluginContextTable, PluginErrorModel, PluginMode, PluginPayload, PluginResult
49from mcpgateway.plugins.framework.observability import current_trace_id, ObservabilityProvider
50from mcpgateway.plugins.framework.registry import PluginInstanceRegistry
51from mcpgateway.plugins.framework.settings import settings
52from mcpgateway.plugins.framework.utils import payload_matches
54# Use standard logging to avoid circular imports (plugins -> services -> plugins)
55logger = logging.getLogger(__name__)
57# Configuration constants
58DEFAULT_PLUGIN_TIMEOUT = 30 # seconds
59MAX_PAYLOAD_SIZE = 1_000_000 # 1MB
60CONTEXT_CLEANUP_INTERVAL = 300 # 5 minutes
61CONTEXT_MAX_AGE = 3600 # 1 hour
62HTTP_AUTH_CHECK_PERMISSION_HOOK = "http_auth_check_permission"
63DECISION_PLUGIN_METADATA_KEY = "_decision_plugin"
64RESERVED_INTERNAL_METADATA_KEYS = frozenset({DECISION_PLUGIN_METADATA_KEY})
67class PluginTimeoutError(Exception):
68 """Raised when a plugin execution exceeds the timeout limit."""
71class PayloadSizeError(ValueError):
72 """Raised when a payload exceeds the maximum allowed size."""
75class PluginExecutor:
76 """Executes a list of plugins with timeout protection and error handling.
78 This class manages the execution of plugins in priority order, handling:
79 - Timeout protection for each plugin
80 - Context management between plugins
81 - Error isolation to prevent plugin failures from affecting the gateway
82 - Metadata aggregation from multiple plugins
84 Examples:
85 >>> executor = PluginExecutor()
86 >>> # In async context:
87 >>> # result, contexts = await executor.execute(
88 >>> # plugins=[plugin1, plugin2],
89 >>> # payload=payload,
90 >>> # global_context=context,
91 >>> # plugin_run=pre_prompt_fetch,
92 >>> # compare=pre_prompt_matches
93 >>> # )
94 """
96 def __init__(
97 self,
98 config: Optional[Config] = None,
99 timeout: int = DEFAULT_PLUGIN_TIMEOUT,
100 observability: Optional[ObservabilityProvider] = None,
101 hook_policies: Optional[dict[str, HookPayloadPolicy]] = None,
102 ):
103 """Initialize the plugin executor.
105 Args:
106 config: the plugin manager configuration.
107 timeout: Maximum execution time per plugin in seconds.
108 observability: Optional observability provider implementing ObservabilityProvider protocol.
109 hook_policies: Per-hook-type payload modification policies.
110 """
111 self.timeout = timeout
112 self.config = config
113 self.observability = observability
114 self.hook_policies: dict[str, HookPayloadPolicy] = hook_policies or {}
115 self.default_hook_policy = DefaultHookPolicy(settings.default_hook_policy)
117 async def execute(
118 self,
119 hook_refs: list[HookRef],
120 payload: PluginPayload,
121 global_context: GlobalContext,
122 hook_type: str,
123 local_contexts: Optional[PluginContextTable] = None,
124 violations_as_exceptions: bool = False,
125 ) -> tuple[PluginResult, PluginContextTable | None]:
126 """Execute plugins in priority order with timeout protection.
128 Args:
129 hook_refs: List of hook references to execute, sorted by priority.
130 payload: The payload to be processed by plugins.
131 global_context: Shared context for all plugins containing request metadata.
132 hook_type: The hook type identifier (e.g., "tool_pre_invoke").
133 local_contexts: Optional existing contexts from previous hook executions.
134 violations_as_exceptions: Raise violations as exceptions rather than as returns.
136 Returns:
137 A tuple containing:
138 - PluginResult with processing status, modified payload, and metadata
139 - PluginContextTable with updated local contexts for each plugin
141 Raises:
142 PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE.
143 PluginError: If there is an error inside a plugin.
144 PluginViolationError: If a violation occurs and violation_as_exceptions is set.
146 Examples:
147 >>> # Execute plugins with timeout protection
148 >>> from mcpgateway.plugins.framework.hooks.prompts import PromptHookType
149 >>> executor = PluginExecutor(timeout=30)
150 >>> # Assuming you have a registry instance:
151 >>> # plugins = registry.get_plugins_for_hook(PromptHookType.PROMPT_PRE_FETCH)
152 >>> # In async context:
153 >>> # result, contexts = await executor.execute(
154 >>> # plugins=plugins,
155 >>> # payload=PromptPrehookPayload(prompt_id="123", name="test", args={}),
156 >>> # global_context=GlobalContext(request_id="123"),
157 >>> # plugin_run=pre_prompt_fetch,
158 >>> # compare=pre_prompt_matches
159 >>> # )
160 """
161 if not hook_refs:
162 return (PluginResult(modified_payload=None), None)
164 # Validate payload size
165 self._validate_payload_size(payload)
167 # Look up the policy for this hook type (may be None)
168 policy = self.hook_policies.get(hook_type)
170 res_local_contexts = {}
171 combined_metadata: dict[str, Any] = {}
172 current_payload: PluginPayload | None = None
173 decision_plugin_name: Optional[str] = None
174 max_retry_delay_ms: int = 0
175 executed_plugins = 0
176 skipped_plugins = 0
177 stopped_by_plugin: Optional[str] = None
179 with create_span(
180 "plugin.hook.invoke",
181 {
182 "plugin.hook.type": hook_type,
183 "plugin.chain.length": len(hook_refs),
184 },
185 ) as hook_chain_span:
186 for hook_ref in hook_refs:
187 # Skip disabled plugins
188 if hook_ref.plugin_ref.mode == PluginMode.DISABLED:
189 skipped_plugins += 1
190 continue
192 # Check if plugin conditions match current context
193 if hook_ref.plugin_ref.conditions and not payload_matches(payload, hook_type, hook_ref.plugin_ref.conditions, global_context):
194 logger.debug("Skipping plugin %s - conditions not met", hook_ref.plugin_ref.name)
195 skipped_plugins += 1
196 continue
198 tmp_global_context = GlobalContext(
199 request_id=global_context.request_id,
200 user=global_context.user,
201 tenant_id=global_context.tenant_id,
202 server_id=global_context.server_id,
203 state={} if not global_context.state else copyonwrite(global_context.state),
204 metadata={} if not global_context.metadata else copyonwrite(global_context.metadata),
205 )
206 # Get or create local context for this plugin
207 local_context_key = global_context.request_id + hook_ref.plugin_ref.uuid
208 if local_contexts and local_context_key in local_contexts:
209 local_context = local_contexts[local_context_key]
210 local_context.global_context = tmp_global_context
211 else:
212 local_context = PluginContext(global_context=tmp_global_context)
213 res_local_contexts[local_context_key] = local_context
215 # When a policy exists or default=deny is active, deep-copy the
216 # payload before handing it to the plugin. The plugin operates on
217 # the copy, so in-place nested mutations cannot pollute the live chain.
218 effective_payload = current_payload if current_payload is not None else payload
219 needs_isolation = policy or self.default_hook_policy == DefaultHookPolicy.DENY or isinstance(effective_payload, RootModel)
220 if needs_isolation:
221 plugin_input = effective_payload.model_copy(deep=True) if isinstance(effective_payload, BaseModel) else copy.deepcopy(effective_payload)
222 else:
223 plugin_input = effective_payload
225 result = await self.execute_plugin(
226 hook_ref,
227 plugin_input,
228 local_context,
229 violations_as_exceptions,
230 global_context,
231 combined_metadata,
232 )
233 executed_plugins += 1
235 # Propagate retry signal — take the largest delay requested by any plugin
236 max_retry_delay_ms = max(max_retry_delay_ms, result.retry_delay_ms)
238 # Apply policy-based controlled merge (per-plugin)
239 if result.modified_payload is not None:
240 if policy:
241 if isinstance(result.modified_payload, type(effective_payload)) and isinstance(effective_payload, BaseModel):
242 filtered = apply_policy(
243 effective_payload,
244 result.modified_payload,
245 policy,
246 )
247 if filtered is not None:
248 current_payload = filtered
249 decision_plugin_name = hook_ref.plugin_ref.name
250 else:
251 if isinstance(result.modified_payload, (PluginPayload, dict)):
252 logger.debug(
253 "Plugin %s returned cross-type payload (%s -> %s) on hook %s; accepting without field filtering",
254 hook_ref.plugin_ref.name,
255 type(effective_payload).__name__,
256 type(result.modified_payload).__name__,
257 hook_type,
258 )
259 current_payload = result.modified_payload
260 decision_plugin_name = hook_ref.plugin_ref.name
261 else:
262 logger.warning(
263 "Plugin %s returned unexpected type %s on hook %s; ignoring modification",
264 hook_ref.plugin_ref.name,
265 type(result.modified_payload).__name__,
266 hook_type,
267 )
268 elif self.default_hook_policy == DefaultHookPolicy.ALLOW:
269 current_payload = result.modified_payload
270 decision_plugin_name = hook_ref.plugin_ref.name
271 else:
272 logger.warning(
273 "Plugin %s attempted payload modification on hook %s but no policy is defined and default is deny",
274 hook_ref.plugin_ref.name,
275 hook_type,
276 )
278 # Both ENFORCE and ENFORCE_IGNORE_ERROR honour continue_processing=False
279 # and halt the chain. They differ only in error handling.
280 if not result.continue_processing and hook_ref.plugin_ref.mode in (PluginMode.ENFORCE, PluginMode.ENFORCE_IGNORE_ERROR):
281 stopped_by_plugin = hook_ref.plugin_ref.name
282 if hook_chain_span is not None:
283 hook_chain_span.set_attribute("plugin.chain.stopped", True)
284 hook_chain_span.set_attribute("plugin.chain.stopped_by", hook_ref.plugin_ref.name)
285 hook_chain_span.set_attribute("plugin.executed_count", executed_plugins)
286 hook_chain_span.set_attribute("plugin.skipped_count", skipped_plugins)
287 if hook_type == HTTP_AUTH_CHECK_PERMISSION_HOOK and decision_plugin_name:
288 combined_metadata[DECISION_PLUGIN_METADATA_KEY] = decision_plugin_name
289 return (
290 PluginResult(
291 continue_processing=False,
292 modified_payload=current_payload,
293 violation=result.violation,
294 metadata=combined_metadata,
295 ),
296 res_local_contexts,
297 )
299 if hook_chain_span is not None:
300 hook_chain_span.set_attribute("plugin.executed_count", executed_plugins)
301 hook_chain_span.set_attribute("plugin.skipped_count", skipped_plugins)
302 hook_chain_span.set_attribute("plugin.chain.stopped", stopped_by_plugin is not None)
304 if hook_type == HTTP_AUTH_CHECK_PERMISSION_HOOK and decision_plugin_name:
305 combined_metadata[DECISION_PLUGIN_METADATA_KEY] = decision_plugin_name
307 return (PluginResult(continue_processing=True, modified_payload=current_payload, violation=None, metadata=combined_metadata, retry_delay_ms=max_retry_delay_ms), res_local_contexts)
309 async def execute_plugin(
310 self,
311 hook_ref: HookRef,
312 payload: PluginPayload,
313 local_context: PluginContext,
314 violations_as_exceptions: bool,
315 global_context: Optional[GlobalContext] = None,
316 combined_metadata: Optional[dict[str, Any]] = None,
317 ) -> PluginResult:
318 """Execute a single plugin with timeout protection.
320 Args:
321 hook_ref: Hooking structure that contains the plugin and hook.
322 payload: The payload to be processed by plugins.
323 local_context: local context.
324 violations_as_exceptions: Raise violations as exceptions rather than as returns.
325 global_context: Shared context for all plugins containing request metadata.
326 combined_metadata: combination of the metadata of all plugins.
328 Returns:
329 A tuple containing:
330 - PluginResult with processing status, modified payload, and metadata
331 - PluginContextTable with updated local contexts for each plugin
333 Raises:
334 PayloadSizeError: If the payload exceeds MAX_PAYLOAD_SIZE.
335 PluginError: If there is an error inside a plugin.
336 PluginViolationError: If a violation occurs and violation_as_exceptions is set.
337 """
338 try:
339 # Execute plugin with timeout protection
340 result = await self._execute_with_timeout(hook_ref, payload, local_context)
341 # Only merge global state for enforce modes; permissive plugins
342 # operate on copy-on-write snapshots and should not mutate shared state.
343 if local_context.global_context and global_context and hook_ref.plugin_ref.mode in (PluginMode.ENFORCE, PluginMode.ENFORCE_IGNORE_ERROR):
344 global_context.state.update(local_context.global_context.state)
345 global_context.metadata.update(local_context.global_context.metadata)
346 # Aggregate metadata from all plugins
347 if result.metadata and combined_metadata is not None:
348 combined_metadata.update({k: v for k, v in result.metadata.items() if k not in RESERVED_INTERNAL_METADATA_KEYS})
350 # Track payload modifications
351 # if result.modified_payload is not None:
352 # current_payload = result.modified_payload
354 # Set plugin name in violation if present
355 if result.violation:
356 result.violation.plugin_name = hook_ref.plugin_ref.plugin.name
358 # Handle plugin blocking the request
359 if not result.continue_processing:
360 if hook_ref.plugin_ref.mode == PluginMode.ENFORCE:
361 logger.warning("Plugin %s blocked request in enforce mode", hook_ref.plugin_ref.plugin.name)
362 if violations_as_exceptions:
363 if result.violation:
364 plugin_name = result.violation.plugin_name
365 violation_reason = result.violation.reason
366 violation_desc = result.violation.description
367 violation_code = result.violation.code
368 raise PluginViolationError(
369 f"{hook_ref.name} blocked by plugin {plugin_name}: {violation_code} - {violation_reason} ({violation_desc})",
370 violation=result.violation,
371 )
372 raise PluginViolationError(f"{hook_ref.name} blocked by plugin")
373 return PluginResult(
374 continue_processing=False,
375 modified_payload=payload,
376 violation=result.violation,
377 metadata=combined_metadata,
378 )
379 if hook_ref.plugin_ref.mode == PluginMode.PERMISSIVE:
380 logger.warning(
381 "Plugin %s would block (permissive mode): %s",
382 hook_ref.plugin_ref.plugin.name,
383 result.violation.description if result.violation else "No description",
384 )
385 return result
386 except asyncio.TimeoutError as exc:
387 logger.error("Plugin %s timed out after %ds", hook_ref.plugin_ref.name, self.timeout)
388 if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.mode == PluginMode.ENFORCE:
389 raise PluginError(
390 error=PluginErrorModel(
391 message=f"Plugin {hook_ref.plugin_ref.name} exceeded {self.timeout}s timeout",
392 plugin_name=hook_ref.plugin_ref.name,
393 )
394 ) from exc
395 # In permissive or enforce_ignore_error mode, continue with next plugin
396 except PluginViolationError:
397 raise
398 except PluginError as pe:
399 logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(pe))
400 if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.mode == PluginMode.ENFORCE:
401 raise
402 except Exception as e:
403 logger.error("Plugin %s failed with error: %s", hook_ref.plugin_ref.name, str(e))
404 if (self.config and self.config.plugin_settings.fail_on_plugin_error) or hook_ref.plugin_ref.mode == PluginMode.ENFORCE:
405 raise PluginError(error=convert_exception_to_error(e, hook_ref.plugin_ref.name)) from e
406 # In permissive or enforce_ignore_error mode, continue with next plugin
407 # Return a result indicating processing should continue despite the error
408 return PluginResult(continue_processing=True)
410 async def _execute_with_timeout(self, hook_ref: HookRef, payload: PluginPayload, context: PluginContext) -> PluginResult:
411 """Execute a plugin with timeout protection.
413 Args:
414 hook_ref: Reference to the hook and plugin to execute.
415 payload: Payload to process.
416 context: Plugin execution context.
418 Returns:
419 Result from plugin execution.
421 Raises:
422 asyncio.TimeoutError: If plugin exceeds timeout.
423 asyncio.CancelledError: If plugin execution is cancelled.
424 Exception: Re-raised from plugin hook execution failures.
425 """
426 # Start observability span if tracing is active
427 trace_id = current_trace_id.get()
428 span_id = None
430 if trace_id and self.observability:
431 try:
432 span_id = self.observability.start_span(
433 trace_id=trace_id,
434 name=f"plugin.execute.{hook_ref.plugin_ref.name}",
435 kind="internal",
436 resource_type="plugin",
437 resource_name=hook_ref.plugin_ref.name,
438 attributes={
439 "plugin.name": hook_ref.plugin_ref.name,
440 "plugin.uuid": hook_ref.plugin_ref.uuid,
441 "plugin.mode": hook_ref.plugin_ref.mode.value if hasattr(hook_ref.plugin_ref.mode, "value") else str(hook_ref.plugin_ref.mode),
442 "plugin.priority": hook_ref.plugin_ref.priority,
443 "plugin.timeout": self.timeout,
444 },
445 )
446 except Exception as e:
447 logger.debug("Plugin observability start_span failed: %s", e)
449 with create_span(
450 "plugin.execute",
451 {
452 "plugin.name": hook_ref.plugin_ref.name,
453 "plugin.uuid": hook_ref.plugin_ref.uuid,
454 "plugin.mode": hook_ref.plugin_ref.mode.value if hasattr(hook_ref.plugin_ref.mode, "value") else str(hook_ref.plugin_ref.mode),
455 "plugin.priority": hook_ref.plugin_ref.priority,
456 "plugin.timeout": self.timeout,
457 "plugin.hook.type": hook_ref.name,
458 "plugin.kind": getattr(getattr(hook_ref.plugin_ref.plugin, "config", None), "kind", None),
459 "contextforge.runtime": "python",
460 },
461 ) as otel_span:
462 # Execute plugin
463 try:
464 result = await asyncio.wait_for(hook_ref.hook(payload, context), timeout=self.timeout)
465 except Exception:
466 if span_id is not None:
467 try:
468 self.observability.end_span(span_id=span_id, status="error")
469 except Exception: # nosec B110
470 pass
471 raise
473 if otel_span is not None:
474 otel_span.set_attribute("plugin.had_violation", result.violation is not None)
475 otel_span.set_attribute("plugin.modified_payload", result.modified_payload is not None)
476 otel_span.set_attribute("plugin.continue_processing", result.continue_processing)
477 otel_span.set_attribute("plugin.stopped_chain", not result.continue_processing)
479 # End span with success
480 if span_id is not None:
481 try:
482 self.observability.end_span(
483 span_id=span_id,
484 status="ok",
485 attributes={
486 "plugin.had_violation": result.violation is not None,
487 "plugin.modified_payload": result.modified_payload is not None,
488 "plugin.continue_processing": result.continue_processing,
489 },
490 )
491 except Exception as e:
492 logger.debug("Plugin observability end_span failed: %s", e)
494 return result
496 def _validate_payload_size(self, payload: Any) -> None:
497 """Validate that payload doesn't exceed size limits.
499 Args:
500 payload: The payload to validate.
502 Raises:
503 PayloadSizeError: If payload exceeds MAX_PAYLOAD_SIZE.
504 """
505 # For PromptPrehookPayload, check args size
506 if hasattr(payload, "args") and payload.args:
507 total_size = sum(len(str(v)) for v in payload.args.values())
508 if total_size > MAX_PAYLOAD_SIZE:
509 raise PayloadSizeError(f"Payload size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes")
510 # For PromptPosthookPayload, check result size
511 elif hasattr(payload, "result") and payload.result:
512 # Estimate size of result messages
513 total_size = len(str(payload.result))
514 if total_size > MAX_PAYLOAD_SIZE:
515 raise PayloadSizeError(f"Result size {total_size} exceeds limit of {MAX_PAYLOAD_SIZE} bytes")
518class PluginManager:
519 """Plugin manager for managing the plugin lifecycle.
521 This class implements a thread-safe Borg singleton pattern to ensure consistent
522 plugin management across the application. It handles:
523 - Plugin discovery and loading from configuration
524 - Plugin lifecycle management (initialization, execution, shutdown)
525 - Context management with automatic cleanup
526 - Hook execution orchestration
528 Thread Safety:
529 Uses double-checked locking to prevent race conditions when multiple threads
530 create PluginManager instances simultaneously. The first instance to acquire
531 the lock loads the configuration; subsequent instances reuse the shared state.
533 Attributes:
534 config: The loaded plugin configuration.
535 plugin_count: Number of currently loaded plugins.
536 initialized: Whether the manager has been initialized.
538 Examples:
539 >>> # Initialize plugin manager
540 >>> manager = PluginManager("plugins/config.yaml")
541 >>> # In async context:
542 >>> # await manager.initialize()
543 >>> # print(f"Loaded {manager.plugin_count} plugins")
544 >>>
545 >>> # Execute prompt hooks
546 >>> from mcpgateway.plugins.framework.models import GlobalContext
547 >>> from mcpgateway.plugins.framework.hooks.prompts import PromptPrehookPayload
548 >>> payload = PromptPrehookPayload(prompt_id="123", name="test", args={})
549 >>> context = GlobalContext(request_id="req-123")
550 >>> # In async context:
551 >>> # result, contexts = await manager.prompt_pre_fetch(payload, context)
552 >>>
553 >>> # Shutdown when done
554 >>> # await manager.shutdown()
555 """
557 __shared_state: dict[Any, Any] = {}
558 __lock: threading.Lock = threading.Lock() # Thread safety for synchronous init
559 _async_lock: asyncio.Lock | None = None # Async lock for initialize/shutdown
560 _loader: PluginLoader = PluginLoader()
561 _initialized: bool = False
562 _registry: PluginInstanceRegistry = PluginInstanceRegistry()
563 _config: Config | None = None
564 _config_path: str | None = None
565 _executor: PluginExecutor | None = None
567 def __init__(
568 self,
569 config: str = "",
570 timeout: int = DEFAULT_PLUGIN_TIMEOUT,
571 observability: Optional[ObservabilityProvider] = None,
572 hook_policies: Optional[dict[str, HookPayloadPolicy]] = None,
573 ):
574 """Initialize plugin manager.
576 PluginManager implements a thread-safe Borg singleton:
577 - Shared state is initialized only once across all instances.
578 - Subsequent instantiations reuse same state and skip config reload.
579 - Uses double-checked locking to prevent race conditions in multi-threaded environments.
581 Thread Safety:
582 The initialization uses a double-checked locking pattern to ensure that
583 config loading only happens once, even when multiple threads create
584 PluginManager instances simultaneously.
586 Args:
587 config: Path to plugin configuration file (YAML).
588 timeout: Maximum execution time per plugin in seconds.
589 observability: Optional observability provider implementing ObservabilityProvider protocol.
590 hook_policies: Per-hook-type payload modification policies (injected by gateway).
592 Examples:
593 >>> # Initialize with configuration file
594 >>> manager = PluginManager("plugins/config.yaml")
596 >>> # Initialize with custom timeout
597 >>> manager = PluginManager("plugins/config.yaml", timeout=60)
598 """
599 self.__dict__ = self.__shared_state
601 # Only initialize once (first instance when shared state is empty)
602 # Use lock to prevent race condition in multi-threaded environments
603 if not self.__shared_state:
604 with self.__lock:
605 # Double-check after acquiring lock (another thread may have initialized)
606 if not self.__shared_state:
607 if config:
608 self._config = ConfigLoader.load_config(config)
609 self._config_path = config
611 # Update executor with timeout, observability, and policies
612 self._executor = PluginExecutor(
613 config=self._config,
614 timeout=timeout,
615 observability=observability,
616 hook_policies=hook_policies,
617 )
618 elif hook_policies:
619 # Allow hook policies to be injected after initial Borg creation.
620 # This handles the case where the first PluginManager instantiation
621 # (e.g. from a service) didn't have policies, but a later one does.
622 with self.__lock:
623 executor = self._get_executor()
624 # Only update timeout if caller provided a non-default value
625 if timeout != DEFAULT_PLUGIN_TIMEOUT:
626 executor.timeout = timeout
627 if not executor.hook_policies:
628 executor.hook_policies = hook_policies
629 elif executor.hook_policies != hook_policies:
630 logger.warning("PluginManager: hook_policies already set; ignoring new policies (call reset() first to replace them)")
631 if observability and not executor.observability:
632 executor.observability = observability
633 elif self._executor is None:
634 # Defensive initialization for unusual state transitions in tests.
635 with self.__lock:
636 if self._executor is None:
637 self._executor = PluginExecutor(config=self._config, timeout=timeout, observability=observability)
639 def _get_executor(self) -> PluginExecutor:
640 """Get plugin executor, creating it lazily if necessary.
642 Returns:
643 PluginExecutor: The plugin executor instance.
644 """
645 if self._executor is None:
646 self._executor = PluginExecutor(config=self._config)
647 return self._executor
649 @property
650 def executor(self) -> PluginExecutor:
651 """Expose executor for tests and internal callers.
653 Returns:
654 PluginExecutor: The plugin executor instance.
655 """
656 return self._get_executor()
658 @executor.setter
659 def executor(self, value: PluginExecutor) -> None:
660 """Set the plugin executor instance.
662 Args:
663 value: The plugin executor to assign.
664 """
665 self._executor = value
667 @classmethod
668 def reset(cls) -> None:
669 """Reset the Borg pattern shared state.
671 This method clears all shared state, allowing a fresh PluginManager
672 instance to be created with new configuration. Primarily used for testing.
674 Thread-safe: Uses lock to ensure atomic reset operation.
676 Examples:
677 >>> # Between tests, reset shared state
678 >>> PluginManager.reset()
679 >>> manager = PluginManager("new_config.yaml")
680 """
681 with cls.__lock:
682 cls.__shared_state.clear()
683 cls._initialized = False
684 cls._config = None
685 cls._config_path = None
686 cls._async_lock = None
687 cls._registry = PluginInstanceRegistry()
688 cls._executor = None
689 cls._loader = PluginLoader()
691 @property
692 def config(self) -> Config | None:
693 """Plugin manager configuration.
695 Returns:
696 The plugin configuration object or None if not configured.
697 """
698 return self._config
700 @property
701 def plugin_count(self) -> int:
702 """Number of plugins loaded.
704 Returns:
705 The number of currently loaded plugins.
706 """
707 return self._registry.plugin_count
709 @property
710 def initialized(self) -> bool:
711 """Plugin manager initialization status.
713 Returns:
714 True if the plugin manager has been initialized.
715 """
716 return self._initialized
718 @property
719 def observability(self) -> Optional[ObservabilityProvider]:
720 """Current observability provider.
722 Returns:
723 The observability provider or None if not configured.
724 """
725 return self._executor.observability
727 @observability.setter
728 def observability(self, provider: Optional[ObservabilityProvider]) -> None:
729 """Set the observability provider.
731 Thread-safe: uses lock to prevent races with concurrent readers.
733 Args:
734 provider: ObservabilityProvider to inject into the executor.
735 """
736 with self.__lock:
737 self._executor.observability = provider
739 def get_plugin(self, name: str) -> Optional[Plugin]:
740 """Get a plugin by name.
742 Args:
743 name: the name of the plugin to return.
745 Returns:
746 A plugin.
747 """
748 plugin_ref = self._registry.get_plugin(name)
749 return plugin_ref.plugin if plugin_ref else None
751 def has_hooks_for(self, hook_type: str) -> bool:
752 """Check if there are any hooks registered for a specific hook type.
754 Args:
755 hook_type: The type of hook to check for.
757 Returns:
758 True if there are hooks registered for the specified type, False otherwise.
759 """
760 return self._registry.has_hooks_for(hook_type)
762 async def initialize(self) -> None:
763 """Initialize the plugin manager and load all configured plugins.
765 This method:
766 1. Loads plugin configurations from the config file
767 2. Instantiates each enabled plugin
768 3. Registers plugins with the registry
769 4. Validates plugin initialization
771 Thread Safety:
772 Uses asyncio.Lock to prevent concurrent initialization from multiple
773 coroutines or async tasks. Combined with threading.Lock in __init__
774 for full multi-threaded safety.
776 Raises:
777 RuntimeError: If plugin initialization fails with an exception.
778 ValueError: If a plugin cannot be initialized or registered.
780 Examples:
781 >>> manager = PluginManager("plugins/config.yaml")
782 >>> # In async context:
783 >>> # await manager.initialize()
784 >>> # Manager is now ready to execute plugins
785 """
786 # Initialize async lock lazily (can't create asyncio.Lock in class definition)
787 with self.__lock:
788 if self._async_lock is None:
789 self._async_lock = asyncio.Lock()
791 async with self._async_lock:
792 # Double-check after acquiring lock
793 if self._initialized:
794 logger.debug("Plugin manager already initialized")
795 return
797 # Defensive cleanup: registry should be empty when not initialized
798 if self._registry.plugin_count:
799 logger.debug("Plugin registry not empty before initialize; clearing stale plugins")
800 await self._registry.shutdown()
802 plugins = self._config.plugins if self._config and self._config.plugins else []
803 loaded_count = 0
805 for plugin_config in plugins:
806 try:
807 # For disabled plugins, create a stub plugin without full instantiation
808 if plugin_config.mode != PluginMode.DISABLED:
809 # Fully instantiate enabled plugins
810 plugin = await self._loader.load_and_instantiate_plugin(plugin_config)
811 if plugin:
812 self._registry.register(plugin)
813 loaded_count += 1
814 logger.info("Loaded plugin: %s (mode: %s)", plugin_config.name, plugin_config.mode)
815 else:
816 raise ValueError(f"Unable to instantiate plugin: {plugin_config.name}")
817 else:
818 logger.info("Plugin: %s is disabled. Ignoring.", plugin_config.name)
820 except Exception as e:
821 # Clean error message without stack trace spam
822 logger.error("Failed to load plugin %s: {%s}", plugin_config.name, str(e))
823 if self._config and not self._config.plugin_settings.fail_on_plugin_error:
824 logger.warning("Skipping plugin %s because fail_on_plugin_error is disabled", plugin_config.name)
825 continue
826 # Let it crash gracefully with a clean error
827 raise RuntimeError(f"Plugin initialization failed: {plugin_config.name} - {str(e)}") from e
829 self._initialized = True
830 logger.info("Plugin manager initialized with %s plugins", loaded_count)
832 async def shutdown(self) -> None:
833 """Shutdown all plugins and cleanup resources.
835 This method:
836 1. Shuts down all registered plugins
837 2. Clears the plugin registry
838 3. Cleans up stored contexts
839 4. Resets initialization state
841 Thread Safety:
842 Uses asyncio.Lock to prevent concurrent shutdown with initialization
843 or with another shutdown call.
845 Note: The config is preserved to allow modifying settings and re-initializing.
846 To fully reset for a new config, create a new PluginManager instance.
848 Examples:
849 >>> manager = PluginManager("plugins/config.yaml")
850 >>> # In async context:
851 >>> # await manager.initialize()
852 >>> # ... use the manager ...
853 >>> # await manager.shutdown()
854 """
855 # Initialize async lock lazily if needed
856 with self.__lock:
857 if self._async_lock is None:
858 self._async_lock = asyncio.Lock()
860 async with self._async_lock:
861 if not self._initialized:
862 logger.debug("Plugin manager not initialized, nothing to shutdown")
863 return
865 logger.info("Shutting down plugin manager")
867 # Shutdown all plugins
868 await self._registry.shutdown()
870 # Reset state to allow re-initialization
871 self._initialized = False
873 logger.info("Plugin manager shutdown complete")
875 async def invoke_hook(
876 self,
877 hook_type: str,
878 payload: PluginPayload,
879 global_context: GlobalContext,
880 local_contexts: Optional[PluginContextTable] = None,
881 violations_as_exceptions: bool = False,
882 ) -> tuple[PluginResult, PluginContextTable | None]:
883 """Invoke a set of plugins configured for the hook point in priority order.
885 Args:
886 hook_type: The type of hook to execute.
887 payload: The plugin payload for which the plugins will analyze and modify.
888 global_context: Shared context for all plugins with request metadata.
889 local_contexts: Optional existing contexts from previous hook executions.
890 violations_as_exceptions: Raise violations as exceptions rather than as returns.
892 Returns:
893 A tuple containing:
894 - PluginResult with processing status and modified payload
895 - PluginContextTable with plugin contexts for state management
897 Examples:
898 >>> manager = PluginManager("plugins/config.yaml")
899 >>> # In async context:
900 >>> # await manager.initialize()
901 >>> # payload = ResourcePreFetchPayload("file:///data.txt")
902 >>> # context = GlobalContext(request_id="123", server_id="srv1")
903 >>> # result, contexts = await manager.resource_pre_fetch(payload, context)
904 >>> # if result.continue_processing:
905 >>> # # Use modified payload
906 >>> # uri = result.modified_payload.uri
907 """
908 # Get plugins configured for this hook
909 hook_refs = self._registry.get_hook_refs_for_hook(hook_type=hook_type)
911 # Execute plugins
912 result = await self._get_executor().execute(hook_refs, payload, global_context, hook_type, local_contexts, violations_as_exceptions)
914 return result
916 async def invoke_hook_for_plugin(
917 self,
918 name: str,
919 hook_type: str,
920 payload: Union[PluginPayload, dict[str, Any], str],
921 context: Union[PluginContext, GlobalContext],
922 violations_as_exceptions: bool = False,
923 payload_as_json: bool = False,
924 ) -> PluginResult:
925 """Invoke a specific hook for a single named plugin.
927 This method allows direct invocation of a particular plugin's hook by name,
928 bypassing the normal priority-ordered execution. Useful for testing individual
929 plugins or when specific plugin behavior needs to be triggered independently.
931 Args:
932 name: The name of the plugin to invoke.
933 hook_type: The type of hook to execute (e.g., "prompt_pre_fetch").
934 payload: The plugin payload to be processed by the hook.
935 context: Plugin execution context (PluginContext) or GlobalContext (will be wrapped).
936 violations_as_exceptions: Raise violations as exceptions rather than returns.
937 payload_as_json: payload passed in as json rather than pydantic.
939 Returns:
940 PluginResult with processing status, modified payload, and metadata.
942 Raises:
943 PluginError: If the plugin or hook type cannot be found in the registry.
944 ValueError: If payload type does not match payload_as_json setting.
946 Examples:
947 >>> manager = PluginManager("plugins/config.yaml")
948 >>> # In async context:
949 >>> # await manager.initialize()
950 >>> # payload = PromptPrehookPayload(name="test", args={})
951 >>> # context = PluginContext(global_context=GlobalContext(request_id="123"))
952 >>> # result = await manager.invoke_hook_for_plugin(
953 >>> # name="auth_plugin",
954 >>> # hook_type="prompt_pre_fetch",
955 >>> # payload=payload,
956 >>> # context=context
957 >>> # )
958 """
959 # Auto-wrap GlobalContext in PluginContext for convenience
960 if isinstance(context, GlobalContext):
961 context = PluginContext(global_context=context)
963 hook_ref = self._registry.get_plugin_hook_by_name(name, hook_type)
964 if not hook_ref:
965 raise PluginError(
966 error=PluginErrorModel(
967 message=f"Unable to find {hook_type} for plugin {name}. Make sure the plugin is registered.",
968 plugin_name=name,
969 )
970 )
971 if payload_as_json:
972 plugin = hook_ref.plugin_ref.plugin
973 # When payload_as_json=True, payload should be str or dict
974 if isinstance(payload, (str, dict)):
975 pydantic_payload = plugin.json_to_payload(hook_type, payload)
976 return await self._get_executor().execute_plugin(hook_ref, pydantic_payload, context, violations_as_exceptions)
977 raise ValueError(f"When payload_as_json=True, payload must be str or dict, got {type(payload)}")
978 # When payload_as_json=False, payload should already be a PluginPayload
979 if not isinstance(payload, PluginPayload):
980 raise ValueError(f"When payload_as_json=False, payload must be a PluginPayload, got {type(payload)}")
981 return await self._get_executor().execute_plugin(hook_ref, payload, context, violations_as_exceptions)