Coverage for mcpgateway / services / a2a_service.py: 100%
604 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2# pylint: disable=invalid-name, import-outside-toplevel, unused-import, no-name-in-module
3"""Location: ./mcpgateway/services/a2a_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8A2A Agent Service
10This module implements A2A (Agent-to-Agent) agent management for ContextForge.
11It handles agent registration, listing, retrieval, updates, activation toggling, deletion,
12and interactions with A2A-compatible agents.
13"""
15# Standard
16import binascii
17from datetime import datetime, timezone
18from typing import Any, AsyncGenerator, Dict, List, Optional, Union
20# Third-Party
21from pydantic import ValidationError
22from sqlalchemy import and_, delete, desc, or_, select
23from sqlalchemy.exc import IntegrityError
24from sqlalchemy.orm import Session
26# First-Party
27from mcpgateway.cache.a2a_stats_cache import a2a_stats_cache
28from mcpgateway.db import A2AAgent as DbA2AAgent
29from mcpgateway.db import A2AAgentMetric, A2AAgentMetricsHourly, EmailTeam, fresh_db_session, get_for_update
30from mcpgateway.schemas import A2AAgentAggregateMetrics, A2AAgentCreate, A2AAgentMetrics, A2AAgentRead, A2AAgentUpdate
31from mcpgateway.services.base_service import BaseService
32from mcpgateway.services.encryption_service import protect_oauth_config_for_storage
33from mcpgateway.services.logging_service import LoggingService
34from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
35from mcpgateway.services.structured_logger import get_structured_logger
36from mcpgateway.services.team_management_service import TeamManagementService
37from mcpgateway.utils.correlation_id import get_correlation_id
38from mcpgateway.utils.create_slug import slugify
39from mcpgateway.utils.pagination import unified_paginate
40from mcpgateway.utils.services_auth import decode_auth, encode_auth
41from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
43# Cache import (lazy to avoid circular dependencies)
44_REGISTRY_CACHE = None
47def _get_registry_cache():
48 """Get registry cache singleton lazily.
50 Returns:
51 RegistryCache instance.
52 """
53 global _REGISTRY_CACHE # pylint: disable=global-statement
54 if _REGISTRY_CACHE is None:
55 # First-Party
56 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
58 _REGISTRY_CACHE = registry_cache
59 return _REGISTRY_CACHE
62# Initialize logging service first
63logging_service = LoggingService()
64logger = logging_service.get_logger(__name__)
66# Initialize structured logger for A2A lifecycle tracking
67structured_logger = get_structured_logger("a2a_service")
70class A2AAgentError(Exception):
71 """Base class for A2A agent-related errors.
73 Examples:
74 >>> try:
75 ... raise A2AAgentError("Agent operation failed")
76 ... except A2AAgentError as e:
77 ... str(e)
78 'Agent operation failed'
79 >>> try:
80 ... raise A2AAgentError("Connection error")
81 ... except Exception as e:
82 ... isinstance(e, A2AAgentError)
83 True
84 """
87class A2AAgentNotFoundError(A2AAgentError):
88 """Raised when a requested A2A agent is not found.
90 Examples:
91 >>> try:
92 ... raise A2AAgentNotFoundError("Agent 'test-agent' not found")
93 ... except A2AAgentNotFoundError as e:
94 ... str(e)
95 "Agent 'test-agent' not found"
96 >>> try:
97 ... raise A2AAgentNotFoundError("No such agent")
98 ... except A2AAgentError as e:
99 ... isinstance(e, A2AAgentError) # Should inherit from A2AAgentError
100 True
101 """
104class A2AAgentNameConflictError(A2AAgentError):
105 """Raised when an A2A agent name conflicts with an existing one."""
107 def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = None, visibility: Optional[str] = "public"):
108 """Initialize an A2AAgentNameConflictError exception.
110 Creates an exception that indicates an agent name conflict, with additional
111 context about whether the conflicting agent is active and its ID if known.
113 Args:
114 name: The agent name that caused the conflict.
115 is_active: Whether the conflicting agent is currently active.
116 agent_id: The ID of the conflicting agent, if known.
117 visibility: The visibility level of the conflicting agent (private, team, public).
119 Examples:
120 >>> error = A2AAgentNameConflictError("test-agent")
121 >>> error.name
122 'test-agent'
123 >>> error.is_active
124 True
125 >>> error.agent_id is None
126 True
127 >>> "test-agent" in str(error)
128 True
129 >>>
130 >>> # Test inactive agent conflict
131 >>> error = A2AAgentNameConflictError("inactive-agent", is_active=False, agent_id="agent-123")
132 >>> error.is_active
133 False
134 >>> error.agent_id
135 'agent-123'
136 >>> "inactive" in str(error)
137 True
138 >>> "agent-123" in str(error)
139 True
140 """
141 self.name = name
142 self.is_active = is_active
143 self.agent_id = agent_id
144 message = f"{visibility.capitalize()} A2A Agent already exists with name: {name}"
145 if not is_active:
146 message += f" (currently inactive, ID: {agent_id})"
147 super().__init__(message)
150class A2AAgentService(BaseService):
151 """Service for managing A2A agents in the gateway.
153 Provides methods to create, list, retrieve, update, set state, and delete agent records.
154 Also supports interactions with A2A-compatible agents.
155 """
157 _visibility_model_cls = DbA2AAgent
159 def __init__(self) -> None:
160 """Initialize a new A2AAgentService instance."""
161 self._initialized = False
162 self._event_streams: List[AsyncGenerator[str, None]] = []
164 async def initialize(self) -> None:
165 """Initialize the A2A agent service."""
166 if not self._initialized:
167 logger.info("Initializing A2A Agent Service")
168 self._initialized = True
170 async def shutdown(self) -> None:
171 """Shutdown the A2A agent service and cleanup resources."""
172 if self._initialized:
173 logger.info("Shutting down A2A Agent Service")
174 self._initialized = False
176 def _get_team_name(self, db: Session, team_id: Optional[str]) -> Optional[str]:
177 """Retrieve the team name given a team ID.
179 Args:
180 db (Session): Database session for querying teams.
181 team_id (Optional[str]): The ID of the team.
183 Returns:
184 Optional[str]: The name of the team if found, otherwise None.
185 """
186 if not team_id:
187 return None
189 team = db.query(EmailTeam).filter(EmailTeam.id == team_id, EmailTeam.is_active.is_(True)).first()
190 db.commit() # Release transaction to avoid idle-in-transaction
191 return team.name if team else None
193 def _batch_get_team_names(self, db: Session, team_ids: List[str]) -> Dict[str, str]:
194 """Batch retrieve team names for multiple team IDs.
196 This method fetches team names in a single query to avoid N+1 issues
197 when converting multiple agents to schemas in list operations.
199 Args:
200 db (Session): Database session for querying teams.
201 team_ids (List[str]): List of team IDs to look up.
203 Returns:
204 Dict[str, str]: Mapping of team_id -> team_name for active teams.
205 """
206 if not team_ids:
207 return {}
209 # Single query for all teams
210 teams = db.query(EmailTeam.id, EmailTeam.name).filter(EmailTeam.id.in_(team_ids), EmailTeam.is_active.is_(True)).all()
212 return {team.id: team.name for team in teams}
214 def _check_agent_access(
215 self,
216 agent: DbA2AAgent,
217 user_email: Optional[str],
218 token_teams: Optional[List[str]],
219 ) -> bool:
220 """Check if user has access to agent based on visibility rules.
222 Access rules (matching tools/resources/prompts):
223 - public visibility: Always allowed
224 - token_teams is None AND user_email is None: Admin bypass (unrestricted access)
225 - No user context (but not admin): Deny access to non-public agents
226 - team visibility: Allowed if agent.team_id in token_teams
227 - private visibility: Allowed if owner (requires user_email and non-empty token_teams)
229 Args:
230 agent: The agent to check access for
231 user_email: User's email for owner matching
232 token_teams: Teams from JWT. None = admin bypass, [] = public-only (no owner access)
234 Returns:
235 True if access allowed, False otherwise.
236 """
237 # Public agents are accessible by everyone
238 if agent.visibility == "public":
239 return True
241 # Admin bypass: token_teams=None AND user_email=None means unrestricted admin
242 # This happens when is_admin=True and no team scoping in token
243 if token_teams is None and user_email is None:
244 return True
246 # No user context (but not admin) = deny access to non-public agents
247 if not user_email:
248 return False
250 # Public-only tokens (empty teams array) can ONLY access public agents
251 is_public_only_token = token_teams is not None and len(token_teams) == 0
252 if is_public_only_token:
253 return False # Already checked public above
255 # Owner can access their own private agents
256 if agent.visibility == "private" and agent.owner_email and agent.owner_email == user_email:
257 return True
259 # Team agents: check team membership
260 # At this point token_teams is guaranteed to be a non-empty list
261 # (None handled by admin bypass, [] by public-only check)
262 if agent.visibility == "team":
263 return agent.team_id in token_teams
265 return False
267 async def register_agent(
268 self,
269 db: Session,
270 agent_data: A2AAgentCreate,
271 created_by: Optional[str] = None,
272 created_from_ip: Optional[str] = None,
273 created_via: Optional[str] = None,
274 created_user_agent: Optional[str] = None,
275 import_batch_id: Optional[str] = None,
276 federation_source: Optional[str] = None,
277 team_id: Optional[str] = None,
278 owner_email: Optional[str] = None,
279 visibility: Optional[str] = "public",
280 ) -> A2AAgentRead:
281 """Register a new A2A agent.
283 Args:
284 db (Session): Database session.
285 agent_data (A2AAgentCreate): Data required to create an agent.
286 created_by (Optional[str]): User who created the agent.
287 created_from_ip (Optional[str]): IP address of the creator.
288 created_via (Optional[str]): Method used for creation (e.g., API, import).
289 created_user_agent (Optional[str]): User agent of the creation request.
290 import_batch_id (Optional[str]): UUID of a bulk import batch.
291 federation_source (Optional[str]): Source gateway for federated agents.
292 team_id (Optional[str]): ID of the team to assign the agent to.
293 owner_email (Optional[str]): Email of the agent owner.
294 visibility (Optional[str]): Visibility level ('public', 'team', 'private').
296 Returns:
297 A2AAgentRead: The created agent object.
299 Raises:
300 A2AAgentNameConflictError: If another agent with the same name already exists.
301 IntegrityError: If a database constraint or integrity violation occurs.
302 ValueError: If invalid configuration or data is provided.
303 A2AAgentError: For any other unexpected errors during registration.
305 Examples:
306 # TODO
307 """
308 try:
309 agent_data.slug = slugify(agent_data.name)
310 # Check for existing server with the same slug within the same team or public scope
311 if visibility.lower() == "public":
312 logger.info(f"visibility.lower(): {visibility.lower()}")
313 logger.info(f"agent_data.name: {agent_data.name}")
314 logger.info(f"agent_data.slug: {agent_data.slug}")
315 # Check for existing public a2a agent with the same slug
316 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "public"))
317 if existing_agent:
318 raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
319 elif visibility.lower() == "team" and team_id:
320 # Check for existing team a2a agent with the same slug
321 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id))
322 if existing_agent:
323 raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
325 auth_type = getattr(agent_data, "auth_type", None)
326 # Support multiple custom headers
327 auth_value = getattr(agent_data, "auth_value", {})
329 # authentication_headers: Optional[Dict[str, str]] = None
331 if hasattr(agent_data, "auth_headers") and agent_data.auth_headers:
332 # Convert list of {key, value} to dict
333 header_dict = {h["key"]: h["value"] for h in agent_data.auth_headers if h.get("key")}
334 # Keep encoded form for persistence, but pass raw headers for initialization
335 auth_value = encode_auth(header_dict) # Encode the dict for consistency
336 # authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
337 # elif isinstance(auth_value, str) and auth_value:
338 # # Decode persisted auth for initialization
339 # decoded = decode_auth(auth_value)
340 # authentication_headers = {str(k): str(v) for k, v in decoded.items()}
341 else:
342 # authentication_headers = None
343 pass
344 # auth_value = {}
346 oauth_config = await protect_oauth_config_for_storage(getattr(agent_data, "oauth_config", None))
348 # Handle query_param auth - encrypt and prepare for storage
349 auth_query_params_encrypted: Optional[Dict[str, str]] = None
350 if auth_type == "query_param":
351 # Standard
352 from urllib.parse import urlparse # pylint: disable=import-outside-toplevel
354 # First-Party
355 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
357 # Service-layer enforcement: Check feature flag
358 if not settings.insecure_allow_queryparam_auth:
359 raise ValueError("Query parameter authentication is disabled. Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
361 # Service-layer enforcement: Check host allowlist
362 if settings.insecure_queryparam_auth_allowed_hosts:
363 parsed = urlparse(str(agent_data.endpoint_url))
364 hostname = (parsed.hostname or "").lower()
365 allowed_hosts = [h.lower() for h in settings.insecure_queryparam_auth_allowed_hosts]
366 if hostname not in allowed_hosts:
367 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
368 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. " f"Allowed: {allowed}")
370 # Extract and encrypt query param auth
371 param_key = getattr(agent_data, "auth_query_param_key", None)
372 param_value = getattr(agent_data, "auth_query_param_value", None)
373 if param_key and param_value:
374 # Handle SecretStr
375 if hasattr(param_value, "get_secret_value"):
376 raw_value = param_value.get_secret_value()
377 else:
378 raw_value = str(param_value)
379 # Encrypt for storage
380 encrypted_value = encode_auth({param_key: raw_value})
381 auth_query_params_encrypted = {param_key: encrypted_value}
382 # Query param auth doesn't use auth_value
383 auth_value = None
385 # Create new agent
386 new_agent = DbA2AAgent(
387 name=agent_data.name,
388 description=agent_data.description,
389 endpoint_url=agent_data.endpoint_url,
390 agent_type=agent_data.agent_type,
391 protocol_version=agent_data.protocol_version,
392 capabilities=agent_data.capabilities,
393 config=agent_data.config,
394 auth_type=auth_type,
395 auth_value=auth_value, # This should be encrypted in practice
396 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth
397 oauth_config=oauth_config,
398 tags=agent_data.tags,
399 passthrough_headers=getattr(agent_data, "passthrough_headers", None),
400 # Team scoping fields - use schema values if provided, otherwise fallback to parameters
401 team_id=getattr(agent_data, "team_id", None) or team_id,
402 owner_email=getattr(agent_data, "owner_email", None) or owner_email or created_by,
403 # Endpoint visibility parameter takes precedence over schema default
404 visibility=visibility if visibility is not None else getattr(agent_data, "visibility", "public"),
405 created_by=created_by,
406 created_from_ip=created_from_ip,
407 created_via=created_via,
408 created_user_agent=created_user_agent,
409 import_batch_id=import_batch_id,
410 federation_source=federation_source,
411 )
413 db.add(new_agent)
414 # Commit agent FIRST to ensure it persists even if tool creation fails
415 # This is critical because ToolService.register_tool calls db.rollback()
416 # on error, which would undo a pending (flushed but uncommitted) agent
417 db.commit()
418 db.refresh(new_agent)
420 # Invalidate caches since agent count changed
421 # Wrapped in try/except to ensure cache failures don't fail the request
422 # when the agent is already successfully committed
423 try:
424 a2a_stats_cache.invalidate()
425 cache = _get_registry_cache()
426 await cache.invalidate_agents()
427 # Also invalidate tags cache since agent tags may have changed
428 # First-Party
429 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
431 await admin_stats_cache.invalidate_tags()
432 # First-Party
433 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
435 metrics_cache.invalidate("a2a")
436 except Exception as cache_error:
437 logger.warning(f"Cache invalidation failed after agent commit: {cache_error}")
439 # Automatically create a tool for the A2A agent if not already present
440 # Tool creation is wrapped in try/except to ensure agent registration succeeds
441 # even if tool creation fails (e.g., due to visibility or permission issues)
442 tool_db = None
443 try:
444 # First-Party
445 from mcpgateway.services.tool_service import tool_service
447 tool_db = await tool_service.create_tool_from_a2a_agent(
448 db=db,
449 agent=new_agent,
450 created_by=created_by,
451 created_from_ip=created_from_ip,
452 created_via=created_via,
453 created_user_agent=created_user_agent,
454 )
456 # Associate the tool with the agent using the relationship
457 # This sets both the tool_id foreign key and the tool relationship
458 new_agent.tool = tool_db
459 db.commit()
460 db.refresh(new_agent)
461 logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id}) with tool ID: {tool_db.id}")
462 except Exception as tool_error:
463 # Log the error but don't fail agent registration
464 # Agent was already committed above, so it persists even if tool creation fails
465 logger.warning(f"Failed to create tool for A2A agent {new_agent.name}: {tool_error}")
466 structured_logger.warning(
467 f"A2A agent '{new_agent.name}' created without tool association",
468 user_id=created_by,
469 resource_type="a2a_agent",
470 resource_id=str(new_agent.id),
471 custom_fields={"error": str(tool_error), "agent_name": new_agent.name},
472 )
473 # Refresh the agent to ensure it's in a clean state after any rollback
474 db.refresh(new_agent)
475 logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id}) without tool")
477 # Log A2A agent registration for lifecycle tracking
478 structured_logger.info(
479 f"A2A agent '{new_agent.name}' registered successfully",
480 user_id=created_by,
481 user_email=owner_email,
482 team_id=team_id,
483 resource_type="a2a_agent",
484 resource_id=str(new_agent.id),
485 resource_action="create",
486 custom_fields={
487 "agent_name": new_agent.name,
488 "agent_type": new_agent.agent_type,
489 "protocol_version": new_agent.protocol_version,
490 "visibility": visibility,
491 "endpoint_url": new_agent.endpoint_url,
492 },
493 )
495 return self.convert_agent_to_read(new_agent, db=db)
497 except A2AAgentNameConflictError as ie:
498 db.rollback()
499 raise ie
500 except IntegrityError as ie:
501 db.rollback()
502 logger.error(f"IntegrityErrors in group: {ie}")
503 raise ie
504 except ValueError as ve:
505 raise ve
506 except Exception as e:
507 db.rollback()
508 raise A2AAgentError(f"Failed to register A2A agent: {str(e)}")
510 async def list_agents(
511 self,
512 db: Session,
513 cursor: Optional[str] = None,
514 include_inactive: bool = False,
515 tags: Optional[List[str]] = None,
516 limit: Optional[int] = None,
517 page: Optional[int] = None,
518 per_page: Optional[int] = None,
519 user_email: Optional[str] = None,
520 token_teams: Optional[List[str]] = None,
521 team_id: Optional[str] = None,
522 visibility: Optional[str] = None,
523 ) -> Union[tuple[List[A2AAgentRead], Optional[str]], Dict[str, Any]]:
524 """List A2A agents with cursor pagination and optional team filtering.
526 Args:
527 db: Database session.
528 cursor: Pagination cursor for keyset pagination.
529 include_inactive: Whether to include inactive agents.
530 tags: List of tags to filter by.
531 limit: Maximum number of agents to return. None for default, 0 for unlimited.
532 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
533 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
534 user_email: Email of user for owner matching in visibility checks.
535 token_teams: Teams from JWT token. None = admin (no filtering),
536 [] = public-only, [...] = team-scoped access.
537 team_id: Optional team ID to filter by specific team.
538 visibility: Optional visibility filter (private, team, public).
540 Returns:
541 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
542 If cursor is provided or neither: tuple of (list of A2AAgentRead objects, next_cursor).
544 Examples:
545 >>> from mcpgateway.services.a2a_service import A2AAgentService
546 >>> from unittest.mock import MagicMock
547 >>> from mcpgateway.schemas import A2AAgentRead
548 >>> import asyncio
550 >>> service = A2AAgentService()
551 >>> db = MagicMock()
553 >>> # Mock a single agent object returned by the DB
554 >>> agent_obj = MagicMock()
555 >>> db.execute.return_value.scalars.return_value.all.return_value = [agent_obj]
557 >>> # Mock the A2AAgentRead schema to return a masked string
558 >>> mocked_agent_read = MagicMock()
559 >>> mocked_agent_read.masked.return_value = 'agent_read'
560 >>> A2AAgentRead.model_validate = MagicMock(return_value=mocked_agent_read)
562 >>> # Run the service method
563 >>> agents, cursor = asyncio.run(service.list_agents(db))
564 >>> agents == ['agent_read'] and cursor is None
565 True
567 >>> # Test include_inactive parameter (same mock works)
568 >>> agents_with_inactive, cursor = asyncio.run(service.list_agents(db, include_inactive=True))
569 >>> agents_with_inactive == ['agent_read'] and cursor is None
570 True
572 >>> # Test empty result
573 >>> db.execute.return_value.scalars.return_value.all.return_value = []
574 >>> empty_agents, cursor = asyncio.run(service.list_agents(db))
575 >>> empty_agents == [] and cursor is None
576 True
578 """
579 # ══════════════════════════════════════════════════════════════════════
580 # CACHE READ: Skip cache when ANY access filtering is applied
581 # This prevents leaking admin-level results to filtered requests
582 # Cache only when: user_email is None AND token_teams is None AND page is None
583 # ══════════════════════════════════════════════════════════════════════
584 cache = _get_registry_cache()
585 if cursor is None and user_email is None and token_teams is None and page is None:
586 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None)
587 cached = await cache.get("agents", filters_hash)
588 if cached is not None:
589 # Reconstruct A2AAgentRead objects from cached dicts
590 cached_agents = [A2AAgentRead.model_validate(a).masked() for a in cached["agents"]]
591 return (cached_agents, cached.get("next_cursor"))
593 # Build base query with ordering
594 query = select(DbA2AAgent).order_by(desc(DbA2AAgent.created_at), desc(DbA2AAgent.id))
596 # Apply active/inactive filter
597 if not include_inactive:
598 query = query.where(DbA2AAgent.enabled)
600 query = await self._apply_access_control(query, db, user_email, token_teams, team_id)
602 if visibility:
603 query = query.where(DbA2AAgent.visibility == visibility)
605 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
606 if tags:
607 query = query.where(json_contains_tag_expr(db, DbA2AAgent.tags, tags, match_any=True))
609 # Use unified pagination helper - handles both page and cursor pagination
610 pag_result = await unified_paginate(
611 db=db,
612 query=query,
613 page=page,
614 per_page=per_page,
615 cursor=cursor,
616 limit=limit,
617 base_url="/admin/a2a", # Used for page-based links
618 query_params={"include_inactive": include_inactive} if include_inactive else {},
619 )
621 next_cursor = None
622 # Extract servers based on pagination type
623 if page is not None:
624 # Page-based: pag_result is a dict
625 a2a_agents_db = pag_result["data"]
626 else:
627 # Cursor-based: pag_result is a tuple
628 a2a_agents_db, next_cursor = pag_result
630 # Fetch team names for the agents (common for both pagination types)
631 team_ids_set = {s.team_id for s in a2a_agents_db if s.team_id}
632 team_map = {}
633 if team_ids_set:
634 teams = db.execute(select(EmailTeam.id, EmailTeam.name).where(EmailTeam.id.in_(team_ids_set), EmailTeam.is_active.is_(True))).all()
635 team_map = {team.id: team.name for team in teams}
637 db.commit() # Release transaction to avoid idle-in-transaction
639 # Convert to A2AAgentRead (common for both pagination types)
640 result = []
641 for s in a2a_agents_db:
642 try:
643 s.team = team_map.get(s.team_id) if s.team_id else None
644 result.append(self.convert_agent_to_read(s, include_metrics=False, db=db, team_map=team_map))
645 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
646 logger.exception(f"Failed to convert A2A agent {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
647 # Continue with remaining agents instead of failing completely
649 # Return appropriate format based on pagination type
650 if page is not None:
651 # Page-based format
652 return {
653 "data": result,
654 "pagination": pag_result["pagination"],
655 "links": pag_result["links"],
656 }
658 # Cursor-based format
660 # ══════════════════════════════════════════════════════════════════════
661 # CACHE WRITE: Only cache admin-level results (matches read guard)
662 # MUST check token_teams is None to prevent caching scoped responses
663 # ══════════════════════════════════════════════════════════════════════
664 if cursor is None and user_email is None and token_teams is None:
665 try:
666 cache_data = {"agents": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
667 await cache.set("agents", cache_data, filters_hash)
668 except AttributeError:
669 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
671 return (result, next_cursor)
673 async def list_agents_for_user(
674 self, db: Session, user_info: Dict[str, Any], team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
675 ) -> List[A2AAgentRead]:
676 """
677 DEPRECATED: Use list_agents() with user_email parameter instead.
679 This method is maintained for backward compatibility but is no longer used.
680 New code should call list_agents() with user_email, team_id, and visibility parameters.
682 List A2A agents user has access to with team filtering.
684 Args:
685 db: Database session
686 user_info: Object representing identity of the user who is requesting agents
687 team_id: Optional team ID to filter by specific team
688 visibility: Optional visibility filter (private, team, public)
689 include_inactive: Whether to include inactive agents
690 skip: Number of agents to skip for pagination
691 limit: Maximum number of agents to return
693 Returns:
694 List[A2AAgentRead]: A2A agents the user has access to
695 """
697 # Handle case where user_info is a string (email) instead of dict (<0.7.0)
698 if isinstance(user_info, str):
699 user_email = str(user_info)
700 else:
701 user_email = user_info.get("email", "")
703 # Build query following existing patterns from list_prompts()
704 team_service = TeamManagementService(db)
705 user_teams = await team_service.get_user_teams(user_email)
706 team_ids = [team.id for team in user_teams]
708 # Build query following existing patterns from list_agents()
709 query = select(DbA2AAgent)
711 # Apply active/inactive filter
712 if not include_inactive:
713 query = query.where(DbA2AAgent.enabled.is_(True))
715 if team_id:
716 if team_id not in team_ids:
717 return [] # No access to team
719 access_conditions = []
720 # Filter by specific team
721 access_conditions.append(and_(DbA2AAgent.team_id == team_id, DbA2AAgent.visibility.in_(["team", "public"])))
723 access_conditions.append(and_(DbA2AAgent.team_id == team_id, DbA2AAgent.owner_email == user_email))
725 query = query.where(or_(*access_conditions))
726 else:
727 # Get user's accessible teams
728 # Build access conditions following existing patterns
729 access_conditions = []
730 # 1. User's personal resources (owner_email matches)
731 access_conditions.append(DbA2AAgent.owner_email == user_email)
732 # 2. Team A2A Agents where user is member
733 if team_ids:
734 access_conditions.append(and_(DbA2AAgent.team_id.in_(team_ids), DbA2AAgent.visibility.in_(["team", "public"])))
735 # 3. Public resources (if visibility allows)
736 access_conditions.append(DbA2AAgent.visibility == "public")
738 query = query.where(or_(*access_conditions))
740 # Apply visibility filter if specified
741 if visibility:
742 query = query.where(DbA2AAgent.visibility == visibility)
744 # Apply pagination following existing patterns
745 query = query.order_by(desc(DbA2AAgent.created_at))
746 query = query.offset(skip).limit(limit)
748 agents = db.execute(query).scalars().all()
750 # Batch fetch team names to avoid N+1 queries
751 team_ids = list({a.team_id for a in agents if a.team_id})
752 team_map = self._batch_get_team_names(db, team_ids)
754 db.commit() # Release transaction to avoid idle-in-transaction
756 # Skip metrics to avoid N+1 queries in list operations
757 result = []
758 for agent in agents:
759 try:
760 result.append(self.convert_agent_to_read(agent, include_metrics=False, db=db, team_map=team_map))
761 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
762 logger.exception(f"Failed to convert A2A agent {getattr(agent, 'id', 'unknown')} ({getattr(agent, 'name', 'unknown')}): {e}")
763 # Continue with remaining agents instead of failing completely
765 return result
767 async def get_agent(
768 self,
769 db: Session,
770 agent_id: str,
771 include_inactive: bool = True,
772 user_email: Optional[str] = None,
773 token_teams: Optional[List[str]] = None,
774 ) -> A2AAgentRead:
775 """Retrieve an A2A agent by ID.
777 Args:
778 db: Database session.
779 agent_id: Agent ID.
780 include_inactive: Whether to include inactive a2a agents.
781 user_email: User's email for owner matching in visibility checks.
782 token_teams: Teams from JWT token. None = admin (no filtering),
783 [] = public-only, [...] = team-scoped access.
785 Returns:
786 Agent data.
788 Raises:
789 A2AAgentNotFoundError: If the agent is not found or user lacks access.
791 Examples:
792 >>> from unittest.mock import MagicMock
793 >>> from datetime import datetime
794 >>> import asyncio
795 >>> from mcpgateway.schemas import A2AAgentRead
796 >>> from mcpgateway.services.a2a_service import A2AAgentService, A2AAgentNotFoundError
798 >>> service = A2AAgentService()
799 >>> db = MagicMock()
801 >>> # Create a mock agent
802 >>> agent_mock = MagicMock()
803 >>> agent_mock.enabled = True
804 >>> agent_mock.id = "agent_id"
805 >>> agent_mock.name = "Test Agent"
806 >>> agent_mock.slug = "test-agent"
807 >>> agent_mock.description = "A2A test agent"
808 >>> agent_mock.endpoint_url = "https://example.com"
809 >>> agent_mock.agent_type = "rest"
810 >>> agent_mock.protocol_version = "v1"
811 >>> agent_mock.capabilities = {}
812 >>> agent_mock.config = {}
813 >>> agent_mock.reachable = True
814 >>> agent_mock.created_at = datetime.now()
815 >>> agent_mock.updated_at = datetime.now()
816 >>> agent_mock.last_interaction = None
817 >>> agent_mock.tags = []
818 >>> agent_mock.metrics = MagicMock()
819 >>> agent_mock.metrics.success_rate = 1.0
820 >>> agent_mock.metrics.failure_rate = 0.0
821 >>> agent_mock.metrics.last_error = None
822 >>> agent_mock.auth_type = None
823 >>> agent_mock.auth_value = None
824 >>> agent_mock.oauth_config = None
825 >>> agent_mock.created_by = "user"
826 >>> agent_mock.created_from_ip = "127.0.0.1"
827 >>> agent_mock.created_via = "ui"
828 >>> agent_mock.created_user_agent = "test-agent"
829 >>> agent_mock.modified_by = "user"
830 >>> agent_mock.modified_from_ip = "127.0.0.1"
831 >>> agent_mock.modified_via = "ui"
832 >>> agent_mock.modified_user_agent = "test-agent"
833 >>> agent_mock.import_batch_id = None
834 >>> agent_mock.federation_source = None
835 >>> agent_mock.team_id = "team-1"
836 >>> agent_mock.team = "Team 1"
837 >>> agent_mock.owner_email = "owner@example.com"
838 >>> agent_mock.visibility = "public"
840 >>> db.get.return_value = agent_mock
842 >>> # Mock convert_agent_to_read to simplify test
843 >>> service.convert_agent_to_read = lambda db_agent, **kwargs: 'agent_read'
845 >>> # Test with active agent
846 >>> result = asyncio.run(service.get_agent(db, 'agent_id'))
847 >>> result
848 'agent_read'
850 >>> # Test with inactive agent but include_inactive=True
851 >>> agent_mock.enabled = False
852 >>> result_inactive = asyncio.run(service.get_agent(db, 'agent_id', include_inactive=True))
853 >>> result_inactive
854 'agent_read'
856 """
857 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
858 agent = db.execute(query).scalar_one_or_none()
860 if not agent:
861 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
863 if not agent.enabled and not include_inactive:
864 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
866 # SECURITY: Check visibility/team access
867 # Return 404 (not 403) to avoid leaking existence of private agents
868 if not self._check_agent_access(agent, user_email, token_teams):
869 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
871 # Delegate conversion and masking to convert_agent_to_read()
872 return self.convert_agent_to_read(agent, db=db)
874 async def get_agent_by_name(self, db: Session, agent_name: str) -> A2AAgentRead:
875 """Retrieve an A2A agent by name.
877 Args:
878 db: Database session.
879 agent_name: Agent name.
881 Returns:
882 Agent data.
884 Raises:
885 A2AAgentNotFoundError: If the agent is not found.
886 """
887 query = select(DbA2AAgent).where(DbA2AAgent.name == agent_name)
888 agent = db.execute(query).scalar_one_or_none()
890 if not agent:
891 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
893 return self.convert_agent_to_read(agent, db=db)
895 async def update_agent(
896 self,
897 db: Session,
898 agent_id: str,
899 agent_data: A2AAgentUpdate,
900 modified_by: Optional[str] = None,
901 modified_from_ip: Optional[str] = None,
902 modified_via: Optional[str] = None,
903 modified_user_agent: Optional[str] = None,
904 user_email: Optional[str] = None,
905 ) -> A2AAgentRead:
906 """Update an existing A2A agent.
908 Args:
909 db: Database session.
910 agent_id: Agent ID.
911 agent_data: Agent update data.
912 modified_by: Username who modified this agent.
913 modified_from_ip: IP address of modifier.
914 modified_via: Modification method.
915 modified_user_agent: User agent of modification request.
916 user_email: Email of user performing update (for ownership check).
918 Returns:
919 Updated agent data.
921 Raises:
922 A2AAgentNotFoundError: If the agent is not found.
923 PermissionError: If user doesn't own the agent.
924 A2AAgentNameConflictError: If name conflicts with another agent.
925 A2AAgentError: For other errors during update.
926 IntegrityError: If a database integrity error occurs.
927 ValueError: If query_param auth is disabled or host not in allowlist.
928 """
929 try:
930 # Acquire row lock for update to avoid lost-update on `version` and other fields
931 agent = get_for_update(db, DbA2AAgent, agent_id)
933 if not agent:
934 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
936 # Check ownership if user_email provided
937 if user_email:
938 # First-Party
939 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
941 permission_service = PermissionService(db)
942 if not await permission_service.check_resource_ownership(user_email, agent):
943 raise PermissionError("Only the owner can update this agent")
944 # Check for name conflict if name is being updated
945 if agent_data.name and agent_data.name != agent.name:
946 new_slug = slugify(agent_data.name)
947 visibility = agent_data.visibility or agent.visibility
948 team_id = agent_data.team_id or agent.team_id
949 # Check for existing server with the same slug within the same team or public scope
950 if visibility.lower() == "public":
951 # Check for existing public a2a agent with the same slug
952 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == new_slug, DbA2AAgent.visibility == "public"))
953 if existing_agent:
954 raise A2AAgentNameConflictError(name=new_slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
955 elif visibility.lower() == "team" and team_id:
956 # Check for existing team a2a agent with the same slug
957 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == new_slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id))
958 if existing_agent:
959 raise A2AAgentNameConflictError(name=new_slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
960 # Update the slug when name changes
961 agent.slug = new_slug
962 # Update fields
963 # Avoid `model_dump()` here: tests use `model_construct()` to create intentionally invalid
964 # payloads, and `model_dump()` emits serializer warnings when encountering unexpected types.
965 update_data = {field: getattr(agent_data, field) for field in agent_data.model_fields_set}
967 # Track original auth_type and endpoint_url before updates
968 original_auth_type = agent.auth_type
969 original_endpoint_url = agent.endpoint_url
971 for field, value in update_data.items():
972 if field == "passthrough_headers":
973 if value is not None:
974 if isinstance(value, list):
975 # Clean list: remove empty or whitespace-only entries
976 cleaned = [h.strip() for h in value if isinstance(h, str) and h.strip()]
977 agent.passthrough_headers = cleaned or None
978 elif isinstance(value, str):
979 # Parse comma-separated string and clean
980 parsed: List[str] = [h.strip() for h in value.split(",") if h.strip()]
981 agent.passthrough_headers = parsed or None
982 else:
983 raise A2AAgentError("Invalid passthrough_headers format: must be list[str] or comma-separated string")
984 else:
985 # Explicitly set to None if value is None
986 agent.passthrough_headers = None
987 continue
989 # Skip query_param fields - handled separately below
990 if field in ("auth_query_param_key", "auth_query_param_value"):
991 continue
993 if field == "oauth_config":
994 value = await protect_oauth_config_for_storage(value, existing_oauth_config=agent.oauth_config)
996 if hasattr(agent, field):
997 setattr(agent, field, value)
999 # Handle query_param auth updates
1000 # Clear auth_query_params when switching away from query_param auth
1001 if original_auth_type == "query_param" and agent_data.auth_type is not None and agent_data.auth_type != "query_param":
1002 agent.auth_query_params = None
1003 logger.debug(f"Cleared auth_query_params for agent {agent.id} (switched from query_param to {agent_data.auth_type})")
1005 # Handle switching to query_param auth or updating existing query_param credentials
1006 is_switching_to_queryparam = agent_data.auth_type == "query_param" and original_auth_type != "query_param"
1007 is_updating_queryparam_creds = original_auth_type == "query_param" and (agent_data.auth_query_param_key is not None or agent_data.auth_query_param_value is not None)
1008 is_url_changing = agent_data.endpoint_url is not None and str(agent_data.endpoint_url) != original_endpoint_url
1010 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"):
1011 # Standard
1012 from urllib.parse import urlparse # pylint: disable=import-outside-toplevel
1014 # First-Party
1015 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
1017 # Service-layer enforcement: Check feature flag
1018 if not settings.insecure_allow_queryparam_auth:
1019 # Grandfather clause: Allow updates to existing query_param agents
1020 # unless they're trying to change credentials
1021 if is_switching_to_queryparam or is_updating_queryparam_creds:
1022 raise ValueError("Query parameter authentication is disabled. Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
1024 # Service-layer enforcement: Check host allowlist
1025 if settings.insecure_queryparam_auth_allowed_hosts:
1026 check_url = str(agent_data.endpoint_url) if agent_data.endpoint_url else agent.endpoint_url
1027 parsed = urlparse(check_url)
1028 hostname = (parsed.hostname or "").lower()
1029 allowed_hosts = [h.lower() for h in settings.insecure_queryparam_auth_allowed_hosts]
1030 if hostname not in allowed_hosts:
1031 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
1032 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. " f"Allowed: {allowed}")
1034 if is_switching_to_queryparam or is_updating_queryparam_creds:
1035 # Get query param key and value
1036 param_key = getattr(agent_data, "auth_query_param_key", None)
1037 param_value = getattr(agent_data, "auth_query_param_value", None)
1039 # If no key provided but value is, reuse existing key (value-only rotation)
1040 existing_key = next(iter(agent.auth_query_params.keys()), None) if agent.auth_query_params else None
1041 if not param_key and param_value and existing_key:
1042 param_key = existing_key
1044 if param_key:
1045 # Check if value is masked (user didn't change it) or new value provided
1046 is_masked_placeholder = False
1047 if param_value and hasattr(param_value, "get_secret_value"):
1048 raw_value = param_value.get_secret_value()
1049 # First-Party
1050 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
1052 is_masked_placeholder = raw_value == settings.masked_auth_value
1053 elif param_value:
1054 raw_value = str(param_value)
1055 else:
1056 raw_value = None
1058 if raw_value and not is_masked_placeholder:
1059 # New value provided - encrypt for storage
1060 encrypted_value = encode_auth({param_key: raw_value})
1061 agent.auth_query_params = {param_key: encrypted_value}
1062 elif agent.auth_query_params and is_masked_placeholder:
1063 # Use existing encrypted value (user didn't change the password)
1064 # But key may have changed, so preserve with new key if different
1065 if existing_key and existing_key != param_key:
1066 # Key changed but value is masked - decrypt and re-encrypt with new key
1067 existing_encrypted = agent.auth_query_params.get(existing_key, "")
1068 if existing_encrypted:
1069 decrypted = decode_auth(existing_encrypted)
1070 existing_value = decrypted.get(existing_key, "")
1071 if existing_value:
1072 encrypted_value = encode_auth({param_key: existing_value})
1073 agent.auth_query_params = {param_key: encrypted_value}
1075 # Update auth_type if switching
1076 if is_switching_to_queryparam:
1077 agent.auth_type = "query_param"
1078 agent.auth_value = None # Query param auth doesn't use auth_value
1080 # Update metadata
1081 if modified_by:
1082 agent.modified_by = modified_by
1083 if modified_from_ip:
1084 agent.modified_from_ip = modified_from_ip
1085 if modified_via:
1086 agent.modified_via = modified_via
1087 if modified_user_agent:
1088 agent.modified_user_agent = modified_user_agent
1090 agent.version += 1
1092 db.commit()
1093 db.refresh(agent)
1095 # Invalidate cache after successful update
1096 cache = _get_registry_cache()
1097 await cache.invalidate_agents()
1098 # Also invalidate tags cache since agent tags may have changed
1099 # First-Party
1100 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1102 await admin_stats_cache.invalidate_tags()
1104 # Update the associated tool if it exists
1105 # Wrap in try/except to handle tool sync failures gracefully - the agent
1106 # update is the primary operation and should succeed even if tool sync fails
1107 try:
1108 # First-Party
1109 from mcpgateway.services.tool_service import tool_service
1111 await tool_service.update_tool_from_a2a_agent(
1112 db=db,
1113 agent=agent,
1114 modified_by=modified_by,
1115 modified_from_ip=modified_from_ip,
1116 modified_via=modified_via,
1117 modified_user_agent=modified_user_agent,
1118 )
1119 except Exception as tool_err:
1120 logger.warning(f"Failed to sync tool for A2A agent {agent.id}: {tool_err}. Agent update succeeded but tool may be out of sync.")
1122 logger.info(f"Updated A2A agent: {agent.name} (ID: {agent.id})")
1123 return self.convert_agent_to_read(agent, db=db)
1124 except PermissionError:
1125 db.rollback()
1126 raise
1127 except A2AAgentNameConflictError as ie:
1128 db.rollback()
1129 raise ie
1130 except A2AAgentNotFoundError as nf:
1131 db.rollback()
1132 raise nf
1133 except IntegrityError as ie:
1134 db.rollback()
1135 logger.error(f"IntegrityErrors in group: {ie}")
1136 raise ie
1137 except Exception as e:
1138 db.rollback()
1139 raise A2AAgentError(f"Failed to update A2A agent: {str(e)}")
1141 async def set_agent_state(self, db: Session, agent_id: str, activate: bool, reachable: Optional[bool] = None, user_email: Optional[str] = None) -> A2AAgentRead:
1142 """Set the activation status of an A2A agent.
1144 Args:
1145 db: Database session.
1146 agent_id: Agent ID.
1147 activate: True to activate, False to deactivate.
1148 reachable: Optional reachability status.
1149 user_email: Optional[str] The email of the user to check if the user has permission to modify.
1151 Returns:
1152 Updated agent data.
1154 Raises:
1155 A2AAgentNotFoundError: If the agent is not found.
1156 PermissionError: If user doesn't own the agent.
1157 """
1158 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
1159 agent = db.execute(query).scalar_one_or_none()
1161 if not agent:
1162 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
1164 if user_email:
1165 # First-Party
1166 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1168 permission_service = PermissionService(db)
1169 if not await permission_service.check_resource_ownership(user_email, agent):
1170 raise PermissionError("Only the owner can activate the Agent" if activate else "Only the owner can deactivate the Agent")
1172 agent.enabled = activate
1173 if reachable is not None:
1174 agent.reachable = reachable
1176 db.commit()
1177 db.refresh(agent)
1179 # Invalidate caches since agent status changed
1180 a2a_stats_cache.invalidate()
1181 cache = _get_registry_cache()
1182 await cache.invalidate_agents()
1184 status = "activated" if activate else "deactivated"
1185 logger.info(f"A2A agent {status}: {agent.name} (ID: {agent.id})")
1187 structured_logger.log(
1188 level="INFO",
1189 message=f"A2A agent {status}",
1190 event_type="a2a_agent_status_changed",
1191 component="a2a_service",
1192 user_email=user_email,
1193 resource_type="a2a_agent",
1194 resource_id=str(agent.id),
1195 custom_fields={
1196 "agent_name": agent.name,
1197 "enabled": agent.enabled,
1198 "reachable": agent.reachable,
1199 },
1200 )
1202 return self.convert_agent_to_read(agent, db=db)
1204 async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
1205 """Delete an A2A agent.
1207 Args:
1208 db: Database session.
1209 agent_id: Agent ID.
1210 user_email: Email of user performing delete (for ownership check).
1211 purge_metrics: If True, delete raw + rollup metrics for this agent.
1213 Raises:
1214 A2AAgentNotFoundError: If the agent is not found.
1215 PermissionError: If user doesn't own the agent.
1216 """
1217 try:
1218 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
1219 agent = db.execute(query).scalar_one_or_none()
1221 if not agent:
1222 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
1224 # Check ownership if user_email provided
1225 if user_email:
1226 # First-Party
1227 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1229 permission_service = PermissionService(db)
1230 if not await permission_service.check_resource_ownership(user_email, agent):
1231 raise PermissionError("Only the owner can delete this agent")
1233 agent_name = agent.name
1235 # Delete the associated tool before deleting the agent
1236 # First-Party
1237 from mcpgateway.services.tool_service import tool_service
1239 await tool_service.delete_tool_from_a2a_agent(db=db, agent=agent, user_email=user_email, purge_metrics=purge_metrics)
1241 if purge_metrics:
1242 with pause_rollup_during_purge(reason=f"purge_a2a_agent:{agent_id}"):
1243 delete_metrics_in_batches(db, A2AAgentMetric, A2AAgentMetric.a2a_agent_id, agent_id)
1244 delete_metrics_in_batches(db, A2AAgentMetricsHourly, A2AAgentMetricsHourly.a2a_agent_id, agent_id)
1245 db.delete(agent)
1246 db.commit()
1248 # Invalidate caches since agent count changed
1249 a2a_stats_cache.invalidate()
1250 cache = _get_registry_cache()
1251 await cache.invalidate_agents()
1252 # Also invalidate tags cache since agent tags may have changed
1253 # First-Party
1254 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1256 await admin_stats_cache.invalidate_tags()
1258 logger.info(f"Deleted A2A agent: {agent_name} (ID: {agent_id})")
1260 structured_logger.log(
1261 level="INFO",
1262 message="A2A agent deleted",
1263 event_type="a2a_agent_deleted",
1264 component="a2a_service",
1265 user_email=user_email,
1266 resource_type="a2a_agent",
1267 resource_id=str(agent_id),
1268 custom_fields={
1269 "agent_name": agent_name,
1270 "purge_metrics": purge_metrics,
1271 },
1272 )
1273 except PermissionError:
1274 db.rollback()
1275 raise
1277 async def invoke_agent(
1278 self,
1279 db: Session,
1280 agent_name: str,
1281 parameters: Dict[str, Any],
1282 interaction_type: str = "query",
1283 *,
1284 user_id: Optional[str] = None,
1285 user_email: Optional[str] = None,
1286 token_teams: Optional[List[str]] = None,
1287 ) -> Dict[str, Any]:
1288 """Invoke an A2A agent.
1290 Args:
1291 db: Database session.
1292 agent_name: Name of the agent to invoke.
1293 parameters: Parameters for the interaction.
1294 interaction_type: Type of interaction.
1295 user_id: Identifier of the user initiating the call.
1296 user_email: Email of the user initiating the call.
1297 token_teams: Teams from JWT token. None = admin (no filtering),
1298 [] = public-only, [...] = team-scoped access.
1300 Returns:
1301 Agent response.
1303 Raises:
1304 A2AAgentNotFoundError: If the agent is not found or user lacks access.
1305 A2AAgentError: If the agent is disabled or invocation fails.
1306 """
1307 # ═══════════════════════════════════════════════════════════════════════════
1308 # PHASE 1: Acquire a short row lock to read `enabled` + `auth_value`,
1309 # then release the lock before performing the external HTTP call.
1310 # This avoids TOCTOU for the critical checks while not holding DB
1311 # connections during the potentially slow HTTP request.
1312 # ═══════════════════════════════════════════════════════════════════════════
1314 # Lookup the agent id, then lock the row by id using get_for_update
1315 agent_row = db.execute(select(DbA2AAgent.id).where(DbA2AAgent.name == agent_name)).scalar_one_or_none()
1316 if not agent_row:
1317 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1319 agent = get_for_update(db, DbA2AAgent, agent_row)
1320 if not agent:
1321 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1323 # ═══════════════════════════════════════════════════════════════════════════
1324 # SECURITY: Check visibility/team access WHILE ROW IS LOCKED
1325 # Return 404 (not 403) to avoid leaking existence of private agents
1326 # ═══════════════════════════════════════════════════════════════════════════
1327 if not self._check_agent_access(agent, user_email, token_teams):
1328 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1330 if not agent.enabled:
1331 raise A2AAgentError(f"A2A Agent '{agent_name}' is disabled")
1333 # Extract all needed data to local variables before releasing DB connection
1334 agent_id = agent.id
1335 agent_endpoint_url = agent.endpoint_url
1336 agent_type = agent.agent_type
1337 agent_protocol_version = agent.protocol_version
1338 agent_auth_type = agent.auth_type
1339 agent_auth_value = agent.auth_value
1340 agent_auth_query_params = agent.auth_query_params
1342 # Handle query_param auth - decrypt and apply to URL
1343 auth_query_params_decrypted: Optional[Dict[str, str]] = None
1344 if agent_auth_type == "query_param" and agent_auth_query_params:
1345 # First-Party
1346 from mcpgateway.utils.url_auth import apply_query_param_auth # pylint: disable=import-outside-toplevel
1348 auth_query_params_decrypted = {}
1349 for param_key, encrypted_value in agent_auth_query_params.items():
1350 if encrypted_value:
1351 try:
1352 decrypted = decode_auth(encrypted_value)
1353 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
1354 except Exception:
1355 logger.debug(f"Failed to decrypt query param '{param_key}' for A2A agent invocation")
1356 if auth_query_params_decrypted:
1357 agent_endpoint_url = apply_query_param_auth(agent_endpoint_url, auth_query_params_decrypted)
1359 # Decode auth_value for supported auth types (before closing session)
1360 auth_headers = {}
1361 if agent_auth_type in ("basic", "bearer", "authheaders") and agent_auth_value:
1362 # Decrypt auth_value and extract headers (follows gateway_service pattern)
1363 if isinstance(agent_auth_value, str):
1364 try:
1365 auth_headers = decode_auth(agent_auth_value)
1366 except Exception as e:
1367 raise A2AAgentError(f"Failed to decrypt authentication for agent '{agent_name}': {e}")
1368 elif isinstance(agent_auth_value, dict):
1369 auth_headers = {str(k): str(v) for k, v in agent_auth_value.items()}
1371 # ═══════════════════════════════════════════════════════════════════════════
1372 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
1373 # This prevents connection pool exhaustion during slow upstream requests.
1374 # ═══════════════════════════════════════════════════════════════════════════
1375 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
1376 db.close()
1378 start_time = datetime.now(timezone.utc)
1379 success = False
1380 error_message = None
1381 response = None
1383 # ═══════════════════════════════════════════════════════════════════════════
1384 # PHASE 2: Make HTTP call (no DB connection held)
1385 # ═══════════════════════════════════════════════════════════════════════════
1387 # Create sanitized URL for logging (redacts auth query params)
1388 # First-Party
1389 from mcpgateway.utils.url_auth import sanitize_exception_message, sanitize_url_for_logging # pylint: disable=import-outside-toplevel
1391 sanitized_endpoint_url = sanitize_url_for_logging(agent_endpoint_url, auth_query_params_decrypted)
1393 try:
1394 # Prepare the request to the A2A agent
1395 # Format request based on agent type and endpoint
1396 if agent_type in ["generic", "jsonrpc"] or agent_endpoint_url.endswith("/"):
1397 # Use JSONRPC format for agents that expect it
1398 request_data = {"jsonrpc": "2.0", "method": parameters.get("method", "message/send"), "params": parameters.get("params", parameters), "id": 1}
1399 else:
1400 # Use custom A2A format
1401 request_data = {"interaction_type": interaction_type, "parameters": parameters, "protocol_version": agent_protocol_version}
1403 # Make HTTP request to the agent endpoint using shared HTTP client
1404 # First-Party
1405 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1407 client = await get_http_client()
1408 headers = {"Content-Type": "application/json"}
1410 # Add authentication if configured (using decoded auth headers)
1411 headers.update(auth_headers)
1413 # Add correlation ID to outbound headers for distributed tracing
1414 correlation_id = get_correlation_id()
1415 if correlation_id:
1416 headers["X-Correlation-ID"] = correlation_id
1418 # Log A2A external call start (with sanitized URL to prevent credential leakage)
1419 call_start_time = datetime.now(timezone.utc)
1420 structured_logger.log(
1421 level="INFO",
1422 message=f"A2A external call started: {agent_name}",
1423 component="a2a_service",
1424 user_id=user_id,
1425 user_email=user_email,
1426 correlation_id=correlation_id,
1427 metadata={
1428 "event": "a2a_call_started",
1429 "agent_name": agent_name,
1430 "agent_id": agent_id,
1431 "endpoint_url": sanitized_endpoint_url,
1432 "interaction_type": interaction_type,
1433 "protocol_version": agent_protocol_version,
1434 },
1435 )
1437 http_response = await client.post(agent_endpoint_url, json=request_data, headers=headers)
1438 call_duration_ms = (datetime.now(timezone.utc) - call_start_time).total_seconds() * 1000
1440 if http_response.status_code == 200:
1441 response = http_response.json()
1442 success = True
1444 # Log successful A2A call
1445 structured_logger.log(
1446 level="INFO",
1447 message=f"A2A external call completed: {agent_name}",
1448 component="a2a_service",
1449 user_id=user_id,
1450 user_email=user_email,
1451 correlation_id=correlation_id,
1452 duration_ms=call_duration_ms,
1453 metadata={"event": "a2a_call_completed", "agent_name": agent_name, "agent_id": agent_id, "status_code": http_response.status_code, "success": True},
1454 )
1455 else:
1456 # Sanitize error message to prevent URL secrets from leaking in logs
1457 raw_error = f"HTTP {http_response.status_code}: {http_response.text}"
1458 error_message = sanitize_exception_message(raw_error, auth_query_params_decrypted)
1460 # Log failed A2A call
1461 structured_logger.log(
1462 level="ERROR",
1463 message=f"A2A external call failed: {agent_name}",
1464 component="a2a_service",
1465 user_id=user_id,
1466 user_email=user_email,
1467 correlation_id=correlation_id,
1468 duration_ms=call_duration_ms,
1469 error_details={"error_type": "A2AHTTPError", "error_message": error_message},
1470 metadata={"event": "a2a_call_failed", "agent_name": agent_name, "agent_id": agent_id, "status_code": http_response.status_code},
1471 )
1473 raise A2AAgentError(error_message)
1475 except A2AAgentError:
1476 # Re-raise A2AAgentError without wrapping
1477 raise
1478 except Exception as e:
1479 # Sanitize error message to prevent URL secrets from leaking in logs
1480 error_message = sanitize_exception_message(str(e), auth_query_params_decrypted)
1481 logger.error(f"Failed to invoke A2A agent '{agent_name}': {error_message}")
1482 raise A2AAgentError(f"Failed to invoke A2A agent: {error_message}")
1484 finally:
1485 # ═══════════════════════════════════════════════════════════════════════════
1486 # PHASE 3: Record metrics via buffered service (batches writes for performance)
1487 # ═══════════════════════════════════════════════════════════════════════════
1488 end_time = datetime.now(timezone.utc)
1489 response_time = (end_time - start_time).total_seconds()
1491 try:
1492 # First-Party
1493 from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service # pylint: disable=import-outside-toplevel
1495 metrics_buffer = get_metrics_buffer_service()
1496 metrics_buffer.record_a2a_agent_metric_with_duration(
1497 a2a_agent_id=agent_id,
1498 response_time=response_time,
1499 success=success,
1500 interaction_type=interaction_type,
1501 error_message=error_message,
1502 )
1503 except Exception as metrics_error:
1504 logger.warning(f"Failed to record A2A metrics for '{agent_name}': {metrics_error}")
1506 # Update last interaction timestamp (quick separate write)
1507 try:
1508 with fresh_db_session() as ts_db:
1509 # Reacquire short lock and re-check enabled before writing
1510 db_agent = get_for_update(ts_db, DbA2AAgent, agent_id)
1511 if db_agent and getattr(db_agent, "enabled", False):
1512 db_agent.last_interaction = end_time
1513 ts_db.commit()
1514 except Exception as ts_error:
1515 logger.warning(f"Failed to update last_interaction for '{agent_name}': {ts_error}")
1517 return response or {"error": error_message}
1519 async def aggregate_metrics(self, db: Session) -> A2AAgentAggregateMetrics:
1520 """Aggregate metrics for all A2A agents.
1522 Combines recent raw metrics (within retention period) with historical
1523 hourly rollups for complete historical coverage. Uses in-memory caching
1524 (10s TTL) to reduce database load under high request rates.
1526 Args:
1527 db: Database session.
1529 Returns:
1530 A2AAgentAggregateMetrics: Aggregated metrics from raw + hourly rollup tables.
1531 """
1532 # Check cache first (if enabled)
1533 # First-Party
1534 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
1536 if is_cache_enabled():
1537 cached = metrics_cache.get("a2a")
1538 if cached is not None:
1539 return A2AAgentAggregateMetrics(**cached)
1541 # Get total/active agent counts from cache (avoids 2 COUNT queries per call)
1542 counts = a2a_stats_cache.get_counts(db)
1543 total_agents = counts["total"]
1544 active_agents = counts["active"]
1546 # Use combined raw + rollup query for full historical coverage
1547 # First-Party
1548 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
1550 result = aggregate_metrics_combined(db, "a2a_agent")
1552 total_interactions = result.total_executions
1553 successful_interactions = result.successful_executions
1554 failed_interactions = result.failed_executions
1556 metrics = A2AAgentAggregateMetrics(
1557 total_agents=total_agents,
1558 active_agents=active_agents,
1559 total_interactions=total_interactions,
1560 successful_interactions=successful_interactions,
1561 failed_interactions=failed_interactions,
1562 success_rate=(successful_interactions / total_interactions * 100) if total_interactions > 0 else 0.0,
1563 avg_response_time=float(result.avg_response_time or 0.0),
1564 min_response_time=float(result.min_response_time or 0.0),
1565 max_response_time=float(result.max_response_time or 0.0),
1566 )
1568 # Cache the result as dict for serialization compatibility (if enabled)
1569 if is_cache_enabled():
1570 metrics_cache.set("a2a", metrics.model_dump())
1572 return metrics
1574 async def reset_metrics(self, db: Session, agent_id: Optional[str] = None) -> None:
1575 """Reset metrics for agents (raw + hourly rollups).
1577 Args:
1578 db: Database session.
1579 agent_id: Optional agent ID to reset metrics for specific agent.
1580 """
1581 if agent_id:
1582 db.execute(delete(A2AAgentMetric).where(A2AAgentMetric.a2a_agent_id == agent_id))
1583 db.execute(delete(A2AAgentMetricsHourly).where(A2AAgentMetricsHourly.a2a_agent_id == agent_id))
1584 else:
1585 db.execute(delete(A2AAgentMetric))
1586 db.execute(delete(A2AAgentMetricsHourly))
1587 db.commit()
1589 # Invalidate metrics cache
1590 # First-Party
1591 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
1593 metrics_cache.invalidate("a2a")
1595 logger.info("Reset A2A agent metrics" + (f" for agent {agent_id}" if agent_id else ""))
1597 def _prepare_a2a_agent_for_read(self, agent: DbA2AAgent) -> DbA2AAgent:
1598 """Prepare a a2a agent object for A2AAgentRead validation.
1600 Ensures auth_value is in the correct format (encoded string) for the schema.
1602 Args:
1603 agent: A2A Agent database object
1605 Returns:
1606 A2A Agent object with properly formatted auth_value
1607 """
1608 # If auth_value is a dict, encode it to string for GatewayRead schema
1609 if isinstance(agent.auth_value, dict):
1610 agent.auth_value = encode_auth(agent.auth_value)
1611 return agent
1613 def convert_agent_to_read(self, db_agent: DbA2AAgent, include_metrics: bool = False, db: Optional[Session] = None, team_map: Optional[Dict[str, str]] = None) -> A2AAgentRead:
1614 """Convert database model to schema.
1616 Args:
1617 db_agent (DbA2AAgent): Database agent model.
1618 include_metrics (bool): Whether to include metrics in the result. Defaults to False.
1619 Set to False for list operations to avoid N+1 query issues.
1620 db (Optional[Session]): Database session. Only required if team name is not pre-populated
1621 on the db_agent object and team_map is not provided.
1622 team_map (Optional[Dict[str, str]]): Pre-fetched team_id -> team_name mapping.
1623 If provided, avoids N+1 queries for team name lookups in list operations.
1625 Returns:
1626 A2AAgentRead: Agent read schema.
1628 Raises:
1629 A2AAgentNotFoundError: If the provided agent is not found or invalid.
1631 """
1633 if not db_agent:
1634 raise A2AAgentNotFoundError("Agent not found")
1636 # Check if team attribute already exists (pre-populated in batch operations)
1637 # Otherwise use pre-fetched team map if available, otherwise query individually
1638 if not hasattr(db_agent, "team") or db_agent.team is None:
1639 team_id = getattr(db_agent, "team_id", None)
1640 if team_map is not None and team_id:
1641 team_name = team_map.get(team_id)
1642 elif db is not None:
1643 team_name = self._get_team_name(db, team_id)
1644 else:
1645 team_name = None
1646 setattr(db_agent, "team", team_name)
1648 # Compute metrics only if requested (avoids N+1 queries in list operations)
1649 if include_metrics:
1650 total_executions = len(db_agent.metrics)
1651 successful_executions = sum(1 for m in db_agent.metrics if m.is_success)
1652 failed_executions = total_executions - successful_executions
1653 failure_rate = (failed_executions / total_executions * 100) if total_executions > 0 else 0.0
1655 min_response_time = max_response_time = avg_response_time = last_execution_time = None
1656 if db_agent.metrics:
1657 response_times = [m.response_time for m in db_agent.metrics if m.response_time is not None]
1658 if response_times:
1659 min_response_time = min(response_times)
1660 max_response_time = max(response_times)
1661 avg_response_time = sum(response_times) / len(response_times)
1662 last_execution_time = max((m.timestamp for m in db_agent.metrics), default=None)
1664 metrics = A2AAgentMetrics(
1665 total_executions=total_executions,
1666 successful_executions=successful_executions,
1667 failed_executions=failed_executions,
1668 failure_rate=failure_rate,
1669 min_response_time=min_response_time,
1670 max_response_time=max_response_time,
1671 avg_response_time=avg_response_time,
1672 last_execution_time=last_execution_time,
1673 )
1674 else:
1675 metrics = None
1677 # Build dict from ORM model
1678 agent_data = {k: getattr(db_agent, k, None) for k in A2AAgentRead.model_fields.keys()}
1679 agent_data["metrics"] = metrics
1680 agent_data["team"] = getattr(db_agent, "team", None)
1681 # Include auth_query_params for the _mask_query_param_auth validator
1682 agent_data["auth_query_params"] = getattr(db_agent, "auth_query_params", None)
1684 # Validate using Pydantic model
1685 validated_agent = A2AAgentRead.model_validate(agent_data)
1687 # Return masked version (like GatewayRead)
1688 return validated_agent.masked()