Coverage for mcpgateway / services / server_classification_service.py: 99%

271 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2""" 

3Server Classification Service. 

4 

5Manages hot/cold server classification based on MCP session pool usage patterns. 

6Provides staggered polling to optimize resource allocation and reduce polling overhead. 

7 

8Classification is based ONLY on upstream MCP pooled session state (gateway -> MCP servers). 

9 

10Copyright 2026 

11SPDX-License-Identifier: Apache-2.0 

12""" 

13 

14# flake8: noqa: DAR101, DAR201, DAR401 

15 

16# Future 

17from __future__ import annotations 

18 

19# Standard 

20import asyncio 

21from dataclasses import asdict, dataclass 

22import hashlib 

23import logging 

24from math import floor 

25import time 

26from typing import Dict, List, Literal, Optional, TYPE_CHECKING 

27 

28# Third-Party 

29import orjson 

30 

31# First-Party 

32from mcpgateway.config import settings 

33 

34if TYPE_CHECKING: 

35 # Third-Party 

36 from redis.asyncio import Redis 

37 

38 # First-Party 

39 from mcpgateway.services.mcp_session_pool import MCPSessionPool 

40 

41logger = logging.getLogger(__name__) 

42 

43 

44@dataclass 

45class ServerUsageMetrics: 

46 """Aggregated usage metrics for a single server from pooled sessions.""" 

47 

48 url: str 

49 server_last_used: float = 0.0 # max(last_used) across all pooled sessions 

50 active_session_count: int = 0 # Count from _active dict 

51 total_use_count: int = 0 # Sum of use_count from all sessions 

52 pooled_session_count: int = 0 # Total pooled sessions for this server 

53 

54 

55@dataclass 

56class ClassificationMetadata: 

57 """Metadata about classification run.""" 

58 

59 total_servers: int # Total servers 

60 hot_cap: int # Maximum hot servers (20% of total_servers) 

61 hot_actual: int # Actual hot servers selected 

62 eligible_count: int # Servers with pooled sessions 

63 timestamp: float # Classification timestamp 

64 underutilized_reason: Optional[str] = None # Why hot < 20% (if applicable) 

65 

66 

67@dataclass 

68class ClassificationResult: 

69 """Result of server classification.""" 

70 

71 hot_servers: List[str] # URLs of hot servers 

72 cold_servers: List[str] # URLs of cold servers 

73 metadata: ClassificationMetadata 

74 

75 

76class ServerClassificationService: 

77 """ 

78 Manages hot/cold server classification based on MCP session pool state. 

79 

80 Classification Logic: 

81 1. Scope: Uses only upstream MCP pooled session state 

82 2. Hot cap: floor(20% * total_servers) 

83 3. Eligibility: Server must have pooled session with valid last_used 

84 4. Ranking: server_last_used descending (newest first) 

85 5. Tie-breakers: active_count, use_count, URL (deterministic) 

86 6. Hot selection: Top min(hot_cap, eligible_count) 

87 7. Cold: All remaining servers 

88 8. Guarantees: No overlap, full coverage, deterministic 

89 

90 Thread-safe for multi-worker deployments via Redis state management. 

91 Falls back to local-only operation when Redis unavailable. 

92 """ 

93 

94 # Redis key templates 

95 CLASSIFICATION_HOT_KEY = "mcpgateway:server_classification:hot" 

96 CLASSIFICATION_COLD_KEY = "mcpgateway:server_classification:cold" 

97 CLASSIFICATION_METADATA_KEY = "mcpgateway:server_classification:metadata" 

98 CLASSIFICATION_TIMESTAMP_KEY = "mcpgateway:server_classification:timestamp" 

99 POLL_STATE_KEY_TEMPLATE = "mcpgateway:server_poll_state:{scope_hash}:last_{poll_type}" 

100 LEADER_KEY = "mcpgateway:server_classification:leader" 

101 

102 # Lua script for atomic leader lock acquire-or-renew. 

103 # Executes as a single atomic operation in Redis, preventing the race where 

104 # the key expires between a GET and EXPIRE in separate round-trips. 

105 _LEADER_LOCK_SCRIPT = """ 

106 if redis.call('SET', KEYS[1], ARGV[1], 'EX', tonumber(ARGV[2]), 'NX') then 

107 return 1 

108 end 

109 if redis.call('GET', KEYS[1]) == ARGV[1] then 

110 redis.call('EXPIRE', KEYS[1], tonumber(ARGV[2])) 

111 return 1 

112 end 

113 return 0 

114 """ 

115 

116 def __init__(self, redis_client: Optional[Redis] = None): 

117 """Initialize classification service. 

118 

119 Args: 

120 redis_client: Redis client for state management (optional for single-worker) 

121 """ 

122 self._redis = redis_client 

123 self._classification_task: Optional[asyncio.Task] = None 

124 self._instance_id = f"classifier_{id(self)}" 

125 # TTL = 3x interval gives ample margin for classification + sleep. 

126 # Classification is idempotent (deterministic algorithm), so even if the lock 

127 # expires and a second worker classifies concurrently, the result is identical. 

128 self._leader_ttl = int(settings.gateway_auto_refresh_interval * 3) 

129 self._running = False 

130 self._error_backoff_seconds: float = 30.0 # Back off duration on loop errors (override in tests) 

131 self._leader_lock_sha: Optional[str] = None # Cached SHA for leader lock Lua script 

132 

133 async def start(self) -> None: 

134 """Start background classification loop (if enabled).""" 

135 if not settings.hot_cold_classification_enabled: 

136 logger.info("Hot/cold classification disabled") 

137 return 

138 

139 if self._running: 

140 logger.warning("Classification service already running") 

141 return 

142 

143 self._running = True 

144 self._classification_task = asyncio.create_task(self._run_classification_loop()) 

145 self._classification_task.add_done_callback(self._on_classification_task_done) 

146 logger.info(f"Server classification service started " f"(instance={self._instance_id}, redis={'enabled' if self._redis else 'disabled'})") 

147 

148 def _on_classification_task_done(self, task: asyncio.Task) -> None: 

149 """Callback when the classification background task exits unexpectedly.""" 

150 if task.cancelled(): 

151 return 

152 exc = task.exception() 

153 if exc: 

154 logger.error(f"Classification background task died: {exc}", exc_info=exc) 

155 self._running = False 

156 

157 async def stop(self) -> None: 

158 """Stop background classification.""" 

159 self._running = False 

160 if self._classification_task: 

161 self._classification_task.cancel() 

162 try: 

163 await self._classification_task 

164 except asyncio.CancelledError: 

165 logger.info("Classification task cancelled") 

166 except Exception as e: 

167 # Task already died with an error — don't let it crash shutdown 

168 logger.warning(f"Classification task had failed: {e}") 

169 

170 async def _run_classification_loop(self) -> None: 

171 """Background loop: classify servers periodically with leader election.""" 

172 while self._running: 

173 try: 

174 # Leader election (Redis-based for multi-worker, local-only otherwise) 

175 is_leader = await self._try_acquire_leader_lock() 

176 

177 if is_leader: 

178 logger.debug(f"Classification leader acquired (instance={self._instance_id})") 

179 # Classification is idempotent (deterministic algorithm on shared pool state), 

180 # so concurrent execution by multiple workers produces identical results. 

181 # Leader election reduces redundant work; it is not a correctness requirement. 

182 # Timeout prevents unbounded runs from holding the loop. 

183 try: 

184 await asyncio.wait_for(self._perform_classification(), timeout=self._leader_ttl * 0.8) 

185 except asyncio.TimeoutError: 

186 logger.warning(f"Classification timed out after {self._leader_ttl * 0.8:.0f}s, skipping this cycle") 

187 # Renew lock after classification to keep it alive during sleep 

188 await self._try_acquire_leader_lock() 

189 else: 

190 logger.debug(f"Not classification leader, skipping (instance={self._instance_id})") 

191 

192 await asyncio.sleep(settings.gateway_auto_refresh_interval) 

193 

194 except asyncio.CancelledError: 

195 logger.info("Classification loop cancelled") 

196 break 

197 except Exception as e: 

198 logger.error(f"Classification loop error: {e}", exc_info=True) 

199 await asyncio.sleep(self._error_backoff_seconds) # Back off on error 

200 

201 async def _try_acquire_leader_lock(self) -> bool: 

202 """Try to acquire or renew leader lock for classification. 

203 

204 Uses an atomic Lua script that either acquires a new lock (SET NX) 

205 or renews the TTL if this instance already holds it. The script 

206 runs as a single Redis transaction, preventing the race where the 

207 key expires between a GET and EXPIRE in separate round-trips. 

208 

209 Returns: 

210 True if this instance is leader, False otherwise 

211 """ 

212 if not self._redis: 

213 # Single-worker mode (no Redis), always leader 

214 return True 

215 

216 try: 

217 # Load Lua script on first call (cached by Redis server via SHA) 

218 if self._leader_lock_sha is None: 

219 self._leader_lock_sha = await self._redis.script_load(self._LEADER_LOCK_SCRIPT) 

220 

221 try: 

222 result = await self._redis.evalsha(self._leader_lock_sha, 1, self.LEADER_KEY, self._instance_id, str(self._leader_ttl)) 

223 except Exception as evalsha_err: 

224 # Handle NOSCRIPT (Redis restarted / SCRIPT FLUSH) by re-registering 

225 if "NOSCRIPT" in str(evalsha_err): 

226 logger.debug("Lua script evicted, re-registering") 

227 self._leader_lock_sha = await self._redis.script_load(self._LEADER_LOCK_SCRIPT) 

228 result = await self._redis.evalsha(self._leader_lock_sha, 1, self.LEADER_KEY, self._instance_id, str(self._leader_ttl)) 

229 else: 

230 raise 

231 return result == 1 

232 except Exception as e: 

233 logger.warning(f"Failed to acquire leader lock: {e}") 

234 return False # Fail safe: don't classify on error 

235 

236 async def _perform_classification(self) -> None: 

237 """Perform classification and publish to Redis (if available).""" 

238 try: 

239 # Get MCP session pool 

240 # First-Party 

241 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool 

242 

243 try: 

244 pool = get_mcp_session_pool() 

245 except RuntimeError: 

246 logger.debug("MCP session pool not initialized, skipping classification") 

247 return 

248 

249 # Get gateway_id → canonical URL mapping from database 

250 gateway_url_map = await self._get_gateway_url_map() 

251 if not gateway_url_map: 

252 logger.debug("No gateways found, skipping classification") 

253 return 

254 

255 # Deduplicate: multiple gateways may share the same upstream URL 

256 # (different credentials/scopes). Classification operates on unique servers. 

257 all_gateway_urls = list(dict.fromkeys(gateway_url_map.values())) 

258 

259 # Perform classification 

260 result = self._classify_servers_from_pool(pool, all_gateway_urls, gateway_url_map) 

261 

262 # Publish to Redis (if available) 

263 if self._redis: 

264 await self._publish_classification_to_redis(result) 

265 

266 logger.info( 

267 f"Classification completed: {len(result.hot_servers)} hot, " f"{len(result.cold_servers)} cold (N={result.metadata.total_servers}, " f"eligible={result.metadata.eligible_count})" 

268 ) 

269 

270 if result.metadata.underutilized_reason: 

271 logger.debug(f"Underutilization: {result.metadata.underutilized_reason}") 

272 

273 except Exception as e: 

274 logger.error(f"Classification failed: {e}", exc_info=True) 

275 

276 def _resolve_canonical_url(self, pool_key: tuple, gateway_url_map: Dict[str, str]) -> Optional[str]: 

277 """Resolve the canonical gateway URL for a pool key. 

278 

279 Pool keys may contain auth-mutated URLs (e.g. with query-param secrets). 

280 Use gateway_id from the pool key to look up the canonical Gateway.url, 

281 preventing secret leakage into classification Redis sets. 

282 

283 Args: 

284 pool_key: PoolKey tuple (user_identity, url, identity_hash, transport_type, gateway_id) 

285 gateway_url_map: Mapping of gateway_id → canonical URL from database 

286 

287 Returns: 

288 Canonical URL if gateway_id resolves, else None 

289 """ 

290 gateway_id = pool_key[4] if len(pool_key) > 4 else "" 

291 if gateway_id and gateway_id in gateway_url_map: 

292 return gateway_url_map[gateway_id] 

293 return None 

294 

295 def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: List[str], gateway_url_map: Optional[Dict[str, str]] = None) -> ClassificationResult: 

296 """Classify servers based on pooled session state. 

297 

298 Algorithm (deterministic): 

299 1. Get total servers N 

300 2. Calculate hot_cap = floor(0.20 * N) 

301 3. Extract server metrics from pooled sessions (idle + active) 

302 4. Filter eligible (has valid last_used or active sessions) 

303 5. Sort by (server_last_used desc, active_count desc, use_count desc, url asc) 

304 6. Select top min(hot_cap, eligible_count) as hot 

305 7. Remaining servers are cold 

306 

307 Args: 

308 pool: MCP session pool 

309 all_gateway_urls: All registered gateway URLs 

310 gateway_url_map: Optional mapping of gateway_id → canonical URL for URL normalization 

311 

312 Returns: 

313 ClassificationResult with hot/cold servers and metadata 

314 """ 

315 total_servers = len(all_gateway_urls) 

316 hot_cap = floor(0.20 * total_servers) 

317 canonical_url_set = set(all_gateway_urls) 

318 

319 # Step 3: Extract server usage from pooled sessions 

320 server_metrics: Dict[str, ServerUsageMetrics] = {} 

321 

322 # Helper to accumulate metrics from a single PooledSession 

323 def _accumulate_session(url: str, session: object) -> None: 

324 if url not in server_metrics: 

325 server_metrics[url] = ServerUsageMetrics(url=url) 

326 if hasattr(session, "last_used") and session.last_used > 0: 

327 server_metrics[url].server_last_used = max(server_metrics[url].server_last_used, session.last_used) 

328 server_metrics[url].total_use_count += getattr(session, "use_count", 0) 

329 server_metrics[url].pooled_session_count += 1 

330 

331 # Iterate over pool._pools (Dict[PoolKey, Queue[PooledSession]]) 

332 # PoolKey = (user_identity, url, identity_hash, transport_type, gateway_id) 

333 for pool_key, session_queue in pool._pools.items(): # pylint: disable=protected-access 

334 # Resolve canonical URL: prefer gateway_id lookup, fall back to raw pool URL 

335 url = (self._resolve_canonical_url(pool_key, gateway_url_map) if gateway_url_map else None) or pool_key[1] 

336 

337 # Only track URLs that correspond to known gateways 

338 if url not in canonical_url_set: 

339 continue 

340 

341 if url not in server_metrics: 

342 server_metrics[url] = ServerUsageMetrics(url=url) 

343 

344 # Process each idle session in the queue 

345 try: 

346 if hasattr(session_queue, "_queue"): 

347 sessions_list = list(session_queue._queue) # pylint: disable=protected-access 

348 else: 

349 sessions_list = [] 

350 

351 for session in sessions_list: 

352 _accumulate_session(url, session) 

353 except Exception as e: 

354 logger.warning(f"Error extracting idle metrics for {url}: {e}") 

355 continue 

356 

357 # Process active sessions (checked-out from the pool) 

358 # This ensures busy servers with all sessions in use are still eligible. 

359 for pool_key, active_set in pool._active.items(): # pylint: disable=protected-access 

360 url = (self._resolve_canonical_url(pool_key, gateway_url_map) if gateway_url_map else None) or pool_key[1] 

361 

362 if url not in canonical_url_set: 

363 continue 

364 

365 if url not in server_metrics: 

366 server_metrics[url] = ServerUsageMetrics(url=url) 

367 

368 server_metrics[url].active_session_count += len(active_set) 

369 

370 # Extract last_used / use_count from active sessions too 

371 for session in active_set: 

372 try: 

373 _accumulate_session(url, session) 

374 except Exception as active_err: 

375 logger.debug(f"Skipping active session metric for {url}: {active_err}") 

376 continue 

377 

378 # Step 4: Filter eligible servers (has valid last_used) 

379 eligible_servers = [metrics for metrics in server_metrics.values() if metrics.server_last_used > 0.0] 

380 eligible_count = len(eligible_servers) 

381 

382 # Step 5: Sort by recency (newer first), then tie-breakers 

383 eligible_servers.sort( 

384 key=lambda m: ( 

385 -m.server_last_used, # Primary: most recent first (descending) 

386 -m.active_session_count, # Tie-breaker 1: more active sessions 

387 -m.total_use_count, # Tie-breaker 2: higher use count 

388 m.url, # Tie-breaker 3: deterministic (ascending) 

389 ) 

390 ) 

391 

392 # Step 6: Select hot servers (up to hot_cap, no backfill) 

393 hot_actual = min(hot_cap, eligible_count) 

394 hot_servers = [m.url for m in eligible_servers[:hot_actual]] 

395 

396 # Step 7: Cold servers = all remaining 

397 hot_set = set(hot_servers) 

398 cold_servers = [url for url in all_gateway_urls if url not in hot_set] 

399 

400 # Step 8: Build metadata 

401 underutilized_reason = None 

402 if eligible_count < hot_cap: 

403 underutilized_reason = f"Only {eligible_count} servers have pooled sessions, " f"below hot_cap={hot_cap}" 

404 

405 return ClassificationResult( 

406 hot_servers=hot_servers, 

407 cold_servers=cold_servers, 

408 metadata=ClassificationMetadata( 

409 total_servers=total_servers, hot_cap=hot_cap, hot_actual=hot_actual, eligible_count=eligible_count, timestamp=time.time(), underutilized_reason=underutilized_reason 

410 ), 

411 ) 

412 

413 async def _get_gateway_url_map(self) -> Dict[str, str]: 

414 """Get mapping of gateway_id → canonical URL for all enabled gateways. 

415 

416 Returns: 

417 Dict mapping gateway ID to its canonical URL 

418 """ 

419 # Third-Party 

420 from sqlalchemy import select 

421 

422 # First-Party 

423 from mcpgateway.db import Gateway, SessionLocal 

424 

425 try: 

426 with SessionLocal() as db: 

427 result = db.execute(select(Gateway.id, Gateway.url).where(Gateway.enabled.is_(True))) 

428 return {str(row[0]): row[1] for row in result} 

429 except Exception as e: 

430 logger.error(f"Failed to get gateway URL map: {e}") 

431 return {} 

432 

433 async def _publish_classification_to_redis(self, result: ClassificationResult) -> None: 

434 """Publish classification result to Redis atomically. 

435 

436 Args: 

437 result: Classification result to publish 

438 """ 

439 if not self._redis: 

440 return 

441 

442 try: 

443 # Atomic pipeline for transactional updates 

444 async with self._redis.pipeline(transaction=True) as pipe: 

445 # Clear old classification 

446 await pipe.delete(self.CLASSIFICATION_HOT_KEY, self.CLASSIFICATION_COLD_KEY) 

447 

448 # Set new classification 

449 # Set TTL on classification sets to prevent stale data after worker crash 

450 ttl = int(settings.gateway_auto_refresh_interval * 2) 

451 

452 if result.hot_servers: 

453 await pipe.sadd(self.CLASSIFICATION_HOT_KEY, *result.hot_servers) 

454 

455 if result.cold_servers: 

456 await pipe.sadd(self.CLASSIFICATION_COLD_KEY, *result.cold_servers) 

457 

458 # Expire classification sets regardless of whether they had members 

459 await pipe.expire(self.CLASSIFICATION_HOT_KEY, ttl) 

460 await pipe.expire(self.CLASSIFICATION_COLD_KEY, ttl) 

461 

462 # Store metadata (expire after 2x classification interval) 

463 metadata_json = orjson.dumps(asdict(result.metadata)) 

464 await pipe.set(self.CLASSIFICATION_METADATA_KEY, metadata_json, ex=ttl) 

465 

466 await pipe.set(self.CLASSIFICATION_TIMESTAMP_KEY, result.metadata.timestamp, ex=ttl) 

467 

468 await pipe.execute() 

469 

470 logger.debug("Classification published to Redis successfully") 

471 

472 except Exception as e: 

473 logger.error(f"Failed to publish classification to Redis: {e}") 

474 

475 async def get_server_classification(self, url: str) -> Optional[str]: 

476 """Get classification for a server (hot/cold). 

477 

478 Args: 

479 url: Server URL 

480 

481 Returns: 

482 "hot", "cold", or None if not classified 

483 """ 

484 if not self._redis: 

485 return None # No Redis, classification not available 

486 

487 try: 

488 is_hot = await self._redis.sismember(self.CLASSIFICATION_HOT_KEY, url) 

489 if is_hot: 

490 return "hot" 

491 

492 is_cold = await self._redis.sismember(self.CLASSIFICATION_COLD_KEY, url) 

493 if is_cold: 

494 return "cold" 

495 

496 return None # Not yet classified 

497 except Exception as e: 

498 logger.warning(f"Failed to get classification for {url}: {e}") 

499 return None # Fail open 

500 

501 def _poll_state_key(self, url: str, poll_type: str, gateway_id: str = "") -> str: 

502 """Build the Redis key for poll-state tracking. 

503 

504 Includes gateway_id when provided so that distinct gateways sharing the 

505 same upstream URL track their refresh schedules independently. 

506 """ 

507 scope = f"{url}\0{gateway_id}" if gateway_id else url 

508 scope_hash = hashlib.sha256(scope.encode()).hexdigest()[:32] 

509 return self.POLL_STATE_KEY_TEMPLATE.format(scope_hash=scope_hash, poll_type=poll_type) 

510 

511 async def should_poll_server(self, url: str, poll_type: Literal["health", "tool_discovery"], gateway_id: str = "") -> bool: 

512 """Determine if server should be polled now based on classification. 

513 

514 Args: 

515 url: Server URL 

516 poll_type: Type of poll (health or tool_discovery) 

517 gateway_id: Optional gateway ID for per-gateway poll tracking 

518 

519 Returns: 

520 True if should poll now, False otherwise 

521 """ 

522 if not settings.hot_cold_classification_enabled: 

523 return True # Feature disabled, always poll 

524 

525 if not self._redis: 

526 return True # No Redis, always poll (single-worker mode) 

527 

528 try: 

529 classification = await self.get_server_classification(url) 

530 if classification is None: 

531 return True # Not yet classified, poll anyway 

532 

533 last_poll_key = self._poll_state_key(url, poll_type, gateway_id) 

534 last_poll_str = await self._redis.get(last_poll_key) 

535 

536 if last_poll_str is None: 

537 # Never polled, should poll now (caller must call mark_poll_completed after) 

538 return True 

539 

540 last_poll = float(last_poll_str) 

541 now = time.time() 

542 if not 0 < last_poll <= now + 60: 

543 last_poll = 0.0 # treat as never polled; prevents manipulation via future timestamps 

544 elapsed = now - last_poll 

545 

546 # Determine interval based on classification 

547 interval = settings.hot_server_check_interval if classification == "hot" else settings.cold_server_check_interval 

548 

549 should_poll = elapsed >= interval 

550 

551 return should_poll 

552 

553 except Exception as e: 

554 logger.warning(f"Error checking poll status for {url}: {e}") 

555 return True # Fail open: poll on error 

556 

557 async def mark_poll_completed(self, url: str, poll_type: Literal["health", "tool_discovery"], gateway_id: str = "") -> None: 

558 """Record that a poll was actually performed. 

559 

560 Call this AFTER the poll/refresh succeeds, not at decision time. 

561 This prevents wasting poll slots when downstream throttling skips the refresh. 

562 

563 Args: 

564 url: Server URL 

565 poll_type: Type of poll 

566 gateway_id: Optional gateway ID for per-gateway poll tracking 

567 """ 

568 if not self._redis: 

569 return 

570 

571 try: 

572 classification = await self.get_server_classification(url) 

573 interval = settings.hot_server_check_interval if classification == "hot" else settings.cold_server_check_interval 

574 

575 last_poll_key = self._poll_state_key(url, poll_type, gateway_id) 

576 await self._redis.set(last_poll_key, time.time(), ex=int(interval * 2)) # Expire after 2x interval 

577 except Exception as e: 

578 logger.warning(f"Failed to update poll timestamp for {url}: {e}")