Coverage for mcpgateway / services / dcr_service.py: 93%

167 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/dcr_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Manav Gupta 

6 

7OAuth 2.0 Dynamic Client Registration Service. 

8 

9This module handles OAuth 2.0 Dynamic Client Registration (DCR) including: 

10- AS metadata discovery (RFC 8414) 

11- Client registration (RFC 7591) 

12- Client management (update, delete) 

13""" 

14 

15# Standard 

16from datetime import datetime, timezone 

17import logging 

18from typing import Any, Dict, List 

19from urllib.parse import urlsplit 

20 

21# Third-Party 

22import httpx 

23import orjson 

24from sqlalchemy.orm import Session 

25 

26# First-Party 

27from mcpgateway.config import get_settings 

28from mcpgateway.db import RegisteredOAuthClient 

29from mcpgateway.services.encryption_service import get_encryption_service 

30from mcpgateway.services.http_client_service import get_http_client 

31 

32logger = logging.getLogger(__name__) 

33 

34# In-memory cache for AS metadata 

35# Format: {issuer: {"metadata": dict, "cached_at": datetime}} 

36_metadata_cache: Dict[str, Dict[str, Any]] = {} 

37 

38 

39class DcrService: 

40 """Service for OAuth 2.0 Dynamic Client Registration (RFC 7591 client).""" 

41 

42 def __init__(self): 

43 """Initialize DCR service.""" 

44 self.settings = get_settings() 

45 

46 async def _get_client(self) -> httpx.AsyncClient: 

47 """Get the shared singleton HTTP client. 

48 

49 Returns: 

50 Shared httpx.AsyncClient instance with connection pooling 

51 """ 

52 return await get_http_client() 

53 

54 def _get_timeout(self) -> float: 

55 """Get the OAuth request timeout from settings. 

56 

57 Returns: 

58 Timeout in seconds for OAuth/DCR requests 

59 """ 

60 return float(self.settings.oauth_request_timeout) 

61 

62 async def discover_as_metadata(self, issuer: str) -> Dict[str, Any]: 

63 """Discover AS metadata via RFC 8414. 

64 

65 Tries: 

66 1. RFC 8414: /.well-known/oauth-authorization-server inserted between host and path 

67 2. OIDC fallback: {issuer}/.well-known/openid-configuration 

68 

69 Args: 

70 issuer: The AS issuer URL 

71 

72 Returns: 

73 Dict containing AS metadata 

74 

75 Raises: 

76 DcrError: If metadata cannot be discovered 

77 """ 

78 # Normalize issuer URL by removing trailing slash for consistency. 

79 # Per RFC 8414 Section 3.1, any terminating "/" MUST be removed before 

80 # inserting "/.well-known/" and the well-known URI suffix. 

81 # This also works around MCP Python SDK issue #1919 where Pydantic's 

82 # AnyHttpUrl adds trailing slashes to bare hostnames. 

83 # See: https://github.com/modelcontextprotocol/python-sdk/issues/1919 

84 normalized_issuer = issuer.rstrip("/") 

85 

86 # Check cache first (using normalized issuer as key for consistency) 

87 if normalized_issuer in _metadata_cache: 

88 cached_entry = _metadata_cache[normalized_issuer] 

89 cached_at = cached_entry["cached_at"] 

90 cache_age = (datetime.now(timezone.utc) - cached_at).total_seconds() 

91 

92 if cache_age < self.settings.dcr_metadata_cache_ttl: 

93 logger.debug(f"Using cached AS metadata for {normalized_issuer}") 

94 return cached_entry["metadata"] 

95 

96 # Try RFC 8414 path first 

97 # Per RFC 8414 Section 3.1: "the well-known URI is formed by inserting the 

98 # well-known URI string... between the host component and any existing path 

99 # component of the issuer's identifier". 

100 # See: https://datatracker.ietf.org/doc/html/rfc8414#section-3.1 

101 parsed = urlsplit(normalized_issuer) 

102 rfc8414_url = f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server" 

103 if parsed.path: 

104 rfc8414_url += parsed.path 

105 

106 try: 

107 client = await self._get_client() 

108 response = await client.get(rfc8414_url, timeout=self._get_timeout()) 

109 if response.status_code == 200: 

110 metadata = response.json() 

111 

112 # Validate issuer matches (normalize metadata issuer for comparison) 

113 metadata_issuer = (metadata.get("issuer") or "").rstrip("/") 

114 if metadata_issuer != normalized_issuer: 

115 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}") 

116 

117 # Cache the metadata 

118 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)} 

119 

120 logger.info(f"Discovered AS metadata for {normalized_issuer} via RFC 8414") 

121 return metadata 

122 except httpx.HTTPError as e: 

123 logger.debug(f"RFC 8414 discovery failed for {normalized_issuer}: {e}, trying OIDC fallback") 

124 

125 # Try OIDC discovery fallback 

126 oidc_url = f"{normalized_issuer}/.well-known/openid-configuration" 

127 

128 try: 

129 client = await self._get_client() 

130 response = await client.get(oidc_url, timeout=self._get_timeout()) 

131 if response.status_code == 200: 

132 metadata = response.json() 

133 

134 # Validate issuer matches (normalize metadata issuer for comparison) 

135 metadata_issuer = (metadata.get("issuer") or "").rstrip("/") 

136 if metadata_issuer != normalized_issuer: 

137 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}") 

138 

139 # Cache the metadata 

140 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)} 

141 

142 logger.info(f"Discovered AS metadata for {normalized_issuer} via OIDC discovery") 

143 return metadata 

144 

145 raise DcrError(f"AS metadata not found for {normalized_issuer} (status: {response.status_code})") 

146 except httpx.HTTPError as e: 

147 raise DcrError(f"Failed to discover AS metadata for {normalized_issuer}: {e}") 

148 

149 async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient: 

150 """Register as OAuth client with upstream AS (RFC 7591). 

151 

152 Args: 

153 gateway_id: Gateway ID 

154 gateway_name: Gateway name 

155 issuer: AS issuer URL 

156 redirect_uri: OAuth redirect URI 

157 scopes: List of OAuth scopes 

158 db: Database session 

159 

160 Returns: 

161 RegisteredOAuthClient record 

162 

163 Raises: 

164 DcrError: If registration fails 

165 """ 

166 # Normalize issuer URL for consistent storage and lookup 

167 normalized_issuer = issuer.rstrip("/") 

168 

169 # Validate issuer if allowlist is configured (normalize both for comparison) 

170 if self.settings.dcr_allowed_issuers: 

171 normalized_allowlist = [i.rstrip("/") for i in self.settings.dcr_allowed_issuers] 

172 if normalized_issuer not in normalized_allowlist: 

173 raise DcrError(f"Issuer {issuer} is not in allowed issuers list") 

174 

175 # Discover AS metadata 

176 metadata = await self.discover_as_metadata(normalized_issuer) 

177 

178 registration_endpoint = metadata.get("registration_endpoint") 

179 if not registration_endpoint: 

180 raise DcrError(f"AS {normalized_issuer} does not support Dynamic Client Registration (no registration_endpoint)") 

181 

182 # Build registration request (RFC 7591) 

183 client_name = self.settings.dcr_client_name_template.replace("{gateway_name}", gateway_name) 

184 

185 # Determine grant types based on AS metadata 

186 # Use `or []` to handle both missing key AND explicit null value (prevents TypeError) 

187 grant_types_supported = metadata.get("grant_types_supported") or [] 

188 requested_grant_types = ["authorization_code"] 

189 

190 # Only request refresh_token if AS explicitly supports it, or if permissive mode is enabled 

191 if "refresh_token" in grant_types_supported: 

192 requested_grant_types.append("refresh_token") 

193 elif self.settings.dcr_request_refresh_token_when_unsupported and not grant_types_supported: 

194 # Permissive mode: request refresh_token when AS doesn't advertise grant_types_supported 

195 # This is useful for AS servers that support refresh tokens but don't advertise it 

196 requested_grant_types.append("refresh_token") 

197 logger.debug(f"Requesting refresh_token for {normalized_issuer} (permissive mode, AS omits grant_types_supported)") 

198 

199 registration_request = { 

200 "client_name": client_name, 

201 "redirect_uris": [redirect_uri], 

202 "grant_types": requested_grant_types, 

203 "response_types": ["code"], 

204 "token_endpoint_auth_method": self.settings.dcr_token_endpoint_auth_method, 

205 "scope": " ".join(scopes), 

206 } 

207 

208 # Send registration request 

209 try: 

210 client = await self._get_client() 

211 response = await client.post(registration_endpoint, json=registration_request, timeout=self._get_timeout()) 

212 # Accept both 200 OK and 201 Created (some servers don't follow RFC 7591 strictly) 

213 if response.status_code in (200, 201): 

214 registration_response = response.json() 

215 else: 

216 error_data = response.json() 

217 error_msg = error_data.get("error", "unknown_error") 

218 error_desc = error_data.get("error_description", str(error_data)) 

219 raise DcrError(f"Client registration failed: {error_msg} - {error_desc}") 

220 except httpx.HTTPError as e: 

221 raise DcrError(f"Failed to register client with {normalized_issuer}: {e}") 

222 

223 # Encrypt secrets 

224 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

225 

226 client_secret = registration_response.get("client_secret") 

227 client_secret_encrypted = await encryption.encrypt_secret_async(client_secret) if client_secret else None 

228 

229 registration_access_token = registration_response.get("registration_access_token") 

230 registration_access_token_encrypted = await encryption.encrypt_secret_async(registration_access_token) if registration_access_token else None 

231 

232 # Calculate expires at 

233 expires_at = None 

234 client_secret_expires_at = registration_response.get("client_secret_expires_at") 

235 if client_secret_expires_at and client_secret_expires_at > 0: 

236 expires_at = datetime.fromtimestamp(client_secret_expires_at, tz=timezone.utc) 

237 

238 # Create database record (use normalized issuer for consistent lookup) 

239 # Fall back to requested grant_types if AS response omits them 

240 registered_client = RegisteredOAuthClient( 

241 gateway_id=gateway_id, 

242 issuer=normalized_issuer, 

243 client_id=registration_response["client_id"], 

244 client_secret_encrypted=client_secret_encrypted, 

245 redirect_uris=orjson.dumps(registration_response.get("redirect_uris", [redirect_uri])).decode(), 

246 grant_types=orjson.dumps(registration_response.get("grant_types", requested_grant_types)).decode(), 

247 response_types=orjson.dumps(registration_response.get("response_types", ["code"])).decode(), 

248 scope=registration_response.get("scope", " ".join(scopes)), 

249 token_endpoint_auth_method=registration_response.get("token_endpoint_auth_method", self.settings.dcr_token_endpoint_auth_method), 

250 registration_client_uri=registration_response.get("registration_client_uri"), 

251 registration_access_token_encrypted=registration_access_token_encrypted, 

252 created_at=datetime.now(timezone.utc), 

253 expires_at=expires_at, 

254 is_active=True, 

255 ) 

256 

257 db.add(registered_client) 

258 db.commit() 

259 db.refresh(registered_client) 

260 

261 logger.info(f"Successfully registered client {registered_client.client_id} with {normalized_issuer} for gateway {gateway_id}") 

262 

263 return registered_client 

264 

265 async def get_or_register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient: 

266 """Get existing registered client or register new one. 

267 

268 Args: 

269 gateway_id: Gateway ID 

270 gateway_name: Gateway name 

271 issuer: AS issuer URL 

272 redirect_uri: OAuth redirect URI 

273 scopes: List of OAuth scopes 

274 db: Database session 

275 

276 Returns: 

277 RegisteredOAuthClient record 

278 

279 Raises: 

280 DcrError: If client not found and auto-register is disabled 

281 """ 

282 # Normalize issuer for consistent lookup (matches how register_client stores it) 

283 normalized_issuer = issuer.rstrip("/") 

284 

285 # Try to find existing client using normalized issuer 

286 existing_client = ( 

287 db.query(RegisteredOAuthClient) 

288 .filter( 

289 RegisteredOAuthClient.gateway_id == gateway_id, RegisteredOAuthClient.issuer == normalized_issuer, RegisteredOAuthClient.is_active.is_(True) 

290 ) # pylint: disable=singleton-comparison 

291 .first() 

292 ) 

293 

294 if existing_client: 

295 logger.debug(f"Found existing registered client for gateway {gateway_id} and issuer {normalized_issuer}") 

296 return existing_client 

297 

298 # No existing client, check if auto-register is enabled 

299 if not self.settings.dcr_auto_register_on_missing_credentials: 

300 raise DcrError( 

301 f"No registered client found for gateway {gateway_id} and issuer {normalized_issuer}. Auto-register is disabled. Set MCPGATEWAY_DCR_AUTO_REGISTER_ON_MISSING_CREDENTIALS=true to enable." 

302 ) 

303 

304 # Auto-register (pass normalized issuer for consistent storage) 

305 logger.info(f"No existing client found for gateway {gateway_id}, registering new client with {normalized_issuer}") 

306 return await self.register_client(gateway_id, gateway_name, normalized_issuer, redirect_uri, scopes, db) 

307 

308 async def update_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> RegisteredOAuthClient: 

309 """Update existing client registration (RFC 7591 section 4.2). 

310 

311 Args: 

312 client_record: Existing RegisteredOAuthClient record 

313 db: Database session 

314 

315 Returns: 

316 Updated RegisteredOAuthClient record 

317 

318 Raises: 

319 DcrError: If update fails 

320 """ 

321 if not client_record.registration_client_uri: 

322 raise DcrError("Cannot update client: no registration_client_uri available") 

323 

324 if not client_record.registration_access_token_encrypted: 

325 raise DcrError("Cannot update client: no registration_access_token available") 

326 

327 # Decrypt registration access token 

328 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

329 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted) 

330 if registration_access_token is None: 

331 raise DcrError("Failed to decrypt registration access token for update operation") 

332 

333 # Build update request 

334 update_request = {"client_id": client_record.client_id, "redirect_uris": orjson.loads(client_record.redirect_uris), "grant_types": orjson.loads(client_record.grant_types)} 

335 

336 # Send update request 

337 try: 

338 client = await self._get_client() 

339 headers = {"Authorization": f"Bearer {registration_access_token}"} 

340 response = await client.put(client_record.registration_client_uri, json=update_request, headers=headers, timeout=self._get_timeout()) 

341 if response.status_code == 200: 

342 updated_response = response.json() 

343 

344 # Update encrypted secret if changed 

345 if "client_secret" in updated_response: 

346 client_record.client_secret_encrypted = await encryption.encrypt_secret_async(updated_response["client_secret"]) 

347 

348 db.commit() 

349 db.refresh(client_record) 

350 

351 logger.info(f"Successfully updated client registration for {client_record.client_id}") 

352 return client_record 

353 

354 error_data = response.json() 

355 raise DcrError(f"Failed to update client: {error_data}") 

356 except httpx.HTTPError as e: 

357 raise DcrError(f"Failed to update client registration: {e}") 

358 

359 async def delete_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> bool: # pylint: disable=unused-argument 

360 """Delete/revoke client registration (RFC 7591 section 4.3). 

361 

362 Args: 

363 client_record: RegisteredOAuthClient record to delete 

364 db: Database session 

365 

366 Returns: 

367 bool: True if deletion succeeded at the Authorization Server. 

368 False if deletion failed (missing prerequisites, decryption error, network error). 

369 Note: Does not guarantee local database deletion. 

370 

371 Raises: 

372 DcrError: If deletion fails catastrophically 

373 """ 

374 if not client_record.registration_client_uri: 

375 logger.warning("Cannot delete client at AS: no registration_client_uri") 

376 return False 

377 

378 if not client_record.registration_access_token_encrypted: 

379 logger.warning("Cannot delete client at AS: no registration_access_token") 

380 return False 

381 

382 # Decrypt registration access token 

383 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

384 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted) 

385 if registration_access_token is None: 

386 logger.error("Failed to decrypt registration access token; cannot authenticate delete request to AS") 

387 return False 

388 

389 # Send delete request 

390 try: 

391 client = await self._get_client() 

392 headers = {"Authorization": f"Bearer {registration_access_token}"} 

393 response = await client.delete(client_record.registration_client_uri, headers=headers, timeout=self._get_timeout()) 

394 if response.status_code in [204, 404]: # 204 = deleted, 404 = already gone 

395 logger.info(f"Successfully deleted client registration for {client_record.client_id}") 

396 return True 

397 

398 logger.warning(f"Unexpected status when deleting client: {response.status_code}") 

399 return False 

400 except httpx.HTTPError as e: 

401 logger.error(f"Failed to delete client at AS: {e}") 

402 return False 

403 

404 

405class DcrError(Exception): 

406 """DCR-related errors."""