Coverage for mcpgateway / services / tag_service.py: 100%
168 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/tag_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Tag Service Implementation.
8This module implements tag management and retrieval for all entities in ContextForge.
9It handles:
10- Fetching all unique tags across entities
11- Filtering tags by entity type
12- Tag statistics and counts
13- Retrieving entities that have specific tags
14"""
16# Standard
17import logging
18from typing import Dict, List, Optional
20# Third-Party
21from sqlalchemy import and_, func, or_, select
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.db import Gateway as DbGateway
26from mcpgateway.db import Prompt as DbPrompt
27from mcpgateway.db import Resource as DbResource
28from mcpgateway.db import Server as DbServer
29from mcpgateway.db import Tool as DbTool
30from mcpgateway.schemas import TaggedEntity, TagInfo, TagStats
31from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
33logger = logging.getLogger(__name__)
35# Cache import (lazy to avoid circular dependencies)
36_ADMIN_STATS_CACHE = None
39def _get_admin_stats_cache():
40 """Get admin stats cache singleton lazily.
42 Returns:
43 AdminStatsCache instance.
44 """
45 global _ADMIN_STATS_CACHE # pylint: disable=global-statement
46 if _ADMIN_STATS_CACHE is None:
47 # First-Party
48 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
50 _ADMIN_STATS_CACHE = admin_stats_cache
51 return _ADMIN_STATS_CACHE
54class TagService:
55 """Service for managing and retrieving tags across all entities.
57 This service provides comprehensive tag management functionality across all ContextForge
58 entity types (tools, resources, prompts, servers, gateways). It handles tag discovery,
59 entity lookup by tags, and statistics aggregation.
61 Example:
62 >>> from unittest.mock import MagicMock
63 >>> from mcpgateway.schemas import TagInfo, TagStats, TaggedEntity
64 >>>
65 >>> # Create service instance
66 >>> service = TagService()
67 >>>
68 >>> # Mock database session
69 >>> mock_db = MagicMock()
70 >>>
71 >>> # Test basic functionality
72 >>> isinstance(service, TagService)
73 True
74 """
76 async def get_all_tags(
77 self,
78 db: Session,
79 entity_types: Optional[List[str]] = None,
80 include_entities: bool = False,
81 user_email: Optional[str] = None,
82 token_teams: Optional[List[str]] = None,
83 ) -> List[TagInfo]:
84 """Retrieve all unique tags across specified entity types.
86 This method aggregates tags from multiple entity types and returns comprehensive
87 statistics about tag usage. It can optionally include detailed information about
88 which entities have each tag.
90 Args:
91 db: Database session for querying entity data
92 entity_types: List of entity types to filter by. Valid types are:
93 ['tools', 'resources', 'prompts', 'servers', 'gateways'].
94 If None, returns tags from all entity types.
95 include_entities: Whether to include the list of entities that have each tag.
96 If False, only statistics are returned for better performance.
97 user_email: Caller email used for owner/team visibility checks
98 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope)
100 Returns:
101 List of TagInfo objects containing tag details, sorted alphabetically by tag name.
102 Each TagInfo includes:
103 - name: The tag name
104 - stats: Usage statistics per entity type
105 - entities: List of entities with this tag (if include_entities=True)
107 Example:
108 >>> import asyncio
109 >>> from unittest.mock import MagicMock, AsyncMock
110 >>>
111 >>> # Create service and mock database
112 >>> service = TagService()
113 >>> mock_db = MagicMock()
114 >>>
115 >>> # Mock empty result
116 >>> mock_db.execute.return_value.__iter__ = lambda self: iter([])
117 >>>
118 >>> # Test with empty database
119 >>> async def test_empty():
120 ... tags = await service.get_all_tags(mock_db)
121 ... return len(tags)
122 >>> asyncio.run(test_empty())
123 0
125 >>> # Mock result with tag data
126 >>> mock_result = MagicMock()
127 >>> mock_result.__iter__ = lambda self: iter([
128 ... (["api", "database"],),
129 ... (["api", "web"],),
130 ... ])
131 >>> mock_db.execute.return_value = mock_result
132 >>>
133 >>> # Test with tag data
134 >>> async def test_with_tags():
135 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"])
136 ... return len(tags) >= 2 # Should have at least api, database, web tags
137 >>> asyncio.run(test_with_tags())
138 True
140 >>> # include_entities=True path
141 >>> from types import SimpleNamespace
142 >>> entity = SimpleNamespace(id='1', name='E', description='d', tags=['api'])
143 >>> mock_result2 = MagicMock()
144 >>> mock_result2.scalars.return_value = [entity]
145 >>> mock_db.execute.return_value = mock_result2
146 >>> async def test_with_entities():
147 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"], include_entities=True)
148 ... return len(tags) == 1 and tags[0].entities[0].name == 'E'
149 >>> asyncio.run(test_with_entities())
150 True
152 Raises:
153 SQLAlchemyError: If database query fails
154 ValidationError: If invalid entity types are processed
155 """
156 # Generate cache key from parameters
157 entity_types_key = ":".join(sorted(entity_types)) if entity_types else "all"
158 cache_key = f"{entity_types_key}:{include_entities}"
160 # SECURITY: Only use cache for unrestricted (admin-bypass) queries.
161 # Scoped queries (public-only/team/user) are user/token specific and must not
162 # reuse global cached results.
163 is_scoped_query = user_email is not None or token_teams is not None
165 # Check cache first (only for non-entity + unrestricted queries)
166 if not include_entities and not is_scoped_query:
167 cache = _get_admin_stats_cache()
168 cached = await cache.get_tags(cache_key)
169 if cached is not None:
170 # Reconstruct TagInfo objects from cached dicts
171 return [TagInfo.model_validate(t) for t in cached]
173 tag_data: Dict[str, Dict] = {}
175 # Define entity type mapping
176 entity_map = {
177 "tools": DbTool,
178 "resources": DbResource,
179 "prompts": DbPrompt,
180 "servers": DbServer,
181 "gateways": DbGateway,
182 }
184 # If no entity types specified, use all
185 if entity_types is None:
186 entity_types = list(entity_map.keys())
188 team_ids = await self._resolve_team_ids(db, user_email, token_teams)
190 # Collect tags from each requested entity type
191 for entity_type in entity_types:
192 if entity_type not in entity_map:
193 continue
195 model = entity_map[entity_type]
197 # Query all entities with tags from this entity type
198 if include_entities:
199 # Get full entity details
200 stmt = select(model).where(model.tags.isnot(None))
201 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids)
202 result = db.execute(stmt)
204 for entity in result.scalars():
205 tags = entity.tags if entity.tags else []
206 for raw_tag in tags:
207 tag = self._get_tag_id(raw_tag)
208 if tag not in tag_data:
209 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []}
211 # Create TaggedEntity
212 # Determine the ID
213 if hasattr(entity, "id") and entity.id is not None:
214 entity_id = str(entity.id)
215 elif entity_type == "resources" and hasattr(entity, "uri"):
216 entity_id = str(entity.uri)
217 else:
218 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown")
220 # Determine the name
221 if hasattr(entity, "name") and entity.name:
222 entity_name = entity.name
223 elif hasattr(entity, "original_name") and entity.original_name:
224 entity_name = entity.original_name
225 elif hasattr(entity, "uri"):
226 entity_name = str(entity.uri)
227 else:
228 entity_name = entity_id
230 entity_info = TaggedEntity(
231 id=entity_id,
232 name=entity_name,
233 type=entity_type[:-1], # Remove plural 's'
234 description=entity.description if hasattr(entity, "description") else None,
235 )
236 tag_data[tag]["entities"].append(entity_info)
238 # Update stats
239 self._update_stats(tag_data[tag]["stats"], entity_type)
240 else:
241 # Just get tags without entity details
242 stmt = select(model.tags).where(model.tags.isnot(None))
243 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids)
244 result = db.execute(stmt)
246 for row in result:
247 tags = row[0] if row[0] else []
248 for raw_tag in tags:
249 tag = self._get_tag_id(raw_tag)
250 if tag not in tag_data:
251 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []}
253 # Update stats
254 self._update_stats(tag_data[tag]["stats"], entity_type)
256 # Convert to TagInfo list
257 tags = [TagInfo(name=tag, stats=data["stats"], entities=data["entities"] if include_entities else []) for tag, data in sorted(tag_data.items())]
259 # Store in cache (only for non-entity + unrestricted queries)
260 if not include_entities and not is_scoped_query:
261 cache = _get_admin_stats_cache()
262 await cache.set_tags([t.model_dump() for t in tags], cache_key)
264 return tags
266 def _update_stats(self, stats: TagStats, entity_type: str) -> None:
267 """Update statistics for a specific entity type.
269 This helper method increments the appropriate counter in the TagStats object
270 based on the entity type and maintains the total count.
272 Args:
273 stats: TagStats object to update with new counts
274 entity_type: Type of entity to increment count for. Must be one of:
275 'tools', 'resources', 'prompts', 'servers', 'gateways'
277 Example:
278 >>> from mcpgateway.schemas import TagStats
279 >>> service = TagService()
280 >>> stats = TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0)
281 >>>
282 >>> # Test updating tool stats
283 >>> service._update_stats(stats, "tools")
284 >>> stats.tools
285 1
286 >>> stats.total
287 1
288 >>>
289 >>> # Test updating resource stats
290 >>> service._update_stats(stats, "resources")
291 >>> stats.resources
292 1
293 >>> stats.total
294 2
295 >>>
296 >>> # Test with invalid entity type (should not crash)
297 >>> service._update_stats(stats, "invalid")
298 >>> stats.total # Should remain 2
299 2
300 """
301 if entity_type == "tools":
302 stats.tools += 1
303 stats.total += 1
304 elif entity_type == "resources":
305 stats.resources += 1
306 stats.total += 1
307 elif entity_type == "prompts":
308 stats.prompts += 1
309 stats.total += 1
310 elif entity_type == "servers":
311 stats.servers += 1
312 stats.total += 1
313 elif entity_type == "gateways":
314 stats.gateways += 1
315 stats.total += 1
316 # Invalid entity types are ignored (no increment)
318 def _get_tag_id(self, tag) -> str:
319 """Return the tag id for a tag entry which may be a string or a dict.
321 Supports legacy string tags and new dict tags with an 'id' field.
322 Falls back to 'label' or the string representation when 'id' is missing.
324 Args:
325 tag: Tag value which may be a string (legacy) or a dict with an
326 'id' or 'label' key.
328 Returns:
329 The normalized tag id as a string.
330 """
331 if isinstance(tag, str):
332 return tag
333 if isinstance(tag, dict):
334 return tag.get("id") or tag.get("label") or str(tag)
335 return str(tag)
337 async def _resolve_team_ids(self, db: Session, user_email: Optional[str], token_teams: Optional[List[str]]) -> List[str]:
338 """Resolve effective team IDs for scoped visibility checks.
340 Args:
341 db: Database session
342 user_email: Caller email for DB-based team lookup when token teams are not explicit
343 token_teams: Explicit token team scope when present
345 Returns:
346 Effective team IDs used to build visibility filters.
347 """
348 if token_teams is not None:
349 return token_teams
350 if not user_email:
351 return []
353 # First-Party
354 from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel
356 team_service = TeamManagementService(db)
357 user_teams = await team_service.get_user_teams(user_email)
358 return [team.id for team in user_teams]
360 def _apply_visibility_scope(self, stmt, model, user_email: Optional[str], token_teams: Optional[List[str]], team_ids: List[str]):
361 """Apply token/user visibility scope to a SQLAlchemy statement.
363 Semantics mirror list/read endpoints:
364 - token_teams is None and user_email is None -> unrestricted (admin bypass)
365 - token_teams == [] -> public-only
366 - token_teams == [...] -> public + matching-team (+ owner if user_email present)
367 - token_teams is None and user_email present -> use DB team memberships
369 Args:
370 stmt: SQLAlchemy statement to constrain
371 model: ORM model that includes visibility/team/owner columns
372 user_email: Caller email used for owner visibility
373 token_teams: Explicit token team scope when present
374 team_ids: Effective team IDs for team visibility
376 Returns:
377 Scoped SQLAlchemy statement.
378 """
379 if token_teams is None and user_email is None:
380 return stmt
382 is_public_only_token = token_teams is not None and len(token_teams) == 0
383 access_conditions = [model.visibility == "public"]
385 if not is_public_only_token and user_email:
386 access_conditions.append(model.owner_email == user_email)
388 if team_ids:
389 access_conditions.append(and_(model.team_id.in_(team_ids), model.visibility.in_(["team", "public"])))
391 return stmt.where(or_(*access_conditions))
393 async def get_entities_by_tag(
394 self,
395 db: Session,
396 tag_name: str,
397 entity_types: Optional[List[str]] = None,
398 user_email: Optional[str] = None,
399 token_teams: Optional[List[str]] = None,
400 ) -> List[TaggedEntity]:
401 """Get all entities that have a specific tag.
403 This method searches across specified entity types to find all entities
404 that contain the given tag. It returns simplified entity representations
405 optimized for tag-based discovery and filtering.
407 Args:
408 db: Database session for querying entity data
409 tag_name: The exact tag to search for (case sensitive)
410 entity_types: Optional list of entity types to search within.
411 Valid types: ['tools', 'resources', 'prompts', 'servers', 'gateways']
412 If None, searches all entity types
413 user_email: Caller email used for owner/team visibility checks
414 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope)
416 Returns:
417 List of TaggedEntity objects containing basic entity information.
418 Each TaggedEntity includes: id, name, type, and description.
419 Results are not sorted and may contain entities from different types.
421 Example:
422 >>> import asyncio
423 >>> from unittest.mock import MagicMock
424 >>>
425 >>> # Setup service and mock database
426 >>> service = TagService()
427 >>> mock_db = MagicMock()
428 >>> mock_db.get_bind.return_value.dialect.name = "sqlite"
429 >>>
430 >>> # Mock entity with tag
431 >>> mock_entity = MagicMock()
432 >>> mock_entity.id = "test-123"
433 >>> mock_entity.name = "Test Entity"
434 >>> mock_entity.description = "A test entity"
435 >>> mock_entity.tags = ["api", "test", "database"]
436 >>>
437 >>> # Mock database result
438 >>> mock_result = MagicMock()
439 >>> mock_result.scalars.return_value = [mock_entity]
440 >>> mock_db.execute.return_value = mock_result
441 >>>
442 >>> # Test entity lookup by tag
443 >>> async def test_entity_lookup():
444 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"])
445 ... return len(entities)
446 >>> asyncio.run(test_entity_lookup())
447 1
449 >>> # Test with non-existent tag
450 >>> mock_entity.tags = ["different", "tags"]
451 >>> async def test_no_match():
452 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"])
453 ... return len(entities)
454 >>> asyncio.run(test_no_match())
455 0
457 Note:
458 - Tag matching is exact and case-sensitive
459 - Entities without the specified tag are filtered out after database query
460 - Performance scales with the number of entities in filtered types
461 - Uses json_contains_tag_expr for cross-database filtering (PostgreSQL/SQLite)
462 """
463 entities = []
465 # Define entity type mapping
466 entity_map = {
467 "tools": DbTool,
468 "resources": DbResource,
469 "prompts": DbPrompt,
470 "servers": DbServer,
471 "gateways": DbGateway,
472 }
474 # If no entity types specified, use all
475 if entity_types is None:
476 entity_types = list(entity_map.keys())
478 team_ids = await self._resolve_team_ids(db, user_email, token_teams)
480 for entity_type in entity_types:
481 if entity_type not in entity_map:
482 continue
484 model = entity_map[entity_type]
486 # Query entities that have this tag
487 # Using json_contains_tag_expr for cross-database compatibility (PostgreSQL/SQLite)
488 stmt = select(model).where(json_contains_tag_expr(db, model.tags, [tag_name], match_any=True))
489 stmt = self._apply_visibility_scope(stmt, model, user_email=user_email, token_teams=token_teams, team_ids=team_ids)
490 result = db.execute(stmt)
492 for entity in result.scalars():
493 entity_tags = entity.tags or []
494 entity_tag_ids = [self._get_tag_id(t) for t in entity_tags]
495 if tag_name in entity_tag_ids:
496 # Determine the ID
497 if hasattr(entity, "id") and entity.id is not None:
498 entity_id = str(entity.id)
499 elif entity_type == "resources" and hasattr(entity, "uri"):
500 entity_id = str(entity.uri)
501 else:
502 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown")
504 # Determine the name
505 if hasattr(entity, "name") and entity.name:
506 entity_name = entity.name
507 elif hasattr(entity, "original_name") and entity.original_name:
508 entity_name = entity.original_name
509 elif hasattr(entity, "uri"):
510 entity_name = str(entity.uri)
511 else:
512 entity_name = entity_id
514 entity_info = TaggedEntity(
515 id=entity_id,
516 name=entity_name,
517 type=entity_type[:-1], # Remove plural 's'
518 description=entity.description if hasattr(entity, "description") else None,
519 )
520 entities.append(entity_info)
522 return entities
524 async def get_tag_counts(self, db: Session) -> Dict[str, int]:
525 """Get count of unique tags per entity type.
527 This method calculates the total number of tag instances (not unique tag names)
528 across all entity types. Useful for analytics and capacity planning.
530 Args:
531 db: Database session for querying tag data
533 Returns:
534 Dictionary mapping entity type names to total tag counts.
535 Keys: 'tools', 'resources', 'prompts', 'servers', 'gateways'
536 Values: Integer counts of total tag instances in each type
538 Example:
539 >>> import asyncio
540 >>> from unittest.mock import MagicMock
541 >>>
542 >>> # Setup service and mock database
543 >>> service = TagService()
544 >>> mock_db = MagicMock()
545 >>>
546 >>> # Mock tag count results
547 >>> mock_db.execute.return_value.scalars.return_value.all.return_value = [2, 1, 3] # 3 entities with 2, 1, 3 tags each
548 >>>
549 >>> # Execute method with mocked responses (same values reused for simplicity)
550 >>> class _Res:
551 ... def scalars(self):
552 ... class _S:
553 ... def all(self_inner):
554 ... return [2, 1, 3]
555 ... return _S()
556 >>> mock_db.execute.return_value = _Res()
557 >>> counts = asyncio.run(service.get_tag_counts(mock_db))
558 >>> counts['tools']
559 6
560 >>> all(isinstance(v, int) for v in counts.values())
561 True
562 >>> len(counts)
563 5
565 Note:
566 - Counts tag instances, not unique tag names
567 - An entity with 3 tags contributes 3 to the count
568 - Empty or null tag arrays contribute 0 to the count
569 - Uses json_array_length() for efficient counting
570 """
571 counts = {}
573 # Count unique tags for tools
574 tool_tags_stmt = select(func.json_array_length(DbTool.tags)).where(DbTool.tags.isnot(None))
575 tool_tags = db.execute(tool_tags_stmt).scalars().all()
576 counts["tools"] = sum(tool_tags)
578 # Count unique tags for resources
579 resource_tags_stmt = select(func.json_array_length(DbResource.tags)).where(DbResource.tags.isnot(None))
580 resource_tags = db.execute(resource_tags_stmt).scalars().all()
581 counts["resources"] = sum(resource_tags)
583 # Count unique tags for prompts
584 prompt_tags_stmt = select(func.json_array_length(DbPrompt.tags)).where(DbPrompt.tags.isnot(None))
585 prompt_tags = db.execute(prompt_tags_stmt).scalars().all()
586 counts["prompts"] = sum(prompt_tags)
588 # Count unique tags for servers
589 server_tags_stmt = select(func.json_array_length(DbServer.tags)).where(DbServer.tags.isnot(None))
590 server_tags = db.execute(server_tags_stmt).scalars().all()
591 counts["servers"] = sum(server_tags)
593 # Count unique tags for gateways
594 gateway_tags_stmt = select(func.json_array_length(DbGateway.tags)).where(DbGateway.tags.isnot(None))
595 gateway_tags = db.execute(gateway_tags_stmt).scalars().all()
596 counts["gateways"] = sum(gateway_tags)
598 return counts