Module supertokens_python.recipe.session.jwks

Expand source code
# Copyright (c) 2023, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

import requests
from os import environ
from typing import List, Optional
from typing_extensions import TypedDict

from jwt import PyJWK, PyJWKSet

from supertokens_python.recipe.session.utils import SessionConfig
from supertokens_python.utils import RWMutex, RWLockContext, get_timestamp_ms
from supertokens_python.querier import Querier
from supertokens_python.logger import log_debug_message


class JWKSConfigType(TypedDict):
    request_timeout: int


JWKSConfig: JWKSConfigType = {
    "request_timeout": 10000,  # 10s
}


class CachedKeys:
    def __init__(self, keys: List[PyJWK], refresh_interval_sec: int):
        self.keys = keys
        self.last_refresh_time = get_timestamp_ms()
        self.refresh_interval_sec = refresh_interval_sec

    def is_fresh(self):
        return (
            get_timestamp_ms() - self.last_refresh_time
            < self.refresh_interval_sec * 1000
        )


cached_keys: Optional[CachedKeys] = None
mutex = RWMutex()


# only for testing purposes
def reset_jwks_cache():
    with RWLockContext(mutex, read=False):
        global cached_keys
        cached_keys = None


def get_cached_keys() -> Optional[List[PyJWK]]:
    if cached_keys is not None:
        # This means that we have valid JWKs for the given core path
        # We check if we need to refresh before returning

        # This means that the value in cache is not expired, in this case we return the cached value
        # Note that this also means that the SDK will not try to query any other core (if there are multiple)
        # if it has a valid cache entry from one of the core URLs. It will only attempt to fetch
        # from the cores again after the entry in the cache is expired
        if cached_keys.is_fresh():
            return cached_keys.keys

    return None


def find_matching_keys(
    keys: Optional[List[PyJWK]], kid: Optional[str]
) -> Optional[List[PyJWK]]:
    if kid is None or keys is None:
        # return all keys since the token does not have a kid
        return keys

    # kid has been provided so filter the keys
    matching_keys = [key for key in keys if key.key_id == kid]  # type: ignore
    if len(matching_keys) > 0:
        return matching_keys

    return None


def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) -> List[PyJWK]:
    global cached_keys

    if environ.get("SUPERTOKENS_ENV") == "testing":
        log_debug_message("Called find_jwk_client")

    with RWLockContext(mutex, read=True):
        matching_keys = find_matching_keys(get_cached_keys(), kid)
        if matching_keys is not None:
            if environ.get("SUPERTOKENS_ENV") == "testing":
                log_debug_message("Returning JWKS from cache")
            return matching_keys
        # otherwise unknown kid, will continue to reload the keys

    core_paths = Querier.get_instance().get_all_core_urls_for_path(
        "./.well-known/jwks.json"
    )

    if len(core_paths) == 0:
        raise Exception(
            "No SuperTokens core available to query. Please pass supertokens > connection_uri to the init function, or override all the functions of the recipe you are using."
        )

    last_error: Exception = Exception("No valid JWKS found")

    with RWLockContext(mutex, read=False):
        # check again if the keys are in cache
        # because another thread might have fetched the keys while this one was waiting for the lock
        matching_keys = find_matching_keys(get_cached_keys(), kid)
        if matching_keys is not None:
            return matching_keys

        for path in core_paths:
            if environ.get("SUPERTOKENS_ENV") == "testing":
                log_debug_message("Attempting to fetch JWKS from path: %s", path)

            cached_jwks: Optional[List[PyJWK]] = None
            try:
                log_debug_message("Fetching jwk set from the configured uri")
                with requests.get(
                    path, timeout=JWKSConfig["request_timeout"] / 1000
                ) as response:  # 5 second timeout
                    response.raise_for_status()
                    cached_jwks = PyJWKSet.from_dict(response.json()).keys  # type: ignore
            except Exception as e:
                last_error = e

            if cached_jwks is not None:  # we found a valid JWKS
                cached_keys = CachedKeys(cached_jwks, config.jwks_refresh_interval_sec)
                log_debug_message("Returning JWKS from fetch")
                matching_keys = find_matching_keys(get_cached_keys(), kid)
                if matching_keys is not None:
                    return matching_keys

                raise Exception("No matching JWKS found")

    raise last_error

Functions

def find_matching_keys(keys: Optional[List[jwt.api_jwk.PyJWK]], kid: Optional[str]) ‑> Optional[List[jwt.api_jwk.PyJWK]]
def get_cached_keys() ‑> Optional[List[jwt.api_jwk.PyJWK]]
def get_latest_keys(config: SessionConfig, kid: Optional[str] = None) ‑> List[jwt.api_jwk.PyJWK]
def reset_jwks_cache()

Classes

class CachedKeys (keys: List[jwt.api_jwk.PyJWK], refresh_interval_sec: int)
Expand source code
class CachedKeys:
    def __init__(self, keys: List[PyJWK], refresh_interval_sec: int):
        self.keys = keys
        self.last_refresh_time = get_timestamp_ms()
        self.refresh_interval_sec = refresh_interval_sec

    def is_fresh(self):
        return (
            get_timestamp_ms() - self.last_refresh_time
            < self.refresh_interval_sec * 1000
        )

Methods

def is_fresh(self)
class JWKSConfigType (*args, **kwargs)

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Expand source code
class JWKSConfigType(TypedDict):
    request_timeout: int

Ancestors

  • builtins.dict

Class variables

var request_timeout : int