Coverage for mcpgateway / services / dcr_service.py: 100%

158 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/dcr_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Manav Gupta 

6 

7OAuth 2.0 Dynamic Client Registration Service. 

8 

9This module handles OAuth 2.0 Dynamic Client Registration (DCR) including: 

10- AS metadata discovery (RFC 8414) 

11- Client registration (RFC 7591) 

12- Client management (update, delete) 

13""" 

14 

15# Standard 

16from datetime import datetime, timezone 

17import logging 

18from typing import Any, Dict, List 

19 

20# Third-Party 

21import httpx 

22import orjson 

23from sqlalchemy.orm import Session 

24 

25# First-Party 

26from mcpgateway.config import get_settings 

27from mcpgateway.db import RegisteredOAuthClient 

28from mcpgateway.services.encryption_service import get_encryption_service 

29from mcpgateway.services.http_client_service import get_http_client 

30 

31logger = logging.getLogger(__name__) 

32 

33# In-memory cache for AS metadata 

34# Format: {issuer: {"metadata": dict, "cached_at": datetime}} 

35_metadata_cache: Dict[str, Dict[str, Any]] = {} 

36 

37 

38class DcrService: 

39 """Service for OAuth 2.0 Dynamic Client Registration (RFC 7591 client).""" 

40 

41 def __init__(self): 

42 """Initialize DCR service.""" 

43 self.settings = get_settings() 

44 

45 async def _get_client(self) -> httpx.AsyncClient: 

46 """Get the shared singleton HTTP client. 

47 

48 Returns: 

49 Shared httpx.AsyncClient instance with connection pooling 

50 """ 

51 return await get_http_client() 

52 

53 def _get_timeout(self) -> float: 

54 """Get the OAuth request timeout from settings. 

55 

56 Returns: 

57 Timeout in seconds for OAuth/DCR requests 

58 """ 

59 return float(self.settings.oauth_request_timeout) 

60 

61 async def discover_as_metadata(self, issuer: str) -> Dict[str, Any]: 

62 """Discover AS metadata via RFC 8414. 

63 

64 Tries: 

65 1. {issuer}/.well-known/oauth-authorization-server (RFC 8414) 

66 2. {issuer}/.well-known/openid-configuration (OIDC fallback) 

67 

68 Args: 

69 issuer: The AS issuer URL 

70 

71 Returns: 

72 Dict containing AS metadata 

73 

74 Raises: 

75 DcrError: If metadata cannot be discovered 

76 """ 

77 # Normalize issuer URL by removing trailing slash for consistency. 

78 # Per RFC 8414 Section 3.1, any terminating "/" MUST be removed before 

79 # inserting "/.well-known/" and the well-known URI suffix. 

80 # This also works around MCP Python SDK issue #1919 where Pydantic's 

81 # AnyHttpUrl adds trailing slashes to bare hostnames. 

82 # See: https://github.com/modelcontextprotocol/python-sdk/issues/1919 

83 normalized_issuer = issuer.rstrip("/") 

84 

85 # Check cache first (using normalized issuer as key for consistency) 

86 if normalized_issuer in _metadata_cache: 

87 cached_entry = _metadata_cache[normalized_issuer] 

88 cached_at = cached_entry["cached_at"] 

89 cache_age = (datetime.now(timezone.utc) - cached_at).total_seconds() 

90 

91 if cache_age < self.settings.dcr_metadata_cache_ttl: 

92 logger.debug(f"Using cached AS metadata for {normalized_issuer}") 

93 return cached_entry["metadata"] 

94 

95 # Try RFC 8414 path first 

96 rfc8414_url = f"{normalized_issuer}/.well-known/oauth-authorization-server" 

97 

98 try: 

99 client = await self._get_client() 

100 response = await client.get(rfc8414_url, timeout=self._get_timeout()) 

101 if response.status_code == 200: 

102 metadata = response.json() 

103 

104 # Validate issuer matches (normalize metadata issuer for comparison) 

105 metadata_issuer = (metadata.get("issuer") or "").rstrip("/") 

106 if metadata_issuer != normalized_issuer: 

107 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}") 

108 

109 # Cache the metadata 

110 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)} 

111 

112 logger.info(f"Discovered AS metadata for {normalized_issuer} via RFC 8414") 

113 return metadata 

114 except httpx.HTTPError as e: 

115 logger.debug(f"RFC 8414 discovery failed for {normalized_issuer}: {e}, trying OIDC fallback") 

116 

117 # Try OIDC discovery fallback 

118 oidc_url = f"{normalized_issuer}/.well-known/openid-configuration" 

119 

120 try: 

121 client = await self._get_client() 

122 response = await client.get(oidc_url, timeout=self._get_timeout()) 

123 if response.status_code == 200: 

124 metadata = response.json() 

125 

126 # Validate issuer matches (normalize metadata issuer for comparison) 

127 metadata_issuer = (metadata.get("issuer") or "").rstrip("/") 

128 if metadata_issuer != normalized_issuer: 

129 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}") 

130 

131 # Cache the metadata 

132 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)} 

133 

134 logger.info(f"Discovered AS metadata for {normalized_issuer} via OIDC discovery") 

135 return metadata 

136 

137 raise DcrError(f"AS metadata not found for {normalized_issuer} (status: {response.status_code})") 

138 except httpx.HTTPError as e: 

139 raise DcrError(f"Failed to discover AS metadata for {normalized_issuer}: {e}") 

140 

141 async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient: 

142 """Register as OAuth client with upstream AS (RFC 7591). 

143 

144 Args: 

145 gateway_id: Gateway ID 

146 gateway_name: Gateway name 

147 issuer: AS issuer URL 

148 redirect_uri: OAuth redirect URI 

149 scopes: List of OAuth scopes 

150 db: Database session 

151 

152 Returns: 

153 RegisteredOAuthClient record 

154 

155 Raises: 

156 DcrError: If registration fails 

157 """ 

158 # Normalize issuer URL for consistent storage and lookup 

159 normalized_issuer = issuer.rstrip("/") 

160 

161 # Validate issuer if allowlist is configured (normalize both for comparison) 

162 if self.settings.dcr_allowed_issuers: 

163 normalized_allowlist = [i.rstrip("/") for i in self.settings.dcr_allowed_issuers] 

164 if normalized_issuer not in normalized_allowlist: 

165 raise DcrError(f"Issuer {issuer} is not in allowed issuers list") 

166 

167 # Discover AS metadata 

168 metadata = await self.discover_as_metadata(normalized_issuer) 

169 

170 registration_endpoint = metadata.get("registration_endpoint") 

171 if not registration_endpoint: 

172 raise DcrError(f"AS {normalized_issuer} does not support Dynamic Client Registration (no registration_endpoint)") 

173 

174 # Build registration request (RFC 7591) 

175 client_name = self.settings.dcr_client_name_template.replace("{gateway_name}", gateway_name) 

176 

177 # Determine grant types based on AS metadata 

178 # Use `or []` to handle both missing key AND explicit null value (prevents TypeError) 

179 grant_types_supported = metadata.get("grant_types_supported") or [] 

180 requested_grant_types = ["authorization_code"] 

181 

182 # Only request refresh_token if AS explicitly supports it, or if permissive mode is enabled 

183 if "refresh_token" in grant_types_supported: 

184 requested_grant_types.append("refresh_token") 

185 elif self.settings.dcr_request_refresh_token_when_unsupported and not grant_types_supported: 

186 # Permissive mode: request refresh_token when AS doesn't advertise grant_types_supported 

187 # This is useful for AS servers that support refresh tokens but don't advertise it 

188 requested_grant_types.append("refresh_token") 

189 logger.debug(f"Requesting refresh_token for {normalized_issuer} (permissive mode, AS omits grant_types_supported)") 

190 

191 registration_request = { 

192 "client_name": client_name, 

193 "redirect_uris": [redirect_uri], 

194 "grant_types": requested_grant_types, 

195 "response_types": ["code"], 

196 "token_endpoint_auth_method": self.settings.dcr_token_endpoint_auth_method, 

197 "scope": " ".join(scopes), 

198 } 

199 

200 # Send registration request 

201 try: 

202 client = await self._get_client() 

203 response = await client.post(registration_endpoint, json=registration_request, timeout=self._get_timeout()) 

204 # Accept both 200 OK and 201 Created (some servers don't follow RFC 7591 strictly) 

205 if response.status_code in (200, 201): 

206 registration_response = response.json() 

207 else: 

208 error_data = response.json() 

209 error_msg = error_data.get("error", "unknown_error") 

210 error_desc = error_data.get("error_description", str(error_data)) 

211 raise DcrError(f"Client registration failed: {error_msg} - {error_desc}") 

212 except httpx.HTTPError as e: 

213 raise DcrError(f"Failed to register client with {normalized_issuer}: {e}") 

214 

215 # Encrypt secrets 

216 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

217 

218 client_secret = registration_response.get("client_secret") 

219 client_secret_encrypted = await encryption.encrypt_secret_async(client_secret) if client_secret else None 

220 

221 registration_access_token = registration_response.get("registration_access_token") 

222 registration_access_token_encrypted = await encryption.encrypt_secret_async(registration_access_token) if registration_access_token else None 

223 

224 # Calculate expires at 

225 expires_at = None 

226 client_secret_expires_at = registration_response.get("client_secret_expires_at") 

227 if client_secret_expires_at and client_secret_expires_at > 0: 

228 expires_at = datetime.fromtimestamp(client_secret_expires_at, tz=timezone.utc) 

229 

230 # Create database record (use normalized issuer for consistent lookup) 

231 # Fall back to requested grant_types if AS response omits them 

232 registered_client = RegisteredOAuthClient( 

233 gateway_id=gateway_id, 

234 issuer=normalized_issuer, 

235 client_id=registration_response["client_id"], 

236 client_secret_encrypted=client_secret_encrypted, 

237 redirect_uris=orjson.dumps(registration_response.get("redirect_uris", [redirect_uri])).decode(), 

238 grant_types=orjson.dumps(registration_response.get("grant_types", requested_grant_types)).decode(), 

239 response_types=orjson.dumps(registration_response.get("response_types", ["code"])).decode(), 

240 scope=registration_response.get("scope", " ".join(scopes)), 

241 token_endpoint_auth_method=registration_response.get("token_endpoint_auth_method", self.settings.dcr_token_endpoint_auth_method), 

242 registration_client_uri=registration_response.get("registration_client_uri"), 

243 registration_access_token_encrypted=registration_access_token_encrypted, 

244 created_at=datetime.now(timezone.utc), 

245 expires_at=expires_at, 

246 is_active=True, 

247 ) 

248 

249 db.add(registered_client) 

250 db.commit() 

251 db.refresh(registered_client) 

252 

253 logger.info(f"Successfully registered client {registered_client.client_id} with {normalized_issuer} for gateway {gateway_id}") 

254 

255 return registered_client 

256 

257 async def get_or_register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient: 

258 """Get existing registered client or register new one. 

259 

260 Args: 

261 gateway_id: Gateway ID 

262 gateway_name: Gateway name 

263 issuer: AS issuer URL 

264 redirect_uri: OAuth redirect URI 

265 scopes: List of OAuth scopes 

266 db: Database session 

267 

268 Returns: 

269 RegisteredOAuthClient record 

270 

271 Raises: 

272 DcrError: If client not found and auto-register is disabled 

273 """ 

274 # Normalize issuer for consistent lookup (matches how register_client stores it) 

275 normalized_issuer = issuer.rstrip("/") 

276 

277 # Try to find existing client using normalized issuer 

278 existing_client = ( 

279 db.query(RegisteredOAuthClient) 

280 .filter( 

281 RegisteredOAuthClient.gateway_id == gateway_id, RegisteredOAuthClient.issuer == normalized_issuer, RegisteredOAuthClient.is_active.is_(True) 

282 ) # pylint: disable=singleton-comparison 

283 .first() 

284 ) 

285 

286 if existing_client: 

287 logger.debug(f"Found existing registered client for gateway {gateway_id} and issuer {normalized_issuer}") 

288 return existing_client 

289 

290 # No existing client, check if auto-register is enabled 

291 if not self.settings.dcr_auto_register_on_missing_credentials: 

292 raise DcrError( 

293 f"No registered client found for gateway {gateway_id} and issuer {normalized_issuer}. Auto-register is disabled. Set MCPGATEWAY_DCR_AUTO_REGISTER_ON_MISSING_CREDENTIALS=true to enable." 

294 ) 

295 

296 # Auto-register (pass normalized issuer for consistent storage) 

297 logger.info(f"No existing client found for gateway {gateway_id}, registering new client with {normalized_issuer}") 

298 return await self.register_client(gateway_id, gateway_name, normalized_issuer, redirect_uri, scopes, db) 

299 

300 async def update_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> RegisteredOAuthClient: 

301 """Update existing client registration (RFC 7591 section 4.2). 

302 

303 Args: 

304 client_record: Existing RegisteredOAuthClient record 

305 db: Database session 

306 

307 Returns: 

308 Updated RegisteredOAuthClient record 

309 

310 Raises: 

311 DcrError: If update fails 

312 """ 

313 if not client_record.registration_client_uri: 

314 raise DcrError("Cannot update client: no registration_client_uri available") 

315 

316 if not client_record.registration_access_token_encrypted: 

317 raise DcrError("Cannot update client: no registration_access_token available") 

318 

319 # Decrypt registration access token 

320 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

321 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted) 

322 

323 # Build update request 

324 update_request = {"client_id": client_record.client_id, "redirect_uris": orjson.loads(client_record.redirect_uris), "grant_types": orjson.loads(client_record.grant_types)} 

325 

326 # Send update request 

327 try: 

328 client = await self._get_client() 

329 headers = {"Authorization": f"Bearer {registration_access_token}"} 

330 response = await client.put(client_record.registration_client_uri, json=update_request, headers=headers, timeout=self._get_timeout()) 

331 if response.status_code == 200: 

332 updated_response = response.json() 

333 

334 # Update encrypted secret if changed 

335 if "client_secret" in updated_response: 

336 client_record.client_secret_encrypted = await encryption.encrypt_secret_async(updated_response["client_secret"]) 

337 

338 db.commit() 

339 db.refresh(client_record) 

340 

341 logger.info(f"Successfully updated client registration for {client_record.client_id}") 

342 return client_record 

343 

344 error_data = response.json() 

345 raise DcrError(f"Failed to update client: {error_data}") 

346 except httpx.HTTPError as e: 

347 raise DcrError(f"Failed to update client registration: {e}") 

348 

349 async def delete_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> bool: # pylint: disable=unused-argument 

350 """Delete/revoke client registration (RFC 7591 section 4.3). 

351 

352 Args: 

353 client_record: RegisteredOAuthClient record to delete 

354 db: Database session 

355 

356 Returns: 

357 True if deletion succeeded 

358 

359 Raises: 

360 DcrError: If deletion fails (except 404) 

361 """ 

362 if not client_record.registration_client_uri: 

363 logger.warning("Cannot delete client at AS: no registration_client_uri") 

364 return True # Consider it deleted locally 

365 

366 if not client_record.registration_access_token_encrypted: 

367 logger.warning("Cannot delete client at AS: no registration_access_token") 

368 return True # Consider it deleted locally 

369 

370 # Decrypt registration access token 

371 encryption = get_encryption_service(self.settings.auth_encryption_secret) 

372 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted) 

373 

374 # Send delete request 

375 try: 

376 client = await self._get_client() 

377 headers = {"Authorization": f"Bearer {registration_access_token}"} 

378 response = await client.delete(client_record.registration_client_uri, headers=headers, timeout=self._get_timeout()) 

379 if response.status_code in [204, 404]: # 204 = deleted, 404 = already gone 

380 logger.info(f"Successfully deleted client registration for {client_record.client_id}") 

381 return True 

382 

383 logger.warning(f"Unexpected status when deleting client: {response.status_code}") 

384 return True # Consider it best-effort 

385 except httpx.HTTPError as e: 

386 logger.warning(f"Failed to delete client at AS: {e}") 

387 return True # Best-effort, don't fail if AS is unreachable 

388 

389 

390class DcrError(Exception): 

391 """DCR-related errors."""