Coverage for mcpgateway / utils / jwt_config_helper.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/jwt_config_helper.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7JWT Configuration Helper Utilities with caching. 

8This module provides JWT configuration validation and key retrieval functions. 

9Configuration is validated once and cached for performance. 

10Key files are cached with mtime tracking to avoid repeated disk I/O. 

11""" 

12 

13# Standard 

14from functools import lru_cache 

15from pathlib import Path 

16from typing import Tuple 

17 

18# First-Party 

19from mcpgateway.config import settings 

20 

21# Cache for key file contents with mtime 

22# Key: (path, mtime), Value: key content 

23_key_file_cache: dict[Tuple[str, float], str] = {} 

24 

25 

26class JWTConfigurationError(Exception): 

27 """Raised when JWT configuration is invalid or incomplete. 

28 

29 Examples: 

30 >>> # Create a configuration error 

31 >>> error = JWTConfigurationError("Missing secret key") 

32 >>> str(error) 

33 'Missing secret key' 

34 >>> isinstance(error, Exception) 

35 True 

36 """ 

37 

38 

39def _read_key_file_cached(path: Path) -> str: 

40 """Read key file with mtime-based caching. 

41 

42 Args: 

43 path: Path to key file 

44 

45 Returns: 

46 str: Key file contents 

47 

48 Raises: 

49 FileNotFoundError: If file doesn't exist 

50 IOError: If file cannot be read 

51 """ 

52 try: 

53 path_str = str(path) 

54 mtime = path.stat().st_mtime 

55 

56 # Check cache 

57 cache_key = (path_str, mtime) 

58 if cache_key in _key_file_cache: 

59 return _key_file_cache[cache_key] 

60 

61 # Read file 

62 with open(path, "r") as f: 

63 content = f.read() 

64 

65 # Clear old entries for this path and cache new content 

66 _key_file_cache.clear() 

67 _key_file_cache[cache_key] = content 

68 

69 return content 

70 except Exception as e: 

71 raise IOError(f"Failed to read key file {path}: {e}") from e 

72 

73 

74@lru_cache(maxsize=1) 

75def _get_validated_config() -> str: 

76 """Validate and cache JWT configuration at first call. 

77 

78 Returns: 

79 The validated algorithm string. 

80 

81 Raises: 

82 JWTConfigurationError: If configuration is invalid 

83 """ 

84 algorithm = settings.jwt_algorithm 

85 

86 if algorithm.startswith("HS"): 

87 secret_key = settings.jwt_secret_key.get_secret_value() if hasattr(settings.jwt_secret_key, "get_secret_value") else settings.jwt_secret_key 

88 if not secret_key: 

89 raise JWTConfigurationError(f"JWT algorithm {algorithm} requires jwt_secret_key to be set") 

90 else: 

91 _validate_asymmetric_keys(algorithm) 

92 

93 return algorithm 

94 

95 

96def validate_jwt_algo_and_keys() -> None: 

97 """Validate JWT algorithm and key configuration. 

98 

99 This function is cached after first successful call. Subsequent calls 

100 are no-ops. Use clear_jwt_caches() to reset if configuration changes. 

101 

102 Raises: 

103 JWTConfigurationError: If configuration is invalid 

104 FileNotFoundError: If key files don't exist 

105 """ 

106 _get_validated_config() 

107 

108 

109def _validate_asymmetric_keys(algorithm: str) -> None: 

110 """Validate asymmetric key configuration. 

111 

112 Args: 

113 algorithm: JWT algorithm being used 

114 

115 Raises: 

116 JWTConfigurationError: If key paths are not configured 

117 FileNotFoundError: If key files don't exist 

118 """ 

119 if not settings.jwt_public_key_path or not settings.jwt_private_key_path: 

120 raise JWTConfigurationError(f"JWT algorithm {algorithm} requires both jwt_public_key_path and jwt_private_key_path to be set") 

121 

122 # Resolve paths 

123 public_key_path = Path(settings.jwt_public_key_path) 

124 private_key_path = Path(settings.jwt_private_key_path) 

125 

126 if not public_key_path.is_absolute(): 

127 public_key_path = Path.cwd() / public_key_path 

128 if not private_key_path.is_absolute(): 

129 private_key_path = Path.cwd() / private_key_path 

130 

131 if not public_key_path.is_file(): 

132 raise JWTConfigurationError(f"JWT public key path is invalid: {public_key_path}") 

133 

134 if not private_key_path.is_file(): 

135 raise JWTConfigurationError(f"JWT private key path is invalid: {private_key_path}") 

136 

137 

138@lru_cache(maxsize=1) 

139def get_jwt_private_key_or_secret() -> str: 

140 """Get signing key based on configured algorithm (cached). 

141 

142 Returns secret key for HMAC algorithms or private key content for asymmetric algorithms. 

143 For file-based keys, content is cached with mtime tracking to avoid repeated disk I/O. 

144 

145 Returns: 

146 str: The signing key as string 

147 

148 Examples: 

149 >>> # Function returns a string key 

150 >>> result = get_jwt_private_key_or_secret() 

151 >>> isinstance(result, str) 

152 True 

153 """ 

154 algorithm = settings.jwt_algorithm.upper() 

155 

156 if algorithm.startswith("HS"): 

157 # Handle SecretStr type from Pydantic v2 

158 return settings.jwt_secret_key.get_secret_value() if hasattr(settings.jwt_secret_key, "get_secret_value") else settings.jwt_secret_key 

159 

160 path = Path(settings.jwt_private_key_path) 

161 if not path.is_absolute(): 

162 path = Path.cwd() / path 

163 return _read_key_file_cached(path) 

164 

165 

166@lru_cache(maxsize=1) 

167def get_jwt_public_key_or_secret() -> str: 

168 """Get verification key based on configured algorithm (cached). 

169 

170 Returns secret key for HMAC algorithms or public key content for asymmetric algorithms. 

171 For file-based keys, content is cached with mtime tracking to avoid repeated disk I/O. 

172 

173 Returns: 

174 str: The verification key as string 

175 

176 Examples: 

177 >>> # Function returns a string key 

178 >>> result = get_jwt_public_key_or_secret() 

179 >>> isinstance(result, str) 

180 True 

181 """ 

182 algorithm = settings.jwt_algorithm.upper() 

183 

184 if algorithm.startswith("HS"): 

185 # Handle SecretStr type from Pydantic v2 

186 return settings.jwt_secret_key.get_secret_value() if hasattr(settings.jwt_secret_key, "get_secret_value") else settings.jwt_secret_key 

187 

188 path = Path(settings.jwt_public_key_path) 

189 if not path.is_absolute(): 

190 path = Path.cwd() / path 

191 return _read_key_file_cached(path) 

192 

193 

194def clear_jwt_caches() -> None: 

195 """Clear all JWT-related caches. 

196 

197 Call this function: 

198 - In test fixtures to ensure test isolation 

199 - After config reload (if runtime config changes are supported) 

200 - After key rotation (if keys are rotated at runtime) 

201 

202 Note: In production, JWT config/key changes require application restart. 

203 """ 

204 _get_validated_config.cache_clear() 

205 get_jwt_public_key_or_secret.cache_clear() 

206 get_jwt_private_key_or_secret.cache_clear() 

207 _key_file_cache.clear()