Coverage for mcpgateway / services / completion_service.py: 100%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/completion_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Completion Service Implementation.
8This module implements argument completion according to the MCP specification.
9It handles completion suggestions for prompt arguments and resource URIs.
11Examples:
12 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
13 >>> service = CompletionService()
14 >>> isinstance(service, CompletionService)
15 True
16 >>> service._custom_completions
17 {}
18"""
20# Standard
21from typing import Any, Dict, List
23# Third-Party
24from sqlalchemy import select
25from sqlalchemy.orm import Session
27# First-Party
28from mcpgateway.common.models import CompleteResult
29from mcpgateway.db import Prompt as DbPrompt
30from mcpgateway.db import Resource as DbResource
31from mcpgateway.services.logging_service import LoggingService
33# Initialize logging service first
34logging_service = LoggingService()
35logger = logging_service.get_logger(__name__)
38class CompletionError(Exception):
39 """Base class for completion errors.
41 Examples:
42 >>> from mcpgateway.services.completion_service import CompletionError
43 >>> err = CompletionError("Invalid reference")
44 >>> str(err)
45 'Invalid reference'
46 >>> isinstance(err, Exception)
47 True
48 """
51class CompletionService:
52 """MCP completion service.
54 Handles argument completion for:
55 - Prompt arguments based on schema
56 - Resource URIs with templates
57 - Custom completion sources
58 """
60 def __init__(self):
61 """Initialize completion service.
63 Examples:
64 >>> from mcpgateway.services.completion_service import CompletionService
65 >>> service = CompletionService()
66 >>> hasattr(service, '_custom_completions')
67 True
68 >>> service._custom_completions
69 {}
70 """
71 self._custom_completions: Dict[str, List[str]] = {}
73 async def initialize(self) -> None:
74 """Initialize completion service."""
75 logger.info("Initializing completion service")
77 async def shutdown(self) -> None:
78 """Shutdown completion service."""
79 logger.info("Shutting down completion service")
80 self._custom_completions.clear()
82 async def handle_completion(self, db: Session, request: Dict[str, Any]) -> CompleteResult:
83 """Handle completion request.
85 Args:
86 db: Database session
87 request: Completion request
89 Returns:
90 Completion result with suggestions
92 Raises:
93 CompletionError: If completion fails
95 Examples:
96 >>> from mcpgateway.services.completion_service import CompletionService
97 >>> from unittest.mock import MagicMock
98 >>> service = CompletionService()
99 >>> db = MagicMock()
100 >>> request = {'ref': {'type': 'ref/prompt', 'name': 'prompt1'}, 'argument': {'name': 'arg1', 'value': ''}}
101 >>> db.execute.return_value.scalars.return_value.all.return_value = []
102 >>> import asyncio
103 >>> try:
104 ... asyncio.run(service.handle_completion(db, request))
105 ... except Exception:
106 ... pass
107 """
108 try:
109 # Get reference and argument info
110 ref = request.get("ref", {})
111 ref_type = ref.get("type")
112 arg = request.get("argument", {})
113 arg_name = arg.get("name")
114 arg_value = arg.get("value", "")
116 if not ref_type or not arg_name:
117 raise CompletionError("Missing reference type or argument name")
119 # Handle different reference types
120 if ref_type == "ref/prompt":
121 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value)
122 elif ref_type == "ref/resource":
123 result = await self._complete_resource_uri(db, ref, arg_value)
124 else:
125 raise CompletionError(f"Invalid reference type: {ref_type}")
127 return result
129 except Exception as e:
130 logger.error(f"Completion error: {e}")
131 raise CompletionError(str(e))
133 async def _complete_prompt_argument(self, db: Session, ref: Dict[str, Any], arg_name: str, arg_value: str) -> CompleteResult:
134 """Complete prompt argument value.
136 Args:
137 db: Database session
138 ref: Prompt reference
139 arg_name: Argument name
140 arg_value: Current argument value
142 Returns:
143 Completion suggestions
145 Raises:
146 CompletionError: If prompt is missing or not found
148 Examples:
149 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
150 >>> from unittest.mock import MagicMock
151 >>> import asyncio
152 >>> service = CompletionService()
153 >>> db = MagicMock()
155 >>> # Test missing prompt name
156 >>> ref = {}
157 >>> try:
158 ... asyncio.run(service._complete_prompt_argument(db, ref, 'arg1', 'val'))
159 ... except CompletionError as e:
160 ... str(e)
161 'Missing prompt name'
163 >>> # Test custom completions
164 >>> service.register_completions('color', ['red', 'green', 'blue'])
165 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock(
166 ... argument_schema={'properties': {'color': {'name': 'color'}}}
167 ... )
168 >>> result = asyncio.run(service._complete_prompt_argument(
169 ... db, {'name': 'test'}, 'color', 'r'
170 ... ))
171 >>> result.completion['values']
172 ['red', 'green']
173 """
174 # Get prompt
175 prompt_name = ref.get("name")
176 if not prompt_name:
177 raise CompletionError("Missing prompt name")
179 # Only consider prompts that are enabled (renamed from `is_active` -> `enabled`)
180 prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.enabled)).scalar_one_or_none()
182 if not prompt:
183 raise CompletionError(f"Prompt not found: {prompt_name}")
185 # Find argument in schema
186 arg_schema = None
187 for arg in prompt.argument_schema.get("properties", {}).values():
188 if arg.get("name") == arg_name:
189 arg_schema = arg
190 break
192 if not arg_schema:
193 raise CompletionError(f"Argument not found: {arg_name}")
195 # Get enum values if defined
196 if "enum" in arg_schema:
197 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()]
198 return CompleteResult(
199 completion={
200 "values": values[:100],
201 "total": len(values),
202 "hasMore": len(values) > 100,
203 }
204 )
206 # Check custom completions
207 if arg_name in self._custom_completions:
208 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()]
209 return CompleteResult(
210 completion={
211 "values": values[:100],
212 "total": len(values),
213 "hasMore": len(values) > 100,
214 }
215 )
217 # No completions available
218 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False})
220 async def _complete_resource_uri(self, db: Session, ref: Dict[str, Any], arg_value: str) -> CompleteResult:
221 """Complete resource URI.
223 Args:
224 db: Database session
225 ref: Resource reference
226 arg_value: Current URI value
228 Returns:
229 URI completion suggestions
231 Raises:
232 CompletionError: If URI template is missing
234 Examples:
235 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError
236 >>> from unittest.mock import MagicMock
237 >>> import asyncio
238 >>> service = CompletionService()
239 >>> db = MagicMock()
241 >>> # Test missing URI template
242 >>> ref = {}
243 >>> try:
244 ... asyncio.run(service._complete_resource_uri(db, ref, 'test'))
245 ... except CompletionError as e:
246 ... str(e)
247 'Missing URI template'
249 >>> # Test resource filtering
250 >>> ref = {'uri': 'template://'}
251 >>> mock_resources = [
252 ... MagicMock(uri='file://doc1.txt'),
253 ... MagicMock(uri='file://doc2.txt'),
254 ... MagicMock(uri='http://example.com')
255 ... ]
256 >>> db.execute.return_value.scalars.return_value.all.return_value = mock_resources
257 >>> result = asyncio.run(service._complete_resource_uri(db, ref, 'doc'))
258 >>> len(result.completion['values'])
259 2
260 >>> 'file://doc1.txt' in result.completion['values']
261 True
262 """
263 # Get base URI template
264 uri_template = ref.get("uri")
265 if not uri_template:
266 raise CompletionError("Missing URI template")
268 # List matching resources
269 resources = db.execute(select(DbResource).where(DbResource.enabled)).scalars().all()
271 # Filter by URI pattern
272 matches = []
273 for resource in resources:
274 if arg_value.lower() in resource.uri.lower():
275 matches.append(resource.uri)
277 return CompleteResult(
278 completion={
279 "values": matches[:100],
280 "total": len(matches),
281 "hasMore": len(matches) > 100,
282 }
283 )
285 def register_completions(self, arg_name: str, values: List[str]) -> None:
286 """Register custom completion values.
288 Args:
289 arg_name: Argument name
290 values: Completion values
292 Examples:
293 >>> from mcpgateway.services.completion_service import CompletionService
294 >>> service = CompletionService()
295 >>> service.register_completions('arg1', ['a', 'b'])
296 >>> service._custom_completions['arg1']
297 ['a', 'b']
298 >>> service.register_completions('arg2', ['x', 'y', 'z'])
299 >>> len(service._custom_completions)
300 2
301 >>> service.register_completions('arg1', ['c']) # Overwrite
302 >>> service._custom_completions['arg1']
303 ['c']
304 """
305 self._custom_completions[arg_name] = list(values)
307 def unregister_completions(self, arg_name: str) -> None:
308 """Unregister custom completion values.
310 Args:
311 arg_name: Argument name
313 Examples:
314 >>> from mcpgateway.services.completion_service import CompletionService
315 >>> service = CompletionService()
316 >>> service.register_completions('arg1', ['a', 'b'])
317 >>> service.unregister_completions('arg1')
318 >>> 'arg1' in service._custom_completions
319 False
320 """
321 self._custom_completions.pop(arg_name, None)