Coverage for mcpgateway / services / encryption_service.py: 100%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/services/encryption_service.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti, Madhav Kandukuri 

6 

7Encryption Service. 

8 

9This service provides encryption and decryption functions for client secrets 

10using the AUTH_ENCRYPTION_SECRET from configuration. 

11""" 

12 

13# Standard 

14import asyncio 

15import base64 

16import logging 

17import os 

18from typing import Optional, Union 

19 

20# Third-Party 

21from argon2.low_level import hash_secret_raw, Type 

22from cryptography.fernet import Fernet 

23import orjson 

24from pydantic import SecretStr 

25 

26# First-Party 

27from mcpgateway.config import settings 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32class EncryptionService: 

33 """Handles encryption and decryption of client secrets. 

34 

35 Examples: 

36 Basic roundtrip: 

37 >>> enc = EncryptionService(SecretStr('very-secret-key')) 

38 >>> cipher = enc.encrypt_secret('hello') 

39 >>> isinstance(cipher, str) and enc.is_encrypted(cipher) 

40 True 

41 >>> enc.decrypt_secret(cipher) 

42 'hello' 

43 

44 Non-encrypted text detection: 

45 >>> enc.is_encrypted('plain-text') 

46 False 

47 """ 

48 

49 def __init__( 

50 self, encryption_secret: Union[SecretStr, str], time_cost: Optional[int] = None, memory_cost: Optional[int] = None, parallelism: Optional[int] = None, hash_len: int = 32, salt_len: int = 16 

51 ): 

52 """Initialize the encryption handler. 

53 

54 Args: 

55 encryption_secret: Secret key for encryption/decryption 

56 time_cost: Argon2id time cost parameter 

57 memory_cost: Argon2id memory cost parameter (in KiB) 

58 parallelism: Argon2id parallelism parameter 

59 hash_len: Length of the derived key 

60 salt_len: Length of the salt 

61 """ 

62 # Handle both SecretStr and plain string for backwards compatibility 

63 if isinstance(encryption_secret, SecretStr): 

64 self.encryption_secret = encryption_secret.get_secret_value().encode() 

65 else: 

66 # If a plain string is passed, use it directly (for testing/legacy code) 

67 self.encryption_secret = str(encryption_secret).encode() 

68 self.time_cost = time_cost or getattr(settings, "argon2id_time_cost", 3) 

69 self.memory_cost = memory_cost or getattr(settings, "argon2id_memory_cost", 65536) 

70 self.parallelism = parallelism or getattr(settings, "argon2id_parallelism", 1) 

71 self.hash_len = hash_len 

72 self.salt_len = salt_len 

73 

74 def derive_key_argon2id(self, passphrase: bytes, salt: bytes, time_cost: int, memory_cost: int, parallelism: int) -> bytes: 

75 """Derive a key from a passphrase using Argon2id. 

76 

77 Args: 

78 passphrase: The passphrase to derive the key from 

79 salt: The salt to use in key derivation 

80 time_cost: Argon2id time cost parameter 

81 memory_cost: Argon2id memory cost parameter (in KiB) 

82 parallelism: Argon2id parallelism parameter 

83 

84 Returns: 

85 The derived key 

86 """ 

87 raw = hash_secret_raw( 

88 secret=passphrase, 

89 salt=salt, 

90 time_cost=time_cost, 

91 memory_cost=memory_cost, # KiB 

92 parallelism=parallelism, 

93 hash_len=self.hash_len, 

94 type=Type.ID, 

95 ) 

96 return base64.urlsafe_b64encode(raw) 

97 

98 def encrypt_secret(self, plaintext: str) -> str: 

99 """Encrypt a plaintext secret. 

100 

101 Args: 

102 plaintext: The secret to encrypt 

103 

104 Returns: 

105 Base64-encoded encrypted string 

106 

107 Raises: 

108 Exception: If encryption fails 

109 """ 

110 try: 

111 salt = os.urandom(16) 

112 key = self.derive_key_argon2id(self.encryption_secret, salt, self.time_cost, self.memory_cost, self.parallelism) 

113 fernet = Fernet(key) 

114 encrypted = fernet.encrypt(plaintext.encode()) 

115 return orjson.dumps( 

116 { 

117 "kdf": "argon2id", 

118 "t": self.time_cost, 

119 "m": self.memory_cost, 

120 "p": self.parallelism, 

121 "salt": base64.b64encode(salt).decode(), 

122 "token": encrypted.decode(), 

123 } 

124 ).decode() 

125 except Exception as e: 

126 logger.error(f"Failed to encrypt secret: {e}") 

127 raise 

128 

129 async def encrypt_secret_async(self, plaintext: str) -> str: 

130 """Encrypt a plaintext secret asynchronously. 

131 

132 Args: 

133 plaintext: The secret to encrypt 

134 

135 Returns: 

136 Base64-encoded encrypted string 

137 

138 Raises: 

139 Exception: If encryption fails 

140 """ 

141 return await asyncio.to_thread(self.encrypt_secret, plaintext) 

142 

143 def decrypt_secret(self, bundle_json: str) -> Optional[str]: 

144 """Decrypt an encrypted secret. 

145 

146 Args: 

147 bundle_json: str: JSON string containing encryption metadata and token 

148 

149 Returns: 

150 Decrypted secret string, or None if decryption fails 

151 """ 

152 try: 

153 b = orjson.loads(bundle_json) 

154 salt = base64.b64decode(b["salt"]) 

155 key = self.derive_key_argon2id(self.encryption_secret, salt, time_cost=b["t"], memory_cost=b["m"], parallelism=b["p"]) 

156 fernet = Fernet(key) 

157 decrypted = fernet.decrypt(b["token"].encode()) 

158 return decrypted.decode() 

159 except Exception as e: 

160 logger.error(f"Failed to decrypt secret: {e}") 

161 return None 

162 

163 async def decrypt_secret_async(self, bundle_json: str) -> Optional[str]: 

164 """Decrypt an encrypted secret asynchronously. 

165 

166 Args: 

167 bundle_json: str: JSON string containing encryption metadata and token 

168 

169 Returns: 

170 Decrypted secret string, or None if decryption fails 

171 """ 

172 return await asyncio.to_thread(self.decrypt_secret, bundle_json) 

173 

174 def is_encrypted(self, text: str) -> bool: 

175 """Check if a string appears to be encrypted. 

176 

177 Args: 

178 text: String to check 

179 

180 Returns: 

181 True if the string appears to be encrypted 

182 

183 Note: 

184 Supports both legacy PBKDF2 (base64-wrapped Fernet) and new Argon2id 

185 (JSON bundle) formats. Checks JSON format first, then falls back to 

186 base64 check for legacy format. 

187 """ 

188 if not text: 

189 return False 

190 

191 # Check for new Argon2id JSON bundle format 

192 if text.startswith("{"): 

193 try: 

194 obj = orjson.loads(text) 

195 if isinstance(obj, dict) and obj.get("kdf") == "argon2id": 

196 return True 

197 except (orjson.JSONDecodeError, ValueError, KeyError): 

198 # Not valid JSON or missing expected structure - continue to legacy check 

199 pass 

200 

201 # Check for legacy PBKDF2 base64-wrapped Fernet format 

202 try: 

203 decoded = base64.urlsafe_b64decode(text.encode()) 

204 # Encrypted data should be at least 32 bytes (Fernet minimum) 

205 return len(decoded) >= 32 

206 except Exception: 

207 return False 

208 

209 

210def get_encryption_service(encryption_secret: Union[SecretStr, str]) -> EncryptionService: 

211 """Get an EncryptionService instance. 

212 

213 Args: 

214 encryption_secret: Secret key for encryption/decryption (SecretStr or plain string) 

215 

216 Returns: 

217 EncryptionService instance 

218 

219 Examples: 

220 >>> enc = get_encryption_service(SecretStr('k')) 

221 >>> isinstance(enc, EncryptionService) 

222 True 

223 >>> enc2 = get_encryption_service('plain-key') 

224 >>> isinstance(enc2, EncryptionService) 

225 True 

226 """ 

227 return EncryptionService(encryption_secret)