Coverage for mcpgateway / services / dcr_service.py: 93%
167 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/dcr_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Manav Gupta
7OAuth 2.0 Dynamic Client Registration Service.
9This module handles OAuth 2.0 Dynamic Client Registration (DCR) including:
10- AS metadata discovery (RFC 8414)
11- Client registration (RFC 7591)
12- Client management (update, delete)
13"""
15# Standard
16from datetime import datetime, timezone
17import logging
18from typing import Any, Dict, List
19from urllib.parse import urlsplit
21# Third-Party
22import httpx
23import orjson
24from sqlalchemy.orm import Session
26# First-Party
27from mcpgateway.config import get_settings
28from mcpgateway.db import RegisteredOAuthClient
29from mcpgateway.services.encryption_service import get_encryption_service
30from mcpgateway.services.http_client_service import get_http_client
32logger = logging.getLogger(__name__)
34# In-memory cache for AS metadata
35# Format: {issuer: {"metadata": dict, "cached_at": datetime}}
36_metadata_cache: Dict[str, Dict[str, Any]] = {}
39class DcrService:
40 """Service for OAuth 2.0 Dynamic Client Registration (RFC 7591 client)."""
42 def __init__(self):
43 """Initialize DCR service."""
44 self.settings = get_settings()
46 async def _get_client(self) -> httpx.AsyncClient:
47 """Get the shared singleton HTTP client.
49 Returns:
50 Shared httpx.AsyncClient instance with connection pooling
51 """
52 return await get_http_client()
54 def _get_timeout(self) -> float:
55 """Get the OAuth request timeout from settings.
57 Returns:
58 Timeout in seconds for OAuth/DCR requests
59 """
60 return float(self.settings.oauth_request_timeout)
62 async def discover_as_metadata(self, issuer: str) -> Dict[str, Any]:
63 """Discover AS metadata via RFC 8414.
65 Tries:
66 1. RFC 8414: /.well-known/oauth-authorization-server inserted between host and path
67 2. OIDC fallback: {issuer}/.well-known/openid-configuration
69 Args:
70 issuer: The AS issuer URL
72 Returns:
73 Dict containing AS metadata
75 Raises:
76 DcrError: If metadata cannot be discovered
77 """
78 # Normalize issuer URL by removing trailing slash for consistency.
79 # Per RFC 8414 Section 3.1, any terminating "/" MUST be removed before
80 # inserting "/.well-known/" and the well-known URI suffix.
81 # This also works around MCP Python SDK issue #1919 where Pydantic's
82 # AnyHttpUrl adds trailing slashes to bare hostnames.
83 # See: https://github.com/modelcontextprotocol/python-sdk/issues/1919
84 normalized_issuer = issuer.rstrip("/")
86 # Check cache first (using normalized issuer as key for consistency)
87 if normalized_issuer in _metadata_cache:
88 cached_entry = _metadata_cache[normalized_issuer]
89 cached_at = cached_entry["cached_at"]
90 cache_age = (datetime.now(timezone.utc) - cached_at).total_seconds()
92 if cache_age < self.settings.dcr_metadata_cache_ttl:
93 logger.debug(f"Using cached AS metadata for {normalized_issuer}")
94 return cached_entry["metadata"]
96 # Try RFC 8414 path first
97 # Per RFC 8414 Section 3.1: "the well-known URI is formed by inserting the
98 # well-known URI string... between the host component and any existing path
99 # component of the issuer's identifier".
100 # See: https://datatracker.ietf.org/doc/html/rfc8414#section-3.1
101 parsed = urlsplit(normalized_issuer)
102 rfc8414_url = f"{parsed.scheme}://{parsed.netloc}/.well-known/oauth-authorization-server"
103 if parsed.path:
104 rfc8414_url += parsed.path
106 try:
107 client = await self._get_client()
108 response = await client.get(rfc8414_url, timeout=self._get_timeout())
109 if response.status_code == 200:
110 metadata = response.json()
112 # Validate issuer matches (normalize metadata issuer for comparison)
113 metadata_issuer = (metadata.get("issuer") or "").rstrip("/")
114 if metadata_issuer != normalized_issuer:
115 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}")
117 # Cache the metadata
118 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)}
120 logger.info(f"Discovered AS metadata for {normalized_issuer} via RFC 8414")
121 return metadata
122 except httpx.HTTPError as e:
123 logger.debug(f"RFC 8414 discovery failed for {normalized_issuer}: {e}, trying OIDC fallback")
125 # Try OIDC discovery fallback
126 oidc_url = f"{normalized_issuer}/.well-known/openid-configuration"
128 try:
129 client = await self._get_client()
130 response = await client.get(oidc_url, timeout=self._get_timeout())
131 if response.status_code == 200:
132 metadata = response.json()
134 # Validate issuer matches (normalize metadata issuer for comparison)
135 metadata_issuer = (metadata.get("issuer") or "").rstrip("/")
136 if metadata_issuer != normalized_issuer:
137 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}")
139 # Cache the metadata
140 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)}
142 logger.info(f"Discovered AS metadata for {normalized_issuer} via OIDC discovery")
143 return metadata
145 raise DcrError(f"AS metadata not found for {normalized_issuer} (status: {response.status_code})")
146 except httpx.HTTPError as e:
147 raise DcrError(f"Failed to discover AS metadata for {normalized_issuer}: {e}")
149 async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient:
150 """Register as OAuth client with upstream AS (RFC 7591).
152 Args:
153 gateway_id: Gateway ID
154 gateway_name: Gateway name
155 issuer: AS issuer URL
156 redirect_uri: OAuth redirect URI
157 scopes: List of OAuth scopes
158 db: Database session
160 Returns:
161 RegisteredOAuthClient record
163 Raises:
164 DcrError: If registration fails
165 """
166 # Normalize issuer URL for consistent storage and lookup
167 normalized_issuer = issuer.rstrip("/")
169 # Validate issuer if allowlist is configured (normalize both for comparison)
170 if self.settings.dcr_allowed_issuers:
171 normalized_allowlist = [i.rstrip("/") for i in self.settings.dcr_allowed_issuers]
172 if normalized_issuer not in normalized_allowlist:
173 raise DcrError(f"Issuer {issuer} is not in allowed issuers list")
175 # Discover AS metadata
176 metadata = await self.discover_as_metadata(normalized_issuer)
178 registration_endpoint = metadata.get("registration_endpoint")
179 if not registration_endpoint:
180 raise DcrError(f"AS {normalized_issuer} does not support Dynamic Client Registration (no registration_endpoint)")
182 # Build registration request (RFC 7591)
183 client_name = self.settings.dcr_client_name_template.replace("{gateway_name}", gateway_name)
185 # Determine grant types based on AS metadata
186 # Use `or []` to handle both missing key AND explicit null value (prevents TypeError)
187 grant_types_supported = metadata.get("grant_types_supported") or []
188 requested_grant_types = ["authorization_code"]
190 # Only request refresh_token if AS explicitly supports it, or if permissive mode is enabled
191 if "refresh_token" in grant_types_supported:
192 requested_grant_types.append("refresh_token")
193 elif self.settings.dcr_request_refresh_token_when_unsupported and not grant_types_supported:
194 # Permissive mode: request refresh_token when AS doesn't advertise grant_types_supported
195 # This is useful for AS servers that support refresh tokens but don't advertise it
196 requested_grant_types.append("refresh_token")
197 logger.debug(f"Requesting refresh_token for {normalized_issuer} (permissive mode, AS omits grant_types_supported)")
199 registration_request = {
200 "client_name": client_name,
201 "redirect_uris": [redirect_uri],
202 "grant_types": requested_grant_types,
203 "response_types": ["code"],
204 "token_endpoint_auth_method": self.settings.dcr_token_endpoint_auth_method,
205 "scope": " ".join(scopes),
206 }
208 # Send registration request
209 try:
210 client = await self._get_client()
211 response = await client.post(registration_endpoint, json=registration_request, timeout=self._get_timeout())
212 # Accept both 200 OK and 201 Created (some servers don't follow RFC 7591 strictly)
213 if response.status_code in (200, 201):
214 registration_response = response.json()
215 else:
216 error_data = response.json()
217 error_msg = error_data.get("error", "unknown_error")
218 error_desc = error_data.get("error_description", str(error_data))
219 raise DcrError(f"Client registration failed: {error_msg} - {error_desc}")
220 except httpx.HTTPError as e:
221 raise DcrError(f"Failed to register client with {normalized_issuer}: {e}")
223 # Encrypt secrets
224 encryption = get_encryption_service(self.settings.auth_encryption_secret)
226 client_secret = registration_response.get("client_secret")
227 client_secret_encrypted = await encryption.encrypt_secret_async(client_secret) if client_secret else None
229 registration_access_token = registration_response.get("registration_access_token")
230 registration_access_token_encrypted = await encryption.encrypt_secret_async(registration_access_token) if registration_access_token else None
232 # Calculate expires at
233 expires_at = None
234 client_secret_expires_at = registration_response.get("client_secret_expires_at")
235 if client_secret_expires_at and client_secret_expires_at > 0:
236 expires_at = datetime.fromtimestamp(client_secret_expires_at, tz=timezone.utc)
238 # Create database record (use normalized issuer for consistent lookup)
239 # Fall back to requested grant_types if AS response omits them
240 registered_client = RegisteredOAuthClient(
241 gateway_id=gateway_id,
242 issuer=normalized_issuer,
243 client_id=registration_response["client_id"],
244 client_secret_encrypted=client_secret_encrypted,
245 redirect_uris=orjson.dumps(registration_response.get("redirect_uris", [redirect_uri])).decode(),
246 grant_types=orjson.dumps(registration_response.get("grant_types", requested_grant_types)).decode(),
247 response_types=orjson.dumps(registration_response.get("response_types", ["code"])).decode(),
248 scope=registration_response.get("scope", " ".join(scopes)),
249 token_endpoint_auth_method=registration_response.get("token_endpoint_auth_method", self.settings.dcr_token_endpoint_auth_method),
250 registration_client_uri=registration_response.get("registration_client_uri"),
251 registration_access_token_encrypted=registration_access_token_encrypted,
252 created_at=datetime.now(timezone.utc),
253 expires_at=expires_at,
254 is_active=True,
255 )
257 db.add(registered_client)
258 db.commit()
259 db.refresh(registered_client)
261 logger.info(f"Successfully registered client {registered_client.client_id} with {normalized_issuer} for gateway {gateway_id}")
263 return registered_client
265 async def get_or_register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient:
266 """Get existing registered client or register new one.
268 Args:
269 gateway_id: Gateway ID
270 gateway_name: Gateway name
271 issuer: AS issuer URL
272 redirect_uri: OAuth redirect URI
273 scopes: List of OAuth scopes
274 db: Database session
276 Returns:
277 RegisteredOAuthClient record
279 Raises:
280 DcrError: If client not found and auto-register is disabled
281 """
282 # Normalize issuer for consistent lookup (matches how register_client stores it)
283 normalized_issuer = issuer.rstrip("/")
285 # Try to find existing client using normalized issuer
286 existing_client = (
287 db.query(RegisteredOAuthClient)
288 .filter(
289 RegisteredOAuthClient.gateway_id == gateway_id, RegisteredOAuthClient.issuer == normalized_issuer, RegisteredOAuthClient.is_active.is_(True)
290 ) # pylint: disable=singleton-comparison
291 .first()
292 )
294 if existing_client:
295 logger.debug(f"Found existing registered client for gateway {gateway_id} and issuer {normalized_issuer}")
296 return existing_client
298 # No existing client, check if auto-register is enabled
299 if not self.settings.dcr_auto_register_on_missing_credentials:
300 raise DcrError(
301 f"No registered client found for gateway {gateway_id} and issuer {normalized_issuer}. Auto-register is disabled. Set MCPGATEWAY_DCR_AUTO_REGISTER_ON_MISSING_CREDENTIALS=true to enable."
302 )
304 # Auto-register (pass normalized issuer for consistent storage)
305 logger.info(f"No existing client found for gateway {gateway_id}, registering new client with {normalized_issuer}")
306 return await self.register_client(gateway_id, gateway_name, normalized_issuer, redirect_uri, scopes, db)
308 async def update_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> RegisteredOAuthClient:
309 """Update existing client registration (RFC 7591 section 4.2).
311 Args:
312 client_record: Existing RegisteredOAuthClient record
313 db: Database session
315 Returns:
316 Updated RegisteredOAuthClient record
318 Raises:
319 DcrError: If update fails
320 """
321 if not client_record.registration_client_uri:
322 raise DcrError("Cannot update client: no registration_client_uri available")
324 if not client_record.registration_access_token_encrypted:
325 raise DcrError("Cannot update client: no registration_access_token available")
327 # Decrypt registration access token
328 encryption = get_encryption_service(self.settings.auth_encryption_secret)
329 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted)
330 if registration_access_token is None:
331 raise DcrError("Failed to decrypt registration access token for update operation")
333 # Build update request
334 update_request = {"client_id": client_record.client_id, "redirect_uris": orjson.loads(client_record.redirect_uris), "grant_types": orjson.loads(client_record.grant_types)}
336 # Send update request
337 try:
338 client = await self._get_client()
339 headers = {"Authorization": f"Bearer {registration_access_token}"}
340 response = await client.put(client_record.registration_client_uri, json=update_request, headers=headers, timeout=self._get_timeout())
341 if response.status_code == 200:
342 updated_response = response.json()
344 # Update encrypted secret if changed
345 if "client_secret" in updated_response:
346 client_record.client_secret_encrypted = await encryption.encrypt_secret_async(updated_response["client_secret"])
348 db.commit()
349 db.refresh(client_record)
351 logger.info(f"Successfully updated client registration for {client_record.client_id}")
352 return client_record
354 error_data = response.json()
355 raise DcrError(f"Failed to update client: {error_data}")
356 except httpx.HTTPError as e:
357 raise DcrError(f"Failed to update client registration: {e}")
359 async def delete_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> bool: # pylint: disable=unused-argument
360 """Delete/revoke client registration (RFC 7591 section 4.3).
362 Args:
363 client_record: RegisteredOAuthClient record to delete
364 db: Database session
366 Returns:
367 bool: True if deletion succeeded at the Authorization Server.
368 False if deletion failed (missing prerequisites, decryption error, network error).
369 Note: Does not guarantee local database deletion.
371 Raises:
372 DcrError: If deletion fails catastrophically
373 """
374 if not client_record.registration_client_uri:
375 logger.warning("Cannot delete client at AS: no registration_client_uri")
376 return False
378 if not client_record.registration_access_token_encrypted:
379 logger.warning("Cannot delete client at AS: no registration_access_token")
380 return False
382 # Decrypt registration access token
383 encryption = get_encryption_service(self.settings.auth_encryption_secret)
384 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted)
385 if registration_access_token is None:
386 logger.error("Failed to decrypt registration access token; cannot authenticate delete request to AS")
387 return False
389 # Send delete request
390 try:
391 client = await self._get_client()
392 headers = {"Authorization": f"Bearer {registration_access_token}"}
393 response = await client.delete(client_record.registration_client_uri, headers=headers, timeout=self._get_timeout())
394 if response.status_code in [204, 404]: # 204 = deleted, 404 = already gone
395 logger.info(f"Successfully deleted client registration for {client_record.client_id}")
396 return True
398 logger.warning(f"Unexpected status when deleting client: {response.status_code}")
399 return False
400 except httpx.HTTPError as e:
401 logger.error(f"Failed to delete client at AS: {e}")
402 return False
405class DcrError(Exception):
406 """DCR-related errors."""