Source code for oso.framework.auth.extension

#
# (c) Copyright IBM Corp. 2025
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Authentication Flask Extension."""


from collections.abc import Iterable
from functools import wraps
from typing import Any, Callable, ClassVar

from flask import Flask, current_app, g, request
from werkzeug.datastructures import ImmutableDict
from werkzeug.exceptions import Forbidden, InternalServerError, Unauthorized

from oso.framework.core.logging import get_logger
from oso.framework.exceptions import StartupException

from .common import ALLOWLIST, EXT_NAME, AuthConfig, AuthResult


[docs] class AuthExtension: """Authentication Extension. An extension to manage authentication states for `flask.Flask` applications. Attributes ---------- NAME : str The literal ``oso-auth``. Parameters ---------- config : `.common.AuthConfig` Configuration with all parsers defined. """ NAME: ClassVar[str] = EXT_NAME def __init__(self, config: AuthConfig): self.logger = get_logger(EXT_NAME) self.config = config self.parsers = {parser.type.NAME: parser.type for parser in self.config.parsers} setattr( self, ALLOWLIST, ImmutableDict( { parser.type.NAME: { k: parser.type.parse_allowlist(allowlist) for k, allowlist in parser.allowlist.items() } for parser in self.config.parsers } ), )
[docs] def init_app(self, app: Flask) -> None: """Attach to a `flask.Flask` application. Parameters ---------- app : `flask.Flask` Application. """ # Setup basic application app.extensions = getattr(app, "extensions", {}) if EXT_NAME not in app.extensions: app.extensions[EXT_NAME] = {} if "self" in app.extensions[EXT_NAME]: raise StartupException("Plugin already initialized") app.before_request(self.parse_request_auth) app.extensions[EXT_NAME]["self"] = self self.logger.info( f"Authentication Loaded with: [ {', '.join(self.parsers.keys())} ]" )
[docs] def parse_request_auth(self) -> None: """Authenticate a `flask.Request` with all parsers.""" results = ImmutableDict( {name: self.parsers[name].parse(request) for name in self.parsers.keys()} ) self._log_results(results) setattr(g, EXT_NAME, results)
def _log_results(self, results: ImmutableDict[str, AuthResult]) -> None: for name, result in results.items(): self.logger.debug( f"{name}: " + "Authorized" if result["authorized"] else ", ".join(result["errors"]) )
[docs] def current_auth_ext() -> AuthExtension: """Get Current Authentication Extension. Returns ------- `.AuthExtension` The current authentication extension registered to the `flask.Flask` application. """ return current_app.extensions[EXT_NAME]["self"]
def _raise_on_unauthorized(handler_name: str) -> None: """Check if user is autherized under a `AuthParser`. Parameters ---------- hander_name : str Parser's name. Raises ------ `werkzeug.exceptions.Unauthorized` If the user is not authorized. Returns a ``HTTP 401: Unauthorized`` response. `werkzeug.exceptions.InternalServerError` Should never happen. """ try: if getattr(g, EXT_NAME)[handler_name]["authorized"] is True: return except (KeyError, AttributeError): # There is a missing key, so authentication parsers were not successful. # Handle this like the user was not authorized. pass except Exception: raise InternalServerError() raise Unauthorized() def _get_user(handler_name: str) -> str | None: """Return the user from a `~flask.Request`. Parameters ---------- handler_name : str The `AuthParser`'s name. Returns ------- str | None The user, if any. """ try: return getattr(g, EXT_NAME, {})[handler_name]["_user"] except (KeyError, AttributeError): return None def _get_allowlist(handler_name: str, key: str) -> Iterable: """Return the allowlist for an `AuthParser`. Parameters ---------- handler_name : str The `AuthParser`'s name. key : str The allowlist's name. Returns ------- `collections.abc.Iterable` """ try: return getattr(current_auth_ext(), ALLOWLIST)[handler_name][key.lower()] except (KeyError, AttributeError): return []
[docs] def RequireAuth(handler_name: str, allowlist: str, *allowlists: str) -> Callable: """Mark an endpoint as requiring authentication. Parameters ---------- handler_name : str The `AuthParser` that is required. *allowlists : list[str] The allowlists that is allowed for the endpoint. Returns ------- `~typing.Callable` The wrapped function. """ def decorator(f: Callable) -> Callable: @wraps(f) def decorated(*args, **kwargs) -> Any: _raise_on_unauthorized(handler_name) for key in [allowlist, *allowlists]: _allowlist = _get_allowlist(handler_name, key) _user = _get_user(handler_name) if _user in _allowlist: # User is allowed to utilize this path return f(*args, **kwargs) # User is forbidden from utilizing this path. raise Forbidden() return decorated return decorator