Coverage for mcpgateway / services / completion_service.py: 100%

71 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/completion_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Completion Service Implementation. 

8This module implements argument completion according to the MCP specification. 

9It handles completion suggestions for prompt arguments and resource URIs. 

10 

11Examples: 

12 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

13 >>> service = CompletionService() 

14 >>> isinstance(service, CompletionService) 

15 True 

16 >>> service._custom_completions 

17 {} 

18""" 

19 

20# Standard 

21from typing import Any, Dict, List 

22 

23# Third-Party 

24from sqlalchemy import select 

25from sqlalchemy.orm import Session 

26 

27# First-Party 

28from mcpgateway.common.models import CompleteResult 

29from mcpgateway.db import Prompt as DbPrompt 

30from mcpgateway.db import Resource as DbResource 

31from mcpgateway.services.logging_service import LoggingService 

32 

33# Initialize logging service first 

34logging_service = LoggingService() 

35logger = logging_service.get_logger(__name__) 

36 

37 

38class CompletionError(Exception): 

39 """Base class for completion errors. 

40 

41 Examples: 

42 >>> from mcpgateway.services.completion_service import CompletionError 

43 >>> err = CompletionError("Invalid reference") 

44 >>> str(err) 

45 'Invalid reference' 

46 >>> isinstance(err, Exception) 

47 True 

48 """ 

49 

50 

51class CompletionService: 

52 """MCP completion service. 

53 

54 Handles argument completion for: 

55 - Prompt arguments based on schema 

56 - Resource URIs with templates 

57 - Custom completion sources 

58 """ 

59 

60 def __init__(self): 

61 """Initialize completion service. 

62 

63 Examples: 

64 >>> from mcpgateway.services.completion_service import CompletionService 

65 >>> service = CompletionService() 

66 >>> hasattr(service, '_custom_completions') 

67 True 

68 >>> service._custom_completions 

69 {} 

70 """ 

71 self._custom_completions: Dict[str, List[str]] = {} 

72 

73 async def initialize(self) -> None: 

74 """Initialize completion service.""" 

75 logger.info("Initializing completion service") 

76 

77 async def shutdown(self) -> None: 

78 """Shutdown completion service.""" 

79 logger.info("Shutting down completion service") 

80 self._custom_completions.clear() 

81 

82 async def handle_completion(self, db: Session, request: Dict[str, Any]) -> CompleteResult: 

83 """Handle completion request. 

84 

85 Args: 

86 db: Database session 

87 request: Completion request 

88 

89 Returns: 

90 Completion result with suggestions 

91 

92 Raises: 

93 CompletionError: If completion fails 

94 

95 Examples: 

96 >>> from mcpgateway.services.completion_service import CompletionService 

97 >>> from unittest.mock import MagicMock 

98 >>> service = CompletionService() 

99 >>> db = MagicMock() 

100 >>> request = {'ref': {'type': 'ref/prompt', 'name': 'prompt1'}, 'argument': {'name': 'arg1', 'value': ''}} 

101 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

102 >>> import asyncio 

103 >>> try: 

104 ... asyncio.run(service.handle_completion(db, request)) 

105 ... except Exception: 

106 ... pass 

107 """ 

108 try: 

109 # Get reference and argument info 

110 ref = request.get("ref", {}) 

111 ref_type = ref.get("type") 

112 arg = request.get("argument", {}) 

113 arg_name = arg.get("name") 

114 arg_value = arg.get("value", "") 

115 

116 if not ref_type or not arg_name: 

117 raise CompletionError("Missing reference type or argument name") 

118 

119 # Handle different reference types 

120 if ref_type == "ref/prompt": 

121 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value) 

122 elif ref_type == "ref/resource": 

123 result = await self._complete_resource_uri(db, ref, arg_value) 

124 else: 

125 raise CompletionError(f"Invalid reference type: {ref_type}") 

126 

127 return result 

128 

129 except Exception as e: 

130 logger.error(f"Completion error: {e}") 

131 raise CompletionError(str(e)) 

132 

133 async def _complete_prompt_argument(self, db: Session, ref: Dict[str, Any], arg_name: str, arg_value: str) -> CompleteResult: 

134 """Complete prompt argument value. 

135 

136 Args: 

137 db: Database session 

138 ref: Prompt reference 

139 arg_name: Argument name 

140 arg_value: Current argument value 

141 

142 Returns: 

143 Completion suggestions 

144 

145 Raises: 

146 CompletionError: If prompt is missing or not found 

147 

148 Examples: 

149 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

150 >>> from unittest.mock import MagicMock 

151 >>> import asyncio 

152 >>> service = CompletionService() 

153 >>> db = MagicMock() 

154 

155 >>> # Test missing prompt name 

156 >>> ref = {} 

157 >>> try: 

158 ... asyncio.run(service._complete_prompt_argument(db, ref, 'arg1', 'val')) 

159 ... except CompletionError as e: 

160 ... str(e) 

161 'Missing prompt name' 

162 

163 >>> # Test custom completions 

164 >>> service.register_completions('color', ['red', 'green', 'blue']) 

165 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock( 

166 ... argument_schema={'properties': {'color': {'name': 'color'}}} 

167 ... ) 

168 >>> result = asyncio.run(service._complete_prompt_argument( 

169 ... db, {'name': 'test'}, 'color', 'r' 

170 ... )) 

171 >>> result.completion['values'] 

172 ['red', 'green'] 

173 """ 

174 # Get prompt 

175 prompt_name = ref.get("name") 

176 if not prompt_name: 

177 raise CompletionError("Missing prompt name") 

178 

179 # Only consider prompts that are enabled (renamed from `is_active` -> `enabled`) 

180 prompt = db.execute(select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.enabled)).scalar_one_or_none() 

181 

182 if not prompt: 

183 raise CompletionError(f"Prompt not found: {prompt_name}") 

184 

185 # Find argument in schema 

186 arg_schema = None 

187 for arg in prompt.argument_schema.get("properties", {}).values(): 

188 if arg.get("name") == arg_name: 

189 arg_schema = arg 

190 break 

191 

192 if not arg_schema: 

193 raise CompletionError(f"Argument not found: {arg_name}") 

194 

195 # Get enum values if defined 

196 if "enum" in arg_schema: 

197 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()] 

198 return CompleteResult( 

199 completion={ 

200 "values": values[:100], 

201 "total": len(values), 

202 "hasMore": len(values) > 100, 

203 } 

204 ) 

205 

206 # Check custom completions 

207 if arg_name in self._custom_completions: 

208 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()] 

209 return CompleteResult( 

210 completion={ 

211 "values": values[:100], 

212 "total": len(values), 

213 "hasMore": len(values) > 100, 

214 } 

215 ) 

216 

217 # No completions available 

218 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False}) 

219 

220 async def _complete_resource_uri(self, db: Session, ref: Dict[str, Any], arg_value: str) -> CompleteResult: 

221 """Complete resource URI. 

222 

223 Args: 

224 db: Database session 

225 ref: Resource reference 

226 arg_value: Current URI value 

227 

228 Returns: 

229 URI completion suggestions 

230 

231 Raises: 

232 CompletionError: If URI template is missing 

233 

234 Examples: 

235 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

236 >>> from unittest.mock import MagicMock 

237 >>> import asyncio 

238 >>> service = CompletionService() 

239 >>> db = MagicMock() 

240 

241 >>> # Test missing URI template 

242 >>> ref = {} 

243 >>> try: 

244 ... asyncio.run(service._complete_resource_uri(db, ref, 'test')) 

245 ... except CompletionError as e: 

246 ... str(e) 

247 'Missing URI template' 

248 

249 >>> # Test resource filtering 

250 >>> ref = {'uri': 'template://'} 

251 >>> mock_resources = [ 

252 ... MagicMock(uri='file://doc1.txt'), 

253 ... MagicMock(uri='file://doc2.txt'), 

254 ... MagicMock(uri='http://example.com') 

255 ... ] 

256 >>> db.execute.return_value.scalars.return_value.all.return_value = mock_resources 

257 >>> result = asyncio.run(service._complete_resource_uri(db, ref, 'doc')) 

258 >>> len(result.completion['values']) 

259 2 

260 >>> 'file://doc1.txt' in result.completion['values'] 

261 True 

262 """ 

263 # Get base URI template 

264 uri_template = ref.get("uri") 

265 if not uri_template: 

266 raise CompletionError("Missing URI template") 

267 

268 # List matching resources 

269 resources = db.execute(select(DbResource).where(DbResource.enabled)).scalars().all() 

270 

271 # Filter by URI pattern 

272 matches = [] 

273 for resource in resources: 

274 if arg_value.lower() in resource.uri.lower(): 

275 matches.append(resource.uri) 

276 

277 return CompleteResult( 

278 completion={ 

279 "values": matches[:100], 

280 "total": len(matches), 

281 "hasMore": len(matches) > 100, 

282 } 

283 ) 

284 

285 def register_completions(self, arg_name: str, values: List[str]) -> None: 

286 """Register custom completion values. 

287 

288 Args: 

289 arg_name: Argument name 

290 values: Completion values 

291 

292 Examples: 

293 >>> from mcpgateway.services.completion_service import CompletionService 

294 >>> service = CompletionService() 

295 >>> service.register_completions('arg1', ['a', 'b']) 

296 >>> service._custom_completions['arg1'] 

297 ['a', 'b'] 

298 >>> service.register_completions('arg2', ['x', 'y', 'z']) 

299 >>> len(service._custom_completions) 

300 2 

301 >>> service.register_completions('arg1', ['c']) # Overwrite 

302 >>> service._custom_completions['arg1'] 

303 ['c'] 

304 """ 

305 self._custom_completions[arg_name] = list(values) 

306 

307 def unregister_completions(self, arg_name: str) -> None: 

308 """Unregister custom completion values. 

309 

310 Args: 

311 arg_name: Argument name 

312 

313 Examples: 

314 >>> from mcpgateway.services.completion_service import CompletionService 

315 >>> service = CompletionService() 

316 >>> service.register_completions('arg1', ['a', 'b']) 

317 >>> service.unregister_completions('arg1') 

318 >>> 'arg1' in service._custom_completions 

319 False 

320 """ 

321 self._custom_completions.pop(arg_name, None)