Coverage for mcpgateway / transports / websocket_transport.py: 100%
81 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/transports/websocket_transport.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7WebSocket Transport Implementation.
8This module implements WebSocket transport for MCP, providing
9full-duplex communication between client and server.
10"""
12# Standard
13import asyncio
14from typing import Any, AsyncGenerator, Dict, Optional
16# Third-Party
17from fastapi import WebSocket, WebSocketDisconnect
19# First-Party
20from mcpgateway.config import settings
21from mcpgateway.services.logging_service import LoggingService
22from mcpgateway.transports.base import Transport
24# Initialize logging service first
25logging_service = LoggingService()
26logger = logging_service.get_logger(__name__)
29class WebSocketTransport(Transport):
30 """Transport implementation using WebSocket.
32 This transport implementation uses WebSocket for full-duplex communication
33 between the MCP gateway and clients. It provides real-time bidirectional
34 messaging with automatic ping/pong keepalive support.
36 Examples:
37 >>> # Note: WebSocket transport requires a FastAPI WebSocket object
38 >>> # and cannot be easily tested in doctest environment
39 >>> from unittest.mock import Mock
40 >>> mock_websocket = Mock(spec=WebSocket)
41 >>> transport = WebSocketTransport(mock_websocket)
42 >>> transport
43 <mcpgateway.transports.websocket_transport.WebSocketTransport object at ...>
45 >>> # Check initial connection state
46 >>> transport._connected
47 False
48 >>> transport._ping_task is None
49 True
51 >>> # Verify it's a proper Transport subclass
52 >>> from mcpgateway.transports.base import Transport
53 >>> isinstance(transport, Transport)
54 True
55 >>> issubclass(WebSocketTransport, Transport)
56 True
58 >>> # Verify required methods exist
59 >>> hasattr(transport, 'connect')
60 True
61 >>> hasattr(transport, 'disconnect')
62 True
63 >>> hasattr(transport, 'send_message')
64 True
65 >>> hasattr(transport, 'receive_message')
66 True
67 >>> hasattr(transport, 'is_connected')
68 True
69 """
71 def __init__(self, websocket: WebSocket):
72 """Initialize WebSocket transport.
74 Args:
75 websocket: FastAPI WebSocket connection
77 Examples:
78 >>> # Test initialization with mock WebSocket
79 >>> from unittest.mock import Mock
80 >>> mock_ws = Mock(spec=WebSocket)
81 >>> transport = WebSocketTransport(mock_ws)
82 >>> transport._websocket is mock_ws
83 True
84 >>> transport._connected
85 False
86 >>> transport._ping_task is None
87 True
88 """
89 self._websocket = websocket
90 self._connected = False
91 self._ping_task: Optional[asyncio.Task] = None
93 async def connect(self) -> None:
94 """Set up WebSocket connection.
96 Examples:
97 >>> # Test connection setup with mock WebSocket
98 >>> from unittest.mock import Mock, AsyncMock
99 >>> mock_ws = Mock(spec=WebSocket)
100 >>> mock_ws.accept = AsyncMock()
101 >>> transport = WebSocketTransport(mock_ws)
102 >>> import asyncio
103 >>> asyncio.run(transport.connect())
104 >>> # Note: connect() may call disconnect() in finally block during testing
105 >>> # So we check that accept was called instead of connection state
106 >>> mock_ws.accept.called
107 True
108 """
109 await self._websocket.accept()
110 self._connected = True
112 # Start ping task
113 if settings.websocket_ping_interval > 0:
114 self._ping_task = asyncio.create_task(self._ping_loop())
116 logger.info("WebSocket transport connected")
118 async def disconnect(self) -> None:
119 """Clean up WebSocket connection.
121 Examples:
122 >>> # Test disconnection with mock WebSocket
123 >>> from unittest.mock import Mock, AsyncMock
124 >>> mock_ws = Mock(spec=WebSocket)
125 >>> mock_ws.close = AsyncMock()
126 >>> transport = WebSocketTransport(mock_ws)
127 >>> transport._connected = True
128 >>> import asyncio
129 >>> asyncio.run(transport.disconnect())
130 >>> transport._connected
131 False
132 >>> mock_ws.close.called
133 True
135 >>> # Test disconnection when already disconnected
136 >>> transport = WebSocketTransport(mock_ws)
137 >>> asyncio.run(transport.disconnect())
138 >>> transport._connected
139 False
140 """
141 try:
142 loop = asyncio.get_running_loop()
143 except RuntimeError:
144 # No running loop (interpreter shutdown, for example)
145 return
147 if loop.is_closed():
148 # The loop is already closed - further asyncio calls are illegal
149 return
151 ping_task = getattr(self, "_ping_task", None)
153 should_cancel = ping_task and not ping_task.done() and ping_task is not asyncio.current_task() # task exists # still running # not *this* coroutine
155 if should_cancel:
156 ping_task.cancel()
157 try:
158 await ping_task # allow it to exit gracefully
159 except asyncio.CancelledError:
160 pass
162 # ────────────────────────────────────────────────────────────────
163 # 3. Close the WebSocket connection (if still open)
164 # ────────────────────────────────────────────────────────────────
165 if getattr(self, "_connected", False):
166 try:
167 await self._websocket.close()
168 finally:
169 self._connected = False
170 logger.info("WebSocket transport disconnected")
172 async def send_message(self, message: Dict[str, Any]) -> None:
173 """Send a message over WebSocket.
175 Args:
176 message: Message to send
178 Raises:
179 RuntimeError: If transport is not connected
180 Exception: If unable to send json to websocket
182 Examples:
183 >>> # Test sending message when connected
184 >>> from unittest.mock import Mock, AsyncMock
185 >>> mock_ws = Mock(spec=WebSocket)
186 >>> mock_ws.send_json = AsyncMock()
187 >>> transport = WebSocketTransport(mock_ws)
188 >>> transport._connected = True
189 >>> message = {"jsonrpc": "2.0", "method": "test", "id": 1}
190 >>> import asyncio
191 >>> asyncio.run(transport.send_message(message))
192 >>> mock_ws.send_json.called
193 True
194 >>> mock_ws.send_json.call_args[0][0]
195 {'jsonrpc': '2.0', 'method': 'test', 'id': 1}
197 >>> # Test sending message when not connected
198 >>> transport = WebSocketTransport(mock_ws)
199 >>> try:
200 ... asyncio.run(transport.send_message({"test": "message"}))
201 ... except RuntimeError as e:
202 ... print("Expected error:", str(e))
203 Expected error: Transport not connected
205 >>> # Test message format validation
206 >>> transport = WebSocketTransport(mock_ws)
207 >>> transport._connected = True
208 >>> valid_message = {"jsonrpc": "2.0", "method": "initialize", "params": {}}
209 >>> isinstance(valid_message, dict)
210 True
211 >>> "jsonrpc" in valid_message
212 True
213 """
214 if not self._connected:
215 raise RuntimeError("Transport not connected")
217 try:
218 await self._websocket.send_json(message)
219 except Exception as e:
220 logger.error(f"Failed to send message: {e}")
221 raise
223 async def receive_message(self) -> AsyncGenerator[Dict[str, Any], None]:
224 """Receive messages from WebSocket.
226 Yields:
227 Received messages
229 Raises:
230 RuntimeError: If transport is not connected
232 Examples:
233 >>> # Test receive message when connected
234 >>> from unittest.mock import Mock, AsyncMock
235 >>> mock_ws = Mock(spec=WebSocket)
236 >>> mock_ws.receive_json = AsyncMock(return_value={"test": "message"})
237 >>> transport = WebSocketTransport(mock_ws)
238 >>> transport._connected = True
239 >>> import asyncio
240 >>> async def test_receive():
241 ... async for msg in transport.receive_message():
242 ... return msg
243 ... return None
244 >>> result = asyncio.run(test_receive())
245 >>> result
246 {'test': 'message'}
248 >>> # Test receive message when not connected
249 >>> transport = WebSocketTransport(mock_ws)
250 >>> try:
251 ... async def test_receive():
252 ... async for msg in transport.receive_message():
253 ... pass
254 ... asyncio.run(test_receive())
255 ... except RuntimeError as e:
256 ... print("Expected error:", str(e))
257 Expected error: Transport not connected
259 >>> # Verify generator behavior
260 >>> transport = WebSocketTransport(mock_ws)
261 >>> import inspect
262 >>> inspect.isasyncgenfunction(transport.receive_message)
263 True
264 """
265 if not self._connected:
266 raise RuntimeError("Transport not connected")
268 try:
269 while True:
270 message = await self._websocket.receive_json()
271 yield message
273 except WebSocketDisconnect:
274 logger.info("WebSocket client disconnected")
275 self._connected = False
276 except Exception as e:
277 logger.error(f"Error receiving message: {e}")
278 self._connected = False
279 finally:
280 await self.disconnect()
282 async def is_connected(self) -> bool:
283 """Check if transport is connected.
285 Returns:
286 True if connected
288 Examples:
289 >>> # Test initial state
290 >>> from unittest.mock import Mock
291 >>> mock_ws = Mock(spec=WebSocket)
292 >>> transport = WebSocketTransport(mock_ws)
293 >>> import asyncio
294 >>> asyncio.run(transport.is_connected())
295 False
297 >>> # Test after connection
298 >>> transport = WebSocketTransport(mock_ws)
299 >>> transport._connected = True
300 >>> asyncio.run(transport.is_connected())
301 True
303 >>> # Test after disconnection
304 >>> transport = WebSocketTransport(mock_ws)
305 >>> transport._connected = True
306 >>> transport._connected = False
307 >>> asyncio.run(transport.is_connected())
308 False
309 """
310 return self._connected
312 async def _ping_loop(self) -> None:
313 """Send periodic ping messages to keep connection alive.
315 Examples:
316 >>> # Test ping loop method exists
317 >>> from unittest.mock import Mock
318 >>> mock_ws = Mock(spec=WebSocket)
319 >>> transport = WebSocketTransport(mock_ws)
320 >>> hasattr(transport, '_ping_loop')
321 True
322 >>> callable(transport._ping_loop)
323 True
324 """
325 try:
326 while self._connected:
327 await asyncio.sleep(settings.websocket_ping_interval)
328 await self._websocket.send_bytes(b"ping")
329 try:
330 resp = await asyncio.wait_for(
331 self._websocket.receive_bytes(),
332 timeout=settings.websocket_ping_interval / 2,
333 )
334 if resp != b"pong":
335 logger.warning("Invalid ping response")
336 except asyncio.TimeoutError:
337 logger.warning("Ping timeout")
338 break
339 except Exception as e:
340 logger.error(f"Ping loop error: {e}")
341 finally:
342 await self.disconnect()
344 async def send_ping(self) -> None:
345 """Send a manual ping message.
347 Examples:
348 >>> # Test manual ping when connected
349 >>> from unittest.mock import Mock, AsyncMock
350 >>> mock_ws = Mock(spec=WebSocket)
351 >>> mock_ws.send_bytes = AsyncMock()
352 >>> transport = WebSocketTransport(mock_ws)
353 >>> transport._connected = True
354 >>> import asyncio
355 >>> asyncio.run(transport.send_ping())
356 >>> mock_ws.send_bytes.called
357 True
358 >>> mock_ws.send_bytes.call_args[0][0]
359 b'ping'
361 >>> # Test manual ping when not connected
362 >>> transport = WebSocketTransport(mock_ws)
363 >>> transport._connected = False
364 >>> asyncio.run(transport.send_ping())
365 >>> # Should not call send_bytes when not connected
366 >>> mock_ws.send_bytes.call_count
367 1
368 """
369 if self._connected:
370 await self._websocket.send_bytes(b"ping")