Coverage for mcpgateway / services / prompt_service.py: 98%
986 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/prompt_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Prompt Service Implementation.
8This module implements prompt template management according to the MCP specification.
9It handles:
10- Prompt template registration and retrieval
11- Prompt argument validation
12- Template rendering with arguments
13- Resource embedding in prompts
14- Active/inactive prompt management
15"""
17# Standard
18import binascii
19from datetime import datetime, timezone
20from functools import lru_cache
21from string import Formatter
22import time
23from typing import Any, AsyncGenerator, Dict, List, Optional, Set, Union
24import uuid
26# Third-Party
27from jinja2 import Environment, meta, select_autoescape, Template
28from mcp import ClientSession
29from mcp.client.sse import sse_client
30from mcp.client.streamable_http import streamablehttp_client
31import orjson
32from pydantic import ValidationError
33from sqlalchemy import and_, delete, desc, not_, or_, select
34from sqlalchemy.exc import IntegrityError, MultipleResultsFound, OperationalError
35from sqlalchemy.orm import joinedload, selectinload, Session
37# First-Party
38from mcpgateway.common.models import Message, PromptResult, Role, TextContent
39from mcpgateway.config import settings
40from mcpgateway.db import EmailTeam
41from mcpgateway.db import EmailTeamMember as DbEmailTeamMember
42from mcpgateway.db import Gateway as DbGateway
43from mcpgateway.db import get_for_update
44from mcpgateway.db import Prompt as DbPrompt
45from mcpgateway.db import PromptMetric, PromptMetricsHourly, server_prompt_association
46from mcpgateway.observability import create_span, set_span_attribute, set_span_error
47from mcpgateway.plugins.framework import get_plugin_manager, GlobalContext, PluginContextTable, PluginManager, PromptHookType, PromptPosthookPayload, PromptPrehookPayload
48from mcpgateway.schemas import PromptCreate, PromptMetrics, PromptRead, PromptUpdate, TopPerformer
49from mcpgateway.services.audit_trail_service import get_audit_trail_service
50from mcpgateway.services.base_service import BaseService
51from mcpgateway.services.content_security import ContentSizeError, get_content_security_service
52from mcpgateway.services.event_service import EventService
53from mcpgateway.services.logging_service import LoggingService
54from mcpgateway.services.mcp_session_pool import get_mcp_session_pool, TransportType
55from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service
56from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
57from mcpgateway.services.observability_service import current_trace_id, ObservabilityService
58from mcpgateway.services.structured_logger import get_structured_logger
59from mcpgateway.services.team_management_service import TeamManagementService
60from mcpgateway.utils.create_slug import slugify
61from mcpgateway.utils.gateway_access import build_gateway_auth_headers
62from mcpgateway.utils.metrics_common import build_top_performers
63from mcpgateway.utils.pagination import unified_paginate
64from mcpgateway.utils.services_auth import decode_auth
65from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
66from mcpgateway.utils.trace_context import format_trace_team_scope
67from mcpgateway.utils.trace_redaction import is_input_capture_enabled, is_output_capture_enabled, serialize_trace_payload
68from mcpgateway.utils.url_auth import apply_query_param_auth, sanitize_exception_message
70# Cache import (lazy to avoid circular dependencies)
71_REGISTRY_CACHE = None
73# Module-level Jinja environment singleton for template caching
74_JINJA_ENV: Optional[Environment] = None
77def _get_jinja_env() -> Environment:
78 """Get or create the module-level Jinja environment singleton.
80 Returns:
81 Jinja2 Environment with autoescape and trim settings.
82 """
83 global _JINJA_ENV # pylint: disable=global-statement
84 if _JINJA_ENV is None:
85 _JINJA_ENV = Environment(
86 autoescape=select_autoescape(["html", "xml"]),
87 trim_blocks=True,
88 lstrip_blocks=True,
89 )
90 return _JINJA_ENV
93@lru_cache(maxsize=256)
94def _compile_jinja_template(template: str) -> Template:
95 """Cache compiled Jinja template by template string.
97 Args:
98 template: The template string to compile.
100 Returns:
101 Compiled Jinja Template object.
102 """
103 return _get_jinja_env().from_string(template)
106def _get_registry_cache():
107 """Get registry cache singleton lazily.
109 Returns:
110 RegistryCache instance.
111 """
112 global _REGISTRY_CACHE # pylint: disable=global-statement
113 if _REGISTRY_CACHE is None:
114 # First-Party
115 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
117 _REGISTRY_CACHE = registry_cache
118 return _REGISTRY_CACHE
121# Initialize logging service first
122logging_service = LoggingService()
123logger = logging_service.get_logger(__name__)
125# Initialize structured logger, audit trail, and metrics buffer for prompt operations
126structured_logger = get_structured_logger("prompt_service")
127audit_trail = get_audit_trail_service()
128metrics_buffer = get_metrics_buffer_service()
131class PromptError(Exception):
132 """Base class for prompt-related errors."""
135class PromptNotFoundError(PromptError):
136 """Raised when a requested prompt is not found."""
139class PromptNameConflictError(PromptError):
140 """Raised when a prompt name conflicts with existing (active or inactive) prompt."""
142 def __init__(self, name: str, enabled: bool = True, prompt_id: Optional[int] = None, visibility: str = "public") -> None:
143 """Initialize the error with prompt information.
145 Args:
146 name: The conflicting prompt name
147 enabled: Whether the existing prompt is enabled
148 prompt_id: ID of the existing prompt if available
149 visibility: Prompt visibility level (private, team, public).
151 Examples:
152 >>> from mcpgateway.services.prompt_service import PromptNameConflictError
153 >>> error = PromptNameConflictError("test_prompt")
154 >>> error.name
155 'test_prompt'
156 >>> error.enabled
157 True
158 >>> error.prompt_id is None
159 True
160 >>> error = PromptNameConflictError("inactive_prompt", False, 123)
161 >>> error.enabled
162 False
163 >>> error.prompt_id
164 123
165 """
166 self.name = name
167 self.enabled = enabled
168 self.prompt_id = prompt_id
169 message = f"{visibility.capitalize()} Prompt already exists with name: {name}"
170 if not enabled:
171 message += f" (currently inactive, ID: {prompt_id})"
172 super().__init__(message)
175class PromptValidationError(PromptError):
176 """Raised when prompt validation fails."""
179class PromptArgumentsJSONError(PromptError):
180 """Raised when prompt arguments JSON is invalid.
182 Attributes:
183 field_name: Name of the field containing invalid JSON
184 raw_value: First 200 characters of the invalid JSON string
185 json_error: The original JSON parsing error message
186 """
188 def __init__(self, field_name: str, json_error: str, raw_value: str = "", context: str = "") -> None:
189 """Initialize the error with JSON parsing details.
191 Args:
192 field_name: Name of the field (e.g., "arguments")
193 json_error: The JSON parsing error message
194 raw_value: First 200 characters of the invalid JSON (for logging)
195 context: Optional context about where the error occurred (e.g., "prompt abc-123")
197 Examples:
198 >>> from mcpgateway.services.prompt_service import PromptArgumentsJSONError
199 >>> error = PromptArgumentsJSONError("arguments", "unexpected character: line 1 column 5")
200 >>> error.field_name
201 'arguments'
202 >>> str(error)
203 'Invalid JSON in arguments field: unexpected character: line 1 column 5'
204 """
205 self.field_name = field_name
206 self.json_error = json_error
207 self.raw_value = raw_value[:200] if raw_value else ""
208 self.context = context
209 context_str = f" for {context}" if context else ""
210 message = f"Invalid JSON in {field_name} field{context_str}: {json_error}"
211 super().__init__(message)
214class PromptLockConflictError(PromptError):
215 """Raised when a prompt row is locked by another transaction.
217 Raises:
218 PromptLockConflictError: When attempting to modify a prompt that is
219 currently locked by another concurrent request.
220 """
223def _validate_prompt_team_assignment(db: Session, user_email: Optional[str], target_team_id: Optional[str]) -> None:
224 """Validate team assignment for prompt updates.
226 Args:
227 db: Database session used for membership checks.
228 user_email: Requesting user email. When omitted, ownership checks are skipped.
229 target_team_id: Team identifier to validate.
231 Raises:
232 ValueError: If team does not exist or caller lacks ownership.
233 """
234 if not target_team_id:
235 raise ValueError("Cannot set visibility to 'team' without a team_id")
237 team = db.query(EmailTeam).filter(EmailTeam.id == target_team_id).first()
238 if not team:
239 raise ValueError(f"Team {target_team_id} not found")
241 if not user_email:
242 return
244 membership = (
245 db.query(DbEmailTeamMember)
246 .filter(DbEmailTeamMember.team_id == target_team_id, DbEmailTeamMember.user_email == user_email, DbEmailTeamMember.is_active, DbEmailTeamMember.role == "owner")
247 .first()
248 )
249 if not membership:
250 raise ValueError("User membership in team not sufficient for this update.")
253class PromptService(BaseService):
254 """Service for managing prompt templates.
256 Handles:
257 - Template registration and retrieval
258 - Argument validation
259 - Template rendering
260 - Resource embedding
261 - Active/inactive status management
262 """
264 _visibility_model_cls = DbPrompt
266 def __init__(self) -> None:
267 """
268 Initialize the prompt service.
270 Sets up the Jinja2 environment for rendering prompt templates.
271 Although these templates are rendered as JSON for the API, if the output is ever
272 embedded into an HTML page, unescaped content could be exploited for cross-site scripting (XSS) attacks.
273 Enabling autoescaping for 'html' and 'xml' templates via select_autoescape helps mitigate this risk.
275 Examples:
276 >>> from mcpgateway.services.prompt_service import PromptService
277 >>> service = PromptService()
278 >>> isinstance(service._event_service, EventService)
279 True
280 >>> service._jinja_env is not None
281 True
282 """
283 self._event_service = EventService(channel_name="mcpgateway:prompt_events")
284 # Use the module-level singleton for template caching
285 self._jinja_env = _get_jinja_env()
286 self._plugin_manager: PluginManager | None = get_plugin_manager()
288 @staticmethod
289 def _should_fetch_gateway_prompt(prompt: DbPrompt) -> bool:
290 """Return whether a prompt must be executed against its source gateway.
292 Federated prompts are synced into the catalog as metadata via
293 ``list_prompts()``. Those records often have ``template=""``, which
294 means the gateway must call the upstream MCP ``prompts/get`` endpoint
295 instead of trying to render a local template.
297 Args:
298 prompt: Prompt ORM object resolved from the catalog.
300 Returns:
301 ``True`` when the prompt is gateway-backed and has no local template.
302 """
303 return bool(getattr(prompt, "gateway_id", None)) and not bool(getattr(prompt, "template", ""))
305 async def _fetch_gateway_prompt_result(self, prompt: DbPrompt, arguments: Optional[Dict[str, str]], user_identity: Optional[str]) -> PromptResult:
306 """Fetch a rendered prompt from the upstream MCP gateway.
308 Args:
309 prompt: Gateway-backed prompt record from the catalog.
310 arguments: Optional prompt-rendering arguments.
311 user_identity: Effective requester email for session-pool isolation.
313 Returns:
314 Prompt result normalized into ContextForge models.
316 Raises:
317 PromptError: If the gateway prompt cannot be fetched.
318 """
319 gateway = getattr(prompt, "gateway", None)
320 if gateway is None:
321 raise PromptError(f"Prompt '{prompt.name}' is gateway-backed but missing gateway metadata")
323 gateway_url = str(gateway.url)
324 headers = build_gateway_auth_headers(gateway)
325 auth_query_params_decrypted: Optional[Dict[str, str]] = None
327 if getattr(gateway, "auth_type", None) == "query_param" and getattr(gateway, "auth_query_params", None):
328 auth_query_params_decrypted = {}
329 for param_key, encrypted_value in (gateway.auth_query_params or {}).items():
330 try:
331 decoded = decode_auth(encrypted_value)
332 auth_query_params_decrypted[param_key] = decoded.get(param_key, "")
333 except Exception as exc:
334 raise PromptError(f"Failed to decode query-parameter auth for prompt gateway '{gateway.id}'") from exc
335 if auth_query_params_decrypted:
336 gateway_url = apply_query_param_auth(gateway_url, auth_query_params_decrypted)
338 remote_name = getattr(prompt, "original_name", None) or prompt.name
339 pool_user_identity = (user_identity or "anonymous").strip() or "anonymous"
340 gateway_id = str(getattr(gateway, "id", ""))
341 transport = str(getattr(gateway, "transport", "streamable_http") or "streamable_http").lower()
342 pool_transport_type = TransportType.SSE if transport == "sse" else TransportType.STREAMABLE_HTTP
343 prompt_arguments = arguments or None
345 try:
346 if settings.mcp_session_pool_enabled:
347 try:
348 pool = get_mcp_session_pool()
349 except RuntimeError:
350 pool = None
351 if pool is not None:
352 async with pool.session(
353 url=gateway_url,
354 headers=headers,
355 transport_type=pool_transport_type,
356 user_identity=pool_user_identity,
357 gateway_id=gateway_id,
358 ) as pooled:
359 remote_result = await pooled.session.get_prompt(remote_name, arguments=prompt_arguments)
360 return PromptResult(
361 messages=[
362 Message.model_validate(message.model_dump(by_alias=True, exclude_none=True) if hasattr(message, "model_dump") else message)
363 for message in getattr(remote_result, "messages", []) or []
364 ],
365 description=getattr(remote_result, "description", None) or prompt.description,
366 )
368 if transport == "sse":
369 async with sse_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as streams:
370 async with ClientSession(*streams) as session:
371 await session.initialize()
372 remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments)
373 else:
374 async with streamablehttp_client(url=gateway_url, headers=headers, timeout=settings.health_check_timeout) as (read_stream, write_stream, _get_session_id):
375 async with ClientSession(read_stream, write_stream) as session:
376 await session.initialize()
377 remote_result = await session.get_prompt(remote_name, arguments=prompt_arguments)
379 return PromptResult(
380 messages=[
381 Message.model_validate(message.model_dump(by_alias=True, exclude_none=True) if hasattr(message, "model_dump") else message)
382 for message in getattr(remote_result, "messages", []) or []
383 ],
384 description=getattr(remote_result, "description", None) or prompt.description,
385 )
386 except Exception as exc:
387 sanitized_error = sanitize_exception_message(str(exc), auth_query_params_decrypted)
388 raise PromptError(f"Failed to fetch prompt '{remote_name}' from gateway: {sanitized_error}") from exc
390 @staticmethod
391 def validate_arguments_json(args_value: Any, context: str = "") -> List[Dict[str, Any]]:
392 """Validate and parse prompt arguments JSON.
394 Args:
395 args_value: The raw arguments value from form data
396 context: Additional context for error messages (e.g., "prompt 123", "new prompt")
398 Returns:
399 Parsed arguments as a list of dictionaries
401 Raises:
402 PromptArgumentsJSONError: If the JSON is invalid
404 Examples:
405 >>> from mcpgateway.services.prompt_service import PromptService
406 >>> PromptService.validate_arguments_json('[]')
407 []
408 >>> PromptService.validate_arguments_json('[{"name": "test"}]')
409 [{'name': 'test'}]
410 >>> PromptService.validate_arguments_json(None)
411 []
412 >>> try:
413 ... PromptService.validate_arguments_json('invalid json')
414 ... except PromptArgumentsJSONError as e:
415 ... print(e.field_name)
416 arguments
417 """
418 # Handle None or empty values
419 if args_value is None or args_value == "":
420 return []
422 # Ensure it's a string
423 if not isinstance(args_value, str):
424 args_value = str(args_value)
426 # Strip whitespace
427 args_value = args_value.strip()
429 # If still empty after strip, return empty list
430 if not args_value:
431 return []
433 # Parse JSON
434 try:
435 arguments = orjson.loads(args_value)
437 # Ensure the result is a list (JSON array)
438 if not isinstance(arguments, list):
439 error_msg = f"Arguments must be a JSON array, got {type(arguments).__name__}"
440 logger.error(f"Invalid arguments type{' for ' + context if context else ''}: {error_msg}. " f"Raw value (first 200 chars): {args_value[:200]!r}")
441 raise PromptArgumentsJSONError(field_name="arguments", json_error=error_msg, raw_value=args_value, context=context)
443 return arguments
444 except orjson.JSONDecodeError as json_err:
445 # Log the error with context
446 logger.error(f"Invalid JSON in arguments field{' for ' + context if context else ''}: {json_err}. " f"Raw value (first 200 chars): {args_value[:200]!r}")
447 # Raise custom exception
448 raise PromptArgumentsJSONError(field_name="arguments", json_error=str(json_err), raw_value=args_value, context=context) from json_err
450 async def initialize(self) -> None:
451 """Initialize the service."""
452 logger.info("Initializing prompt service")
453 await self._event_service.initialize()
455 async def shutdown(self) -> None:
456 """Shutdown the service.
458 Examples:
459 >>> from mcpgateway.services.prompt_service import PromptService
460 >>> from unittest.mock import AsyncMock
461 >>> import asyncio
462 >>> service = PromptService()
463 >>> service._event_service = AsyncMock()
464 >>> asyncio.run(service.shutdown())
465 >>> # Verify event service shutdown was called
466 >>> service._event_service.shutdown.assert_awaited_once()
467 """
468 await self._event_service.shutdown()
469 logger.info("Prompt service shutdown complete")
471 async def get_top_prompts(self, db: Session, limit: Optional[int] = 5, include_deleted: bool = False) -> List[TopPerformer]:
472 """Retrieve the top-performing prompts based on execution count.
474 Queries the database to get prompts with their metrics, ordered by the number of executions
475 in descending order. Combines recent raw metrics with historical hourly rollups for complete
476 historical coverage. Returns a list of TopPerformer objects containing prompt details and
477 performance metrics. Results are cached for performance.
479 Args:
480 db (Session): Database session for querying prompt metrics.
481 limit (Optional[int]): Maximum number of prompts to return. Defaults to 5.
482 include_deleted (bool): Whether to include deleted prompts from rollups.
484 Returns:
485 List[TopPerformer]: A list of TopPerformer objects, each containing:
486 - id: Prompt ID.
487 - name: Prompt name.
488 - execution_count: Total number of executions.
489 - avg_response_time: Average response time in seconds, or None if no metrics.
490 - success_rate: Success rate percentage, or None if no metrics.
491 - last_execution: Timestamp of the last execution, or None if no metrics.
492 """
493 # Check cache first (if enabled)
494 # First-Party
495 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
497 effective_limit = limit or 5
498 cache_key = f"top_prompts:{effective_limit}:include_deleted={include_deleted}"
500 if is_cache_enabled():
501 cached = metrics_cache.get(cache_key)
502 if cached is not None:
503 return cached
505 # Use combined query that includes both raw metrics and rollup data
506 # First-Party
507 from mcpgateway.services.metrics_query_service import get_top_performers_combined # pylint: disable=import-outside-toplevel
509 results = get_top_performers_combined(
510 db=db,
511 metric_type="prompt",
512 entity_model=DbPrompt,
513 limit=effective_limit,
514 include_deleted=include_deleted,
515 )
516 top_performers = build_top_performers(results)
518 # Cache the result (if enabled)
519 if is_cache_enabled():
520 metrics_cache.set(cache_key, top_performers)
522 return top_performers
524 def convert_prompt_to_read(self, db_prompt: DbPrompt, include_metrics: bool = False) -> PromptRead:
525 """
526 Convert a DbPrompt instance to a PromptRead Pydantic model,
527 optionally including aggregated metrics computed from the associated PromptMetric records.
529 Args:
530 db_prompt: Db prompt to convert
531 include_metrics: Whether to include metrics in the result. Defaults to False.
532 Set to False for list operations to avoid N+1 query issues.
534 Returns:
535 PromptRead: Pydantic model instance
536 """
537 arg_schema = db_prompt.argument_schema or {}
538 properties = arg_schema.get("properties", {})
539 required_list = arg_schema.get("required", [])
540 arguments_list = []
541 for arg_name, prop in properties.items():
542 arguments_list.append(
543 {
544 "name": arg_name,
545 "description": prop.get("description") or "",
546 "required": arg_name in required_list,
547 }
548 )
550 # Compute aggregated metrics only if requested (avoids N+1 queries in list operations)
551 if include_metrics:
552 # Use metrics_summary which combines raw + hourly rollup data (matches tool_service pattern)
553 metrics = db_prompt.metrics_summary
554 metrics_dict = {
555 "totalExecutions": metrics["total_executions"],
556 "successfulExecutions": metrics["successful_executions"],
557 "failedExecutions": metrics["failed_executions"],
558 "failureRate": metrics["failure_rate"],
559 "minResponseTime": metrics["min_response_time"],
560 "maxResponseTime": metrics["max_response_time"],
561 "avgResponseTime": metrics["avg_response_time"],
562 "lastExecutionTime": metrics["last_execution_time"],
563 }
564 else:
565 metrics_dict = None
567 original_name = getattr(db_prompt, "original_name", None) or db_prompt.name
568 custom_name = getattr(db_prompt, "custom_name", None) or original_name
569 custom_name_slug = getattr(db_prompt, "custom_name_slug", None) or slugify(custom_name)
570 display_name = getattr(db_prompt, "display_name", None) or custom_name
572 prompt_dict = {
573 "id": db_prompt.id,
574 "name": db_prompt.name,
575 "original_name": original_name,
576 "custom_name": custom_name,
577 "custom_name_slug": custom_name_slug,
578 "display_name": display_name,
579 "gateway_id": getattr(db_prompt, "gateway_id", None),
580 "gateway_slug": getattr(db_prompt, "gateway_slug", None),
581 "description": db_prompt.description,
582 "template": db_prompt.template,
583 "arguments": arguments_list,
584 "created_at": db_prompt.created_at,
585 "updated_at": db_prompt.updated_at,
586 "enabled": db_prompt.enabled,
587 "metrics": metrics_dict,
588 "tags": db_prompt.tags or [],
589 "visibility": db_prompt.visibility,
590 "team": getattr(db_prompt, "team", None),
591 # Include metadata fields for proper API response
592 "created_by": getattr(db_prompt, "created_by", None),
593 "modified_by": getattr(db_prompt, "modified_by", None),
594 "created_from_ip": getattr(db_prompt, "created_from_ip", None),
595 "created_via": getattr(db_prompt, "created_via", None),
596 "created_user_agent": getattr(db_prompt, "created_user_agent", None),
597 "modified_from_ip": getattr(db_prompt, "modified_from_ip", None),
598 "modified_via": getattr(db_prompt, "modified_via", None),
599 "modified_user_agent": getattr(db_prompt, "modified_user_agent", None),
600 "version": getattr(db_prompt, "version", None),
601 "team_id": getattr(db_prompt, "team_id", None),
602 "owner_email": getattr(db_prompt, "owner_email", None),
603 }
604 return PromptRead.model_validate(prompt_dict)
606 def _get_team_name(self, db: Session, team_id: Optional[str]) -> Optional[str]:
607 """Retrieve the team name given a team ID.
609 Args:
610 db (Session): Database session for querying teams.
611 team_id (Optional[str]): The ID of the team.
613 Returns:
614 Optional[str]: The name of the team if found, otherwise None.
615 """
616 if not team_id:
617 return None
618 team = db.query(EmailTeam).filter(EmailTeam.id == team_id, EmailTeam.is_active.is_(True)).first()
619 db.commit() # Release transaction to avoid idle-in-transaction
620 return team.name if team else None
622 def _compute_prompt_name(self, custom_name: str, gateway: Optional[Any] = None) -> str:
623 """Compute the stored prompt name from custom_name and gateway context.
625 Args:
626 custom_name: Prompt name to slugify and store.
627 gateway: Optional gateway for namespacing.
629 Returns:
630 The stored prompt name with gateway prefix when applicable.
631 """
632 name_slug = slugify(custom_name)
633 if gateway:
634 gateway_slug = slugify(gateway.name)
635 return f"{gateway_slug}{settings.gateway_tool_name_separator}{name_slug}"
636 return name_slug
638 async def register_prompt(
639 self,
640 db: Session,
641 prompt: PromptCreate,
642 created_by: Optional[str] = None,
643 created_from_ip: Optional[str] = None,
644 created_via: Optional[str] = None,
645 created_user_agent: Optional[str] = None,
646 import_batch_id: Optional[str] = None,
647 federation_source: Optional[str] = None,
648 team_id: Optional[str] = None,
649 owner_email: Optional[str] = None,
650 visibility: Optional[str] = "public",
651 ) -> PromptRead:
652 """Register a new prompt template.
654 Args:
655 db: Database session
656 prompt: Prompt creation schema
657 created_by: Username who created this prompt
658 created_from_ip: IP address of creator
659 created_via: Creation method (ui, api, import, federation)
660 created_user_agent: User agent of creation request
661 import_batch_id: UUID for bulk import operations
662 federation_source: Source gateway for federated prompts
663 team_id (Optional[str]): Team ID to assign the prompt to.
664 owner_email (Optional[str]): Email of the user who owns this prompt.
665 visibility (str): Prompt visibility level (private, team, public).
667 Returns:
668 Created prompt information
670 Raises:
671 IntegrityError: If a database integrity error occurs.
672 PromptNameConflictError: If a prompt with the same name already exists.
673 PromptError: For other prompt registration errors
674 ContentSizeError: For template size exceed
676 Examples:
677 >>> import logging
678 >>> logging.disable(logging.CRITICAL)
679 >>> from mcpgateway.services.prompt_service import PromptService
680 >>> from unittest.mock import AsyncMock, MagicMock
681 >>> service = PromptService()
682 >>> db = MagicMock()
683 >>> prompt = MagicMock()
684 >>> prompt.template = "Hello {{ name }}"
685 >>> prompt.name = "test-prompt"
686 >>> prompt.custom_name = None
687 >>> prompt.display_name = None
688 >>> prompt.arguments = []
689 >>> db.execute.return_value.scalar_one_or_none.return_value = None
690 >>> db.add = MagicMock()
691 >>> db.commit = MagicMock()
692 >>> db.refresh = MagicMock()
693 >>> service._notify_prompt_added = AsyncMock()
694 >>> service.convert_prompt_to_read = MagicMock(return_value={})
695 >>> import asyncio
696 >>> try:
697 ... asyncio.run(service.register_prompt(db, prompt))
698 ... except Exception:
699 ... pass
700 >>> logging.disable(logging.NOTSET)
701 """
702 try:
703 content_security = get_content_security_service()
704 content_security.validate_prompt_size(
705 template=prompt.template,
706 name=prompt.name,
707 user_email=created_by or owner_email,
708 ip_address=created_from_ip,
709 )
711 # Validate template syntax
712 self._validate_template(prompt.template)
714 # Extract required arguments from template
715 required_args = self._get_required_arguments(prompt.template)
717 # Create argument schema
718 argument_schema = {
719 "type": "object",
720 "properties": {},
721 "required": list(required_args),
722 }
723 for arg in prompt.arguments:
724 schema = {"type": "string"}
725 if arg.description is not None:
726 schema["description"] = arg.description
727 argument_schema["properties"][arg.name] = schema
729 custom_name = prompt.custom_name or prompt.name
730 display_name = prompt.display_name or custom_name
732 # Extract gateway_id from prompt if present and look up gateway for namespacing
733 gateway_id = getattr(prompt, "gateway_id", None)
734 gateway = None
735 if gateway_id:
736 gateway = db.execute(select(DbGateway).where(DbGateway.id == gateway_id)).scalar_one_or_none()
738 computed_name = self._compute_prompt_name(custom_name, gateway=gateway)
740 # Create DB model
741 db_prompt = DbPrompt(
742 name=computed_name,
743 original_name=prompt.name,
744 custom_name=custom_name,
745 display_name=display_name,
746 title=prompt.title,
747 description=prompt.description,
748 template=prompt.template,
749 argument_schema=argument_schema,
750 tags=prompt.tags,
751 # Metadata fields
752 created_by=created_by,
753 created_from_ip=created_from_ip,
754 created_via=created_via,
755 created_user_agent=created_user_agent,
756 import_batch_id=import_batch_id,
757 federation_source=federation_source,
758 version=1,
759 # Team scoping fields - use schema values if provided, otherwise fallback to parameters
760 team_id=getattr(prompt, "team_id", None) or team_id,
761 owner_email=getattr(prompt, "owner_email", None) or owner_email or created_by,
762 visibility=getattr(prompt, "visibility", None) or visibility,
763 gateway_id=gateway_id,
764 )
765 # Check for existing server with the same name
766 if visibility.lower() == "public":
767 # Check for existing public prompt with the same name and gateway_id
768 existing_prompt = db.execute(select(DbPrompt).where(DbPrompt.name == computed_name, DbPrompt.visibility == "public", DbPrompt.gateway_id == gateway_id)).scalar_one_or_none()
769 if existing_prompt:
770 raise PromptNameConflictError(computed_name, enabled=existing_prompt.enabled, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility)
771 elif visibility.lower() == "team":
772 # Check for existing team prompt with the same name and gateway_id
773 existing_prompt = db.execute(
774 select(DbPrompt).where(DbPrompt.name == computed_name, DbPrompt.visibility == "team", DbPrompt.team_id == team_id, DbPrompt.gateway_id == gateway_id)
775 ).scalar_one_or_none()
776 if existing_prompt:
777 raise PromptNameConflictError(computed_name, enabled=existing_prompt.enabled, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility)
779 # Set gateway relationship to help the before_insert event handler compute the name correctly
780 if gateway:
781 db_prompt.gateway = gateway
782 db_prompt.gateway_name_cache = gateway.name # type: ignore[attr-defined]
784 # Add to DB
785 db.add(db_prompt)
786 db.commit()
787 db.refresh(db_prompt)
788 # Notify subscribers
789 await self._notify_prompt_added(db_prompt)
791 logger.info(f"Registered prompt: {prompt.name}")
793 # Structured logging: Audit trail for prompt creation
794 audit_trail.log_action(
795 user_id=created_by or "system",
796 action="create_prompt",
797 resource_type="prompt",
798 resource_id=str(db_prompt.id),
799 resource_name=db_prompt.name,
800 user_email=owner_email,
801 team_id=team_id,
802 client_ip=created_from_ip,
803 user_agent=created_user_agent,
804 new_values={
805 "name": db_prompt.name,
806 "visibility": visibility,
807 },
808 context={
809 "created_via": created_via,
810 "import_batch_id": import_batch_id,
811 "federation_source": federation_source,
812 },
813 db=db,
814 )
816 # Structured logging: Log successful prompt creation
817 structured_logger.log(
818 level="INFO",
819 message="Prompt created successfully",
820 event_type="prompt_created",
821 component="prompt_service",
822 user_id=created_by,
823 user_email=owner_email,
824 team_id=team_id,
825 resource_type="prompt",
826 resource_id=str(db_prompt.id),
827 custom_fields={
828 "prompt_name": db_prompt.name,
829 "visibility": visibility,
830 },
831 )
833 db_prompt.team = self._get_team_name(db, db_prompt.team_id)
834 prompt_dict = self.convert_prompt_to_read(db_prompt)
836 # Invalidate cache after successful creation
837 cache = _get_registry_cache()
838 await cache.invalidate_prompts()
839 # Also invalidate tags cache since prompt tags may have changed
840 # First-Party
841 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
843 await admin_stats_cache.invalidate_tags()
844 # First-Party
845 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
847 metrics_cache.invalidate_prefix("top_prompts:")
848 metrics_cache.invalidate("prompts")
850 return PromptRead.model_validate(prompt_dict)
852 except IntegrityError as ie:
853 logger.error(f"IntegrityErrors in group: {ie}")
855 structured_logger.log(
856 level="ERROR",
857 message="Prompt creation failed due to database integrity error",
858 event_type="prompt_creation_failed",
859 component="prompt_service",
860 user_id=created_by,
861 user_email=owner_email,
862 error=ie,
863 custom_fields={"prompt_name": prompt.name},
864 )
865 raise ie
866 except PromptNameConflictError as se:
867 db.rollback()
869 structured_logger.log(
870 level="WARNING",
871 message="Prompt creation failed due to name conflict",
872 event_type="prompt_name_conflict",
873 component="prompt_service",
874 user_id=created_by,
875 user_email=owner_email,
876 custom_fields={"prompt_name": prompt.name, "visibility": visibility},
877 )
878 raise se
879 except ContentSizeError as cse:
880 db.rollback()
882 structured_logger.log(
883 level="ERROR",
884 message=f"Prompt template size limit exceeded: {cse.actual_size} bytes (max: {cse.max_size} bytes)",
885 event_type="prompt_size_exceed",
886 component="prompt_service",
887 user_id=created_by,
888 user_email=owner_email,
889 custom_fields={"prompt_name": prompt.name, "visibility": visibility},
890 )
891 raise cse
892 except Exception as e:
893 db.rollback()
895 structured_logger.log(
896 level="ERROR",
897 message="Prompt creation failed",
898 event_type="prompt_creation_failed",
899 component="prompt_service",
900 user_id=created_by,
901 user_email=owner_email,
902 error=e,
903 custom_fields={"prompt_name": prompt.name},
904 )
905 raise PromptError(f"Failed to register prompt: {str(e)}")
907 async def register_prompts_bulk(
908 self,
909 db: Session,
910 prompts: List[PromptCreate],
911 created_by: Optional[str] = None,
912 created_from_ip: Optional[str] = None,
913 created_via: Optional[str] = None,
914 created_user_agent: Optional[str] = None,
915 import_batch_id: Optional[str] = None,
916 federation_source: Optional[str] = None,
917 team_id: Optional[str] = None,
918 owner_email: Optional[str] = None,
919 visibility: Optional[str] = "public",
920 conflict_strategy: str = "skip",
921 ) -> Dict[str, Any]:
922 """Register multiple prompts in bulk with a single commit.
924 This method provides significant performance improvements over individual
925 prompt registration by:
926 - Using db.add_all() instead of individual db.add() calls
927 - Performing a single commit for all prompts
928 - Batch conflict detection
929 - Chunking for very large imports (>500 items)
931 Args:
932 db: Database session
933 prompts: List of prompt creation schemas
934 created_by: Username who created these prompts
935 created_from_ip: IP address of creator
936 created_via: Creation method (ui, api, import, federation)
937 created_user_agent: User agent of creation request
938 import_batch_id: UUID for bulk import operations
939 federation_source: Source gateway for federated prompts
940 team_id: Team ID to assign the prompts to
941 owner_email: Email of the user who owns these prompts
942 visibility: Prompt visibility level (private, team, public)
943 conflict_strategy: How to handle conflicts (skip, update, rename, fail)
945 Returns:
946 Dict with statistics:
947 - created: Number of prompts created
948 - updated: Number of prompts updated
949 - skipped: Number of prompts skipped
950 - failed: Number of prompts that failed
951 - errors: List of error messages
953 Raises:
954 PromptError: If bulk registration fails critically
956 Examples:
957 >>> import logging
958 >>> logging.disable(logging.CRITICAL)
959 >>> from mcpgateway.services.prompt_service import PromptService
960 >>> from unittest.mock import MagicMock
961 >>> service = PromptService()
962 >>> db = MagicMock()
963 >>> p1 = MagicMock()
964 >>> p1.name = "prompt-1"
965 >>> p1.template = "Hello"
966 >>> p1.custom_name = None
967 >>> p1.display_name = None
968 >>> p1.arguments = []
969 >>> p2 = MagicMock()
970 >>> p2.name = "prompt-2"
971 >>> p2.template = "World"
972 >>> p2.custom_name = None
973 >>> p2.display_name = None
974 >>> p2.arguments = []
975 >>> prompts = [p1, p2]
976 >>> import asyncio
977 >>> try:
978 ... result = asyncio.run(service.register_prompts_bulk(db, prompts))
979 ... except Exception:
980 ... pass
981 >>> logging.disable(logging.NOTSET)
982 """
983 if not prompts:
984 return {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
986 stats = {"created": 0, "updated": 0, "skipped": 0, "failed": 0, "errors": []}
988 # Process in chunks to avoid memory issues and SQLite parameter limits
989 chunk_size = 500
991 for chunk_start in range(0, len(prompts), chunk_size):
992 chunk = prompts[chunk_start : chunk_start + chunk_size]
994 try:
995 # Collect unique gateway_ids and look them up
996 gateway_ids = set()
997 for prompt in chunk:
998 gw_id = getattr(prompt, "gateway_id", None)
999 if gw_id:
1000 gateway_ids.add(gw_id)
1002 gateways_map: Dict[str, Any] = {}
1003 if gateway_ids:
1004 gateways = db.execute(select(DbGateway).where(DbGateway.id.in_(gateway_ids))).scalars().all()
1005 gateways_map = {gw.id: gw for gw in gateways}
1007 # Batch check for existing prompts to detect conflicts
1008 # Build computed names with gateway context
1009 prompt_names = []
1010 for prompt in chunk:
1011 custom_name = getattr(prompt, "custom_name", None) or prompt.name
1012 gw_id = getattr(prompt, "gateway_id", None)
1013 gateway = gateways_map.get(gw_id) if gw_id else None
1014 computed_name = self._compute_prompt_name(custom_name, gateway=gateway)
1015 prompt_names.append(computed_name)
1017 # Query for existing prompts - need to consider gateway_id in conflict detection
1018 # Build base query conditions
1019 if visibility.lower() == "public":
1020 base_conditions = [DbPrompt.name.in_(prompt_names), DbPrompt.visibility == "public"]
1021 elif visibility.lower() == "team" and team_id:
1022 base_conditions = [DbPrompt.name.in_(prompt_names), DbPrompt.visibility == "team", DbPrompt.team_id == team_id]
1023 else:
1024 # Private prompts - check by owner
1025 base_conditions = [DbPrompt.name.in_(prompt_names), DbPrompt.visibility == "private", DbPrompt.owner_email == (owner_email or created_by)]
1027 existing_prompts_query = select(DbPrompt).where(*base_conditions)
1028 existing_prompts = db.execute(existing_prompts_query).scalars().all()
1029 # Use (name, gateway_id) tuple as key for proper conflict detection
1030 existing_prompts_map = {(p.name, p.gateway_id): p for p in existing_prompts}
1032 prompts_to_add = []
1033 prompts_to_update = []
1035 for prompt in chunk:
1036 try:
1037 # Validate template size BEFORE any processing
1038 content_security = get_content_security_service()
1039 content_security.validate_prompt_size(template=prompt.template, name=prompt.name, user_email=created_by, ip_address=created_from_ip)
1041 # Validate template syntax
1042 self._validate_template(prompt.template)
1044 # Extract required arguments from template
1045 required_args = self._get_required_arguments(prompt.template)
1047 # Create argument schema
1048 argument_schema = {
1049 "type": "object",
1050 "properties": {},
1051 "required": list(required_args),
1052 }
1053 for arg in prompt.arguments:
1054 schema = {"type": "string"}
1055 if arg.description is not None:
1056 schema["description"] = arg.description
1057 argument_schema["properties"][arg.name] = schema
1059 # Use provided parameters or schema values
1060 prompt_team_id = team_id if team_id is not None else getattr(prompt, "team_id", None)
1061 prompt_owner_email = owner_email or getattr(prompt, "owner_email", None) or created_by
1062 prompt_visibility = visibility if visibility is not None else (getattr(prompt, "visibility", None) or "public")
1063 prompt_gateway_id = getattr(prompt, "gateway_id", None)
1065 custom_name = getattr(prompt, "custom_name", None) or prompt.name
1066 display_name = getattr(prompt, "display_name", None) or custom_name
1067 gateway = gateways_map.get(prompt_gateway_id) if prompt_gateway_id else None
1068 computed_name = self._compute_prompt_name(custom_name, gateway=gateway)
1070 # Look up existing prompt by (name, gateway_id) tuple
1071 existing_prompt = existing_prompts_map.get((computed_name, prompt_gateway_id))
1073 if existing_prompt:
1074 # Handle conflict based on strategy
1075 if conflict_strategy == "skip":
1076 stats["skipped"] += 1
1077 continue
1078 if conflict_strategy == "update":
1079 # Update existing prompt
1080 existing_prompt.title = getattr(prompt, "title", None)
1081 existing_prompt.description = prompt.description
1082 existing_prompt.template = prompt.template
1083 # Clear template cache to reduce memory growth
1084 _compile_jinja_template.cache_clear()
1085 existing_prompt.argument_schema = argument_schema
1086 existing_prompt.tags = prompt.tags or []
1087 if getattr(prompt, "custom_name", None) is not None:
1088 existing_prompt.custom_name = custom_name
1089 if getattr(prompt, "display_name", None) is not None:
1090 existing_prompt.display_name = display_name
1091 existing_prompt.modified_by = created_by
1092 existing_prompt.modified_from_ip = created_from_ip
1093 existing_prompt.modified_via = created_via
1094 existing_prompt.modified_user_agent = created_user_agent
1095 existing_prompt.updated_at = datetime.now(timezone.utc)
1096 existing_prompt.version = (existing_prompt.version or 1) + 1
1098 prompts_to_update.append(existing_prompt)
1099 stats["updated"] += 1
1100 elif conflict_strategy == "rename":
1101 # Create with renamed prompt
1102 new_name = f"{prompt.name}_imported_{int(datetime.now().timestamp())}"
1103 new_custom_name = new_name
1104 new_display_name = new_name
1105 computed_name = self._compute_prompt_name(new_custom_name, gateway=gateway)
1106 db_prompt = DbPrompt(
1107 name=computed_name,
1108 original_name=prompt.name,
1109 custom_name=new_custom_name,
1110 display_name=new_display_name,
1111 title=getattr(prompt, "title", None),
1112 description=prompt.description,
1113 template=prompt.template,
1114 argument_schema=argument_schema,
1115 tags=prompt.tags or [],
1116 created_by=created_by,
1117 created_from_ip=created_from_ip,
1118 created_via=created_via,
1119 created_user_agent=created_user_agent,
1120 import_batch_id=import_batch_id,
1121 federation_source=federation_source,
1122 version=1,
1123 team_id=prompt_team_id,
1124 owner_email=prompt_owner_email,
1125 visibility=prompt_visibility,
1126 gateway_id=prompt_gateway_id,
1127 )
1128 # Set gateway relationship to help the before_insert event handler
1129 if gateway:
1130 db_prompt.gateway = gateway
1131 db_prompt.gateway_name_cache = gateway.name # type: ignore[attr-defined]
1132 prompts_to_add.append(db_prompt)
1133 stats["created"] += 1
1134 elif conflict_strategy == "fail":
1135 stats["failed"] += 1
1136 stats["errors"].append(f"Prompt name conflict: {prompt.name}")
1137 continue
1138 else:
1139 # Create new prompt
1140 db_prompt = DbPrompt(
1141 name=computed_name,
1142 original_name=prompt.name,
1143 custom_name=custom_name,
1144 display_name=display_name,
1145 title=getattr(prompt, "title", None),
1146 description=prompt.description,
1147 template=prompt.template,
1148 argument_schema=argument_schema,
1149 tags=prompt.tags or [],
1150 created_by=created_by,
1151 created_from_ip=created_from_ip,
1152 created_via=created_via,
1153 created_user_agent=created_user_agent,
1154 import_batch_id=import_batch_id,
1155 federation_source=federation_source,
1156 version=1,
1157 team_id=prompt_team_id,
1158 owner_email=prompt_owner_email,
1159 visibility=prompt_visibility,
1160 gateway_id=prompt_gateway_id,
1161 )
1162 # Set gateway relationship to help the before_insert event handler
1163 if gateway:
1164 db_prompt.gateway = gateway
1165 db_prompt.gateway_name_cache = gateway.name # type: ignore[attr-defined]
1166 prompts_to_add.append(db_prompt)
1167 stats["created"] += 1
1169 except Exception as e:
1170 stats["failed"] += 1
1171 stats["errors"].append(f"Failed to process prompt {prompt.name}: {str(e)}")
1172 logger.warning(f"Failed to process prompt {prompt.name} in bulk operation: {str(e)}")
1173 continue
1175 # Bulk add new prompts
1176 if prompts_to_add:
1177 db.add_all(prompts_to_add)
1179 # Commit the chunk
1180 db.commit()
1182 # Refresh prompts for notifications and audit trail
1183 for db_prompt in prompts_to_add:
1184 db.refresh(db_prompt)
1185 # Notify subscribers
1186 await self._notify_prompt_added(db_prompt)
1188 # Log bulk audit trail entry
1189 if prompts_to_add or prompts_to_update:
1190 audit_trail.log_action(
1191 user_id=created_by or "system",
1192 action="bulk_create_prompts" if prompts_to_add else "bulk_update_prompts",
1193 resource_type="prompt",
1194 resource_id=import_batch_id or "bulk_operation",
1195 resource_name=f"Bulk operation: {len(prompts_to_add)} created, {len(prompts_to_update)} updated",
1196 user_email=owner_email,
1197 team_id=team_id,
1198 client_ip=created_from_ip,
1199 user_agent=created_user_agent,
1200 new_values={
1201 "prompts_created": len(prompts_to_add),
1202 "prompts_updated": len(prompts_to_update),
1203 "visibility": visibility,
1204 },
1205 context={
1206 "created_via": created_via,
1207 "import_batch_id": import_batch_id,
1208 "federation_source": federation_source,
1209 "conflict_strategy": conflict_strategy,
1210 },
1211 db=db,
1212 )
1214 logger.info(f"Bulk registered {len(prompts_to_add)} prompts, updated {len(prompts_to_update)} prompts in chunk")
1216 except Exception as e:
1217 db.rollback()
1218 logger.error(f"Failed to process chunk in bulk prompt registration: {str(e)}")
1219 stats["failed"] += len(chunk)
1220 stats["errors"].append(f"Chunk processing failed: {str(e)}")
1221 continue
1223 # Final structured logging
1224 structured_logger.log(
1225 level="INFO",
1226 message="Bulk prompt registration completed",
1227 event_type="prompts_bulk_created",
1228 component="prompt_service",
1229 user_id=created_by,
1230 user_email=owner_email,
1231 team_id=team_id,
1232 resource_type="prompt",
1233 custom_fields={
1234 "prompts_created": stats["created"],
1235 "prompts_updated": stats["updated"],
1236 "prompts_skipped": stats["skipped"],
1237 "prompts_failed": stats["failed"],
1238 "total_prompts": len(prompts),
1239 "visibility": visibility,
1240 "conflict_strategy": conflict_strategy,
1241 },
1242 )
1244 return stats
1246 async def list_prompts(
1247 self,
1248 db: Session,
1249 include_inactive: bool = False,
1250 cursor: Optional[str] = None,
1251 tags: Optional[List[str]] = None,
1252 gateway_id: Optional[str] = None,
1253 limit: Optional[int] = None,
1254 page: Optional[int] = None,
1255 per_page: Optional[int] = None,
1256 user_email: Optional[str] = None,
1257 team_id: Optional[str] = None,
1258 visibility: Optional[str] = None,
1259 token_teams: Optional[List[str]] = None,
1260 ) -> Union[tuple[List[PromptRead], Optional[str]], Dict[str, Any]]:
1261 """
1262 Retrieve a list of prompt templates from the database with pagination support.
1264 This method retrieves prompt templates from the database and converts them into a list
1265 of PromptRead objects. It supports filtering out inactive prompts based on the
1266 include_inactive parameter and cursor-based pagination.
1268 Args:
1269 db (Session): The SQLAlchemy database session.
1270 include_inactive (bool): If True, include inactive prompts in the result.
1271 Defaults to False.
1272 cursor (Optional[str], optional): An opaque cursor token for pagination.
1273 Opaque base64-encoded string containing last item's ID and created_at.
1274 tags (Optional[List[str]]): Filter prompts by tags. If provided, only prompts with at least one matching tag will be returned.
1275 gateway_id (Optional[str]): Filter prompts by gateway ID. Accepts the literal value 'null' to match NULL gateway_id.
1276 limit (Optional[int]): Maximum number of prompts to return. Use 0 for all prompts (no limit).
1277 If not specified, uses pagination_default_page_size.
1278 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
1279 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
1280 user_email (Optional[str]): User email for team-based access control. If None, no access control is applied.
1281 team_id (Optional[str]): Filter by specific team ID. Requires user_email for access validation.
1282 visibility (Optional[str]): Filter by visibility (private, team, public).
1283 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API token access
1284 where the token scope should be respected instead of the user's full team memberships.
1286 Returns:
1287 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
1288 If cursor is provided or neither: tuple of (list of PromptRead objects, next_cursor).
1290 Examples:
1291 >>> from mcpgateway.services.prompt_service import PromptService
1292 >>> from unittest.mock import MagicMock
1293 >>> from mcpgateway.schemas import PromptRead
1294 >>> service = PromptService()
1295 >>> db = MagicMock()
1296 >>> prompt_read_obj = MagicMock(spec=PromptRead)
1297 >>> service.convert_prompt_to_read = MagicMock(return_value=prompt_read_obj)
1298 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1299 >>> import asyncio
1300 >>> prompts, next_cursor = asyncio.run(service.list_prompts(db))
1301 >>> prompts == [prompt_read_obj]
1302 True
1303 """
1304 with create_span(
1305 "prompt.list",
1306 {
1307 "include_inactive": include_inactive,
1308 "tags.count": len(tags) if tags else 0,
1309 "gateway_id": gateway_id,
1310 "limit": limit,
1311 "page": page,
1312 "per_page": per_page,
1313 "user.email": user_email,
1314 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None,
1315 "team.filter": team_id,
1316 "visibility": visibility,
1317 },
1318 ):
1319 # Check cache for first page only (cursor=None)
1320 # Skip caching when:
1321 # - user_email is provided (team-filtered results are user-specific)
1322 # - token_teams is set (scoped access, e.g., public-only or team-scoped tokens)
1323 # - page-based pagination is used
1324 # This prevents cache poisoning where admin results could leak to public-only requests
1325 cache = _get_registry_cache()
1326 if cursor is None and user_email is None and token_teams is None and page is None:
1327 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None, gateway_id=gateway_id, limit=limit, visibility=visibility)
1328 cached = await cache.get("prompts", filters_hash)
1329 if cached is not None:
1330 # Reconstruct PromptRead objects from cached dicts
1331 cached_prompts = [PromptRead.model_validate(p) for p in cached["prompts"]]
1332 return (cached_prompts, cached.get("next_cursor"))
1334 # Build base query with ordering and eager load gateway to avoid N+1
1335 query = select(DbPrompt).options(joinedload(DbPrompt.gateway)).order_by(desc(DbPrompt.created_at), desc(DbPrompt.id))
1337 if not include_inactive:
1338 query = query.where(DbPrompt.enabled)
1340 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
1342 if visibility:
1343 query = query.where(DbPrompt.visibility == visibility)
1345 # Add gateway_id filtering if provided
1346 if gateway_id:
1347 if gateway_id.lower() == "null":
1348 query = query.where(DbPrompt.gateway_id.is_(None))
1349 else:
1350 query = query.where(DbPrompt.gateway_id == gateway_id)
1352 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
1353 if tags:
1354 query = query.where(json_contains_tag_expr(db, DbPrompt.tags, tags, match_any=True))
1356 # Use unified pagination helper - handles both page and cursor pagination
1357 pag_result = await unified_paginate(
1358 db=db,
1359 query=query,
1360 page=page,
1361 per_page=per_page,
1362 cursor=cursor,
1363 limit=limit,
1364 base_url="/admin/prompts", # Used for page-based links
1365 query_params={"include_inactive": include_inactive} if include_inactive else {},
1366 )
1368 next_cursor = None
1369 # Extract servers based on pagination type
1370 if page is not None:
1371 # Page-based: pag_result is a dict
1372 prompts_db = pag_result["data"]
1373 else:
1374 # Cursor-based: pag_result is a tuple
1375 prompts_db, next_cursor = pag_result
1377 # Fetch team names for the prompts (common for both pagination types)
1378 team_ids_set = {s.team_id for s in prompts_db if s.team_id}
1379 team_map = {}
1380 if team_ids_set:
1381 teams = db.execute(select(EmailTeam.id, EmailTeam.name).where(EmailTeam.id.in_(team_ids_set), EmailTeam.is_active.is_(True))).all()
1382 team_map = {team.id: team.name for team in teams}
1384 db.commit() # Release transaction to avoid idle-in-transaction
1386 # Convert to PromptRead (common for both pagination types)
1387 result = []
1388 for s in prompts_db:
1389 try:
1390 s.team = team_map.get(s.team_id) if s.team_id else None
1391 result.append(self.convert_prompt_to_read(s, include_metrics=False))
1392 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1393 logger.exception(f"Failed to convert prompt {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
1394 # Continue with remaining prompts instead of failing completely
1395 # Return appropriate format based on pagination type
1396 if page is not None:
1397 # Page-based format
1398 return {
1399 "data": result,
1400 "pagination": pag_result["pagination"],
1401 "links": pag_result["links"],
1402 }
1404 # Cursor-based format
1406 # Cache first page results - only for non-user-specific/non-scoped queries
1407 # Must match the same conditions as cache lookup to prevent cache poisoning
1408 if cursor is None and user_email is None and token_teams is None:
1409 try:
1410 cache_data = {"prompts": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
1411 await cache.set("prompts", cache_data, filters_hash)
1412 except AttributeError:
1413 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
1415 return (result, next_cursor)
1417 async def list_prompts_for_user(
1418 self, db: Session, user_email: str, team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
1419 ) -> List[PromptRead]:
1420 """
1421 DEPRECATED: Use list_prompts() with user_email parameter instead.
1423 This method is maintained for backward compatibility but is no longer used.
1424 New code should call list_prompts() with user_email, team_id, and visibility parameters.
1426 List prompts user has access to with team filtering.
1428 Args:
1429 db: Database session
1430 user_email: Email of the user requesting prompts
1431 team_id: Optional team ID to filter by specific team
1432 visibility: Optional visibility filter (private, team, public)
1433 include_inactive: Whether to include inactive prompts
1434 skip: Number of prompts to skip for pagination
1435 limit: Maximum number of prompts to return
1437 Returns:
1438 List[PromptRead]: Prompts the user has access to
1439 """
1440 # Build query following existing patterns from list_prompts()
1441 team_service = TeamManagementService(db)
1442 user_teams = await team_service.get_user_teams(user_email)
1443 team_ids = [team.id for team in user_teams]
1445 # Build query following existing patterns from list_resources()
1446 # Eager load gateway to avoid N+1 when accessing gateway_slug
1447 query = select(DbPrompt).options(joinedload(DbPrompt.gateway))
1449 # Apply active/inactive filter
1450 if not include_inactive:
1451 query = query.where(DbPrompt.enabled)
1453 if team_id:
1454 if team_id not in team_ids:
1455 return [] # No access to team
1457 access_conditions = []
1458 # Filter by specific team
1459 access_conditions.append(and_(DbPrompt.team_id == team_id, DbPrompt.visibility.in_(["team", "public"])))
1461 access_conditions.append(and_(DbPrompt.team_id == team_id, DbPrompt.owner_email == user_email))
1463 query = query.where(or_(*access_conditions))
1464 else:
1465 # Get user's accessible teams
1466 # Build access conditions following existing patterns
1467 access_conditions = []
1468 # 1. User's personal resources (owner_email matches)
1469 access_conditions.append(DbPrompt.owner_email == user_email)
1470 # 2. Team resources where user is member
1471 if team_ids:
1472 access_conditions.append(and_(DbPrompt.team_id.in_(team_ids), DbPrompt.visibility.in_(["team", "public"])))
1473 # 3. Public resources (if visibility allows)
1474 access_conditions.append(DbPrompt.visibility == "public")
1476 query = query.where(or_(*access_conditions))
1478 # Apply visibility filter if specified
1479 if visibility:
1480 query = query.where(DbPrompt.visibility == visibility)
1482 # Apply pagination following existing patterns
1483 query = query.offset(skip).limit(limit)
1485 prompts = db.execute(query).scalars().all()
1487 # Batch fetch team names to avoid N+1 queries
1488 prompt_team_ids = {p.team_id for p in prompts if p.team_id}
1489 team_map = {}
1490 if prompt_team_ids:
1491 teams = db.execute(select(EmailTeam.id, EmailTeam.name).where(EmailTeam.id.in_(prompt_team_ids), EmailTeam.is_active.is_(True))).all()
1492 team_map = {str(team.id): team.name for team in teams}
1494 db.commit() # Release transaction to avoid idle-in-transaction
1496 result = []
1497 for t in prompts:
1498 try:
1499 t.team = team_map.get(str(t.team_id)) if t.team_id else None
1500 result.append(self.convert_prompt_to_read(t, include_metrics=False))
1501 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1502 logger.exception(f"Failed to convert prompt {getattr(t, 'id', 'unknown')} ({getattr(t, 'name', 'unknown')}): {e}")
1503 # Continue with remaining prompts instead of failing completely
1504 return result
1506 async def list_server_prompts(
1507 self,
1508 db: Session,
1509 server_id: str,
1510 include_inactive: bool = False,
1511 include_metrics: bool = False,
1512 cursor: Optional[str] = None,
1513 user_email: Optional[str] = None,
1514 token_teams: Optional[List[str]] = None,
1515 ) -> List[PromptRead]:
1516 """
1517 Retrieve a list of prompt templates from the database.
1519 This method retrieves prompt templates from the database and converts them into a list
1520 of PromptRead objects. It supports filtering out inactive prompts based on the
1521 include_inactive parameter. The cursor parameter is reserved for future pagination support
1522 but is currently not implemented.
1524 Args:
1525 db (Session): The SQLAlchemy database session.
1526 server_id (str): Server ID
1527 include_inactive (bool): If True, include inactive prompts in the result.
1528 Defaults to False.
1529 include_metrics (bool): If True, include metrics data in the response.
1530 Defaults to False.
1531 cursor (Optional[str], optional): An opaque cursor token for pagination. Currently,
1532 this parameter is ignored. Defaults to None.
1533 user_email (Optional[str]): User email for visibility filtering. If None, no filtering applied.
1534 token_teams (Optional[List[str]]): Override DB team lookup with token's teams. Used for MCP/API
1535 token access where the token scope should be respected.
1537 Returns:
1538 List[PromptRead]: A list of prompt templates represented as PromptRead objects.
1540 Examples:
1541 >>> from mcpgateway.services.prompt_service import PromptService
1542 >>> from unittest.mock import MagicMock
1543 >>> from mcpgateway.schemas import PromptRead
1544 >>> service = PromptService()
1545 >>> db = MagicMock()
1546 >>> prompt_read_obj = MagicMock(spec=PromptRead)
1547 >>> service.convert_prompt_to_read = MagicMock(return_value=prompt_read_obj)
1548 >>> db.execute.return_value.scalars.return_value.all.return_value = [MagicMock()]
1549 >>> import asyncio
1550 >>> result = asyncio.run(service.list_server_prompts(db, 'server1'))
1551 >>> result == [prompt_read_obj]
1552 True
1553 """
1554 with create_span(
1555 "prompt.list",
1556 {
1557 "server_id": server_id,
1558 "include_inactive": include_inactive,
1559 "include_metrics": include_metrics,
1560 "user.email": user_email,
1561 "team.scope": format_trace_team_scope(token_teams) if token_teams is not None else None,
1562 },
1563 ):
1564 # Eager load gateway to avoid N+1 when accessing gateway_slug
1565 query = (
1566 select(DbPrompt)
1567 .options(joinedload(DbPrompt.gateway))
1568 .join(server_prompt_association, DbPrompt.id == server_prompt_association.c.prompt_id)
1569 .where(server_prompt_association.c.server_id == server_id)
1570 )
1572 # Eager load metrics relationships to prevent N+1 queries when include_metrics=true
1573 if include_metrics:
1574 query = query.options(selectinload(DbPrompt.metrics), selectinload(DbPrompt.metrics_hourly))
1575 if not include_inactive:
1576 query = query.where(DbPrompt.enabled)
1578 # Add visibility filtering if user context OR token_teams provided
1579 # This ensures unauthenticated requests with token_teams=[] only see public prompts
1580 if user_email is not None or token_teams is not None: # empty-string user_email -> public-only filtering (secure default)
1581 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
1582 if token_teams is not None:
1583 team_ids = token_teams
1584 elif user_email:
1585 team_service = TeamManagementService(db)
1586 user_teams = await team_service.get_user_teams(user_email)
1587 team_ids = [team.id for team in user_teams]
1588 else:
1589 team_ids = []
1591 # Check if this is a public-only token (empty teams array)
1592 # Public-only tokens can ONLY see public resources - no owner access
1593 is_public_only_token = token_teams is not None and len(token_teams) == 0
1595 access_conditions = [
1596 DbPrompt.visibility == "public",
1597 ]
1598 # Only include owner access for non-public-only tokens with user_email
1599 if not is_public_only_token and user_email:
1600 access_conditions.append(DbPrompt.owner_email == user_email)
1601 if team_ids:
1602 access_conditions.append(and_(DbPrompt.team_id.in_(team_ids), DbPrompt.visibility.in_(["team", "public"])))
1603 query = query.where(or_(*access_conditions))
1605 # Cursor-based pagination logic can be implemented here in the future.
1606 logger.debug(cursor)
1607 prompts = db.execute(query).scalars().all()
1609 # Batch fetch team names to avoid N+1 queries
1610 prompt_team_ids = {p.team_id for p in prompts if p.team_id}
1611 team_map = {}
1612 if prompt_team_ids:
1613 teams = db.execute(select(EmailTeam.id, EmailTeam.name).where(EmailTeam.id.in_(prompt_team_ids), EmailTeam.is_active.is_(True))).all()
1614 team_map = {str(team.id): team.name for team in teams}
1616 db.commit() # Release transaction to avoid idle-in-transaction
1618 result = []
1619 for t in prompts:
1620 try:
1621 t.team = team_map.get(str(t.team_id)) if t.team_id else None
1622 result.append(self.convert_prompt_to_read(t, include_metrics=include_metrics))
1623 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
1624 logger.exception(f"Failed to convert prompt {getattr(t, 'id', 'unknown')} ({getattr(t, 'name', 'unknown')}): {e}")
1625 # Continue with remaining prompts instead of failing completely
1626 return result
1628 async def _record_prompt_metric(self, db: Session, prompt: DbPrompt, start_time: float, success: bool, error_message: Optional[str]) -> None:
1629 """
1630 Records a metric for a prompt invocation.
1632 Args:
1633 db: Database session
1634 prompt: The prompt that was invoked
1635 start_time: Monotonic start time of the invocation
1636 success: True if successful, False otherwise
1637 error_message: Error message if failed, None otherwise
1638 """
1639 end_time = time.monotonic()
1640 response_time = end_time - start_time
1642 metric = PromptMetric(
1643 prompt_id=prompt.id,
1644 response_time=response_time,
1645 is_success=success,
1646 error_message=error_message,
1647 )
1648 db.add(metric)
1649 db.commit()
1651 async def _check_prompt_access(
1652 self,
1653 db: Session,
1654 prompt: DbPrompt,
1655 user_email: Optional[str],
1656 token_teams: Optional[List[str]],
1657 ) -> bool:
1658 """Check if user has access to a prompt based on visibility rules.
1660 Implements the same access control logic as list_prompts() for consistency.
1662 Args:
1663 db: Database session for team membership lookup if needed.
1664 prompt: Prompt ORM object with visibility, team_id, owner_email.
1665 user_email: Email of the requesting user (None = unauthenticated).
1666 token_teams: List of team IDs from token.
1667 - None = unrestricted admin access
1668 - [] = public-only token
1669 - [...] = team-scoped token
1671 Returns:
1672 True if access is allowed, False otherwise.
1673 """
1674 visibility = getattr(prompt, "visibility", "public")
1675 prompt_team_id = getattr(prompt, "team_id", None)
1676 prompt_owner_email = getattr(prompt, "owner_email", None)
1678 # Public prompts are accessible by everyone
1679 if visibility == "public":
1680 return True
1682 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin
1683 # This happens when is_admin=True and no team scoping in token
1684 if token_teams is None and user_email is None:
1685 return True
1687 # No user context (but not admin) = deny access to non-public prompts
1688 if not user_email:
1689 return False
1691 # Public-only tokens (empty teams array) can ONLY access public prompts
1692 is_public_only_token = token_teams is not None and len(token_teams) == 0
1693 if is_public_only_token:
1694 return False # Already checked public above
1696 # Owner can access their own private prompts
1697 if visibility == "private" and prompt_owner_email and prompt_owner_email == user_email:
1698 return True
1700 # Team prompts: check team membership (matches list_prompts behavior)
1701 if prompt_team_id:
1702 # Use token_teams if provided, otherwise look up from DB
1703 if token_teams is not None:
1704 team_ids = token_teams
1705 else:
1706 team_service = TeamManagementService(db)
1707 user_teams = await team_service.get_user_teams(user_email)
1708 team_ids = [team.id for team in user_teams]
1710 # Team/public visibility allows access if user is in the team
1711 if visibility in ["team", "public"] and prompt_team_id in team_ids:
1712 return True
1714 return False
1716 def _find_prompt_by_name_or_id(
1717 self,
1718 db: Session,
1719 scoped_query: Any,
1720 prompt_id: str,
1721 ) -> Optional[DbPrompt]:
1722 """Find a prompt by name or ID using a scoped query.
1724 Uses a single OR query for efficiency, with a fallback to name-only
1725 lookup if the OR matches multiple rows (e.g. one prompt's name equals
1726 another prompt's ID).
1728 Args:
1729 db: Database session
1730 scoped_query: Pre-scoped SQLAlchemy query with access control applied
1731 prompt_id: Name or ID of the prompt to find
1733 Returns:
1734 DbPrompt instance if found, None otherwise
1736 Raises:
1737 PromptError: If multiple accessible prompts share the same name.
1739 Note:
1740 The scoped_query must already have team-based access control applied
1741 via _apply_access_control() to ensure multi-tenancy security.
1742 """
1743 try:
1744 return db.execute(scoped_query.where(or_(DbPrompt.name == prompt_id, DbPrompt.id == prompt_id))).scalar_one_or_none()
1745 except MultipleResultsFound:
1746 # OR matched multiple rows — try name-only (MCP spec primary key)
1747 try:
1748 return db.execute(scoped_query.where(DbPrompt.name == prompt_id)).scalar_one_or_none()
1749 except MultipleResultsFound:
1750 raise PromptError(f"Prompt name '{prompt_id}' is ambiguous across multiple scopes; use /servers/{{id}}/mcp to disambiguate.")
1752 async def get_prompt(
1753 self,
1754 db: Session,
1755 prompt_id: Union[int, str],
1756 arguments: Optional[Dict[str, str]] = None,
1757 user: Optional[str] = None,
1758 tenant_id: Optional[str] = None,
1759 server_id: Optional[str] = None,
1760 request_id: Optional[str] = None,
1761 token_teams: Optional[List[str]] = None,
1762 plugin_context_table: Optional[PluginContextTable] = None,
1763 plugin_global_context: Optional[GlobalContext] = None,
1764 _meta_data: Optional[Dict[str, Any]] = None,
1765 ) -> PromptResult:
1766 """Retrieve and render a prompt by name or ID.
1768 This method implements MCP specification-compliant prompt lookup with
1769 multi-tenancy support and backward compatibility.
1771 Args:
1772 db: Database session
1773 prompt_id: Name or ID of the prompt to retrieve. Name-based lookup
1774 is prioritized per MCP spec, with ID fallback for backward
1775 compatibility. Team-based access control is applied before the
1776 lookup to ensure multi-tenancy security.
1777 arguments: Optional arguments for rendering
1778 user: Optional user email for authorization checks
1779 tenant_id: Optional tenant identifier for plugin context
1780 server_id: Optional server ID for server scoping enforcement
1781 request_id: Optional request ID, generated if not provided
1782 token_teams: Optional list of team IDs from token for authorization.
1783 None = unrestricted admin, [] = public-only, [...] = team-scoped.
1784 plugin_context_table: Optional plugin context table from previous hooks for cross-hook state sharing.
1785 plugin_global_context: Optional global context from middleware for consistency across hooks.
1786 _meta_data: Optional metadata for prompt retrieval (not used currently).
1788 Returns:
1789 Prompt result with rendered messages
1791 Raises:
1792 PluginViolationError: If prompt violates a plugin policy
1793 PromptNotFoundError: If prompt not found or access denied
1794 PromptError: For other prompt errors
1795 PluginError: If encounters issue with plugin
1797 Examples:
1798 >>> from mcpgateway.services.prompt_service import PromptService
1799 >>> from unittest.mock import MagicMock
1800 >>> service = PromptService()
1801 >>> db = MagicMock()
1802 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock()
1803 >>> import asyncio
1804 >>> try:
1805 ... asyncio.run(service.get_prompt(db, 'prompt_id'))
1806 ... except Exception:
1807 ... pass
1808 """
1810 start_time = time.monotonic()
1811 success = False
1812 error_message = None
1813 prompt = None
1814 server_scoped = False
1816 # Create database span for observability dashboard
1817 trace_id = current_trace_id.get()
1818 db_span_id = None
1819 db_span_ended = False
1820 observability_service = ObservabilityService() if trace_id else None
1822 if trace_id and observability_service:
1823 try:
1824 db_span_id = observability_service.start_span(
1825 db=db,
1826 trace_id=trace_id,
1827 name="prompt.render",
1828 attributes={
1829 "prompt.id": str(prompt_id),
1830 "arguments_count": len(arguments) if arguments else 0,
1831 "user": user or "anonymous",
1832 "server_id": server_id,
1833 "tenant_id": tenant_id,
1834 "request_id": request_id or "none",
1835 },
1836 )
1837 logger.debug(f"✓ Created prompt.render span: {db_span_id} for prompt: {prompt_id}")
1838 except Exception as e:
1839 logger.warning(f"Failed to start observability span for prompt rendering: {e}")
1840 db_span_id = None
1842 # Create a trace span for OpenTelemetry export (Jaeger, Zipkin, etc.)
1843 span_attributes = {
1844 "prompt.id": prompt_id,
1845 "arguments_count": len(arguments) if arguments else 0,
1846 "user": user or "anonymous",
1847 "server_id": server_id,
1848 "tenant_id": tenant_id,
1849 "request_id": request_id or "none",
1850 }
1851 if is_input_capture_enabled("prompt.render"):
1852 span_attributes["langfuse.observation.input"] = serialize_trace_payload(arguments or {})
1854 with create_span("prompt.render", span_attributes) as span:
1855 try:
1856 # Check if any prompt hooks are registered to avoid unnecessary context creation
1857 has_pre_fetch = self._plugin_manager and self._plugin_manager.has_hooks_for(PromptHookType.PROMPT_PRE_FETCH)
1858 has_post_fetch = self._plugin_manager and self._plugin_manager.has_hooks_for(PromptHookType.PROMPT_POST_FETCH)
1860 # Initialize plugin context variables only if hooks are registered
1861 context_table = None
1862 global_context = None
1863 if has_pre_fetch or has_post_fetch:
1864 context_table = plugin_context_table
1865 if plugin_global_context:
1866 global_context = plugin_global_context
1867 # Update fields with prompt-specific information
1868 if user:
1869 global_context.user = user
1870 if server_id:
1871 global_context.server_id = server_id
1872 if tenant_id:
1873 global_context.tenant_id = tenant_id
1874 else:
1875 # Create new context (fallback when middleware didn't run)
1876 if not request_id:
1877 request_id = uuid.uuid4().hex
1878 global_context = GlobalContext(request_id=request_id, user=user, server_id=server_id, tenant_id=tenant_id)
1880 if has_pre_fetch:
1881 pre_result, context_table = await self._plugin_manager.invoke_hook(
1882 PromptHookType.PROMPT_PRE_FETCH,
1883 payload=PromptPrehookPayload(prompt_id=prompt_id, args=arguments),
1884 global_context=global_context,
1885 local_contexts=context_table, # Pass context from previous hooks
1886 violations_as_exceptions=True,
1887 )
1889 # Use modified payload if provided
1890 if pre_result.modified_payload:
1891 payload = pre_result.modified_payload
1892 arguments = payload.args
1894 # ═══════════════════════════════════════════════════════════════════════════
1895 # SECURITY: Apply team scoping BEFORE lookup to prevent multi-tenancy issues
1896 # This ensures users only see prompts they have access to, matching list_prompts()
1897 # Build base query and apply access control (matches list_prompts architecture)
1898 # ═══════════════════════════════════════════════════════════════════════════
1899 search_key = str(prompt_id)
1901 # Build base query with server + team scoping applied FIRST
1902 base_query = select(DbPrompt).options(joinedload(DbPrompt.gateway)).where(DbPrompt.enabled)
1903 if server_id:
1904 base_query = base_query.join(server_prompt_association, DbPrompt.id == server_prompt_association.c.prompt_id).where(server_prompt_association.c.server_id == server_id)
1905 scoped_query = await self._apply_access_control(base_query, db, user, token_teams, team_id=None)
1907 # Find prompt by name or ID (active prompts only) using optimized OR query
1908 prompt = self._find_prompt_by_name_or_id(db, scoped_query, prompt_id)
1910 # If not found in active prompts, check inactive prompts (with team + server scoping)
1911 if not prompt:
1912 inactive_base_query = select(DbPrompt).options(joinedload(DbPrompt.gateway)).where(not_(DbPrompt.enabled))
1913 if server_id:
1914 inactive_base_query = inactive_base_query.join(server_prompt_association, DbPrompt.id == server_prompt_association.c.prompt_id).where(
1915 server_prompt_association.c.server_id == server_id
1916 )
1917 inactive_scoped_query = await self._apply_access_control(inactive_base_query, db, user, token_teams, team_id=None)
1919 # Find in inactive prompts using optimized OR query
1920 inactive_prompt = self._find_prompt_by_name_or_id(db, inactive_scoped_query, prompt_id)
1922 if inactive_prompt:
1923 raise PromptNotFoundError(f"Prompt '{search_key}' exists but is inactive")
1925 raise PromptNotFoundError(f"Prompt not found: {search_key}")
1927 # Access control already applied via scoped query - no additional check needed
1929 # ═══════════════════════════════════════════════════════════════════════════
1930 # SECURITY: Enforce server scoping if server_id is provided
1931 # Prompt must be attached to the specified virtual server
1932 # ═══════════════════════════════════════════════════════════════════════════
1933 if server_id:
1934 server_match = db.execute(
1935 select(server_prompt_association.c.prompt_id).where(
1936 server_prompt_association.c.server_id == server_id,
1937 server_prompt_association.c.prompt_id == prompt.id,
1938 )
1939 ).first()
1940 if not server_match:
1941 raise PromptNotFoundError(f"Prompt not found: {search_key}")
1942 server_scoped = True
1944 if self._should_fetch_gateway_prompt(prompt):
1945 # Release the read transaction before any remote network I/O.
1946 db.commit()
1947 result = await self._fetch_gateway_prompt_result(prompt, arguments, user)
1948 elif not arguments:
1949 result = PromptResult(
1950 messages=[
1951 Message(
1952 role=Role.USER,
1953 content=TextContent(type="text", text=prompt.template),
1954 )
1955 ],
1956 description=prompt.description,
1957 )
1958 else:
1959 try:
1960 prompt.validate_arguments(arguments)
1961 rendered = self._render_template(prompt.template, arguments)
1962 messages = self._parse_messages(rendered)
1963 result = PromptResult(messages=messages, description=prompt.description)
1964 except Exception as e:
1965 set_span_error(span, e)
1966 raise PromptError(f"Failed to process prompt: {str(e)}")
1968 if has_post_fetch:
1969 post_result, _ = await self._plugin_manager.invoke_hook(
1970 PromptHookType.PROMPT_POST_FETCH,
1971 payload=PromptPosthookPayload(prompt_id=prompt.name, result=result),
1972 global_context=global_context,
1973 local_contexts=context_table,
1974 violations_as_exceptions=True,
1975 )
1976 # Use modified payload if provided
1977 result = post_result.modified_payload.result if post_result.modified_payload else result
1979 arguments_supplied = bool(arguments)
1981 audit_trail.log_action(
1982 user_id=user or "anonymous",
1983 action="view_prompt",
1984 resource_type="prompt",
1985 resource_id=str(prompt.id),
1986 resource_name=prompt.name,
1987 team_id=prompt.team_id,
1988 context={
1989 "tenant_id": tenant_id,
1990 "server_id": server_id,
1991 "arguments_provided": arguments_supplied,
1992 "request_id": request_id,
1993 },
1994 db=db,
1995 )
1997 structured_logger.log(
1998 level="INFO",
1999 message="Prompt retrieved successfully",
2000 event_type="prompt_viewed",
2001 component="prompt_service",
2002 user_id=user,
2003 team_id=prompt.team_id,
2004 resource_type="prompt",
2005 resource_id=str(prompt.id),
2006 request_id=request_id,
2007 custom_fields={
2008 "prompt_name": prompt.name,
2009 "arguments_provided": arguments_supplied,
2010 "tenant_id": tenant_id,
2011 "server_id": server_id,
2012 },
2013 )
2015 # Set success attributes on span
2016 if span:
2017 set_span_attribute(span, "success", True)
2018 set_span_attribute(span, "duration.ms", (time.monotonic() - start_time) * 1000)
2019 set_span_attribute(span, "langfuse.observation.prompt.name", prompt.name)
2020 if getattr(prompt, "version", None) is not None:
2021 set_span_attribute(span, "langfuse.observation.prompt.version", int(prompt.version))
2022 if result and hasattr(result, "messages"):
2023 set_span_attribute(span, "messages.count", len(result.messages))
2024 if is_output_capture_enabled("prompt.render"):
2025 set_span_attribute(span, "langfuse.observation.output", serialize_trace_payload(result))
2027 success = True
2028 logger.info(f"Retrieved prompt: {prompt.id} successfully")
2029 return result
2031 except Exception as e:
2032 success = False
2033 error_message = str(e)
2034 raise
2035 finally:
2036 # Record metrics only if we found a prompt
2037 if prompt:
2038 try:
2039 metrics_buffer.record_prompt_metric(
2040 prompt_id=prompt.id,
2041 start_time=start_time,
2042 success=success,
2043 error_message=error_message,
2044 )
2045 except Exception as metrics_error:
2046 logger.warning(f"Failed to record prompt metric: {metrics_error}")
2048 # Record server metrics ONLY when the server scoping check passed.
2049 # This prevents recording metrics with unvalidated server_id values
2050 # from admin API headers (X-Server-ID) or RPC params.
2051 if server_scoped:
2052 try:
2053 # Record server metric only for the specific virtual server being accessed
2054 metrics_buffer.record_server_metric(
2055 server_id=server_id,
2056 start_time=start_time,
2057 success=success,
2058 error_message=error_message,
2059 )
2060 except Exception as metrics_error:
2061 logger.warning(f"Failed to record server metric: {metrics_error}")
2063 # End database span for observability dashboard
2064 if db_span_id and observability_service and not db_span_ended:
2065 try:
2066 observability_service.end_span(
2067 db=db,
2068 span_id=db_span_id,
2069 status="ok" if success else "error",
2070 status_message=error_message if error_message else None,
2071 )
2072 db_span_ended = True
2073 logger.debug(f"✓ Ended prompt.render span: {db_span_id}")
2074 except Exception as e:
2075 logger.warning(f"Failed to end observability span for prompt rendering: {e}")
2077 async def update_prompt(
2078 self,
2079 db: Session,
2080 prompt_id: Union[int, str],
2081 prompt_update: PromptUpdate,
2082 modified_by: Optional[str] = None,
2083 modified_from_ip: Optional[str] = None,
2084 modified_via: Optional[str] = None,
2085 modified_user_agent: Optional[str] = None,
2086 user_email: Optional[str] = None,
2087 ) -> PromptRead:
2088 """
2089 Update a prompt template.
2091 Args:
2092 db: Database session
2093 prompt_id: ID of prompt to update
2094 prompt_update: Prompt update object
2095 modified_by: Username of the person modifying the prompt
2096 modified_from_ip: IP address where the modification originated
2097 modified_via: Source of modification (ui/api/import)
2098 modified_user_agent: User agent string from the modification request
2099 user_email: Email of user performing update (for ownership check)
2101 Returns:
2102 The updated PromptRead object
2104 Raises:
2105 PromptNotFoundError: If the prompt is not found
2106 PermissionError: If user doesn't own the prompt
2107 IntegrityError: If a database integrity error occurs.
2108 PromptNameConflictError: If a prompt with the same name already exists.
2109 PromptError: For other update errors
2110 ContentSizeError: For template size exceed
2112 Examples:
2113 >>> import logging
2114 >>> logging.disable(logging.CRITICAL)
2115 >>> from mcpgateway.services.prompt_service import PromptService
2116 >>> from unittest.mock import AsyncMock, MagicMock
2117 >>> service = PromptService()
2118 >>> db = MagicMock()
2119 >>> existing = MagicMock()
2120 >>> existing.custom_name = "test-prompt"
2121 >>> existing.name = "test-prompt"
2122 >>> existing.gateway = None
2123 >>> db.execute.return_value.scalar_one_or_none.return_value = existing
2124 >>> db.commit = MagicMock()
2125 >>> db.refresh = MagicMock()
2126 >>> service._notify_prompt_updated = AsyncMock()
2127 >>> service.convert_prompt_to_read = MagicMock(return_value={})
2128 >>> update = MagicMock()
2129 >>> update.name = None
2130 >>> update.visibility = None
2131 >>> update.team_id = None
2132 >>> import asyncio
2133 >>> try:
2134 ... asyncio.run(service.update_prompt(db, 'prompt_name', update))
2135 ... except Exception:
2136 ... pass
2137 >>> logging.disable(logging.NOTSET)
2138 """
2139 try:
2140 # Acquire a row-level lock for the prompt being updated to make
2141 # name-checks and the subsequent update atomic in PostgreSQL.
2142 # For SQLite `get_for_update` falls back to a regular get.
2143 prompt = get_for_update(db, DbPrompt, prompt_id)
2144 if not prompt:
2145 raise PromptNotFoundError(f"Prompt not found: {prompt_id}")
2147 visibility = prompt_update.visibility or prompt.visibility
2148 team_id = prompt_update.team_id or prompt.team_id
2149 owner_email = prompt.owner_email or user_email
2151 candidate_custom_name = prompt.custom_name
2153 if prompt_update.name is not None:
2154 candidate_custom_name = prompt_update.custom_name or prompt_update.name
2155 elif prompt_update.custom_name is not None:
2156 candidate_custom_name = prompt_update.custom_name
2158 computed_name = self._compute_prompt_name(candidate_custom_name, prompt.gateway)
2159 if computed_name != prompt.name:
2160 if visibility.lower() == "public":
2161 # Lock any conflicting row so concurrent updates cannot race.
2162 existing_prompt = get_for_update(db, DbPrompt, where=and_(DbPrompt.name == computed_name, DbPrompt.visibility == "public", DbPrompt.id != prompt.id))
2163 if existing_prompt:
2164 raise PromptNameConflictError(computed_name, enabled=existing_prompt.enabled, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility)
2165 elif visibility.lower() == "team" and team_id:
2166 existing_prompt = get_for_update(db, DbPrompt, where=and_(DbPrompt.name == computed_name, DbPrompt.visibility == "team", DbPrompt.team_id == team_id, DbPrompt.id != prompt.id))
2167 logger.info(f"Existing prompt check result: {existing_prompt}")
2168 if existing_prompt:
2169 raise PromptNameConflictError(computed_name, enabled=existing_prompt.enabled, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility)
2170 elif visibility.lower() == "private":
2171 existing_prompt = get_for_update(
2172 db, DbPrompt, where=and_(DbPrompt.name == computed_name, DbPrompt.visibility == "private", DbPrompt.owner_email == owner_email, DbPrompt.id != prompt.id)
2173 )
2174 if existing_prompt:
2175 raise PromptNameConflictError(computed_name, enabled=existing_prompt.enabled, prompt_id=existing_prompt.id, visibility=existing_prompt.visibility)
2177 # Check ownership if user_email provided
2178 if user_email:
2179 # First-Party
2180 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2182 permission_service = PermissionService(db)
2183 if not await permission_service.check_resource_ownership(user_email, prompt):
2184 raise PermissionError("Only the owner can update this prompt")
2186 if prompt_update.name is not None:
2187 if prompt.gateway_id:
2188 prompt.custom_name = prompt_update.custom_name or prompt_update.name
2189 else:
2190 prompt.original_name = prompt_update.name
2191 if prompt_update.custom_name is None:
2192 prompt.custom_name = prompt_update.name
2193 if prompt_update.custom_name is not None:
2194 prompt.custom_name = prompt_update.custom_name
2195 if prompt_update.display_name is not None:
2196 prompt.display_name = prompt_update.display_name
2197 if prompt_update.description is not None:
2198 prompt.description = prompt_update.description
2199 if prompt_update.title is not None:
2200 prompt.title = prompt_update.title
2201 if prompt_update.template is not None:
2202 # Validate template size before updating
2203 content_security = get_content_security_service()
2204 content_security.validate_prompt_size(
2205 template=prompt_update.template,
2206 name=prompt.name,
2207 user_email=modified_by or user_email,
2208 ip_address=modified_from_ip,
2209 )
2210 prompt.template = prompt_update.template
2211 self._validate_template(prompt.template)
2212 # Clear template cache to reduce memory growth
2213 _compile_jinja_template.cache_clear()
2214 if prompt_update.arguments is not None:
2215 required_args = self._get_required_arguments(prompt.template)
2216 argument_schema = {
2217 "type": "object",
2218 "properties": {},
2219 "required": list(required_args),
2220 }
2221 for arg in prompt_update.arguments:
2222 schema = {"type": "string"}
2223 if arg.description is not None:
2224 schema["description"] = arg.description
2225 argument_schema["properties"][arg.name] = schema
2226 prompt.argument_schema = argument_schema
2228 if prompt_update.visibility is not None:
2229 # Validate visibility transitions
2230 if prompt_update.visibility == "team":
2231 target_team_id = prompt_update.team_id if prompt_update.team_id is not None else prompt.team_id
2232 _validate_prompt_team_assignment(db, user_email, target_team_id)
2233 prompt.visibility = prompt_update.visibility
2235 # Update tags if provided
2236 if prompt_update.tags is not None:
2237 prompt.tags = prompt_update.tags
2239 # Update team assignment if provided, validating ownership
2240 if prompt_update.team_id is not None:
2241 if prompt_update.team_id != prompt.team_id:
2242 _validate_prompt_team_assignment(db, user_email, prompt_update.team_id)
2243 prompt.team_id = prompt_update.team_id
2245 # Update metadata fields
2246 prompt.updated_at = datetime.now(timezone.utc)
2247 if modified_by:
2248 prompt.modified_by = modified_by
2249 if modified_from_ip:
2250 prompt.modified_from_ip = modified_from_ip
2251 if modified_via:
2252 prompt.modified_via = modified_via
2253 if modified_user_agent:
2254 prompt.modified_user_agent = modified_user_agent
2255 if hasattr(prompt, "version") and prompt.version is not None:
2256 prompt.version = prompt.version + 1
2257 else:
2258 prompt.version = 1
2260 db.commit()
2261 db.refresh(prompt)
2263 await self._notify_prompt_updated(prompt)
2265 # Structured logging: Audit trail for prompt update
2266 audit_trail.log_action(
2267 user_id=user_email or modified_by or "system",
2268 action="update_prompt",
2269 resource_type="prompt",
2270 resource_id=str(prompt.id),
2271 resource_name=prompt.name,
2272 user_email=user_email,
2273 team_id=prompt.team_id,
2274 client_ip=modified_from_ip,
2275 user_agent=modified_user_agent,
2276 new_values={"name": prompt.name, "version": prompt.version},
2277 context={"modified_via": modified_via},
2278 db=db,
2279 )
2281 structured_logger.log(
2282 level="INFO",
2283 message="Prompt updated successfully",
2284 event_type="prompt_updated",
2285 component="prompt_service",
2286 user_id=modified_by,
2287 user_email=user_email,
2288 team_id=prompt.team_id,
2289 resource_type="prompt",
2290 resource_id=str(prompt.id),
2291 custom_fields={"prompt_name": prompt.name, "version": prompt.version},
2292 )
2294 prompt.team = self._get_team_name(db, prompt.team_id)
2296 # Invalidate cache after successful update
2297 cache = _get_registry_cache()
2298 await cache.invalidate_prompts()
2299 # Also invalidate tags cache since prompt tags may have changed
2300 # First-Party
2301 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2303 await admin_stats_cache.invalidate_tags()
2305 return self.convert_prompt_to_read(prompt)
2307 except PermissionError as pe:
2308 db.rollback()
2310 structured_logger.log(
2311 level="WARNING",
2312 message="Prompt update failed due to permission error",
2313 event_type="prompt_update_permission_denied",
2314 component="prompt_service",
2315 user_email=user_email,
2316 resource_type="prompt",
2317 resource_id=str(prompt_id),
2318 error=pe,
2319 )
2320 raise
2321 except IntegrityError as ie:
2322 db.rollback()
2323 logger.error(f"IntegrityErrors in group: {ie}")
2325 structured_logger.log(
2326 level="ERROR",
2327 message="Prompt update failed due to database integrity error",
2328 event_type="prompt_update_failed",
2329 component="prompt_service",
2330 user_email=user_email,
2331 resource_type="prompt",
2332 resource_id=str(prompt_id),
2333 error=ie,
2334 )
2335 raise ie
2336 except PromptNotFoundError as e:
2337 db.rollback()
2338 logger.error(f"Prompt not found: {e}")
2340 structured_logger.log(
2341 level="ERROR",
2342 message="Prompt update failed - prompt not found",
2343 event_type="prompt_not_found",
2344 component="prompt_service",
2345 user_email=user_email,
2346 resource_type="prompt",
2347 resource_id=str(prompt_id),
2348 error=e,
2349 )
2350 raise e
2351 except PromptNameConflictError as pnce:
2352 db.rollback()
2353 logger.error(f"Prompt name conflict: {pnce}")
2355 structured_logger.log(
2356 level="WARNING",
2357 message="Prompt update failed due to name conflict",
2358 event_type="prompt_name_conflict",
2359 component="prompt_service",
2360 user_email=user_email,
2361 resource_type="prompt",
2362 resource_id=str(prompt_id),
2363 error=pnce,
2364 )
2365 raise pnce
2366 except ContentSizeError as cse:
2367 db.rollback()
2368 logger.error(f"Prompt template size limit exceeded: {cse.actual_size} bytes (max: {cse.max_size} bytes)")
2369 structured_logger.log(
2370 level="ERROR",
2371 message="Prompt update failed - Template size exceeded",
2372 event_type="prompt_update_failed",
2373 component="prompt_service",
2374 user_email=user_email,
2375 resource_type="prompt",
2376 resource_id=str(prompt_id),
2377 error=cse,
2378 )
2379 raise cse
2380 except Exception as e:
2381 db.rollback()
2383 structured_logger.log(
2384 level="ERROR",
2385 message="Prompt update failed",
2386 event_type="prompt_update_failed",
2387 component="prompt_service",
2388 user_email=user_email,
2389 resource_type="prompt",
2390 resource_id=str(prompt_id),
2391 error=e,
2392 )
2393 raise PromptError(f"Failed to update prompt: {str(e)}")
2395 async def set_prompt_state(self, db: Session, prompt_id: int, activate: bool, user_email: Optional[str] = None, skip_cache_invalidation: bool = False) -> PromptRead:
2396 """
2397 Set the activation status of a prompt.
2399 Args:
2400 db: Database session
2401 prompt_id: Prompt ID
2402 activate: True to activate, False to deactivate
2403 user_email: Optional[str] The email of the user to check if the user has permission to modify.
2404 skip_cache_invalidation: If True, skip cache invalidation (used for batch operations).
2406 Returns:
2407 The updated PromptRead object
2409 Raises:
2410 PromptNotFoundError: If the prompt is not found.
2411 PromptLockConflictError: If the prompt is locked by another transaction.
2412 PromptError: For other errors.
2413 PermissionError: If user doesn't own the prompt.
2415 Examples:
2416 >>> import logging
2417 >>> logging.disable(logging.CRITICAL)
2418 >>> from mcpgateway.services.prompt_service import PromptService
2419 >>> from unittest.mock import AsyncMock, MagicMock
2420 >>> service = PromptService()
2421 >>> db = MagicMock()
2422 >>> prompt = MagicMock()
2423 >>> db.get.return_value = prompt
2424 >>> db.commit = MagicMock()
2425 >>> db.refresh = MagicMock()
2426 >>> service._notify_prompt_activated = AsyncMock()
2427 >>> service._notify_prompt_deactivated = AsyncMock()
2428 >>> service.convert_prompt_to_read = MagicMock(return_value={})
2429 >>> import asyncio
2430 >>> try:
2431 ... result = asyncio.run(service.set_prompt_state(db, 1, True))
2432 ... except Exception:
2433 ... pass
2434 >>> logging.disable(logging.NOTSET)
2435 """
2436 try:
2437 # Use nowait=True to fail fast if row is locked, preventing lock contention under high load
2438 try:
2439 prompt = get_for_update(db, DbPrompt, prompt_id, nowait=True)
2440 except OperationalError as lock_err:
2441 # Row is locked by another transaction - fail fast with 409
2442 db.rollback()
2443 raise PromptLockConflictError(f"Prompt {prompt_id} is currently being modified by another request") from lock_err
2444 if not prompt:
2445 raise PromptNotFoundError(f"Prompt not found: {prompt_id}")
2447 if user_email:
2448 # First-Party
2449 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2451 permission_service = PermissionService(db)
2452 if not await permission_service.check_resource_ownership(user_email, prompt):
2453 raise PermissionError("Only the owner can activate the Prompt" if activate else "Only the owner can deactivate the Prompt")
2455 if prompt.enabled != activate:
2456 prompt.enabled = activate
2457 prompt.updated_at = datetime.now(timezone.utc)
2458 db.commit()
2459 db.refresh(prompt)
2461 # Invalidate cache after status change (skip for batch operations)
2462 if not skip_cache_invalidation:
2463 cache = _get_registry_cache()
2464 await cache.invalidate_prompts()
2466 if activate:
2467 await self._notify_prompt_activated(prompt)
2468 else:
2469 await self._notify_prompt_deactivated(prompt)
2470 logger.info(f"Prompt {prompt.name} {'activated' if activate else 'deactivated'}")
2472 # Structured logging: Audit trail for prompt state change
2473 audit_trail.log_action(
2474 user_id=user_email or "system",
2475 action="set_prompt_state",
2476 resource_type="prompt",
2477 resource_id=str(prompt.id),
2478 resource_name=prompt.name,
2479 user_email=user_email,
2480 team_id=prompt.team_id,
2481 new_values={"enabled": prompt.enabled},
2482 context={"action": "activate" if activate else "deactivate"},
2483 db=db,
2484 )
2486 structured_logger.log(
2487 level="INFO",
2488 message=f"Prompt {'activated' if activate else 'deactivated'} successfully",
2489 event_type="prompt_state_changed",
2490 component="prompt_service",
2491 user_email=user_email,
2492 team_id=prompt.team_id,
2493 resource_type="prompt",
2494 resource_id=str(prompt.id),
2495 custom_fields={"prompt_name": prompt.name, "enabled": prompt.enabled},
2496 )
2498 prompt.team = self._get_team_name(db, prompt.team_id)
2499 return self.convert_prompt_to_read(prompt)
2500 except PermissionError as e:
2501 structured_logger.log(
2502 level="WARNING",
2503 message="Prompt state change failed due to permission error",
2504 event_type="prompt_state_change_permission_denied",
2505 component="prompt_service",
2506 user_email=user_email,
2507 resource_type="prompt",
2508 resource_id=str(prompt_id),
2509 error=e,
2510 )
2511 raise e
2512 except PromptLockConflictError:
2513 # Re-raise lock conflicts without wrapping - allows 409 response
2514 raise
2515 except PromptNotFoundError:
2516 # Re-raise not found without wrapping - allows 404 response
2517 raise
2518 except Exception as e:
2519 db.rollback()
2521 structured_logger.log(
2522 level="ERROR",
2523 message="Prompt state change failed",
2524 event_type="prompt_state_change_failed",
2525 component="prompt_service",
2526 user_email=user_email,
2527 resource_type="prompt",
2528 resource_id=str(prompt_id),
2529 error=e,
2530 )
2531 raise PromptError(f"Failed to set prompt state: {str(e)}")
2533 # Get prompt details for admin ui
2535 async def get_prompt_details(self, db: Session, prompt_id: Union[int, str], include_inactive: bool = False) -> Dict[str, Any]: # pylint: disable=unused-argument
2536 """
2537 Get prompt details by ID.
2539 Args:
2540 db: Database session
2541 prompt_id: ID of prompt
2542 include_inactive: Whether to include inactive prompts
2544 Returns:
2545 Dictionary of prompt details
2547 Raises:
2548 PromptNotFoundError: If the prompt is not found
2550 Examples:
2551 >>> from mcpgateway.services.prompt_service import PromptService
2552 >>> from unittest.mock import MagicMock
2553 >>> service = PromptService()
2554 >>> db = MagicMock()
2555 >>> prompt_dict = {'id': '1', 'name': 'test', 'description': 'desc', 'template': 'tpl', 'arguments': [], 'createdAt': '2023-01-01T00:00:00', 'updatedAt': '2023-01-01T00:00:00', 'isActive': True, 'metrics': {}}
2556 >>> service.convert_prompt_to_read = MagicMock(return_value=prompt_dict)
2557 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock()
2558 >>> import asyncio
2559 >>> result = asyncio.run(service.get_prompt_details(db, 'prompt_name'))
2560 >>> result == prompt_dict
2561 True
2562 """
2563 prompt = db.get(DbPrompt, prompt_id)
2564 if not prompt:
2565 raise PromptNotFoundError(f"Prompt not found: {prompt_id}")
2566 # Return the fully converted prompt including metrics
2567 prompt.team = self._get_team_name(db, prompt.team_id)
2568 prompt_data = self.convert_prompt_to_read(prompt)
2570 audit_trail.log_action(
2571 user_id="system",
2572 action="view_prompt_details",
2573 resource_type="prompt",
2574 resource_id=str(prompt.id),
2575 resource_name=prompt.name,
2576 team_id=prompt.team_id,
2577 context={"include_inactive": include_inactive},
2578 db=db,
2579 )
2581 structured_logger.log(
2582 level="INFO",
2583 message="Prompt details retrieved",
2584 event_type="prompt_details_viewed",
2585 component="prompt_service",
2586 resource_type="prompt",
2587 resource_id=str(prompt.id),
2588 team_id=prompt.team_id,
2589 custom_fields={
2590 "prompt_name": prompt.name,
2591 "include_inactive": include_inactive,
2592 },
2593 )
2595 return prompt_data
2597 async def delete_prompt(self, db: Session, prompt_id: Union[int, str], user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
2598 """
2599 Delete a prompt template by its ID.
2601 Args:
2602 db (Session): Database session.
2603 prompt_id (str): ID of the prompt to delete.
2604 user_email (Optional[str]): Email of user performing delete (for ownership check).
2605 purge_metrics (bool): If True, delete raw + rollup metrics for this prompt.
2607 Raises:
2608 PromptNotFoundError: If the prompt is not found.
2609 PermissionError: If user doesn't own the prompt.
2610 PromptError: For other deletion errors.
2611 Exception: For unexpected errors.
2613 Examples:
2614 >>> import logging
2615 >>> logging.disable(logging.CRITICAL)
2616 >>> from mcpgateway.services.prompt_service import PromptService
2617 >>> from unittest.mock import AsyncMock, MagicMock
2618 >>> service = PromptService()
2619 >>> db = MagicMock()
2620 >>> prompt = MagicMock()
2621 >>> db.get.return_value = prompt
2622 >>> db.delete = MagicMock()
2623 >>> db.commit = MagicMock()
2624 >>> service._notify_prompt_deleted = AsyncMock()
2625 >>> import asyncio
2626 >>> try:
2627 ... asyncio.run(service.delete_prompt(db, '123'))
2628 ... except Exception:
2629 ... pass
2630 >>> logging.disable(logging.NOTSET)
2631 """
2632 try:
2633 prompt = db.get(DbPrompt, prompt_id)
2634 if not prompt:
2635 raise PromptNotFoundError(f"Prompt not found: {prompt_id}")
2637 # Check ownership if user_email provided
2638 if user_email:
2639 # First-Party
2640 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
2642 permission_service = PermissionService(db)
2643 if not await permission_service.check_resource_ownership(user_email, prompt):
2644 raise PermissionError("Only the owner can delete this prompt")
2646 prompt_info = {"id": prompt.id, "name": prompt.name}
2647 prompt_name = prompt.name
2648 prompt_team_id = prompt.team_id
2650 if purge_metrics:
2651 with pause_rollup_during_purge(reason=f"purge_prompt:{prompt_id}"):
2652 delete_metrics_in_batches(db, PromptMetric, PromptMetric.prompt_id, prompt_id)
2653 delete_metrics_in_batches(db, PromptMetricsHourly, PromptMetricsHourly.prompt_id, prompt_id)
2655 db.delete(prompt)
2656 db.commit()
2657 await self._notify_prompt_deleted(prompt_info)
2658 logger.info(f"Deleted prompt: {prompt_info['name']}")
2660 # Structured logging: Audit trail for prompt deletion
2661 audit_trail.log_action(
2662 user_id=user_email or "system",
2663 action="delete_prompt",
2664 resource_type="prompt",
2665 resource_id=str(prompt_info["id"]),
2666 resource_name=prompt_name,
2667 user_email=user_email,
2668 team_id=prompt_team_id,
2669 old_values={"name": prompt_name},
2670 db=db,
2671 )
2673 # Structured logging: Log successful prompt deletion
2674 structured_logger.log(
2675 level="INFO",
2676 message="Prompt deleted successfully",
2677 event_type="prompt_deleted",
2678 component="prompt_service",
2679 user_email=user_email,
2680 team_id=prompt_team_id,
2681 resource_type="prompt",
2682 resource_id=str(prompt_info["id"]),
2683 custom_fields={
2684 "prompt_name": prompt_name,
2685 "purge_metrics": purge_metrics,
2686 },
2687 )
2689 # Invalidate cache after successful deletion
2690 cache = _get_registry_cache()
2691 await cache.invalidate_prompts()
2692 # Also invalidate tags cache since prompt tags may have changed
2693 # First-Party
2694 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
2696 await admin_stats_cache.invalidate_tags()
2697 except PermissionError as pe:
2698 db.rollback()
2700 # Structured logging: Log permission error
2701 structured_logger.log(
2702 level="WARNING",
2703 message="Prompt deletion failed due to permission error",
2704 event_type="prompt_delete_permission_denied",
2705 component="prompt_service",
2706 user_email=user_email,
2707 resource_type="prompt",
2708 resource_id=str(prompt_id),
2709 error=pe,
2710 )
2711 raise
2712 except Exception as e:
2713 db.rollback()
2714 if isinstance(e, PromptNotFoundError):
2715 # Structured logging: Log not found error
2716 structured_logger.log(
2717 level="ERROR",
2718 message="Prompt deletion failed - prompt not found",
2719 event_type="prompt_not_found",
2720 component="prompt_service",
2721 user_email=user_email,
2722 resource_type="prompt",
2723 resource_id=str(prompt_id),
2724 error=e,
2725 )
2726 raise e
2728 # Structured logging: Log generic prompt deletion failure
2729 structured_logger.log(
2730 level="ERROR",
2731 message="Prompt deletion failed",
2732 event_type="prompt_deletion_failed",
2733 component="prompt_service",
2734 user_email=user_email,
2735 resource_type="prompt",
2736 resource_id=str(prompt_id),
2737 error=e,
2738 )
2739 raise PromptError(f"Failed to delete prompt: {str(e)}")
2741 async def subscribe_events(self) -> AsyncGenerator[Dict[str, Any], None]:
2742 """Subscribe to Prompt events via the EventService.
2744 Yields:
2745 Prompt event messages.
2746 """
2747 async for event in self._event_service.subscribe_events():
2748 yield event
2750 def _validate_template(self, template: str) -> None:
2751 """Validate template syntax.
2753 Args:
2754 template: Template to validate
2756 Raises:
2757 PromptValidationError: If template is invalid
2759 Examples:
2760 >>> from mcpgateway.services.prompt_service import PromptService
2761 >>> service = PromptService()
2762 >>> service._validate_template("Hello {{ name }}") # Valid template
2763 >>> try:
2764 ... service._validate_template("Hello {{ invalid") # Invalid template
2765 ... except Exception as e:
2766 ... "Invalid template syntax" in str(e)
2767 True
2768 """
2769 try:
2770 self._jinja_env.parse(template)
2771 except Exception as e:
2772 raise PromptValidationError(f"Invalid template syntax: {str(e)}")
2774 def _get_required_arguments(self, template: str) -> Set[str]:
2775 """Extract required arguments from template.
2777 Args:
2778 template: Template to analyze
2780 Returns:
2781 Set of required argument names
2783 Examples:
2784 >>> from mcpgateway.services.prompt_service import PromptService
2785 >>> service = PromptService()
2786 >>> args = service._get_required_arguments("Hello {{ name }} from {{ place }}")
2787 >>> sorted(args)
2788 ['name', 'place']
2789 >>> service._get_required_arguments("No variables") == set()
2790 True
2791 """
2792 ast = self._jinja_env.parse(template)
2793 variables = meta.find_undeclared_variables(ast)
2794 formatter = Formatter()
2795 format_vars = {field_name for _, field_name, _, _ in formatter.parse(template) if field_name is not None}
2796 return variables.union(format_vars)
2798 def _render_template(self, template: str, arguments: Dict[str, str]) -> str:
2799 """Render template with arguments using cached compiled templates.
2801 Args:
2802 template: Template to render
2803 arguments: Arguments for rendering
2805 Returns:
2806 Rendered template text
2808 Raises:
2809 PromptError: If rendering fails
2811 Examples:
2812 >>> from mcpgateway.services.prompt_service import PromptService
2813 >>> service = PromptService()
2814 >>> result = service._render_template("Hello {{ name }}", {"name": "World"})
2815 >>> result
2816 'Hello World'
2817 >>> service._render_template("No variables", {})
2818 'No variables'
2819 """
2820 try:
2821 jinja_template = _compile_jinja_template(template)
2822 return jinja_template.render(**arguments)
2823 except Exception:
2824 try:
2825 return template.format(**arguments)
2826 except Exception as e:
2827 raise PromptError(f"Failed to render template: {str(e)}")
2829 def _parse_messages(self, text: str) -> List[Message]:
2830 """Parse rendered text into messages.
2832 Args:
2833 text: Text to parse
2835 Returns:
2836 List of parsed messages
2838 Examples:
2839 >>> from mcpgateway.services.prompt_service import PromptService
2840 >>> service = PromptService()
2841 >>> messages = service._parse_messages("Simple text")
2842 >>> len(messages)
2843 1
2844 >>> messages[0].role.value
2845 'user'
2846 >>> messages = service._parse_messages("# User:\\nHello\\n# Assistant:\\nHi there")
2847 >>> len(messages)
2848 2
2849 """
2850 messages = []
2851 current_role = Role.USER
2852 current_text = []
2853 for line in text.split("\n"):
2854 if line.startswith("# Assistant:"):
2855 if current_text:
2856 messages.append(
2857 Message(
2858 role=current_role,
2859 content=TextContent(type="text", text="\n".join(current_text).strip()),
2860 )
2861 )
2862 current_role = Role.ASSISTANT
2863 current_text = []
2864 elif line.startswith("# User:"):
2865 if current_text:
2866 messages.append(
2867 Message(
2868 role=current_role,
2869 content=TextContent(type="text", text="\n".join(current_text).strip()),
2870 )
2871 )
2872 current_role = Role.USER
2873 current_text = []
2874 else:
2875 current_text.append(line)
2876 if current_text:
2877 messages.append(
2878 Message(
2879 role=current_role,
2880 content=TextContent(type="text", text="\n".join(current_text).strip()),
2881 )
2882 )
2883 return messages
2885 async def _notify_prompt_added(self, prompt: DbPrompt) -> None:
2886 """
2887 Notify subscribers of prompt addition.
2889 Args:
2890 prompt: Prompt to add
2891 """
2892 event = {
2893 "type": "prompt_added",
2894 "data": {
2895 "id": prompt.id,
2896 "name": prompt.name,
2897 "description": prompt.description,
2898 "enabled": prompt.enabled,
2899 },
2900 "timestamp": datetime.now(timezone.utc).isoformat(),
2901 }
2902 await self._publish_event(event)
2904 async def _notify_prompt_updated(self, prompt: DbPrompt) -> None:
2905 """
2906 Notify subscribers of prompt update.
2908 Args:
2909 prompt: Prompt to update
2910 """
2911 event = {
2912 "type": "prompt_updated",
2913 "data": {
2914 "id": prompt.id,
2915 "name": prompt.name,
2916 "description": prompt.description,
2917 "enabled": prompt.enabled,
2918 },
2919 "timestamp": datetime.now(timezone.utc).isoformat(),
2920 }
2921 await self._publish_event(event)
2923 async def _notify_prompt_activated(self, prompt: DbPrompt) -> None:
2924 """
2925 Notify subscribers of prompt activation.
2927 Args:
2928 prompt: Prompt to activate
2929 """
2930 event = {
2931 "type": "prompt_activated",
2932 "data": {"id": prompt.id, "name": prompt.name, "enabled": True},
2933 "timestamp": datetime.now(timezone.utc).isoformat(),
2934 }
2935 await self._publish_event(event)
2937 async def _notify_prompt_deactivated(self, prompt: DbPrompt) -> None:
2938 """
2939 Notify subscribers of prompt deactivation.
2941 Args:
2942 prompt: Prompt to deactivate
2943 """
2944 event = {
2945 "type": "prompt_deactivated",
2946 "data": {"id": prompt.id, "name": prompt.name, "enabled": False},
2947 "timestamp": datetime.now(timezone.utc).isoformat(),
2948 }
2949 await self._publish_event(event)
2951 async def _notify_prompt_deleted(self, prompt_info: Dict[str, Any]) -> None:
2952 """
2953 Notify subscribers of prompt deletion.
2955 Args:
2956 prompt_info: Dict on prompt to notify as deleted
2957 """
2958 event = {
2959 "type": "prompt_deleted",
2960 "data": prompt_info,
2961 "timestamp": datetime.now(timezone.utc).isoformat(),
2962 }
2963 await self._publish_event(event)
2965 async def _notify_prompt_removed(self, prompt: DbPrompt) -> None:
2966 """
2967 Notify subscribers of prompt removal (deactivation).
2969 Args:
2970 prompt: Prompt to remove
2971 """
2972 event = {
2973 "type": "prompt_removed",
2974 "data": {"id": prompt.id, "name": prompt.name, "enabled": False},
2975 "timestamp": datetime.now(timezone.utc).isoformat(),
2976 }
2977 await self._publish_event(event)
2979 async def _publish_event(self, event: Dict[str, Any]) -> None:
2980 """
2981 Publish event to all subscribers via the EventService.
2983 Args:
2984 event: Event to publish
2985 """
2986 await self._event_service.publish_event(event)
2988 # --- Metrics ---
2989 async def aggregate_metrics(self, db: Session) -> PromptMetrics:
2990 """
2991 Aggregate metrics for all prompt invocations across all prompts.
2993 Combines recent raw metrics (within retention period) with historical
2994 hourly rollups for complete historical coverage. Uses in-memory caching
2995 (10s TTL) to reduce database load under high request rates.
2997 Args:
2998 db: Database session
3000 Returns:
3001 PromptMetrics: Aggregated prompt metrics from raw + hourly rollups.
3003 Examples:
3004 >>> from mcpgateway.services.prompt_service import PromptService
3005 >>> service = PromptService()
3006 >>> # Method exists and is callable
3007 >>> callable(service.aggregate_metrics)
3008 True
3009 """
3010 # Check cache first (if enabled)
3011 # First-Party
3012 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
3014 if is_cache_enabled():
3015 cached = metrics_cache.get("prompts")
3016 if cached is not None:
3017 return PromptMetrics(**cached)
3019 # Use combined raw + rollup query for full historical coverage
3020 # First-Party
3021 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
3023 result = aggregate_metrics_combined(db, "prompt")
3025 metrics = PromptMetrics(
3026 total_executions=result.total_executions,
3027 successful_executions=result.successful_executions,
3028 failed_executions=result.failed_executions,
3029 failure_rate=result.failure_rate,
3030 min_response_time=result.min_response_time,
3031 max_response_time=result.max_response_time,
3032 avg_response_time=result.avg_response_time,
3033 last_execution_time=result.last_execution_time,
3034 )
3036 # Cache the result as dict for serialization compatibility (if enabled)
3037 if is_cache_enabled():
3038 metrics_cache.set("prompts", metrics.model_dump())
3040 return metrics
3042 async def reset_metrics(self, db: Session) -> None:
3043 """
3044 Reset all prompt metrics by deleting raw and hourly rollup records.
3046 Args:
3047 db: Database session
3049 Examples:
3050 >>> from mcpgateway.services.prompt_service import PromptService
3051 >>> from unittest.mock import MagicMock
3052 >>> service = PromptService()
3053 >>> db = MagicMock()
3054 >>> db.execute = MagicMock()
3055 >>> db.commit = MagicMock()
3056 >>> import asyncio
3057 >>> asyncio.run(service.reset_metrics(db))
3058 """
3060 db.execute(delete(PromptMetric))
3061 db.execute(delete(PromptMetricsHourly))
3062 db.commit()
3064 # Invalidate metrics cache
3065 # First-Party
3066 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
3068 metrics_cache.invalidate("prompts")
3069 metrics_cache.invalidate_prefix("top_prompts:")
3072# Lazy singleton - created on first access, not at module import time.
3073# This avoids instantiation when only exception classes are imported.
3074_prompt_service_instance = None # pylint: disable=invalid-name
3077def __getattr__(name: str):
3078 """Module-level __getattr__ for lazy singleton creation.
3080 Args:
3081 name: The attribute name being accessed.
3083 Returns:
3084 The prompt_service singleton instance if name is "prompt_service".
3086 Raises:
3087 AttributeError: If the attribute name is not "prompt_service".
3088 """
3089 global _prompt_service_instance # pylint: disable=global-statement
3090 if name == "prompt_service":
3091 if _prompt_service_instance is None:
3092 _prompt_service_instance = PromptService()
3093 return _prompt_service_instance
3094 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")