Coverage for mcpgateway / cache / a2a_stats_cache.py: 100%
40 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/cache/a2a_stats_cache.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
6A2A Agent Statistics In-Memory Cache.
8This module implements a thread-safe in-memory cache for A2A agent counts with TTL expiration.
9Since aggregate agent counts (total, active) are queried frequently but change rarely,
10caching these values eliminates thousands of redundant database queries under load.
12Performance Impact:
13 - Before: 2 COUNT queries per /metrics call (10,000+ queries under load)
14 - After: 1 combined query per TTL period (default 30 seconds)
16Security Considerations:
17 - Stale count window: Changes take up to TTL to propagate
18 - Mitigation: Invalidation is called after agent mutations
19 - Cache poisoning: Not applicable (populated from DB only)
20 - Information leakage: Not applicable (stores counts only)
22Examples:
23 >>> from unittest.mock import Mock, patch
24 >>> from mcpgateway.cache.a2a_stats_cache import A2AStatsCache
26 >>> # Test cache miss and hit
27 >>> cache = A2AStatsCache(ttl_seconds=30)
28 >>> mock_db = Mock()
29 >>> mock_result = Mock()
30 >>> mock_result.total = 5
31 >>> mock_result.active = 3
32 >>> mock_db.execute.return_value.one.return_value = mock_result
34 >>> # First call - cache miss, queries DB
35 >>> result = cache.get_counts(mock_db)
36 >>> result == {"total": 5, "active": 3}
37 True
38 >>> mock_db.execute.return_value.one.call_count
39 1
41 >>> # Second call - cache hit, no DB query
42 >>> result = cache.get_counts(mock_db)
43 >>> mock_db.execute.return_value.one.call_count
44 1
46 >>> # After invalidation - queries DB again
47 >>> cache.invalidate()
48 >>> result = cache.get_counts(mock_db)
49 >>> mock_db.execute.return_value.one.call_count
50 2
51"""
53# Standard
54import logging
55import threading
56import time
57from typing import Dict
59# Use standard logging to avoid circular imports with services
60logger = logging.getLogger(__name__)
63class A2AStatsCache:
64 """
65 Thread-safe in-memory cache for A2A agent statistics with TTL.
67 This cache stores aggregate counts (total agents, active agents) to avoid
68 repeated COUNT queries on the a2a_agents table. These counts are queried
69 frequently via the /metrics endpoint but change only when agents are
70 created, toggled, or deleted.
72 Attributes:
73 ttl_seconds: Time-to-live in seconds before cache refresh
74 _cache: Cached statistics dict (or sentinel _NOT_CACHED)
75 _expiry: Timestamp when cache expires
76 _lock: Threading lock for thread-safe operations
78 Examples:
79 >>> from unittest.mock import Mock
80 >>> cache = A2AStatsCache(ttl_seconds=30)
81 >>> mock_db = Mock()
82 >>> mock_result = Mock()
83 >>> mock_result.total = 0
84 >>> mock_result.active = 0
85 >>> mock_db.execute.return_value.one.return_value = mock_result
87 >>> # Returns counts when no agents exist
88 >>> result = cache.get_counts(mock_db)
89 >>> result == {"total": 0, "active": 0}
90 True
91 """
93 # Sentinel value to distinguish "not cached" from "cached with 0 agents"
94 _NOT_CACHED = object()
96 def __init__(self, ttl_seconds: int = 30):
97 """
98 Initialize the A2A stats cache.
100 Args:
101 ttl_seconds: Time-to-live in seconds (default: 30).
102 After this duration, the cache will refresh from DB.
104 Examples:
105 >>> cache = A2AStatsCache(ttl_seconds=15)
106 >>> cache.ttl_seconds
107 15
108 """
109 self.ttl_seconds = ttl_seconds
110 self._cache = self._NOT_CACHED # Use sentinel to distinguish from cached empty
111 self._expiry: float = 0
112 self._lock = threading.Lock()
113 self._hit_count = 0
114 self._miss_count = 0
116 def get_counts(self, db) -> Dict[str, int]:
117 """
118 Get A2A agent counts from cache or database.
120 Uses a double-checked locking pattern for thread safety with minimal
121 lock contention on the hot path (cache hit).
123 This method combines what was previously 2 separate COUNT queries into
124 a single query with conditional aggregation.
126 Args:
127 db: SQLAlchemy database session
129 Returns:
130 Dict with 'total' and 'active' agent counts
132 Examples:
133 >>> from unittest.mock import Mock
134 >>> cache = A2AStatsCache(ttl_seconds=30)
135 >>> mock_db = Mock()
136 >>> mock_result = Mock()
137 >>> mock_result.total = 10
138 >>> mock_result.active = 7
139 >>> mock_db.execute.return_value.one.return_value = mock_result
140 >>> cache.get_counts(mock_db)
141 {'total': 10, 'active': 7}
142 """
143 now = time.time()
145 # Fast path: cache hit (no lock needed for read)
146 # Use sentinel check to properly cache zero counts
147 if now < self._expiry and self._cache is not self._NOT_CACHED:
148 self._hit_count += 1
149 return self._cache
151 # Slow path: cache miss or expired - acquire lock
152 with self._lock:
153 # Double-check after acquiring lock (another thread may have refreshed)
154 if now < self._expiry and self._cache is not self._NOT_CACHED:
155 self._hit_count += 1
156 return self._cache
158 # Import here to avoid circular imports
159 # Third-Party
160 from sqlalchemy import case, func, select # pylint: disable=import-outside-toplevel
162 # First-Party
163 from mcpgateway.db import A2AAgent # pylint: disable=import-outside-toplevel
165 # Single query with conditional aggregation (replaces 2 separate queries)
166 result = db.execute(
167 select(
168 func.count(A2AAgent.id).label("total"), # pylint: disable=not-callable
169 func.sum(case((A2AAgent.enabled.is_(True), 1), else_=0)).label("active"),
170 )
171 ).one()
173 self._cache = {
174 "total": result.total or 0,
175 "active": int(result.active or 0),
176 }
177 self._expiry = now + self.ttl_seconds
178 self._miss_count += 1
180 logger.debug(f"A2A stats cache refreshed: {self._cache} (TTL: {self.ttl_seconds}s)")
182 return self._cache
184 def invalidate(self) -> None:
185 """
186 Invalidate the cache, forcing a refresh on next access.
188 Call this method after creating, toggling, or deleting A2A agents
189 to ensure changes propagate immediately.
191 Examples:
192 >>> cache = A2AStatsCache(ttl_seconds=30)
193 >>> cache._expiry = time.time() + 1000 # Set future expiry
194 >>> cache.invalidate()
195 >>> cache._expiry
196 0
197 """
198 with self._lock:
199 self._cache = self._NOT_CACHED
200 self._expiry = 0
201 logger.info("A2A stats cache invalidated")
203 def stats(self) -> dict:
204 """
205 Get cache statistics.
207 Returns:
208 Dictionary with hit_count, miss_count, hit_rate, ttl_seconds, and is_cached
210 Examples:
211 >>> cache = A2AStatsCache(ttl_seconds=30)
212 >>> cache._hit_count = 90
213 >>> cache._miss_count = 10
214 >>> stats = cache.stats()
215 >>> stats["hit_rate"]
216 0.9
217 """
218 total = self._hit_count + self._miss_count
219 return {
220 "hit_count": self._hit_count,
221 "miss_count": self._miss_count,
222 "hit_rate": self._hit_count / total if total > 0 else 0.0,
223 "ttl_seconds": self.ttl_seconds,
224 "is_cached": self._cache is not self._NOT_CACHED and time.time() < self._expiry,
225 }
228# Global singleton instance
229# This is the primary interface for accessing cached A2A stats
230a2a_stats_cache = A2AStatsCache()