Coverage for mcpgateway / services / encryption_service.py: 100%
66 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/services/encryption_service.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti, Madhav Kandukuri
7Encryption Service.
9This service provides encryption and decryption functions for client secrets
10using the AUTH_ENCRYPTION_SECRET from configuration.
11"""
13# Standard
14import asyncio
15import base64
16import logging
17import os
18from typing import Optional, Union
20# Third-Party
21from argon2.low_level import hash_secret_raw, Type
22from cryptography.fernet import Fernet
23import orjson
24from pydantic import SecretStr
26# First-Party
27from mcpgateway.config import settings
29logger = logging.getLogger(__name__)
32class EncryptionService:
33 """Handles encryption and decryption of client secrets.
35 Examples:
36 Basic roundtrip:
37 >>> enc = EncryptionService(SecretStr('very-secret-key'))
38 >>> cipher = enc.encrypt_secret('hello')
39 >>> isinstance(cipher, str) and enc.is_encrypted(cipher)
40 True
41 >>> enc.decrypt_secret(cipher)
42 'hello'
44 Non-encrypted text detection:
45 >>> enc.is_encrypted('plain-text')
46 False
47 """
49 def __init__(
50 self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16
51 ):
52 """Initialize the encryption handler.
54 Args:
55 encryption_secret: Secret key for encryption/decryption
56 time_cost: Argon2id time cost parameter
57 memory_cost: Argon2id memory cost parameter (in KiB)
58 parallelism: Argon2id parallelism parameter
59 hash_len: Length of the derived key
60 salt_len: Length of the salt
61 """
62 # Handle both SecretStr and plain string for backwards compatibility
63 if isinstance(encryption_secret, SecretStr):
64 self.encryption_secret = encryption_secret.get_secret_value().encode()
65 else:
66 # If a plain string is passed, use it directly (for testing/legacy code)
67 self.encryption_secret = str(encryption_secret).encode()
68 self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3)
69 self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536)
70 self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1)
71 self.hash_len = hash_len
72 self.salt_len = salt_len
74 def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes:
75 """Derive a key from a passphrase using Argon2id.
77 Args:
78 passphrase: The passphrase to derive the key from
79 salt: The salt to use in key derivation
80 time_cost: Argon2id time cost parameter
81 memory_cost: Argon2id memory cost parameter (in KiB)
82 parallelism: Argon2id parallelism parameter
84 Returns:
85 The derived key
86 """
87 raw = hash_secret_raw(
88 secret=passphrase,
89 salt=salt,
90 time_cost=time_cost,
91 memory_cost=memory_cost, # KiB
92 parallelism=parallelism,
93 hash_len=self.hash_len,
94 type=Type.ID,
95 )
96 return base64.urlsafe_b64encode(raw)
98 def encrypt_secret(self, plaintext: str) -> str:
99 """Encrypt a plaintext secret.
101 Args:
102 plaintext: The secret to encrypt
104 Returns:
105 Base64-encoded encrypted string
107 Raises:
108 Exception: If encryption fails
109 """
110 try:
111 salt = os.urandom(16)
112 key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism)
113 fernet = Fernet(key)
114 encrypted = fernet.encrypt(plaintext.encode())
115 return orjson.dumps(
116 {
117 "kdf": "argon2id",
118 "t": self.time_cost,
119 "m": self.memory_cost,
120 "p": self.parallelism,
121 "salt": base64.b64encode(salt).decode(),
122 "token": encrypted.decode(),
123 }
124 ).decode()
125 except Exception as e:
126 logger.error(f"Failed to encrypt secret: {e}")
127 raise
129 async def encrypt_secret_async(self, plaintext: str) -> str:
130 """Encrypt a plaintext secret asynchronously.
132 Args:
133 plaintext: The secret to encrypt
135 Returns:
136 Base64-encoded encrypted string
138 Raises:
139 Exception: If encryption fails
140 """
141 return await asyncio.to_thread(self.encrypt_secret, plaintext)
143 def decrypt_secret(self, bundle_json: str) -> Optional[str]:
144 """Decrypt an encrypted secret.
146 Args:
147 bundle_json: str: JSON string containing encryption metadata and token
149 Returns:
150 Decrypted secret string, or None if decryption fails
151 """
152 try:
153 b = orjson.loads(bundle_json)
154 salt = base64.b64decode(b["salt"])
155 key = self.derive_key_argon2id(self.encryption_secret, salt, time_cost=b["t"], memory_cost=b["m"], parallelism=b["p"])
156 fernet = Fernet(key)
157 decrypted = fernet.decrypt(b["token"].encode())
158 return decrypted.decode()
159 except Exception as e:
160 logger.error(f"Failed to decrypt secret: {e}")
161 return None
163 async def decrypt_secret_async(self, bundle_json: str) -> Optional[str]:
164 """Decrypt an encrypted secret asynchronously.
166 Args:
167 bundle_json: str: JSON string containing encryption metadata and token
169 Returns:
170 Decrypted secret string, or None if decryption fails
171 """
172 return await asyncio.to_thread(self.decrypt_secret, bundle_json)
174 def is_encrypted(self, text: str) -> bool:
175 """Check if a string appears to be encrypted.
177 Args:
178 text: String to check
180 Returns:
181 True if the string appears to be encrypted
183 Note:
184 Supports both legacy PBKDF2 (base64-wrapped Fernet) and new Argon2id
185 (JSON bundle) formats. Checks JSON format first, then falls back to
186 base64 check for legacy format.
187 """
188 if not text:
189 return False
191 # Check for new Argon2id JSON bundle format
192 if text.startswith("{"):
193 try:
194 obj = orjson.loads(text)
195 if isinstance(obj, dict) and obj.get("kdf") == "argon2id":
196 return True
197 except (orjson.JSONDecodeError, ValueError, KeyError):
198 # Not valid JSON or missing expected structure - continue to legacy check
199 pass
201 # Check for legacy PBKDF2 base64-wrapped Fernet format
202 try:
203 decoded = base64.urlsafe_b64decode(text.encode())
204 # Encrypted data should be at least 32 bytes (Fernet minimum)
205 return len(decoded) >= 32
206 except Exception:
207 return False
210def get_encryption_service(encryption_secret: Union[SecretStr, str]) -> EncryptionService:
211 """Get an EncryptionService instance.
213 Args:
214 encryption_secret: Secret key for encryption/decryption (SecretStr or plain string)
216 Returns:
217 EncryptionService instance
219 Examples:
220 >>> enc = get_encryption_service(SecretStr('k'))
221 >>> isinstance(enc, EncryptionService)
222 True
223 >>> enc2 = get_encryption_service('plain-key')
224 >>> isinstance(enc2, EncryptionService)
225 True
226 """
227 return EncryptionService(encryption_secret)