Coverage for mcpgateway / services / token_storage_service.py: 99%

184 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/token_storage_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7OAuth Token Storage Service for ContextForge. 

8 

9This module handles the storage, retrieval, and management of OAuth access and refresh tokens 

10for Authorization Code flow implementations. 

11""" 

12 

13# Standard 

14from datetime import datetime, timedelta, timezone 

15import logging 

16from typing import Any, Dict, List, Optional 

17 

18# Third-Party 

19from sqlalchemy import delete, select 

20from sqlalchemy.orm import Session 

21 

22# First-Party 

23from mcpgateway.common.validators import SecurityValidator 

24from mcpgateway.config import get_settings 

25from mcpgateway.db import OAuthToken 

26from mcpgateway.services.encryption_service import get_encryption_service 

27from mcpgateway.services.oauth_manager import OAuthError 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32class TokenStorageService: 

33 """Manages OAuth token storage and retrieval. 

34 

35 Examples: 

36 >>> service = TokenStorageService(None) # Mock DB for doctest 

37 >>> service.db is None 

38 True 

39 >>> service.encryption is not None or service.encryption is None # Encryption may or may not be available 

40 True 

41 >>> # Test token expiration calculation 

42 >>> from datetime import datetime, timedelta 

43 >>> expires_in = 3600 # 1 hour 

44 >>> now = datetime.now(tz=timezone.utc) 

45 >>> expires_at = now + timedelta(seconds=expires_in) 

46 >>> expires_at > now 

47 True 

48 >>> # Test scope list handling 

49 >>> scopes = ["read", "write", "admin"] 

50 >>> isinstance(scopes, list) 

51 True 

52 >>> "read" in scopes 

53 True 

54 >>> # Test token encryption detection 

55 >>> short_token = "abc123" 

56 >>> len(short_token) < 100 

57 True 

58 >>> encrypted_token = "gAAAAABh" + "x" * 100 

59 >>> len(encrypted_token) > 100 

60 True 

61 """ 

62 

63 def __init__(self, db: Session): 

64 """Initialize Token Storage Service. 

65 

66 Args: 

67 db: Database session 

68 """ 

69 self.db = db 

70 try: 

71 settings = get_settings() 

72 self.encryption = get_encryption_service(settings.auth_encryption_secret) 

73 except (ImportError, AttributeError): 

74 logger.warning("OAuth encryption not available, using plain text storage") 

75 self.encryption = None 

76 

77 async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken: 

78 """Store OAuth tokens for a gateway-user combination. 

79 

80 Args: 

81 gateway_id: ID of the gateway 

82 user_id: OAuth provider user ID 

83 app_user_email: ContextForge user email (required) 

84 access_token: Access token from OAuth provider 

85 refresh_token: Refresh token from OAuth provider (optional) 

86 expires_in: Token expiration time in seconds 

87 scopes: List of OAuth scopes granted 

88 

89 Returns: 

90 OAuthToken record 

91 

92 Raises: 

93 OAuthError: If token storage fails 

94 """ 

95 try: 

96 # Encrypt sensitive tokens if encryption is available 

97 encrypted_access = access_token 

98 encrypted_refresh = refresh_token 

99 

100 if self.encryption: 

101 encrypted_access = await self.encryption.encrypt_secret_async(access_token) 

102 if refresh_token: 

103 encrypted_refresh = await self.encryption.encrypt_secret_async(refresh_token) 

104 

105 # Calculate expiration 

106 expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) 

107 # Create or update token record - now scoped by app_user_email 

108 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

109 

110 if token_record: 

111 # Update existing record 

112 token_record.user_id = user_id # Update OAuth provider ID in case it changed 

113 token_record.access_token = encrypted_access 

114 token_record.refresh_token = encrypted_refresh 

115 token_record.expires_at = expires_at 

116 token_record.scopes = scopes 

117 token_record.updated_at = datetime.now(timezone.utc) 

118 logger.info( 

119 f"Updated OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}, OAuth user {SecurityValidator.sanitize_log_message(user_id)}" 

120 ) 

121 else: 

122 # Create new record 

123 token_record = OAuthToken( 

124 gateway_id=gateway_id, user_id=user_id, app_user_email=app_user_email, access_token=encrypted_access, refresh_token=encrypted_refresh, expires_at=expires_at, scopes=scopes 

125 ) 

126 self.db.add(token_record) 

127 logger.info( 

128 f"Stored new OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}, OAuth user {SecurityValidator.sanitize_log_message(user_id)}" 

129 ) 

130 

131 self.db.commit() 

132 return token_record 

133 

134 except Exception as e: 

135 self.db.rollback() 

136 logger.error(f"Failed to store OAuth tokens: {str(e)}") 

137 raise OAuthError(f"Token storage failed: {str(e)}") 

138 

139 async def get_user_token(self, gateway_id: str, app_user_email: str, threshold_seconds: int = 300) -> Optional[str]: 

140 """Get a valid access token for a specific ContextForge user, refreshing if necessary. 

141 

142 Args: 

143 gateway_id: ID of the gateway 

144 app_user_email: ContextForge user email (required) 

145 threshold_seconds: Seconds before expiry to consider token expired 

146 

147 Returns: 

148 Valid access token or None if no valid token available for this user 

149 """ 

150 try: 

151 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

152 

153 if not token_record: 

154 logger.debug(f"No OAuth tokens found for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}") 

155 return None 

156 

157 # Check if token is expired or near expiration 

158 if self._is_token_expired(token_record, threshold_seconds): 

159 logger.info(f"OAuth token expired for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, app user {SecurityValidator.sanitize_log_message(app_user_email)}") 

160 if token_record.refresh_token: 

161 # Attempt to refresh token 

162 new_token = await self._refresh_access_token(token_record) 

163 if new_token: 

164 return new_token 

165 return None 

166 

167 # Decrypt and return valid token 

168 if self.encryption: 

169 return await self.encryption.decrypt_secret_async(token_record.access_token) 

170 return token_record.access_token 

171 

172 except Exception as e: 

173 logger.error(f"Failed to retrieve OAuth token: {str(e)}") 

174 return None 

175 

176 # REMOVED: get_any_valid_token() - This was a security vulnerability 

177 # All OAuth tokens MUST be user-specific to prevent cross-user token access 

178 

179 async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]: 

180 """Refresh an expired access token using refresh token. 

181 

182 Args: 

183 token_record: OAuth token record to refresh 

184 

185 Returns: 

186 New access token or None if refresh failed 

187 """ 

188 try: 

189 if not token_record.refresh_token: 

190 logger.warning(f"No refresh token available for gateway {token_record.gateway_id}") 

191 return None 

192 

193 # Get the gateway configuration to retrieve OAuth settings 

194 # First-Party 

195 from mcpgateway.db import Gateway # pylint: disable=import-outside-toplevel 

196 

197 gateway = self.db.query(Gateway).filter(Gateway.id == token_record.gateway_id).first() 

198 

199 if not gateway or not gateway.oauth_config: 

200 logger.error(f"No OAuth configuration found for gateway {token_record.gateway_id}") 

201 return None 

202 

203 # Decrypt the refresh token if encryption is available 

204 refresh_token = token_record.refresh_token 

205 if self.encryption: 

206 try: 

207 refresh_token = await self.encryption.decrypt_secret_async(refresh_token) 

208 except Exception as e: 

209 logger.error(f"Failed to decrypt refresh token: {str(e)}") 

210 return None 

211 

212 # Decrypt client_secret if it's encrypted 

213 oauth_config = gateway.oauth_config.copy() 

214 if "client_secret" in oauth_config and oauth_config["client_secret"]: 

215 if self.encryption: 

216 try: 

217 oauth_config["client_secret"] = await self.encryption.decrypt_secret_async(oauth_config["client_secret"]) 

218 except Exception: # nosec B110 

219 # If decryption fails, assume it's already plain text - intentional fallback 

220 pass 

221 

222 # RFC 8707: Set resource parameter for JWT access tokens during refresh 

223 # Standard 

224 from urllib.parse import urlparse, urlunparse # pylint: disable=import-outside-toplevel 

225 

226 def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: 

227 """Normalize resource URL per RFC 8707. 

228 

229 Args: 

230 url: Resource URL to normalize 

231 preserve_query: If True, preserve query (for explicit config). If False, strip query. 

232 

233 Returns: 

234 Normalized URL string, or None if invalid. 

235 """ 

236 if not url: 

237 return None 

238 parsed = urlparse(url) 

239 # RFC 8707: resource MUST be absolute URI (requires scheme) 

240 # Support both hierarchical URIs and URNs 

241 if not parsed.scheme: 

242 logger.warning(f"Invalid resource URL (must be absolute URI with scheme): {url}") 

243 return None 

244 # Remove fragment (MUST NOT); query: preserve for explicit, strip for auto-derived 

245 query = parsed.query if preserve_query else "" 

246 return urlunparse((parsed.scheme, parsed.netloc, parsed.path, parsed.params, query, "")) 

247 

248 existing_resource = oauth_config.get("resource") 

249 if existing_resource: 

250 # Normalize existing resource - preserve query for explicit config 

251 if isinstance(existing_resource, list): 

252 original_count = len(existing_resource) 

253 normalized = [normalize_resource(r, preserve_query=True) for r in existing_resource] 

254 oauth_config["resource"] = [r for r in normalized if r] 

255 if not oauth_config["resource"] and original_count > 0: 

256 logger.warning(f"All {original_count} configured resource values were invalid and removed during refresh") 

257 else: 

258 normalized = normalize_resource(existing_resource, preserve_query=True) 

259 if not normalized and existing_resource: 

260 logger.warning(f"Configured resource was invalid and removed during refresh: {existing_resource}") 

261 oauth_config["resource"] = normalized 

262 elif gateway.url: 

263 # Derive from gateway.url if not explicitly configured (strip query) 

264 oauth_config["resource"] = normalize_resource(gateway.url) 

265 if not oauth_config.get("resource"): 

266 logger.warning(f"Gateway URL is not a valid absolute URI, skipping resource parameter: {gateway.url}") 

267 

268 # Use OAuthManager to refresh the token 

269 # First-Party 

270 from mcpgateway.services.oauth_manager import OAuthManager # pylint: disable=import-outside-toplevel 

271 

272 oauth_manager = OAuthManager() 

273 

274 logger.info(f"Attempting to refresh token for gateway {token_record.gateway_id}, user {token_record.app_user_email}") 

275 token_response = await oauth_manager.refresh_token(refresh_token, oauth_config) 

276 

277 # Update stored tokens with new values 

278 new_access_token = token_response["access_token"] 

279 new_refresh_token = token_response.get("refresh_token", refresh_token) # Some providers return new refresh token 

280 expires_in = token_response.get("expires_in", 3600) 

281 

282 # Encrypt new tokens if encryption is available 

283 encrypted_access = new_access_token 

284 encrypted_refresh = new_refresh_token 

285 if self.encryption: 

286 encrypted_access = await self.encryption.encrypt_secret_async(new_access_token) 

287 encrypted_refresh = await self.encryption.encrypt_secret_async(new_refresh_token) 

288 

289 # Update the token record 

290 token_record.access_token = encrypted_access 

291 token_record.refresh_token = encrypted_refresh 

292 token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) 

293 token_record.updated_at = datetime.now(timezone.utc) 

294 

295 self.db.commit() 

296 logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}") 

297 

298 return new_access_token 

299 

300 except Exception as e: 

301 logger.error(f"Failed to refresh OAuth token for gateway {token_record.gateway_id}: {str(e)}") 

302 # If refresh fails, we should clear the token to force re-authentication 

303 if "invalid" in str(e).lower() or "expired" in str(e).lower(): 

304 logger.warning(f"Refresh token appears invalid/expired, clearing tokens for gateway {token_record.gateway_id}") 

305 self.db.delete(token_record) 

306 self.db.commit() 

307 return None 

308 

309 def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool: 

310 """Check if token is expired or near expiration. 

311 

312 Args: 

313 token_record: OAuth token record to check 

314 threshold_seconds: Seconds before expiry to consider token expired 

315 

316 Returns: 

317 True if token is expired or near expiration 

318 

319 Examples: 

320 >>> from types import SimpleNamespace 

321 >>> from datetime import datetime, timedelta 

322 >>> svc = TokenStorageService(None) 

323 >>> future = datetime.now(tz=timezone.utc) + timedelta(seconds=600) 

324 >>> past = datetime.now(tz=timezone.utc) - timedelta(seconds=10) 

325 >>> rec_future = SimpleNamespace(expires_at=future) 

326 >>> rec_past = SimpleNamespace(expires_at=past) 

327 >>> svc._is_token_expired(rec_future, threshold_seconds=300) # 10 min ahead, 5 min threshold 

328 False 

329 >>> svc._is_token_expired(rec_future, threshold_seconds=900) # 10 min ahead, 15 min threshold 

330 True 

331 >>> svc._is_token_expired(rec_past, threshold_seconds=0) 

332 True 

333 >>> svc._is_token_expired(SimpleNamespace(expires_at=None)) 

334 False 

335 """ 

336 if not token_record.expires_at: 

337 return False 

338 expires_at = token_record.expires_at 

339 if expires_at.tzinfo is None: 

340 expires_at = expires_at.replace(tzinfo=timezone.utc) 

341 return datetime.now(timezone.utc) + timedelta(seconds=threshold_seconds) >= expires_at 

342 

343 async def get_token_info(self, gateway_id: str, app_user_email: str) -> Optional[Dict[str, Any]]: 

344 """Get information about stored OAuth tokens. 

345 

346 Args: 

347 gateway_id: ID of the gateway 

348 app_user_email: ContextForge user email 

349 

350 Returns: 

351 Token information dictionary or None if not found 

352 

353 Examples: 

354 >>> from types import SimpleNamespace 

355 >>> from datetime import datetime, timedelta 

356 >>> svc = TokenStorageService(None) 

357 >>> now = datetime.now(tz=timezone.utc) 

358 >>> future = now + timedelta(seconds=60) 

359 >>> rec = SimpleNamespace(user_id='u1', app_user_email='u1', token_type='bearer', expires_at=future, scopes=['s1'], created_at=now, updated_at=now) 

360 >>> class _Res: 

361 ... def scalar_one_or_none(self): 

362 ... return rec 

363 >>> class _DB: 

364 ... def execute(self, *_args, **_kw): 

365 ... return _Res() 

366 >>> svc.db = _DB() 

367 >>> import asyncio 

368 >>> info = asyncio.run(svc.get_token_info('g1', 'u1')) 

369 >>> info['user_id'] 

370 'u1' 

371 >>> isinstance(info['is_expired'], bool) 

372 True 

373 """ 

374 try: 

375 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

376 

377 if not token_record: 

378 return None 

379 

380 return { 

381 "user_id": token_record.user_id, # OAuth provider user ID 

382 "app_user_email": token_record.app_user_email, # ContextForge user 

383 "token_type": token_record.token_type, 

384 "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None, 

385 "scopes": token_record.scopes, 

386 "created_at": token_record.created_at.isoformat(), 

387 "updated_at": token_record.updated_at.isoformat(), 

388 "is_expired": self._is_token_expired(token_record, 0), 

389 } 

390 

391 except Exception as e: 

392 logger.error(f"Failed to get token info: {str(e)}") 

393 return None 

394 

395 async def revoke_user_tokens(self, gateway_id: str, app_user_email: str) -> bool: 

396 """Revoke OAuth tokens for a specific user. 

397 

398 Args: 

399 gateway_id: ID of the gateway 

400 app_user_email: ContextForge user email 

401 

402 Returns: 

403 True if tokens were revoked successfully 

404 

405 Examples: 

406 >>> from types import SimpleNamespace 

407 >>> from unittest.mock import MagicMock 

408 >>> svc = TokenStorageService(MagicMock()) 

409 >>> rec = SimpleNamespace() 

410 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = rec 

411 >>> svc.db.delete = lambda obj: None 

412 >>> svc.db.commit = lambda: None 

413 >>> import asyncio 

414 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1')) 

415 True 

416 >>> # Not found 

417 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = None 

418 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1')) 

419 False 

420 """ 

421 try: 

422 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

423 

424 if token_record: 

425 self.db.delete(token_record) 

426 self.db.commit() 

427 logger.info(f"Revoked OAuth tokens for gateway {SecurityValidator.sanitize_log_message(gateway_id)}, user {SecurityValidator.sanitize_log_message(app_user_email)}") 

428 return True 

429 

430 return False 

431 

432 except Exception as e: 

433 self.db.rollback() 

434 logger.error(f"Failed to revoke OAuth tokens: {str(e)}") 

435 return False 

436 

437 async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int: 

438 """Clean up expired OAuth tokens older than specified days. 

439 

440 Uses a single SQL DELETE statement instead of loading tokens into memory 

441 and deleting them one by one. This is more efficient and avoids memory 

442 issues when many tokens expire at once. 

443 

444 Args: 

445 max_age_days: Maximum age of tokens to keep 

446 

447 Returns: 

448 Number of tokens cleaned up 

449 

450 Examples: 

451 >>> from unittest.mock import MagicMock 

452 >>> svc = TokenStorageService(MagicMock()) 

453 >>> svc.db.execute.return_value.rowcount = 2 

454 >>> svc.db.commit = lambda: None 

455 >>> import asyncio 

456 >>> asyncio.run(svc.cleanup_expired_tokens(1)) 

457 2 

458 """ 

459 try: 

460 cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=max_age_days) 

461 

462 result = self.db.execute(delete(OAuthToken).where(OAuthToken.expires_at < cutoff_date)) 

463 count = result.rowcount 

464 

465 self.db.commit() 

466 

467 if count > 0: 

468 logger.info(f"Cleaned up {count} expired OAuth tokens") 

469 

470 return count 

471 

472 except Exception as e: 

473 self.db.rollback() 

474 logger.error(f"Failed to cleanup expired tokens: {str(e)}") 

475 return 0