Coverage for mcpgateway / services / tag_service.py: 100%
143 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/tag_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Tag Service Implementation.
8This module implements tag management and retrieval for all entities in the MCP Gateway.
9It handles:
10- Fetching all unique tags across entities
11- Filtering tags by entity type
12- Tag statistics and counts
13- Retrieving entities that have specific tags
14"""
16# Standard
17import logging
18from typing import Dict, List, Optional
20# Third-Party
21from sqlalchemy import func, select
22from sqlalchemy.orm import Session
24# First-Party
25from mcpgateway.db import Gateway as DbGateway
26from mcpgateway.db import Prompt as DbPrompt
27from mcpgateway.db import Resource as DbResource
28from mcpgateway.db import Server as DbServer
29from mcpgateway.db import Tool as DbTool
30from mcpgateway.schemas import TaggedEntity, TagInfo, TagStats
31from mcpgateway.utils.sqlalchemy_modifier import json_contains_tag_expr
33logger = logging.getLogger(__name__)
35# Cache import (lazy to avoid circular dependencies)
36_ADMIN_STATS_CACHE = None
39def _get_admin_stats_cache():
40 """Get admin stats cache singleton lazily.
42 Returns:
43 AdminStatsCache instance.
44 """
45 global _ADMIN_STATS_CACHE # pylint: disable=global-statement
46 if _ADMIN_STATS_CACHE is None:
47 # First-Party
48 from mcpgateway.cache.admin_stats_cache import admin_stats_cache # pylint: disable=import-outside-toplevel
50 _ADMIN_STATS_CACHE = admin_stats_cache
51 return _ADMIN_STATS_CACHE
54class TagService:
55 """Service for managing and retrieving tags across all entities.
57 This service provides comprehensive tag management functionality across all MCP Gateway
58 entity types (tools, resources, prompts, servers, gateways). It handles tag discovery,
59 entity lookup by tags, and statistics aggregation.
61 Example:
62 >>> from unittest.mock import MagicMock
63 >>> from mcpgateway.schemas import TagInfo, TagStats, TaggedEntity
64 >>>
65 >>> # Create service instance
66 >>> service = TagService()
67 >>>
68 >>> # Mock database session
69 >>> mock_db = MagicMock()
70 >>>
71 >>> # Test basic functionality
72 >>> isinstance(service, TagService)
73 True
74 """
76 async def get_all_tags(self, db: Session, entity_types: Optional[List[str]] = None, include_entities: bool = False) -> List[TagInfo]:
77 """Retrieve all unique tags across specified entity types.
79 This method aggregates tags from multiple entity types and returns comprehensive
80 statistics about tag usage. It can optionally include detailed information about
81 which entities have each tag.
83 Args:
84 db: Database session for querying entity data
85 entity_types: List of entity types to filter by. Valid types are:
86 ['tools', 'resources', 'prompts', 'servers', 'gateways'].
87 If None, returns tags from all entity types.
88 include_entities: Whether to include the list of entities that have each tag.
89 If False, only statistics are returned for better performance.
91 Returns:
92 List of TagInfo objects containing tag details, sorted alphabetically by tag name.
93 Each TagInfo includes:
94 - name: The tag name
95 - stats: Usage statistics per entity type
96 - entities: List of entities with this tag (if include_entities=True)
98 Example:
99 >>> import asyncio
100 >>> from unittest.mock import MagicMock, AsyncMock
101 >>>
102 >>> # Create service and mock database
103 >>> service = TagService()
104 >>> mock_db = MagicMock()
105 >>>
106 >>> # Mock empty result
107 >>> mock_db.execute.return_value.__iter__ = lambda self: iter([])
108 >>>
109 >>> # Test with empty database
110 >>> async def test_empty():
111 ... tags = await service.get_all_tags(mock_db)
112 ... return len(tags)
113 >>> asyncio.run(test_empty())
114 0
116 >>> # Mock result with tag data
117 >>> mock_result = MagicMock()
118 >>> mock_result.__iter__ = lambda self: iter([
119 ... (["api", "database"],),
120 ... (["api", "web"],),
121 ... ])
122 >>> mock_db.execute.return_value = mock_result
123 >>>
124 >>> # Test with tag data
125 >>> async def test_with_tags():
126 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"])
127 ... return len(tags) >= 2 # Should have at least api, database, web tags
128 >>> asyncio.run(test_with_tags())
129 True
131 >>> # include_entities=True path
132 >>> from types import SimpleNamespace
133 >>> entity = SimpleNamespace(id='1', name='E', description='d', tags=['api'])
134 >>> mock_result2 = MagicMock()
135 >>> mock_result2.scalars.return_value = [entity]
136 >>> mock_db.execute.return_value = mock_result2
137 >>> async def test_with_entities():
138 ... tags = await service.get_all_tags(mock_db, entity_types=["tools"], include_entities=True)
139 ... return len(tags) == 1 and tags[0].entities[0].name == 'E'
140 >>> asyncio.run(test_with_entities())
141 True
143 Raises:
144 SQLAlchemyError: If database query fails
145 ValidationError: If invalid entity types are processed
146 """
147 # Generate cache key from parameters
148 entity_types_key = ":".join(sorted(entity_types)) if entity_types else "all"
149 cache_key = f"{entity_types_key}:{include_entities}"
151 # Check cache first (only for non-entity queries as entity data is large)
152 if not include_entities:
153 cache = _get_admin_stats_cache()
154 cached = await cache.get_tags(cache_key)
155 if cached is not None:
156 # Reconstruct TagInfo objects from cached dicts
157 return [TagInfo.model_validate(t) for t in cached]
159 tag_data: Dict[str, Dict] = {}
161 # Define entity type mapping
162 entity_map = {
163 "tools": DbTool,
164 "resources": DbResource,
165 "prompts": DbPrompt,
166 "servers": DbServer,
167 "gateways": DbGateway,
168 }
170 # If no entity types specified, use all
171 if entity_types is None:
172 entity_types = list(entity_map.keys())
174 # Collect tags from each requested entity type
175 for entity_type in entity_types:
176 if entity_type not in entity_map:
177 continue
179 model = entity_map[entity_type]
181 # Query all entities with tags from this entity type
182 if include_entities:
183 # Get full entity details
184 stmt = select(model).where(model.tags.isnot(None))
185 result = db.execute(stmt)
187 for entity in result.scalars():
188 tags = entity.tags if entity.tags else []
189 for raw_tag in tags:
190 tag = self._get_tag_id(raw_tag)
191 if tag not in tag_data:
192 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []}
194 # Create TaggedEntity
195 # Determine the ID
196 if hasattr(entity, "id") and entity.id is not None:
197 entity_id = str(entity.id)
198 elif entity_type == "resources" and hasattr(entity, "uri"):
199 entity_id = str(entity.uri)
200 else:
201 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown")
203 # Determine the name
204 if hasattr(entity, "name") and entity.name:
205 entity_name = entity.name
206 elif hasattr(entity, "original_name") and entity.original_name:
207 entity_name = entity.original_name
208 elif hasattr(entity, "uri"):
209 entity_name = str(entity.uri)
210 else:
211 entity_name = entity_id
213 entity_info = TaggedEntity(
214 id=entity_id,
215 name=entity_name,
216 type=entity_type[:-1], # Remove plural 's'
217 description=entity.description if hasattr(entity, "description") else None,
218 )
219 tag_data[tag]["entities"].append(entity_info)
221 # Update stats
222 self._update_stats(tag_data[tag]["stats"], entity_type)
223 else:
224 # Just get tags without entity details
225 stmt = select(model.tags).where(model.tags.isnot(None))
226 result = db.execute(stmt)
228 for row in result:
229 tags = row[0] if row[0] else []
230 for raw_tag in tags:
231 tag = self._get_tag_id(raw_tag)
232 if tag not in tag_data:
233 tag_data[tag] = {"stats": TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0), "entities": []}
235 # Update stats
236 self._update_stats(tag_data[tag]["stats"], entity_type)
238 # Convert to TagInfo list
239 tags = [TagInfo(name=tag, stats=data["stats"], entities=data["entities"] if include_entities else []) for tag, data in sorted(tag_data.items())]
241 # Store in cache (only for non-entity queries)
242 if not include_entities:
243 cache = _get_admin_stats_cache()
244 await cache.set_tags([t.model_dump() for t in tags], cache_key)
246 return tags
248 def _update_stats(self, stats: TagStats, entity_type: str) -> None:
249 """Update statistics for a specific entity type.
251 This helper method increments the appropriate counter in the TagStats object
252 based on the entity type and maintains the total count.
254 Args:
255 stats: TagStats object to update with new counts
256 entity_type: Type of entity to increment count for. Must be one of:
257 'tools', 'resources', 'prompts', 'servers', 'gateways'
259 Example:
260 >>> from mcpgateway.schemas import TagStats
261 >>> service = TagService()
262 >>> stats = TagStats(tools=0, resources=0, prompts=0, servers=0, gateways=0, total=0)
263 >>>
264 >>> # Test updating tool stats
265 >>> service._update_stats(stats, "tools")
266 >>> stats.tools
267 1
268 >>> stats.total
269 1
270 >>>
271 >>> # Test updating resource stats
272 >>> service._update_stats(stats, "resources")
273 >>> stats.resources
274 1
275 >>> stats.total
276 2
277 >>>
278 >>> # Test with invalid entity type (should not crash)
279 >>> service._update_stats(stats, "invalid")
280 >>> stats.total # Should remain 2
281 2
282 """
283 if entity_type == "tools":
284 stats.tools += 1
285 stats.total += 1
286 elif entity_type == "resources":
287 stats.resources += 1
288 stats.total += 1
289 elif entity_type == "prompts":
290 stats.prompts += 1
291 stats.total += 1
292 elif entity_type == "servers":
293 stats.servers += 1
294 stats.total += 1
295 elif entity_type == "gateways":
296 stats.gateways += 1
297 stats.total += 1
298 # Invalid entity types are ignored (no increment)
300 def _get_tag_id(self, tag) -> str:
301 """Return the tag id for a tag entry which may be a string or a dict.
303 Supports legacy string tags and new dict tags with an 'id' field.
304 Falls back to 'label' or the string representation when 'id' is missing.
306 Args:
307 tag: Tag value which may be a string (legacy) or a dict with an
308 'id' or 'label' key.
310 Returns:
311 The normalized tag id as a string.
312 """
313 if isinstance(tag, str):
314 return tag
315 if isinstance(tag, dict):
316 return tag.get("id") or tag.get("label") or str(tag)
317 return str(tag)
319 async def get_entities_by_tag(self, db: Session, tag_name: str, entity_types: Optional[List[str]] = None) -> List[TaggedEntity]:
320 """Get all entities that have a specific tag.
322 This method searches across specified entity types to find all entities
323 that contain the given tag. It returns simplified entity representations
324 optimized for tag-based discovery and filtering.
326 Args:
327 db: Database session for querying entity data
328 tag_name: The exact tag to search for (case sensitive)
329 entity_types: Optional list of entity types to search within.
330 Valid types: ['tools', 'resources', 'prompts', 'servers', 'gateways']
331 If None, searches all entity types
333 Returns:
334 List of TaggedEntity objects containing basic entity information.
335 Each TaggedEntity includes: id, name, type, and description.
336 Results are not sorted and may contain entities from different types.
338 Example:
339 >>> import asyncio
340 >>> from unittest.mock import MagicMock
341 >>>
342 >>> # Setup service and mock database
343 >>> service = TagService()
344 >>> mock_db = MagicMock()
345 >>> mock_db.get_bind.return_value.dialect.name = "sqlite"
346 >>>
347 >>> # Mock entity with tag
348 >>> mock_entity = MagicMock()
349 >>> mock_entity.id = "test-123"
350 >>> mock_entity.name = "Test Entity"
351 >>> mock_entity.description = "A test entity"
352 >>> mock_entity.tags = ["api", "test", "database"]
353 >>>
354 >>> # Mock database result
355 >>> mock_result = MagicMock()
356 >>> mock_result.scalars.return_value = [mock_entity]
357 >>> mock_db.execute.return_value = mock_result
358 >>>
359 >>> # Test entity lookup by tag
360 >>> async def test_entity_lookup():
361 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"])
362 ... return len(entities)
363 >>> asyncio.run(test_entity_lookup())
364 1
366 >>> # Test with non-existent tag
367 >>> mock_entity.tags = ["different", "tags"]
368 >>> async def test_no_match():
369 ... entities = await service.get_entities_by_tag(mock_db, "api", ["tools"])
370 ... return len(entities)
371 >>> asyncio.run(test_no_match())
372 0
374 Note:
375 - Tag matching is exact and case-sensitive
376 - Entities without the specified tag are filtered out after database query
377 - Performance scales with the number of entities in filtered types
378 - Uses json_contains_tag_expr for cross-database filtering (PostgreSQL/SQLite)
379 """
380 entities = []
382 # Define entity type mapping
383 entity_map = {
384 "tools": DbTool,
385 "resources": DbResource,
386 "prompts": DbPrompt,
387 "servers": DbServer,
388 "gateways": DbGateway,
389 }
391 # If no entity types specified, use all
392 if entity_types is None:
393 entity_types = list(entity_map.keys())
395 for entity_type in entity_types:
396 if entity_type not in entity_map:
397 continue
399 model = entity_map[entity_type]
401 # Query entities that have this tag
402 # Using json_contains_tag_expr for cross-database compatibility (PostgreSQL/SQLite)
403 stmt = select(model).where(json_contains_tag_expr(db, model.tags, [tag_name], match_any=True))
404 result = db.execute(stmt)
406 for entity in result.scalars():
407 entity_tags = entity.tags or []
408 entity_tag_ids = [self._get_tag_id(t) for t in entity_tags]
409 if tag_name in entity_tag_ids:
410 # Determine the ID
411 if hasattr(entity, "id") and entity.id is not None:
412 entity_id = str(entity.id)
413 elif entity_type == "resources" and hasattr(entity, "uri"):
414 entity_id = str(entity.uri)
415 else:
416 entity_id = str(entity.name if hasattr(entity, "name") and entity.name else "unknown")
418 # Determine the name
419 if hasattr(entity, "name") and entity.name:
420 entity_name = entity.name
421 elif hasattr(entity, "original_name") and entity.original_name:
422 entity_name = entity.original_name
423 elif hasattr(entity, "uri"):
424 entity_name = str(entity.uri)
425 else:
426 entity_name = entity_id
428 entity_info = TaggedEntity(
429 id=entity_id,
430 name=entity_name,
431 type=entity_type[:-1], # Remove plural 's'
432 description=entity.description if hasattr(entity, "description") else None,
433 )
434 entities.append(entity_info)
436 return entities
438 async def get_tag_counts(self, db: Session) -> Dict[str, int]:
439 """Get count of unique tags per entity type.
441 This method calculates the total number of tag instances (not unique tag names)
442 across all entity types. Useful for analytics and capacity planning.
444 Args:
445 db: Database session for querying tag data
447 Returns:
448 Dictionary mapping entity type names to total tag counts.
449 Keys: 'tools', 'resources', 'prompts', 'servers', 'gateways'
450 Values: Integer counts of total tag instances in each type
452 Example:
453 >>> import asyncio
454 >>> from unittest.mock import MagicMock
455 >>>
456 >>> # Setup service and mock database
457 >>> service = TagService()
458 >>> mock_db = MagicMock()
459 >>>
460 >>> # Mock tag count results
461 >>> mock_db.execute.return_value.scalars.return_value.all.return_value = [2, 1, 3] # 3 entities with 2, 1, 3 tags each
462 >>>
463 >>> # Execute method with mocked responses (same values reused for simplicity)
464 >>> class _Res:
465 ... def scalars(self):
466 ... class _S:
467 ... def all(self_inner):
468 ... return [2, 1, 3]
469 ... return _S()
470 >>> mock_db.execute.return_value = _Res()
471 >>> counts = asyncio.run(service.get_tag_counts(mock_db))
472 >>> counts['tools']
473 6
474 >>> all(isinstance(v, int) for v in counts.values())
475 True
476 >>> len(counts)
477 5
479 Note:
480 - Counts tag instances, not unique tag names
481 - An entity with 3 tags contributes 3 to the count
482 - Empty or null tag arrays contribute 0 to the count
483 - Uses json_array_length() for efficient counting
484 """
485 counts = {}
487 # Count unique tags for tools
488 tool_tags_stmt = select(func.json_array_length(DbTool.tags)).where(DbTool.tags.isnot(None))
489 tool_tags = db.execute(tool_tags_stmt).scalars().all()
490 counts["tools"] = sum(tool_tags)
492 # Count unique tags for resources
493 resource_tags_stmt = select(func.json_array_length(DbResource.tags)).where(DbResource.tags.isnot(None))
494 resource_tags = db.execute(resource_tags_stmt).scalars().all()
495 counts["resources"] = sum(resource_tags)
497 # Count unique tags for prompts
498 prompt_tags_stmt = select(func.json_array_length(DbPrompt.tags)).where(DbPrompt.tags.isnot(None))
499 prompt_tags = db.execute(prompt_tags_stmt).scalars().all()
500 counts["prompts"] = sum(prompt_tags)
502 # Count unique tags for servers
503 server_tags_stmt = select(func.json_array_length(DbServer.tags)).where(DbServer.tags.isnot(None))
504 server_tags = db.execute(server_tags_stmt).scalars().all()
505 counts["servers"] = sum(server_tags)
507 # Count unique tags for gateways
508 gateway_tags_stmt = select(func.json_array_length(DbGateway.tags)).where(DbGateway.tags.isnot(None))
509 gateway_tags = db.execute(gateway_tags_stmt).scalars().all()
510 counts["gateways"] = sum(gateway_tags)
512 return counts