Coverage for mcpgateway / services / token_storage_service.py: 98%
183 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/token_storage_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7OAuth Token Storage Service for MCP Gateway.
9This module handles the storage, retrieval, and management of OAuth access and refresh tokens
10for Authorization Code flow implementations.
11"""
13# Standard
14from datetime import datetime, timedelta, timezone
15import logging
16from typing import Any, Dict, List, Optional
18# Third-Party
19from sqlalchemy import delete, select
20from sqlalchemy.orm import Session
22# First-Party
23from mcpgateway.config import get_settings
24from mcpgateway.db import OAuthToken
25from mcpgateway.services.encryption_service import get_encryption_service
26from mcpgateway.services.oauth_manager import OAuthError
28logger = logging.getLogger(__name__)
31class TokenStorageService:
32 """Manages OAuth token storage and retrieval.
34 Examples:
35 >>> service = TokenStorageService(None) # Mock DB for doctest
36 >>> service.db is None
37 True
38 >>> service.encryption is not None or service.encryption is None # Encryption may or may not be available
39 True
40 >>> # Test token expiration calculation
41 >>> from datetime import datetime, timedelta
42 >>> expires_in = 3600 # 1 hour
43 >>> now = datetime.now(tz=timezone.utc)
44 >>> expires_at = now + timedelta(seconds=expires_in)
45 >>> expires_at > now
46 True
47 >>> # Test scope list handling
48 >>> scopes = ["read", "write", "admin"]
49 >>> isinstance(scopes, list)
50 True
51 >>> "read" in scopes
52 True
53 >>> # Test token encryption detection
54 >>> short_token = "abc123"
55 >>> len(short_token) < 100
56 True
57 >>> encrypted_token = "gAAAAABh" + "x" * 100
58 >>> len(encrypted_token) > 100
59 True
60 """
62 def __init__(self, db: Session):
63 """Initialize Token Storage Service.
65 Args:
66 db: Database session
67 """
68 self.db = db
69 try:
70 settings = get_settings()
71 self.encryption = get_encryption_service(settings.auth_encryption_secret)
72 except (ImportError, AttributeError):
73 logger.warning("OAuth encryption not available, using plain text storage")
74 self.encryption = None
76 async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken:
77 """Store OAuth tokens for a gateway-user combination.
79 Args:
80 gateway_id: ID of the gateway
81 user_id: OAuth provider user ID
82 app_user_email: MCP Gateway user email (required)
83 access_token: Access token from OAuth provider
84 refresh_token: Refresh token from OAuth provider (optional)
85 expires_in: Token expiration time in seconds
86 scopes: List of OAuth scopes granted
88 Returns:
89 OAuthToken record
91 Raises:
92 OAuthError: If token storage fails
93 """
94 try:
95 # Encrypt sensitive tokens if encryption is available
96 encrypted_access = access_token
97 encrypted_refresh = refresh_token
99 if self.encryption:
100 encrypted_access = await self.encryption.encrypt_secret_async(access_token)
101 if refresh_token:
102 encrypted_refresh = await self.encryption.encrypt_secret_async(refresh_token)
104 # Calculate expiration
105 expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in))
106 # Create or update token record - now scoped by app_user_email
107 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
109 if token_record:
110 # Update existing record
111 token_record.user_id = user_id # Update OAuth provider ID in case it changed
112 token_record.access_token = encrypted_access
113 token_record.refresh_token = encrypted_refresh
114 token_record.expires_at = expires_at
115 token_record.scopes = scopes
116 token_record.updated_at = datetime.now(timezone.utc)
117 logger.info(f"Updated OAuth tokens for gateway {gateway_id}, app user {app_user_email}, OAuth user {user_id}")
118 else:
119 # Create new record
120 token_record = OAuthToken(
121 gateway_id=gateway_id, user_id=user_id, app_user_email=app_user_email, access_token=encrypted_access, refresh_token=encrypted_refresh, expires_at=expires_at, scopes=scopes
122 )
123 self.db.add(token_record)
124 logger.info(f"Stored new OAuth tokens for gateway {gateway_id}, app user {app_user_email}, OAuth user {user_id}")
126 self.db.commit()
127 return token_record
129 except Exception as e:
130 self.db.rollback()
131 logger.error(f"Failed to store OAuth tokens: {str(e)}")
132 raise OAuthError(f"Token storage failed: {str(e)}")
134 async def get_user_token(self, gateway_id: str, app_user_email: str, threshold_seconds: int = 300) -> Optional[str]:
135 """Get a valid access token for a specific MCP Gateway user, refreshing if necessary.
137 Args:
138 gateway_id: ID of the gateway
139 app_user_email: MCP Gateway user email (required)
140 threshold_seconds: Seconds before expiry to consider token expired
142 Returns:
143 Valid access token or None if no valid token available for this user
144 """
145 try:
146 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
148 if not token_record:
149 logger.debug(f"No OAuth tokens found for gateway {gateway_id}, app user {app_user_email}")
150 return None
152 # Check if token is expired or near expiration
153 if self._is_token_expired(token_record, threshold_seconds):
154 logger.info(f"OAuth token expired for gateway {gateway_id}, app user {app_user_email}")
155 if token_record.refresh_token:
156 # Attempt to refresh token
157 new_token = await self._refresh_access_token(token_record)
158 if new_token:
159 return new_token
160 return None
162 # Decrypt and return valid token
163 if self.encryption:
164 return await self.encryption.decrypt_secret_async(token_record.access_token)
165 return token_record.access_token
167 except Exception as e:
168 logger.error(f"Failed to retrieve OAuth token: {str(e)}")
169 return None
171 # REMOVED: get_any_valid_token() - This was a security vulnerability
172 # All OAuth tokens MUST be user-specific to prevent cross-user token access
174 async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]:
175 """Refresh an expired access token using refresh token.
177 Args:
178 token_record: OAuth token record to refresh
180 Returns:
181 New access token or None if refresh failed
182 """
183 try:
184 if not token_record.refresh_token:
185 logger.warning(f"No refresh token available for gateway {token_record.gateway_id}")
186 return None
188 # Get the gateway configuration to retrieve OAuth settings
189 # First-Party
190 from mcpgateway.db import Gateway # pylint: disable=import-outside-toplevel
192 gateway = self.db.query(Gateway).filter(Gateway.id == token_record.gateway_id).first()
194 if not gateway or not gateway.oauth_config:
195 logger.error(f"No OAuth configuration found for gateway {token_record.gateway_id}")
196 return None
198 # Decrypt the refresh token if encryption is available
199 refresh_token = token_record.refresh_token
200 if self.encryption:
201 try:
202 refresh_token = await self.encryption.decrypt_secret_async(refresh_token)
203 except Exception as e:
204 logger.error(f"Failed to decrypt refresh token: {str(e)}")
205 return None
207 # Decrypt client_secret if it's encrypted
208 oauth_config = gateway.oauth_config.copy()
209 if "client_secret" in oauth_config and oauth_config["client_secret"]:
210 if self.encryption:
211 try:
212 oauth_config["client_secret"] = await self.encryption.decrypt_secret_async(oauth_config["client_secret"])
213 except Exception: # nosec B110
214 # If decryption fails, assume it's already plain text - intentional fallback
215 pass
217 # RFC 8707: Set resource parameter for JWT access tokens during refresh
218 # Standard
219 from urllib.parse import urlparse, urlunparse # pylint: disable=import-outside-toplevel
221 def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None:
222 """Normalize resource URL per RFC 8707.
224 Args:
225 url: Resource URL to normalize
226 preserve_query: If True, preserve query (for explicit config). If False, strip query.
228 Returns:
229 Normalized URL string, or None if invalid.
230 """
231 if not url: 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true
232 return None
233 parsed = urlparse(url)
234 # RFC 8707: resource MUST be absolute URI (requires scheme)
235 # Support both hierarchical URIs and URNs
236 if not parsed.scheme:
237 logger.warning(f"Invalid resource URL (must be absolute URI with scheme): {url}")
238 return None
239 # Remove fragment (MUST NOT); query: preserve for explicit, strip for auto-derived
240 query = parsed.query if preserve_query else ""
241 return urlunparse((parsed.scheme, parsed.netloc, parsed.path, parsed.params, query, ""))
243 existing_resource = oauth_config.get("resource")
244 if existing_resource:
245 # Normalize existing resource - preserve query for explicit config
246 if isinstance(existing_resource, list):
247 original_count = len(existing_resource)
248 normalized = [normalize_resource(r, preserve_query=True) for r in existing_resource]
249 oauth_config["resource"] = [r for r in normalized if r]
250 if not oauth_config["resource"] and original_count > 0:
251 logger.warning(f"All {original_count} configured resource values were invalid and removed during refresh")
252 else:
253 normalized = normalize_resource(existing_resource, preserve_query=True)
254 if not normalized and existing_resource:
255 logger.warning(f"Configured resource was invalid and removed during refresh: {existing_resource}")
256 oauth_config["resource"] = normalized
257 elif gateway.url:
258 # Derive from gateway.url if not explicitly configured (strip query)
259 oauth_config["resource"] = normalize_resource(gateway.url)
260 if not oauth_config.get("resource"): 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true
261 logger.warning(f"Gateway URL is not a valid absolute URI, skipping resource parameter: {gateway.url}")
263 # Use OAuthManager to refresh the token
264 # First-Party
265 from mcpgateway.services.oauth_manager import OAuthManager # pylint: disable=import-outside-toplevel
267 oauth_manager = OAuthManager()
269 logger.info(f"Attempting to refresh token for gateway {token_record.gateway_id}, user {token_record.app_user_email}")
270 token_response = await oauth_manager.refresh_token(refresh_token, oauth_config)
272 # Update stored tokens with new values
273 new_access_token = token_response["access_token"]
274 new_refresh_token = token_response.get("refresh_token", refresh_token) # Some providers return new refresh token
275 expires_in = token_response.get("expires_in", 3600)
277 # Encrypt new tokens if encryption is available
278 encrypted_access = new_access_token
279 encrypted_refresh = new_refresh_token
280 if self.encryption:
281 encrypted_access = await self.encryption.encrypt_secret_async(new_access_token)
282 encrypted_refresh = await self.encryption.encrypt_secret_async(new_refresh_token)
284 # Update the token record
285 token_record.access_token = encrypted_access
286 token_record.refresh_token = encrypted_refresh
287 token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in))
288 token_record.updated_at = datetime.now(timezone.utc)
290 self.db.commit()
291 logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}")
293 return new_access_token
295 except Exception as e:
296 logger.error(f"Failed to refresh OAuth token for gateway {token_record.gateway_id}: {str(e)}")
297 # If refresh fails, we should clear the token to force re-authentication
298 if "invalid" in str(e).lower() or "expired" in str(e).lower():
299 logger.warning(f"Refresh token appears invalid/expired, clearing tokens for gateway {token_record.gateway_id}")
300 self.db.delete(token_record)
301 self.db.commit()
302 return None
304 def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool:
305 """Check if token is expired or near expiration.
307 Args:
308 token_record: OAuth token record to check
309 threshold_seconds: Seconds before expiry to consider token expired
311 Returns:
312 True if token is expired or near expiration
314 Examples:
315 >>> from types import SimpleNamespace
316 >>> from datetime import datetime, timedelta
317 >>> svc = TokenStorageService(None)
318 >>> future = datetime.now(tz=timezone.utc) + timedelta(seconds=600)
319 >>> past = datetime.now(tz=timezone.utc) - timedelta(seconds=10)
320 >>> rec_future = SimpleNamespace(expires_at=future)
321 >>> rec_past = SimpleNamespace(expires_at=past)
322 >>> svc._is_token_expired(rec_future, threshold_seconds=300) # 10 min ahead, 5 min threshold
323 False
324 >>> svc._is_token_expired(rec_future, threshold_seconds=900) # 10 min ahead, 15 min threshold
325 True
326 >>> svc._is_token_expired(rec_past, threshold_seconds=0)
327 True
328 >>> svc._is_token_expired(SimpleNamespace(expires_at=None))
329 False
330 """
331 if not token_record.expires_at:
332 return False
333 expires_at = token_record.expires_at
334 if expires_at.tzinfo is None:
335 expires_at = expires_at.replace(tzinfo=timezone.utc)
336 return datetime.now(timezone.utc) + timedelta(seconds=threshold_seconds) >= expires_at
338 async def get_token_info(self, gateway_id: str, app_user_email: str) -> Optional[Dict[str, Any]]:
339 """Get information about stored OAuth tokens.
341 Args:
342 gateway_id: ID of the gateway
343 app_user_email: MCP Gateway user email
345 Returns:
346 Token information dictionary or None if not found
348 Examples:
349 >>> from types import SimpleNamespace
350 >>> from datetime import datetime, timedelta
351 >>> svc = TokenStorageService(None)
352 >>> now = datetime.now(tz=timezone.utc)
353 >>> future = now + timedelta(seconds=60)
354 >>> rec = SimpleNamespace(user_id='u1', app_user_email='u1', token_type='bearer', expires_at=future, scopes=['s1'], created_at=now, updated_at=now)
355 >>> class _Res:
356 ... def scalar_one_or_none(self):
357 ... return rec
358 >>> class _DB:
359 ... def execute(self, *_args, **_kw):
360 ... return _Res()
361 >>> svc.db = _DB()
362 >>> import asyncio
363 >>> info = asyncio.run(svc.get_token_info('g1', 'u1'))
364 >>> info['user_id']
365 'u1'
366 >>> isinstance(info['is_expired'], bool)
367 True
368 """
369 try:
370 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
372 if not token_record:
373 return None
375 return {
376 "user_id": token_record.user_id, # OAuth provider user ID
377 "app_user_email": token_record.app_user_email, # MCP Gateway user
378 "token_type": token_record.token_type,
379 "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None,
380 "scopes": token_record.scopes,
381 "created_at": token_record.created_at.isoformat(),
382 "updated_at": token_record.updated_at.isoformat(),
383 "is_expired": self._is_token_expired(token_record, 0),
384 }
386 except Exception as e:
387 logger.error(f"Failed to get token info: {str(e)}")
388 return None
390 async def revoke_user_tokens(self, gateway_id: str, app_user_email: str) -> bool:
391 """Revoke OAuth tokens for a specific user.
393 Args:
394 gateway_id: ID of the gateway
395 app_user_email: MCP Gateway user email
397 Returns:
398 True if tokens were revoked successfully
400 Examples:
401 >>> from types import SimpleNamespace
402 >>> from unittest.mock import MagicMock
403 >>> svc = TokenStorageService(MagicMock())
404 >>> rec = SimpleNamespace()
405 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = rec
406 >>> svc.db.delete = lambda obj: None
407 >>> svc.db.commit = lambda: None
408 >>> import asyncio
409 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1'))
410 True
411 >>> # Not found
412 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = None
413 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1'))
414 False
415 """
416 try:
417 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none()
419 if token_record:
420 self.db.delete(token_record)
421 self.db.commit()
422 logger.info(f"Revoked OAuth tokens for gateway {gateway_id}, user {app_user_email}")
423 return True
425 return False
427 except Exception as e:
428 self.db.rollback()
429 logger.error(f"Failed to revoke OAuth tokens: {str(e)}")
430 return False
432 async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int:
433 """Clean up expired OAuth tokens older than specified days.
435 Uses a single SQL DELETE statement instead of loading tokens into memory
436 and deleting them one by one. This is more efficient and avoids memory
437 issues when many tokens expire at once.
439 Args:
440 max_age_days: Maximum age of tokens to keep
442 Returns:
443 Number of tokens cleaned up
445 Examples:
446 >>> from unittest.mock import MagicMock
447 >>> svc = TokenStorageService(MagicMock())
448 >>> svc.db.execute.return_value.rowcount = 2
449 >>> svc.db.commit = lambda: None
450 >>> import asyncio
451 >>> asyncio.run(svc.cleanup_expired_tokens(1))
452 2
453 """
454 try:
455 cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=max_age_days)
457 result = self.db.execute(delete(OAuthToken).where(OAuthToken.expires_at < cutoff_date))
458 count = result.rowcount
460 self.db.commit()
462 if count > 0:
463 logger.info(f"Cleaned up {count} expired OAuth tokens")
465 return count
467 except Exception as e:
468 self.db.rollback()
469 logger.error(f"Failed to cleanup expired tokens: {str(e)}")
470 return 0