Coverage for mcpgateway / services / dcr_service.py: 100%
158 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/dcr_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Manav Gupta
7OAuth 2.0 Dynamic Client Registration Service.
9This module handles OAuth 2.0 Dynamic Client Registration (DCR) including:
10- AS metadata discovery (RFC 8414)
11- Client registration (RFC 7591)
12- Client management (update, delete)
13"""
15# Standard
16from datetime import datetime, timezone
17import logging
18from typing import Any, Dict, List
20# Third-Party
21import httpx
22import orjson
23from sqlalchemy.orm import Session
25# First-Party
26from mcpgateway.config import get_settings
27from mcpgateway.db import RegisteredOAuthClient
28from mcpgateway.services.encryption_service import get_encryption_service
29from mcpgateway.services.http_client_service import get_http_client
31logger = logging.getLogger(__name__)
33# In-memory cache for AS metadata
34# Format: {issuer: {"metadata": dict, "cached_at": datetime}}
35_metadata_cache: Dict[str, Dict[str, Any]] = {}
38class DcrService:
39 """Service for OAuth 2.0 Dynamic Client Registration (RFC 7591 client)."""
41 def __init__(self):
42 """Initialize DCR service."""
43 self.settings = get_settings()
45 async def _get_client(self) -> httpx.AsyncClient:
46 """Get the shared singleton HTTP client.
48 Returns:
49 Shared httpx.AsyncClient instance with connection pooling
50 """
51 return await get_http_client()
53 def _get_timeout(self) -> float:
54 """Get the OAuth request timeout from settings.
56 Returns:
57 Timeout in seconds for OAuth/DCR requests
58 """
59 return float(self.settings.oauth_request_timeout)
61 async def discover_as_metadata(self, issuer: str) -> Dict[str, Any]:
62 """Discover AS metadata via RFC 8414.
64 Tries:
65 1. {issuer}/.well-known/oauth-authorization-server (RFC 8414)
66 2. {issuer}/.well-known/openid-configuration (OIDC fallback)
68 Args:
69 issuer: The AS issuer URL
71 Returns:
72 Dict containing AS metadata
74 Raises:
75 DcrError: If metadata cannot be discovered
76 """
77 # Normalize issuer URL by removing trailing slash for consistency.
78 # Per RFC 8414 Section 3.1, any terminating "/" MUST be removed before
79 # inserting "/.well-known/" and the well-known URI suffix.
80 # This also works around MCP Python SDK issue #1919 where Pydantic's
81 # AnyHttpUrl adds trailing slashes to bare hostnames.
82 # See: https://github.com/modelcontextprotocol/python-sdk/issues/1919
83 normalized_issuer = issuer.rstrip("/")
85 # Check cache first (using normalized issuer as key for consistency)
86 if normalized_issuer in _metadata_cache:
87 cached_entry = _metadata_cache[normalized_issuer]
88 cached_at = cached_entry["cached_at"]
89 cache_age = (datetime.now(timezone.utc) - cached_at).total_seconds()
91 if cache_age < self.settings.dcr_metadata_cache_ttl:
92 logger.debug(f"Using cached AS metadata for {normalized_issuer}")
93 return cached_entry["metadata"]
95 # Try RFC 8414 path first
96 rfc8414_url = f"{normalized_issuer}/.well-known/oauth-authorization-server"
98 try:
99 client = await self._get_client()
100 response = await client.get(rfc8414_url, timeout=self._get_timeout())
101 if response.status_code == 200:
102 metadata = response.json()
104 # Validate issuer matches (normalize metadata issuer for comparison)
105 metadata_issuer = (metadata.get("issuer") or "").rstrip("/")
106 if metadata_issuer != normalized_issuer:
107 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}")
109 # Cache the metadata
110 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)}
112 logger.info(f"Discovered AS metadata for {normalized_issuer} via RFC 8414")
113 return metadata
114 except httpx.HTTPError as e:
115 logger.debug(f"RFC 8414 discovery failed for {normalized_issuer}: {e}, trying OIDC fallback")
117 # Try OIDC discovery fallback
118 oidc_url = f"{normalized_issuer}/.well-known/openid-configuration"
120 try:
121 client = await self._get_client()
122 response = await client.get(oidc_url, timeout=self._get_timeout())
123 if response.status_code == 200:
124 metadata = response.json()
126 # Validate issuer matches (normalize metadata issuer for comparison)
127 metadata_issuer = (metadata.get("issuer") or "").rstrip("/")
128 if metadata_issuer != normalized_issuer:
129 raise DcrError(f"AS metadata issuer mismatch: expected {normalized_issuer}, got {metadata.get('issuer')}")
131 # Cache the metadata
132 _metadata_cache[normalized_issuer] = {"metadata": metadata, "cached_at": datetime.now(timezone.utc)}
134 logger.info(f"Discovered AS metadata for {normalized_issuer} via OIDC discovery")
135 return metadata
137 raise DcrError(f"AS metadata not found for {normalized_issuer} (status: {response.status_code})")
138 except httpx.HTTPError as e:
139 raise DcrError(f"Failed to discover AS metadata for {normalized_issuer}: {e}")
141 async def register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient:
142 """Register as OAuth client with upstream AS (RFC 7591).
144 Args:
145 gateway_id: Gateway ID
146 gateway_name: Gateway name
147 issuer: AS issuer URL
148 redirect_uri: OAuth redirect URI
149 scopes: List of OAuth scopes
150 db: Database session
152 Returns:
153 RegisteredOAuthClient record
155 Raises:
156 DcrError: If registration fails
157 """
158 # Normalize issuer URL for consistent storage and lookup
159 normalized_issuer = issuer.rstrip("/")
161 # Validate issuer if allowlist is configured (normalize both for comparison)
162 if self.settings.dcr_allowed_issuers:
163 normalized_allowlist = [i.rstrip("/") for i in self.settings.dcr_allowed_issuers]
164 if normalized_issuer not in normalized_allowlist:
165 raise DcrError(f"Issuer {issuer} is not in allowed issuers list")
167 # Discover AS metadata
168 metadata = await self.discover_as_metadata(normalized_issuer)
170 registration_endpoint = metadata.get("registration_endpoint")
171 if not registration_endpoint:
172 raise DcrError(f"AS {normalized_issuer} does not support Dynamic Client Registration (no registration_endpoint)")
174 # Build registration request (RFC 7591)
175 client_name = self.settings.dcr_client_name_template.replace("{gateway_name}", gateway_name)
177 # Determine grant types based on AS metadata
178 # Use `or []` to handle both missing key AND explicit null value (prevents TypeError)
179 grant_types_supported = metadata.get("grant_types_supported") or []
180 requested_grant_types = ["authorization_code"]
182 # Only request refresh_token if AS explicitly supports it, or if permissive mode is enabled
183 if "refresh_token" in grant_types_supported:
184 requested_grant_types.append("refresh_token")
185 elif self.settings.dcr_request_refresh_token_when_unsupported and not grant_types_supported:
186 # Permissive mode: request refresh_token when AS doesn't advertise grant_types_supported
187 # This is useful for AS servers that support refresh tokens but don't advertise it
188 requested_grant_types.append("refresh_token")
189 logger.debug(f"Requesting refresh_token for {normalized_issuer} (permissive mode, AS omits grant_types_supported)")
191 registration_request = {
192 "client_name": client_name,
193 "redirect_uris": [redirect_uri],
194 "grant_types": requested_grant_types,
195 "response_types": ["code"],
196 "token_endpoint_auth_method": self.settings.dcr_token_endpoint_auth_method,
197 "scope": " ".join(scopes),
198 }
200 # Send registration request
201 try:
202 client = await self._get_client()
203 response = await client.post(registration_endpoint, json=registration_request, timeout=self._get_timeout())
204 # Accept both 200 OK and 201 Created (some servers don't follow RFC 7591 strictly)
205 if response.status_code in (200, 201):
206 registration_response = response.json()
207 else:
208 error_data = response.json()
209 error_msg = error_data.get("error", "unknown_error")
210 error_desc = error_data.get("error_description", str(error_data))
211 raise DcrError(f"Client registration failed: {error_msg} - {error_desc}")
212 except httpx.HTTPError as e:
213 raise DcrError(f"Failed to register client with {normalized_issuer}: {e}")
215 # Encrypt secrets
216 encryption = get_encryption_service(self.settings.auth_encryption_secret)
218 client_secret = registration_response.get("client_secret")
219 client_secret_encrypted = await encryption.encrypt_secret_async(client_secret) if client_secret else None
221 registration_access_token = registration_response.get("registration_access_token")
222 registration_access_token_encrypted = await encryption.encrypt_secret_async(registration_access_token) if registration_access_token else None
224 # Calculate expires at
225 expires_at = None
226 client_secret_expires_at = registration_response.get("client_secret_expires_at")
227 if client_secret_expires_at and client_secret_expires_at > 0:
228 expires_at = datetime.fromtimestamp(client_secret_expires_at, tz=timezone.utc)
230 # Create database record (use normalized issuer for consistent lookup)
231 # Fall back to requested grant_types if AS response omits them
232 registered_client = RegisteredOAuthClient(
233 gateway_id=gateway_id,
234 issuer=normalized_issuer,
235 client_id=registration_response["client_id"],
236 client_secret_encrypted=client_secret_encrypted,
237 redirect_uris=orjson.dumps(registration_response.get("redirect_uris", [redirect_uri])).decode(),
238 grant_types=orjson.dumps(registration_response.get("grant_types", requested_grant_types)).decode(),
239 response_types=orjson.dumps(registration_response.get("response_types", ["code"])).decode(),
240 scope=registration_response.get("scope", " ".join(scopes)),
241 token_endpoint_auth_method=registration_response.get("token_endpoint_auth_method", self.settings.dcr_token_endpoint_auth_method),
242 registration_client_uri=registration_response.get("registration_client_uri"),
243 registration_access_token_encrypted=registration_access_token_encrypted,
244 created_at=datetime.now(timezone.utc),
245 expires_at=expires_at,
246 is_active=True,
247 )
249 db.add(registered_client)
250 db.commit()
251 db.refresh(registered_client)
253 logger.info(f"Successfully registered client {registered_client.client_id} with {normalized_issuer} for gateway {gateway_id}")
255 return registered_client
257 async def get_or_register_client(self, gateway_id: str, gateway_name: str, issuer: str, redirect_uri: str, scopes: List[str], db: Session) -> RegisteredOAuthClient:
258 """Get existing registered client or register new one.
260 Args:
261 gateway_id: Gateway ID
262 gateway_name: Gateway name
263 issuer: AS issuer URL
264 redirect_uri: OAuth redirect URI
265 scopes: List of OAuth scopes
266 db: Database session
268 Returns:
269 RegisteredOAuthClient record
271 Raises:
272 DcrError: If client not found and auto-register is disabled
273 """
274 # Normalize issuer for consistent lookup (matches how register_client stores it)
275 normalized_issuer = issuer.rstrip("/")
277 # Try to find existing client using normalized issuer
278 existing_client = (
279 db.query(RegisteredOAuthClient)
280 .filter(
281 RegisteredOAuthClient.gateway_id == gateway_id, RegisteredOAuthClient.issuer == normalized_issuer, RegisteredOAuthClient.is_active.is_(True)
282 ) # pylint: disable=singleton-comparison
283 .first()
284 )
286 if existing_client:
287 logger.debug(f"Found existing registered client for gateway {gateway_id} and issuer {normalized_issuer}")
288 return existing_client
290 # No existing client, check if auto-register is enabled
291 if not self.settings.dcr_auto_register_on_missing_credentials:
292 raise DcrError(
293 f"No registered client found for gateway {gateway_id} and issuer {normalized_issuer}. Auto-register is disabled. Set MCPGATEWAY_DCR_AUTO_REGISTER_ON_MISSING_CREDENTIALS=true to enable."
294 )
296 # Auto-register (pass normalized issuer for consistent storage)
297 logger.info(f"No existing client found for gateway {gateway_id}, registering new client with {normalized_issuer}")
298 return await self.register_client(gateway_id, gateway_name, normalized_issuer, redirect_uri, scopes, db)
300 async def update_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> RegisteredOAuthClient:
301 """Update existing client registration (RFC 7591 section 4.2).
303 Args:
304 client_record: Existing RegisteredOAuthClient record
305 db: Database session
307 Returns:
308 Updated RegisteredOAuthClient record
310 Raises:
311 DcrError: If update fails
312 """
313 if not client_record.registration_client_uri:
314 raise DcrError("Cannot update client: no registration_client_uri available")
316 if not client_record.registration_access_token_encrypted:
317 raise DcrError("Cannot update client: no registration_access_token available")
319 # Decrypt registration access token
320 encryption = get_encryption_service(self.settings.auth_encryption_secret)
321 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted)
323 # Build update request
324 update_request = {"client_id": client_record.client_id, "redirect_uris": orjson.loads(client_record.redirect_uris), "grant_types": orjson.loads(client_record.grant_types)}
326 # Send update request
327 try:
328 client = await self._get_client()
329 headers = {"Authorization": f"Bearer {registration_access_token}"}
330 response = await client.put(client_record.registration_client_uri, json=update_request, headers=headers, timeout=self._get_timeout())
331 if response.status_code == 200:
332 updated_response = response.json()
334 # Update encrypted secret if changed
335 if "client_secret" in updated_response:
336 client_record.client_secret_encrypted = await encryption.encrypt_secret_async(updated_response["client_secret"])
338 db.commit()
339 db.refresh(client_record)
341 logger.info(f"Successfully updated client registration for {client_record.client_id}")
342 return client_record
344 error_data = response.json()
345 raise DcrError(f"Failed to update client: {error_data}")
346 except httpx.HTTPError as e:
347 raise DcrError(f"Failed to update client registration: {e}")
349 async def delete_client_registration(self, client_record: RegisteredOAuthClient, db: Session) -> bool: # pylint: disable=unused-argument
350 """Delete/revoke client registration (RFC 7591 section 4.3).
352 Args:
353 client_record: RegisteredOAuthClient record to delete
354 db: Database session
356 Returns:
357 True if deletion succeeded
359 Raises:
360 DcrError: If deletion fails (except 404)
361 """
362 if not client_record.registration_client_uri:
363 logger.warning("Cannot delete client at AS: no registration_client_uri")
364 return True # Consider it deleted locally
366 if not client_record.registration_access_token_encrypted:
367 logger.warning("Cannot delete client at AS: no registration_access_token")
368 return True # Consider it deleted locally
370 # Decrypt registration access token
371 encryption = get_encryption_service(self.settings.auth_encryption_secret)
372 registration_access_token = await encryption.decrypt_secret_async(client_record.registration_access_token_encrypted)
374 # Send delete request
375 try:
376 client = await self._get_client()
377 headers = {"Authorization": f"Bearer {registration_access_token}"}
378 response = await client.delete(client_record.registration_client_uri, headers=headers, timeout=self._get_timeout())
379 if response.status_code in [204, 404]: # 204 = deleted, 404 = already gone
380 logger.info(f"Successfully deleted client registration for {client_record.client_id}")
381 return True
383 logger.warning(f"Unexpected status when deleting client: {response.status_code}")
384 return True # Consider it best-effort
385 except httpx.HTTPError as e:
386 logger.warning(f"Failed to delete client at AS: {e}")
387 return True # Best-effort, don't fail if AS is unreachable
390class DcrError(Exception):
391 """DCR-related errors."""