Coverage for mcpgateway / utils / token_scoping.py: 100%
30 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/token_scoping.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Mihai Criveti
7Token scoping utilities for extracting and validating token scopes.
8"""
10# Standard
11from typing import Optional
13# Third-Party
14from fastapi import HTTPException, Request
16# First-Party
17from mcpgateway.utils.verify_credentials import verify_jwt_token_cached
20async def extract_token_scopes_from_request(request: Request) -> Optional[dict]:
21 """Extract token scopes from JWT in request.
23 Args:
24 request: FastAPI request object
26 Returns:
27 Dict containing token scopes or None if no valid token
29 Examples:
30 >>> # Test with no authorization header
31 >>> from unittest.mock import Mock
32 >>> import asyncio
33 >>> mock_request = Mock()
34 >>> mock_request.headers = {}
35 >>> asyncio.run(extract_token_scopes_from_request(mock_request)) is None
36 True
37 >>>
38 >>> # Test with invalid authorization header
39 >>> mock_request = Mock()
40 >>> mock_request.headers = {"Authorization": "Invalid token"}
41 >>> asyncio.run(extract_token_scopes_from_request(mock_request)) is None
42 True
43 >>>
44 >>> # Test with malformed Bearer token
45 >>> mock_request = Mock()
46 >>> mock_request.headers = {"Authorization": "Bearer"}
47 >>> asyncio.run(extract_token_scopes_from_request(mock_request)) is None
48 True
49 >>>
50 >>> # Test with Bearer but no space
51 >>> mock_request = Mock()
52 >>> mock_request.headers = {"Authorization": "Bearer123"}
53 >>> asyncio.run(extract_token_scopes_from_request(mock_request)) is None
54 True
55 """
56 # Get authorization header
57 auth_header = request.headers.get("Authorization")
58 if not auth_header or not auth_header.startswith("Bearer "):
59 return None
61 token = auth_header.split(" ", 1)[1]
63 try:
64 # Use the centralized verify_jwt_token_cached function for consistent JWT validation
65 payload = await verify_jwt_token_cached(token, request)
66 return payload.get("scopes")
67 except HTTPException:
68 # Token validation failed (expired, invalid, etc.)
69 return None
70 except Exception:
71 # Any other error in token validation
72 return None
75def is_token_server_scoped(scopes: Optional[dict]) -> bool:
76 """Check if token has server-specific scoping.
78 Args:
79 scopes: Token scopes dictionary
81 Returns:
82 bool: True if token is scoped to a specific server
84 Examples:
85 >>> scopes = {"server_id": "server-123", "permissions": ["tools.read"]}
86 >>> is_token_server_scoped(scopes)
87 True
88 >>> scopes = {"server_id": None, "permissions": ["*"]}
89 >>> is_token_server_scoped(scopes)
90 False
91 """
92 if not scopes:
93 return False
94 return scopes.get("server_id") is not None
97def get_token_server_id(scopes: Optional[dict]) -> Optional[str]:
98 """Get the server ID that a token is scoped to.
100 Args:
101 scopes: Token scopes dictionary
103 Returns:
104 Optional[str]: Server ID if token is server-scoped, None otherwise
106 Examples:
107 >>> scopes = {"server_id": "server-123", "permissions": ["tools.read"]}
108 >>> get_token_server_id(scopes)
109 'server-123'
110 >>> scopes = {"server_id": None, "permissions": ["*"]}
111 >>> get_token_server_id(scopes) is None
112 True
113 """
114 if not scopes:
115 return None
116 return scopes.get("server_id")
119def validate_server_access(scopes: Optional[dict], requested_server_id: str) -> bool:
120 """Validate that token scopes allow access to the requested server.
122 Args:
123 scopes: Token scopes dictionary
124 requested_server_id: ID of server being accessed
126 Returns:
127 bool: True if access is allowed
129 Examples:
130 >>> scopes = {"server_id": "server-123", "permissions": ["tools.read"]}
131 >>> validate_server_access(scopes, "server-123")
132 True
133 >>> validate_server_access(scopes, "server-456")
134 False
135 >>> scopes = {"server_id": None, "permissions": ["*"]}
136 >>> validate_server_access(scopes, "any-server")
137 True
138 """
139 if not scopes:
140 return True # No scopes means full access (legacy tokens)
142 server_id = scopes.get("server_id")
143 if server_id is None:
144 return True # Global scope token
146 return server_id == requested_server_id