Coverage for mcpgateway / cache / a2a_stats_cache.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/cache/a2a_stats_cache.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5 

6A2A Agent Statistics In-Memory Cache. 

7 

8This module implements a thread-safe in-memory cache for A2A agent counts with TTL expiration. 

9Since aggregate agent counts (total, active) are queried frequently but change rarely, 

10caching these values eliminates thousands of redundant database queries under load. 

11 

12Performance Impact: 

13 - Before: 2 COUNT queries per /metrics call (10,000+ queries under load) 

14 - After: 1 combined query per TTL period (default 30 seconds) 

15 

16Security Considerations: 

17 - Stale count window: Changes take up to TTL to propagate 

18 - Mitigation: Invalidation is called after agent mutations 

19 - Cache poisoning: Not applicable (populated from DB only) 

20 - Information leakage: Not applicable (stores counts only) 

21 

22Examples: 

23 >>> from unittest.mock import Mock, patch 

24 >>> from mcpgateway.cache.a2a_stats_cache import A2AStatsCache 

25 

26 >>> # Test cache miss and hit 

27 >>> cache = A2AStatsCache(ttl_seconds=30) 

28 >>> mock_db = Mock() 

29 >>> mock_result = Mock() 

30 >>> mock_result.total = 5 

31 >>> mock_result.active = 3 

32 >>> mock_db.execute.return_value.one.return_value = mock_result 

33 

34 >>> # First call - cache miss, queries DB 

35 >>> result = cache.get_counts(mock_db) 

36 >>> result == {"total": 5, "active": 3} 

37 True 

38 >>> mock_db.execute.return_value.one.call_count 

39 1 

40 

41 >>> # Second call - cache hit, no DB query 

42 >>> result = cache.get_counts(mock_db) 

43 >>> mock_db.execute.return_value.one.call_count 

44 1 

45 

46 >>> # After invalidation - queries DB again 

47 >>> cache.invalidate() 

48 >>> result = cache.get_counts(mock_db) 

49 >>> mock_db.execute.return_value.one.call_count 

50 2 

51""" 

52 

53# Standard 

54import logging 

55import threading 

56import time 

57from typing import Dict 

58 

59# Use standard logging to avoid circular imports with services 

60logger = logging.getLogger(__name__) 

61 

62 

63class A2AStatsCache: 

64 """ 

65 Thread-safe in-memory cache for A2A agent statistics with TTL. 

66 

67 This cache stores aggregate counts (total agents, active agents) to avoid 

68 repeated COUNT queries on the a2a_agents table. These counts are queried 

69 frequently via the /metrics endpoint but change only when agents are 

70 created, toggled, or deleted. 

71 

72 Attributes: 

73 ttl_seconds: Time-to-live in seconds before cache refresh 

74 _cache: Cached statistics dict (or sentinel _NOT_CACHED) 

75 _expiry: Timestamp when cache expires 

76 _lock: Threading lock for thread-safe operations 

77 

78 Examples: 

79 >>> from unittest.mock import Mock 

80 >>> cache = A2AStatsCache(ttl_seconds=30) 

81 >>> mock_db = Mock() 

82 >>> mock_result = Mock() 

83 >>> mock_result.total = 0 

84 >>> mock_result.active = 0 

85 >>> mock_db.execute.return_value.one.return_value = mock_result 

86 

87 >>> # Returns counts when no agents exist 

88 >>> result = cache.get_counts(mock_db) 

89 >>> result == {"total": 0, "active": 0} 

90 True 

91 """ 

92 

93 # Sentinel value to distinguish "not cached" from "cached with 0 agents" 

94 _NOT_CACHED = object() 

95 

96 def __init__(self, ttl_seconds: int = 30): 

97 """ 

98 Initialize the A2A stats cache. 

99 

100 Args: 

101 ttl_seconds: Time-to-live in seconds (default: 30). 

102 After this duration, the cache will refresh from DB. 

103 

104 Examples: 

105 >>> cache = A2AStatsCache(ttl_seconds=15) 

106 >>> cache.ttl_seconds 

107 15 

108 """ 

109 self.ttl_seconds = ttl_seconds 

110 self._cache = self._NOT_CACHED # Use sentinel to distinguish from cached empty 

111 self._expiry: float = 0 

112 self._lock = threading.Lock() 

113 self._hit_count = 0 

114 self._miss_count = 0 

115 

116 def get_counts(self, db) -> Dict[str, int]: 

117 """ 

118 Get A2A agent counts from cache or database. 

119 

120 Uses a double-checked locking pattern for thread safety with minimal 

121 lock contention on the hot path (cache hit). 

122 

123 This method combines what was previously 2 separate COUNT queries into 

124 a single query with conditional aggregation. 

125 

126 Args: 

127 db: SQLAlchemy database session 

128 

129 Returns: 

130 Dict with 'total' and 'active' agent counts 

131 

132 Examples: 

133 >>> from unittest.mock import Mock 

134 >>> cache = A2AStatsCache(ttl_seconds=30) 

135 >>> mock_db = Mock() 

136 >>> mock_result = Mock() 

137 >>> mock_result.total = 10 

138 >>> mock_result.active = 7 

139 >>> mock_db.execute.return_value.one.return_value = mock_result 

140 >>> cache.get_counts(mock_db) 

141 {'total': 10, 'active': 7} 

142 """ 

143 now = time.time() 

144 

145 # Fast path: cache hit (no lock needed for read) 

146 # Use sentinel check to properly cache zero counts 

147 if now < self._expiry and self._cache is not self._NOT_CACHED: 

148 self._hit_count += 1 

149 return self._cache 

150 

151 # Slow path: cache miss or expired - acquire lock 

152 with self._lock: 

153 # Double-check after acquiring lock (another thread may have refreshed) 

154 if now < self._expiry and self._cache is not self._NOT_CACHED: 

155 self._hit_count += 1 

156 return self._cache 

157 

158 # Import here to avoid circular imports 

159 # Third-Party 

160 from sqlalchemy import case, func, select # pylint: disable=import-outside-toplevel 

161 

162 # First-Party 

163 from mcpgateway.db import A2AAgent # pylint: disable=import-outside-toplevel 

164 

165 # Single query with conditional aggregation (replaces 2 separate queries) 

166 result = db.execute( 

167 select( 

168 func.count(A2AAgent.id).label("total"), # pylint: disable=not-callable 

169 func.sum(case((A2AAgent.enabled.is_(True), 1), else_=0)).label("active"), 

170 ) 

171 ).one() 

172 

173 self._cache = { 

174 "total": result.total or 0, 

175 "active": int(result.active or 0), 

176 } 

177 self._expiry = now + self.ttl_seconds 

178 self._miss_count += 1 

179 

180 logger.debug(f"A2A stats cache refreshed: {self._cache} (TTL: {self.ttl_seconds}s)") 

181 

182 return self._cache 

183 

184 def invalidate(self) -> None: 

185 """ 

186 Invalidate the cache, forcing a refresh on next access. 

187 

188 Call this method after creating, toggling, or deleting A2A agents 

189 to ensure changes propagate immediately. 

190 

191 Examples: 

192 >>> cache = A2AStatsCache(ttl_seconds=30) 

193 >>> cache._expiry = time.time() + 1000 # Set future expiry 

194 >>> cache.invalidate() 

195 >>> cache._expiry 

196 0 

197 """ 

198 with self._lock: 

199 self._cache = self._NOT_CACHED 

200 self._expiry = 0 

201 logger.info("A2A stats cache invalidated") 

202 

203 def stats(self) -> dict: 

204 """ 

205 Get cache statistics. 

206 

207 Returns: 

208 Dictionary with hit_count, miss_count, hit_rate, ttl_seconds, and is_cached 

209 

210 Examples: 

211 >>> cache = A2AStatsCache(ttl_seconds=30) 

212 >>> cache._hit_count = 90 

213 >>> cache._miss_count = 10 

214 >>> stats = cache.stats() 

215 >>> stats["hit_rate"] 

216 0.9 

217 """ 

218 total = self._hit_count + self._miss_count 

219 return { 

220 "hit_count": self._hit_count, 

221 "miss_count": self._miss_count, 

222 "hit_rate": self._hit_count / total if total > 0 else 0.0, 

223 "ttl_seconds": self.ttl_seconds, 

224 "is_cached": self._cache is not self._NOT_CACHED and time.time() < self._expiry, 

225 } 

226 

227 

228# Global singleton instance 

229# This is the primary interface for accessing cached A2A stats 

230a2a_stats_cache = A2AStatsCache()