Coverage for mcpgateway / transports / websocket_transport.py: 100%

81 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/transports/websocket_transport.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7WebSocket Transport Implementation. 

8This module implements WebSocket transport for MCP, providing 

9full-duplex communication between client and server. 

10""" 

11 

12# Standard 

13import asyncio 

14from typing import Any, AsyncGenerator, Dict, Optional 

15 

16# Third-Party 

17from fastapi import WebSocket, WebSocketDisconnect 

18 

19# First-Party 

20from mcpgateway.config import settings 

21from mcpgateway.services.logging_service import LoggingService 

22from mcpgateway.transports.base import Transport 

23 

24# Initialize logging service first 

25logging_service = LoggingService() 

26logger = logging_service.get_logger(__name__) 

27 

28 

29class WebSocketTransport(Transport): 

30 """Transport implementation using WebSocket. 

31 

32 This transport implementation uses WebSocket for full-duplex communication 

33 between the MCP gateway and clients. It provides real-time bidirectional 

34 messaging with automatic ping/pong keepalive support. 

35 

36 Examples: 

37 >>> # Note: WebSocket transport requires a FastAPI WebSocket object 

38 >>> # and cannot be easily tested in doctest environment 

39 >>> from unittest.mock import Mock 

40 >>> mock_websocket = Mock(spec=WebSocket) 

41 >>> transport = WebSocketTransport(mock_websocket) 

42 >>> transport 

43 <mcpgateway.transports.websocket_transport.WebSocketTransport object at ...> 

44 

45 >>> # Check initial connection state 

46 >>> transport._connected 

47 False 

48 >>> transport._ping_task is None 

49 True 

50 

51 >>> # Verify it's a proper Transport subclass 

52 >>> from mcpgateway.transports.base import Transport 

53 >>> isinstance(transport, Transport) 

54 True 

55 >>> issubclass(WebSocketTransport, Transport) 

56 True 

57 

58 >>> # Verify required methods exist 

59 >>> hasattr(transport, 'connect') 

60 True 

61 >>> hasattr(transport, 'disconnect') 

62 True 

63 >>> hasattr(transport, 'send_message') 

64 True 

65 >>> hasattr(transport, 'receive_message') 

66 True 

67 >>> hasattr(transport, 'is_connected') 

68 True 

69 """ 

70 

71 def __init__(self, websocket: WebSocket): 

72 """Initialize WebSocket transport. 

73 

74 Args: 

75 websocket: FastAPI WebSocket connection 

76 

77 Examples: 

78 >>> # Test initialization with mock WebSocket 

79 >>> from unittest.mock import Mock 

80 >>> mock_ws = Mock(spec=WebSocket) 

81 >>> transport = WebSocketTransport(mock_ws) 

82 >>> transport._websocket is mock_ws 

83 True 

84 >>> transport._connected 

85 False 

86 >>> transport._ping_task is None 

87 True 

88 """ 

89 self._websocket = websocket 

90 self._connected = False 

91 self._ping_task: Optional[asyncio.Task] = None 

92 

93 async def connect(self) -> None: 

94 """Set up WebSocket connection. 

95 

96 Examples: 

97 >>> # Test connection setup with mock WebSocket 

98 >>> from unittest.mock import Mock, AsyncMock 

99 >>> mock_ws = Mock(spec=WebSocket) 

100 >>> mock_ws.accept = AsyncMock() 

101 >>> transport = WebSocketTransport(mock_ws) 

102 >>> import asyncio 

103 >>> asyncio.run(transport.connect()) 

104 >>> # Note: connect() may call disconnect() in finally block during testing 

105 >>> # So we check that accept was called instead of connection state 

106 >>> mock_ws.accept.called 

107 True 

108 """ 

109 await self._websocket.accept() 

110 self._connected = True 

111 

112 # Start ping task 

113 if settings.websocket_ping_interval > 0: 

114 self._ping_task = asyncio.create_task(self._ping_loop()) 

115 

116 logger.info("WebSocket transport connected") 

117 

118 async def disconnect(self) -> None: 

119 """Clean up WebSocket connection. 

120 

121 Examples: 

122 >>> # Test disconnection with mock WebSocket 

123 >>> from unittest.mock import Mock, AsyncMock 

124 >>> mock_ws = Mock(spec=WebSocket) 

125 >>> mock_ws.close = AsyncMock() 

126 >>> transport = WebSocketTransport(mock_ws) 

127 >>> transport._connected = True 

128 >>> import asyncio 

129 >>> asyncio.run(transport.disconnect()) 

130 >>> transport._connected 

131 False 

132 >>> mock_ws.close.called 

133 True 

134 

135 >>> # Test disconnection when already disconnected 

136 >>> transport = WebSocketTransport(mock_ws) 

137 >>> asyncio.run(transport.disconnect()) 

138 >>> transport._connected 

139 False 

140 """ 

141 try: 

142 loop = asyncio.get_running_loop() 

143 except RuntimeError: 

144 # No running loop (interpreter shutdown, for example) 

145 return 

146 

147 if loop.is_closed(): 

148 # The loop is already closed - further asyncio calls are illegal 

149 return 

150 

151 ping_task = getattr(self, "_ping_task", None) 

152 

153 should_cancel = ping_task and not ping_task.done() and ping_task is not asyncio.current_task() # task exists # still running # not *this* coroutine 

154 

155 if should_cancel: 

156 ping_task.cancel() 

157 try: 

158 await ping_task # allow it to exit gracefully 

159 except asyncio.CancelledError: 

160 pass 

161 

162 # ──────────────────────────────────────────────────────────────── 

163 # 3. Close the WebSocket connection (if still open) 

164 # ──────────────────────────────────────────────────────────────── 

165 if getattr(self, "_connected", False): 

166 try: 

167 await self._websocket.close() 

168 finally: 

169 self._connected = False 

170 logger.info("WebSocket transport disconnected") 

171 

172 async def send_message(self, message: Dict[str, Any]) -> None: 

173 """Send a message over WebSocket. 

174 

175 Args: 

176 message: Message to send 

177 

178 Raises: 

179 RuntimeError: If transport is not connected 

180 Exception: If unable to send json to websocket 

181 

182 Examples: 

183 >>> # Test sending message when connected 

184 >>> from unittest.mock import Mock, AsyncMock 

185 >>> mock_ws = Mock(spec=WebSocket) 

186 >>> mock_ws.send_json = AsyncMock() 

187 >>> transport = WebSocketTransport(mock_ws) 

188 >>> transport._connected = True 

189 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1} 

190 >>> import asyncio 

191 >>> asyncio.run(transport.send_message(message)) 

192 >>> mock_ws.send_json.called 

193 True 

194 >>> mock_ws.send_json.call_args[0][0] 

195 {'jsonrpc': '2.0', 'method': 'test', 'id': 1} 

196 

197 >>> # Test sending message when not connected 

198 >>> transport = WebSocketTransport(mock_ws) 

199 >>> try: 

200 ... asyncio.run(transport.send_message({"test": "message"})) 

201 ... except RuntimeError as e: 

202 ... print("Expected error:", str(e)) 

203 Expected error: Transport not connected 

204 

205 >>> # Test message format validation 

206 >>> transport = WebSocketTransport(mock_ws) 

207 >>> transport._connected = True 

208 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}} 

209 >>> isinstance(valid_message, dict) 

210 True 

211 >>> "jsonrpc" in valid_message 

212 True 

213 """ 

214 if not self._connected: 

215 raise RuntimeError("Transport not connected") 

216 

217 try: 

218 await self._websocket.send_json(message) 

219 except Exception as e: 

220 logger.error(f"Failed to send message: {e}") 

221 raise 

222 

223 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]: 

224 """Receive messages from WebSocket. 

225 

226 Yields: 

227 Received messages 

228 

229 Raises: 

230 RuntimeError: If transport is not connected 

231 

232 Examples: 

233 >>> # Test receive message when connected 

234 >>> from unittest.mock import Mock, AsyncMock 

235 >>> mock_ws = Mock(spec=WebSocket) 

236 >>> mock_ws.receive_json = AsyncMock(return_value={"test": "message"}) 

237 >>> transport = WebSocketTransport(mock_ws) 

238 >>> transport._connected = True 

239 >>> import asyncio 

240 >>> async def test_receive(): 

241 ... async for msg in transport.receive_message(): 

242 ... return msg 

243 ... return None 

244 >>> result = asyncio.run(test_receive()) 

245 >>> result 

246 {'test': 'message'} 

247 

248 >>> # Test receive message when not connected 

249 >>> transport = WebSocketTransport(mock_ws) 

250 >>> try: 

251 ... async def test_receive(): 

252 ... async for msg in transport.receive_message(): 

253 ... pass 

254 ... asyncio.run(test_receive()) 

255 ... except RuntimeError as e: 

256 ... print("Expected error:", str(e)) 

257 Expected error: Transport not connected 

258 

259 >>> # Verify generator behavior 

260 >>> transport = WebSocketTransport(mock_ws) 

261 >>> import inspect 

262 >>> inspect.isasyncgenfunction(transport.receive_message) 

263 True 

264 """ 

265 if not self._connected: 

266 raise RuntimeError("Transport not connected") 

267 

268 try: 

269 while True: 

270 message = await self._websocket.receive_json() 

271 yield message 

272 

273 except WebSocketDisconnect: 

274 logger.info("WebSocket client disconnected") 

275 self._connected = False 

276 except Exception as e: 

277 logger.error(f"Error receiving message: {e}") 

278 self._connected = False 

279 finally: 

280 await self.disconnect() 

281 

282 async def is_connected(self) -> bool: 

283 """Check if transport is connected. 

284 

285 Returns: 

286 True if connected 

287 

288 Examples: 

289 >>> # Test initial state 

290 >>> from unittest.mock import Mock 

291 >>> mock_ws = Mock(spec=WebSocket) 

292 >>> transport = WebSocketTransport(mock_ws) 

293 >>> import asyncio 

294 >>> asyncio.run(transport.is_connected()) 

295 False 

296 

297 >>> # Test after connection 

298 >>> transport = WebSocketTransport(mock_ws) 

299 >>> transport._connected = True 

300 >>> asyncio.run(transport.is_connected()) 

301 True 

302 

303 >>> # Test after disconnection 

304 >>> transport = WebSocketTransport(mock_ws) 

305 >>> transport._connected = True 

306 >>> transport._connected = False 

307 >>> asyncio.run(transport.is_connected()) 

308 False 

309 """ 

310 return self._connected 

311 

312 async def _ping_loop(self) -> None: 

313 """Send periodic ping messages to keep connection alive. 

314 

315 Examples: 

316 >>> # Test ping loop method exists 

317 >>> from unittest.mock import Mock 

318 >>> mock_ws = Mock(spec=WebSocket) 

319 >>> transport = WebSocketTransport(mock_ws) 

320 >>> hasattr(transport, '_ping_loop') 

321 True 

322 >>> callable(transport._ping_loop) 

323 True 

324 """ 

325 try: 

326 while self._connected: 

327 await asyncio.sleep(settings.websocket_ping_interval) 

328 await self._websocket.send_bytes(b"ping") 

329 try: 

330 resp = await asyncio.wait_for( 

331 self._websocket.receive_bytes(), 

332 timeout=settings.websocket_ping_interval / 2, 

333 ) 

334 if resp != b"pong": 

335 logger.warning("Invalid ping response") 

336 except asyncio.TimeoutError: 

337 logger.warning("Ping timeout") 

338 break 

339 except Exception as e: 

340 logger.error(f"Ping loop error: {e}") 

341 finally: 

342 await self.disconnect() 

343 

344 async def send_ping(self) -> None: 

345 """Send a manual ping message. 

346 

347 Examples: 

348 >>> # Test manual ping when connected 

349 >>> from unittest.mock import Mock, AsyncMock 

350 >>> mock_ws = Mock(spec=WebSocket) 

351 >>> mock_ws.send_bytes = AsyncMock() 

352 >>> transport = WebSocketTransport(mock_ws) 

353 >>> transport._connected = True 

354 >>> import asyncio 

355 >>> asyncio.run(transport.send_ping()) 

356 >>> mock_ws.send_bytes.called 

357 True 

358 >>> mock_ws.send_bytes.call_args[0][0] 

359 b'ping' 

360 

361 >>> # Test manual ping when not connected 

362 >>> transport = WebSocketTransport(mock_ws) 

363 >>> transport._connected = False 

364 >>> asyncio.run(transport.send_ping()) 

365 >>> # Should not call send_bytes when not connected 

366 >>> mock_ws.send_bytes.call_count 

367 1 

368 """ 

369 if self._connected: 

370 await self._websocket.send_bytes(b"ping")