Coverage for mcpgateway / services / tag_service.py: 100%

168 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/tag_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Tag Service Implementation. 

8This module implements tag management and retrieval for all entities in ContextForge. 

9It handles: 

10- Fetching all unique tags across entities 

11- Filtering tags by entity type 

12- Tag statistics and counts 

13- Retrieving entities that have specific tags 

14""" 

15 

16# Standard 

17import logging 

18from typing import Dict, List, Optional 

19 

20# Third-Party 

21from sqlalchemy import and_, func, or_, select 

22from sqlalchemy.orm import Session 

23 

24# First-Party 

25from mcpgateway.db import Gateway as DbGateway 

26from mcpgateway.db import Prompt as DbPrompt 

27from mcpgateway.db import Resource as DbResource 

28from mcpgateway.db import Server as DbServer 

29from mcpgateway.db import Tool as DbTool 

30from mcpgateway.schemas import TaggedEntity, TagInfo, TagStats 

31from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr 

32 

33logger = logging.getLogger(__name__) 

34 

35# Cache import (lazy to avoid circular dependencies) 

36_ADMIN_STATS_CACHE = None 

37 

38 

39def _get_admin_stats_cache(): 

40 """Get admin stats cache singleton lazily. 

41 

42 Returns: 

43 AdminStatsCache instance. 

44 """ 

45 global _ADMIN_STATS_CACHE # pylint: disable=global-statement 

46 if _ADMIN_STATS_CACHE is None: 

47 # First-Party 

48 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel 

49 

50 _ADMIN_STATS_CACHE = admin_stats_cache 

51 return _ADMIN_STATS_CACHE 

52 

53 

54class TagService: 

55 """Service for managing and retrieving tags across all entities. 

56 

57 This service provides comprehensive tag management functionality across all ContextForge 

58 entity types (tools, resources, prompts, servers, gateways). It handles tag discovery, 

59 entity lookup by tags, and statistics aggregation. 

60 

61 Example: 

62 >>> from unittest.mock import MagicMock 

63 >>> from mcpgateway.schemas import TagInfo, TagStats, TaggedEntity 

64 >>> 

65 >>> # Create service instance 

66 >>> service = TagService() 

67 >>> 

68 >>> # Mock database session 

69 >>> mock_db = MagicMock() 

70 >>> 

71 >>> # Test basic functionality 

72 >>> isinstance(service, TagService) 

73 True 

74 """ 

75 

76 async def get_all_tags( 

77 self, 

78 db: Session, 

79 entity_types: Optional[List[str]] = None, 

80 include_entities: bool = False, 

81 user_email: Optional[str] = None, 

82 token_teams: Optional[List[str]] = None, 

83 ) -> List[TagInfo]: 

84 """Retrieve all unique tags across specified entity types. 

85 

86 This method aggregates tags from multiple entity types and returns comprehensive 

87 statistics about tag usage. It can optionally include detailed information about 

88 which entities have each tag. 

89 

90 Args: 

91 db: Database session for querying entity data 

92 entity_types: List of entity types to filter by. Valid types are: 

93 ['tools', 'resources', 'prompts', 'servers', 'gateways']. 

94 If None, returns tags from all entity types. 

95 include_entities: Whether to include the list of entities that have each tag. 

96 If False, only statistics are returned for better performance. 

97 user_email: Caller email used for owner/team visibility checks 

98 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope) 

99 

100 Returns: 

101 List of TagInfo objects containing tag details, sorted alphabetically by tag name. 

102 Each TagInfo includes: 

103 - name: The tag name 

104 - stats: Usage statistics per entity type 

105 - entities: List of entities with this tag (if include_entities=True) 

106 

107 Example: 

108 >>> import asyncio 

109 >>> from unittest.mock import MagicMock, AsyncMock 

110 >>> 

111 >>> # Create service and mock database 

112 >>> service = TagService() 

113 >>> mock_db = MagicMock() 

114 >>> 

115 >>> # Mock empty result 

116 >>> mock_db.execute.return_value.__iter__ = lambda self: iter([]) 

117 >>> 

118 >>> # Test with empty database 

119 >>> async def test_empty(): 

120 ... tags = await service.get_all_tags(mock_db) 

121 ... return len(tags) 

122 >>> asyncio.run(test_empty()) 

123 0 

124 

125 >>> # Mock result with tag data 

126 >>> mock_result = MagicMock() 

127 >>> mock_result.__iter__ = lambda self: iter([ 

128 ... (["api", "database"],), 

129 ... (["api", "web"],), 

130 ... ]) 

131 >>> mock_db.execute.return_value = mock_result 

132 >>> 

133 >>> # Test with tag data 

134 >>> async def test_with_tags(): 

135 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"]) 

136 ... return len(tags) >= 2 # Should have at least api, database, web tags 

137 >>> asyncio.run(test_with_tags()) 

138 True 

139 

140 >>> # include_entities=True path 

141 >>> from types import SimpleNamespace 

142 >>> entity = SimpleNamespace(id='1', name='E', description='d', tags=['api']) 

143 >>> mock_result2 = MagicMock() 

144 >>> mock_result2.scalars.return_value = [entity] 

145 >>> mock_db.execute.return_value = mock_result2 

146 >>> async def test_with_entities(): 

147 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"], include_entities=True) 

148 ... return len(tags) == 1 and tags[0].entities[0].name == 'E' 

149 >>> asyncio.run(test_with_entities()) 

150 True 

151 

152 Raises: 

153 SQLAlchemyError: If database query fails 

154 ValidationError: If invalid entity types are processed 

155 """ 

156 # Generate cache key from parameters 

157 entity_types_key = ":".join(sorted(entity_types)) if entity_types else "all" 

158 cache_key = f"{entity_types_key}:{include_entities}" 

159 

160 # SECURITY: Only use cache for unrestricted (admin-bypass) queries. 

161 # Scoped queries (public-only/team/user) are user/token specific and must not 

162 # reuse global cached results. 

163 is_scoped_query = user_email is not None or token_teams is not None 

164 

165 # Check cache first (only for non-entity + unrestricted queries) 

166 if not include_entities and not is_scoped_query: 

167 cache = _get_admin_stats_cache() 

168 cached = await cache.get_tags(cache_key) 

169 if cached is not None: 

170 # Reconstruct TagInfo objects from cached dicts 

171 return [TagInfo.model_validate(t) for t in cached] 

172 

173 tag_data: Dict[str, Dict] = {} 

174 

175 # Define entity type mapping 

176 entity_map = { 

177 "tools": DbTool, 

178 "resources": DbResource, 

179 "prompts": DbPrompt, 

180 "servers": DbServer, 

181 "gateways": DbGateway, 

182 } 

183 

184 # If no entity types specified, use all 

185 if entity_types is None: 

186 entity_types = list(entity_map.keys()) 

187 

188 team_ids = await self._resolve_team_ids(db, user_email, token_teams) 

189 

190 # Collect tags from each requested entity type 

191 for entity_type in entity_types: 

192 if entity_type not in entity_map: 

193 continue 

194 

195 model = entity_map[entity_type] 

196 

197 # Query all entities with tags from this entity type 

198 if include_entities: 

199 # Get full entity details 

200 stmt = select(model).where(model.tags.isnot(None)) 

201 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids) 

202 result = db.execute(stmt) 

203 

204 for entity in result.scalars(): 

205 tags = entity.tags if entity.tags else [] 

206 for raw_tag in tags: 

207 tag = self._get_tag_id(raw_tag) 

208 if tag not in tag_data: 

209 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []} 

210 

211 # Create TaggedEntity 

212 # Determine the ID 

213 if hasattr(entity, "id") and entity.id is not None: 

214 entity_id = str(entity.id) 

215 elif entity_type == "resources" and hasattr(entity, "uri"): 

216 entity_id = str(entity.uri) 

217 else: 

218 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown") 

219 

220 # Determine the name 

221 if hasattr(entity, "name") and entity.name: 

222 entity_name = entity.name 

223 elif hasattr(entity, "original_name") and entity.original_name: 

224 entity_name = entity.original_name 

225 elif hasattr(entity, "uri"): 

226 entity_name = str(entity.uri) 

227 else: 

228 entity_name = entity_id 

229 

230 entity_info = TaggedEntity( 

231 id=entity_id, 

232 name=entity_name, 

233 type=entity_type[:-1], # Remove plural 's' 

234 description=entity.description if hasattr(entity, "description") else None, 

235 ) 

236 tag_data[tag]["entities"].append(entity_info) 

237 

238 # Update stats 

239 self._update_stats(tag_data[tag]["stats"], entity_type) 

240 else: 

241 # Just get tags without entity details 

242 stmt = select(model.tags).where(model.tags.isnot(None)) 

243 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids) 

244 result = db.execute(stmt) 

245 

246 for row in result: 

247 tags = row[0] if row[0] else [] 

248 for raw_tag in tags: 

249 tag = self._get_tag_id(raw_tag) 

250 if tag not in tag_data: 

251 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []} 

252 

253 # Update stats 

254 self._update_stats(tag_data[tag]["stats"], entity_type) 

255 

256 # Convert to TagInfo list 

257 tags = [TagInfo(name=tag, stats=data["stats"], entities=data["entities"] if include_entities else []) for tag, data in sorted(tag_data.items())] 

258 

259 # Store in cache (only for non-entity + unrestricted queries) 

260 if not include_entities and not is_scoped_query: 

261 cache = _get_admin_stats_cache() 

262 await cache.set_tags([t.model_dump() for t in tags], cache_key) 

263 

264 return tags 

265 

266 def _update_stats(self, stats: TagStats, entity_type: str) -> None: 

267 """Update statistics for a specific entity type. 

268 

269 This helper method increments the appropriate counter in the TagStats object 

270 based on the entity type and maintains the total count. 

271 

272 Args: 

273 stats: TagStats object to update with new counts 

274 entity_type: Type of entity to increment count for. Must be one of: 

275 'tools', 'resources', 'prompts', 'servers', 'gateways' 

276 

277 Example: 

278 >>> from mcpgateway.schemas import TagStats 

279 >>> service = TagService() 

280 >>> stats = TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0) 

281 >>> 

282 >>> # Test updating tool stats 

283 >>> service._update_stats(stats, "tools") 

284 >>> stats.tools 

285 1 

286 >>> stats.total 

287 1 

288 >>> 

289 >>> # Test updating resource stats 

290 >>> service._update_stats(stats, "resources") 

291 >>> stats.resources 

292 1 

293 >>> stats.total 

294 2 

295 >>> 

296 >>> # Test with invalid entity type (should not crash) 

297 >>> service._update_stats(stats, "invalid") 

298 >>> stats.total # Should remain 2 

299 2 

300 """ 

301 if entity_type == "tools": 

302 stats.tools += 1 

303 stats.total += 1 

304 elif entity_type == "resources": 

305 stats.resources += 1 

306 stats.total += 1 

307 elif entity_type == "prompts": 

308 stats.prompts += 1 

309 stats.total += 1 

310 elif entity_type == "servers": 

311 stats.servers += 1 

312 stats.total += 1 

313 elif entity_type == "gateways": 

314 stats.gateways += 1 

315 stats.total += 1 

316 # Invalid entity types are ignored (no increment) 

317 

318 def _get_tag_id(self, tag) -> str: 

319 """Return the tag id for a tag entry which may be a string or a dict. 

320 

321 Supports legacy string tags and new dict tags with an 'id' field. 

322 Falls back to 'label' or the string representation when 'id' is missing. 

323 

324 Args: 

325 tag: Tag value which may be a string (legacy) or a dict with an 

326 'id' or 'label' key. 

327 

328 Returns: 

329 The normalized tag id as a string. 

330 """ 

331 if isinstance(tag, str): 

332 return tag 

333 if isinstance(tag, dict): 

334 return tag.get("id") or tag.get("label") or str(tag) 

335 return str(tag) 

336 

337 async def _resolve_team_ids(self, db: Session, user_email: Optional[str], token_teams: Optional[List[str]]) -> List[str]: 

338 """Resolve effective team IDs for scoped visibility checks. 

339 

340 Args: 

341 db: Database session 

342 user_email: Caller email for DB-based team lookup when token teams are not explicit 

343 token_teams: Explicit token team scope when present 

344 

345 Returns: 

346 Effective team IDs used to build visibility filters. 

347 """ 

348 if token_teams is not None: 

349 return token_teams 

350 if not user_email: 

351 return [] 

352 

353 # First-Party 

354 from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel 

355 

356 team_service = TeamManagementService(db) 

357 user_teams = await team_service.get_user_teams(user_email) 

358 return [team.id for team in user_teams] 

359 

360 def _apply_visibility_scope(self, stmt, model, user_email: Optional[str], token_teams: Optional[List[str]], team_ids: List[str]): 

361 """Apply token/user visibility scope to a SQLAlchemy statement. 

362 

363 Semantics mirror list/read endpoints: 

364 - token_teams is None and user_email is None -> unrestricted (admin bypass) 

365 - token_teams == [] -> public-only 

366 - token_teams == [...] -> public + matching-team (+ owner if user_email present) 

367 - token_teams is None and user_email present -> use DB team memberships 

368 

369 Args: 

370 stmt: SQLAlchemy statement to constrain 

371 model: ORM model that includes visibility/team/owner columns 

372 user_email: Caller email used for owner visibility 

373 token_teams: Explicit token team scope when present 

374 team_ids: Effective team IDs for team visibility 

375 

376 Returns: 

377 Scoped SQLAlchemy statement. 

378 """ 

379 if token_teams is None and user_email is None: 

380 return stmt 

381 

382 is_public_only_token = token_teams is not None and len(token_teams) == 0 

383 access_conditions = [model.visibility == "public"] 

384 

385 if not is_public_only_token and user_email: 

386 access_conditions.append(model.owner_email == user_email) 

387 

388 if team_ids: 

389 access_conditions.append(and_(model.team_id.in_(team_ids), model.visibility.in_(["team", "public"]))) 

390 

391 return stmt.where(or_(*access_conditions)) 

392 

393 async def get_entities_by_tag( 

394 self, 

395 db: Session, 

396 tag_name: str, 

397 entity_types: Optional[List[str]] = None, 

398 user_email: Optional[str] = None, 

399 token_teams: Optional[List[str]] = None, 

400 ) -> List[TaggedEntity]: 

401 """Get all entities that have a specific tag. 

402 

403 This method searches across specified entity types to find all entities 

404 that contain the given tag. It returns simplified entity representations 

405 optimized for tag-based discovery and filtering. 

406 

407 Args: 

408 db: Database session for querying entity data 

409 tag_name: The exact tag to search for (case sensitive) 

410 entity_types: Optional list of entity types to search within. 

411 Valid types: ['tools', 'resources', 'prompts', 'servers', 'gateways'] 

412 If None, searches all entity types 

413 user_email: Caller email used for owner/team visibility checks 

414 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope) 

415 

416 Returns: 

417 List of TaggedEntity objects containing basic entity information. 

418 Each TaggedEntity includes: id, name, type, and description. 

419 Results are not sorted and may contain entities from different types. 

420 

421 Example: 

422 >>> import asyncio 

423 >>> from unittest.mock import MagicMock 

424 >>> 

425 >>> # Setup service and mock database 

426 >>> service = TagService() 

427 >>> mock_db = MagicMock() 

428 >>> mock_db.get_bind.return_value.dialect.name = "sqlite" 

429 >>> 

430 >>> # Mock entity with tag 

431 >>> mock_entity = MagicMock() 

432 >>> mock_entity.id = "test-123" 

433 >>> mock_entity.name = "Test Entity" 

434 >>> mock_entity.description = "A test entity" 

435 >>> mock_entity.tags = ["api", "test", "database"] 

436 >>> 

437 >>> # Mock database result 

438 >>> mock_result = MagicMock() 

439 >>> mock_result.scalars.return_value = [mock_entity] 

440 >>> mock_db.execute.return_value = mock_result 

441 >>> 

442 >>> # Test entity lookup by tag 

443 >>> async def test_entity_lookup(): 

444 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"]) 

445 ... return len(entities) 

446 >>> asyncio.run(test_entity_lookup()) 

447 1 

448 

449 >>> # Test with non-existent tag 

450 >>> mock_entity.tags = ["different", "tags"] 

451 >>> async def test_no_match(): 

452 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"]) 

453 ... return len(entities) 

454 >>> asyncio.run(test_no_match()) 

455 0 

456 

457 Note: 

458 - Tag matching is exact and case-sensitive 

459 - Entities without the specified tag are filtered out after database query 

460 - Performance scales with the number of entities in filtered types 

461 - Uses json_contains_tag_expr for cross-database filtering (PostgreSQL/SQLite) 

462 """ 

463 entities = [] 

464 

465 # Define entity type mapping 

466 entity_map = { 

467 "tools": DbTool, 

468 "resources": DbResource, 

469 "prompts": DbPrompt, 

470 "servers": DbServer, 

471 "gateways": DbGateway, 

472 } 

473 

474 # If no entity types specified, use all 

475 if entity_types is None: 

476 entity_types = list(entity_map.keys()) 

477 

478 team_ids = await self._resolve_team_ids(db, user_email, token_teams) 

479 

480 for entity_type in entity_types: 

481 if entity_type not in entity_map: 

482 continue 

483 

484 model = entity_map[entity_type] 

485 

486 # Query entities that have this tag 

487 # Using json_contains_tag_expr for cross-database compatibility (PostgreSQL/SQLite) 

488 stmt = select(model).where(json_contains_tag_expr(db, model.tags, [tag_name], match_any=True)) 

489 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids) 

490 result = db.execute(stmt) 

491 

492 for entity in result.scalars(): 

493 entity_tags = entity.tags or [] 

494 entity_tag_ids = [self._get_tag_id(t) for t in entity_tags] 

495 if tag_name in entity_tag_ids: 

496 # Determine the ID 

497 if hasattr(entity, "id") and entity.id is not None: 

498 entity_id = str(entity.id) 

499 elif entity_type == "resources" and hasattr(entity, "uri"): 

500 entity_id = str(entity.uri) 

501 else: 

502 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown") 

503 

504 # Determine the name 

505 if hasattr(entity, "name") and entity.name: 

506 entity_name = entity.name 

507 elif hasattr(entity, "original_name") and entity.original_name: 

508 entity_name = entity.original_name 

509 elif hasattr(entity, "uri"): 

510 entity_name = str(entity.uri) 

511 else: 

512 entity_name = entity_id 

513 

514 entity_info = TaggedEntity( 

515 id=entity_id, 

516 name=entity_name, 

517 type=entity_type[:-1], # Remove plural 's' 

518 description=entity.description if hasattr(entity, "description") else None, 

519 ) 

520 entities.append(entity_info) 

521 

522 return entities 

523 

524 async def get_tag_counts(self, db: Session) -> Dict[str, int]: 

525 """Get count of unique tags per entity type. 

526 

527 This method calculates the total number of tag instances (not unique tag names) 

528 across all entity types. Useful for analytics and capacity planning. 

529 

530 Args: 

531 db: Database session for querying tag data 

532 

533 Returns: 

534 Dictionary mapping entity type names to total tag counts. 

535 Keys: 'tools', 'resources', 'prompts', 'servers', 'gateways' 

536 Values: Integer counts of total tag instances in each type 

537 

538 Example: 

539 >>> import asyncio 

540 >>> from unittest.mock import MagicMock 

541 >>> 

542 >>> # Setup service and mock database 

543 >>> service = TagService() 

544 >>> mock_db = MagicMock() 

545 >>> 

546 >>> # Mock tag count results 

547 >>> mock_db.execute.return_value.scalars.return_value.all.return_value = [2, 1, 3] # 3 entities with 2, 1, 3 tags each 

548 >>> 

549 >>> # Execute method with mocked responses (same values reused for simplicity) 

550 >>> class _Res: 

551 ... def scalars(self): 

552 ... class _S: 

553 ... def all(self_inner): 

554 ... return [2, 1, 3] 

555 ... return _S() 

556 >>> mock_db.execute.return_value = _Res() 

557 >>> counts = asyncio.run(service.get_tag_counts(mock_db)) 

558 >>> counts['tools'] 

559 6 

560 >>> all(isinstance(v, int) for v in counts.values()) 

561 True 

562 >>> len(counts) 

563 5 

564 

565 Note: 

566 - Counts tag instances, not unique tag names 

567 - An entity with 3 tags contributes 3 to the count 

568 - Empty or null tag arrays contribute 0 to the count 

569 - Uses json_array_length() for efficient counting 

570 """ 

571 counts = {} 

572 

573 # Count unique tags for tools 

574 tool_tags_stmt = select(func.json_array_length(DbTool.tags)).where(DbTool.tags.isnot(None)) 

575 tool_tags = db.execute(tool_tags_stmt).scalars().all() 

576 counts["tools"] = sum(tool_tags) 

577 

578 # Count unique tags for resources 

579 resource_tags_stmt = select(func.json_array_length(DbResource.tags)).where(DbResource.tags.isnot(None)) 

580 resource_tags = db.execute(resource_tags_stmt).scalars().all() 

581 counts["resources"] = sum(resource_tags) 

582 

583 # Count unique tags for prompts 

584 prompt_tags_stmt = select(func.json_array_length(DbPrompt.tags)).where(DbPrompt.tags.isnot(None)) 

585 prompt_tags = db.execute(prompt_tags_stmt).scalars().all() 

586 counts["prompts"] = sum(prompt_tags) 

587 

588 # Count unique tags for servers 

589 server_tags_stmt = select(func.json_array_length(DbServer.tags)).where(DbServer.tags.isnot(None)) 

590 server_tags = db.execute(server_tags_stmt).scalars().all() 

591 counts["servers"] = sum(server_tags) 

592 

593 # Count unique tags for gateways 

594 gateway_tags_stmt = select(func.json_array_length(DbGateway.tags)).where(DbGateway.tags.isnot(None)) 

595 gateway_tags = db.execute(gateway_tags_stmt).scalars().all() 

596 counts["gateways"] = sum(gateway_tags) 

597 

598 return counts