Coverage for mcpgateway / utils / services_auth.py: 100%
56 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/services_auth.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7mcpgateway.utils.services_auth - Authentication utilities for MCP Gateway
8Doctest examples
9----------------
10>>> import os
11>>> from mcpgateway.utils import services_auth
12>>> os.environ['AUTH_ENCRYPTION_SECRET'] = 'doctest-secret'
13>>> services_auth.settings.auth_encryption_secret = 'doctest-secret'
14>>> key = services_auth.get_key()
15>>> isinstance(key, bytes)
16True
17>>> d = {'user': 'alice'}
18>>> token = services_auth.encode_auth(d)
19>>> isinstance(token, str)
20True
21>>> services_auth.decode_auth(token) == d
22True
23>>> services_auth.encode_auth(None) is None
24True
25>>> services_auth.decode_auth(None) == {}
26True
27>>> services_auth.settings.auth_encryption_secret = ''
28>>> try:
29... services_auth.get_key()
30... except ValueError as e:
31... print('error')
32error
33"""
35# Standard
36import base64
37import hashlib
38import os
39from typing import Tuple
41# Third-Party
42from cryptography.hazmat.primitives.ciphers.aead import AESGCM
43import orjson
44from pydantic import SecretStr
46# First-Party
47from mcpgateway.config import settings
49# Cache for derived key and AESGCM instance
50# Key: passphrase value, Value: (key_bytes, AESGCM instance)
51_crypto_cache: dict[str, Tuple[bytes, AESGCM]] = {}
54def _get_passphrase() -> str:
55 """Extract passphrase from settings, handling SecretStr type.
57 Returns:
58 str: The passphrase value
60 Raises:
61 ValueError: If the passphrase is not set or empty
62 """
63 passphrase = settings.auth_encryption_secret
64 if not passphrase:
65 raise ValueError("AUTH_ENCRYPTION_SECRET not set in environment.")
67 # If it's SecretStr, extract the real value
68 if isinstance(passphrase, SecretStr):
69 return passphrase.get_secret_value()
70 return passphrase
73def get_key() -> bytes:
74 """
75 Generate a 32-byte AES encryption key derived from a passphrase.
77 The key is cached based on the passphrase value. If the passphrase
78 changes, the cache is automatically invalidated.
80 Returns:
81 bytes: A 32-byte encryption key.
83 Raises:
84 ValueError: If the passphrase is not set or empty.
86 Doctest:
87 >>> import os
88 >>> from mcpgateway.utils import services_auth
89 >>> os.environ['AUTH_ENCRYPTION_SECRET'] = 'doctest-secret'
90 >>> services_auth.settings.auth_encryption_secret = 'doctest-secret'
91 >>> key = services_auth.get_key()
92 >>> isinstance(key, bytes)
93 True
94 >>> services_auth.settings.auth_encryption_secret = ''
95 >>> try:
96 ... services_auth.get_key()
97 ... except ValueError as e:
98 ... print('error')
99 error
100 """
101 passphrase = _get_passphrase()
103 # Check cache
104 if passphrase in _crypto_cache:
105 return _crypto_cache[passphrase][0]
107 # Derive key
108 key = hashlib.sha256(passphrase.encode()).digest() # 32-byte key
110 # Cache key and AESGCM together
111 aesgcm = AESGCM(key)
112 _crypto_cache.clear() # Clear old entries
113 _crypto_cache[passphrase] = (key, aesgcm)
115 return key
118def _get_aesgcm() -> AESGCM:
119 """Get cached AESGCM instance, creating if needed.
121 Returns:
122 AESGCM: Cached AESGCM cipher instance
124 Raises:
125 ValueError: If the passphrase is not set or empty
126 """
127 passphrase = _get_passphrase()
129 # Check cache
130 if passphrase in _crypto_cache:
131 return _crypto_cache[passphrase][1]
133 # Derive key and create AESGCM
134 key = hashlib.sha256(passphrase.encode()).digest()
135 aesgcm = AESGCM(key)
137 # Cache both
138 _crypto_cache.clear() # Clear old entries
139 _crypto_cache[passphrase] = (key, aesgcm)
141 return aesgcm
144def clear_crypto_cache() -> None:
145 """Clear the crypto cache.
147 Call this function:
148 - In test fixtures to ensure test isolation
149 - After passphrase rotation (if supported at runtime)
150 """
151 _crypto_cache.clear()
154def encode_auth(auth_value: dict) -> str:
155 """
156 Encrypt and encode an authentication dictionary into a compact base64-url string.
158 Args:
159 auth_value (dict): The authentication dictionary to encrypt and encode.
161 Returns:
162 str: A base64-url-safe encrypted string representing the dictionary, or None if input is None.
164 Doctest:
165 >>> import os
166 >>> from mcpgateway.utils import services_auth
167 >>> os.environ['AUTH_ENCRYPTION_SECRET'] = 'doctest-secret'
168 >>> services_auth.settings.auth_encryption_secret = 'doctest-secret'
169 >>> token = services_auth.encode_auth({'user': 'alice'})
170 >>> isinstance(token, str)
171 True
172 >>> services_auth.encode_auth(None) is None
173 True
174 """
175 if not auth_value:
176 return None
177 plaintext = orjson.dumps(auth_value)
178 aesgcm = _get_aesgcm()
179 nonce = os.urandom(12)
180 ciphertext = aesgcm.encrypt(nonce, plaintext, None)
181 combined = nonce + ciphertext
182 encoded = base64.urlsafe_b64encode(combined).rstrip(b"=")
183 return encoded.decode()
186def decode_auth(encoded_value: str) -> dict:
187 """
188 Decode and decrypt a base64-url-safe encrypted string back into the authentication dictionary.
190 Args:
191 encoded_value (str): The encrypted base64-url string to decode and decrypt.
193 Returns:
194 dict: The decrypted authentication dictionary, or empty dict if input is None.
196 Doctest:
197 >>> import os
198 >>> from mcpgateway.utils import services_auth
199 >>> os.environ['AUTH_ENCRYPTION_SECRET'] = 'doctest-secret'
200 >>> services_auth.settings.auth_encryption_secret = 'doctest-secret'
201 >>> d = {'user': 'alice'}
202 >>> token = services_auth.encode_auth(d)
203 >>> services_auth.decode_auth(token) == d
204 True
205 >>> services_auth.decode_auth(None) == {}
206 True
207 """
208 if not encoded_value:
209 return {}
210 aesgcm = _get_aesgcm()
211 # Fix base64 padding
212 padded = encoded_value + "=" * (-len(encoded_value) % 4)
213 combined = base64.urlsafe_b64decode(padded)
214 nonce = combined[:12]
215 ciphertext = combined[12:]
216 plaintext = aesgcm.decrypt(nonce, ciphertext, None)
217 return orjson.loads(plaintext)