Coverage for mcpgateway / transports / redis_event_store.py: 100%
76 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""
3Redis-backed event store for Streamable HTTP stateful sessions.
5Design goals:
6- Multi-worker safe: store+evict is atomic (Lua), so concurrent writers do not corrupt meta/count.
7- Bounded memory: per-stream ring buffer with eviction.
8- Bounded index growth: event_id index entries expire with the stream TTL.
9"""
11# Standard
12import logging
13from typing import TYPE_CHECKING
14import uuid
16# Third-Party
17from mcp.server.streamable_http import EventCallback, EventStore
18from mcp.types import JSONRPCMessage
19import orjson
21# First-Party
22from mcpgateway.utils.redis_client import get_redis_client
24if TYPE_CHECKING: # pragma: no cover
25 # Third-Party
26 from redis.asyncio import Redis
28logger = logging.getLogger(__name__)
31_STORE_EVENT_LUA = r"""
32-- KEYS:
33-- 1) meta_key
34-- 2) events_key (zset: member=event_id, score=seq_num)
35-- 3) messages_key (hash: event_id -> message_json)
36-- ARGV:
37-- 1) event_id
38-- 2) message_json (orjson encoded; "null" for priming)
39-- 3) ttl_seconds
40-- 4) max_events
41-- 5) index_prefix (string, eg "mcpgw:eventstore:event_index:")
42-- 6) stream_id
44local meta_key = KEYS[1]
45local events_key = KEYS[2]
46local messages_key = KEYS[3]
48local event_id = ARGV[1]
49local message_json = ARGV[2]
50local ttl = tonumber(ARGV[3])
51local max_events = tonumber(ARGV[4])
52local index_prefix = ARGV[5]
53local stream_id = ARGV[6]
55local seq_num = redis.call('HINCRBY', meta_key, 'next_seq', 1)
56local count = redis.call('HINCRBY', meta_key, 'count', 1)
57if count == 1 then
58 redis.call('HSET', meta_key, 'start_seq', seq_num)
59end
61redis.call('ZADD', events_key, seq_num, event_id)
62redis.call('HSET', messages_key, event_id, message_json)
64local index_key = index_prefix .. event_id
65redis.call('SET', index_key, cjson.encode({stream_id=stream_id, seq_num=seq_num}), 'EX', ttl)
67if count > max_events then
68 local to_evict = count - max_events
69 local evicted_ids = redis.call('ZRANGE', events_key, 0, to_evict - 1)
70 redis.call('ZREMRANGEBYRANK', events_key, 0, to_evict - 1)
72 if #evicted_ids > 0 then
73 redis.call('HDEL', messages_key, unpack(evicted_ids))
74 for _, ev_id in ipairs(evicted_ids) do
75 redis.call('DEL', index_prefix .. ev_id)
76 end
77 end
79 redis.call('HSET', meta_key, 'count', max_events)
80 local first = redis.call('ZRANGE', events_key, 0, 0, 'WITHSCORES')
81 if #first >= 2 then
82 redis.call('HSET', meta_key, 'start_seq', tonumber(first[2]))
83 else
84 redis.call('HSET', meta_key, 'start_seq', seq_num)
85 end
86end
88redis.call('EXPIRE', meta_key, ttl)
89redis.call('EXPIRE', events_key, ttl)
90redis.call('EXPIRE', messages_key, ttl)
92return seq_num
93"""
96class RedisEventStore(EventStore):
97 """Redis-backed event store for multi-worker Streamable HTTP."""
99 def __init__(self, max_events_per_stream: int = 100, ttl: int = 3600, key_prefix: str = "mcpgw:eventstore"):
100 """Initialize Redis event store.
102 Args:
103 max_events_per_stream: Maximum events per stream (ring buffer size).
104 ttl: Stream TTL in seconds.
105 key_prefix: Redis key prefix for namespacing this store's data. Primarily useful for test isolation.
106 """
107 self.max_events = max_events_per_stream
108 self.ttl = ttl
109 self.key_prefix = key_prefix.rstrip(":")
110 logger.debug("RedisEventStore initialized: max_events=%s ttl=%ss", max_events_per_stream, ttl)
112 def _get_stream_meta_key(self, stream_id: str) -> str:
113 """Return Redis key for stream metadata hash.
115 Args:
116 stream_id: Unique stream identifier.
118 Returns:
119 Redis key string.
120 """
121 return f"{self.key_prefix}:{stream_id}:meta"
123 def _get_stream_events_key(self, stream_id: str) -> str:
124 """Return Redis key for stream events sorted set.
126 Args:
127 stream_id: Unique stream identifier.
129 Returns:
130 Redis key string.
131 """
132 return f"{self.key_prefix}:{stream_id}:events"
134 def _get_stream_messages_key(self, stream_id: str) -> str:
135 """Return Redis key for stream messages hash.
137 Args:
138 stream_id: Unique stream identifier.
140 Returns:
141 Redis key string.
142 """
143 return f"{self.key_prefix}:{stream_id}:messages"
145 def _event_index_prefix(self) -> str:
146 """Return prefix for per-event index keys.
148 Returns:
149 Prefix string for index keys.
150 """
151 return f"{self.key_prefix}:event_index:"
153 def _event_index_key(self, event_id: str) -> str:
154 """Return Redis key for event index lookup.
156 Args:
157 event_id: Unique event identifier.
159 Returns:
160 Redis key string.
161 """
162 return f"{self._event_index_prefix()}{event_id}"
164 async def store_event(self, stream_id: str, message: JSONRPCMessage | None) -> str:
165 """Store an event in Redis atomically.
167 Args:
168 stream_id: Unique stream identifier.
169 message: JSON-RPC message to store (None for priming events).
171 Returns:
172 Unique event_id for this event.
174 Raises:
175 RuntimeError: If Redis client is not available.
176 """
177 redis: Redis = await get_redis_client()
178 if redis is None:
179 raise RuntimeError("Redis client not available - cannot store event")
181 event_id = str(uuid.uuid4())
183 # Convert message to dict for serialization (Pydantic model -> dict)
184 message_dict = None if message is None else (message.model_dump() if hasattr(message, "model_dump") else dict(message))
185 message_json = orjson.dumps(message_dict)
187 meta_key = self._get_stream_meta_key(stream_id)
188 events_key = self._get_stream_events_key(stream_id)
189 messages_key = self._get_stream_messages_key(stream_id)
191 await redis.eval(
192 _STORE_EVENT_LUA,
193 3,
194 meta_key,
195 events_key,
196 messages_key,
197 event_id,
198 message_json,
199 int(self.ttl),
200 int(self.max_events),
201 self._event_index_prefix(),
202 stream_id,
203 )
205 return event_id
207 async def replay_events_after(self, last_event_id: str, send_callback: EventCallback) -> str | None:
208 """Replay events after a specific event_id.
210 Args:
211 last_event_id: Event ID to replay from.
212 send_callback: Async callback to receive replayed messages.
214 Returns:
215 stream_id if found, None if event not found or evicted.
216 """
217 redis: Redis = await get_redis_client()
218 if redis is None:
219 logger.debug("Redis client not available - cannot replay events")
220 return None
222 index_data = await redis.get(self._event_index_key(last_event_id))
223 if not index_data:
224 return None
226 try:
227 info = orjson.loads(index_data)
228 except Exception:
229 return None
231 stream_id = info.get("stream_id")
232 last_seq = info.get("seq_num")
233 if not stream_id or last_seq is None:
234 return None
236 meta_key = self._get_stream_meta_key(stream_id)
237 events_key = self._get_stream_events_key(stream_id)
238 messages_key = self._get_stream_messages_key(stream_id)
240 # Eviction detection: if last_seq < start_seq, the event is gone.
241 start_seq_bytes = await redis.hget(meta_key, "start_seq")
242 if start_seq_bytes:
243 try:
244 start_seq = int(start_seq_bytes)
245 except Exception:
246 start_seq = None
247 if start_seq is not None and int(last_seq) < start_seq:
248 return None
250 event_ids = await redis.zrangebyscore(events_key, int(last_seq) + 1, "+inf")
251 for event_id_bytes in event_ids:
252 ev_id = event_id_bytes.decode("latin-1") if isinstance(event_id_bytes, (bytes, bytearray)) else str(event_id_bytes)
253 msg_json = await redis.hget(messages_key, ev_id)
254 if msg_json is None:
255 continue
256 try:
257 msg = orjson.loads(msg_json)
258 except Exception:
259 msg = None
260 await send_callback(msg)
262 return stream_id