Module supertokens_python.recipe.session.with_jwt.recipe_implementation

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Union, Optional

from jwt import decode

from supertokens_python.utils import get_timestamp_ms

from .constants import ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY
from .session_class import get_session_with_jwt
from .utills import add_jwt_to_access_token_payload

if TYPE_CHECKING:
    from supertokens_python.recipe.session.utils import SessionConfig
    from supertokens_python.recipe.session.interfaces import (
        RecipeInterface,
        SessionContainer,
        SessionInformationResult,
    )
    from supertokens_python.framework.types import BaseRequest

from math import ceil

from supertokens_python.recipe.openid.interfaces import (
    RecipeInterface as OpenIdRecipeInterface,
)

EXPIRY_OFFSET_SECONDS = 30


def get_jwt_expiry(access_token_expiry: int):
    return access_token_expiry + EXPIRY_OFFSET_SECONDS


def get_recipe_implementation_with_jwt(
    original_implementation: RecipeInterface,
    config: SessionConfig,
    openid_recipe_implementation: OpenIdRecipeInterface,
) -> RecipeInterface:

    og_create_new_session = original_implementation.create_new_session

    async def create_new_session(
        request: BaseRequest,
        user_id: str,
        access_token_payload: Union[None, Dict[str, Any]],
        session_data: Union[None, Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> SessionContainer:
        if access_token_payload is None:
            access_token_payload = {}
        access_token_validity_in_seconds = ceil(
            await original_implementation.get_access_token_lifetime_ms(user_context)
            / 1000
        )
        access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=access_token_payload,
            jwt_expiry=get_jwt_expiry(access_token_validity_in_seconds),
            user_id=user_id,
            jwt_property_name=config.jwt.property_name_in_access_token_payload,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )
        session = await og_create_new_session(
            request,
            user_id,
            access_token_payload,
            session_data,
            user_context=user_context,
        )
        return get_session_with_jwt(session, openid_recipe_implementation)

    og_get_session = original_implementation.get_session

    async def get_session(
        request: BaseRequest,
        anti_csrf_check: Union[bool, None],
        session_required: bool,
        user_context: Dict[str, Any],
    ) -> Union[SessionContainer, None]:
        session_container = await og_get_session(
            request,
            anti_csrf_check,
            session_required,
            user_context,
        )
        if session_container is None:
            return None
        return get_session_with_jwt(session_container, openid_recipe_implementation)

    og_refresh_session = original_implementation.refresh_session

    async def refresh_session(
        request: BaseRequest, user_context: Dict[str, Any]
    ) -> SessionContainer:
        access_token_validity_in_seconds = ceil(
            await original_implementation.get_access_token_lifetime_ms(user_context)
            / 1000
        )

        # Refresh session first because this will create a new access token
        new_session = await og_refresh_session(request, user_context)
        access_token_payload = new_session.get_access_token_payload()
        access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=access_token_payload,
            jwt_expiry=get_jwt_expiry(access_token_validity_in_seconds),
            user_id=new_session.get_user_id(),
            jwt_property_name=config.jwt.property_name_in_access_token_payload,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )

        await new_session.update_access_token_payload(
            access_token_payload, user_context
        )
        return get_session_with_jwt(new_session, openid_recipe_implementation)

    og_update_access_token_payload = original_implementation.update_access_token_payload

    async def jwt_aware_update_access_token_payload(
        session_information: SessionInformationResult,
        new_access_token_payload: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        access_token_payload = session_information.access_token_payload

        if ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY not in access_token_payload:
            return await og_update_access_token_payload(
                session_information.session_handle,
                new_access_token_payload,
                user_context,
            )

        existing_jwt_property_name = access_token_payload[
            ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY
        ]

        assert existing_jwt_property_name in access_token_payload
        existing_jwt = access_token_payload[existing_jwt_property_name]

        current_time_in_seconds = ceil(get_timestamp_ms() / 1000)
        decoded_payload: Dict[str, Any] = decode(
            jwt=existing_jwt, options={"verify_signature": False, "verify_exp": False}
        )

        if decoded_payload is None or decoded_payload.get("exp") is None:
            raise Exception("Error reading JWT from session")

        jwt_expiry = 1
        if "exp" in decoded_payload:
            exp = decoded_payload["exp"]
            if exp > current_time_in_seconds:
                # it can come here if someone calls this function well after
                # the access token and the jwt payload have expired. In this case,
                # we still want the jwt payload to update, but the resulting JWT should
                # not be alive for too long (since it's expired already). So we set it to
                # 1 second lifetime.
                jwt_expiry = exp - current_time_in_seconds

        new_access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=new_access_token_payload,
            jwt_expiry=jwt_expiry,
            user_id=session_information.user_id,
            jwt_property_name=existing_jwt_property_name,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )

        return await og_update_access_token_payload(
            session_information.session_handle, new_access_token_payload, user_context
        )

    async def update_access_token_payload(
        session_handle: str,
        new_access_token_payload: Optional[Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> bool:
        if new_access_token_payload is None:
            new_access_token_payload = {}

        session_information = await original_implementation.get_session_information(
            session_handle, user_context
        )
        if session_information is None:
            return False

        return await jwt_aware_update_access_token_payload(
            session_information, new_access_token_payload, user_context
        )

    async def merge_into_access_token_payload(
        session_handle: str,
        access_token_payload_update: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_information = await original_implementation.get_session_information(
            session_handle, user_context
        )
        if session_information is None:
            return False

        new_access_token_payload = {
            **session_information.access_token_payload,
            **access_token_payload_update,
        }
        for k in access_token_payload_update.keys():
            if new_access_token_payload[k] is None:
                del new_access_token_payload[k]

        return await jwt_aware_update_access_token_payload(
            session_information, new_access_token_payload, user_context
        )

    original_implementation.create_new_session = create_new_session
    original_implementation.get_session = get_session
    original_implementation.refresh_session = refresh_session
    original_implementation.update_access_token_payload = update_access_token_payload
    original_implementation.merge_into_access_token_payload = (
        merge_into_access_token_payload
    )
    return original_implementation

Functions

def get_jwt_expiry(access_token_expiry: int)
Expand source code
def get_jwt_expiry(access_token_expiry: int):
    return access_token_expiry + EXPIRY_OFFSET_SECONDS
def get_recipe_implementation_with_jwt(original_implementation: RecipeInterface, config: SessionConfig, openid_recipe_implementation: OpenIdRecipeInterface) ‑> RecipeInterface
Expand source code
def get_recipe_implementation_with_jwt(
    original_implementation: RecipeInterface,
    config: SessionConfig,
    openid_recipe_implementation: OpenIdRecipeInterface,
) -> RecipeInterface:

    og_create_new_session = original_implementation.create_new_session

    async def create_new_session(
        request: BaseRequest,
        user_id: str,
        access_token_payload: Union[None, Dict[str, Any]],
        session_data: Union[None, Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> SessionContainer:
        if access_token_payload is None:
            access_token_payload = {}
        access_token_validity_in_seconds = ceil(
            await original_implementation.get_access_token_lifetime_ms(user_context)
            / 1000
        )
        access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=access_token_payload,
            jwt_expiry=get_jwt_expiry(access_token_validity_in_seconds),
            user_id=user_id,
            jwt_property_name=config.jwt.property_name_in_access_token_payload,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )
        session = await og_create_new_session(
            request,
            user_id,
            access_token_payload,
            session_data,
            user_context=user_context,
        )
        return get_session_with_jwt(session, openid_recipe_implementation)

    og_get_session = original_implementation.get_session

    async def get_session(
        request: BaseRequest,
        anti_csrf_check: Union[bool, None],
        session_required: bool,
        user_context: Dict[str, Any],
    ) -> Union[SessionContainer, None]:
        session_container = await og_get_session(
            request,
            anti_csrf_check,
            session_required,
            user_context,
        )
        if session_container is None:
            return None
        return get_session_with_jwt(session_container, openid_recipe_implementation)

    og_refresh_session = original_implementation.refresh_session

    async def refresh_session(
        request: BaseRequest, user_context: Dict[str, Any]
    ) -> SessionContainer:
        access_token_validity_in_seconds = ceil(
            await original_implementation.get_access_token_lifetime_ms(user_context)
            / 1000
        )

        # Refresh session first because this will create a new access token
        new_session = await og_refresh_session(request, user_context)
        access_token_payload = new_session.get_access_token_payload()
        access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=access_token_payload,
            jwt_expiry=get_jwt_expiry(access_token_validity_in_seconds),
            user_id=new_session.get_user_id(),
            jwt_property_name=config.jwt.property_name_in_access_token_payload,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )

        await new_session.update_access_token_payload(
            access_token_payload, user_context
        )
        return get_session_with_jwt(new_session, openid_recipe_implementation)

    og_update_access_token_payload = original_implementation.update_access_token_payload

    async def jwt_aware_update_access_token_payload(
        session_information: SessionInformationResult,
        new_access_token_payload: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        access_token_payload = session_information.access_token_payload

        if ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY not in access_token_payload:
            return await og_update_access_token_payload(
                session_information.session_handle,
                new_access_token_payload,
                user_context,
            )

        existing_jwt_property_name = access_token_payload[
            ACCESS_TOKEN_PAYLOAD_JWT_PROPERTY_NAME_KEY
        ]

        assert existing_jwt_property_name in access_token_payload
        existing_jwt = access_token_payload[existing_jwt_property_name]

        current_time_in_seconds = ceil(get_timestamp_ms() / 1000)
        decoded_payload: Dict[str, Any] = decode(
            jwt=existing_jwt, options={"verify_signature": False, "verify_exp": False}
        )

        if decoded_payload is None or decoded_payload.get("exp") is None:
            raise Exception("Error reading JWT from session")

        jwt_expiry = 1
        if "exp" in decoded_payload:
            exp = decoded_payload["exp"]
            if exp > current_time_in_seconds:
                # it can come here if someone calls this function well after
                # the access token and the jwt payload have expired. In this case,
                # we still want the jwt payload to update, but the resulting JWT should
                # not be alive for too long (since it's expired already). So we set it to
                # 1 second lifetime.
                jwt_expiry = exp - current_time_in_seconds

        new_access_token_payload = await add_jwt_to_access_token_payload(
            access_token_payload=new_access_token_payload,
            jwt_expiry=jwt_expiry,
            user_id=session_information.user_id,
            jwt_property_name=existing_jwt_property_name,
            openid_recipe_implementation=openid_recipe_implementation,
            user_context=user_context,
        )

        return await og_update_access_token_payload(
            session_information.session_handle, new_access_token_payload, user_context
        )

    async def update_access_token_payload(
        session_handle: str,
        new_access_token_payload: Optional[Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> bool:
        if new_access_token_payload is None:
            new_access_token_payload = {}

        session_information = await original_implementation.get_session_information(
            session_handle, user_context
        )
        if session_information is None:
            return False

        return await jwt_aware_update_access_token_payload(
            session_information, new_access_token_payload, user_context
        )

    async def merge_into_access_token_payload(
        session_handle: str,
        access_token_payload_update: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_information = await original_implementation.get_session_information(
            session_handle, user_context
        )
        if session_information is None:
            return False

        new_access_token_payload = {
            **session_information.access_token_payload,
            **access_token_payload_update,
        }
        for k in access_token_payload_update.keys():
            if new_access_token_payload[k] is None:
                del new_access_token_payload[k]

        return await jwt_aware_update_access_token_payload(
            session_information, new_access_token_payload, user_context
        )

    original_implementation.create_new_session = create_new_session
    original_implementation.get_session = get_session
    original_implementation.refresh_session = refresh_session
    original_implementation.update_access_token_payload = update_access_token_payload
    original_implementation.merge_into_access_token_payload = (
        merge_into_access_token_payload
    )
    return original_implementation