Coverage for mcpgateway / transports / redis_event_store.py: 100%

76 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2""" 

3Redis-backed event store for Streamable HTTP stateful sessions. 

4 

5Design goals: 

6- Multi-worker safe: store+evict is atomic (Lua), so concurrent writers do not corrupt meta/count. 

7- Bounded memory: per-stream ring buffer with eviction. 

8- Bounded index growth: event_id index entries expire with the stream TTL. 

9""" 

10 

11# Standard 

12import logging 

13from typing import TYPE_CHECKING 

14import uuid 

15 

16# Third-Party 

17from mcp.server.streamable_http import EventCallback, EventStore 

18from mcp.types import JSONRPCMessage 

19import orjson 

20 

21# First-Party 

22from mcpgateway.utils.redis_client import get_redis_client 

23 

24if TYPE_CHECKING: # pragma: no cover 

25 # Third-Party 

26 from redis.asyncio import Redis 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31_STORE_EVENT_LUA = r""" 

32-- KEYS: 

33-- 1) meta_key 

34-- 2) events_key (zset: member=event_id, score=seq_num) 

35-- 3) messages_key (hash: event_id -> message_json) 

36-- ARGV: 

37-- 1) event_id 

38-- 2) message_json (orjson encoded; "null" for priming) 

39-- 3) ttl_seconds 

40-- 4) max_events 

41-- 5) index_prefix (string, eg "mcpgw:eventstore:event_index:") 

42-- 6) stream_id 

43 

44local meta_key = KEYS[1] 

45local events_key = KEYS[2] 

46local messages_key = KEYS[3] 

47 

48local event_id = ARGV[1] 

49local message_json = ARGV[2] 

50local ttl = tonumber(ARGV[3]) 

51local max_events = tonumber(ARGV[4]) 

52local index_prefix = ARGV[5] 

53local stream_id = ARGV[6] 

54 

55local seq_num = redis.call('HINCRBY', meta_key, 'next_seq', 1) 

56local count = redis.call('HINCRBY', meta_key, 'count', 1) 

57if count == 1 then 

58 redis.call('HSET', meta_key, 'start_seq', seq_num) 

59end 

60 

61redis.call('ZADD', events_key, seq_num, event_id) 

62redis.call('HSET', messages_key, event_id, message_json) 

63 

64local index_key = index_prefix .. event_id 

65redis.call('SET', index_key, cjson.encode({stream_id=stream_id, seq_num=seq_num}), 'EX', ttl) 

66 

67if count > max_events then 

68 local to_evict = count - max_events 

69 local evicted_ids = redis.call('ZRANGE', events_key, 0, to_evict - 1) 

70 redis.call('ZREMRANGEBYRANK', events_key, 0, to_evict - 1) 

71 

72 if #evicted_ids > 0 then 

73 redis.call('HDEL', messages_key, unpack(evicted_ids)) 

74 for _, ev_id in ipairs(evicted_ids) do 

75 redis.call('DEL', index_prefix .. ev_id) 

76 end 

77 end 

78 

79 redis.call('HSET', meta_key, 'count', max_events) 

80 local first = redis.call('ZRANGE', events_key, 0, 0, 'WITHSCORES') 

81 if #first >= 2 then 

82 redis.call('HSET', meta_key, 'start_seq', tonumber(first[2])) 

83 else 

84 redis.call('HSET', meta_key, 'start_seq', seq_num) 

85 end 

86end 

87 

88redis.call('EXPIRE', meta_key, ttl) 

89redis.call('EXPIRE', events_key, ttl) 

90redis.call('EXPIRE', messages_key, ttl) 

91 

92return seq_num 

93""" 

94 

95 

96class RedisEventStore(EventStore): 

97 """Redis-backed event store for multi-worker Streamable HTTP.""" 

98 

99 def __init__(self, max_events_per_stream: int = 100, ttl: int = 3600, key_prefix: str = "mcpgw:eventstore"): 

100 """Initialize Redis event store. 

101 

102 Args: 

103 max_events_per_stream: Maximum events per stream (ring buffer size). 

104 ttl: Stream TTL in seconds. 

105 key_prefix: Redis key prefix for namespacing this store's data. Primarily useful for test isolation. 

106 """ 

107 self.max_events = max_events_per_stream 

108 self.ttl = ttl 

109 self.key_prefix = key_prefix.rstrip(":") 

110 logger.debug("RedisEventStore initialized: max_events=%s ttl=%ss", max_events_per_stream, ttl) 

111 

112 def _get_stream_meta_key(self, stream_id: str) -> str: 

113 """Return Redis key for stream metadata hash. 

114 

115 Args: 

116 stream_id: Unique stream identifier. 

117 

118 Returns: 

119 Redis key string. 

120 """ 

121 return f"{self.key_prefix}:{stream_id}:meta" 

122 

123 def _get_stream_events_key(self, stream_id: str) -> str: 

124 """Return Redis key for stream events sorted set. 

125 

126 Args: 

127 stream_id: Unique stream identifier. 

128 

129 Returns: 

130 Redis key string. 

131 """ 

132 return f"{self.key_prefix}:{stream_id}:events" 

133 

134 def _get_stream_messages_key(self, stream_id: str) -> str: 

135 """Return Redis key for stream messages hash. 

136 

137 Args: 

138 stream_id: Unique stream identifier. 

139 

140 Returns: 

141 Redis key string. 

142 """ 

143 return f"{self.key_prefix}:{stream_id}:messages" 

144 

145 def _event_index_prefix(self) -> str: 

146 """Return prefix for per-event index keys. 

147 

148 Returns: 

149 Prefix string for index keys. 

150 """ 

151 return f"{self.key_prefix}:event_index:" 

152 

153 def _event_index_key(self, event_id: str) -> str: 

154 """Return Redis key for event index lookup. 

155 

156 Args: 

157 event_id: Unique event identifier. 

158 

159 Returns: 

160 Redis key string. 

161 """ 

162 return f"{self._event_index_prefix()}{event_id}" 

163 

164 async def store_event(self, stream_id: str, message: JSONRPCMessage | None) -> str: 

165 """Store an event in Redis atomically. 

166 

167 Args: 

168 stream_id: Unique stream identifier. 

169 message: JSON-RPC message to store (None for priming events). 

170 

171 Returns: 

172 Unique event_id for this event. 

173 

174 Raises: 

175 RuntimeError: If Redis client is not available. 

176 """ 

177 redis: Redis = await get_redis_client() 

178 if redis is None: 

179 raise RuntimeError("Redis client not available - cannot store event") 

180 

181 event_id = str(uuid.uuid4()) 

182 

183 # Convert message to dict for serialization (Pydantic model -> dict) 

184 message_dict = None if message is None else (message.model_dump() if hasattr(message, "model_dump") else dict(message)) 

185 message_json = orjson.dumps(message_dict) 

186 

187 meta_key = self._get_stream_meta_key(stream_id) 

188 events_key = self._get_stream_events_key(stream_id) 

189 messages_key = self._get_stream_messages_key(stream_id) 

190 

191 await redis.eval( 

192 _STORE_EVENT_LUA, 

193 3, 

194 meta_key, 

195 events_key, 

196 messages_key, 

197 event_id, 

198 message_json, 

199 int(self.ttl), 

200 int(self.max_events), 

201 self._event_index_prefix(), 

202 stream_id, 

203 ) 

204 

205 return event_id 

206 

207 async def replay_events_after(self, last_event_id: str, send_callback: EventCallback) -> str | None: 

208 """Replay events after a specific event_id. 

209 

210 Args: 

211 last_event_id: Event ID to replay from. 

212 send_callback: Async callback to receive replayed messages. 

213 

214 Returns: 

215 stream_id if found, None if event not found or evicted. 

216 """ 

217 redis: Redis = await get_redis_client() 

218 if redis is None: 

219 logger.debug("Redis client not available - cannot replay events") 

220 return None 

221 

222 index_data = await redis.get(self._event_index_key(last_event_id)) 

223 if not index_data: 

224 return None 

225 

226 try: 

227 info = orjson.loads(index_data) 

228 except Exception: 

229 return None 

230 

231 stream_id = info.get("stream_id") 

232 last_seq = info.get("seq_num") 

233 if not stream_id or last_seq is None: 

234 return None 

235 

236 meta_key = self._get_stream_meta_key(stream_id) 

237 events_key = self._get_stream_events_key(stream_id) 

238 messages_key = self._get_stream_messages_key(stream_id) 

239 

240 # Eviction detection: if last_seq < start_seq, the event is gone. 

241 start_seq_bytes = await redis.hget(meta_key, "start_seq") 

242 if start_seq_bytes: 

243 try: 

244 start_seq = int(start_seq_bytes) 

245 except Exception: 

246 start_seq = None 

247 if start_seq is not None and int(last_seq) < start_seq: 

248 return None 

249 

250 event_ids = await redis.zrangebyscore(events_key, int(last_seq) + 1, "+inf") 

251 for event_id_bytes in event_ids: 

252 ev_id = event_id_bytes.decode("latin-1") if isinstance(event_id_bytes, (bytes, bytearray)) else str(event_id_bytes) 

253 msg_json = await redis.hget(messages_key, ev_id) 

254 if msg_json is None: 

255 continue 

256 try: 

257 msg = orjson.loads(msg_json) 

258 except Exception: 

259 msg = None 

260 await send_callback(msg) 

261 

262 return stream_id