Coverage for mcpgateway / services / completion_service.py: 100%
97 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/completion_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Completion Service Implementation.
8This module implements argument completion according to the MCP specification.
9It handles completion suggestions for prompt arguments and resource URIs.
11Examples:
12 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
13 >>> service = CompletionService()
14 >>> isinstance(service, CompletionService)
15 True
16 >>> service._custom_completions
17 {}
18"""
20# Standard
21from typing import Any, Dict, List, Optional
23# Third-Party
24from sqlalchemy import and_, desc, or_, select
25from sqlalchemy.orm import Session
27# First-Party
28from mcpgateway.common.models import CompleteResult
29from mcpgateway.db import Prompt as DbPrompt
30from mcpgateway.db import Resource as DbResource
31from mcpgateway.services.logging_service import LoggingService
33# Initialize logging service first
34logging_service = LoggingService()
35logger = logging_service.get_logger(__name__)
38class CompletionError(Exception):
39 """Base class for completion errors.
41 Examples:
42 >>> from mcpgateway.services.completion_service import CompletionError
43 >>> err = CompletionError("Invalid reference")
44 >>> str(err)
45 'Invalid reference'
46 >>> isinstance(err, Exception)
47 True
48 """
51class CompletionService:
52 """MCP completion service.
54 Handles argument completion for:
55 - Prompt arguments based on schema
56 - Resource URIs with templates
57 - Custom completion sources
58 """
60 def __init__(self):
61 """Initialize completion service.
63 Examples:
64 >>> from mcpgateway.services.completion_service import CompletionService
65 >>> service = CompletionService()
66 >>> hasattr(service, '_custom_completions')
67 True
68 >>> service._custom_completions
69 {}
70 """
71 self._custom_completions: Dict[str, List[str]] = {}
73 async def initialize(self) -> None:
74 """Initialize completion service."""
75 logger.info("Initializing completion service")
77 async def shutdown(self) -> None:
78 """Shutdown completion service."""
79 logger.info("Shutting down completion service")
80 self._custom_completions.clear()
82 async def handle_completion(
83 self,
84 db: Session,
85 request: Dict[str, Any],
86 user_email: Optional[str] = None,
87 token_teams: Optional[List[str]] = None,
88 ) -> CompleteResult:
89 """Handle completion request.
91 Args:
92 db: Database session
93 request: Completion request
94 user_email: Caller email used for owner/team visibility checks
95 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope)
97 Returns:
98 Completion result with suggestions
100 Raises:
101 CompletionError: If completion fails
103 Examples:
104 >>> from mcpgateway.services.completion_service import CompletionService
105 >>> from unittest.mock import MagicMock
106 >>> service = CompletionService()
107 >>> db = MagicMock()
108 >>> request = {'ref': {'type': 'ref/prompt', 'name': 'prompt1'}, 'argument': {'name': 'arg1', 'value': ''}}
109 >>> db.execute.return_value.scalars.return_value.all.return_value = []
110 >>> import asyncio
111 >>> try:
112 ... asyncio.run(service.handle_completion(db, request))
113 ... except Exception:
114 ... pass
115 """
116 try:
117 # Get reference and argument info
118 ref = request.get("ref", {})
119 ref_type = ref.get("type")
120 arg = request.get("argument", {})
121 arg_name = arg.get("name")
122 arg_value = arg.get("value", "")
124 if not ref_type or not arg_name:
125 raise CompletionError("Missing reference type or argument name")
127 # Handle different reference types
128 if ref_type == "ref/prompt":
129 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value, user_email=user_email, token_teams=token_teams)
130 elif ref_type == "ref/resource":
131 result = await self._complete_resource_uri(db, ref, arg_value, user_email=user_email, token_teams=token_teams)
132 else:
133 raise CompletionError(f"Invalid reference type: {ref_type}")
135 return result
137 except Exception as e:
138 logger.error(f"Completion error: {e}")
139 raise CompletionError(str(e))
141 async def _resolve_team_ids(self, db: Session, user_email: Optional[str], token_teams: Optional[List[str]]) -> List[str]:
142 """Resolve effective team IDs for scoped visibility checks.
144 Args:
145 db: Database session
146 user_email: Caller email for DB-based team lookup when token teams are not explicit
147 token_teams: Explicit token team scope when present
149 Returns:
150 Effective team IDs used to build visibility filters.
151 """
152 if token_teams is not None:
153 return token_teams
154 if not user_email:
155 return []
157 # First-Party
158 from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel
160 team_service = TeamManagementService(db)
161 user_teams = await team_service.get_user_teams(user_email)
162 return [team.id for team in user_teams]
164 def _apply_visibility_scope(self, stmt, model, user_email: Optional[str], token_teams: Optional[List[str]], team_ids: List[str]):
165 """Apply token/user visibility scope to a SQLAlchemy statement.
167 Args:
168 stmt: SQLAlchemy statement to constrain
169 model: ORM model that includes visibility/team/owner columns
170 user_email: Caller email used for owner visibility
171 token_teams: Explicit token team scope when present
172 team_ids: Effective team IDs for team visibility
174 Returns:
175 Scoped SQLAlchemy statement.
176 """
177 if token_teams is None and user_email is None:
178 return stmt
180 is_public_only_token = token_teams is not None and len(token_teams) == 0
181 access_conditions = [model.visibility == "public"]
183 if not is_public_only_token and user_email:
184 access_conditions.append(model.owner_email == user_email)
186 if team_ids:
187 access_conditions.append(and_(model.team_id.in_(team_ids), model.visibility.in_(["team", "public"])))
189 return stmt.where(or_(*access_conditions))
191 async def _complete_prompt_argument(
192 self,
193 db: Session,
194 ref: Dict[str, Any],
195 arg_name: str,
196 arg_value: str,
197 user_email: Optional[str] = None,
198 token_teams: Optional[List[str]] = None,
199 ) -> CompleteResult:
200 """Complete prompt argument value.
202 Args:
203 db: Database session
204 ref: Prompt reference
205 arg_name: Argument name
206 arg_value: Current argument value
207 user_email: Caller email used for owner/team visibility checks
208 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope)
210 Returns:
211 Completion suggestions
213 Raises:
214 CompletionError: If prompt is missing or not found
216 Examples:
217 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
218 >>> from unittest.mock import MagicMock
219 >>> import asyncio
220 >>> service = CompletionService()
221 >>> db = MagicMock()
223 >>> # Test missing prompt name
224 >>> ref = {}
225 >>> try:
226 ... asyncio.run(service._complete_prompt_argument(db, ref, 'arg1', 'val'))
227 ... except CompletionError as e:
228 ... str(e)
229 'Missing prompt name'
231 >>> # Test custom completions
232 >>> service.register_completions('color', ['red', 'green', 'blue'])
233 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock(
234 ... argument_schema={'properties': {'color': {'name': 'color'}}}
235 ... )
236 >>> result = asyncio.run(service._complete_prompt_argument(
237 ... db, {'name': 'test'}, 'color', 'r'
238 ... ))
239 >>> result.completion['values']
240 ['red', 'green']
241 """
242 # Get prompt
243 prompt_name = ref.get("name")
244 if not prompt_name:
245 raise CompletionError("Missing prompt name")
247 # Only consider prompts that are enabled and visible to caller
248 team_ids = await self._resolve_team_ids(db, user_email, token_teams)
249 stmt = select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.enabled)
250 stmt = self._apply_visibility_scope(stmt, DbPrompt, user_email=user_email, token_teams=token_teams, team_ids=team_ids)
251 stmt = stmt.order_by(desc(DbPrompt.created_at), desc(DbPrompt.id)).limit(1)
252 prompt = db.execute(stmt).scalar_one_or_none()
254 if not prompt:
255 raise CompletionError(f"Prompt not found: {prompt_name}")
257 # Find argument in schema
258 arg_schema = None
259 for arg in prompt.argument_schema.get("properties", {}).values():
260 if arg.get("name") == arg_name:
261 arg_schema = arg
262 break
264 if not arg_schema:
265 raise CompletionError(f"Argument not found: {arg_name}")
267 # Get enum values if defined
268 if "enum" in arg_schema:
269 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()]
270 return CompleteResult(
271 completion={
272 "values": values[:100],
273 "total": len(values),
274 "hasMore": len(values) > 100,
275 }
276 )
278 # Check custom completions
279 if arg_name in self._custom_completions:
280 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()]
281 return CompleteResult(
282 completion={
283 "values": values[:100],
284 "total": len(values),
285 "hasMore": len(values) > 100,
286 }
287 )
289 # No completions available
290 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False})
292 async def _complete_resource_uri(
293 self,
294 db: Session,
295 ref: Dict[str, Any],
296 arg_value: str,
297 user_email: Optional[str] = None,
298 token_teams: Optional[List[str]] = None,
299 ) -> CompleteResult:
300 """Complete resource URI.
302 Args:
303 db: Database session
304 ref: Resource reference
305 arg_value: Current URI value
306 user_email: Caller email used for owner/team visibility checks
307 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope)
309 Returns:
310 URI completion suggestions
312 Raises:
313 CompletionError: If URI template is missing
315 Examples:
316 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
317 >>> from unittest.mock import MagicMock
318 >>> import asyncio
319 >>> service = CompletionService()
320 >>> db = MagicMock()
322 >>> # Test missing URI template
323 >>> ref = {}
324 >>> try:
325 ... asyncio.run(service._complete_resource_uri(db, ref, 'test'))
326 ... except CompletionError as e:
327 ... str(e)
328 'Missing URI template'
330 >>> # Test resource filtering
331 >>> ref = {'uri': 'template://'}
332 >>> mock_resources = [
333 ... MagicMock(uri='file://doc1.txt'),
334 ... MagicMock(uri='file://doc2.txt'),
335 ... MagicMock(uri='http://example.com')
336 ... ]
337 >>> db.execute.return_value.scalars.return_value.all.return_value = mock_resources
338 >>> result = asyncio.run(service._complete_resource_uri(db, ref, 'doc'))
339 >>> len(result.completion['values'])
340 2
341 >>> 'file://doc1.txt' in result.completion['values']
342 True
343 """
344 # Get base URI template
345 uri_template = ref.get("uri")
346 if not uri_template:
347 raise CompletionError("Missing URI template")
349 # List matching resources visible to caller
350 team_ids = await self._resolve_team_ids(db, user_email, token_teams)
351 stmt = select(DbResource).where(DbResource.enabled)
352 stmt = self._apply_visibility_scope(stmt, DbResource, user_email=user_email, token_teams=token_teams, team_ids=team_ids)
353 resources = db.execute(stmt).scalars().all()
355 # Filter by URI pattern
356 matches = []
357 for resource in resources:
358 if arg_value.lower() in resource.uri.lower():
359 matches.append(resource.uri)
361 return CompleteResult(
362 completion={
363 "values": matches[:100],
364 "total": len(matches),
365 "hasMore": len(matches) > 100,
366 }
367 )
369 def register_completions(self, arg_name: str, values: List[str]) -> None:
370 """Register custom completion values.
372 Args:
373 arg_name: Argument name
374 values: Completion values
376 Examples:
377 >>> from mcpgateway.services.completion_service import CompletionService
378 >>> service = CompletionService()
379 >>> service.register_completions('arg1', ['a', 'b'])
380 >>> service._custom_completions['arg1']
381 ['a', 'b']
382 >>> service.register_completions('arg2', ['x', 'y', 'z'])
383 >>> len(service._custom_completions)
384 2
385 >>> service.register_completions('arg1', ['c']) # Overwrite
386 >>> service._custom_completions['arg1']
387 ['c']
388 """
389 self._custom_completions[arg_name] = list(values)
391 def unregister_completions(self, arg_name: str) -> None:
392 """Unregister custom completion values.
394 Args:
395 arg_name: Argument name
397 Examples:
398 >>> from mcpgateway.services.completion_service import CompletionService
399 >>> service = CompletionService()
400 >>> service.register_completions('arg1', ['a', 'b'])
401 >>> service.unregister_completions('arg1')
402 >>> 'arg1' in service._custom_completions
403 False
404 """
405 self._custom_completions.pop(arg_name, None)