Coverage for mcpgateway / plugins / framework / hooks / prompts.py: 100%

30 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/plugins/hooks/prompts.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Teryl Taylor, Fred Araujo 

6 

7Pydantic models for prompt plugins. 

8This module implements the pydantic models associated with 

9the base plugin layer including configurations, and contexts. 

10""" 

11 

12# Standard 

13from enum import Enum 

14from typing import Any, Optional 

15 

16# Third-Party 

17from pydantic import Field, field_validator 

18 

19# First-Party 

20from mcpgateway.plugins.framework.models import PluginPayload, PluginResult 

21from mcpgateway.plugins.framework.protocols import PromptResultLike # noqa: F401 # pylint: disable=unused-import 

22from mcpgateway.plugins.framework.utils import coerce_nested 

23 

24 

25class PromptHookType(str, Enum): 

26 """MCP Forge Gateway hook points. 

27 

28 Attributes: 

29 prompt_pre_fetch: The prompt pre hook. 

30 prompt_post_fetch: The prompt post hook. 

31 tool_pre_invoke: The tool pre invoke hook. 

32 tool_post_invoke: The tool post invoke hook. 

33 resource_pre_fetch: The resource pre fetch hook. 

34 resource_post_fetch: The resource post fetch hook. 

35 

36 Examples: 

37 >>> PromptHookType.PROMPT_PRE_FETCH 

38 <PromptHookType.PROMPT_PRE_FETCH: 'prompt_pre_fetch'> 

39 >>> PromptHookType.PROMPT_PRE_FETCH.value 

40 'prompt_pre_fetch' 

41 >>> PromptHookType('prompt_post_fetch') 

42 <PromptHookType.PROMPT_POST_FETCH: 'prompt_post_fetch'> 

43 >>> list(PromptHookType) 

44 [<PromptHookType.PROMPT_PRE_FETCH: 'prompt_pre_fetch'>, <PromptHookType.PROMPT_POST_FETCH: 'prompt_post_fetch'>] 

45 """ 

46 

47 PROMPT_PRE_FETCH = "prompt_pre_fetch" 

48 PROMPT_POST_FETCH = "prompt_post_fetch" 

49 

50 

51class PromptPrehookPayload(PluginPayload): 

52 """A prompt payload for a prompt prehook. 

53 

54 Attributes: 

55 prompt_id (str): The ID of the prompt template. 

56 args (dic[str,str]): The prompt template arguments. 

57 

58 Examples: 

59 >>> payload = PromptPrehookPayload(prompt_id="123", args={"user": "alice"}) 

60 >>> payload.prompt_id 

61 '123' 

62 >>> payload.args 

63 {'user': 'alice'} 

64 >>> payload2 = PromptPrehookPayload(prompt_id="empty") 

65 >>> payload2.args 

66 {} 

67 >>> p = PromptPrehookPayload(prompt_id="123", args={"name": "Bob", "time": "morning"}) 

68 >>> p.prompt_id 

69 '123' 

70 >>> p.args["name"] 

71 'Bob' 

72 """ 

73 

74 prompt_id: str 

75 args: Optional[dict[str, str]] = Field(default_factory=dict) 

76 

77 

78class PromptPosthookPayload(PluginPayload): 

79 """A prompt payload for a prompt posthook. 

80 

81 Attributes: 

82 prompt_id (str): The prompt ID. 

83 result (Any): The prompt result (accepts any PromptResultLike-satisfying object). 

84 

85 Examples: 

86 >>> from types import SimpleNamespace 

87 >>> result = SimpleNamespace(messages=[], description=None) 

88 >>> payload = PromptPosthookPayload(prompt_id="123", result=result) 

89 >>> payload.prompt_id 

90 '123' 

91 """ 

92 

93 prompt_id: str 

94 result: Any # Satisfies PromptResultLike protocol (messages, description attributes) 

95 

96 @field_validator("result", mode="before") 

97 @classmethod 

98 def _coerce_result(cls, v: Any) -> Any: 

99 """Convert nested dicts to objects with attribute access. 

100 

101 When deserializing from JSON (external server flows), ``result`` 

102 arrives as a plain dict. This validator converts it to a 

103 :class:`~mcpgateway.plugins.framework.utils.StructuredData` so 

104 that plugin code like ``payload.result.messages[0].content.text`` 

105 works regardless of the transport. 

106 

107 Args: 

108 v: The raw value for the ``result`` field. 

109 

110 Returns: 

111 The coerced value with attribute access, or the original value. 

112 """ 

113 if isinstance(v, dict): 

114 return coerce_nested(v) 

115 return v 

116 

117 

118PromptPrehookResult = PluginResult[PromptPrehookPayload] 

119PromptPosthookResult = PluginResult[PromptPosthookPayload] 

120 

121 

122def _register_prompt_hooks() -> None: 

123 """Register prompt hooks in the global registry. 

124 

125 This is called lazily to avoid circular import issues. 

126 """ 

127 # Import here to avoid circular dependency at module load time 

128 # First-Party 

129 from mcpgateway.plugins.framework.hooks.registry import get_hook_registry # pylint: disable=import-outside-toplevel 

130 

131 registry = get_hook_registry() 

132 

133 # Only register if not already registered (idempotent) 

134 if not registry.is_registered(PromptHookType.PROMPT_PRE_FETCH): 

135 registry.register_hook(PromptHookType.PROMPT_PRE_FETCH, PromptPrehookPayload, PromptPrehookResult) 

136 registry.register_hook(PromptHookType.PROMPT_POST_FETCH, PromptPosthookPayload, PromptPosthookResult) 

137 

138 

139_register_prompt_hooks()