Coverage for mcpgateway / services / a2a_service.py: 99%
618 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2# pylint: disable=invalid-name, import-outside-toplevel, unused-import, no-name-in-module
3"""Location: ./mcpgateway/services/a2a_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8A2A Agent Service
10This module implements A2A (Agent-to-Agent) agent management for the MCP Gateway.
11It handles agent registration, listing, retrieval, updates, activation toggling, deletion,
12and interactions with A2A-compatible agents.
13"""
15# Standard
16import binascii
17from datetime import datetime, timezone
18from typing import Any, AsyncGenerator, Dict, List, Optional, Union
20# Third-Party
21from pydantic import ValidationError
22from sqlalchemy import and_, delete, desc, or_, select
23from sqlalchemy.exc import IntegrityError
24from sqlalchemy.orm import Session
26# First-Party
27from mcpgateway.cache.a2a_stats_cache import a2a_stats_cache
28from mcpgateway.db import A2AAgent as DbA2AAgent
29from mcpgateway.db import A2AAgentMetric, A2AAgentMetricsHourly, EmailTeam, fresh_db_session, get_for_update
30from mcpgateway.schemas import A2AAgentCreate, A2AAgentMetrics, A2AAgentRead, A2AAgentUpdate
31from mcpgateway.services.logging_service import LoggingService
32from mcpgateway.services.metrics_cleanup_service import delete_metrics_in_batches, pause_rollup_during_purge
33from mcpgateway.services.structured_logger import get_structured_logger
34from mcpgateway.services.team_management_service import TeamManagementService
35from mcpgateway.utils.correlation_id import get_correlation_id
36from mcpgateway.utils.create_slug import slugify
37from mcpgateway.utils.pagination import unified_paginate
38from mcpgateway.utils.services_auth import decode_auth, encode_auth
39from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
41# Cache import (lazy to avoid circular dependencies)
42_REGISTRY_CACHE = None
45def _get_registry_cache():
46 """Get registry cache singleton lazily.
48 Returns:
49 RegistryCache instance.
50 """
51 global _REGISTRY_CACHE # pylint: disable=global-statement
52 if _REGISTRY_CACHE is None:
53 # First-Party
54 from mcpgateway.cache.registry_cache import registry_cache # pylint: disable=import-outside-toplevel
56 _REGISTRY_CACHE = registry_cache
57 return _REGISTRY_CACHE
60# Initialize logging service first
61logging_service = LoggingService()
62logger = logging_service.get_logger(__name__)
64# Initialize structured logger for A2A lifecycle tracking
65structured_logger = get_structured_logger("a2a_service")
68class A2AAgentError(Exception):
69 """Base class for A2A agent-related errors.
71 Examples:
72 >>> try:
73 ... raise A2AAgentError("Agent operation failed")
74 ... except A2AAgentError as e:
75 ... str(e)
76 'Agent operation failed'
77 >>> try:
78 ... raise A2AAgentError("Connection error")
79 ... except Exception as e:
80 ... isinstance(e, A2AAgentError)
81 True
82 """
85class A2AAgentNotFoundError(A2AAgentError):
86 """Raised when a requested A2A agent is not found.
88 Examples:
89 >>> try:
90 ... raise A2AAgentNotFoundError("Agent 'test-agent' not found")
91 ... except A2AAgentNotFoundError as e:
92 ... str(e)
93 "Agent 'test-agent' not found"
94 >>> try:
95 ... raise A2AAgentNotFoundError("No such agent")
96 ... except A2AAgentError as e:
97 ... isinstance(e, A2AAgentError) # Should inherit from A2AAgentError
98 True
99 """
102class A2AAgentNameConflictError(A2AAgentError):
103 """Raised when an A2A agent name conflicts with an existing one."""
105 def __init__(self, name: str, is_active: bool = True, agent_id: Optional[str] = None, visibility: Optional[str] = "public"):
106 """Initialize an A2AAgentNameConflictError exception.
108 Creates an exception that indicates an agent name conflict, with additional
109 context about whether the conflicting agent is active and its ID if known.
111 Args:
112 name: The agent name that caused the conflict.
113 is_active: Whether the conflicting agent is currently active.
114 agent_id: The ID of the conflicting agent, if known.
115 visibility: The visibility level of the conflicting agent (private, team, public).
117 Examples:
118 >>> error = A2AAgentNameConflictError("test-agent")
119 >>> error.name
120 'test-agent'
121 >>> error.is_active
122 True
123 >>> error.agent_id is None
124 True
125 >>> "test-agent" in str(error)
126 True
127 >>>
128 >>> # Test inactive agent conflict
129 >>> error = A2AAgentNameConflictError("inactive-agent", is_active=False, agent_id="agent-123")
130 >>> error.is_active
131 False
132 >>> error.agent_id
133 'agent-123'
134 >>> "inactive" in str(error)
135 True
136 >>> "agent-123" in str(error)
137 True
138 """
139 self.name = name
140 self.is_active = is_active
141 self.agent_id = agent_id
142 message = f"{visibility.capitalize()} A2A Agent already exists with name: {name}"
143 if not is_active:
144 message += f" (currently inactive, ID: {agent_id})"
145 super().__init__(message)
148class A2AAgentService:
149 """Service for managing A2A agents in the gateway.
151 Provides methods to create, list, retrieve, update, set state, and delete agent records.
152 Also supports interactions with A2A-compatible agents.
153 """
155 def __init__(self) -> None:
156 """Initialize a new A2AAgentService instance."""
157 self._initialized = False
158 self._event_streams: List[AsyncGenerator[str, None]] = []
160 async def initialize(self) -> None:
161 """Initialize the A2A agent service."""
162 if not self._initialized:
163 logger.info("Initializing A2A Agent Service")
164 self._initialized = True
166 async def shutdown(self) -> None:
167 """Shutdown the A2A agent service and cleanup resources."""
168 if self._initialized:
169 logger.info("Shutting down A2A Agent Service")
170 self._initialized = False
172 def _get_team_name(self, db: Session, team_id: Optional[str]) -> Optional[str]:
173 """Retrieve the team name given a team ID.
175 Args:
176 db (Session): Database session for querying teams.
177 team_id (Optional[str]): The ID of the team.
179 Returns:
180 Optional[str]: The name of the team if found, otherwise None.
181 """
182 if not team_id:
183 return None
185 team = db.query(EmailTeam).filter(EmailTeam.id == team_id, EmailTeam.is_active.is_(True)).first()
186 db.commit() # Release transaction to avoid idle-in-transaction
187 return team.name if team else None
189 def _batch_get_team_names(self, db: Session, team_ids: List[str]) -> Dict[str, str]:
190 """Batch retrieve team names for multiple team IDs.
192 This method fetches team names in a single query to avoid N+1 issues
193 when converting multiple agents to schemas in list operations.
195 Args:
196 db (Session): Database session for querying teams.
197 team_ids (List[str]): List of team IDs to look up.
199 Returns:
200 Dict[str, str]: Mapping of team_id -> team_name for active teams.
201 """
202 if not team_ids:
203 return {}
205 # Single query for all teams
206 teams = db.query(EmailTeam.id, EmailTeam.name).filter(EmailTeam.id.in_(team_ids), EmailTeam.is_active.is_(True)).all()
208 return {team.id: team.name for team in teams}
210 def _check_agent_access(
211 self,
212 agent: DbA2AAgent,
213 user_email: Optional[str],
214 token_teams: Optional[List[str]],
215 ) -> bool:
216 """Check if user has access to agent based on visibility rules.
218 Access rules (matching tools/resources/prompts):
219 - token_teams is None: Admin bypass (unrestricted access)
220 - public visibility: Always allowed
221 - team visibility: Allowed if agent.team_id in token_teams
222 - private visibility: Allowed if owner, BUT NOT for public-only tokens
224 Args:
225 agent: The agent to check access for
226 user_email: User's email for owner matching
227 token_teams: Teams from JWT. None = admin, [] = public-only (no owner access)
229 Returns:
230 True if access allowed, False otherwise.
231 """
232 # Admin bypass - token_teams is None means unrestricted access
233 if token_teams is None:
234 return True
236 if agent.visibility == "public":
237 return True
239 if agent.visibility == "team" and token_teams:
240 return agent.team_id in token_teams
242 # Private visibility: owner can access, BUT NOT for public-only tokens
243 # Public-only tokens (empty teams array) should NOT get owner access
244 is_public_only_token = len(token_teams) == 0
245 if agent.visibility == "private" and user_email and not is_public_only_token:
246 return agent.owner_email == user_email
248 return False
250 def _apply_visibility_filter(
251 self,
252 query,
253 user_email: Optional[str],
254 token_teams: List[str],
255 team_id: Optional[str] = None,
256 ) -> Any:
257 """Apply visibility-based access control to query.
259 Access rules (matching tools/resources/prompts):
260 - public: visible to all
261 - team: visible to team members (token_teams contains team_id)
262 - private: visible only to owner, BUT NOT for public-only tokens
264 Args:
265 query: SQLAlchemy query to filter
266 user_email: User's email for owner matching
267 token_teams: Teams from JWT. [] = public-only (no owner access)
268 team_id: Optional specific team filter
270 Returns:
271 Filtered query
272 """
273 # Check if this is a public-only token (empty teams array)
274 # Public-only tokens can ONLY see public resources - no owner access
275 is_public_only_token = len(token_teams) == 0
277 if team_id:
278 # User requesting specific team - verify access
279 if team_id not in token_teams:
280 # Return query that matches nothing (will return empty result)
281 return query.where(False)
283 access_conditions = [
284 and_(DbA2AAgent.team_id == team_id, DbA2AAgent.visibility.in_(["team", "public"])),
285 ]
286 # Only include owner access for non-public-only tokens with user_email
287 if not is_public_only_token and user_email:
288 access_conditions.append(and_(DbA2AAgent.team_id == team_id, DbA2AAgent.owner_email == user_email))
289 return query.where(or_(*access_conditions))
291 # General access: public + team (+ owner if not public-only token)
292 access_conditions = [DbA2AAgent.visibility == "public"]
294 # Only include owner access for non-public-only tokens with user_email
295 if not is_public_only_token and user_email:
296 access_conditions.append(DbA2AAgent.owner_email == user_email)
298 if token_teams:
299 access_conditions.append(and_(DbA2AAgent.team_id.in_(token_teams), DbA2AAgent.visibility.in_(["team", "public"])))
301 return query.where(or_(*access_conditions))
303 async def register_agent(
304 self,
305 db: Session,
306 agent_data: A2AAgentCreate,
307 created_by: Optional[str] = None,
308 created_from_ip: Optional[str] = None,
309 created_via: Optional[str] = None,
310 created_user_agent: Optional[str] = None,
311 import_batch_id: Optional[str] = None,
312 federation_source: Optional[str] = None,
313 team_id: Optional[str] = None,
314 owner_email: Optional[str] = None,
315 visibility: Optional[str] = "public",
316 ) -> A2AAgentRead:
317 """Register a new A2A agent.
319 Args:
320 db (Session): Database session.
321 agent_data (A2AAgentCreate): Data required to create an agent.
322 created_by (Optional[str]): User who created the agent.
323 created_from_ip (Optional[str]): IP address of the creator.
324 created_via (Optional[str]): Method used for creation (e.g., API, import).
325 created_user_agent (Optional[str]): User agent of the creation request.
326 import_batch_id (Optional[str]): UUID of a bulk import batch.
327 federation_source (Optional[str]): Source gateway for federated agents.
328 team_id (Optional[str]): ID of the team to assign the agent to.
329 owner_email (Optional[str]): Email of the agent owner.
330 visibility (Optional[str]): Visibility level ('public', 'team', 'private').
332 Returns:
333 A2AAgentRead: The created agent object.
335 Raises:
336 A2AAgentNameConflictError: If another agent with the same name already exists.
337 IntegrityError: If a database constraint or integrity violation occurs.
338 ValueError: If invalid configuration or data is provided.
339 A2AAgentError: For any other unexpected errors during registration.
341 Examples:
342 # TODO
343 """
344 try:
345 agent_data.slug = slugify(agent_data.name)
346 # Check for existing server with the same slug within the same team or public scope
347 if visibility.lower() == "public":
348 logger.info(f"visibility.lower(): {visibility.lower()}")
349 logger.info(f"agent_data.name: {agent_data.name}")
350 logger.info(f"agent_data.slug: {agent_data.slug}")
351 # Check for existing public a2a agent with the same slug
352 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "public"))
353 if existing_agent:
354 raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
355 elif visibility.lower() == "team" and team_id:
356 # Check for existing team a2a agent with the same slug
357 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == agent_data.slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id))
358 if existing_agent:
359 raise A2AAgentNameConflictError(name=agent_data.slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
361 auth_type = getattr(agent_data, "auth_type", None)
362 # Support multiple custom headers
363 auth_value = getattr(agent_data, "auth_value", {})
365 # authentication_headers: Optional[Dict[str, str]] = None
367 if hasattr(agent_data, "auth_headers") and agent_data.auth_headers:
368 # Convert list of {key, value} to dict
369 header_dict = {h["key"]: h["value"] for h in agent_data.auth_headers if h.get("key")}
370 # Keep encoded form for persistence, but pass raw headers for initialization
371 auth_value = encode_auth(header_dict) # Encode the dict for consistency
372 # authentication_headers = {str(k): str(v) for k, v in header_dict.items()}
373 # elif isinstance(auth_value, str) and auth_value:
374 # # Decode persisted auth for initialization
375 # decoded = decode_auth(auth_value)
376 # authentication_headers = {str(k): str(v) for k, v in decoded.items()}
377 else:
378 # authentication_headers = None
379 pass
380 # auth_value = {}
382 oauth_config = getattr(agent_data, "oauth_config", None)
384 # Handle query_param auth - encrypt and prepare for storage
385 auth_query_params_encrypted: Optional[Dict[str, str]] = None
386 if auth_type == "query_param":
387 # Standard
388 from urllib.parse import urlparse # pylint: disable=import-outside-toplevel
390 # First-Party
391 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
393 # Service-layer enforcement: Check feature flag
394 if not settings.insecure_allow_queryparam_auth:
395 raise ValueError("Query parameter authentication is disabled. Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
397 # Service-layer enforcement: Check host allowlist
398 if settings.insecure_queryparam_auth_allowed_hosts:
399 parsed = urlparse(str(agent_data.endpoint_url))
400 hostname = (parsed.hostname or "").lower()
401 allowed_hosts = [h.lower() for h in settings.insecure_queryparam_auth_allowed_hosts]
402 if hostname not in allowed_hosts:
403 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
404 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. " f"Allowed: {allowed}")
406 # Extract and encrypt query param auth
407 param_key = getattr(agent_data, "auth_query_param_key", None)
408 param_value = getattr(agent_data, "auth_query_param_value", None)
409 if param_key and param_value:
410 # Handle SecretStr
411 if hasattr(param_value, "get_secret_value"):
412 raw_value = param_value.get_secret_value()
413 else:
414 raw_value = str(param_value)
415 # Encrypt for storage
416 encrypted_value = encode_auth({param_key: raw_value})
417 auth_query_params_encrypted = {param_key: encrypted_value}
418 # Query param auth doesn't use auth_value
419 auth_value = None
421 # Create new agent
422 new_agent = DbA2AAgent(
423 name=agent_data.name,
424 description=agent_data.description,
425 endpoint_url=agent_data.endpoint_url,
426 agent_type=agent_data.agent_type,
427 protocol_version=agent_data.protocol_version,
428 capabilities=agent_data.capabilities,
429 config=agent_data.config,
430 auth_type=auth_type,
431 auth_value=auth_value, # This should be encrypted in practice
432 auth_query_params=auth_query_params_encrypted, # Encrypted query param auth
433 oauth_config=oauth_config,
434 tags=agent_data.tags,
435 passthrough_headers=getattr(agent_data, "passthrough_headers", None),
436 # Team scoping fields - use schema values if provided, otherwise fallback to parameters
437 team_id=getattr(agent_data, "team_id", None) or team_id,
438 owner_email=getattr(agent_data, "owner_email", None) or owner_email or created_by,
439 # Endpoint visibility parameter takes precedence over schema default
440 visibility=visibility if visibility is not None else getattr(agent_data, "visibility", "public"),
441 created_by=created_by,
442 created_from_ip=created_from_ip,
443 created_via=created_via,
444 created_user_agent=created_user_agent,
445 import_batch_id=import_batch_id,
446 federation_source=federation_source,
447 )
449 db.add(new_agent)
450 # Commit agent FIRST to ensure it persists even if tool creation fails
451 # This is critical because ToolService.register_tool calls db.rollback()
452 # on error, which would undo a pending (flushed but uncommitted) agent
453 db.commit()
454 db.refresh(new_agent)
456 # Invalidate caches since agent count changed
457 # Wrapped in try/except to ensure cache failures don't fail the request
458 # when the agent is already successfully committed
459 try:
460 a2a_stats_cache.invalidate()
461 cache = _get_registry_cache()
462 await cache.invalidate_agents()
463 # Also invalidate tags cache since agent tags may have changed
464 # First-Party
465 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
467 await admin_stats_cache.invalidate_tags()
468 # First-Party
469 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
471 metrics_cache.invalidate("a2a")
472 except Exception as cache_error:
473 logger.warning(f"Cache invalidation failed after agent commit: {cache_error}")
475 # Automatically create a tool for the A2A agent if not already present
476 # Tool creation is wrapped in try/except to ensure agent registration succeeds
477 # even if tool creation fails (e.g., due to visibility or permission issues)
478 tool_db = None
479 try:
480 # First-Party
481 from mcpgateway.services.tool_service import tool_service
483 tool_db = await tool_service.create_tool_from_a2a_agent(
484 db=db,
485 agent=new_agent,
486 created_by=created_by,
487 created_from_ip=created_from_ip,
488 created_via=created_via,
489 created_user_agent=created_user_agent,
490 )
492 # Associate the tool with the agent using the relationship
493 # This sets both the tool_id foreign key and the tool relationship
494 new_agent.tool = tool_db
495 db.commit()
496 db.refresh(new_agent)
497 logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id}) with tool ID: {tool_db.id}")
498 except Exception as tool_error:
499 # Log the error but don't fail agent registration
500 # Agent was already committed above, so it persists even if tool creation fails
501 logger.warning(f"Failed to create tool for A2A agent {new_agent.name}: {tool_error}")
502 structured_logger.warning(
503 f"A2A agent '{new_agent.name}' created without tool association",
504 user_id=created_by,
505 resource_type="a2a_agent",
506 resource_id=str(new_agent.id),
507 custom_fields={"error": str(tool_error), "agent_name": new_agent.name},
508 )
509 # Refresh the agent to ensure it's in a clean state after any rollback
510 db.refresh(new_agent)
511 logger.info(f"Registered new A2A agent: {new_agent.name} (ID: {new_agent.id}) without tool")
513 # Log A2A agent registration for lifecycle tracking
514 structured_logger.info(
515 f"A2A agent '{new_agent.name}' registered successfully",
516 user_id=created_by,
517 user_email=owner_email,
518 team_id=team_id,
519 resource_type="a2a_agent",
520 resource_id=str(new_agent.id),
521 resource_action="create",
522 custom_fields={
523 "agent_name": new_agent.name,
524 "agent_type": new_agent.agent_type,
525 "protocol_version": new_agent.protocol_version,
526 "visibility": visibility,
527 "endpoint_url": new_agent.endpoint_url,
528 },
529 )
531 return self.convert_agent_to_read(new_agent, db=db)
533 except A2AAgentNameConflictError as ie:
534 db.rollback()
535 raise ie
536 except IntegrityError as ie:
537 db.rollback()
538 logger.error(f"IntegrityErrors in group: {ie}")
539 raise ie
540 except ValueError as ve:
541 raise ve
542 except Exception as e:
543 db.rollback()
544 raise A2AAgentError(f"Failed to register A2A agent: {str(e)}")
546 async def list_agents(
547 self,
548 db: Session,
549 cursor: Optional[str] = None,
550 include_inactive: bool = False,
551 tags: Optional[List[str]] = None,
552 limit: Optional[int] = None,
553 page: Optional[int] = None,
554 per_page: Optional[int] = None,
555 user_email: Optional[str] = None,
556 token_teams: Optional[List[str]] = None,
557 team_id: Optional[str] = None,
558 visibility: Optional[str] = None,
559 ) -> Union[tuple[List[A2AAgentRead], Optional[str]], Dict[str, Any]]:
560 """List A2A agents with cursor pagination and optional team filtering.
562 Args:
563 db: Database session.
564 cursor: Pagination cursor for keyset pagination.
565 include_inactive: Whether to include inactive agents.
566 tags: List of tags to filter by.
567 limit: Maximum number of agents to return. None for default, 0 for unlimited.
568 page: Page number for page-based pagination (1-indexed). Mutually exclusive with cursor.
569 per_page: Items per page for page-based pagination. Defaults to pagination_default_page_size.
570 user_email: Email of user for owner matching in visibility checks.
571 token_teams: Teams from JWT token. None = admin (no filtering),
572 [] = public-only, [...] = team-scoped access.
573 team_id: Optional team ID to filter by specific team.
574 visibility: Optional visibility filter (private, team, public).
576 Returns:
577 If page is provided: Dict with {"data": [...], "pagination": {...}, "links": {...}}
578 If cursor is provided or neither: tuple of (list of A2AAgentRead objects, next_cursor).
580 Examples:
581 >>> from mcpgateway.services.a2a_service import A2AAgentService
582 >>> from unittest.mock import MagicMock
583 >>> from mcpgateway.schemas import A2AAgentRead
584 >>> import asyncio
586 >>> service = A2AAgentService()
587 >>> db = MagicMock()
589 >>> # Mock a single agent object returned by the DB
590 >>> agent_obj = MagicMock()
591 >>> db.execute.return_value.scalars.return_value.all.return_value = [agent_obj]
593 >>> # Mock the A2AAgentRead schema to return a masked string
594 >>> mocked_agent_read = MagicMock()
595 >>> mocked_agent_read.masked.return_value = 'agent_read'
596 >>> A2AAgentRead.model_validate = MagicMock(return_value=mocked_agent_read)
598 >>> # Run the service method
599 >>> agents, cursor = asyncio.run(service.list_agents(db))
600 >>> agents == ['agent_read'] and cursor is None
601 True
603 >>> # Test include_inactive parameter (same mock works)
604 >>> agents_with_inactive, cursor = asyncio.run(service.list_agents(db, include_inactive=True))
605 >>> agents_with_inactive == ['agent_read'] and cursor is None
606 True
608 >>> # Test empty result
609 >>> db.execute.return_value.scalars.return_value.all.return_value = []
610 >>> empty_agents, cursor = asyncio.run(service.list_agents(db))
611 >>> empty_agents == [] and cursor is None
612 True
614 """
615 # ══════════════════════════════════════════════════════════════════════
616 # CACHE READ: Skip cache when ANY access filtering is applied
617 # This prevents leaking admin-level results to filtered requests
618 # Cache only when: user_email is None AND token_teams is None AND page is None
619 # ══════════════════════════════════════════════════════════════════════
620 cache = _get_registry_cache()
621 if cursor is None and user_email is None and token_teams is None and page is None:
622 filters_hash = cache.hash_filters(include_inactive=include_inactive, tags=sorted(tags) if tags else None)
623 cached = await cache.get("agents", filters_hash)
624 if cached is not None:
625 # Reconstruct A2AAgentRead objects from cached dicts
626 cached_agents = [A2AAgentRead.model_validate(a) for a in cached["agents"]]
627 return (cached_agents, cached.get("next_cursor"))
629 # Build base query with ordering
630 query = select(DbA2AAgent).order_by(desc(DbA2AAgent.created_at), desc(DbA2AAgent.id))
632 # Apply active/inactive filter
633 if not include_inactive:
634 query = query.where(DbA2AAgent.enabled)
636 # Apply team-based access control if user_email is provided OR token_teams is explicitly set
637 # This ensures unauthenticated requests with token_teams=[] only see public agents
638 if user_email or token_teams is not None:
639 # Use token_teams if provided (for MCP/API token access), otherwise look up from DB
640 # Default is public-only access (empty teams) when no teams are available.
641 effective_teams: List[str] = []
642 if token_teams is not None:
643 effective_teams = token_teams
644 elif user_email: 644 ↛ 650line 644 didn't jump to line 650 because the condition on line 644 was always true
645 # Look up user's teams from DB (for admin UI / first-party access)
646 team_service = TeamManagementService(db)
647 user_teams = await team_service.get_user_teams(user_email)
648 effective_teams = [team.id for team in user_teams]
650 query = self._apply_visibility_filter(query, user_email, effective_teams, team_id)
652 # IMPORTANT: Apply visibility filter AFTER access control
653 # This allows users to further filter by visibility within their allowed access
654 if visibility:
655 query = query.where(DbA2AAgent.visibility == visibility)
657 # Add tag filtering if tags are provided (supports both List[str] and List[Dict] formats)
658 if tags:
659 query = query.where(json_contains_tag_expr(db, DbA2AAgent.tags, tags, match_any=True))
661 # Use unified pagination helper - handles both page and cursor pagination
662 pag_result = await unified_paginate(
663 db=db,
664 query=query,
665 page=page,
666 per_page=per_page,
667 cursor=cursor,
668 limit=limit,
669 base_url="/admin/a2a", # Used for page-based links
670 query_params={"include_inactive": include_inactive} if include_inactive else {},
671 )
673 next_cursor = None
674 # Extract servers based on pagination type
675 if page is not None:
676 # Page-based: pag_result is a dict
677 a2a_agents_db = pag_result["data"]
678 else:
679 # Cursor-based: pag_result is a tuple
680 a2a_agents_db, next_cursor = pag_result
682 # Fetch team names for the agents (common for both pagination types)
683 team_ids_set = {s.team_id for s in a2a_agents_db if s.team_id}
684 team_map = {}
685 if team_ids_set:
686 teams = db.execute(select(EmailTeam.id, EmailTeam.name).where(EmailTeam.id.in_(team_ids_set), EmailTeam.is_active.is_(True))).all()
687 team_map = {team.id: team.name for team in teams}
689 db.commit() # Release transaction to avoid idle-in-transaction
691 # Convert to A2AAgentRead (common for both pagination types)
692 result = []
693 for s in a2a_agents_db:
694 try:
695 s.team = team_map.get(s.team_id) if s.team_id else None
696 result.append(self.convert_agent_to_read(s, include_metrics=False, db=db, team_map=team_map))
697 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
698 logger.exception(f"Failed to convert A2A agent {getattr(s, 'id', 'unknown')} ({getattr(s, 'name', 'unknown')}): {e}")
699 # Continue with remaining agents instead of failing completely
701 # Return appropriate format based on pagination type
702 if page is not None:
703 # Page-based format
704 return {
705 "data": result,
706 "pagination": pag_result["pagination"],
707 "links": pag_result["links"],
708 }
710 # Cursor-based format
712 # ══════════════════════════════════════════════════════════════════════
713 # CACHE WRITE: Only cache admin-level results (matches read guard)
714 # MUST check token_teams is None to prevent caching scoped responses
715 # ══════════════════════════════════════════════════════════════════════
716 if cursor is None and user_email is None and token_teams is None:
717 try:
718 cache_data = {"agents": [s.model_dump(mode="json") for s in result], "next_cursor": next_cursor}
719 await cache.set("agents", cache_data, filters_hash)
720 except AttributeError:
721 pass # Skip caching if result objects don't support model_dump (e.g., in doctests)
723 return (result, next_cursor)
725 async def list_agents_for_user(
726 self, db: Session, user_info: Dict[str, Any], team_id: Optional[str] = None, visibility: Optional[str] = None, include_inactive: bool = False, skip: int = 0, limit: int = 100
727 ) -> List[A2AAgentRead]:
728 """
729 DEPRECATED: Use list_agents() with user_email parameter instead.
731 This method is maintained for backward compatibility but is no longer used.
732 New code should call list_agents() with user_email, team_id, and visibility parameters.
734 List A2A agents user has access to with team filtering.
736 Args:
737 db: Database session
738 user_info: Object representing identity of the user who is requesting agents
739 team_id: Optional team ID to filter by specific team
740 visibility: Optional visibility filter (private, team, public)
741 include_inactive: Whether to include inactive agents
742 skip: Number of agents to skip for pagination
743 limit: Maximum number of agents to return
745 Returns:
746 List[A2AAgentRead]: A2A agents the user has access to
747 """
749 # Handle case where user_info is a string (email) instead of dict (<0.7.0)
750 if isinstance(user_info, str):
751 user_email = str(user_info)
752 else:
753 user_email = user_info.get("email", "")
755 # Build query following existing patterns from list_prompts()
756 team_service = TeamManagementService(db)
757 user_teams = await team_service.get_user_teams(user_email)
758 team_ids = [team.id for team in user_teams]
760 # Build query following existing patterns from list_agents()
761 query = select(DbA2AAgent)
763 # Apply active/inactive filter
764 if not include_inactive:
765 query = query.where(DbA2AAgent.enabled.is_(True))
767 if team_id:
768 if team_id not in team_ids:
769 return [] # No access to team
771 access_conditions = []
772 # Filter by specific team
773 access_conditions.append(and_(DbA2AAgent.team_id == team_id, DbA2AAgent.visibility.in_(["team", "public"])))
775 access_conditions.append(and_(DbA2AAgent.team_id == team_id, DbA2AAgent.owner_email == user_email))
777 query = query.where(or_(*access_conditions))
778 else:
779 # Get user's accessible teams
780 # Build access conditions following existing patterns
781 access_conditions = []
782 # 1. User's personal resources (owner_email matches)
783 access_conditions.append(DbA2AAgent.owner_email == user_email)
784 # 2. Team A2A Agents where user is member
785 if team_ids:
786 access_conditions.append(and_(DbA2AAgent.team_id.in_(team_ids), DbA2AAgent.visibility.in_(["team", "public"])))
787 # 3. Public resources (if visibility allows)
788 access_conditions.append(DbA2AAgent.visibility == "public")
790 query = query.where(or_(*access_conditions))
792 # Apply visibility filter if specified
793 if visibility:
794 query = query.where(DbA2AAgent.visibility == visibility)
796 # Apply pagination following existing patterns
797 query = query.order_by(desc(DbA2AAgent.created_at))
798 query = query.offset(skip).limit(limit)
800 agents = db.execute(query).scalars().all()
802 # Batch fetch team names to avoid N+1 queries
803 team_ids = list({a.team_id for a in agents if a.team_id})
804 team_map = self._batch_get_team_names(db, team_ids)
806 db.commit() # Release transaction to avoid idle-in-transaction
808 # Skip metrics to avoid N+1 queries in list operations
809 result = []
810 for agent in agents:
811 try:
812 result.append(self.convert_agent_to_read(agent, include_metrics=False, db=db, team_map=team_map))
813 except (ValidationError, ValueError, KeyError, TypeError, binascii.Error) as e:
814 logger.exception(f"Failed to convert A2A agent {getattr(agent, 'id', 'unknown')} ({getattr(agent, 'name', 'unknown')}): {e}")
815 # Continue with remaining agents instead of failing completely
817 return result
819 async def get_agent(
820 self,
821 db: Session,
822 agent_id: str,
823 include_inactive: bool = True,
824 user_email: Optional[str] = None,
825 token_teams: Optional[List[str]] = None,
826 ) -> A2AAgentRead:
827 """Retrieve an A2A agent by ID.
829 Args:
830 db: Database session.
831 agent_id: Agent ID.
832 include_inactive: Whether to include inactive a2a agents.
833 user_email: User's email for owner matching in visibility checks.
834 token_teams: Teams from JWT token. None = admin (no filtering),
835 [] = public-only, [...] = team-scoped access.
837 Returns:
838 Agent data.
840 Raises:
841 A2AAgentNotFoundError: If the agent is not found or user lacks access.
843 Examples:
844 >>> from unittest.mock import MagicMock
845 >>> from datetime import datetime
846 >>> import asyncio
847 >>> from mcpgateway.schemas import A2AAgentRead
848 >>> from mcpgateway.services.a2a_service import A2AAgentService, A2AAgentNotFoundError
850 >>> service = A2AAgentService()
851 >>> db = MagicMock()
853 >>> # Create a mock agent
854 >>> agent_mock = MagicMock()
855 >>> agent_mock.enabled = True
856 >>> agent_mock.id = "agent_id"
857 >>> agent_mock.name = "Test Agent"
858 >>> agent_mock.slug = "test-agent"
859 >>> agent_mock.description = "A2A test agent"
860 >>> agent_mock.endpoint_url = "https://example.com"
861 >>> agent_mock.agent_type = "rest"
862 >>> agent_mock.protocol_version = "v1"
863 >>> agent_mock.capabilities = {}
864 >>> agent_mock.config = {}
865 >>> agent_mock.reachable = True
866 >>> agent_mock.created_at = datetime.now()
867 >>> agent_mock.updated_at = datetime.now()
868 >>> agent_mock.last_interaction = None
869 >>> agent_mock.tags = []
870 >>> agent_mock.metrics = MagicMock()
871 >>> agent_mock.metrics.success_rate = 1.0
872 >>> agent_mock.metrics.failure_rate = 0.0
873 >>> agent_mock.metrics.last_error = None
874 >>> agent_mock.auth_type = None
875 >>> agent_mock.auth_value = None
876 >>> agent_mock.oauth_config = None
877 >>> agent_mock.created_by = "user"
878 >>> agent_mock.created_from_ip = "127.0.0.1"
879 >>> agent_mock.created_via = "ui"
880 >>> agent_mock.created_user_agent = "test-agent"
881 >>> agent_mock.modified_by = "user"
882 >>> agent_mock.modified_from_ip = "127.0.0.1"
883 >>> agent_mock.modified_via = "ui"
884 >>> agent_mock.modified_user_agent = "test-agent"
885 >>> agent_mock.import_batch_id = None
886 >>> agent_mock.federation_source = None
887 >>> agent_mock.team_id = "team-1"
888 >>> agent_mock.team = "Team 1"
889 >>> agent_mock.owner_email = "owner@example.com"
890 >>> agent_mock.visibility = "public"
892 >>> db.get.return_value = agent_mock
894 >>> # Mock convert_agent_to_read to simplify test
895 >>> service.convert_agent_to_read = lambda db_agent, **kwargs: 'agent_read'
897 >>> # Test with active agent
898 >>> result = asyncio.run(service.get_agent(db, 'agent_id'))
899 >>> result
900 'agent_read'
902 >>> # Test with inactive agent but include_inactive=True
903 >>> agent_mock.enabled = False
904 >>> result_inactive = asyncio.run(service.get_agent(db, 'agent_id', include_inactive=True))
905 >>> result_inactive
906 'agent_read'
908 """
909 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
910 agent = db.execute(query).scalar_one_or_none()
912 if not agent:
913 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
915 if not agent.enabled and not include_inactive:
916 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
918 # SECURITY: Check visibility/team access
919 # Return 404 (not 403) to avoid leaking existence of private agents
920 if not self._check_agent_access(agent, user_email, token_teams):
921 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
923 # Delegate conversion and masking to convert_agent_to_read()
924 return self.convert_agent_to_read(agent, db=db)
926 async def get_agent_by_name(self, db: Session, agent_name: str) -> A2AAgentRead:
927 """Retrieve an A2A agent by name.
929 Args:
930 db: Database session.
931 agent_name: Agent name.
933 Returns:
934 Agent data.
936 Raises:
937 A2AAgentNotFoundError: If the agent is not found.
938 """
939 query = select(DbA2AAgent).where(DbA2AAgent.name == agent_name)
940 agent = db.execute(query).scalar_one_or_none()
942 if not agent:
943 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
945 return self.convert_agent_to_read(agent, db=db)
947 async def update_agent(
948 self,
949 db: Session,
950 agent_id: str,
951 agent_data: A2AAgentUpdate,
952 modified_by: Optional[str] = None,
953 modified_from_ip: Optional[str] = None,
954 modified_via: Optional[str] = None,
955 modified_user_agent: Optional[str] = None,
956 user_email: Optional[str] = None,
957 ) -> A2AAgentRead:
958 """Update an existing A2A agent.
960 Args:
961 db: Database session.
962 agent_id: Agent ID.
963 agent_data: Agent update data.
964 modified_by: Username who modified this agent.
965 modified_from_ip: IP address of modifier.
966 modified_via: Modification method.
967 modified_user_agent: User agent of modification request.
968 user_email: Email of user performing update (for ownership check).
970 Returns:
971 Updated agent data.
973 Raises:
974 A2AAgentNotFoundError: If the agent is not found.
975 PermissionError: If user doesn't own the agent.
976 A2AAgentNameConflictError: If name conflicts with another agent.
977 A2AAgentError: For other errors during update.
978 IntegrityError: If a database integrity error occurs.
979 ValueError: If query_param auth is disabled or host not in allowlist.
980 """
981 try:
982 # Acquire row lock for update to avoid lost-update on `version` and other fields
983 agent = get_for_update(db, DbA2AAgent, agent_id)
985 if not agent:
986 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
988 # Check ownership if user_email provided
989 if user_email:
990 # First-Party
991 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
993 permission_service = PermissionService(db)
994 if not await permission_service.check_resource_ownership(user_email, agent):
995 raise PermissionError("Only the owner can update this agent")
996 # Check for name conflict if name is being updated
997 if agent_data.name and agent_data.name != agent.name:
998 new_slug = slugify(agent_data.name)
999 visibility = agent_data.visibility or agent.visibility
1000 team_id = agent_data.team_id or agent.team_id
1001 # Check for existing server with the same slug within the same team or public scope
1002 if visibility.lower() == "public":
1003 # Check for existing public a2a agent with the same slug
1004 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == new_slug, DbA2AAgent.visibility == "public"))
1005 if existing_agent:
1006 raise A2AAgentNameConflictError(name=new_slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
1007 elif visibility.lower() == "team" and team_id: 1007 ↛ 1013line 1007 didn't jump to line 1013 because the condition on line 1007 was always true
1008 # Check for existing team a2a agent with the same slug
1009 existing_agent = get_for_update(db, DbA2AAgent, where=and_(DbA2AAgent.slug == new_slug, DbA2AAgent.visibility == "team", DbA2AAgent.team_id == team_id))
1010 if existing_agent: 1010 ↛ 1013line 1010 didn't jump to line 1013 because the condition on line 1010 was always true
1011 raise A2AAgentNameConflictError(name=new_slug, is_active=existing_agent.enabled, agent_id=existing_agent.id, visibility=existing_agent.visibility)
1012 # Update the slug when name changes
1013 agent.slug = new_slug
1014 # Update fields
1015 # Avoid `model_dump()` here: tests use `model_construct()` to create intentionally invalid
1016 # payloads, and `model_dump()` emits serializer warnings when encountering unexpected types.
1017 update_data = {field: getattr(agent_data, field) for field in agent_data.model_fields_set}
1019 # Track original auth_type and endpoint_url before updates
1020 original_auth_type = agent.auth_type
1021 original_endpoint_url = agent.endpoint_url
1023 for field, value in update_data.items():
1024 if field == "passthrough_headers":
1025 if value is not None:
1026 if isinstance(value, list):
1027 # Clean list: remove empty or whitespace-only entries
1028 cleaned = [h.strip() for h in value if isinstance(h, str) and h.strip()]
1029 agent.passthrough_headers = cleaned or None
1030 elif isinstance(value, str):
1031 # Parse comma-separated string and clean
1032 parsed: List[str] = [h.strip() for h in value.split(",") if h.strip()]
1033 agent.passthrough_headers = parsed or None
1034 else:
1035 raise A2AAgentError("Invalid passthrough_headers format: must be list[str] or comma-separated string")
1036 else:
1037 # Explicitly set to None if value is None
1038 agent.passthrough_headers = None
1039 continue
1041 # Skip query_param fields - handled separately below
1042 if field in ("auth_query_param_key", "auth_query_param_value"):
1043 continue
1045 if hasattr(agent, field):
1046 setattr(agent, field, value)
1048 # Handle query_param auth updates
1049 # Clear auth_query_params when switching away from query_param auth
1050 if original_auth_type == "query_param" and agent_data.auth_type is not None and agent_data.auth_type != "query_param":
1051 agent.auth_query_params = None
1052 logger.debug(f"Cleared auth_query_params for agent {agent.id} (switched from query_param to {agent_data.auth_type})")
1054 # Handle switching to query_param auth or updating existing query_param credentials
1055 is_switching_to_queryparam = agent_data.auth_type == "query_param" and original_auth_type != "query_param"
1056 is_updating_queryparam_creds = original_auth_type == "query_param" and (agent_data.auth_query_param_key is not None or agent_data.auth_query_param_value is not None)
1057 is_url_changing = agent_data.endpoint_url is not None and str(agent_data.endpoint_url) != original_endpoint_url
1059 if is_switching_to_queryparam or is_updating_queryparam_creds or (is_url_changing and original_auth_type == "query_param"):
1060 # Standard
1061 from urllib.parse import urlparse # pylint: disable=import-outside-toplevel
1063 # First-Party
1064 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
1066 # Service-layer enforcement: Check feature flag
1067 if not settings.insecure_allow_queryparam_auth:
1068 # Grandfather clause: Allow updates to existing query_param agents
1069 # unless they're trying to change credentials
1070 if is_switching_to_queryparam or is_updating_queryparam_creds:
1071 raise ValueError("Query parameter authentication is disabled. Set INSECURE_ALLOW_QUERYPARAM_AUTH=true to enable.")
1073 # Service-layer enforcement: Check host allowlist
1074 if settings.insecure_queryparam_auth_allowed_hosts:
1075 check_url = str(agent_data.endpoint_url) if agent_data.endpoint_url else agent.endpoint_url
1076 parsed = urlparse(check_url)
1077 hostname = (parsed.hostname or "").lower()
1078 allowed_hosts = [h.lower() for h in settings.insecure_queryparam_auth_allowed_hosts]
1079 if hostname not in allowed_hosts:
1080 allowed = ", ".join(settings.insecure_queryparam_auth_allowed_hosts)
1081 raise ValueError(f"Host '{hostname}' is not in the allowed hosts for query param auth. " f"Allowed: {allowed}")
1083 if is_switching_to_queryparam or is_updating_queryparam_creds:
1084 # Get query param key and value
1085 param_key = getattr(agent_data, "auth_query_param_key", None)
1086 param_value = getattr(agent_data, "auth_query_param_value", None)
1088 # If no key provided but value is, reuse existing key (value-only rotation)
1089 existing_key = next(iter(agent.auth_query_params.keys()), None) if agent.auth_query_params else None
1090 if not param_key and param_value and existing_key:
1091 param_key = existing_key
1093 if param_key: 1093 ↛ 1125line 1093 didn't jump to line 1125 because the condition on line 1093 was always true
1094 # Check if value is masked (user didn't change it) or new value provided
1095 is_masked_placeholder = False
1096 if param_value and hasattr(param_value, "get_secret_value"):
1097 raw_value = param_value.get_secret_value()
1098 # First-Party
1099 from mcpgateway.config import settings # pylint: disable=import-outside-toplevel
1101 is_masked_placeholder = raw_value == settings.masked_auth_value
1102 elif param_value:
1103 raw_value = str(param_value)
1104 else:
1105 raw_value = None
1107 if raw_value and not is_masked_placeholder:
1108 # New value provided - encrypt for storage
1109 encrypted_value = encode_auth({param_key: raw_value})
1110 agent.auth_query_params = {param_key: encrypted_value}
1111 elif agent.auth_query_params and is_masked_placeholder:
1112 # Use existing encrypted value (user didn't change the password)
1113 # But key may have changed, so preserve with new key if different
1114 if existing_key and existing_key != param_key: 1114 ↛ 1125line 1114 didn't jump to line 1125 because the condition on line 1114 was always true
1115 # Key changed but value is masked - decrypt and re-encrypt with new key
1116 existing_encrypted = agent.auth_query_params.get(existing_key, "")
1117 if existing_encrypted: 1117 ↛ 1125line 1117 didn't jump to line 1125 because the condition on line 1117 was always true
1118 decrypted = decode_auth(existing_encrypted)
1119 existing_value = decrypted.get(existing_key, "")
1120 if existing_value: 1120 ↛ 1125line 1120 didn't jump to line 1125 because the condition on line 1120 was always true
1121 encrypted_value = encode_auth({param_key: existing_value})
1122 agent.auth_query_params = {param_key: encrypted_value}
1124 # Update auth_type if switching
1125 if is_switching_to_queryparam:
1126 agent.auth_type = "query_param"
1127 agent.auth_value = None # Query param auth doesn't use auth_value
1129 # Update metadata
1130 if modified_by:
1131 agent.modified_by = modified_by
1132 if modified_from_ip:
1133 agent.modified_from_ip = modified_from_ip
1134 if modified_via:
1135 agent.modified_via = modified_via
1136 if modified_user_agent:
1137 agent.modified_user_agent = modified_user_agent
1139 agent.version += 1
1141 db.commit()
1142 db.refresh(agent)
1144 # Invalidate cache after successful update
1145 cache = _get_registry_cache()
1146 await cache.invalidate_agents()
1147 # Also invalidate tags cache since agent tags may have changed
1148 # First-Party
1149 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1151 await admin_stats_cache.invalidate_tags()
1153 # Update the associated tool if it exists
1154 # Wrap in try/except to handle tool sync failures gracefully - the agent
1155 # update is the primary operation and should succeed even if tool sync fails
1156 try:
1157 # First-Party
1158 from mcpgateway.services.tool_service import tool_service
1160 await tool_service.update_tool_from_a2a_agent(
1161 db=db,
1162 agent=agent,
1163 modified_by=modified_by,
1164 modified_from_ip=modified_from_ip,
1165 modified_via=modified_via,
1166 modified_user_agent=modified_user_agent,
1167 )
1168 except Exception as tool_err:
1169 logger.warning(f"Failed to sync tool for A2A agent {agent.id}: {tool_err}. Agent update succeeded but tool may be out of sync.")
1171 logger.info(f"Updated A2A agent: {agent.name} (ID: {agent.id})")
1172 return self.convert_agent_to_read(agent, db=db)
1173 except PermissionError:
1174 db.rollback()
1175 raise
1176 except A2AAgentNameConflictError as ie:
1177 db.rollback()
1178 raise ie
1179 except A2AAgentNotFoundError as nf:
1180 db.rollback()
1181 raise nf
1182 except IntegrityError as ie:
1183 db.rollback()
1184 logger.error(f"IntegrityErrors in group: {ie}")
1185 raise ie
1186 except Exception as e:
1187 db.rollback()
1188 raise A2AAgentError(f"Failed to update A2A agent: {str(e)}")
1190 async def set_agent_state(self, db: Session, agent_id: str, activate: bool, reachable: Optional[bool] = None, user_email: Optional[str] = None) -> A2AAgentRead:
1191 """Set the activation status of an A2A agent.
1193 Args:
1194 db: Database session.
1195 agent_id: Agent ID.
1196 activate: True to activate, False to deactivate.
1197 reachable: Optional reachability status.
1198 user_email: Optional[str] The email of the user to check if the user has permission to modify.
1200 Returns:
1201 Updated agent data.
1203 Raises:
1204 A2AAgentNotFoundError: If the agent is not found.
1205 PermissionError: If user doesn't own the agent.
1206 """
1207 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
1208 agent = db.execute(query).scalar_one_or_none()
1210 if not agent:
1211 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
1213 if user_email:
1214 # First-Party
1215 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1217 permission_service = PermissionService(db)
1218 if not await permission_service.check_resource_ownership(user_email, agent):
1219 raise PermissionError("Only the owner can activate the Agent" if activate else "Only the owner can deactivate the Agent")
1221 agent.enabled = activate
1222 if reachable is not None:
1223 agent.reachable = reachable
1225 db.commit()
1226 db.refresh(agent)
1228 # Invalidate caches since agent status changed
1229 a2a_stats_cache.invalidate()
1230 cache = _get_registry_cache()
1231 await cache.invalidate_agents()
1233 status = "activated" if activate else "deactivated"
1234 logger.info(f"A2A agent {status}: {agent.name} (ID: {agent.id})")
1236 structured_logger.log(
1237 level="INFO",
1238 message=f"A2A agent {status}",
1239 event_type="a2a_agent_status_changed",
1240 component="a2a_service",
1241 user_email=user_email,
1242 resource_type="a2a_agent",
1243 resource_id=str(agent.id),
1244 custom_fields={
1245 "agent_name": agent.name,
1246 "enabled": agent.enabled,
1247 "reachable": agent.reachable,
1248 },
1249 )
1251 return self.convert_agent_to_read(agent, db=db)
1253 async def delete_agent(self, db: Session, agent_id: str, user_email: Optional[str] = None, purge_metrics: bool = False) -> None:
1254 """Delete an A2A agent.
1256 Args:
1257 db: Database session.
1258 agent_id: Agent ID.
1259 user_email: Email of user performing delete (for ownership check).
1260 purge_metrics: If True, delete raw + rollup metrics for this agent.
1262 Raises:
1263 A2AAgentNotFoundError: If the agent is not found.
1264 PermissionError: If user doesn't own the agent.
1265 """
1266 try:
1267 query = select(DbA2AAgent).where(DbA2AAgent.id == agent_id)
1268 agent = db.execute(query).scalar_one_or_none()
1270 if not agent:
1271 raise A2AAgentNotFoundError(f"A2A Agent not found with ID: {agent_id}")
1273 # Check ownership if user_email provided
1274 if user_email:
1275 # First-Party
1276 from mcpgateway.services.permission_service import PermissionService # pylint: disable=import-outside-toplevel
1278 permission_service = PermissionService(db)
1279 if not await permission_service.check_resource_ownership(user_email, agent):
1280 raise PermissionError("Only the owner can delete this agent")
1282 agent_name = agent.name
1284 # Delete the associated tool before deleting the agent
1285 # First-Party
1286 from mcpgateway.services.tool_service import tool_service
1288 await tool_service.delete_tool_from_a2a_agent(db=db, agent=agent, user_email=user_email, purge_metrics=purge_metrics)
1290 if purge_metrics:
1291 with pause_rollup_during_purge(reason=f"purge_a2a_agent:{agent_id}"):
1292 delete_metrics_in_batches(db, A2AAgentMetric, A2AAgentMetric.a2a_agent_id, agent_id)
1293 delete_metrics_in_batches(db, A2AAgentMetricsHourly, A2AAgentMetricsHourly.a2a_agent_id, agent_id)
1294 db.delete(agent)
1295 db.commit()
1297 # Invalidate caches since agent count changed
1298 a2a_stats_cache.invalidate()
1299 cache = _get_registry_cache()
1300 await cache.invalidate_agents()
1301 # Also invalidate tags cache since agent tags may have changed
1302 # First-Party
1303 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
1305 await admin_stats_cache.invalidate_tags()
1307 logger.info(f"Deleted A2A agent: {agent_name} (ID: {agent_id})")
1309 structured_logger.log(
1310 level="INFO",
1311 message="A2A agent deleted",
1312 event_type="a2a_agent_deleted",
1313 component="a2a_service",
1314 user_email=user_email,
1315 resource_type="a2a_agent",
1316 resource_id=str(agent_id),
1317 custom_fields={
1318 "agent_name": agent_name,
1319 "purge_metrics": purge_metrics,
1320 },
1321 )
1322 except PermissionError:
1323 db.rollback()
1324 raise
1326 async def invoke_agent(
1327 self,
1328 db: Session,
1329 agent_name: str,
1330 parameters: Dict[str, Any],
1331 interaction_type: str = "query",
1332 *,
1333 user_id: Optional[str] = None,
1334 user_email: Optional[str] = None,
1335 token_teams: Optional[List[str]] = None,
1336 ) -> Dict[str, Any]:
1337 """Invoke an A2A agent.
1339 Args:
1340 db: Database session.
1341 agent_name: Name of the agent to invoke.
1342 parameters: Parameters for the interaction.
1343 interaction_type: Type of interaction.
1344 user_id: Identifier of the user initiating the call.
1345 user_email: Email of the user initiating the call.
1346 token_teams: Teams from JWT token. None = admin (no filtering),
1347 [] = public-only, [...] = team-scoped access.
1349 Returns:
1350 Agent response.
1352 Raises:
1353 A2AAgentNotFoundError: If the agent is not found or user lacks access.
1354 A2AAgentError: If the agent is disabled or invocation fails.
1355 """
1356 # ═══════════════════════════════════════════════════════════════════════════
1357 # PHASE 1: Acquire a short row lock to read `enabled` + `auth_value`,
1358 # then release the lock before performing the external HTTP call.
1359 # This avoids TOCTOU for the critical checks while not holding DB
1360 # connections during the potentially slow HTTP request.
1361 # ═══════════════════════════════════════════════════════════════════════════
1363 # Lookup the agent id, then lock the row by id using get_for_update
1364 agent_row = db.execute(select(DbA2AAgent.id).where(DbA2AAgent.name == agent_name)).scalar_one_or_none()
1365 if not agent_row:
1366 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1368 agent = get_for_update(db, DbA2AAgent, agent_row)
1369 if not agent:
1370 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1372 # ═══════════════════════════════════════════════════════════════════════════
1373 # SECURITY: Check visibility/team access WHILE ROW IS LOCKED
1374 # Return 404 (not 403) to avoid leaking existence of private agents
1375 # ═══════════════════════════════════════════════════════════════════════════
1376 if not self._check_agent_access(agent, user_email, token_teams):
1377 raise A2AAgentNotFoundError(f"A2A Agent not found with name: {agent_name}")
1379 if not agent.enabled:
1380 raise A2AAgentError(f"A2A Agent '{agent_name}' is disabled")
1382 # Extract all needed data to local variables before releasing DB connection
1383 agent_id = agent.id
1384 agent_endpoint_url = agent.endpoint_url
1385 agent_type = agent.agent_type
1386 agent_protocol_version = agent.protocol_version
1387 agent_auth_type = agent.auth_type
1388 agent_auth_value = agent.auth_value
1389 agent_auth_query_params = agent.auth_query_params
1391 # Handle query_param auth - decrypt and apply to URL
1392 auth_query_params_decrypted: Optional[Dict[str, str]] = None
1393 if agent_auth_type == "query_param" and agent_auth_query_params:
1394 # First-Party
1395 from mcpgateway.utils.url_auth import apply_query_param_auth # pylint: disable=import-outside-toplevel
1397 auth_query_params_decrypted = {}
1398 for param_key, encrypted_value in agent_auth_query_params.items():
1399 if encrypted_value:
1400 try:
1401 decrypted = decode_auth(encrypted_value)
1402 auth_query_params_decrypted[param_key] = decrypted.get(param_key, "")
1403 except Exception:
1404 logger.debug(f"Failed to decrypt query param '{param_key}' for A2A agent invocation")
1405 if auth_query_params_decrypted:
1406 agent_endpoint_url = apply_query_param_auth(agent_endpoint_url, auth_query_params_decrypted)
1408 # Decode auth_value for supported auth types (before closing session)
1409 auth_headers = {}
1410 if agent_auth_type in ("basic", "bearer", "authheaders") and agent_auth_value:
1411 # Decrypt auth_value and extract headers (follows gateway_service pattern)
1412 if isinstance(agent_auth_value, str):
1413 try:
1414 auth_headers = decode_auth(agent_auth_value)
1415 except Exception as e:
1416 raise A2AAgentError(f"Failed to decrypt authentication for agent '{agent_name}': {e}")
1417 elif isinstance(agent_auth_value, dict): 1417 ↛ 1424line 1417 didn't jump to line 1424 because the condition on line 1417 was always true
1418 auth_headers = {str(k): str(v) for k, v in agent_auth_value.items()}
1420 # ═══════════════════════════════════════════════════════════════════════════
1421 # CRITICAL: Release DB connection back to pool BEFORE making HTTP calls
1422 # This prevents connection pool exhaustion during slow upstream requests.
1423 # ═══════════════════════════════════════════════════════════════════════════
1424 db.commit() # End read-only transaction cleanly (commit not rollback to avoid inflating rollback stats)
1425 db.close()
1427 start_time = datetime.now(timezone.utc)
1428 success = False
1429 error_message = None
1430 response = None
1432 # ═══════════════════════════════════════════════════════════════════════════
1433 # PHASE 2: Make HTTP call (no DB connection held)
1434 # ═══════════════════════════════════════════════════════════════════════════
1436 # Create sanitized URL for logging (redacts auth query params)
1437 # First-Party
1438 from mcpgateway.utils.url_auth import sanitize_exception_message, sanitize_url_for_logging # pylint: disable=import-outside-toplevel
1440 sanitized_endpoint_url = sanitize_url_for_logging(agent_endpoint_url, auth_query_params_decrypted)
1442 try:
1443 # Prepare the request to the A2A agent
1444 # Format request based on agent type and endpoint
1445 if agent_type in ["generic", "jsonrpc"] or agent_endpoint_url.endswith("/"):
1446 # Use JSONRPC format for agents that expect it
1447 request_data = {"jsonrpc": "2.0", "method": parameters.get("method", "message/send"), "params": parameters.get("params", parameters), "id": 1}
1448 else:
1449 # Use custom A2A format
1450 request_data = {"interaction_type": interaction_type, "parameters": parameters, "protocol_version": agent_protocol_version}
1452 # Make HTTP request to the agent endpoint using shared HTTP client
1453 # First-Party
1454 from mcpgateway.services.http_client_service import get_http_client # pylint: disable=import-outside-toplevel
1456 client = await get_http_client()
1457 headers = {"Content-Type": "application/json"}
1459 # Add authentication if configured (using decoded auth headers)
1460 headers.update(auth_headers)
1462 # Add correlation ID to outbound headers for distributed tracing
1463 correlation_id = get_correlation_id()
1464 if correlation_id:
1465 headers["X-Correlation-ID"] = correlation_id
1467 # Log A2A external call start (with sanitized URL to prevent credential leakage)
1468 call_start_time = datetime.now(timezone.utc)
1469 structured_logger.log(
1470 level="INFO",
1471 message=f"A2A external call started: {agent_name}",
1472 component="a2a_service",
1473 user_id=user_id,
1474 user_email=user_email,
1475 correlation_id=correlation_id,
1476 metadata={
1477 "event": "a2a_call_started",
1478 "agent_name": agent_name,
1479 "agent_id": agent_id,
1480 "endpoint_url": sanitized_endpoint_url,
1481 "interaction_type": interaction_type,
1482 "protocol_version": agent_protocol_version,
1483 },
1484 )
1486 http_response = await client.post(agent_endpoint_url, json=request_data, headers=headers)
1487 call_duration_ms = (datetime.now(timezone.utc) - call_start_time).total_seconds() * 1000
1489 if http_response.status_code == 200:
1490 response = http_response.json()
1491 success = True
1493 # Log successful A2A call
1494 structured_logger.log(
1495 level="INFO",
1496 message=f"A2A external call completed: {agent_name}",
1497 component="a2a_service",
1498 user_id=user_id,
1499 user_email=user_email,
1500 correlation_id=correlation_id,
1501 duration_ms=call_duration_ms,
1502 metadata={"event": "a2a_call_completed", "agent_name": agent_name, "agent_id": agent_id, "status_code": http_response.status_code, "success": True},
1503 )
1504 else:
1505 # Sanitize error message to prevent URL secrets from leaking in logs
1506 raw_error = f"HTTP {http_response.status_code}: {http_response.text}"
1507 error_message = sanitize_exception_message(raw_error, auth_query_params_decrypted)
1509 # Log failed A2A call
1510 structured_logger.log(
1511 level="ERROR",
1512 message=f"A2A external call failed: {agent_name}",
1513 component="a2a_service",
1514 user_id=user_id,
1515 user_email=user_email,
1516 correlation_id=correlation_id,
1517 duration_ms=call_duration_ms,
1518 error_details={"error_type": "A2AHTTPError", "error_message": error_message},
1519 metadata={"event": "a2a_call_failed", "agent_name": agent_name, "agent_id": agent_id, "status_code": http_response.status_code},
1520 )
1522 raise A2AAgentError(error_message)
1524 except A2AAgentError:
1525 # Re-raise A2AAgentError without wrapping
1526 raise
1527 except Exception as e:
1528 # Sanitize error message to prevent URL secrets from leaking in logs
1529 error_message = sanitize_exception_message(str(e), auth_query_params_decrypted)
1530 logger.error(f"Failed to invoke A2A agent '{agent_name}': {error_message}")
1531 raise A2AAgentError(f"Failed to invoke A2A agent: {error_message}")
1533 finally:
1534 # ═══════════════════════════════════════════════════════════════════════════
1535 # PHASE 3: Record metrics via buffered service (batches writes for performance)
1536 # ═══════════════════════════════════════════════════════════════════════════
1537 end_time = datetime.now(timezone.utc)
1538 response_time = (end_time - start_time).total_seconds()
1540 try:
1541 # First-Party
1542 from mcpgateway.services.metrics_buffer_service import get_metrics_buffer_service # pylint: disable=import-outside-toplevel
1544 metrics_buffer = get_metrics_buffer_service()
1545 metrics_buffer.record_a2a_agent_metric_with_duration(
1546 a2a_agent_id=agent_id,
1547 response_time=response_time,
1548 success=success,
1549 interaction_type=interaction_type,
1550 error_message=error_message,
1551 )
1552 except Exception as metrics_error:
1553 logger.warning(f"Failed to record A2A metrics for '{agent_name}': {metrics_error}")
1555 # Update last interaction timestamp (quick separate write)
1556 try:
1557 with fresh_db_session() as ts_db:
1558 # Reacquire short lock and re-check enabled before writing
1559 db_agent = get_for_update(ts_db, DbA2AAgent, agent_id)
1560 if db_agent and getattr(db_agent, "enabled", False):
1561 db_agent.last_interaction = end_time
1562 ts_db.commit()
1563 except Exception as ts_error:
1564 logger.warning(f"Failed to update last_interaction for '{agent_name}': {ts_error}")
1566 return response or {"error": error_message}
1568 async def aggregate_metrics(self, db: Session) -> Dict[str, Any]:
1569 """Aggregate metrics for all A2A agents.
1571 Combines recent raw metrics (within retention period) with historical
1572 hourly rollups for complete historical coverage. Uses in-memory caching
1573 (10s TTL) to reduce database load under high request rates.
1575 Args:
1576 db: Database session.
1578 Returns:
1579 Aggregated metrics from raw + hourly rollup tables.
1580 """
1581 # Check cache first (if enabled)
1582 # First-Party
1583 from mcpgateway.cache.metrics_cache import is_cache_enabled, metrics_cache # pylint: disable=import-outside-toplevel
1585 if is_cache_enabled(): 1585 ↛ 1591line 1585 didn't jump to line 1591 because the condition on line 1585 was always true
1586 cached = metrics_cache.get("a2a")
1587 if cached is not None:
1588 return cached
1590 # Get total/active agent counts from cache (avoids 2 COUNT queries per call)
1591 counts = a2a_stats_cache.get_counts(db)
1592 total_agents = counts["total"]
1593 active_agents = counts["active"]
1595 # Use combined raw + rollup query for full historical coverage
1596 # First-Party
1597 from mcpgateway.services.metrics_query_service import aggregate_metrics_combined # pylint: disable=import-outside-toplevel
1599 result = aggregate_metrics_combined(db, "a2a_agent")
1601 total_interactions = result.total_executions
1602 successful_interactions = result.successful_executions
1603 failed_interactions = result.failed_executions
1605 metrics = {
1606 "total_agents": total_agents,
1607 "active_agents": active_agents,
1608 "total_interactions": total_interactions,
1609 "successful_interactions": successful_interactions,
1610 "failed_interactions": failed_interactions,
1611 "success_rate": (successful_interactions / total_interactions * 100) if total_interactions > 0 else 0.0,
1612 "avg_response_time": float(result.avg_response_time or 0.0),
1613 "min_response_time": float(result.min_response_time or 0.0),
1614 "max_response_time": float(result.max_response_time or 0.0),
1615 }
1617 # Cache the result (if enabled)
1618 if is_cache_enabled(): 1618 ↛ 1621line 1618 didn't jump to line 1621 because the condition on line 1618 was always true
1619 metrics_cache.set("a2a", metrics)
1621 return metrics
1623 async def reset_metrics(self, db: Session, agent_id: Optional[str] = None) -> None:
1624 """Reset metrics for agents (raw + hourly rollups).
1626 Args:
1627 db: Database session.
1628 agent_id: Optional agent ID to reset metrics for specific agent.
1629 """
1630 if agent_id:
1631 db.execute(delete(A2AAgentMetric).where(A2AAgentMetric.a2a_agent_id == agent_id))
1632 db.execute(delete(A2AAgentMetricsHourly).where(A2AAgentMetricsHourly.a2a_agent_id == agent_id))
1633 else:
1634 db.execute(delete(A2AAgentMetric))
1635 db.execute(delete(A2AAgentMetricsHourly))
1636 db.commit()
1638 # Invalidate metrics cache
1639 # First-Party
1640 from mcpgateway.cache.metrics_cache import metrics_cache # pylint: disable=import-outside-toplevel
1642 metrics_cache.invalidate("a2a")
1644 logger.info("Reset A2A agent metrics" + (f" for agent {agent_id}" if agent_id else ""))
1646 def _prepare_a2a_agent_for_read(self, agent: DbA2AAgent) -> DbA2AAgent:
1647 """Prepare a a2a agent object for A2AAgentRead validation.
1649 Ensures auth_value is in the correct format (encoded string) for the schema.
1651 Args:
1652 agent: A2A Agent database object
1654 Returns:
1655 A2A Agent object with properly formatted auth_value
1656 """
1657 # If auth_value is a dict, encode it to string for GatewayRead schema
1658 if isinstance(agent.auth_value, dict):
1659 agent.auth_value = encode_auth(agent.auth_value)
1660 return agent
1662 def convert_agent_to_read(self, db_agent: DbA2AAgent, include_metrics: bool = False, db: Optional[Session] = None, team_map: Optional[Dict[str, str]] = None) -> A2AAgentRead:
1663 """Convert database model to schema.
1665 Args:
1666 db_agent (DbA2AAgent): Database agent model.
1667 include_metrics (bool): Whether to include metrics in the result. Defaults to False.
1668 Set to False for list operations to avoid N+1 query issues.
1669 db (Optional[Session]): Database session. Only required if team name is not pre-populated
1670 on the db_agent object and team_map is not provided.
1671 team_map (Optional[Dict[str, str]]): Pre-fetched team_id -> team_name mapping.
1672 If provided, avoids N+1 queries for team name lookups in list operations.
1674 Returns:
1675 A2AAgentRead: Agent read schema.
1677 Raises:
1678 A2AAgentNotFoundError: If the provided agent is not found or invalid.
1680 """
1682 if not db_agent:
1683 raise A2AAgentNotFoundError("Agent not found")
1685 # Check if team attribute already exists (pre-populated in batch operations)
1686 # Otherwise use pre-fetched team map if available, otherwise query individually
1687 if not hasattr(db_agent, "team") or db_agent.team is None:
1688 team_id = getattr(db_agent, "team_id", None)
1689 if team_map is not None and team_id:
1690 team_name = team_map.get(team_id)
1691 elif db is not None:
1692 team_name = self._get_team_name(db, team_id)
1693 else:
1694 team_name = None
1695 setattr(db_agent, "team", team_name)
1697 # Compute metrics only if requested (avoids N+1 queries in list operations)
1698 if include_metrics:
1699 total_executions = len(db_agent.metrics)
1700 successful_executions = sum(1 for m in db_agent.metrics if m.is_success)
1701 failed_executions = total_executions - successful_executions
1702 failure_rate = (failed_executions / total_executions * 100) if total_executions > 0 else 0.0
1704 min_response_time = max_response_time = avg_response_time = last_execution_time = None
1705 if db_agent.metrics:
1706 response_times = [m.response_time for m in db_agent.metrics if m.response_time is not None]
1707 if response_times:
1708 min_response_time = min(response_times)
1709 max_response_time = max(response_times)
1710 avg_response_time = sum(response_times) / len(response_times)
1711 last_execution_time = max((m.timestamp for m in db_agent.metrics), default=None)
1713 metrics = A2AAgentMetrics(
1714 total_executions=total_executions,
1715 successful_executions=successful_executions,
1716 failed_executions=failed_executions,
1717 failure_rate=failure_rate,
1718 min_response_time=min_response_time,
1719 max_response_time=max_response_time,
1720 avg_response_time=avg_response_time,
1721 last_execution_time=last_execution_time,
1722 )
1723 else:
1724 metrics = None
1726 # Build dict from ORM model
1727 agent_data = {k: getattr(db_agent, k, None) for k in A2AAgentRead.model_fields.keys()}
1728 agent_data["metrics"] = metrics
1729 agent_data["team"] = getattr(db_agent, "team", None)
1730 # Include auth_query_params for the _mask_query_param_auth validator
1731 agent_data["auth_query_params"] = getattr(db_agent, "auth_query_params", None)
1733 # Validate using Pydantic model
1734 validated_agent = A2AAgentRead.model_validate(agent_data)
1736 # Return masked version (like GatewayRead)
1737 return validated_agent.masked()