Coverage for mcpgateway / middleware / token_usage_middleware.py: 100%
108 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/middleware/token_usage_middleware.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Token Usage Logging Middleware.
9This middleware logs API token usage for analytics and security monitoring.
10It records each request made with an API token, including endpoint, method,
11response time, and status code.
13Note: Implemented as raw ASGI middleware (not BaseHTTPMiddleware) to avoid
14response body buffering issues with streaming responses.
16Examples:
17 >>> from mcpgateway.middleware.token_usage_middleware import TokenUsageMiddleware # doctest: +SKIP
18 >>> app.add_middleware(TokenUsageMiddleware) # doctest: +SKIP
19"""
21# Standard
22import logging
23import time
24from typing import Optional
26# Third-Party
27import jwt as _jwt
28from starlette.datastructures import Headers
29from starlette.requests import Request
30from starlette.types import ASGIApp, Receive, Scope, Send
32# First-Party
33from mcpgateway.db import fresh_db_session
34from mcpgateway.middleware.path_filter import should_skip_auth_context
35from mcpgateway.services.token_catalog_service import TokenCatalogService
36from mcpgateway.utils.verify_credentials import verify_jwt_token_cached
38logger = logging.getLogger(__name__)
41class TokenUsageMiddleware:
42 """Raw ASGI middleware for logging API token usage.
44 This middleware tracks when API tokens are used, recording details like:
45 - Endpoint accessed
46 - HTTP method
47 - Response status code
48 - Response time
49 - Client IP and user agent
51 This data is used for security auditing, usage analytics, and detecting
52 anomalous token usage patterns.
54 Note:
55 Only logs usage for requests authenticated with API tokens (identified
56 by request.state.auth_method == "api_token").
58 Implemented as raw ASGI middleware to avoid BaseHTTPMiddleware issues:
59 - BaseHTTPMiddleware buffers entire response bodies (problematic for streaming)
60 - Raw ASGI middleware streams responses efficiently
61 """
63 def __init__(self, app: ASGIApp) -> None:
64 """Initialize middleware.
66 Args:
67 app: ASGI application to wrap
68 """
69 self.app = app
71 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
72 """Process ASGI request.
74 Args:
75 scope: ASGI scope dict
76 receive: Receive callable
77 send: Send callable
78 """
79 # Only process HTTP requests
80 if scope["type"] != "http":
81 await self.app(scope, receive, send)
82 return
84 # Skip health checks and static files
85 path = scope.get("path", "")
86 if should_skip_auth_context(path):
87 await self.app(scope, receive, send)
88 return
90 # Record start time
91 start_time = time.time()
93 # Capture response status
94 status_code = 200 # Default
96 async def send_wrapper(message: dict) -> None:
97 """Wrap send to capture response status.
99 Args:
100 message: ASGI message dict containing response data
101 """
102 nonlocal status_code
103 if message["type"] == "http.response.start":
104 status_code = message["status"]
105 await send(message)
107 # Process request
108 await self.app(scope, receive, send_wrapper)
110 # Calculate response time
111 response_time_ms = round((time.time() - start_time) * 1000)
113 # Log API token usage — covers both successful requests and auth-rejected attempts.
114 # Every request that uses (or tries to use) an API token is recorded,
115 # including blocked calls with revoked/expired tokens, so that usage stats are accurate.
116 state = scope.get("state", {})
117 auth_method = state.get("auth_method") if state else None
119 jti: Optional[str] = None
120 user_email: Optional[str] = None
121 blocked: bool = False
122 block_reason: Optional[str] = None
124 if auth_method == "api_token":
125 # --- Successfully authenticated API token request ---
126 jti = state.get("jti") if state else None
127 user = state.get("user") if state else None
128 user_email = getattr(user, "email", None) if user else None
129 if not user_email:
130 user_email = state.get("user_email") if state else None
132 # If we don't have JTI or email, try to decode the token from the header
133 if not jti or not user_email:
134 try:
135 headers = Headers(scope=scope)
136 auth_header = headers.get("authorization")
137 if not auth_header or not auth_header.startswith("Bearer "):
138 return
139 token = auth_header.replace("Bearer ", "")
140 request = Request(scope, receive)
141 try:
142 payload = await verify_jwt_token_cached(token, request)
143 jti = jti or payload.get("jti")
144 user_email = user_email or payload.get("sub") or payload.get("email")
145 except Exception as decode_error:
146 logger.debug(f"Failed to decode token for usage logging: {decode_error}")
147 return
148 except Exception as e:
149 logger.debug(f"Error extracting token information: {e}")
150 return
152 if not jti or not user_email:
153 logger.debug("Missing JTI or user_email for token usage logging")
154 return
156 # Bug 3a fix: reflect the actual outcome — 4xx responses mark the attempt
157 # as blocked (e.g. RBAC denied, rate-limited, or server-scoping violation).
158 # 5xx errors are backend failures, not security denials, so exclude them.
159 blocked = 400 <= status_code < 500
160 if blocked:
161 block_reason = f"http_{status_code}"
163 elif status_code in (401, 403):
164 # --- Auth-rejected request: check if the Bearer token was an API token ---
165 # When a revoked or expired API token is used, auth middleware rejects the
166 # request before setting auth_method="api_token", so the path above is
167 # never reached. We detect the attempt here by decoding the JWT payload
168 # without re-verifying it (the token identity is valid even if rejected).
169 try:
170 headers = Headers(scope=scope)
171 auth_header = headers.get("authorization")
172 if not auth_header or not auth_header.startswith("Bearer "):
173 return
174 raw_token = auth_header[7:] # strip "Bearer "
176 # Decode without signature/expiry check — for identification only, not auth.
177 unverified = _jwt.decode(raw_token, options={"verify_signature": False})
178 user_info = unverified.get("user", {})
179 if user_info.get("auth_provider") != "api_token":
180 return # Not an API token — nothing to log
182 jti = unverified.get("jti")
183 user_email = unverified.get("sub") or unverified.get("email")
184 if not jti or not user_email:
185 return
187 # Verify JTI belongs to a real API token before logging.
188 # Without this check, an attacker can craft a JWT with fake
189 # jti/sub and auth_provider=api_token to pollute usage logs.
190 # Verify JTI belongs to a real API token and use the DB-stored
191 # owner email instead of the unverified JWT claim. Without this,
192 # an attacker who knows a valid JTI could forge a JWT with an
193 # arbitrary sub/email to poison another user's usage stats.
194 try:
195 # Third-Party
196 from sqlalchemy import select # pylint: disable=import-outside-toplevel
198 # First-Party
199 from mcpgateway.db import EmailApiToken # pylint: disable=import-outside-toplevel
201 with fresh_db_session() as verify_db:
202 token_row = verify_db.execute(select(EmailApiToken.id, EmailApiToken.user_email).where(EmailApiToken.jti == jti)).first()
203 if token_row is None:
204 return # JTI not in DB — forged token, skip logging
205 # Use the DB-stored owner, not the unverified JWT claim
206 user_email = token_row.user_email
207 except Exception:
208 return # DB error — skip logging rather than log unverified data
210 blocked = True
211 block_reason = "revoked_or_expired" if status_code == 401 else f"http_{status_code}"
212 except Exception as e:
213 logger.debug(f"Failed to extract API token identity from rejected request: {e}")
214 return
215 else:
216 return # Not an API token request — nothing to log
218 # Shared logging path for both authenticated and blocked API token requests
219 try:
220 with fresh_db_session() as db:
221 token_service = TokenCatalogService(db)
222 client = scope.get("client")
223 ip_address = client[0] if client else None
224 headers = Headers(scope=scope)
225 user_agent = headers.get("user-agent")
227 await token_service.log_token_usage(
228 jti=jti,
229 user_email=user_email,
230 endpoint=path,
231 method=scope.get("method", "GET"),
232 ip_address=ip_address,
233 user_agent=user_agent,
234 status_code=status_code,
235 response_time_ms=response_time_ms,
236 blocked=blocked,
237 block_reason=block_reason,
238 )
239 except Exception as e:
240 logger.debug(f"Failed to log token usage: {e}")