Coverage for mcpgateway / utils / validate_signature.py: 100%
71 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3"""Location: ./mcpgateway/utils/validate_signature.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Madhav Kandukuri
8Utility to validate Ed25519 signatures.
9Given data, signature, and public key PEM, verifies authenticity.
10"""
12# Future
13from __future__ import annotations
15# Standard
16import hashlib
18# Logging setup
19import logging
20from typing import Tuple
22# Third-Party
23from cryptography.hazmat.primitives import serialization
24from cryptography.hazmat.primitives.asymmetric import ed25519
26# First-Party
27from mcpgateway.config import get_settings
29logger = logging.getLogger(__name__)
31# Cache for signature validation results
32# Key: (data_hash, signature_hex, public_key_hash), Value: bool
33_signature_validation_cache: dict[Tuple[str, str, str], bool] = {}
35# Cache for loaded public keys
36# Key: public_key_pem_hash, Value: public_key object
37_public_key_cache: dict[str, ed25519.Ed25519PublicKey] = {}
39# ---------------------------------------------------------------------------
40# Helper: sign data using Ed25519 private key
41# ---------------------------------------------------------------------------
44def sign_data(data: bytes, private_key_pem: str) -> str:
45 """Sign data using an Ed25519 private key.
47 Args:
48 data: Message bytes to sign.
49 private_key_pem: PEM-formatted private key string.
51 Returns:
52 str: Hex-encoded signature.
54 Raises:
55 TypeError: If the provided key is not an Ed25519 private key.
57 Examples:
58 >>> from cryptography.hazmat.primitives.asymmetric import ed25519
59 >>> from cryptography.hazmat.primitives import serialization
60 >>>
61 >>> # Generate a test key pair
62 >>> private_key = ed25519.Ed25519PrivateKey.generate()
63 >>> private_pem = private_key.private_bytes(
64 ... encoding=serialization.Encoding.PEM,
65 ... format=serialization.PrivateFormat.PKCS8,
66 ... encryption_algorithm=serialization.NoEncryption()
67 ... ).decode()
68 >>>
69 >>> # Sign some data
70 >>> data = b"test message"
71 >>> signature = sign_data(data, private_pem)
72 >>> isinstance(signature, str)
73 True
74 >>> len(signature) == 128 # 64 bytes = 128 hex chars
75 True
76 """
77 try:
78 private_key = serialization.load_pem_private_key(private_key_pem.encode(), password=None)
79 if not isinstance(private_key, ed25519.Ed25519PrivateKey):
80 raise TypeError("Expected an Ed25519 private key")
81 return private_key.sign(data).hex()
82 except Exception as e:
83 logger.error(f"Error signing data: {e}")
84 raise
87# ---------------------------------------------------------------------------
88# Validate Ed25519 signature
89# ---------------------------------------------------------------------------
92def _load_public_key_cached(public_key_pem: str) -> ed25519.Ed25519PublicKey:
93 """Load and cache Ed25519 public key.
95 Args:
96 public_key_pem: PEM-formatted public key string
98 Returns:
99 ed25519.Ed25519PublicKey: The loaded public key
101 Raises:
102 ValueError: If the key cannot be loaded
103 """
104 key_hash = hashlib.sha256(public_key_pem.encode()).hexdigest()
106 if key_hash in _public_key_cache:
107 return _public_key_cache[key_hash]
109 public_key = serialization.load_pem_public_key(public_key_pem.encode())
111 # Limit cache size
112 if len(_public_key_cache) > 100:
113 _public_key_cache.clear()
115 _public_key_cache[key_hash] = public_key
116 return public_key
119def validate_signature(data: bytes, signature: bytes | str, public_key_pem: str) -> bool:
120 """Validate an Ed25519 signature with caching.
122 Caches validation results to avoid repeated cryptographic verification
123 for the same data/signature/key combination.
125 Args:
126 data: Original message bytes.
127 signature: Signature bytes or hex string to verify.
128 public_key_pem: PEM-formatted public key string.
130 Returns:
131 bool: True if signature is valid, False otherwise.
133 Examples:
134 >>> from cryptography.hazmat.primitives.asymmetric import ed25519
135 >>> from cryptography.hazmat.primitives import serialization
136 >>>
137 >>> # Generate a test key pair
138 >>> private_key = ed25519.Ed25519PrivateKey.generate()
139 >>> public_key = private_key.public_key()
140 >>> public_pem = public_key.public_bytes(
141 ... encoding=serialization.Encoding.PEM,
142 ... format=serialization.PublicFormat.SubjectPublicKeyInfo
143 ... ).decode()
144 >>>
145 >>> # Sign and verify
146 >>> data = b"test message"
147 >>> signature = private_key.sign(data)
148 >>> validate_signature(data, signature, public_pem)
149 True
150 >>>
151 >>> # Test with hex signature
152 >>> hex_sig = signature.hex()
153 >>> validate_signature(data, hex_sig, public_pem)
154 True
155 >>>
156 >>> # Test invalid signature
157 >>> validate_signature(b"wrong data", signature, public_pem)
158 False
159 >>>
160 >>> # Test with string data (gets encoded)
161 >>> validate_signature("test message", signature, public_pem)
162 True
163 >>>
164 >>> # Test invalid hex signature format
165 >>> validate_signature(data, "not-valid-hex", public_pem)
166 False
167 """
168 if isinstance(data, str):
169 data = data.encode()
171 # Accept hex-encoded signatures
172 if isinstance(signature, str):
173 try:
174 signature_bytes = bytes.fromhex(signature)
175 except ValueError:
176 logger.error("Invalid hex signature format.")
177 return False
178 else:
179 signature_bytes = signature
181 # Create cache key
182 data_hash = hashlib.sha256(data).hexdigest()
183 signature_hex = signature_bytes.hex()
184 key_hash = hashlib.sha256(public_key_pem.encode()).hexdigest()
185 cache_key = (data_hash, signature_hex, key_hash)
187 # Check cache
188 if cache_key in _signature_validation_cache:
189 return _signature_validation_cache[cache_key]
191 # Validate signature
192 try:
193 public_key = _load_public_key_cached(public_key_pem)
194 public_key.verify(signature_bytes, data)
195 result = True
196 except Exception as e:
197 logger.error(f"Signature validation failed: {e}")
198 result = False
200 # Cache result (limit cache size)
201 if len(_signature_validation_cache) > 1000:
202 # Keep only the most recent 500 entries
203 items = list(_signature_validation_cache.items())
204 _signature_validation_cache.clear()
205 _signature_validation_cache.update(items[-500:])
207 _signature_validation_cache[cache_key] = result
208 return result
211def clear_signature_caches() -> None:
212 """Clear signature validation caches.
214 Call this function:
215 - In test fixtures to ensure test isolation
216 - After key rotation
217 """
218 _signature_validation_cache.clear()
219 _public_key_cache.clear()
222# ---------------------------------------------------------------------------
223# Helper: re-sign data after verifying old signature
224# ---------------------------------------------------------------------------
227def resign_data(
228 data: bytes,
229 old_public_key_pem: str,
230 old_signature: bytes | str,
231 new_private_key_pem: str,
232) -> bytes | None:
233 """Re-sign data after verifying old signature.
235 Args:
236 data: Message bytes to verify and re-sign.
237 old_public_key_pem: PEM-formatted old public key.
238 old_signature: Existing signature bytes or empty string.
239 new_private_key_pem: PEM-formatted new private key.
241 Returns:
242 bytes | None: New signature if re-signed, None if verification fails.
244 Examples:
245 >>> from cryptography.hazmat.primitives.asymmetric import ed25519
246 >>> from cryptography.hazmat.primitives import serialization
247 >>>
248 >>> # Generate old and new key pairs
249 >>> old_private = ed25519.Ed25519PrivateKey.generate()
250 >>> old_public = old_private.public_key()
251 >>> new_private = ed25519.Ed25519PrivateKey.generate()
252 >>>
253 >>> old_public_pem = old_public.public_bytes(
254 ... encoding=serialization.Encoding.PEM,
255 ... format=serialization.PublicFormat.SubjectPublicKeyInfo
256 ... ).decode()
257 >>> new_private_pem = new_private.private_bytes(
258 ... encoding=serialization.Encoding.PEM,
259 ... format=serialization.PrivateFormat.PKCS8,
260 ... encryption_algorithm=serialization.NoEncryption()
261 ... ).decode()
262 >>>
263 >>> # Test first-time signing (no old signature)
264 >>> data = b"test message"
265 >>> new_sig = resign_data(data, old_public_pem, "", new_private_pem)
266 >>> isinstance(new_sig, str)
267 True
268 >>>
269 >>> # Test re-signing with valid old signature
270 >>> old_sig = old_private.sign(data)
271 >>> new_sig2 = resign_data(data, old_public_pem, old_sig, new_private_pem)
272 >>> isinstance(new_sig2, str)
273 True
274 >>> new_sig2 != old_sig.hex() # New signature should be different
275 True
276 >>>
277 >>> # Test with invalid old signature (should return None)
278 >>> bad_sig = b"invalid signature bytes"
279 >>> result = resign_data(data, old_public_pem, bad_sig, new_private_pem)
280 >>> result is None
281 True
282 """
283 # Handle first-time signing (no old signature)
284 if not old_signature:
285 logger.info("No existing signature found — signing for the first time.")
286 return sign_data(data, new_private_key_pem)
288 if isinstance(old_signature, str):
289 old_signature = old_signature.encode()
291 # Verify old signature before re-signing
292 if not validate_signature(data, old_signature, old_public_key_pem):
293 logger.warning("Old signature invalid — not re-signing.")
294 return None
296 logger.info("Old signature valid — re-signing with new key.")
297 return sign_data(data, new_private_key_pem)
300if __name__ == "__main__":
301 # Example usage
302 settings = get_settings()
304 private_key_pem = settings.ed25519_private_key
305 private_key_obj = serialization.load_pem_private_key(
306 private_key_pem.encode(),
307 password=None,
308 )
309 public_key = private_key_obj.public_key()
311 message = b"test message"
312 sig = private_key_obj.sign(message)
314 public_pem = public_key.public_bytes(
315 encoding=serialization.Encoding.PEM,
316 format=serialization.PublicFormat.SubjectPublicKeyInfo,
317 ).decode()
319 logger.info("Signature valid:", validate_signature(message, sig, public_pem))