Coverage for mcpgateway / services / encryption_service.py: 99%
191 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/encryption_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti, Madhav Kandukuri, Mohan Lakshmaiah
7Encryption Service for Client Secrets.
9Handles encryption and decryption of client secrets using Argon2id-derived
10Fernet keys with explicit format markers for secure detection.
12## Format & Detection
14**New Format (v2):**
15- Encrypted bundles are JSON objects prefixed with "v2:" marker
16- Format: "v2:{...json...}"
17- Contains explicit version, KDF type, parameters, salt, and encrypted token
18- Always detectable by strict validation
20**Legacy Support:**
21- Fernet binary format (version byte 0x80 marker)
22- JSON bundles with argon2id KDF (without v2: prefix)
23- Accepted for reading, but all new encryptions use v2 format
25## Detection & Security
27**Strict Detection:**
28- Checks v2: prefix first (most reliable)
29- Falls back to legacy Fernet version byte (0x80)
30- Validates all required JSON keys before considering data encrypted
31- Returns False for ambiguous data (safe default)
33**WARNING: Do NOT use is_encrypted() for security decisions:**
34- Edge cases exist where plaintext JSON could theoretically match encrypted structure
35- Always validate encryption state at storage boundaries
36- Use explicit markers when possible (e.g., in database schema)
38## API Usage
40**Strict Mode** (for validation/auditing):
41- `encrypt_secret(plaintext: str) -> str` – Raises if already encrypted
42- `decrypt_secret(bundle: str) -> str` – Raises if not encrypted or fails
43- Forces calling code to be explicit about intent
45**Idempotent Mode** (for resilience):
46- `decrypt_secret_or_plaintext(bundle: str) -> Optional[str]` – Returns plaintext if not encrypted
47- `decrypt_secret_async(bundle: str) -> Optional[str]` – Backward compatible async wrapper
49**Async Variants:**
50- `encrypt_secret_async()` – Async encryption
51- `decrypt_secret_async()` – Idempotent async (backward compatible)
52- `decrypt_secret_strict_async()` – Strict async
53- `decrypt_secret_or_plaintext_async()` – Idempotent async
55## Error Handling
57| Scenario | Strict Mode | Idempotent Mode |
58|----------|-------------|-----------------|
59| Encrypt plaintext | Returns encrypted bundle | Returns encrypted bundle |
60| Encrypt already-encrypted | Raises `AlreadyEncryptedError` | Raises `AlreadyEncryptedError` |
61| Decrypt valid bundle | Returns plaintext | Returns plaintext |
62| Decrypt plaintext | Raises `NotEncryptedError` | Returns plaintext unchanged |
63| Decrypt corrupted data | Raises `ValueError` | Returns None |
64| Wrong decryption key | Raises `ValueError` | Returns None |
66## Migration Strategy
681. **Phase 1 (Current)**: New encryptions use v2 format, decryptions accept both
692. **Phase 2 (Next sprint)**: Background job migrates legacy data to v2
703. **Phase 3 (When 95%+ migrated)**: Deprecate legacy format support
714. **Phase 4 (Next release)**: Remove legacy code
73## Performance Notes
75- Argon2id KDF: tuned for 3ms on modern hardware (see config)
76- Random salt per encryption: unique ciphertexts for same plaintext
77- Thread-safe: Each call derives unique salt/nonce
78- Async via `asyncio.to_thread()`: scales to thread pool
79"""
81# Standard
82import asyncio
83import base64
84import binascii
85import logging
86import os
87from typing import Any, Optional, Union
89# Third-Party
90from argon2.low_level import hash_secret_raw, Type
91from cryptography.fernet import Fernet, InvalidToken
92import orjson
93from pydantic import SecretStr
95# First-Party
96from mcpgateway.common.oauth import is_sensitive_oauth_key
97from mcpgateway.config import settings
99logger = logging.getLogger(__name__)
102class AlreadyEncryptedError(ValueError):
103 """Raised when encrypt_secret() is called on already-encrypted data."""
106class NotEncryptedError(ValueError):
107 """Raised when decrypt_secret() is called on plaintext data."""
110class EncryptionService:
111 """Service for encrypting/decrypting client secrets using Argon2id-derived Fernet.
113 Provides strict and idempotent modes for different use cases:
114 - Strict mode: `encrypt_secret()` and `decrypt_secret()` for explicit validation
115 - Idempotent mode: `decrypt_secret_or_plaintext()` for resilient decryption
117 All new encryptions produce v2 format (v2:{json}). Legacy formats are still
118 accepted for backward compatibility.
120 Example (Strict Mode):
121 ```python
122 svc = EncryptionService(SecretStr("key"))
123 encrypted = svc.encrypt_secret("my_secret") # Returns "v2:{...}"
124 plaintext = svc.decrypt_secret(encrypted) # Returns "my_secret"
125 # Raises AlreadyEncryptedError if called on already-encrypted data
126 ```
128 Example (Idempotent Mode):
129 ```python
130 # Returns plaintext unchanged, or None on error
131 result = svc.decrypt_secret_or_plaintext(data)
132 ```
134 Thread-safe: All methods generate unique random salt/nonce per call.
135 """
137 # Format marker for new encrypted bundles
138 FORMAT_MARKER = "v2:"
139 FORMAT_VERSION = "v2"
141 def __init__(
142 self,
143 encryption_secret: Union[SecretStr, str],
144 time_cost: Optional[int] = None,
145 memory_cost: Optional[int] = None,
146 parallelism: Optional[int] = None,
147 hash_len: int = 32,
148 salt_len: int = 16,
149 ):
150 """Initialize the encryption service.
152 Args:
153 encryption_secret: Secret key for encryption/decryption (SecretStr or string)
154 time_cost: Argon2id time cost parameter (default: from settings or 3)
155 memory_cost: Argon2id memory cost parameter in KiB (default: from settings or 65536)
156 parallelism: Argon2id parallelism parameter (default: from settings or 1)
157 hash_len: Length of derived key in bytes (default: 32)
158 salt_len: Length of salt in bytes (default: 16)
159 """
160 if isinstance(encryption_secret, SecretStr):
161 self.encryption_secret = encryption_secret.get_secret_value().encode()
162 else:
163 self.encryption_secret = str(encryption_secret).encode()
165 self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3)
166 self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536)
167 self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1)
168 self.hash_len = hash_len
169 self.salt_len = salt_len
171 def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes:
172 """Derive encryption key using Argon2id KDF.
174 Args:
175 passphrase: Secret passphrase to derive key from
176 salt: Random salt for key derivation
177 time_cost: Argon2id time cost parameter
178 memory_cost: Argon2id memory cost parameter (in KiB)
179 parallelism: Argon2id parallelism parameter
181 Returns:
182 Base64-encoded derived key ready for Fernet
183 """
184 raw = hash_secret_raw(
185 secret=passphrase,
186 salt=salt,
187 time_cost=time_cost,
188 memory_cost=memory_cost,
189 parallelism=parallelism,
190 hash_len=self.hash_len,
191 type=Type.ID,
192 )
193 return base64.urlsafe_b64encode(raw)
195 def encrypt_secret(self, plaintext: str) -> str:
196 """Encrypt plaintext to v2 format with explicit marker.
198 STRICT: Raises AlreadyEncryptedError if input is already encrypted.
199 Caller must check is_encrypted() first if input origin is uncertain.
201 Args:
202 plaintext: Unencrypted secret to encrypt
204 Returns:
205 str: Encrypted bundle as "v2:{json}" string
207 Raises:
208 AlreadyEncryptedError: If input is already encrypted
209 ValueError: If encryption fails
210 """
211 if self.is_encrypted(plaintext):
212 raise AlreadyEncryptedError("Input is already encrypted. Use decrypt_secret() first, or use decrypt_secret_or_plaintext() if you need idempotent behavior.")
214 try:
215 salt = os.urandom(16)
216 key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism)
217 fernet = Fernet(key)
218 token = fernet.encrypt(plaintext.encode()).decode()
220 bundle_obj = {
221 "version": self.FORMAT_VERSION,
222 "kdf": "argon2id",
223 "t": self.time_cost,
224 "m": self.memory_cost,
225 "p": self.parallelism,
226 "salt": base64.b64encode(salt).decode(),
227 "token": token,
228 }
230 json_str = orjson.dumps(bundle_obj).decode()
231 return f"{self.FORMAT_MARKER}{json_str}"
232 except Exception as e:
233 logger.error("Failed to encrypt secret: %s", e)
234 raise ValueError(f"Encryption failed: {e}") from e
236 async def encrypt_secret_async(self, plaintext: str) -> str:
237 """Async wrapper for encrypt_secret().
239 Args:
240 plaintext: Unencrypted secret to encrypt
242 Returns:
243 str: Encrypted bundle as "v2:{json}" string
244 """
245 return await asyncio.to_thread(self.encrypt_secret, plaintext)
247 def decrypt_secret(self, bundle_json: str) -> str:
248 """Decrypt an encrypted bundle (strict mode).
250 STRICT: Raises NotEncryptedError if input is not encrypted.
251 Raises DecryptionError if bundle is corrupted/invalid.
253 Use decrypt_secret_or_plaintext() if you need idempotent behavior.
255 Args:
256 bundle_json: Encrypted bundle (with or without v2: prefix)
258 Returns:
259 str: Decrypted plaintext
261 Raises:
262 NotEncryptedError: If input is not encrypted
263 ValueError: If decryption fails (corrupted/invalid data)
264 """
265 if not self.is_encrypted(bundle_json):
266 raise NotEncryptedError("Input is not encrypted. Use decrypt_secret_or_plaintext() for idempotent behavior.")
268 return self._decrypt_bundle(bundle_json)
270 async def decrypt_secret_strict_async(self, bundle_json: str) -> str:
271 """Async wrapper for decrypt_secret() (STRICT mode).
273 Raises exceptions if input is not encrypted or decryption fails.
274 Use this when you need explicit error handling.
276 Args:
277 bundle_json: Encrypted bundle (with or without v2: prefix)
279 Returns:
280 str: Decrypted plaintext
281 """
282 return await asyncio.to_thread(self.decrypt_secret, bundle_json)
284 # NOTE: This async wrapper remains IDEMPOTENT for backward compatibility.
285 # - Returns plaintext unchanged if input is not encrypted.
286 # - Returns decrypted plaintext if input is encrypted.
287 # - Returns None if decryption fails.
288 # Prefer `decrypt_secret_strict_async()` or `decrypt_secret()` when strict validation is required.
289 async def decrypt_secret_async(self, bundle_json: str) -> Optional[str]:
290 """Async wrapper for decrypt_secret_or_plaintext() (IDEMPOTENT for backward compatibility).
292 BACKWARD COMPATIBLE: This is idempotent for existing code.
293 - Returns plaintext if not encrypted
294 - Returns decrypted plaintext if encrypted
295 - Returns None if decryption fails
297 For strict error handling, use decrypt_secret_strict_async() or decrypt_secret().
299 Args:
300 bundle_json: Encrypted bundle or plaintext
302 Returns:
303 Optional[str]: Decrypted plaintext if encrypted, original input if plaintext, or None on failure
304 """
305 return await asyncio.to_thread(self.decrypt_secret_or_plaintext, bundle_json)
307 # Idempotent helper: safe to call repeatedly. Returns original input for plaintext.
308 def decrypt_secret_or_plaintext(self, bundle_json: str) -> Optional[str]:
309 """Decrypt if encrypted, return plaintext unchanged if not (idempotent).
311 Args:
312 bundle_json: Encrypted bundle or plaintext
314 Returns:
315 Optional[str]: Decrypted plaintext if encrypted, original input if plaintext.
316 None if bundle is encrypted but decryption fails.
318 This method is idempotent: calling it multiple times is safe.
319 Use decrypt_secret() if you need strict error handling.
320 """
321 is_encrypted = self.is_encrypted(bundle_json)
322 if not is_encrypted:
323 # For data that starts with encryption markers but failed validation,
324 # return None (it was supposed to be encrypted but is corrupted)
325 if bundle_json.startswith(self.FORMAT_MARKER):
326 # Has v2: prefix but failed validation - corrupted encrypted data
327 # Return None since this is almost certainly corrupted encryption
328 logger.error("Input has v2: prefix but failed validation: %s", bundle_json[:50])
329 return None
331 # No encryption markers - treat as plaintext
332 return bundle_json
334 try:
335 return self._decrypt_bundle(bundle_json)
336 except Exception as e:
337 logger.error("Failed to decrypt secret: %s", e)
338 return None
340 async def decrypt_secret_or_plaintext_async(self, bundle_json: str) -> Optional[str]:
341 """Async wrapper for decrypt_secret_or_plaintext().
343 Args:
344 bundle_json: Encrypted bundle or plaintext
346 Returns:
347 Optional[str]: Decrypted plaintext if encrypted, original input if plaintext, or None on failure
348 """
349 return await asyncio.to_thread(self.decrypt_secret_or_plaintext, bundle_json)
351 def _decrypt_bundle(self, bundle_json: str) -> str:
352 """Internal method to decrypt an already-validated encrypted bundle.
354 Args:
355 bundle_json: Validated encrypted bundle (with or without v2: prefix)
357 Returns:
358 str: Decrypted plaintext
360 Raises:
361 ValueError: If bundle is corrupted or decryption fails
362 """
363 # Strip v2: prefix if present
364 json_str = bundle_json
365 if json_str.startswith(self.FORMAT_MARKER):
366 json_str = json_str[len(self.FORMAT_MARKER) :]
368 try:
369 obj = orjson.loads(json_str)
371 # Validate required keys
372 required = {"salt", "token", "t", "m", "p"}
373 if not required.issubset(set(obj.keys())):
374 raise ValueError(f"Encrypted bundle missing required keys. Found: {set(obj.keys())}, Need: {required}")
376 # Derive key and decrypt
377 salt = base64.b64decode(obj["salt"])
378 key = self.derive_key_argon2id(self.encryption_secret, salt, time_cost=obj["t"], memory_cost=obj["m"], parallelism=obj["p"])
379 fernet = Fernet(key)
380 decrypted = fernet.decrypt(obj["token"].encode())
381 return decrypted.decode()
382 except (InvalidToken, binascii.Error) as e:
383 raise ValueError(f"Decryption failed (corrupted or wrong key): {e}") from e
384 except ValueError:
385 raise
386 except Exception as e:
387 raise ValueError(f"Decryption failed: {e}") from e
389 def is_encrypted(self, text: str) -> bool:
390 """Detect whether text is encrypted (best-effort heuristic).
392 Checks for:
393 1. v2: prefix with valid JSON bundle (most reliable)
394 2. Legacy Fernet format (base64 with version byte 0x80)
395 3. Legacy argon2id JSON format (for backward compatibility)
397 ⚠️ SECURITY WARNING - READ BEFORE USING:
399 This uses heuristics and has limitations:
400 - NOT suitable for security-critical code paths
401 - May fail to detect edge-case encrypted formats
402 - May falsely identify structured plaintext as encrypted
403 - ONLY use for non-security purposes (caching, logging, display)
405 ALWAYS validate encryption state at storage/trust boundaries using:
406 - Database schema constraints (e.g., separate plaintext/encrypted columns)
407 - Explicit markers in data structure
408 - Cryptographic signatures/MACs
409 - Hardware security modules
411 Args:
412 text: Text to check for encryption markers
414 Returns:
415 bool: True if text appears to be encrypted, False otherwise (safe default)
417 Examples:
418 >>> enc = EncryptionService(SecretStr("key"))
419 >>> encrypted = enc.encrypt_secret("secret")
420 >>> enc.is_encrypted(encrypted)
421 True
422 >>> enc.is_encrypted("plaintext")
423 False
424 """
425 if not text:
426 return False
428 # Check for v2: prefix (most reliable)
429 if text.startswith(self.FORMAT_MARKER):
430 return self._is_valid_v2_bundle(text[len(self.FORMAT_MARKER) :])
432 # Check for JSON bundle (legacy or without prefix)
433 if text.startswith("{"):
434 return self._is_valid_json_bundle(text)
436 # Check for legacy Fernet binary format
437 return self._is_valid_fernet_format(text)
439 def _is_valid_v2_bundle(self, json_str: str) -> bool:
440 """Validate v2: prefixed bundle.
442 Strictly validates that:
443 1. JSON parses successfully
444 2. Has version: "v2"
445 3. Contains all required keys
447 Args:
448 json_str: JSON string to validate (without v2: prefix)
450 Returns:
451 bool: True if valid v2 bundle, False otherwise
452 """
453 try:
454 obj = orjson.loads(json_str)
455 if not isinstance(obj, dict):
456 return False
458 # Must have version and all required keys
459 if obj.get("version") != self.FORMAT_VERSION:
460 return False
462 required = {"salt", "token", "t", "m", "p"}
463 return required.issubset(set(obj.keys()))
464 except (orjson.JSONDecodeError, ValueError):
465 # Invalid JSON means it's not a valid v2 bundle
466 return False
468 def _is_valid_json_bundle(self, json_str: str) -> bool:
469 """Validate legacy JSON bundle (without v2: prefix).
471 Args:
472 json_str: JSON string to validate
474 Returns:
475 bool: True if valid legacy JSON bundle, False otherwise
476 """
477 try:
478 obj = orjson.loads(json_str)
479 if not isinstance(obj, dict):
480 return False
482 required = {"salt", "token", "t", "m", "p"}
484 # Require either explicit v2 version or argon2id kdf
485 has_version = obj.get("version") == self.FORMAT_VERSION
486 has_kdf = obj.get("kdf") == "argon2id"
488 if not (has_version or has_kdf):
489 return False
491 # Validate all required keys present
492 return required.issubset(set(obj.keys()))
493 except (orjson.JSONDecodeError, ValueError):
494 return False
496 def _is_valid_fernet_format(self, text: str) -> bool:
497 """Validate legacy Fernet binary format (base64 with version byte 0x80).
499 Args:
500 text: Text to validate
502 Returns:
503 bool: True if valid Fernet binary format, False otherwise
504 """
505 try:
506 decoded = base64.urlsafe_b64decode(text.encode())
507 # Fernet tokens are >= 57 bytes and start with version byte 0x80
508 return len(decoded) >= 57 and decoded[0:1] == b"\x80"
509 except Exception:
510 return False
513def get_encryption_service(encryption_secret: Union[SecretStr, str]) -> EncryptionService:
514 """Factory function to create EncryptionService instance.
516 Args:
517 encryption_secret: Secret key for encryption (as SecretStr or string)
519 Returns:
520 EncryptionService: Configured encryption service instance
521 """
522 return EncryptionService(encryption_secret)
525async def _encrypt_oauth_secret_value(value: Any, existing_value: Any, encryption: EncryptionService) -> Any:
526 """Encrypt a sensitive oauth value while preserving masked placeholders.
528 Args:
529 value: Incoming oauth value to protect.
530 existing_value: Existing stored value used for masked placeholder preservation.
531 encryption: Encryption service instance.
533 Returns:
534 Any: Protected value suitable for persistence.
535 """
536 if isinstance(value, dict):
537 return await _protect_oauth_config_value(value, existing_value, encryption)
539 if isinstance(value, list):
540 existing_list = existing_value if isinstance(existing_value, list) else []
541 protected_list = []
542 for idx, item in enumerate(value):
543 prior_item = existing_list[idx] if idx < len(existing_list) else None
544 protected_list.append(await _encrypt_oauth_secret_value(item, prior_item, encryption))
545 return protected_list
547 if value is None or value == "":
548 return value
550 if not isinstance(value, str):
551 return value
553 if value == settings.masked_auth_value:
554 if isinstance(existing_value, str) and existing_value:
555 return existing_value
556 return None
558 if encryption.is_encrypted(value):
559 return value
561 return await encryption.encrypt_secret_async(value)
564async def _protect_oauth_config_value(value: Any, existing_value: Any, encryption: EncryptionService) -> Any:
565 """Recursively encrypt sensitive oauth_config values.
567 Args:
568 value: Incoming oauth_config fragment.
569 existing_value: Existing oauth_config fragment for masked placeholder preservation.
570 encryption: Encryption service instance.
572 Returns:
573 Any: Protected oauth_config fragment.
574 """
575 if isinstance(value, dict):
576 existing_dict = existing_value if isinstance(existing_value, dict) else {}
577 protected: dict[str, Any] = {}
578 for key, item in value.items():
579 existing_item = existing_dict.get(key)
580 if is_sensitive_oauth_key(key):
581 protected[key] = await _encrypt_oauth_secret_value(item, existing_item, encryption)
582 else:
583 protected[key] = await _protect_oauth_config_value(item, existing_item, encryption)
584 return protected
586 if isinstance(value, list):
587 existing_list = existing_value if isinstance(existing_value, list) else []
588 protected_list = []
589 for idx, item in enumerate(value):
590 prior_item = existing_list[idx] if idx < len(existing_list) else None
591 protected_list.append(await _protect_oauth_config_value(item, prior_item, encryption))
592 return protected_list
594 return value
597async def protect_oauth_config_for_storage(oauth_config: Any, existing_oauth_config: Any = None) -> Any:
598 """Recursively encrypt sensitive oauth_config values before persistence.
600 Args:
601 oauth_config: Incoming oauth_config payload.
602 existing_oauth_config: Existing oauth_config payload, if any.
604 Returns:
605 Any: Protected oauth_config payload.
606 """
607 if oauth_config is None:
608 return None
610 encryption = get_encryption_service(settings.auth_encryption_secret)
611 return await _protect_oauth_config_value(oauth_config, existing_oauth_config, encryption)
614async def _decrypt_oauth_config_value(value: Any, encryption: EncryptionService) -> Any:
615 """Recursively decrypt sensitive oauth_config values for runtime use.
617 Args:
618 value: oauth_config fragment.
619 encryption: Encryption service instance.
621 Returns:
622 Any: Decrypted oauth_config fragment.
623 """
624 if isinstance(value, dict):
625 decrypted: dict[str, Any] = {}
626 for key, item in value.items():
627 if is_sensitive_oauth_key(key):
628 if isinstance(item, str) and item and item != settings.masked_auth_value and encryption.is_encrypted(item):
629 decrypted_item = await encryption.decrypt_secret_async(item)
630 decrypted[key] = decrypted_item if decrypted_item is not None else item
631 else:
632 decrypted[key] = await _decrypt_oauth_config_value(item, encryption)
633 else:
634 decrypted[key] = await _decrypt_oauth_config_value(item, encryption)
635 return decrypted
637 if isinstance(value, list):
638 return [await _decrypt_oauth_config_value(item, encryption) for item in value]
640 return value
643async def decrypt_oauth_config_for_runtime(oauth_config: Any, encryption: Optional[EncryptionService] = None) -> Any:
644 """Recursively decrypt sensitive oauth_config values only at runtime use-sites.
646 Args:
647 oauth_config: Stored oauth_config payload.
648 encryption: Optional shared encryption service instance.
650 Returns:
651 Any: Runtime-ready oauth_config payload.
652 """
653 if oauth_config is None:
654 return None
656 active_encryption = encryption or get_encryption_service(settings.auth_encryption_secret)
657 return await _decrypt_oauth_config_value(oauth_config, active_encryption)