Manages OAuth2 token lifecycle including acquisition and refresh.
Handles token expiration and thread-safe token refresh using asyncio locks.
Implements proper logging for debugging and monitoring token lifecycle events.
Source code in src/tools/core/utils/token_manager.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152 | class OAuth2ClientCredentialsManager:
"""
Manages OAuth2 token lifecycle including acquisition and refresh.
Handles token expiration and thread-safe token refresh using asyncio locks.
Implements proper logging for debugging and monitoring token lifecycle events.
"""
def __init__(
self,
api_key: str,
client_secret_base64: str,
token_url: str,
refresh_buffer: int = 60,
logger: Optional[logging.Logger] = None
) -> None:
"""
Initialize the TokenManager.
Args:
api_key: API key for authentication
client_secret_base64: Base64 encoded client secret
token_url: OAuth2 token endpoint URL
refresh_buffer: Seconds before expiry to trigger refresh
logger: Optional custom logger instance
Raises:
ValueError: If required parameters are missing or invalid
"""
# Validate inputs
if not all([api_key, client_secret_base64, token_url]):
raise ValueError("api_key, client_secret_base64, and token_url are required")
if not token_url.startswith(('http://', 'https://')):
raise ValueError("token_url must be a valid HTTP(S) URL")
if refresh_buffer < 0:
raise ValueError("refresh_buffer must be non-negative")
self.api_key = api_key
self.client_secret_base64 = client_secret_base64
self.token_url = token_url
self.refresh_buffer = refresh_buffer
# Token state
self.access_token: Optional[str] = None
self.expiry_time: float = 0
self.lock = asyncio.Lock()
# Set up logging
self.logger = logger or logging.getLogger(__name__)
async def _is_token_expired(self) -> bool:
"""
Check if the current token is expired or close to expiration.
Returns:
bool: True if token is expired or close to expiry, False otherwise
"""
current_time = time.time()
token_expired = (
self.access_token is None or
current_time > (self.expiry_time - self.refresh_buffer)
)
if token_expired:
self.logger.debug(
"Token status: expired or near expiry. "
f"Current time: {current_time}, Expiry time: {self.expiry_time}"
)
return token_expired
async def _refresh_token(self) -> None:
"""
Refresh the OAuth token by making an async API request.
Raises:
aiohttp.ClientError: If network request fails
ValueError: If authentication fails
Exception: For unexpected errors during token refresh
"""
async with self.lock:
# Double-check expiration after acquiring lock
if not await self._is_token_expired():
self.logger.debug("Token was refreshed by another task")
return
try:
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"apikey": self.api_key,
"Authorization": f"Basic {self.client_secret_base64}"
}
payload = {
"grant_type": "client_credentials",
"scope": "public"
}
self.logger.debug(f"Attempting to refresh token from {self.token_url}")
async with aiohttp.ClientSession() as session:
async with session.post(
self.token_url,
headers=headers,
data=payload
) as response:
if response.status == 401:
self.logger.error(
"Authentication failed during token refresh. "
"Check credentials."
)
raise ValueError("Authentication failed")
response.raise_for_status()
token_info: Dict[str, Any] = await response.json()
self.access_token = token_info["access_token"]
expires_in = int(token_info["expires_in"])
self.expiry_time = time.time() + expires_in
self.logger.info(f"Token refreshed successfully. Expires in {expires_in} seconds.")
except aiohttp.ClientError as e:
self.logger.error(f"Network error during token refresh: {str(e)}")
self.access_token = None
raise
except Exception as e:
self.logger.error(f"Unexpected error during token refresh: {str(e)}")
self.access_token = None
raise
async def get_token(self) -> Optional[str]:
"""
Get the current access token, refreshing if necessary.
Returns:
str: The current access token if valid
None: If token refresh fails
Raises:
May raise exceptions from _refresh_token() if refresh fails
"""
if await self._is_token_expired():
self.logger.debug("Token expired, initiating refresh")
await self._refresh_token()
return self.access_token
|
Initialize the TokenManager.
Parameters:
Raises:
Source code in src/tools/core/utils/token_manager.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58 | def __init__(
self,
api_key: str,
client_secret_base64: str,
token_url: str,
refresh_buffer: int = 60,
logger: Optional[logging.Logger] = None
) -> None:
"""
Initialize the TokenManager.
Args:
api_key: API key for authentication
client_secret_base64: Base64 encoded client secret
token_url: OAuth2 token endpoint URL
refresh_buffer: Seconds before expiry to trigger refresh
logger: Optional custom logger instance
Raises:
ValueError: If required parameters are missing or invalid
"""
# Validate inputs
if not all([api_key, client_secret_base64, token_url]):
raise ValueError("api_key, client_secret_base64, and token_url are required")
if not token_url.startswith(('http://', 'https://')):
raise ValueError("token_url must be a valid HTTP(S) URL")
if refresh_buffer < 0:
raise ValueError("refresh_buffer must be non-negative")
self.api_key = api_key
self.client_secret_base64 = client_secret_base64
self.token_url = token_url
self.refresh_buffer = refresh_buffer
# Token state
self.access_token: Optional[str] = None
self.expiry_time: float = 0
self.lock = asyncio.Lock()
# Set up logging
self.logger = logger or logging.getLogger(__name__)
|
Get the current access token, refreshing if necessary.
Returns:
Source code in src/tools/core/utils/token_manager.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152 | async def get_token(self) -> Optional[str]:
"""
Get the current access token, refreshing if necessary.
Returns:
str: The current access token if valid
None: If token refresh fails
Raises:
May raise exceptions from _refresh_token() if refresh fails
"""
if await self._is_token_expired():
self.logger.debug("Token expired, initiating refresh")
await self._refresh_token()
return self.access_token
|