Coverage for mcpgateway / middleware / token_usage_middleware.py: 100%

108 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-09 03:05 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/middleware/token_usage_middleware.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Mihai Criveti 

6 

7Token Usage Logging Middleware. 

8 

9This middleware logs API token usage for analytics and security monitoring. 

10It records each request made with an API token, including endpoint, method, 

11response time, and status code. 

12 

13Note: Implemented as raw ASGI middleware (not BaseHTTPMiddleware) to avoid 

14response body buffering issues with streaming responses. 

15 

16Examples: 

17 >>> from mcpgateway.middleware.token_usage_middleware import TokenUsageMiddleware # doctest: +SKIP 

18 >>> app.add_middleware(TokenUsageMiddleware) # doctest: +SKIP 

19""" 

20 

21# Standard 

22import logging 

23import time 

24from typing import Optional 

25 

26# Third-Party 

27import jwt as _jwt 

28from starlette.datastructures import Headers 

29from starlette.requests import Request 

30from starlette.types import ASGIApp, Receive, Scope, Send 

31 

32# First-Party 

33from mcpgateway.db import fresh_db_session 

34from mcpgateway.middleware.path_filter import should_skip_auth_context 

35from mcpgateway.services.token_catalog_service import TokenCatalogService 

36from mcpgateway.utils.verify_credentials import verify_jwt_token_cached 

37 

38logger = logging.getLogger(__name__) 

39 

40 

41class TokenUsageMiddleware: 

42 """Raw ASGI middleware for logging API token usage. 

43 

44 This middleware tracks when API tokens are used, recording details like: 

45 - Endpoint accessed 

46 - HTTP method 

47 - Response status code 

48 - Response time 

49 - Client IP and user agent 

50 

51 This data is used for security auditing, usage analytics, and detecting 

52 anomalous token usage patterns. 

53 

54 Note: 

55 Only logs usage for requests authenticated with API tokens (identified 

56 by request.state.auth_method == "api_token"). 

57 

58 Implemented as raw ASGI middleware to avoid BaseHTTPMiddleware issues: 

59 - BaseHTTPMiddleware buffers entire response bodies (problematic for streaming) 

60 - Raw ASGI middleware streams responses efficiently 

61 """ 

62 

63 def __init__(self, app: ASGIApp) -> None: 

64 """Initialize middleware. 

65 

66 Args: 

67 app: ASGI application to wrap 

68 """ 

69 self.app = app 

70 

71 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

72 """Process ASGI request. 

73 

74 Args: 

75 scope: ASGI scope dict 

76 receive: Receive callable 

77 send: Send callable 

78 """ 

79 # Only process HTTP requests 

80 if scope["type"] != "http": 

81 await self.app(scope, receive, send) 

82 return 

83 

84 # Skip health checks and static files 

85 path = scope.get("path", "") 

86 if should_skip_auth_context(path): 

87 await self.app(scope, receive, send) 

88 return 

89 

90 # Record start time 

91 start_time = time.time() 

92 

93 # Capture response status 

94 status_code = 200 # Default 

95 

96 async def send_wrapper(message: dict) -> None: 

97 """Wrap send to capture response status. 

98 

99 Args: 

100 message: ASGI message dict containing response data 

101 """ 

102 nonlocal status_code 

103 if message["type"] == "http.response.start": 

104 status_code = message["status"] 

105 await send(message) 

106 

107 # Process request 

108 await self.app(scope, receive, send_wrapper) 

109 

110 # Calculate response time 

111 response_time_ms = round((time.time() - start_time) * 1000) 

112 

113 # Log API token usage — covers both successful requests and auth-rejected attempts. 

114 # Every request that uses (or tries to use) an API token is recorded, 

115 # including blocked calls with revoked/expired tokens, so that usage stats are accurate. 

116 state = scope.get("state", {}) 

117 auth_method = state.get("auth_method") if state else None 

118 

119 jti: Optional[str] = None 

120 user_email: Optional[str] = None 

121 blocked: bool = False 

122 block_reason: Optional[str] = None 

123 

124 if auth_method == "api_token": 

125 # --- Successfully authenticated API token request --- 

126 jti = state.get("jti") if state else None 

127 user = state.get("user") if state else None 

128 user_email = getattr(user, "email", None) if user else None 

129 if not user_email: 

130 user_email = state.get("user_email") if state else None 

131 

132 # If we don't have JTI or email, try to decode the token from the header 

133 if not jti or not user_email: 

134 try: 

135 headers = Headers(scope=scope) 

136 auth_header = headers.get("authorization") 

137 if not auth_header or not auth_header.startswith("Bearer "): 

138 return 

139 token = auth_header.replace("Bearer ", "") 

140 request = Request(scope, receive) 

141 try: 

142 payload = await verify_jwt_token_cached(token, request) 

143 jti = jti or payload.get("jti") 

144 user_email = user_email or payload.get("sub") or payload.get("email") 

145 except Exception as decode_error: 

146 logger.debug(f"Failed to decode token for usage logging: {decode_error}") 

147 return 

148 except Exception as e: 

149 logger.debug(f"Error extracting token information: {e}") 

150 return 

151 

152 if not jti or not user_email: 

153 logger.debug("Missing JTI or user_email for token usage logging") 

154 return 

155 

156 # Bug 3a fix: reflect the actual outcome — 4xx responses mark the attempt 

157 # as blocked (e.g. RBAC denied, rate-limited, or server-scoping violation). 

158 # 5xx errors are backend failures, not security denials, so exclude them. 

159 blocked = 400 <= status_code < 500 

160 if blocked: 

161 block_reason = f"http_{status_code}" 

162 

163 elif status_code in (401, 403): 

164 # --- Auth-rejected request: check if the Bearer token was an API token --- 

165 # When a revoked or expired API token is used, auth middleware rejects the 

166 # request before setting auth_method="api_token", so the path above is 

167 # never reached. We detect the attempt here by decoding the JWT payload 

168 # without re-verifying it (the token identity is valid even if rejected). 

169 try: 

170 headers = Headers(scope=scope) 

171 auth_header = headers.get("authorization") 

172 if not auth_header or not auth_header.startswith("Bearer "): 

173 return 

174 raw_token = auth_header[7:] # strip "Bearer " 

175 

176 # Decode without signature/expiry check — for identification only, not auth. 

177 unverified = _jwt.decode(raw_token, options={"verify_signature": False}) 

178 user_info = unverified.get("user", {}) 

179 if user_info.get("auth_provider") != "api_token": 

180 return # Not an API token — nothing to log 

181 

182 jti = unverified.get("jti") 

183 user_email = unverified.get("sub") or unverified.get("email") 

184 if not jti or not user_email: 

185 return 

186 

187 # Verify JTI belongs to a real API token before logging. 

188 # Without this check, an attacker can craft a JWT with fake 

189 # jti/sub and auth_provider=api_token to pollute usage logs. 

190 # Verify JTI belongs to a real API token and use the DB-stored 

191 # owner email instead of the unverified JWT claim. Without this, 

192 # an attacker who knows a valid JTI could forge a JWT with an 

193 # arbitrary sub/email to poison another user's usage stats. 

194 try: 

195 # Third-Party 

196 from sqlalchemy import select # pylint: disable=import-outside-toplevel 

197 

198 # First-Party 

199 from mcpgateway.db import EmailApiToken # pylint: disable=import-outside-toplevel 

200 

201 with fresh_db_session() as verify_db: 

202 token_row = verify_db.execute(select(EmailApiToken.id, EmailApiToken.user_email).where(EmailApiToken.jti == jti)).first() 

203 if token_row is None: 

204 return # JTI not in DB — forged token, skip logging 

205 # Use the DB-stored owner, not the unverified JWT claim 

206 user_email = token_row.user_email 

207 except Exception: 

208 return # DB error — skip logging rather than log unverified data 

209 

210 blocked = True 

211 block_reason = "revoked_or_expired" if status_code == 401 else f"http_{status_code}" 

212 except Exception as e: 

213 logger.debug(f"Failed to extract API token identity from rejected request: {e}") 

214 return 

215 else: 

216 return # Not an API token request — nothing to log 

217 

218 # Shared logging path for both authenticated and blocked API token requests 

219 try: 

220 with fresh_db_session() as db: 

221 token_service = TokenCatalogService(db) 

222 client = scope.get("client") 

223 ip_address = client[0] if client else None 

224 headers = Headers(scope=scope) 

225 user_agent = headers.get("user-agent") 

226 

227 await token_service.log_token_usage( 

228 jti=jti, 

229 user_email=user_email, 

230 endpoint=path, 

231 method=scope.get("method", "GET"), 

232 ip_address=ip_address, 

233 user_agent=user_agent, 

234 status_code=status_code, 

235 response_time_ms=response_time_ms, 

236 blocked=blocked, 

237 block_reason=block_reason, 

238 ) 

239 except Exception as e: 

240 logger.debug(f"Failed to log token usage: {e}")