Coverage for mcpgateway / services / token_storage_service.py: 98%

183 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/token_storage_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7OAuth Token Storage Service for MCP Gateway. 

8 

9This module handles the storage, retrieval, and management of OAuth access and refresh tokens 

10for Authorization Code flow implementations. 

11""" 

12 

13# Standard 

14from datetime import datetime, timedelta, timezone 

15import logging 

16from typing import Any, Dict, List, Optional 

17 

18# Third-Party 

19from sqlalchemy import delete, select 

20from sqlalchemy.orm import Session 

21 

22# First-Party 

23from mcpgateway.config import get_settings 

24from mcpgateway.db import OAuthToken 

25from mcpgateway.services.encryption_service import get_encryption_service 

26from mcpgateway.services.oauth_manager import OAuthError 

27 

28logger = logging.getLogger(__name__) 

29 

30 

31class TokenStorageService: 

32 """Manages OAuth token storage and retrieval. 

33 

34 Examples: 

35 >>> service = TokenStorageService(None) # Mock DB for doctest 

36 >>> service.db is None 

37 True 

38 >>> service.encryption is not None or service.encryption is None # Encryption may or may not be available 

39 True 

40 >>> # Test token expiration calculation 

41 >>> from datetime import datetime, timedelta 

42 >>> expires_in = 3600 # 1 hour 

43 >>> now = datetime.now(tz=timezone.utc) 

44 >>> expires_at = now + timedelta(seconds=expires_in) 

45 >>> expires_at > now 

46 True 

47 >>> # Test scope list handling 

48 >>> scopes = ["read", "write", "admin"] 

49 >>> isinstance(scopes, list) 

50 True 

51 >>> "read" in scopes 

52 True 

53 >>> # Test token encryption detection 

54 >>> short_token = "abc123" 

55 >>> len(short_token) < 100 

56 True 

57 >>> encrypted_token = "gAAAAABh" + "x" * 100 

58 >>> len(encrypted_token) > 100 

59 True 

60 """ 

61 

62 def __init__(self, db: Session): 

63 """Initialize Token Storage Service. 

64 

65 Args: 

66 db: Database session 

67 """ 

68 self.db = db 

69 try: 

70 settings = get_settings() 

71 self.encryption = get_encryption_service(settings.auth_encryption_secret) 

72 except (ImportError, AttributeError): 

73 logger.warning("OAuth encryption not available, using plain text storage") 

74 self.encryption = None 

75 

76 async def store_tokens(self, gateway_id: str, user_id: str, app_user_email: str, access_token: str, refresh_token: Optional[str], expires_in: int, scopes: List[str]) -> OAuthToken: 

77 """Store OAuth tokens for a gateway-user combination. 

78 

79 Args: 

80 gateway_id: ID of the gateway 

81 user_id: OAuth provider user ID 

82 app_user_email: MCP Gateway user email (required) 

83 access_token: Access token from OAuth provider 

84 refresh_token: Refresh token from OAuth provider (optional) 

85 expires_in: Token expiration time in seconds 

86 scopes: List of OAuth scopes granted 

87 

88 Returns: 

89 OAuthToken record 

90 

91 Raises: 

92 OAuthError: If token storage fails 

93 """ 

94 try: 

95 # Encrypt sensitive tokens if encryption is available 

96 encrypted_access = access_token 

97 encrypted_refresh = refresh_token 

98 

99 if self.encryption: 

100 encrypted_access = await self.encryption.encrypt_secret_async(access_token) 

101 if refresh_token: 

102 encrypted_refresh = await self.encryption.encrypt_secret_async(refresh_token) 

103 

104 # Calculate expiration 

105 expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) 

106 # Create or update token record - now scoped by app_user_email 

107 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

108 

109 if token_record: 

110 # Update existing record 

111 token_record.user_id = user_id # Update OAuth provider ID in case it changed 

112 token_record.access_token = encrypted_access 

113 token_record.refresh_token = encrypted_refresh 

114 token_record.expires_at = expires_at 

115 token_record.scopes = scopes 

116 token_record.updated_at = datetime.now(timezone.utc) 

117 logger.info(f"Updated OAuth tokens for gateway {gateway_id}, app user {app_user_email}, OAuth user {user_id}") 

118 else: 

119 # Create new record 

120 token_record = OAuthToken( 

121 gateway_id=gateway_id, user_id=user_id, app_user_email=app_user_email, access_token=encrypted_access, refresh_token=encrypted_refresh, expires_at=expires_at, scopes=scopes 

122 ) 

123 self.db.add(token_record) 

124 logger.info(f"Stored new OAuth tokens for gateway {gateway_id}, app user {app_user_email}, OAuth user {user_id}") 

125 

126 self.db.commit() 

127 return token_record 

128 

129 except Exception as e: 

130 self.db.rollback() 

131 logger.error(f"Failed to store OAuth tokens: {str(e)}") 

132 raise OAuthError(f"Token storage failed: {str(e)}") 

133 

134 async def get_user_token(self, gateway_id: str, app_user_email: str, threshold_seconds: int = 300) -> Optional[str]: 

135 """Get a valid access token for a specific MCP Gateway user, refreshing if necessary. 

136 

137 Args: 

138 gateway_id: ID of the gateway 

139 app_user_email: MCP Gateway user email (required) 

140 threshold_seconds: Seconds before expiry to consider token expired 

141 

142 Returns: 

143 Valid access token or None if no valid token available for this user 

144 """ 

145 try: 

146 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

147 

148 if not token_record: 

149 logger.debug(f"No OAuth tokens found for gateway {gateway_id}, app user {app_user_email}") 

150 return None 

151 

152 # Check if token is expired or near expiration 

153 if self._is_token_expired(token_record, threshold_seconds): 

154 logger.info(f"OAuth token expired for gateway {gateway_id}, app user {app_user_email}") 

155 if token_record.refresh_token: 

156 # Attempt to refresh token 

157 new_token = await self._refresh_access_token(token_record) 

158 if new_token: 

159 return new_token 

160 return None 

161 

162 # Decrypt and return valid token 

163 if self.encryption: 

164 return await self.encryption.decrypt_secret_async(token_record.access_token) 

165 return token_record.access_token 

166 

167 except Exception as e: 

168 logger.error(f"Failed to retrieve OAuth token: {str(e)}") 

169 return None 

170 

171 # REMOVED: get_any_valid_token() - This was a security vulnerability 

172 # All OAuth tokens MUST be user-specific to prevent cross-user token access 

173 

174 async def _refresh_access_token(self, token_record: OAuthToken) -> Optional[str]: 

175 """Refresh an expired access token using refresh token. 

176 

177 Args: 

178 token_record: OAuth token record to refresh 

179 

180 Returns: 

181 New access token or None if refresh failed 

182 """ 

183 try: 

184 if not token_record.refresh_token: 

185 logger.warning(f"No refresh token available for gateway {token_record.gateway_id}") 

186 return None 

187 

188 # Get the gateway configuration to retrieve OAuth settings 

189 # First-Party 

190 from mcpgateway.db import Gateway # pylint: disable=import-outside-toplevel 

191 

192 gateway = self.db.query(Gateway).filter(Gateway.id == token_record.gateway_id).first() 

193 

194 if not gateway or not gateway.oauth_config: 

195 logger.error(f"No OAuth configuration found for gateway {token_record.gateway_id}") 

196 return None 

197 

198 # Decrypt the refresh token if encryption is available 

199 refresh_token = token_record.refresh_token 

200 if self.encryption: 

201 try: 

202 refresh_token = await self.encryption.decrypt_secret_async(refresh_token) 

203 except Exception as e: 

204 logger.error(f"Failed to decrypt refresh token: {str(e)}") 

205 return None 

206 

207 # Decrypt client_secret if it's encrypted 

208 oauth_config = gateway.oauth_config.copy() 

209 if "client_secret" in oauth_config and oauth_config["client_secret"]: 

210 if self.encryption: 

211 try: 

212 oauth_config["client_secret"] = await self.encryption.decrypt_secret_async(oauth_config["client_secret"]) 

213 except Exception: # nosec B110 

214 # If decryption fails, assume it's already plain text - intentional fallback 

215 pass 

216 

217 # RFC 8707: Set resource parameter for JWT access tokens during refresh 

218 # Standard 

219 from urllib.parse import urlparse, urlunparse # pylint: disable=import-outside-toplevel 

220 

221 def normalize_resource(url: str, *, preserve_query: bool = False) -> str | None: 

222 """Normalize resource URL per RFC 8707. 

223 

224 Args: 

225 url: Resource URL to normalize 

226 preserve_query: If True, preserve query (for explicit config). If False, strip query. 

227 

228 Returns: 

229 Normalized URL string, or None if invalid. 

230 """ 

231 if not url: 231 ↛ 232line 231 didn't jump to line 232 because the condition on line 231 was never true

232 return None 

233 parsed = urlparse(url) 

234 # RFC 8707: resource MUST be absolute URI (requires scheme) 

235 # Support both hierarchical URIs and URNs 

236 if not parsed.scheme: 

237 logger.warning(f"Invalid resource URL (must be absolute URI with scheme): {url}") 

238 return None 

239 # Remove fragment (MUST NOT); query: preserve for explicit, strip for auto-derived 

240 query = parsed.query if preserve_query else "" 

241 return urlunparse((parsed.scheme, parsed.netloc, parsed.path, parsed.params, query, "")) 

242 

243 existing_resource = oauth_config.get("resource") 

244 if existing_resource: 

245 # Normalize existing resource - preserve query for explicit config 

246 if isinstance(existing_resource, list): 

247 original_count = len(existing_resource) 

248 normalized = [normalize_resource(r, preserve_query=True) for r in existing_resource] 

249 oauth_config["resource"] = [r for r in normalized if r] 

250 if not oauth_config["resource"] and original_count > 0: 

251 logger.warning(f"All {original_count} configured resource values were invalid and removed during refresh") 

252 else: 

253 normalized = normalize_resource(existing_resource, preserve_query=True) 

254 if not normalized and existing_resource: 

255 logger.warning(f"Configured resource was invalid and removed during refresh: {existing_resource}") 

256 oauth_config["resource"] = normalized 

257 elif gateway.url: 

258 # Derive from gateway.url if not explicitly configured (strip query) 

259 oauth_config["resource"] = normalize_resource(gateway.url) 

260 if not oauth_config.get("resource"): 260 ↛ 261line 260 didn't jump to line 261 because the condition on line 260 was never true

261 logger.warning(f"Gateway URL is not a valid absolute URI, skipping resource parameter: {gateway.url}") 

262 

263 # Use OAuthManager to refresh the token 

264 # First-Party 

265 from mcpgateway.services.oauth_manager import OAuthManager # pylint: disable=import-outside-toplevel 

266 

267 oauth_manager = OAuthManager() 

268 

269 logger.info(f"Attempting to refresh token for gateway {token_record.gateway_id}, user {token_record.app_user_email}") 

270 token_response = await oauth_manager.refresh_token(refresh_token, oauth_config) 

271 

272 # Update stored tokens with new values 

273 new_access_token = token_response["access_token"] 

274 new_refresh_token = token_response.get("refresh_token", refresh_token) # Some providers return new refresh token 

275 expires_in = token_response.get("expires_in", 3600) 

276 

277 # Encrypt new tokens if encryption is available 

278 encrypted_access = new_access_token 

279 encrypted_refresh = new_refresh_token 

280 if self.encryption: 

281 encrypted_access = await self.encryption.encrypt_secret_async(new_access_token) 

282 encrypted_refresh = await self.encryption.encrypt_secret_async(new_refresh_token) 

283 

284 # Update the token record 

285 token_record.access_token = encrypted_access 

286 token_record.refresh_token = encrypted_refresh 

287 token_record.expires_at = datetime.now(timezone.utc) + timedelta(seconds=int(expires_in)) 

288 token_record.updated_at = datetime.now(timezone.utc) 

289 

290 self.db.commit() 

291 logger.info(f"Successfully refreshed token for gateway {token_record.gateway_id}, user {token_record.app_user_email}") 

292 

293 return new_access_token 

294 

295 except Exception as e: 

296 logger.error(f"Failed to refresh OAuth token for gateway {token_record.gateway_id}: {str(e)}") 

297 # If refresh fails, we should clear the token to force re-authentication 

298 if "invalid" in str(e).lower() or "expired" in str(e).lower(): 

299 logger.warning(f"Refresh token appears invalid/expired, clearing tokens for gateway {token_record.gateway_id}") 

300 self.db.delete(token_record) 

301 self.db.commit() 

302 return None 

303 

304 def _is_token_expired(self, token_record: OAuthToken, threshold_seconds: int = 300) -> bool: 

305 """Check if token is expired or near expiration. 

306 

307 Args: 

308 token_record: OAuth token record to check 

309 threshold_seconds: Seconds before expiry to consider token expired 

310 

311 Returns: 

312 True if token is expired or near expiration 

313 

314 Examples: 

315 >>> from types import SimpleNamespace 

316 >>> from datetime import datetime, timedelta 

317 >>> svc = TokenStorageService(None) 

318 >>> future = datetime.now(tz=timezone.utc) + timedelta(seconds=600) 

319 >>> past = datetime.now(tz=timezone.utc) - timedelta(seconds=10) 

320 >>> rec_future = SimpleNamespace(expires_at=future) 

321 >>> rec_past = SimpleNamespace(expires_at=past) 

322 >>> svc._is_token_expired(rec_future, threshold_seconds=300) # 10 min ahead, 5 min threshold 

323 False 

324 >>> svc._is_token_expired(rec_future, threshold_seconds=900) # 10 min ahead, 15 min threshold 

325 True 

326 >>> svc._is_token_expired(rec_past, threshold_seconds=0) 

327 True 

328 >>> svc._is_token_expired(SimpleNamespace(expires_at=None)) 

329 False 

330 """ 

331 if not token_record.expires_at: 

332 return False 

333 expires_at = token_record.expires_at 

334 if expires_at.tzinfo is None: 

335 expires_at = expires_at.replace(tzinfo=timezone.utc) 

336 return datetime.now(timezone.utc) + timedelta(seconds=threshold_seconds) >= expires_at 

337 

338 async def get_token_info(self, gateway_id: str, app_user_email: str) -> Optional[Dict[str, Any]]: 

339 """Get information about stored OAuth tokens. 

340 

341 Args: 

342 gateway_id: ID of the gateway 

343 app_user_email: MCP Gateway user email 

344 

345 Returns: 

346 Token information dictionary or None if not found 

347 

348 Examples: 

349 >>> from types import SimpleNamespace 

350 >>> from datetime import datetime, timedelta 

351 >>> svc = TokenStorageService(None) 

352 >>> now = datetime.now(tz=timezone.utc) 

353 >>> future = now + timedelta(seconds=60) 

354 >>> rec = SimpleNamespace(user_id='u1', app_user_email='u1', token_type='bearer', expires_at=future, scopes=['s1'], created_at=now, updated_at=now) 

355 >>> class _Res: 

356 ... def scalar_one_or_none(self): 

357 ... return rec 

358 >>> class _DB: 

359 ... def execute(self, *_args, **_kw): 

360 ... return _Res() 

361 >>> svc.db = _DB() 

362 >>> import asyncio 

363 >>> info = asyncio.run(svc.get_token_info('g1', 'u1')) 

364 >>> info['user_id'] 

365 'u1' 

366 >>> isinstance(info['is_expired'], bool) 

367 True 

368 """ 

369 try: 

370 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

371 

372 if not token_record: 

373 return None 

374 

375 return { 

376 "user_id": token_record.user_id, # OAuth provider user ID 

377 "app_user_email": token_record.app_user_email, # MCP Gateway user 

378 "token_type": token_record.token_type, 

379 "expires_at": token_record.expires_at.isoformat() if token_record.expires_at else None, 

380 "scopes": token_record.scopes, 

381 "created_at": token_record.created_at.isoformat(), 

382 "updated_at": token_record.updated_at.isoformat(), 

383 "is_expired": self._is_token_expired(token_record, 0), 

384 } 

385 

386 except Exception as e: 

387 logger.error(f"Failed to get token info: {str(e)}") 

388 return None 

389 

390 async def revoke_user_tokens(self, gateway_id: str, app_user_email: str) -> bool: 

391 """Revoke OAuth tokens for a specific user. 

392 

393 Args: 

394 gateway_id: ID of the gateway 

395 app_user_email: MCP Gateway user email 

396 

397 Returns: 

398 True if tokens were revoked successfully 

399 

400 Examples: 

401 >>> from types import SimpleNamespace 

402 >>> from unittest.mock import MagicMock 

403 >>> svc = TokenStorageService(MagicMock()) 

404 >>> rec = SimpleNamespace() 

405 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = rec 

406 >>> svc.db.delete = lambda obj: None 

407 >>> svc.db.commit = lambda: None 

408 >>> import asyncio 

409 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1')) 

410 True 

411 >>> # Not found 

412 >>> svc.db.execute.return_value.scalar_one_or_none.return_value = None 

413 >>> asyncio.run(svc.revoke_user_tokens('g1', 'u1')) 

414 False 

415 """ 

416 try: 

417 token_record = self.db.execute(select(OAuthToken).where(OAuthToken.gateway_id == gateway_id, OAuthToken.app_user_email == app_user_email)).scalar_one_or_none() 

418 

419 if token_record: 

420 self.db.delete(token_record) 

421 self.db.commit() 

422 logger.info(f"Revoked OAuth tokens for gateway {gateway_id}, user {app_user_email}") 

423 return True 

424 

425 return False 

426 

427 except Exception as e: 

428 self.db.rollback() 

429 logger.error(f"Failed to revoke OAuth tokens: {str(e)}") 

430 return False 

431 

432 async def cleanup_expired_tokens(self, max_age_days: int = 30) -> int: 

433 """Clean up expired OAuth tokens older than specified days. 

434 

435 Uses a single SQL DELETE statement instead of loading tokens into memory 

436 and deleting them one by one. This is more efficient and avoids memory 

437 issues when many tokens expire at once. 

438 

439 Args: 

440 max_age_days: Maximum age of tokens to keep 

441 

442 Returns: 

443 Number of tokens cleaned up 

444 

445 Examples: 

446 >>> from unittest.mock import MagicMock 

447 >>> svc = TokenStorageService(MagicMock()) 

448 >>> svc.db.execute.return_value.rowcount = 2 

449 >>> svc.db.commit = lambda: None 

450 >>> import asyncio 

451 >>> asyncio.run(svc.cleanup_expired_tokens(1)) 

452 2 

453 """ 

454 try: 

455 cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=max_age_days) 

456 

457 result = self.db.execute(delete(OAuthToken).where(OAuthToken.expires_at < cutoff_date)) 

458 count = result.rowcount 

459 

460 self.db.commit() 

461 

462 if count > 0: 

463 logger.info(f"Cleaned up {count} expired OAuth tokens") 

464 

465 return count 

466 

467 except Exception as e: 

468 self.db.rollback() 

469 logger.error(f"Failed to cleanup expired tokens: {str(e)}") 

470 return 0