Coverage for mcpgateway / services / token_storage_service.py: 99%
184 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/token_storage_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7OAuth Token Storage Service for ContextForge.
9This module handles the storage, retrieval, and management of OAuth access and refresh tokens
10for Authorization Code flow implementations.
11"""
13# Standard
14from datetime import datetime, timedelta, timezone
15import logging
16from typing import Any, Dict, List, Optional
18# Third-Party
19from sqlalchemy import delete, select
20from sqlalchemy.orm import Session
22# First-Party
23from mcpgateway.common.validators import SecurityValidator
24from mcpgateway.config import get_settings
25from mcpgateway.db import OAuthToken
26from mcpgateway.services.encryption_service import get_encryption_service
27from mcpgateway.services.oauth_manager import OAuthError
29logger = logging.getLogger(__name__)
32class TokenStorageService:
33 """Manages OAuth token storage and retrieval.
35 Examples:
36 >>> service = TokenStorageService(None) # Mock DB for doctest
37 >>> service.db is None
38 True
39 >>> service.encryption is not None or service.encryption is None # Encryption may or may not be available
40 True
41 >>> # Test token expiration calculation
42 >>> from datetime import datetime, timedelta
43 >>> expires_in = 3600 # 1 hour
44 >>> now = datetime.now(tz=timezone.utc)
45 >>> expires_at = now + timedelta(seconds=expires_in)
46 >>> expires_at > now
47 True
48 >>> # Test scope list handling
49 >>> scopes = ["read", "write", "admin"]
50 >>> isinstance(scopes, list)
51 True
52 >>> "read" in scopes
53 True
54 >>> # Test token encryption detection
55 >>> short_token = "abc123"
56 >>> len(short_token) < 100
57 True
58 >>> encrypted_token = "gAAAAABh" + "x" * 100
59 >>> len(encrypted_token) > 100
60 True
61 """
63 def __init__(self, db: Session):
64 """Initialize Token Storage Service.
66 Args:
67 db: Database session
68 """
69 self.db = db
70 try:
71 settings = get_settings()
72 self.encryption = get_encryption_service(settings.auth_encryption_secret)
73 except (ImportError, AttributeError):
74 logger.warning("OAuth encryption not available, using plain text storage")
75 self.encryption = None
77 async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken:
78 """Store OAuth tokens for a gateway-user combination.
80 Args:
81 gateway_id: ID of the gateway
82 user_id: OAuth provider user ID
83 app_user_email: ContextForge user email (required)
84 access_token: Access token from OAuth provider
85 refresh_token: Refresh token from OAuth provider (optional)
86 expires_in: Token expiration time in seconds
87 scopes: List of OAuth scopes granted
89 Returns:
90 OAuthToken record
92 Raises:
93 OAuthError: If token storage fails
94 """
95 try:
96 # Encrypt sensitive tokens if encryption is available
97 encrypted_access = access_token
98 encrypted_refresh = refresh_token
100 if self.encryption:
101 encrypted_access = await self.encryption.encrypt_secret_async(access_token)
102 if refresh_token:
103 encrypted_refresh = await self.encryption.encrypt_secret_async(refresh_token)
105 # Calculate expiration
106 expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in))
107 # Create or update token record - now scoped by app_user_email
108 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
110 if token_record:
111 # Update existing record
112 token_record.user_id = user_id # Update OAuth provider ID in case it changed
113 token_record.access_token = encrypted_access
114 token_record.refresh_token = encrypted_refresh
115 token_record.expires_at = expires_at
116 token_record.scopes = scopes
117 token_record.updated_at = datetime.now(timezone.utc)
118 logger.info(
119 f"Updated OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}, OAuth user {SecurityValidator.sanitize_log_message(user_id)}"
120 )
121 else:
122 # Create new record
123 token_record = OAuthToken(
124 gateway_id=gateway_id, user_id=user_id, app_user_email=app_user_email, access_token=encrypted_access, refresh_token=encrypted_refresh, expires_at=expires_at, scopes=scopes
125 )
126 self.db.add(token_record)
127 logger.info(
128 f"Stored new OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}, OAuth user {SecurityValidator.sanitize_log_message(user_id)}"
129 )
131 self.db.commit()
132 return token_record
134 except Exception as e:
135 self.db.rollback()
136 logger.error(f"Failed to store OAuth tokens: {str(e)}")
137 raise OAuthError(f"Token storage failed: {str(e)}")
139 async def get_user_token(self, gateway_id: str, app_user_email: str, threshold_seconds: int = 300) -> Optional[str]:
140 """Get a valid access token for a specific ContextForge user, refreshing if necessary.
142 Args:
143 gateway_id: ID of the gateway
144 app_user_email: ContextForge user email (required)
145 threshold_seconds: Seconds before expiry to consider token expired
147 Returns:
148 Valid access token or None if no valid token available for this user
149 """
150 try:
151 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
153 if not token_record:
154 logger.debug(f"No OAuth tokens found for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}")
155 return None
157 # Check if token is expired or near expiration
158 if self._is_token_expired(token_record, threshold_seconds):
159 logger.info(f"OAuth token expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}")
160 if token_record.refresh_token:
161 # Attempt to refresh token
162 new_token = await self._refresh_access_token(token_record)
163 if new_token:
164 return new_token
165 return None
167 # Decrypt and return valid token
168 if self.encryption:
169 return await self.encryption.decrypt_secret_async(token_record.access_token)
170 return token_record.access_token
172 except Exception as e:
173 logger.error(f"Failed to retrieve OAuth token: {str(e)}")
174 return None
176 # REMOVED: get_any_valid_token() - This was a security vulnerability
177 # All OAuth tokens MUST be user-specific to prevent cross-user token access
179 async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]:
180 """Refresh an expired access token using refresh token.
182 Args:
183 token_record: OAuth token record to refresh
185 Returns:
186 New access token or None if refresh failed
187 """
188 try:
189 if not token_record.refresh_token:
190 logger.warning(f"No refresh token available for gateway {token_record.gateway_id}")
191 return None
193 # Get the gateway configuration to retrieve OAuth settings
194 # First-Party
195 from mcpgateway.db import Gateway # pylint: disable=import-outside-toplevel
197 gateway = self.db.query(Gateway).filter(Gateway.id == token_record.gateway_id).first()
199 if not gateway or not gateway.oauth_config:
200 logger.error(f"No OAuth configuration found for gateway {token_record.gateway_id}")
201 return None
203 # Decrypt the refresh token if encryption is available
204 refresh_token = token_record.refresh_token
205 if self.encryption:
206 try:
207 refresh_token = await self.encryption.decrypt_secret_async(refresh_token)
208 except Exception as e:
209 logger.error(f"Failed to decrypt refresh token: {str(e)}")
210 return None
212 # Decrypt client_secret if it's encrypted
213 oauth_config = gateway.oauth_config.copy()
214 if "client_secret" in oauth_config and oauth_config["client_secret"]:
215 if self.encryption:
216 try:
217 oauth_config["client_secret"] = await self.encryption.decrypt_secret_async(oauth_config["client_secret"])
218 except Exception: # nosec B110
219 # If decryption fails, assume it's already plain text - intentional fallback
220 pass
222 # RFC 8707: Set resource parameter for JWT access tokens during refresh
223 # Standard
224 from urllib.parse import urlparse, urlunparse # pylint: disable=import-outside-toplevel
226 def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None:
227 """Normalize resource URL per RFC 8707.
229 Args:
230 url: Resource URL to normalize
231 preserve_query: If True, preserve query (for explicit config). If False, strip query.
233 Returns:
234 Normalized URL string, or None if invalid.
235 """
236 if not url:
237 return None
238 parsed = urlparse(url)
239 # RFC 8707: resource MUST be absolute URI (requires scheme)
240 # Support both hierarchical URIs and URNs
241 if not parsed.scheme:
242 logger.warning(f"Invalid resource URL (must be absolute URI with scheme): {url}")
243 return None
244 # Remove fragment (MUST NOT); query: preserve for explicit, strip for auto-derived
245 query = parsed.query if preserve_query else ""
246 return urlunparse((parsed.scheme, parsed.netloc, parsed.path, parsed.params, query, ""))
248 existing_resource = oauth_config.get("resource")
249 if existing_resource:
250 # Normalize existing resource - preserve query for explicit config
251 if isinstance(existing_resource, list):
252 original_count = len(existing_resource)
253 normalized = [normalize_resource(r, preserve_query=True) for r in existing_resource]
254 oauth_config["resource"] = [r for r in normalized if r]
255 if not oauth_config["resource"] and original_count > 0:
256 logger.warning(f"All {original_count} configured resource values were invalid and removed during refresh")
257 else:
258 normalized = normalize_resource(existing_resource, preserve_query=True)
259 if not normalized and existing_resource:
260 logger.warning(f"Configured resource was invalid and removed during refresh: {existing_resource}")
261 oauth_config["resource"] = normalized
262 elif gateway.url:
263 # Derive from gateway.url if not explicitly configured (strip query)
264 oauth_config["resource"] = normalize_resource(gateway.url)
265 if not oauth_config.get("resource"):
266 logger.warning(f"Gateway URL is not a valid absolute URI, skipping resource parameter: {gateway.url}")
268 # Use OAuthManager to refresh the token
269 # First-Party
270 from mcpgateway.services.oauth_manager import OAuthManager # pylint: disable=import-outside-toplevel
272 oauth_manager = OAuthManager()
274 logger.info(f"Attempting to refresh token for gateway {token_record.gateway_id}, user {token_record.app_user_email}")
275 token_response = await oauth_manager.refresh_token(refresh_token, oauth_config)
277 # Update stored tokens with new values
278 new_access_token = token_response["access_token"]
279 new_refresh_token = token_response.get("refresh_token", refresh_token) # Some providers return new refresh token
280 expires_in = token_response.get("expires_in", 3600)
282 # Encrypt new tokens if encryption is available
283 encrypted_access = new_access_token
284 encrypted_refresh = new_refresh_token
285 if self.encryption:
286 encrypted_access = await self.encryption.encrypt_secret_async(new_access_token)
287 encrypted_refresh = await self.encryption.encrypt_secret_async(new_refresh_token)
289 # Update the token record
290 token_record.access_token = encrypted_access
291 token_record.refresh_token = encrypted_refresh
292 token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in))
293 token_record.updated_at = datetime.now(timezone.utc)
295 self.db.commit()
296 logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}")
298 return new_access_token
300 except Exception as e:
301 logger.error(f"Failed to refresh OAuth token for gateway {token_record.gateway_id}: {str(e)}")
302 # If refresh fails, we should clear the token to force re-authentication
303 if "invalid" in str(e).lower() or "expired" in str(e).lower():
304 logger.warning(f"Refresh token appears invalid/expired, clearing tokens for gateway {token_record.gateway_id}")
305 self.db.delete(token_record)
306 self.db.commit()
307 return None
309 def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool:
310 """Check if token is expired or near expiration.
312 Args:
313 token_record: OAuth token record to check
314 threshold_seconds: Seconds before expiry to consider token expired
316 Returns:
317 True if token is expired or near expiration
319 Examples:
320 >>> from types import SimpleNamespace
321 >>> from datetime import datetime, timedelta
322 >>> svc = TokenStorageService(None)
323 >>> future = datetime.now(tz=timezone.utc) + timedelta(seconds=600)
324 >>> past = datetime.now(tz=timezone.utc) - timedelta(seconds=10)
325 >>> rec_future = SimpleNamespace(expires_at=future)
326 >>> rec_past = SimpleNamespace(expires_at=past)
327 >>> svc._is_token_expired(rec_future, threshold_seconds=300) # 10 min ahead, 5 min threshold
328 False
329 >>> svc._is_token_expired(rec_future, threshold_seconds=900) # 10 min ahead, 15 min threshold
330 True
331 >>> svc._is_token_expired(rec_past, threshold_seconds=0)
332 True
333 >>> svc._is_token_expired(SimpleNamespace(expires_at=None))
334 False
335 """
336 if not token_record.expires_at:
337 return False
338 expires_at = token_record.expires_at
339 if expires_at.tzinfo is None:
340 expires_at = expires_at.replace(tzinfo=timezone.utc)
341 return datetime.now(timezone.utc) + timedelta(seconds=threshold_seconds) >= expires_at
343 async def get_token_info(self, gateway_id: str, app_user_email: str) -> Optional[Dict[str, Any]]:
344 """Get information about stored OAuth tokens.
346 Args:
347 gateway_id: ID of the gateway
348 app_user_email: ContextForge user email
350 Returns:
351 Token information dictionary or None if not found
353 Examples:
354 >>> from types import SimpleNamespace
355 >>> from datetime import datetime, timedelta
356 >>> svc = TokenStorageService(None)
357 >>> now = datetime.now(tz=timezone.utc)
358 >>> future = now + timedelta(seconds=60)
359 >>> rec = SimpleNamespace(user_id='u1', app_user_email='u1', token_type='bearer', expires_at=future, scopes=['s1'], created_at=now, updated_at=now)
360 >>> class _Res:
361 ... def scalar_one_or_none(self):
362 ... return rec
363 >>> class _DB:
364 ... def execute(self, *_args, **_kw):
365 ... return _Res()
366 >>> svc.db = _DB()
367 >>> import asyncio
368 >>> info = asyncio.run(svc.get_token_info('g1', 'u1'))
369 >>> info['user_id']
370 'u1'
371 >>> isinstance(info['is_expired'], bool)
372 True
373 """
374 try:
375 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
377 if not token_record:
378 return None
380 return {
381 "user_id": token_record.user_id, # OAuth provider user ID
382 "app_user_email": token_record.app_user_email, # ContextForge user
383 "token_type": token_record.token_type,
384 "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None,
385 "scopes": token_record.scopes,
386 "created_at": token_record.created_at.isoformat(),
387 "updated_at": token_record.updated_at.isoformat(),
388 "is_expired": self._is_token_expired(token_record, 0),
389 }
391 except Exception as e:
392 logger.error(f"Failed to get token info: {str(e)}")
393 return None
395 async def revoke_user_tokens(self, gateway_id: str, app_user_email: str) -> bool:
396 """Revoke OAuth tokens for a specific user.
398 Args:
399 gateway_id: ID of the gateway
400 app_user_email: ContextForge user email
402 Returns:
403 True if tokens were revoked successfully
405 Examples:
406 >>> from types import SimpleNamespace
407 >>> from unittest.mock import MagicMock
408 >>> svc = TokenStorageService(MagicMock())
409 >>> rec = SimpleNamespace()
410 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = rec
411 >>> svc.db.delete = lambda obj: None
412 >>> svc.db.commit = lambda: None
413 >>> import asyncio
414 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1'))
415 True
416 >>> # Not found
417 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = None
418 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1'))
419 False
420 """
421 try:
422 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
424 if token_record:
425 self.db.delete(token_record)
426 self.db.commit()
427 logger.info(f"Revoked OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, user {SecurityValidator.sanitize_log_message(app_user_email)}")
428 return True
430 return False
432 except Exception as e:
433 self.db.rollback()
434 logger.error(f"Failed to revoke OAuth tokens: {str(e)}")
435 return False
437 async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int:
438 """Clean up expired OAuth tokens older than specified days.
440 Uses a single SQL DELETE statement instead of loading tokens into memory
441 and deleting them one by one. This is more efficient and avoids memory
442 issues when many tokens expire at once.
444 Args:
445 max_age_days: Maximum age of tokens to keep
447 Returns:
448 Number of tokens cleaned up
450 Examples:
451 >>> from unittest.mock import MagicMock
452 >>> svc = TokenStorageService(MagicMock())
453 >>> svc.db.execute.return_value.rowcount = 2
454 >>> svc.db.commit = lambda: None
455 >>> import asyncio
456 >>> asyncio.run(svc.cleanup_expired_tokens(1))
457 2
458 """
459 try:
460 cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=max_age_days)
462 result = self.db.execute(delete(OAuthToken).where(OAuthToken.expires_at < cutoff_date))
463 count = result.rowcount
465 self.db.commit()
467 if count > 0:
468 logger.info(f"Cleaned up {count} expired OAuth tokens")
470 return count
472 except Exception as e:
473 self.db.rollback()
474 logger.error(f"Failed to cleanup expired tokens: {str(e)}")
475 return 0