Coverage for mcpgateway / services / server_classification_service.py: 99%
271 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""
3Server Classification Service.
5Manages hot/cold server classification based on MCP session pool usage patterns.
6Provides staggered polling to optimize resource allocation and reduce polling overhead.
8Classification is based ONLY on upstream MCP pooled session state (gateway -> MCP servers).
10Copyright 2026
11SPDX-License-Identifier: Apache-2.0
12"""
14# flake8: noqa: DAR101, DAR201, DAR401
16# Future
17from __future__ import annotations
19# Standard
20import asyncio
21from dataclasses import asdict, dataclass
22import hashlib
23import logging
24from math import floor
25import time
26from typing import Dict, List, Literal, Optional, TYPE_CHECKING
28# Third-Party
29import orjson
31# First-Party
32from mcpgateway.config import settings
34if TYPE_CHECKING:
35 # Third-Party
36 from redis.asyncio import Redis
38 # First-Party
39 from mcpgateway.services.mcp_session_pool import MCPSessionPool
41logger = logging.getLogger(__name__)
44@dataclass
45class ServerUsageMetrics:
46 """Aggregated usage metrics for a single server from pooled sessions."""
48 url: str
49 server_last_used: float = 0.0 # max(last_used) across all pooled sessions
50 active_session_count: int = 0 # Count from _active dict
51 total_use_count: int = 0 # Sum of use_count from all sessions
52 pooled_session_count: int = 0 # Total pooled sessions for this server
55@dataclass
56class ClassificationMetadata:
57 """Metadata about classification run."""
59 total_servers: int # Total servers
60 hot_cap: int # Maximum hot servers (20% of total_servers)
61 hot_actual: int # Actual hot servers selected
62 eligible_count: int # Servers with pooled sessions
63 timestamp: float # Classification timestamp
64 underutilized_reason: Optional[str] = None # Why hot < 20% (if applicable)
67@dataclass
68class ClassificationResult:
69 """Result of server classification."""
71 hot_servers: List[str] # URLs of hot servers
72 cold_servers: List[str] # URLs of cold servers
73 metadata: ClassificationMetadata
76class ServerClassificationService:
77 """
78 Manages hot/cold server classification based on MCP session pool state.
80 Classification Logic:
81 1. Scope: Uses only upstream MCP pooled session state
82 2. Hot cap: floor(20% * total_servers)
83 3. Eligibility: Server must have pooled session with valid last_used
84 4. Ranking: server_last_used descending (newest first)
85 5. Tie-breakers: active_count, use_count, URL (deterministic)
86 6. Hot selection: Top min(hot_cap, eligible_count)
87 7. Cold: All remaining servers
88 8. Guarantees: No overlap, full coverage, deterministic
90 Thread-safe for multi-worker deployments via Redis state management.
91 Falls back to local-only operation when Redis unavailable.
92 """
94 # Redis key templates
95 CLASSIFICATION_HOT_KEY = "mcpgateway:server_classification:hot"
96 CLASSIFICATION_COLD_KEY = "mcpgateway:server_classification:cold"
97 CLASSIFICATION_METADATA_KEY = "mcpgateway:server_classification:metadata"
98 CLASSIFICATION_TIMESTAMP_KEY = "mcpgateway:server_classification:timestamp"
99 POLL_STATE_KEY_TEMPLATE = "mcpgateway:server_poll_state:{scope_hash}:last_{poll_type}"
100 LEADER_KEY = "mcpgateway:server_classification:leader"
102 # Lua script for atomic leader lock acquire-or-renew.
103 # Executes as a single atomic operation in Redis, preventing the race where
104 # the key expires between a GET and EXPIRE in separate round-trips.
105 _LEADER_LOCK_SCRIPT = """
106 if redis.call('SET', KEYS[1], ARGV[1], 'EX', tonumber(ARGV[2]), 'NX') then
107 return 1
108 end
109 if redis.call('GET', KEYS[1]) == ARGV[1] then
110 redis.call('EXPIRE', KEYS[1], tonumber(ARGV[2]))
111 return 1
112 end
113 return 0
114 """
116 def __init__(self, redis_client: Optional[Redis] = None):
117 """Initialize classification service.
119 Args:
120 redis_client: Redis client for state management (optional for single-worker)
121 """
122 self._redis = redis_client
123 self._classification_task: Optional[asyncio.Task] = None
124 self._instance_id = f"classifier_{id(self)}"
125 # TTL = 3x interval gives ample margin for classification + sleep.
126 # Classification is idempotent (deterministic algorithm), so even if the lock
127 # expires and a second worker classifies concurrently, the result is identical.
128 self._leader_ttl = int(settings.gateway_auto_refresh_interval * 3)
129 self._running = False
130 self._error_backoff_seconds: float = 30.0 # Back off duration on loop errors (override in tests)
131 self._leader_lock_sha: Optional[str] = None # Cached SHA for leader lock Lua script
133 async def start(self) -> None:
134 """Start background classification loop (if enabled)."""
135 if not settings.hot_cold_classification_enabled:
136 logger.info("Hot/cold classification disabled")
137 return
139 if self._running:
140 logger.warning("Classification service already running")
141 return
143 self._running = True
144 self._classification_task = asyncio.create_task(self._run_classification_loop())
145 self._classification_task.add_done_callback(self._on_classification_task_done)
146 logger.info(f"Server classification service started " f"(instance={self._instance_id}, redis={'enabled' if self._redis else 'disabled'})")
148 def _on_classification_task_done(self, task: asyncio.Task) -> None:
149 """Callback when the classification background task exits unexpectedly."""
150 if task.cancelled():
151 return
152 exc = task.exception()
153 if exc:
154 logger.error(f"Classification background task died: {exc}", exc_info=exc)
155 self._running = False
157 async def stop(self) -> None:
158 """Stop background classification."""
159 self._running = False
160 if self._classification_task:
161 self._classification_task.cancel()
162 try:
163 await self._classification_task
164 except asyncio.CancelledError:
165 logger.info("Classification task cancelled")
166 except Exception as e:
167 # Task already died with an error — don't let it crash shutdown
168 logger.warning(f"Classification task had failed: {e}")
170 async def _run_classification_loop(self) -> None:
171 """Background loop: classify servers periodically with leader election."""
172 while self._running:
173 try:
174 # Leader election (Redis-based for multi-worker, local-only otherwise)
175 is_leader = await self._try_acquire_leader_lock()
177 if is_leader:
178 logger.debug(f"Classification leader acquired (instance={self._instance_id})")
179 # Classification is idempotent (deterministic algorithm on shared pool state),
180 # so concurrent execution by multiple workers produces identical results.
181 # Leader election reduces redundant work; it is not a correctness requirement.
182 # Timeout prevents unbounded runs from holding the loop.
183 try:
184 await asyncio.wait_for(self._perform_classification(), timeout=self._leader_ttl * 0.8)
185 except asyncio.TimeoutError:
186 logger.warning(f"Classification timed out after {self._leader_ttl * 0.8:.0f}s, skipping this cycle")
187 # Renew lock after classification to keep it alive during sleep
188 await self._try_acquire_leader_lock()
189 else:
190 logger.debug(f"Not classification leader, skipping (instance={self._instance_id})")
192 await asyncio.sleep(settings.gateway_auto_refresh_interval)
194 except asyncio.CancelledError:
195 logger.info("Classification loop cancelled")
196 break
197 except Exception as e:
198 logger.error(f"Classification loop error: {e}", exc_info=True)
199 await asyncio.sleep(self._error_backoff_seconds) # Back off on error
201 async def _try_acquire_leader_lock(self) -> bool:
202 """Try to acquire or renew leader lock for classification.
204 Uses an atomic Lua script that either acquires a new lock (SET NX)
205 or renews the TTL if this instance already holds it. The script
206 runs as a single Redis transaction, preventing the race where the
207 key expires between a GET and EXPIRE in separate round-trips.
209 Returns:
210 True if this instance is leader, False otherwise
211 """
212 if not self._redis:
213 # Single-worker mode (no Redis), always leader
214 return True
216 try:
217 # Load Lua script on first call (cached by Redis server via SHA)
218 if self._leader_lock_sha is None:
219 self._leader_lock_sha = await self._redis.script_load(self._LEADER_LOCK_SCRIPT)
221 try:
222 result = await self._redis.evalsha(self._leader_lock_sha, 1, self.LEADER_KEY, self._instance_id, str(self._leader_ttl))
223 except Exception as evalsha_err:
224 # Handle NOSCRIPT (Redis restarted / SCRIPT FLUSH) by re-registering
225 if "NOSCRIPT" in str(evalsha_err):
226 logger.debug("Lua script evicted, re-registering")
227 self._leader_lock_sha = await self._redis.script_load(self._LEADER_LOCK_SCRIPT)
228 result = await self._redis.evalsha(self._leader_lock_sha, 1, self.LEADER_KEY, self._instance_id, str(self._leader_ttl))
229 else:
230 raise
231 return result == 1
232 except Exception as e:
233 logger.warning(f"Failed to acquire leader lock: {e}")
234 return False # Fail safe: don't classify on error
236 async def _perform_classification(self) -> None:
237 """Perform classification and publish to Redis (if available)."""
238 try:
239 # Get MCP session pool
240 # First-Party
241 from mcpgateway.services.mcp_session_pool import get_mcp_session_pool
243 try:
244 pool = get_mcp_session_pool()
245 except RuntimeError:
246 logger.debug("MCP session pool not initialized, skipping classification")
247 return
249 # Get gateway_id → canonical URL mapping from database
250 gateway_url_map = await self._get_gateway_url_map()
251 if not gateway_url_map:
252 logger.debug("No gateways found, skipping classification")
253 return
255 # Deduplicate: multiple gateways may share the same upstream URL
256 # (different credentials/scopes). Classification operates on unique servers.
257 all_gateway_urls = list(dict.fromkeys(gateway_url_map.values()))
259 # Perform classification
260 result = self._classify_servers_from_pool(pool, all_gateway_urls, gateway_url_map)
262 # Publish to Redis (if available)
263 if self._redis:
264 await self._publish_classification_to_redis(result)
266 logger.info(
267 f"Classification completed: {len(result.hot_servers)} hot, " f"{len(result.cold_servers)} cold (N={result.metadata.total_servers}, " f"eligible={result.metadata.eligible_count})"
268 )
270 if result.metadata.underutilized_reason:
271 logger.debug(f"Underutilization: {result.metadata.underutilized_reason}")
273 except Exception as e:
274 logger.error(f"Classification failed: {e}", exc_info=True)
276 def _resolve_canonical_url(self, pool_key: tuple, gateway_url_map: Dict[str, str]) -> Optional[str]:
277 """Resolve the canonical gateway URL for a pool key.
279 Pool keys may contain auth-mutated URLs (e.g. with query-param secrets).
280 Use gateway_id from the pool key to look up the canonical Gateway.url,
281 preventing secret leakage into classification Redis sets.
283 Args:
284 pool_key: PoolKey tuple (user_identity, url, identity_hash, transport_type, gateway_id)
285 gateway_url_map: Mapping of gateway_id → canonical URL from database
287 Returns:
288 Canonical URL if gateway_id resolves, else None
289 """
290 gateway_id = pool_key[4] if len(pool_key) > 4 else ""
291 if gateway_id and gateway_id in gateway_url_map:
292 return gateway_url_map[gateway_id]
293 return None
295 def _classify_servers_from_pool(self, pool: MCPSessionPool, all_gateway_urls: List[str], gateway_url_map: Optional[Dict[str, str]] = None) -> ClassificationResult:
296 """Classify servers based on pooled session state.
298 Algorithm (deterministic):
299 1. Get total servers N
300 2. Calculate hot_cap = floor(0.20 * N)
301 3. Extract server metrics from pooled sessions (idle + active)
302 4. Filter eligible (has valid last_used or active sessions)
303 5. Sort by (server_last_used desc, active_count desc, use_count desc, url asc)
304 6. Select top min(hot_cap, eligible_count) as hot
305 7. Remaining servers are cold
307 Args:
308 pool: MCP session pool
309 all_gateway_urls: All registered gateway URLs
310 gateway_url_map: Optional mapping of gateway_id → canonical URL for URL normalization
312 Returns:
313 ClassificationResult with hot/cold servers and metadata
314 """
315 total_servers = len(all_gateway_urls)
316 hot_cap = floor(0.20 * total_servers)
317 canonical_url_set = set(all_gateway_urls)
319 # Step 3: Extract server usage from pooled sessions
320 server_metrics: Dict[str, ServerUsageMetrics] = {}
322 # Helper to accumulate metrics from a single PooledSession
323 def _accumulate_session(url: str, session: object) -> None:
324 if url not in server_metrics:
325 server_metrics[url] = ServerUsageMetrics(url=url)
326 if hasattr(session, "last_used") and session.last_used > 0:
327 server_metrics[url].server_last_used = max(server_metrics[url].server_last_used, session.last_used)
328 server_metrics[url].total_use_count += getattr(session, "use_count", 0)
329 server_metrics[url].pooled_session_count += 1
331 # Iterate over pool._pools (Dict[PoolKey, Queue[PooledSession]])
332 # PoolKey = (user_identity, url, identity_hash, transport_type, gateway_id)
333 for pool_key, session_queue in pool._pools.items(): # pylint: disable=protected-access
334 # Resolve canonical URL: prefer gateway_id lookup, fall back to raw pool URL
335 url = (self._resolve_canonical_url(pool_key, gateway_url_map) if gateway_url_map else None) or pool_key[1]
337 # Only track URLs that correspond to known gateways
338 if url not in canonical_url_set:
339 continue
341 if url not in server_metrics:
342 server_metrics[url] = ServerUsageMetrics(url=url)
344 # Process each idle session in the queue
345 try:
346 if hasattr(session_queue, "_queue"):
347 sessions_list = list(session_queue._queue) # pylint: disable=protected-access
348 else:
349 sessions_list = []
351 for session in sessions_list:
352 _accumulate_session(url, session)
353 except Exception as e:
354 logger.warning(f"Error extracting idle metrics for {url}: {e}")
355 continue
357 # Process active sessions (checked-out from the pool)
358 # This ensures busy servers with all sessions in use are still eligible.
359 for pool_key, active_set in pool._active.items(): # pylint: disable=protected-access
360 url = (self._resolve_canonical_url(pool_key, gateway_url_map) if gateway_url_map else None) or pool_key[1]
362 if url not in canonical_url_set:
363 continue
365 if url not in server_metrics:
366 server_metrics[url] = ServerUsageMetrics(url=url)
368 server_metrics[url].active_session_count += len(active_set)
370 # Extract last_used / use_count from active sessions too
371 for session in active_set:
372 try:
373 _accumulate_session(url, session)
374 except Exception as active_err:
375 logger.debug(f"Skipping active session metric for {url}: {active_err}")
376 continue
378 # Step 4: Filter eligible servers (has valid last_used)
379 eligible_servers = [metrics for metrics in server_metrics.values() if metrics.server_last_used > 0.0]
380 eligible_count = len(eligible_servers)
382 # Step 5: Sort by recency (newer first), then tie-breakers
383 eligible_servers.sort(
384 key=lambda m: (
385 -m.server_last_used, # Primary: most recent first (descending)
386 -m.active_session_count, # Tie-breaker 1: more active sessions
387 -m.total_use_count, # Tie-breaker 2: higher use count
388 m.url, # Tie-breaker 3: deterministic (ascending)
389 )
390 )
392 # Step 6: Select hot servers (up to hot_cap, no backfill)
393 hot_actual = min(hot_cap, eligible_count)
394 hot_servers = [m.url for m in eligible_servers[:hot_actual]]
396 # Step 7: Cold servers = all remaining
397 hot_set = set(hot_servers)
398 cold_servers = [url for url in all_gateway_urls if url not in hot_set]
400 # Step 8: Build metadata
401 underutilized_reason = None
402 if eligible_count < hot_cap:
403 underutilized_reason = f"Only {eligible_count} servers have pooled sessions, " f"below hot_cap={hot_cap}"
405 return ClassificationResult(
406 hot_servers=hot_servers,
407 cold_servers=cold_servers,
408 metadata=ClassificationMetadata(
409 total_servers=total_servers, hot_cap=hot_cap, hot_actual=hot_actual, eligible_count=eligible_count, timestamp=time.time(), underutilized_reason=underutilized_reason
410 ),
411 )
413 async def _get_gateway_url_map(self) -> Dict[str, str]:
414 """Get mapping of gateway_id → canonical URL for all enabled gateways.
416 Returns:
417 Dict mapping gateway ID to its canonical URL
418 """
419 # Third-Party
420 from sqlalchemy import select
422 # First-Party
423 from mcpgateway.db import Gateway, SessionLocal
425 try:
426 with SessionLocal() as db:
427 result = db.execute(select(Gateway.id, Gateway.url).where(Gateway.enabled.is_(True)))
428 return {str(row[0]): row[1] for row in result}
429 except Exception as e:
430 logger.error(f"Failed to get gateway URL map: {e}")
431 return {}
433 async def _publish_classification_to_redis(self, result: ClassificationResult) -> None:
434 """Publish classification result to Redis atomically.
436 Args:
437 result: Classification result to publish
438 """
439 if not self._redis:
440 return
442 try:
443 # Atomic pipeline for transactional updates
444 async with self._redis.pipeline(transaction=True) as pipe:
445 # Clear old classification
446 await pipe.delete(self.CLASSIFICATION_HOT_KEY, self.CLASSIFICATION_COLD_KEY)
448 # Set new classification
449 # Set TTL on classification sets to prevent stale data after worker crash
450 ttl = int(settings.gateway_auto_refresh_interval * 2)
452 if result.hot_servers:
453 await pipe.sadd(self.CLASSIFICATION_HOT_KEY, *result.hot_servers)
455 if result.cold_servers:
456 await pipe.sadd(self.CLASSIFICATION_COLD_KEY, *result.cold_servers)
458 # Expire classification sets regardless of whether they had members
459 await pipe.expire(self.CLASSIFICATION_HOT_KEY, ttl)
460 await pipe.expire(self.CLASSIFICATION_COLD_KEY, ttl)
462 # Store metadata (expire after 2x classification interval)
463 metadata_json = orjson.dumps(asdict(result.metadata))
464 await pipe.set(self.CLASSIFICATION_METADATA_KEY, metadata_json, ex=ttl)
466 await pipe.set(self.CLASSIFICATION_TIMESTAMP_KEY, result.metadata.timestamp, ex=ttl)
468 await pipe.execute()
470 logger.debug("Classification published to Redis successfully")
472 except Exception as e:
473 logger.error(f"Failed to publish classification to Redis: {e}")
475 async def get_server_classification(self, url: str) -> Optional[str]:
476 """Get classification for a server (hot/cold).
478 Args:
479 url: Server URL
481 Returns:
482 "hot", "cold", or None if not classified
483 """
484 if not self._redis:
485 return None # No Redis, classification not available
487 try:
488 is_hot = await self._redis.sismember(self.CLASSIFICATION_HOT_KEY, url)
489 if is_hot:
490 return "hot"
492 is_cold = await self._redis.sismember(self.CLASSIFICATION_COLD_KEY, url)
493 if is_cold:
494 return "cold"
496 return None # Not yet classified
497 except Exception as e:
498 logger.warning(f"Failed to get classification for {url}: {e}")
499 return None # Fail open
501 def _poll_state_key(self, url: str, poll_type: str, gateway_id: str = "") -> str:
502 """Build the Redis key for poll-state tracking.
504 Includes gateway_id when provided so that distinct gateways sharing the
505 same upstream URL track their refresh schedules independently.
506 """
507 scope = f"{url}\0{gateway_id}" if gateway_id else url
508 scope_hash = hashlib.sha256(scope.encode()).hexdigest()[:32]
509 return self.POLL_STATE_KEY_TEMPLATE.format(scope_hash=scope_hash, poll_type=poll_type)
511 async def should_poll_server(self, url: str, poll_type: Literal["health", "tool_discovery"], gateway_id: str = "") -> bool:
512 """Determine if server should be polled now based on classification.
514 Args:
515 url: Server URL
516 poll_type: Type of poll (health or tool_discovery)
517 gateway_id: Optional gateway ID for per-gateway poll tracking
519 Returns:
520 True if should poll now, False otherwise
521 """
522 if not settings.hot_cold_classification_enabled:
523 return True # Feature disabled, always poll
525 if not self._redis:
526 return True # No Redis, always poll (single-worker mode)
528 try:
529 classification = await self.get_server_classification(url)
530 if classification is None:
531 return True # Not yet classified, poll anyway
533 last_poll_key = self._poll_state_key(url, poll_type, gateway_id)
534 last_poll_str = await self._redis.get(last_poll_key)
536 if last_poll_str is None:
537 # Never polled, should poll now (caller must call mark_poll_completed after)
538 return True
540 last_poll = float(last_poll_str)
541 now = time.time()
542 if not 0 < last_poll <= now + 60:
543 last_poll = 0.0 # treat as never polled; prevents manipulation via future timestamps
544 elapsed = now - last_poll
546 # Determine interval based on classification
547 interval = settings.hot_server_check_interval if classification == "hot" else settings.cold_server_check_interval
549 should_poll = elapsed >= interval
551 return should_poll
553 except Exception as e:
554 logger.warning(f"Error checking poll status for {url}: {e}")
555 return True # Fail open: poll on error
557 async def mark_poll_completed(self, url: str, poll_type: Literal["health", "tool_discovery"], gateway_id: str = "") -> None:
558 """Record that a poll was actually performed.
560 Call this AFTER the poll/refresh succeeds, not at decision time.
561 This prevents wasting poll slots when downstream throttling skips the refresh.
563 Args:
564 url: Server URL
565 poll_type: Type of poll
566 gateway_id: Optional gateway ID for per-gateway poll tracking
567 """
568 if not self._redis:
569 return
571 try:
572 classification = await self.get_server_classification(url)
573 interval = settings.hot_server_check_interval if classification == "hot" else settings.cold_server_check_interval
575 last_poll_key = self._poll_state_key(url, poll_type, gateway_id)
576 await self._redis.set(last_poll_key, time.time(), ex=int(interval * 2)) # Expire after 2x interval
577 except Exception as e:
578 logger.warning(f"Failed to update poll timestamp for {url}: {e}")