Coverage for mcpgateway / services / tag_service.py: 100%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/tag_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Tag Service Implementation. 

8This module implements tag management and retrieval for all entities in the MCP Gateway. 

9It handles: 

10- Fetching all unique tags across entities 

11- Filtering tags by entity type 

12- Tag statistics and counts 

13- Retrieving entities that have specific tags 

14""" 

15 

16# Standard 

17import logging 

18from typing import Dict, List, Optional 

19 

20# Third-Party 

21from sqlalchemy import func, select 

22from sqlalchemy.orm import Session 

23 

24# First-Party 

25from mcpgateway.db import Gateway as DbGateway 

26from mcpgateway.db import Prompt as DbPrompt 

27from mcpgateway.db import Resource as DbResource 

28from mcpgateway.db import Server as DbServer 

29from mcpgateway.db import Tool as DbTool 

30from mcpgateway.schemas import TaggedEntity, TagInfo, TagStats 

31from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

32 

33logger = logging.getLogger(__name__) 

34 

35# Cache import (lazy to avoid circular dependencies) 

36_ADMIN_STATS_CACHE = None 

37 

38 

39def _get_admin_stats_cache(): 

40 """Get admin stats cache singleton lazily. 

41 

42 Returns: 

43 AdminStatsCache instance. 

44 """ 

45 global _ADMIN_STATS_CACHE # pylint: disable=global-statement 

46 if _ADMIN_STATS_CACHE is None: 

47 # First-Party 

48 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

49 

50 _ADMIN_STATS_CACHE = admin_stats_cache 

51 return _ADMIN_STATS_CACHE 

52 

53 

54class TagService: 

55 """Service for managing and retrieving tags across all entities. 

56 

57 This service provides comprehensive tag management functionality across all MCP Gateway 

58 entity types (tools, resources, prompts, servers, gateways). It handles tag discovery, 

59 entity lookup by tags, and statistics aggregation. 

60 

61 Example: 

62 >>> from unittest.mock import MagicMock 

63 >>> from mcpgateway.schemas import TagInfo, TagStats, TaggedEntity 

64 >>> 

65 >>> # Create service instance 

66 >>> service = TagService() 

67 >>> 

68 >>> # Mock database session 

69 >>> mock_db = MagicMock() 

70 >>> 

71 >>> # Test basic functionality 

72 >>> isinstance(service, TagService) 

73 True 

74 """ 

75 

76 async def get_all_tags(self, db: Session, entity_types: Optional[List[str]] = None, include_entities: bool = False) -> List[TagInfo]: 

77 """Retrieve all unique tags across specified entity types. 

78 

79 This method aggregates tags from multiple entity types and returns comprehensive 

80 statistics about tag usage. It can optionally include detailed information about 

81 which entities have each tag. 

82 

83 Args: 

84 db: Database session for querying entity data 

85 entity_types: List of entity types to filter by. Valid types are: 

86 ['tools', 'resources', 'prompts', 'servers', 'gateways']. 

87 If None, returns tags from all entity types. 

88 include_entities: Whether to include the list of entities that have each tag. 

89 If False, only statistics are returned for better performance. 

90 

91 Returns: 

92 List of TagInfo objects containing tag details, sorted alphabetically by tag name. 

93 Each TagInfo includes: 

94 - name: The tag name 

95 - stats: Usage statistics per entity type 

96 - entities: List of entities with this tag (if include_entities=True) 

97 

98 Example: 

99 >>> import asyncio 

100 >>> from unittest.mock import MagicMock, AsyncMock 

101 >>> 

102 >>> # Create service and mock database 

103 >>> service = TagService() 

104 >>> mock_db = MagicMock() 

105 >>> 

106 >>> # Mock empty result 

107 >>> mock_db.execute.return_value.__iter__ = lambda self: iter([]) 

108 >>> 

109 >>> # Test with empty database 

110 >>> async def test_empty(): 

111 ... tags = await service.get_all_tags(mock_db) 

112 ... return len(tags) 

113 >>> asyncio.run(test_empty()) 

114 0 

115 

116 >>> # Mock result with tag data 

117 >>> mock_result = MagicMock() 

118 >>> mock_result.__iter__ = lambda self: iter([ 

119 ... (["api", "database"],), 

120 ... (["api", "web"],), 

121 ... ]) 

122 >>> mock_db.execute.return_value = mock_result 

123 >>> 

124 >>> # Test with tag data 

125 >>> async def test_with_tags(): 

126 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"]) 

127 ... return len(tags) >= 2 # Should have at least api, database, web tags 

128 >>> asyncio.run(test_with_tags()) 

129 True 

130 

131 >>> # include_entities=True path 

132 >>> from types import SimpleNamespace 

133 >>> entity = SimpleNamespace(id='1', name='E', description='d', tags=['api']) 

134 >>> mock_result2 = MagicMock() 

135 >>> mock_result2.scalars.return_value = [entity] 

136 >>> mock_db.execute.return_value = mock_result2 

137 >>> async def test_with_entities(): 

138 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"], include_entities=True) 

139 ... return len(tags) == 1 and tags[0].entities[0].name == 'E' 

140 >>> asyncio.run(test_with_entities()) 

141 True 

142 

143 Raises: 

144 SQLAlchemyError: If database query fails 

145 ValidationError: If invalid entity types are processed 

146 """ 

147 # Generate cache key from parameters 

148 entity_types_key = ":".join(sorted(entity_types)) if entity_types else "all" 

149 cache_key = f"{entity_types_key}:{include_entities}" 

150 

151 # Check cache first (only for non-entity queries as entity data is large) 

152 if not include_entities: 

153 cache = _get_admin_stats_cache() 

154 cached = await cache.get_tags(cache_key) 

155 if cached is not None: 

156 # Reconstruct TagInfo objects from cached dicts 

157 return [TagInfo.model_validate(t) for t in cached] 

158 

159 tag_data: Dict[str, Dict] = {} 

160 

161 # Define entity type mapping 

162 entity_map = { 

163 "tools": DbTool, 

164 "resources": DbResource, 

165 "prompts": DbPrompt, 

166 "servers": DbServer, 

167 "gateways": DbGateway, 

168 } 

169 

170 # If no entity types specified, use all 

171 if entity_types is None: 

172 entity_types = list(entity_map.keys()) 

173 

174 # Collect tags from each requested entity type 

175 for entity_type in entity_types: 

176 if entity_type not in entity_map: 

177 continue 

178 

179 model = entity_map[entity_type] 

180 

181 # Query all entities with tags from this entity type 

182 if include_entities: 

183 # Get full entity details 

184 stmt = select(model).where(model.tags.isnot(None)) 

185 result = db.execute(stmt) 

186 

187 for entity in result.scalars(): 

188 tags = entity.tags if entity.tags else [] 

189 for raw_tag in tags: 

190 tag = self._get_tag_id(raw_tag) 

191 if tag not in tag_data: 

192 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []} 

193 

194 # Create TaggedEntity 

195 # Determine the ID 

196 if hasattr(entity, "id") and entity.id is not None: 

197 entity_id = str(entity.id) 

198 elif entity_type == "resources" and hasattr(entity, "uri"): 

199 entity_id = str(entity.uri) 

200 else: 

201 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown") 

202 

203 # Determine the name 

204 if hasattr(entity, "name") and entity.name: 

205 entity_name = entity.name 

206 elif hasattr(entity, "original_name") and entity.original_name: 

207 entity_name = entity.original_name 

208 elif hasattr(entity, "uri"): 

209 entity_name = str(entity.uri) 

210 else: 

211 entity_name = entity_id 

212 

213 entity_info = TaggedEntity( 

214 id=entity_id, 

215 name=entity_name, 

216 type=entity_type[:-1], # Remove plural 's' 

217 description=entity.description if hasattr(entity, "description") else None, 

218 ) 

219 tag_data[tag]["entities"].append(entity_info) 

220 

221 # Update stats 

222 self._update_stats(tag_data[tag]["stats"], entity_type) 

223 else: 

224 # Just get tags without entity details 

225 stmt = select(model.tags).where(model.tags.isnot(None)) 

226 result = db.execute(stmt) 

227 

228 for row in result: 

229 tags = row[0] if row[0] else [] 

230 for raw_tag in tags: 

231 tag = self._get_tag_id(raw_tag) 

232 if tag not in tag_data: 

233 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []} 

234 

235 # Update stats 

236 self._update_stats(tag_data[tag]["stats"], entity_type) 

237 

238 # Convert to TagInfo list 

239 tags = [TagInfo(name=tag, stats=data["stats"], entities=data["entities"] if include_entities else []) for tag, data in sorted(tag_data.items())] 

240 

241 # Store in cache (only for non-entity queries) 

242 if not include_entities: 

243 cache = _get_admin_stats_cache() 

244 await cache.set_tags([t.model_dump() for t in tags], cache_key) 

245 

246 return tags 

247 

248 def _update_stats(self, stats: TagStats, entity_type: str) -> None: 

249 """Update statistics for a specific entity type. 

250 

251 This helper method increments the appropriate counter in the TagStats object 

252 based on the entity type and maintains the total count. 

253 

254 Args: 

255 stats: TagStats object to update with new counts 

256 entity_type: Type of entity to increment count for. Must be one of: 

257 'tools', 'resources', 'prompts', 'servers', 'gateways' 

258 

259 Example: 

260 >>> from mcpgateway.schemas import TagStats 

261 >>> service = TagService() 

262 >>> stats = TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0) 

263 >>> 

264 >>> # Test updating tool stats 

265 >>> service._update_stats(stats, "tools") 

266 >>> stats.tools 

267 1 

268 >>> stats.total 

269 1 

270 >>> 

271 >>> # Test updating resource stats 

272 >>> service._update_stats(stats, "resources") 

273 >>> stats.resources 

274 1 

275 >>> stats.total 

276 2 

277 >>> 

278 >>> # Test with invalid entity type (should not crash) 

279 >>> service._update_stats(stats, "invalid") 

280 >>> stats.total # Should remain 2 

281 2 

282 """ 

283 if entity_type == "tools": 

284 stats.tools += 1 

285 stats.total += 1 

286 elif entity_type == "resources": 

287 stats.resources += 1 

288 stats.total += 1 

289 elif entity_type == "prompts": 

290 stats.prompts += 1 

291 stats.total += 1 

292 elif entity_type == "servers": 

293 stats.servers += 1 

294 stats.total += 1 

295 elif entity_type == "gateways": 

296 stats.gateways += 1 

297 stats.total += 1 

298 # Invalid entity types are ignored (no increment) 

299 

300 def _get_tag_id(self, tag) -> str: 

301 """Return the tag id for a tag entry which may be a string or a dict. 

302 

303 Supports legacy string tags and new dict tags with an 'id' field. 

304 Falls back to 'label' or the string representation when 'id' is missing. 

305 

306 Args: 

307 tag: Tag value which may be a string (legacy) or a dict with an 

308 'id' or 'label' key. 

309 

310 Returns: 

311 The normalized tag id as a string. 

312 """ 

313 if isinstance(tag, str): 

314 return tag 

315 if isinstance(tag, dict): 

316 return tag.get("id") or tag.get("label") or str(tag) 

317 return str(tag) 

318 

319 async def get_entities_by_tag(self, db: Session, tag_name: str, entity_types: Optional[List[str]] = None) -> List[TaggedEntity]: 

320 """Get all entities that have a specific tag. 

321 

322 This method searches across specified entity types to find all entities 

323 that contain the given tag. It returns simplified entity representations 

324 optimized for tag-based discovery and filtering. 

325 

326 Args: 

327 db: Database session for querying entity data 

328 tag_name: The exact tag to search for (case sensitive) 

329 entity_types: Optional list of entity types to search within. 

330 Valid types: ['tools', 'resources', 'prompts', 'servers', 'gateways'] 

331 If None, searches all entity types 

332 

333 Returns: 

334 List of TaggedEntity objects containing basic entity information. 

335 Each TaggedEntity includes: id, name, type, and description. 

336 Results are not sorted and may contain entities from different types. 

337 

338 Example: 

339 >>> import asyncio 

340 >>> from unittest.mock import MagicMock 

341 >>> 

342 >>> # Setup service and mock database 

343 >>> service = TagService() 

344 >>> mock_db = MagicMock() 

345 >>> mock_db.get_bind.return_value.dialect.name = "sqlite" 

346 >>> 

347 >>> # Mock entity with tag 

348 >>> mock_entity = MagicMock() 

349 >>> mock_entity.id = "test-123" 

350 >>> mock_entity.name = "Test Entity" 

351 >>> mock_entity.description = "A test entity" 

352 >>> mock_entity.tags = ["api", "test", "database"] 

353 >>> 

354 >>> # Mock database result 

355 >>> mock_result = MagicMock() 

356 >>> mock_result.scalars.return_value = [mock_entity] 

357 >>> mock_db.execute.return_value = mock_result 

358 >>> 

359 >>> # Test entity lookup by tag 

360 >>> async def test_entity_lookup(): 

361 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"]) 

362 ... return len(entities) 

363 >>> asyncio.run(test_entity_lookup()) 

364 1 

365 

366 >>> # Test with non-existent tag 

367 >>> mock_entity.tags = ["different", "tags"] 

368 >>> async def test_no_match(): 

369 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"]) 

370 ... return len(entities) 

371 >>> asyncio.run(test_no_match()) 

372 0 

373 

374 Note: 

375 - Tag matching is exact and case-sensitive 

376 - Entities without the specified tag are filtered out after database query 

377 - Performance scales with the number of entities in filtered types 

378 - Uses json_contains_tag_expr for cross-database filtering (PostgreSQL/SQLite) 

379 """ 

380 entities = [] 

381 

382 # Define entity type mapping 

383 entity_map = { 

384 "tools": DbTool, 

385 "resources": DbResource, 

386 "prompts": DbPrompt, 

387 "servers": DbServer, 

388 "gateways": DbGateway, 

389 } 

390 

391 # If no entity types specified, use all 

392 if entity_types is None: 

393 entity_types = list(entity_map.keys()) 

394 

395 for entity_type in entity_types: 

396 if entity_type not in entity_map: 

397 continue 

398 

399 model = entity_map[entity_type] 

400 

401 # Query entities that have this tag 

402 # Using json_contains_tag_expr for cross-database compatibility (PostgreSQL/SQLite) 

403 stmt = select(model).where(json_contains_tag_expr(db, model.tags, [tag_name], match_any=True)) 

404 result = db.execute(stmt) 

405 

406 for entity in result.scalars(): 

407 entity_tags = entity.tags or [] 

408 entity_tag_ids = [self._get_tag_id(t) for t in entity_tags] 

409 if tag_name in entity_tag_ids: 

410 # Determine the ID 

411 if hasattr(entity, "id") and entity.id is not None: 

412 entity_id = str(entity.id) 

413 elif entity_type == "resources" and hasattr(entity, "uri"): 

414 entity_id = str(entity.uri) 

415 else: 

416 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown") 

417 

418 # Determine the name 

419 if hasattr(entity, "name") and entity.name: 

420 entity_name = entity.name 

421 elif hasattr(entity, "original_name") and entity.original_name: 

422 entity_name = entity.original_name 

423 elif hasattr(entity, "uri"): 

424 entity_name = str(entity.uri) 

425 else: 

426 entity_name = entity_id 

427 

428 entity_info = TaggedEntity( 

429 id=entity_id, 

430 name=entity_name, 

431 type=entity_type[:-1], # Remove plural 's' 

432 description=entity.description if hasattr(entity, "description") else None, 

433 ) 

434 entities.append(entity_info) 

435 

436 return entities 

437 

438 async def get_tag_counts(self, db: Session) -> Dict[str, int]: 

439 """Get count of unique tags per entity type. 

440 

441 This method calculates the total number of tag instances (not unique tag names) 

442 across all entity types. Useful for analytics and capacity planning. 

443 

444 Args: 

445 db: Database session for querying tag data 

446 

447 Returns: 

448 Dictionary mapping entity type names to total tag counts. 

449 Keys: 'tools', 'resources', 'prompts', 'servers', 'gateways' 

450 Values: Integer counts of total tag instances in each type 

451 

452 Example: 

453 >>> import asyncio 

454 >>> from unittest.mock import MagicMock 

455 >>> 

456 >>> # Setup service and mock database 

457 >>> service = TagService() 

458 >>> mock_db = MagicMock() 

459 >>> 

460 >>> # Mock tag count results 

461 >>> mock_db.execute.return_value.scalars.return_value.all.return_value = [2, 1, 3] # 3 entities with 2, 1, 3 tags each 

462 >>> 

463 >>> # Execute method with mocked responses (same values reused for simplicity) 

464 >>> class _Res: 

465 ... def scalars(self): 

466 ... class _S: 

467 ... def all(self_inner): 

468 ... return [2, 1, 3] 

469 ... return _S() 

470 >>> mock_db.execute.return_value = _Res() 

471 >>> counts = asyncio.run(service.get_tag_counts(mock_db)) 

472 >>> counts['tools'] 

473 6 

474 >>> all(isinstance(v, int) for v in counts.values()) 

475 True 

476 >>> len(counts) 

477 5 

478 

479 Note: 

480 - Counts tag instances, not unique tag names 

481 - An entity with 3 tags contributes 3 to the count 

482 - Empty or null tag arrays contribute 0 to the count 

483 - Uses json_array_length() for efficient counting 

484 """ 

485 counts = {} 

486 

487 # Count unique tags for tools 

488 tool_tags_stmt = select(func.json_array_length(DbTool.tags)).where(DbTool.tags.isnot(None)) 

489 tool_tags = db.execute(tool_tags_stmt).scalars().all() 

490 counts["tools"] = sum(tool_tags) 

491 

492 # Count unique tags for resources 

493 resource_tags_stmt = select(func.json_array_length(DbResource.tags)).where(DbResource.tags.isnot(None)) 

494 resource_tags = db.execute(resource_tags_stmt).scalars().all() 

495 counts["resources"] = sum(resource_tags) 

496 

497 # Count unique tags for prompts 

498 prompt_tags_stmt = select(func.json_array_length(DbPrompt.tags)).where(DbPrompt.tags.isnot(None)) 

499 prompt_tags = db.execute(prompt_tags_stmt).scalars().all() 

500 counts["prompts"] = sum(prompt_tags) 

501 

502 # Count unique tags for servers 

503 server_tags_stmt = select(func.json_array_length(DbServer.tags)).where(DbServer.tags.isnot(None)) 

504 server_tags = db.execute(server_tags_stmt).scalars().all() 

505 counts["servers"] = sum(server_tags) 

506 

507 # Count unique tags for gateways 

508 gateway_tags_stmt = select(func.json_array_length(DbGateway.tags)).where(DbGateway.tags.isnot(None)) 

509 gateway_tags = db.execute(gateway_tags_stmt).scalars().all() 

510 counts["gateways"] = sum(gateway_tags) 

511 

512 return counts