Coverage for mcpgateway / services / completion_service.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/completion_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Completion Service Implementation. 

8This module implements argument completion according to the MCP specification. 

9It handles completion suggestions for prompt arguments and resource URIs. 

10 

11Examples: 

12 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

13 >>> service = CompletionService() 

14 >>> isinstance(service, CompletionService) 

15 True 

16 >>> service._custom_completions 

17 {} 

18""" 

19 

20# Standard 

21from typing import Any, Dict, List, Optional 

22 

23# Third-Party 

24from sqlalchemy import and_, desc, or_, select 

25from sqlalchemy.orm import Session 

26 

27# First-Party 

28from mcpgateway.common.models import CompleteResult 

29from mcpgateway.db import Prompt as DbPrompt 

30from mcpgateway.db import Resource as DbResource 

31from mcpgateway.services.logging_service import LoggingService 

32 

33# Initialize logging service first 

34logging_service = LoggingService() 

35logger = logging_service.get_logger(__name__) 

36 

37 

38class CompletionError(Exception): 

39 """Base class for completion errors. 

40 

41 Examples: 

42 >>> from mcpgateway.services.completion_service import CompletionError 

43 >>> err = CompletionError("Invalid reference") 

44 >>> str(err) 

45 'Invalid reference' 

46 >>> isinstance(err, Exception) 

47 True 

48 """ 

49 

50 

51class CompletionService: 

52 """MCP completion service. 

53 

54 Handles argument completion for: 

55 - Prompt arguments based on schema 

56 - Resource URIs with templates 

57 - Custom completion sources 

58 """ 

59 

60 def __init__(self): 

61 """Initialize completion service. 

62 

63 Examples: 

64 >>> from mcpgateway.services.completion_service import CompletionService 

65 >>> service = CompletionService() 

66 >>> hasattr(service, '_custom_completions') 

67 True 

68 >>> service._custom_completions 

69 {} 

70 """ 

71 self._custom_completions: Dict[str, List[str]] = {} 

72 

73 async def initialize(self) -> None: 

74 """Initialize completion service.""" 

75 logger.info("Initializing completion service") 

76 

77 async def shutdown(self) -> None: 

78 """Shutdown completion service.""" 

79 logger.info("Shutting down completion service") 

80 self._custom_completions.clear() 

81 

82 async def handle_completion( 

83 self, 

84 db: Session, 

85 request: Dict[str, Any], 

86 user_email: Optional[str] = None, 

87 token_teams: Optional[List[str]] = None, 

88 ) -> CompleteResult: 

89 """Handle completion request. 

90 

91 Args: 

92 db: Database session 

93 request: Completion request 

94 user_email: Caller email used for owner/team visibility checks 

95 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope) 

96 

97 Returns: 

98 Completion result with suggestions 

99 

100 Raises: 

101 CompletionError: If completion fails 

102 

103 Examples: 

104 >>> from mcpgateway.services.completion_service import CompletionService 

105 >>> from unittest.mock import MagicMock 

106 >>> service = CompletionService() 

107 >>> db = MagicMock() 

108 >>> request = {'ref': {'type': 'ref/prompt', 'name': 'prompt1'}, 'argument': {'name': 'arg1', 'value': ''}} 

109 >>> db.execute.return_value.scalars.return_value.all.return_value = [] 

110 >>> import asyncio 

111 >>> try: 

112 ... asyncio.run(service.handle_completion(db, request)) 

113 ... except Exception: 

114 ... pass 

115 """ 

116 try: 

117 # Get reference and argument info 

118 ref = request.get("ref", {}) 

119 ref_type = ref.get("type") 

120 arg = request.get("argument", {}) 

121 arg_name = arg.get("name") 

122 arg_value = arg.get("value", "") 

123 

124 if not ref_type or not arg_name: 

125 raise CompletionError("Missing reference type or argument name") 

126 

127 # Handle different reference types 

128 if ref_type == "ref/prompt": 

129 result = await self._complete_prompt_argument(db, ref, arg_name, arg_value, user_email=user_email, token_teams=token_teams) 

130 elif ref_type == "ref/resource": 

131 result = await self._complete_resource_uri(db, ref, arg_value, user_email=user_email, token_teams=token_teams) 

132 else: 

133 raise CompletionError(f"Invalid reference type: {ref_type}") 

134 

135 return result 

136 

137 except Exception as e: 

138 logger.error(f"Completion error: {e}") 

139 raise CompletionError(str(e)) 

140 

141 async def _resolve_team_ids(self, db: Session, user_email: Optional[str], token_teams: Optional[List[str]]) -> List[str]: 

142 """Resolve effective team IDs for scoped visibility checks. 

143 

144 Args: 

145 db: Database session 

146 user_email: Caller email for DB-based team lookup when token teams are not explicit 

147 token_teams: Explicit token team scope when present 

148 

149 Returns: 

150 Effective team IDs used to build visibility filters. 

151 """ 

152 if token_teams is not None: 

153 return token_teams 

154 if not user_email: 

155 return [] 

156 

157 # First-Party 

158 from mcpgateway.services.team_management_service import TeamManagementService # pylint: disable=import-outside-toplevel 

159 

160 team_service = TeamManagementService(db) 

161 user_teams = await team_service.get_user_teams(user_email) 

162 return [team.id for team in user_teams] 

163 

164 def _apply_visibility_scope(self, stmt, model, user_email: Optional[str], token_teams: Optional[List[str]], team_ids: List[str]): 

165 """Apply token/user visibility scope to a SQLAlchemy statement. 

166 

167 Args: 

168 stmt: SQLAlchemy statement to constrain 

169 model: ORM model that includes visibility/team/owner columns 

170 user_email: Caller email used for owner visibility 

171 token_teams: Explicit token team scope when present 

172 team_ids: Effective team IDs for team visibility 

173 

174 Returns: 

175 Scoped SQLAlchemy statement. 

176 """ 

177 if token_teams is None and user_email is None: 

178 return stmt 

179 

180 is_public_only_token = token_teams is not None and len(token_teams) == 0 

181 access_conditions = [model.visibility == "public"] 

182 

183 if not is_public_only_token and user_email: 

184 access_conditions.append(model.owner_email == user_email) 

185 

186 if team_ids: 

187 access_conditions.append(and_(model.team_id.in_(team_ids), model.visibility.in_(["team", "public"]))) 

188 

189 return stmt.where(or_(*access_conditions)) 

190 

191 async def _complete_prompt_argument( 

192 self, 

193 db: Session, 

194 ref: Dict[str, Any], 

195 arg_name: str, 

196 arg_value: str, 

197 user_email: Optional[str] = None, 

198 token_teams: Optional[List[str]] = None, 

199 ) -> CompleteResult: 

200 """Complete prompt argument value. 

201 

202 Args: 

203 db: Database session 

204 ref: Prompt reference 

205 arg_name: Argument name 

206 arg_value: Current argument value 

207 user_email: Caller email used for owner/team visibility checks 

208 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope) 

209 

210 Returns: 

211 Completion suggestions 

212 

213 Raises: 

214 CompletionError: If prompt is missing or not found 

215 

216 Examples: 

217 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

218 >>> from unittest.mock import MagicMock 

219 >>> import asyncio 

220 >>> service = CompletionService() 

221 >>> db = MagicMock() 

222 

223 >>> # Test missing prompt name 

224 >>> ref = {} 

225 >>> try: 

226 ... asyncio.run(service._complete_prompt_argument(db, ref, 'arg1', 'val')) 

227 ... except CompletionError as e: 

228 ... str(e) 

229 'Missing prompt name' 

230 

231 >>> # Test custom completions 

232 >>> service.register_completions('color', ['red', 'green', 'blue']) 

233 >>> db.execute.return_value.scalar_one_or_none.return_value = MagicMock( 

234 ... argument_schema={'properties': {'color': {'name': 'color'}}} 

235 ... ) 

236 >>> result = asyncio.run(service._complete_prompt_argument( 

237 ... db, {'name': 'test'}, 'color', 'r' 

238 ... )) 

239 >>> result.completion['values'] 

240 ['red', 'green'] 

241 """ 

242 # Get prompt 

243 prompt_name = ref.get("name") 

244 if not prompt_name: 

245 raise CompletionError("Missing prompt name") 

246 

247 # Only consider prompts that are enabled and visible to caller 

248 team_ids = await self._resolve_team_ids(db, user_email, token_teams) 

249 stmt = select(DbPrompt).where(DbPrompt.name == prompt_name).where(DbPrompt.enabled) 

250 stmt = self._apply_visibility_scope(stmt, DbPrompt, user_email=user_email, token_teams=token_teams, team_ids=team_ids) 

251 stmt = stmt.order_by(desc(DbPrompt.created_at), desc(DbPrompt.id)).limit(1) 

252 prompt = db.execute(stmt).scalar_one_or_none() 

253 

254 if not prompt: 

255 raise CompletionError(f"Prompt not found: {prompt_name}") 

256 

257 # Find argument in schema 

258 arg_schema = None 

259 for arg in prompt.argument_schema.get("properties", {}).values(): 

260 if arg.get("name") == arg_name: 

261 arg_schema = arg 

262 break 

263 

264 if not arg_schema: 

265 raise CompletionError(f"Argument not found: {arg_name}") 

266 

267 # Get enum values if defined 

268 if "enum" in arg_schema: 

269 values = [v for v in arg_schema["enum"] if arg_value.lower() in str(v).lower()] 

270 return CompleteResult( 

271 completion={ 

272 "values": values[:100], 

273 "total": len(values), 

274 "hasMore": len(values) > 100, 

275 } 

276 ) 

277 

278 # Check custom completions 

279 if arg_name in self._custom_completions: 

280 values = [v for v in self._custom_completions[arg_name] if arg_value.lower() in v.lower()] 

281 return CompleteResult( 

282 completion={ 

283 "values": values[:100], 

284 "total": len(values), 

285 "hasMore": len(values) > 100, 

286 } 

287 ) 

288 

289 # No completions available 

290 return CompleteResult(completion={"values": [], "total": 0, "hasMore": False}) 

291 

292 async def _complete_resource_uri( 

293 self, 

294 db: Session, 

295 ref: Dict[str, Any], 

296 arg_value: str, 

297 user_email: Optional[str] = None, 

298 token_teams: Optional[List[str]] = None, 

299 ) -> CompleteResult: 

300 """Complete resource URI. 

301 

302 Args: 

303 db: Database session 

304 ref: Resource reference 

305 arg_value: Current URI value 

306 user_email: Caller email used for owner/team visibility checks 

307 token_teams: Normalized token teams (`None` admin bypass, `[]` public-only, list for team scope) 

308 

309 Returns: 

310 URI completion suggestions 

311 

312 Raises: 

313 CompletionError: If URI template is missing 

314 

315 Examples: 

316 >>> from mcpgateway.services.completion_service import CompletionService, CompletionError 

317 >>> from unittest.mock import MagicMock 

318 >>> import asyncio 

319 >>> service = CompletionService() 

320 >>> db = MagicMock() 

321 

322 >>> # Test missing URI template 

323 >>> ref = {} 

324 >>> try: 

325 ... asyncio.run(service._complete_resource_uri(db, ref, 'test')) 

326 ... except CompletionError as e: 

327 ... str(e) 

328 'Missing URI template' 

329 

330 >>> # Test resource filtering 

331 >>> ref = {'uri': 'template://'} 

332 >>> mock_resources = [ 

333 ... MagicMock(uri='file://doc1.txt'), 

334 ... MagicMock(uri='file://doc2.txt'), 

335 ... MagicMock(uri='http://example.com') 

336 ... ] 

337 >>> db.execute.return_value.scalars.return_value.all.return_value = mock_resources 

338 >>> result = asyncio.run(service._complete_resource_uri(db, ref, 'doc')) 

339 >>> len(result.completion['values']) 

340 2 

341 >>> 'file://doc1.txt' in result.completion['values'] 

342 True 

343 """ 

344 # Get base URI template 

345 uri_template = ref.get("uri") 

346 if not uri_template: 

347 raise CompletionError("Missing URI template") 

348 

349 # List matching resources visible to caller 

350 team_ids = await self._resolve_team_ids(db, user_email, token_teams) 

351 stmt = select(DbResource).where(DbResource.enabled) 

352 stmt = self._apply_visibility_scope(stmt, DbResource, user_email=user_email, token_teams=token_teams, team_ids=team_ids) 

353 resources = db.execute(stmt).scalars().all() 

354 

355 # Filter by URI pattern 

356 matches = [] 

357 for resource in resources: 

358 if arg_value.lower() in resource.uri.lower(): 

359 matches.append(resource.uri) 

360 

361 return CompleteResult( 

362 completion={ 

363 "values": matches[:100], 

364 "total": len(matches), 

365 "hasMore": len(matches) > 100, 

366 } 

367 ) 

368 

369 def register_completions(self, arg_name: str, values: List[str]) -> None: 

370 """Register custom completion values. 

371 

372 Args: 

373 arg_name: Argument name 

374 values: Completion values 

375 

376 Examples: 

377 >>> from mcpgateway.services.completion_service import CompletionService 

378 >>> service = CompletionService() 

379 >>> service.register_completions('arg1', ['a', 'b']) 

380 >>> service._custom_completions['arg1'] 

381 ['a', 'b'] 

382 >>> service.register_completions('arg2', ['x', 'y', 'z']) 

383 >>> len(service._custom_completions) 

384 2 

385 >>> service.register_completions('arg1', ['c']) # Overwrite 

386 >>> service._custom_completions['arg1'] 

387 ['c'] 

388 """ 

389 self._custom_completions[arg_name] = list(values) 

390 

391 def unregister_completions(self, arg_name: str) -> None: 

392 """Unregister custom completion values. 

393 

394 Args: 

395 arg_name: Argument name 

396 

397 Examples: 

398 >>> from mcpgateway.services.completion_service import CompletionService 

399 >>> service = CompletionService() 

400 >>> service.register_completions('arg1', ['a', 'b']) 

401 >>> service.unregister_completions('arg1') 

402 >>> 'arg1' in service._custom_completions 

403 False 

404 """ 

405 self._custom_completions.pop(arg_name, None)