Coverage for mcpgateway / services / root_service.py: 100%

98 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/root_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Root Service Implementation. 

8This module implements root directory management according to the MCP specification. 

9It handles root registration, validation, and change notifications. 

10""" 

11 

12# Standard 

13import asyncio 

14import os 

15from typing import AsyncGenerator, Dict, List, Optional 

16from urllib.parse import urlparse 

17 

18# First-Party 

19from mcpgateway.common.models import Root 

20from mcpgateway.config import settings 

21from mcpgateway.observability import create_span 

22from mcpgateway.services.logging_service import LoggingService 

23 

24# Initialize logging service first 

25logging_service = LoggingService() 

26logger = logging_service.get_logger(__name__) 

27 

28 

29class RootServiceError(Exception): 

30 """Base class for root service errors.""" 

31 

32 

33class RootServiceNotFoundError(RootServiceError): 

34 """Raised when a requested root is not found. 

35 

36 Examples: 

37 >>> error = RootServiceNotFoundError("Root Service not found") 

38 >>> str(error) 

39 'Root Service not found' 

40 >>> isinstance(error, RootServiceError) 

41 True 

42 """ 

43 

44 

45class RootService: 

46 """MCP root service. 

47 

48 Manages roots that can be exposed to MCP clients. 

49 Handles: 

50 - Root registration and validation 

51 - Change notifications 

52 - Root permissions and access control 

53 """ 

54 

55 def __init__(self) -> None: 

56 """Initialize root service.""" 

57 self._roots: Dict[str, Root] = {} 

58 self._subscribers: List[asyncio.Queue] = [] 

59 

60 async def initialize(self) -> None: 

61 """Initialize root service. 

62 

63 Examples: 

64 >>> from mcpgateway.services.root_service import RootService 

65 >>> import asyncio 

66 >>> service = RootService() 

67 >>> asyncio.run(service.initialize()) 

68 

69 Test with default roots configured: 

70 >>> from unittest.mock import patch 

71 >>> service = RootService() 

72 >>> with patch('mcpgateway.config.settings.default_roots', ['file:///tmp', 'http://example.com']): 

73 ... asyncio.run(service.initialize()) 

74 >>> len(service._roots) 

75 2 

76 """ 

77 logger.info("Initializing root service") 

78 # Add any configured default roots 

79 for root_uri in settings.default_roots: 

80 try: 

81 await self.add_root(root_uri) 

82 except RootServiceError as e: 

83 logger.error(f"Failed to add default root {root_uri}: {e}") 

84 

85 async def shutdown(self) -> None: 

86 """Shutdown root service. 

87 

88 Examples: 

89 >>> from mcpgateway.services.root_service import RootService 

90 >>> import asyncio 

91 >>> service = RootService() 

92 >>> asyncio.run(service.shutdown()) 

93 

94 Test cleanup of roots and subscribers: 

95 >>> service = RootService() 

96 >>> _ = asyncio.run(service.add_root('file:///tmp')) 

97 >>> service._subscribers.append(asyncio.Queue()) 

98 >>> asyncio.run(service.shutdown()) 

99 >>> len(service._roots) 

100 0 

101 >>> len(service._subscribers) 

102 0 

103 """ 

104 logger.info("Shutting down root service") 

105 # Clear all roots and subscribers 

106 self._roots.clear() 

107 self._subscribers.clear() 

108 

109 async def list_roots(self) -> List[Root]: 

110 """List available roots. 

111 

112 Returns: 

113 List of registered roots 

114 

115 Examples: 

116 >>> from mcpgateway.services.root_service import RootService 

117 >>> import asyncio 

118 >>> service = RootService() 

119 >>> asyncio.run(service.list_roots()) 

120 [] 

121 

122 Test with multiple roots: 

123 >>> service = RootService() 

124 >>> _ = asyncio.run(service.add_root('file:///tmp')) 

125 >>> _ = asyncio.run(service.add_root('file:///home')) 

126 >>> roots = asyncio.run(service.list_roots()) 

127 >>> len(roots) 

128 2 

129 >>> sorted([str(r.uri) for r in roots]) 

130 ['file:///home', 'file:///tmp'] 

131 """ 

132 with create_span("root.list", {"root.count": len(self._roots)}): 

133 return list(self._roots.values()) 

134 

135 async def add_root(self, uri: str, name: Optional[str] = None) -> Root: 

136 """Add a new root. 

137 

138 Args: 

139 uri: Root URI 

140 name: Optional root name 

141 

142 Returns: 

143 Created root object 

144 

145 Raises: 

146 RootServiceError: If root is invalid or already exists 

147 

148 Examples: 

149 >>> from mcpgateway.services.root_service import RootService 

150 >>> import asyncio 

151 >>> service = RootService() 

152 >>> root = asyncio.run(service.add_root('file:///tmp')) 

153 >>> root.uri == 'file:///tmp' 

154 True 

155 

156 Test with custom name: 

157 >>> service = RootService() 

158 >>> root = asyncio.run(service.add_root('file:///home/user', 'MyHome')) 

159 >>> root.name 

160 'MyHome' 

161 

162 Test duplicate root error: 

163 >>> service = RootService() 

164 >>> _ = asyncio.run(service.add_root('file:///tmp')) 

165 >>> try: 

166 ... asyncio.run(service.add_root('file:///tmp')) 

167 ... except RootServiceError as e: 

168 ... str(e) 

169 'Root already exists: file:///tmp' 

170 

171 Test invalid URI error: 

172 >>> from unittest.mock import patch 

173 >>> service = RootService() 

174 >>> with patch.object(service, '_make_root_uri', side_effect=ValueError('Bad URI')): 

175 ... try: 

176 ... asyncio.run(service.add_root('bad_uri')) 

177 ... except RootServiceError as e: 

178 ... str(e) 

179 'Invalid root URI: Bad URI' 

180 """ 

181 try: 

182 root_uri = self._make_root_uri(uri) 

183 except ValueError as e: 

184 raise RootServiceError(f"Invalid root URI: {e}") 

185 

186 # Skip any access check; just store the key/value. 

187 root_obj = Root( 

188 uri=root_uri, 

189 name=name or os.path.basename(urlparse(root_uri).path) or root_uri, 

190 ) 

191 

192 # NORMALIZED URI from the Root object as the dictionary key 

193 normalized_key = str(root_obj.uri) 

194 

195 if normalized_key in self._roots: 

196 raise RootServiceError(f"Root already exists: {root_uri}") 

197 

198 self._roots[normalized_key] = root_obj 

199 

200 await self._notify_root_added(root_obj) 

201 logger.info(f"Added root: {root_uri}") 

202 return root_obj 

203 

204 async def get_root_by_uri(self, root_uri: str) -> Root: 

205 """Get a root by URI. 

206 

207 Args: 

208 root_uri: Root URI to retrieve 

209 

210 Returns: 

211 Root: The found root object 

212 

213 Raises: 

214 RootServiceNotFoundError: If root not found 

215 

216 Examples: 

217 >>> from mcpgateway.services.root_service import RootService 

218 >>> import asyncio 

219 >>> service = RootService() 

220 >>> _ = asyncio.run(service.add_root('file:///tmp')) 

221 >>> root = asyncio.run(service.get_root_by_uri('file:///tmp')) 

222 >>> root.uri == 'file:///tmp' 

223 True 

224 

225 Test root not found error: 

226 >>> service = RootService() 

227 >>> try: 

228 ... asyncio.run(service.get_root_by_uri('file:///nonexistent')) 

229 ... except RootServiceError as e: 

230 ... str(e) 

231 'Root not found: file:///nonexistent' 

232 """ 

233 # Normalize the URI to match how it was stored 

234 normalized_uri = self._make_root_uri(root_uri) 

235 if normalized_uri not in self._roots: 

236 raise RootServiceNotFoundError(f"Root not found: {root_uri}") 

237 return self._roots[normalized_uri] 

238 

239 async def update_root(self, root_uri: str, name: Optional[str] = None) -> Root: 

240 """Update an existing root. 

241 

242 Args: 

243 root_uri: Root URI to update 

244 name: New name for the root 

245 

246 Returns: 

247 Root: The updated root object 

248 

249 Raises: 

250 RootServiceNotFoundError: If root is not found 

251 

252 Examples: 

253 >>> from mcpgateway.services.root_service import RootService 

254 >>> import asyncio 

255 >>> service = RootService() 

256 >>> _ = asyncio.run(service.add_root('file:///tmp', 'Temp')) 

257 >>> updated = asyncio.run(service.update_root('file:///tmp', 'Updated Temp')) 

258 >>> updated.name 

259 'Updated Temp' 

260 

261 Test root not found error: 

262 >>> service = RootService() 

263 >>> try: 

264 ... asyncio.run(service.update_root('file:///nonexistent', 'New Name')) 

265 ... except RootServiceError as e: 

266 ... str(e) 

267 'Root not found: file:///nonexistent' 

268 """ 

269 # Normalize the URI to match how it was stored 

270 normalized_uri = self._make_root_uri(root_uri) 

271 if normalized_uri not in self._roots: 

272 raise RootServiceNotFoundError(f"Root not found: {root_uri}") 

273 

274 root_obj = self._roots[normalized_uri] 

275 

276 # Update name if provided 

277 if name is not None: 

278 root_obj.name = name 

279 

280 # Notify subscribers of the update 

281 event = {"type": "root_updated", "data": {"uri": root_obj.uri, "name": root_obj.name}} 

282 await self._notify_subscribers(event) 

283 

284 logger.info(f"Updated root: {root_uri}, name: {name}") 

285 return root_obj 

286 

287 async def remove_root(self, root_uri: str) -> None: 

288 """Remove a registered root. 

289 

290 Args: 

291 root_uri: Root URI to remove 

292 

293 Raises: 

294 RootServiceError: If root not found 

295 

296 Examples: 

297 >>> from mcpgateway.services.root_service import RootService 

298 >>> import asyncio 

299 >>> service = RootService() 

300 >>> _ = asyncio.run(service.add_root('file:///tmp')) 

301 >>> asyncio.run(service.remove_root('file:///tmp')) 

302 

303 Test root not found error: 

304 >>> service = RootService() 

305 >>> try: 

306 ... asyncio.run(service.remove_root('file:///nonexistent')) 

307 ... except RootServiceError as e: 

308 ... str(e) 

309 'Root not found: file:///nonexistent' 

310 """ 

311 # Normalize the URI to match how it was stored 

312 normalized_uri = self._make_root_uri(root_uri) 

313 if normalized_uri not in self._roots: 

314 raise RootServiceError(f"Root not found: {root_uri}") 

315 root_obj = self._roots.pop(normalized_uri) 

316 await self._notify_root_removed(root_obj) 

317 logger.info(f"Removed root: {root_uri}") 

318 

319 async def subscribe_changes(self) -> AsyncGenerator[Dict, None]: 

320 """Subscribe to root changes. 

321 

322 Yields: 

323 Root change events 

324 

325 Examples: 

326 This example demonstrates subscription mechanics: 

327 >>> import asyncio 

328 >>> from mcpgateway.services.root_service import RootService 

329 >>> async def test_subscribe(): 

330 ... service = RootService() 

331 ... events = [] 

332 ... async def collect_events(): 

333 ... async for event in service.subscribe_changes(): 

334 ... events.append(event) 

335 ... if event['type'] == 'root_removed': 

336 ... break 

337 ... task = asyncio.create_task(collect_events()) 

338 ... await asyncio.sleep(0) # Let subscription start 

339 ... await service.add_root('file:///tmp') 

340 ... await service.remove_root('file:///tmp') 

341 ... await task 

342 ... return events 

343 >>> events = asyncio.run(test_subscribe()) 

344 >>> len(events) 

345 2 

346 >>> events[0]['type'] 

347 'root_added' 

348 >>> events[1]['type'] 

349 'root_removed' 

350 """ 

351 queue: asyncio.Queue = asyncio.Queue() 

352 self._subscribers.append(queue) 

353 try: 

354 while True: 

355 event = await queue.get() 

356 yield event 

357 finally: 

358 self._subscribers.remove(queue) 

359 

360 def _make_root_uri(self, uri: str) -> str: 

361 """Convert input to a valid URI. 

362 

363 If no scheme is provided, assume a file URI and convert the path to an absolute path. 

364 

365 Args: 

366 uri: Input URI or filesystem path 

367 

368 Returns: 

369 A valid URI string 

370 

371 Examples: 

372 >>> service = RootService() 

373 >>> service._make_root_uri('/tmp') 

374 'file:///tmp' 

375 >>> service._make_root_uri('file:///home') 

376 'file:///home' 

377 >>> service._make_root_uri('http://example.com') 

378 'http://example.com' 

379 >>> service._make_root_uri('ftp://server/path') 

380 'ftp://server/path' 

381 """ 

382 parsed = urlparse(uri) 

383 if not parsed.scheme: 

384 # No scheme provided; assume a file URI and add file:// prefix 

385 return f"file://{uri}" 

386 # If a scheme is present (e.g., http, https, ftp, etc.), return the URI as-is. 

387 return uri 

388 

389 async def _notify_root_added(self, root: Root) -> None: 

390 """Notify subscribers of root addition. 

391 

392 Args: 

393 root: Added root 

394 

395 Note: 

396 The root.uri field returns a FileUrl object which is serialized 

397 as-is in the event data. 

398 

399 Examples: 

400 >>> import asyncio 

401 >>> from mcpgateway.services.root_service import RootService 

402 >>> from mcpgateway.common.models import Root 

403 >>> service = RootService() 

404 >>> queue = asyncio.Queue() 

405 >>> service._subscribers.append(queue) 

406 >>> root = Root(uri='file:///tmp', name='temp') 

407 >>> asyncio.run(service._notify_root_added(root)) 

408 >>> event = asyncio.run(queue.get()) 

409 >>> event['type'] 

410 'root_added' 

411 >>> event['data']['uri'] 

412 FileUrl('file:///tmp') 

413 """ 

414 event = {"type": "root_added", "data": {"uri": root.uri, "name": root.name}} 

415 await self._notify_subscribers(event) 

416 

417 async def _notify_root_removed(self, root: Root) -> None: 

418 """Notify subscribers of root removal. 

419 

420 Args: 

421 root: Removed root 

422 

423 Examples: 

424 >>> import asyncio 

425 >>> from mcpgateway.services.root_service import RootService 

426 >>> from mcpgateway.common.models import Root 

427 >>> service = RootService() 

428 >>> queue = asyncio.Queue() 

429 >>> service._subscribers.append(queue) 

430 >>> root = Root(uri='file:///tmp', name='temp') 

431 >>> asyncio.run(service._notify_root_removed(root)) 

432 >>> event = asyncio.run(queue.get()) 

433 >>> event['type'] 

434 'root_removed' 

435 >>> event['data']['uri'] 

436 FileUrl('file:///tmp') 

437 """ 

438 event = {"type": "root_removed", "data": {"uri": root.uri}} 

439 await self._notify_subscribers(event) 

440 

441 async def _notify_subscribers(self, event: Dict) -> None: 

442 """Send event to all subscribers. 

443 

444 Args: 

445 event: Event to send 

446 

447 Examples: 

448 >>> import asyncio 

449 >>> from mcpgateway.services.root_service import RootService 

450 >>> service = RootService() 

451 >>> queue1 = asyncio.Queue() 

452 >>> queue2 = asyncio.Queue() 

453 >>> service._subscribers.extend([queue1, queue2]) 

454 >>> event = {"type": "test", "data": {}} 

455 >>> asyncio.run(service._notify_subscribers(event)) 

456 >>> asyncio.run(queue1.get()) == event 

457 True 

458 >>> asyncio.run(queue2.get()) == event 

459 True 

460 

461 Test error handling with closed queue: 

462 >>> import logging 

463 >>> logging.disable(logging.CRITICAL) 

464 >>> from unittest.mock import AsyncMock 

465 >>> service = RootService() 

466 >>> bad_queue = AsyncMock() 

467 >>> bad_queue.put.side_effect = Exception("Queue error") 

468 >>> service._subscribers.append(bad_queue) 

469 >>> asyncio.run(service._notify_subscribers({"type": "test"})) 

470 >>> logging.disable(logging.NOTSET) 

471 """ 

472 for queue in self._subscribers: 

473 try: 

474 await queue.put(event) 

475 except Exception as e: 

476 logger.error(f"Failed to notify subscriber: {e}") 

477 

478 

479# Lazy singleton - created on first access, not at module import time. 

480# This avoids instantiation when only exception classes are imported. 

481_root_service_instance = None # pylint: disable=invalid-name 

482 

483 

484def __getattr__(name: str): 

485 """Module-level __getattr__ for lazy singleton creation. 

486 

487 Args: 

488 name: The attribute name being accessed. 

489 

490 Returns: 

491 The root_service singleton instance if name is "root_service". 

492 

493 Raises: 

494 AttributeError: If the attribute name is not "root_service". 

495 """ 

496 global _root_service_instance # pylint: disable=global-statement 

497 if name == "root_service": 

498 if _root_service_instance is None: 

499 _root_service_instance = RootService() 

500 return _root_service_instance 

501 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")