Source code for flask_ligand.extensions.jwt
"""JWT authentication and authorization."""
# ======================================================================================================================
# Imports
# ======================================================================================================================
from __future__ import annotations
import json
from requests import get
from http import HTTPStatus
from functools import wraps
from flask import current_app
from typing import TYPE_CHECKING
from dataclasses import dataclass
from jwt.algorithms import RSAAlgorithm
from flask_ligand.extensions.api import abort
from requests.exceptions import RequestException
from flask_jwt_extended import verify_jwt_in_request, get_current_user, JWTManager
# ======================================================================================================================
# Type Checking
# ======================================================================================================================
if TYPE_CHECKING: # pragma: no cover
from flask import Flask
from typing import Callable, Any
# ======================================================================================================================
# Globals
# ======================================================================================================================
JWT = JWTManager()
# ======================================================================================================================
# Classes: Public
# ======================================================================================================================
[docs]@dataclass
class User:
"""
A simple class representing pertinent user information.
Args:
id: The UUID of the user.
roles: A list of roles that the user has been assigned.
"""
id: str
roles: list[str]
# ======================================================================================================================
# Decorators: Public
# ======================================================================================================================
[docs]def jwt_role_required(role: str): # type: ignore
"""A decorator for restricting access to an endpoint based on role membership.
Note: This decorator style was chosen because of: https://stackoverflow.com/a/42581103
Args:
role: The role membership required by the user in order to access this endpoint.
"""
def decorator(fn: Callable[[Any], Any]) -> Callable[[Any], Any]:
@wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# standard flask_jwt_extended token verifications
verify_jwt_in_request()
if role not in current_app.config["ALLOWED_ROLES"]:
abort(HTTPStatus(500), message="Endpoint required role is not an allowed role!")
# custom role membership verification
user: User = get_current_user()
if role not in user.roles:
abort(HTTPStatus(403), message=f"This endpoint requires the user to have the '{role}' role!")
return fn(*args, **kwargs)
return wrapper
return decorator
# ======================================================================================================================
# Functions: Callbacks
# ======================================================================================================================
@JWT.user_lookup_loader
def user_lookup_callback(_jwt_header: dict[str, Any], jwt_data: dict[str, Any]) -> User:
"""This callback function is used to convert a JWT into a python object that can be used in a protected endpoint.
Note: https://flask-jwt-extended.readthedocs.io/en/stable/api/#flask_jwt_extended.JWTManager.user_lookup_loader
Args:
_jwt_header: Header data of the JWT. (Unused argument)
jwt_data: Payload data of the JWT.
"""
return User(
id=jwt_data["sub"],
roles=jwt_data["realm_access"]["roles"],
)
# ======================================================================================================================
# Functions: Public
# ======================================================================================================================
def init_app(app: Flask) -> None: # pragma: no cover (Covered by integration tests)
"""Initialize JWT."""
verify_ssl_cert = app.config["VERIFY_SSL_CERT"]
try:
# Retrieve master openid-configuration endpoint from issuer realm
oidc_config = get(app.config["OIDC_DISCOVERY_URL"], verify=verify_ssl_cert).json()
# Retrieve data from jwks_uri endpoint
oidc_jwks_uri = get(oidc_config["jwks_uri"], verify=verify_ssl_cert).json()
except (RequestException, KeyError):
raise RuntimeError(
f"Failed to retrieve public key from the '{app.config['OIDC_DISCOVERY_URL']}' OIDC Discovery URL!"
)
# Retrieve first jwk entry from jwks_uri endpoint and use it to construct the RSA public key
app.config["JWT_PUBLIC_KEY"] = RSAAlgorithm.from_jwk(json.dumps(oidc_jwks_uri["keys"][0])) # type: ignore
JWT.init_app(app)