Module supertokens_python.recipe.session.recipe_implementation

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Dict, Optional
from supertokens_python.framework.request import BaseRequest
from supertokens_python.logger import log_debug_message
from supertokens_python.normalised_url_path import NormalisedURLPath
from supertokens_python.process_state import AllowedProcessStates, ProcessState
from supertokens_python.utils import (
    execute_async,
    frontend_has_interceptor,
    get_timestamp_ms,
    normalise_http_method,
    resolve,
)
from . import session_functions
from .cookie_and_header import (
    get_access_token_from_cookie,
    get_anti_csrf_header,
    get_id_refresh_token_from_cookie,
    get_refresh_token_from_cookie,
    get_rid_header,
)
from .exceptions import raise_try_refresh_token_exception, raise_unauthorised_exception
from .interfaces import (
    AccessTokenObj,
    RecipeInterface,
    RegenerateAccessTokenOkResult,
    SessionClaim,
    SessionClaimValidator,
    SessionInformationResult,
    SessionObj,
    ClaimsValidationResult,
    SessionDoesNotExistError,
    JSONObject,
    GetClaimValueOkResult,
)
from .session_class import Session
from ...types import MaybeAwaitable
from .utils import (
    SessionConfig,
    validate_claims_in_payload,
)

if TYPE_CHECKING:
    from typing import List, Union

    from supertokens_python.querier import Querier


from .interfaces import SessionContainer


class HandshakeInfo:
    def __init__(self, info: Dict[str, Any]):
        self.access_token_blacklisting_enabled = info["accessTokenBlacklistingEnabled"]
        self.raw_jwt_signing_public_key_list: List[Dict[str, Any]] = []
        self.anti_csrf = info["antiCsrf"]
        self.access_token_validity = info["accessTokenValidity"]
        self.refresh_token_validity = info["refreshTokenValidity"]

    def set_jwt_signing_public_key_list(self, updated_list: List[Dict[str, Any]]):
        self.raw_jwt_signing_public_key_list = updated_list

    def get_jwt_signing_public_key_list(self) -> List[Dict[str, Any]]:
        time_now = get_timestamp_ms()
        return [
            key
            for key in self.raw_jwt_signing_public_key_list
            if key["expiryTime"] > time_now
        ]


class RecipeImplementation(RecipeInterface):  # pylint: disable=too-many-public-methods
    def __init__(self, querier: Querier, config: SessionConfig):
        super().__init__()
        self.querier = querier
        self.config = config
        self.handshake_info: Union[HandshakeInfo, None] = None

        async def call_get_handshake_info():
            try:
                await self.get_handshake_info()
            except Exception:
                pass

        try:
            execute_async(config.mode, call_get_handshake_info)
        except Exception:
            pass

    async def get_handshake_info(self, force_refetch: bool = False) -> HandshakeInfo:
        if (
            self.handshake_info is None
            or len(self.handshake_info.get_jwt_signing_public_key_list()) == 0
            or force_refetch
        ):
            ProcessState.get_instance().add_state(
                AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO
            )
            response = await self.querier.send_post_request(
                NormalisedURLPath("/recipe/handshake"), {}
            )
            self.handshake_info = HandshakeInfo(
                {**response, "antiCsrf": self.config.anti_csrf}
            )

            self.update_jwt_signing_public_key_info(
                response["jwtSigningPublicKeyList"],
                response["jwtSigningPublicKey"],
                response["jwtSigningPublicKeyExpiryTime"],
            )

        return self.handshake_info

    def update_jwt_signing_public_key_info(
        self,
        key_list: Union[List[Dict[str, Any]], None],
        public_key: str,
        expiry_time: int,
    ):
        if key_list is None:
            key_list = [
                {
                    "publicKey": public_key,
                    "expiryTime": expiry_time,
                    "createdAt": get_timestamp_ms(),
                }
            ]

        if self.handshake_info is not None:
            self.handshake_info.set_jwt_signing_public_key_list(key_list)

    async def create_new_session(
        self,
        request: BaseRequest,
        user_id: str,
        access_token_payload: Union[None, Dict[str, Any]],
        session_data: Union[None, Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> SessionContainer:
        session = await session_functions.create_new_session(
            self, user_id, access_token_payload, session_data
        )
        access_token = session["accessToken"]
        refresh_token = session["refreshToken"]
        id_refresh_token = session["idRefreshToken"]
        new_session = Session(
            self,
            access_token["token"],
            session["session"]["handle"],
            session["session"]["userId"],
            session["session"]["userDataInJWT"],
        )
        new_session.new_access_token_info = access_token
        new_session.new_refresh_token_info = refresh_token
        new_session.new_id_refresh_token_info = id_refresh_token
        if "antiCsrfToken" in session and session["antiCsrfToken"] is not None:
            new_session.new_anti_csrf_token = session["antiCsrfToken"]
        request.set_session(new_session)
        return new_session

    async def validate_claims(
        self,
        user_id: str,
        access_token_payload: Dict[str, Any],
        claim_validators: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> ClaimsValidationResult:
        access_token_payload_update = None
        original_access_token_payload = json.dumps(access_token_payload)

        for validator in claim_validators:
            log_debug_message(
                "update_claims_in_payload_if_needed checking should_refetch for %s",
                validator.id,
            )
            if validator.claim is not None and validator.should_refetch(
                access_token_payload, user_context
            ):
                log_debug_message(
                    "update_claims_in_payload_if_needed refetching for %s", validator.id
                )
                value = await resolve(
                    validator.claim.fetch_value(user_id, user_context)
                )
                log_debug_message(
                    "update_claims_in_payload_if_needed %s refetch result %s",
                    validator.id,
                    json.dumps(value),
                )
                if value is not None:
                    access_token_payload = validator.claim.add_to_payload_(
                        access_token_payload, value, user_context
                    )

        if json.dumps(access_token_payload) != original_access_token_payload:
            access_token_payload_update = access_token_payload

        invalid_claims = await validate_claims_in_payload(
            claim_validators, access_token_payload, user_context
        )

        return ClaimsValidationResult(invalid_claims, access_token_payload_update)

    async def validate_claims_in_jwt_payload(
        self,
        user_id: str,
        jwt_payload: JSONObject,
        claim_validators: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> ClaimsValidationResult:
        invalid_claims = await validate_claims_in_payload(
            claim_validators,
            jwt_payload,
            user_context,
        )

        return ClaimsValidationResult(invalid_claims)

    async def get_session(
        self,
        request: BaseRequest,
        anti_csrf_check: Union[bool, None],
        session_required: bool,
        user_context: Dict[str, Any],
    ) -> Optional[SessionContainer]:
        log_debug_message("getSession: Started")

        log_debug_message(
            "getSession: rid in header: %s", str(frontend_has_interceptor(request))
        )
        log_debug_message("getSession: request method: %s", request.method())

        id_refresh_token = get_id_refresh_token_from_cookie(request)
        if id_refresh_token is None:
            if not session_required:
                log_debug_message(
                    "getSession: returning None because idRefreshToken is undefined and session_required is false"
                )
                return None
            log_debug_message(
                "getSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Session does not exist. Are you sending the session tokens in the request as cookies?",
                False,
            )
        access_token: Union[str, None] = get_access_token_from_cookie(request)
        if access_token is None:
            if (
                session_required is True
                or frontend_has_interceptor(request)
                or normalise_http_method(request.method()) == "get"
            ):
                log_debug_message(
                    "getSession: Returning try refresh token because access token from cookies is undefined"
                )
                raise_try_refresh_token_exception(
                    "Access token has expired. Please call the refresh API"
                )
            return None
        anti_csrf_token = get_anti_csrf_header(request)
        if anti_csrf_check is None:
            anti_csrf_check = normalise_http_method(request.method()) != "get"

        log_debug_message(
            "getSession: Value of doAntiCsrfCheck is: %s", str(anti_csrf_check)
        )
        new_session = await session_functions.get_session(
            self,
            access_token,
            anti_csrf_token,
            anti_csrf_check,
            get_rid_header(request) is not None,
        )
        if "accessToken" in new_session:
            access_token = new_session["accessToken"]["token"]

        if access_token is None:
            raise Exception("Should never come here")
        session = Session(
            self,
            access_token,
            new_session["session"]["handle"],
            new_session["session"]["userId"],
            new_session["session"]["userDataInJWT"],
        )

        if "accessToken" in new_session:
            session.new_access_token_info = new_session["accessToken"]

        log_debug_message("getSession: Success!")
        request.set_session(session)
        return request.get_session()

    async def refresh_session(
        self, request: BaseRequest, user_context: Dict[str, Any]
    ) -> SessionContainer:
        log_debug_message("refreshSession: Started")

        id_refresh_token = get_id_refresh_token_from_cookie(request)
        if id_refresh_token is None:
            log_debug_message(
                "refreshSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Session does not exist. Are you sending the session tokens in the request "
                "as cookies?",
                False,
            )
        refresh_token = get_refresh_token_from_cookie(request)
        if refresh_token is None:
            log_debug_message(
                "refreshSession: UNAUTHORISED because refresh token from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Refresh token not found. Are you sending the refresh token in the "
                "request as a cookie?"
            )
        anti_csrf_token = get_anti_csrf_header(request)
        new_session = await session_functions.refresh_session(
            self, refresh_token, anti_csrf_token, get_rid_header(request) is not None
        )
        access_token = new_session["accessToken"]
        refresh_token = new_session["refreshToken"]
        id_refresh_token = new_session["idRefreshToken"]
        session = Session(
            self,
            access_token["token"],
            new_session["session"]["handle"],
            new_session["session"]["userId"],
            new_session["session"]["userDataInJWT"],
        )
        session.new_access_token_info = access_token
        session.new_refresh_token_info = refresh_token
        session.new_id_refresh_token_info = id_refresh_token
        if "antiCsrfToken" in new_session and new_session["antiCsrfToken"] is not None:
            session.new_anti_csrf_token = new_session["antiCsrfToken"]

        log_debug_message("refreshSession: Success!")
        request.set_session(session)
        return session

    async def revoke_session(
        self, session_handle: str, user_context: Dict[str, Any]
    ) -> bool:
        return await session_functions.revoke_session(self, session_handle)

    async def revoke_all_sessions_for_user(
        self, user_id: str, user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.revoke_all_sessions_for_user(self, user_id)

    async def get_all_session_handles_for_user(
        self, user_id: str, user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.get_all_session_handles_for_user(self, user_id)

    async def revoke_multiple_sessions(
        self, session_handles: List[str], user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.revoke_multiple_sessions(self, session_handles)

    async def get_session_information(
        self, session_handle: str, user_context: Dict[str, Any]
    ) -> Union[SessionInformationResult, None]:
        return await session_functions.get_session_information(self, session_handle)

    async def update_session_data(
        self,
        session_handle: str,
        new_session_data: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        return await session_functions.update_session_data(
            self, session_handle, new_session_data
        )

    async def update_access_token_payload(
        self,
        session_handle: str,
        new_access_token_payload: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:

        return await session_functions.update_access_token_payload(
            self, session_handle, new_access_token_payload
        )

    async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
        return (await self.get_handshake_info()).access_token_validity

    async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
        return (await self.get_handshake_info()).refresh_token_validity

    async def merge_into_access_token_payload(
        self,
        session_handle: str,
        access_token_payload_update: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return False

        new_access_token_payload = {
            **session_info.access_token_payload,
            **access_token_payload_update,
        }
        for k in access_token_payload_update.keys():
            if new_access_token_payload[k] is None:
                del new_access_token_payload[k]

        return await self.update_access_token_payload(
            session_handle, new_access_token_payload, user_context
        )

    async def fetch_and_set_claim(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return False

        access_token_payload_update = await claim.build(
            session_info.user_id, user_context
        )
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload_update, user_context
        )

    async def set_claim_value(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        value: Any,
        user_context: Dict[str, Any],
    ):
        access_token_payload_update = claim.add_to_payload_({}, value, user_context)
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload_update, user_context
        )

    async def get_claim_value(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return SessionDoesNotExistError()

        return GetClaimValueOkResult(
            value=claim.get_value_from_payload(
                session_info.access_token_payload, user_context
            )
        )

    def get_global_claim_validators(
        self,
        user_id: str,
        claim_validators_added_by_other_recipes: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> MaybeAwaitable[List[SessionClaimValidator]]:
        return claim_validators_added_by_other_recipes

    async def remove_claim(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> bool:
        access_token_payload = claim.remove_from_payload_by_merge_({}, user_context)
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload, user_context
        )

    async def regenerate_access_token(
        self,
        access_token: str,
        new_access_token_payload: Union[Dict[str, Any], None],
        user_context: Dict[str, Any],
    ) -> Union[RegenerateAccessTokenOkResult, None]:
        if new_access_token_payload is None:
            new_access_token_payload = {}
        response: Dict[str, Any] = await self.querier.send_post_request(
            NormalisedURLPath("/recipe/session/regenerate"),
            {"accessToken": access_token, "userDataInJWT": new_access_token_payload},
        )
        if response["status"] == "UNAUTHORISED":
            return None
        access_token_obj: Union[None, AccessTokenObj] = None
        if "accessToken" in response:
            access_token_obj = AccessTokenObj(
                response["accessToken"]["token"],
                response["accessToken"]["expiry"],
                response["accessToken"]["createdTime"],
            )
        session = SessionObj(
            response["session"]["handle"],
            response["session"]["userId"],
            response["session"]["userDataInJWT"],
        )
        return RegenerateAccessTokenOkResult(session, access_token_obj)

Classes

class HandshakeInfo (info: Dict[str, Any])
Expand source code
class HandshakeInfo:
    def __init__(self, info: Dict[str, Any]):
        self.access_token_blacklisting_enabled = info["accessTokenBlacklistingEnabled"]
        self.raw_jwt_signing_public_key_list: List[Dict[str, Any]] = []
        self.anti_csrf = info["antiCsrf"]
        self.access_token_validity = info["accessTokenValidity"]
        self.refresh_token_validity = info["refreshTokenValidity"]

    def set_jwt_signing_public_key_list(self, updated_list: List[Dict[str, Any]]):
        self.raw_jwt_signing_public_key_list = updated_list

    def get_jwt_signing_public_key_list(self) -> List[Dict[str, Any]]:
        time_now = get_timestamp_ms()
        return [
            key
            for key in self.raw_jwt_signing_public_key_list
            if key["expiryTime"] > time_now
        ]

Methods

def get_jwt_signing_public_key_list(self) ‑> List[Dict[str, Any]]
Expand source code
def get_jwt_signing_public_key_list(self) -> List[Dict[str, Any]]:
    time_now = get_timestamp_ms()
    return [
        key
        for key in self.raw_jwt_signing_public_key_list
        if key["expiryTime"] > time_now
    ]
def set_jwt_signing_public_key_list(self, updated_list: List[Dict[str, Any]])
Expand source code
def set_jwt_signing_public_key_list(self, updated_list: List[Dict[str, Any]]):
    self.raw_jwt_signing_public_key_list = updated_list
class RecipeImplementation (querier: Querier, config: SessionConfig)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class RecipeImplementation(RecipeInterface):  # pylint: disable=too-many-public-methods
    def __init__(self, querier: Querier, config: SessionConfig):
        super().__init__()
        self.querier = querier
        self.config = config
        self.handshake_info: Union[HandshakeInfo, None] = None

        async def call_get_handshake_info():
            try:
                await self.get_handshake_info()
            except Exception:
                pass

        try:
            execute_async(config.mode, call_get_handshake_info)
        except Exception:
            pass

    async def get_handshake_info(self, force_refetch: bool = False) -> HandshakeInfo:
        if (
            self.handshake_info is None
            or len(self.handshake_info.get_jwt_signing_public_key_list()) == 0
            or force_refetch
        ):
            ProcessState.get_instance().add_state(
                AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO
            )
            response = await self.querier.send_post_request(
                NormalisedURLPath("/recipe/handshake"), {}
            )
            self.handshake_info = HandshakeInfo(
                {**response, "antiCsrf": self.config.anti_csrf}
            )

            self.update_jwt_signing_public_key_info(
                response["jwtSigningPublicKeyList"],
                response["jwtSigningPublicKey"],
                response["jwtSigningPublicKeyExpiryTime"],
            )

        return self.handshake_info

    def update_jwt_signing_public_key_info(
        self,
        key_list: Union[List[Dict[str, Any]], None],
        public_key: str,
        expiry_time: int,
    ):
        if key_list is None:
            key_list = [
                {
                    "publicKey": public_key,
                    "expiryTime": expiry_time,
                    "createdAt": get_timestamp_ms(),
                }
            ]

        if self.handshake_info is not None:
            self.handshake_info.set_jwt_signing_public_key_list(key_list)

    async def create_new_session(
        self,
        request: BaseRequest,
        user_id: str,
        access_token_payload: Union[None, Dict[str, Any]],
        session_data: Union[None, Dict[str, Any]],
        user_context: Dict[str, Any],
    ) -> SessionContainer:
        session = await session_functions.create_new_session(
            self, user_id, access_token_payload, session_data
        )
        access_token = session["accessToken"]
        refresh_token = session["refreshToken"]
        id_refresh_token = session["idRefreshToken"]
        new_session = Session(
            self,
            access_token["token"],
            session["session"]["handle"],
            session["session"]["userId"],
            session["session"]["userDataInJWT"],
        )
        new_session.new_access_token_info = access_token
        new_session.new_refresh_token_info = refresh_token
        new_session.new_id_refresh_token_info = id_refresh_token
        if "antiCsrfToken" in session and session["antiCsrfToken"] is not None:
            new_session.new_anti_csrf_token = session["antiCsrfToken"]
        request.set_session(new_session)
        return new_session

    async def validate_claims(
        self,
        user_id: str,
        access_token_payload: Dict[str, Any],
        claim_validators: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> ClaimsValidationResult:
        access_token_payload_update = None
        original_access_token_payload = json.dumps(access_token_payload)

        for validator in claim_validators:
            log_debug_message(
                "update_claims_in_payload_if_needed checking should_refetch for %s",
                validator.id,
            )
            if validator.claim is not None and validator.should_refetch(
                access_token_payload, user_context
            ):
                log_debug_message(
                    "update_claims_in_payload_if_needed refetching for %s", validator.id
                )
                value = await resolve(
                    validator.claim.fetch_value(user_id, user_context)
                )
                log_debug_message(
                    "update_claims_in_payload_if_needed %s refetch result %s",
                    validator.id,
                    json.dumps(value),
                )
                if value is not None:
                    access_token_payload = validator.claim.add_to_payload_(
                        access_token_payload, value, user_context
                    )

        if json.dumps(access_token_payload) != original_access_token_payload:
            access_token_payload_update = access_token_payload

        invalid_claims = await validate_claims_in_payload(
            claim_validators, access_token_payload, user_context
        )

        return ClaimsValidationResult(invalid_claims, access_token_payload_update)

    async def validate_claims_in_jwt_payload(
        self,
        user_id: str,
        jwt_payload: JSONObject,
        claim_validators: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> ClaimsValidationResult:
        invalid_claims = await validate_claims_in_payload(
            claim_validators,
            jwt_payload,
            user_context,
        )

        return ClaimsValidationResult(invalid_claims)

    async def get_session(
        self,
        request: BaseRequest,
        anti_csrf_check: Union[bool, None],
        session_required: bool,
        user_context: Dict[str, Any],
    ) -> Optional[SessionContainer]:
        log_debug_message("getSession: Started")

        log_debug_message(
            "getSession: rid in header: %s", str(frontend_has_interceptor(request))
        )
        log_debug_message("getSession: request method: %s", request.method())

        id_refresh_token = get_id_refresh_token_from_cookie(request)
        if id_refresh_token is None:
            if not session_required:
                log_debug_message(
                    "getSession: returning None because idRefreshToken is undefined and session_required is false"
                )
                return None
            log_debug_message(
                "getSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Session does not exist. Are you sending the session tokens in the request as cookies?",
                False,
            )
        access_token: Union[str, None] = get_access_token_from_cookie(request)
        if access_token is None:
            if (
                session_required is True
                or frontend_has_interceptor(request)
                or normalise_http_method(request.method()) == "get"
            ):
                log_debug_message(
                    "getSession: Returning try refresh token because access token from cookies is undefined"
                )
                raise_try_refresh_token_exception(
                    "Access token has expired. Please call the refresh API"
                )
            return None
        anti_csrf_token = get_anti_csrf_header(request)
        if anti_csrf_check is None:
            anti_csrf_check = normalise_http_method(request.method()) != "get"

        log_debug_message(
            "getSession: Value of doAntiCsrfCheck is: %s", str(anti_csrf_check)
        )
        new_session = await session_functions.get_session(
            self,
            access_token,
            anti_csrf_token,
            anti_csrf_check,
            get_rid_header(request) is not None,
        )
        if "accessToken" in new_session:
            access_token = new_session["accessToken"]["token"]

        if access_token is None:
            raise Exception("Should never come here")
        session = Session(
            self,
            access_token,
            new_session["session"]["handle"],
            new_session["session"]["userId"],
            new_session["session"]["userDataInJWT"],
        )

        if "accessToken" in new_session:
            session.new_access_token_info = new_session["accessToken"]

        log_debug_message("getSession: Success!")
        request.set_session(session)
        return request.get_session()

    async def refresh_session(
        self, request: BaseRequest, user_context: Dict[str, Any]
    ) -> SessionContainer:
        log_debug_message("refreshSession: Started")

        id_refresh_token = get_id_refresh_token_from_cookie(request)
        if id_refresh_token is None:
            log_debug_message(
                "refreshSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Session does not exist. Are you sending the session tokens in the request "
                "as cookies?",
                False,
            )
        refresh_token = get_refresh_token_from_cookie(request)
        if refresh_token is None:
            log_debug_message(
                "refreshSession: UNAUTHORISED because refresh token from cookies is undefined"
            )
            raise_unauthorised_exception(
                "Refresh token not found. Are you sending the refresh token in the "
                "request as a cookie?"
            )
        anti_csrf_token = get_anti_csrf_header(request)
        new_session = await session_functions.refresh_session(
            self, refresh_token, anti_csrf_token, get_rid_header(request) is not None
        )
        access_token = new_session["accessToken"]
        refresh_token = new_session["refreshToken"]
        id_refresh_token = new_session["idRefreshToken"]
        session = Session(
            self,
            access_token["token"],
            new_session["session"]["handle"],
            new_session["session"]["userId"],
            new_session["session"]["userDataInJWT"],
        )
        session.new_access_token_info = access_token
        session.new_refresh_token_info = refresh_token
        session.new_id_refresh_token_info = id_refresh_token
        if "antiCsrfToken" in new_session and new_session["antiCsrfToken"] is not None:
            session.new_anti_csrf_token = new_session["antiCsrfToken"]

        log_debug_message("refreshSession: Success!")
        request.set_session(session)
        return session

    async def revoke_session(
        self, session_handle: str, user_context: Dict[str, Any]
    ) -> bool:
        return await session_functions.revoke_session(self, session_handle)

    async def revoke_all_sessions_for_user(
        self, user_id: str, user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.revoke_all_sessions_for_user(self, user_id)

    async def get_all_session_handles_for_user(
        self, user_id: str, user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.get_all_session_handles_for_user(self, user_id)

    async def revoke_multiple_sessions(
        self, session_handles: List[str], user_context: Dict[str, Any]
    ) -> List[str]:
        return await session_functions.revoke_multiple_sessions(self, session_handles)

    async def get_session_information(
        self, session_handle: str, user_context: Dict[str, Any]
    ) -> Union[SessionInformationResult, None]:
        return await session_functions.get_session_information(self, session_handle)

    async def update_session_data(
        self,
        session_handle: str,
        new_session_data: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        return await session_functions.update_session_data(
            self, session_handle, new_session_data
        )

    async def update_access_token_payload(
        self,
        session_handle: str,
        new_access_token_payload: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:

        return await session_functions.update_access_token_payload(
            self, session_handle, new_access_token_payload
        )

    async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
        return (await self.get_handshake_info()).access_token_validity

    async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
        return (await self.get_handshake_info()).refresh_token_validity

    async def merge_into_access_token_payload(
        self,
        session_handle: str,
        access_token_payload_update: Dict[str, Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return False

        new_access_token_payload = {
            **session_info.access_token_payload,
            **access_token_payload_update,
        }
        for k in access_token_payload_update.keys():
            if new_access_token_payload[k] is None:
                del new_access_token_payload[k]

        return await self.update_access_token_payload(
            session_handle, new_access_token_payload, user_context
        )

    async def fetch_and_set_claim(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> bool:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return False

        access_token_payload_update = await claim.build(
            session_info.user_id, user_context
        )
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload_update, user_context
        )

    async def set_claim_value(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        value: Any,
        user_context: Dict[str, Any],
    ):
        access_token_payload_update = claim.add_to_payload_({}, value, user_context)
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload_update, user_context
        )

    async def get_claim_value(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
        session_info = await self.get_session_information(session_handle, user_context)
        if session_info is None:
            return SessionDoesNotExistError()

        return GetClaimValueOkResult(
            value=claim.get_value_from_payload(
                session_info.access_token_payload, user_context
            )
        )

    def get_global_claim_validators(
        self,
        user_id: str,
        claim_validators_added_by_other_recipes: List[SessionClaimValidator],
        user_context: Dict[str, Any],
    ) -> MaybeAwaitable[List[SessionClaimValidator]]:
        return claim_validators_added_by_other_recipes

    async def remove_claim(
        self,
        session_handle: str,
        claim: SessionClaim[Any],
        user_context: Dict[str, Any],
    ) -> bool:
        access_token_payload = claim.remove_from_payload_by_merge_({}, user_context)
        return await self.merge_into_access_token_payload(
            session_handle, access_token_payload, user_context
        )

    async def regenerate_access_token(
        self,
        access_token: str,
        new_access_token_payload: Union[Dict[str, Any], None],
        user_context: Dict[str, Any],
    ) -> Union[RegenerateAccessTokenOkResult, None]:
        if new_access_token_payload is None:
            new_access_token_payload = {}
        response: Dict[str, Any] = await self.querier.send_post_request(
            NormalisedURLPath("/recipe/session/regenerate"),
            {"accessToken": access_token, "userDataInJWT": new_access_token_payload},
        )
        if response["status"] == "UNAUTHORISED":
            return None
        access_token_obj: Union[None, AccessTokenObj] = None
        if "accessToken" in response:
            access_token_obj = AccessTokenObj(
                response["accessToken"]["token"],
                response["accessToken"]["expiry"],
                response["accessToken"]["createdTime"],
            )
        session = SessionObj(
            response["session"]["handle"],
            response["session"]["userId"],
            response["session"]["userDataInJWT"],
        )
        return RegenerateAccessTokenOkResult(session, access_token_obj)

Ancestors

Methods

async def create_new_session(self, request: BaseRequest, user_id: str, access_token_payload: Union[None, Dict[str, Any]], session_data: Union[None, Dict[str, Any]], user_context: Dict[str, Any]) ‑> SessionContainer
Expand source code
async def create_new_session(
    self,
    request: BaseRequest,
    user_id: str,
    access_token_payload: Union[None, Dict[str, Any]],
    session_data: Union[None, Dict[str, Any]],
    user_context: Dict[str, Any],
) -> SessionContainer:
    session = await session_functions.create_new_session(
        self, user_id, access_token_payload, session_data
    )
    access_token = session["accessToken"]
    refresh_token = session["refreshToken"]
    id_refresh_token = session["idRefreshToken"]
    new_session = Session(
        self,
        access_token["token"],
        session["session"]["handle"],
        session["session"]["userId"],
        session["session"]["userDataInJWT"],
    )
    new_session.new_access_token_info = access_token
    new_session.new_refresh_token_info = refresh_token
    new_session.new_id_refresh_token_info = id_refresh_token
    if "antiCsrfToken" in session and session["antiCsrfToken"] is not None:
        new_session.new_anti_csrf_token = session["antiCsrfToken"]
    request.set_session(new_session)
    return new_session
async def fetch_and_set_claim(self, session_handle: str, claim: SessionClaim[Any], user_context: Dict[str, Any]) ‑> bool
Expand source code
async def fetch_and_set_claim(
    self,
    session_handle: str,
    claim: SessionClaim[Any],
    user_context: Dict[str, Any],
) -> bool:
    session_info = await self.get_session_information(session_handle, user_context)
    if session_info is None:
        return False

    access_token_payload_update = await claim.build(
        session_info.user_id, user_context
    )
    return await self.merge_into_access_token_payload(
        session_handle, access_token_payload_update, user_context
    )
async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) ‑> int
Expand source code
async def get_access_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
    return (await self.get_handshake_info()).access_token_validity
async def get_all_session_handles_for_user(self, user_id: str, user_context: Dict[str, Any]) ‑> List[str]
Expand source code
async def get_all_session_handles_for_user(
    self, user_id: str, user_context: Dict[str, Any]
) -> List[str]:
    return await session_functions.get_all_session_handles_for_user(self, user_id)
async def get_claim_value(self, session_handle: str, claim: SessionClaim[Any], user_context: Dict[str, Any]) ‑> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]
Expand source code
async def get_claim_value(
    self,
    session_handle: str,
    claim: SessionClaim[Any],
    user_context: Dict[str, Any],
) -> Union[SessionDoesNotExistError, GetClaimValueOkResult[Any]]:
    session_info = await self.get_session_information(session_handle, user_context)
    if session_info is None:
        return SessionDoesNotExistError()

    return GetClaimValueOkResult(
        value=claim.get_value_from_payload(
            session_info.access_token_payload, user_context
        )
    )
def get_global_claim_validators(self, user_id: str, claim_validators_added_by_other_recipes: List[SessionClaimValidator], user_context: Dict[str, Any]) ‑> MaybeAwaitable[List[SessionClaimValidator]]
Expand source code
def get_global_claim_validators(
    self,
    user_id: str,
    claim_validators_added_by_other_recipes: List[SessionClaimValidator],
    user_context: Dict[str, Any],
) -> MaybeAwaitable[List[SessionClaimValidator]]:
    return claim_validators_added_by_other_recipes
async def get_handshake_info(self, force_refetch: bool = False) ‑> HandshakeInfo
Expand source code
async def get_handshake_info(self, force_refetch: bool = False) -> HandshakeInfo:
    if (
        self.handshake_info is None
        or len(self.handshake_info.get_jwt_signing_public_key_list()) == 0
        or force_refetch
    ):
        ProcessState.get_instance().add_state(
            AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO
        )
        response = await self.querier.send_post_request(
            NormalisedURLPath("/recipe/handshake"), {}
        )
        self.handshake_info = HandshakeInfo(
            {**response, "antiCsrf": self.config.anti_csrf}
        )

        self.update_jwt_signing_public_key_info(
            response["jwtSigningPublicKeyList"],
            response["jwtSigningPublicKey"],
            response["jwtSigningPublicKeyExpiryTime"],
        )

    return self.handshake_info
async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) ‑> int
Expand source code
async def get_refresh_token_lifetime_ms(self, user_context: Dict[str, Any]) -> int:
    return (await self.get_handshake_info()).refresh_token_validity
async def get_session(self, request: BaseRequest, anti_csrf_check: Union[bool, None], session_required: bool, user_context: Dict[str, Any]) ‑> Optional[SessionContainer]
Expand source code
async def get_session(
    self,
    request: BaseRequest,
    anti_csrf_check: Union[bool, None],
    session_required: bool,
    user_context: Dict[str, Any],
) -> Optional[SessionContainer]:
    log_debug_message("getSession: Started")

    log_debug_message(
        "getSession: rid in header: %s", str(frontend_has_interceptor(request))
    )
    log_debug_message("getSession: request method: %s", request.method())

    id_refresh_token = get_id_refresh_token_from_cookie(request)
    if id_refresh_token is None:
        if not session_required:
            log_debug_message(
                "getSession: returning None because idRefreshToken is undefined and session_required is false"
            )
            return None
        log_debug_message(
            "getSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
        )
        raise_unauthorised_exception(
            "Session does not exist. Are you sending the session tokens in the request as cookies?",
            False,
        )
    access_token: Union[str, None] = get_access_token_from_cookie(request)
    if access_token is None:
        if (
            session_required is True
            or frontend_has_interceptor(request)
            or normalise_http_method(request.method()) == "get"
        ):
            log_debug_message(
                "getSession: Returning try refresh token because access token from cookies is undefined"
            )
            raise_try_refresh_token_exception(
                "Access token has expired. Please call the refresh API"
            )
        return None
    anti_csrf_token = get_anti_csrf_header(request)
    if anti_csrf_check is None:
        anti_csrf_check = normalise_http_method(request.method()) != "get"

    log_debug_message(
        "getSession: Value of doAntiCsrfCheck is: %s", str(anti_csrf_check)
    )
    new_session = await session_functions.get_session(
        self,
        access_token,
        anti_csrf_token,
        anti_csrf_check,
        get_rid_header(request) is not None,
    )
    if "accessToken" in new_session:
        access_token = new_session["accessToken"]["token"]

    if access_token is None:
        raise Exception("Should never come here")
    session = Session(
        self,
        access_token,
        new_session["session"]["handle"],
        new_session["session"]["userId"],
        new_session["session"]["userDataInJWT"],
    )

    if "accessToken" in new_session:
        session.new_access_token_info = new_session["accessToken"]

    log_debug_message("getSession: Success!")
    request.set_session(session)
    return request.get_session()
async def get_session_information(self, session_handle: str, user_context: Dict[str, Any]) ‑> Union[SessionInformationResult, None]
Expand source code
async def get_session_information(
    self, session_handle: str, user_context: Dict[str, Any]
) -> Union[SessionInformationResult, None]:
    return await session_functions.get_session_information(self, session_handle)
async def merge_into_access_token_payload(self, session_handle: str, access_token_payload_update: Dict[str, Any], user_context: Dict[str, Any]) ‑> bool
Expand source code
async def merge_into_access_token_payload(
    self,
    session_handle: str,
    access_token_payload_update: Dict[str, Any],
    user_context: Dict[str, Any],
) -> bool:
    session_info = await self.get_session_information(session_handle, user_context)
    if session_info is None:
        return False

    new_access_token_payload = {
        **session_info.access_token_payload,
        **access_token_payload_update,
    }
    for k in access_token_payload_update.keys():
        if new_access_token_payload[k] is None:
            del new_access_token_payload[k]

    return await self.update_access_token_payload(
        session_handle, new_access_token_payload, user_context
    )
async def refresh_session(self, request: BaseRequest, user_context: Dict[str, Any]) ‑> SessionContainer
Expand source code
async def refresh_session(
    self, request: BaseRequest, user_context: Dict[str, Any]
) -> SessionContainer:
    log_debug_message("refreshSession: Started")

    id_refresh_token = get_id_refresh_token_from_cookie(request)
    if id_refresh_token is None:
        log_debug_message(
            "refreshSession: UNAUTHORISED because idRefreshToken from cookies is undefined"
        )
        raise_unauthorised_exception(
            "Session does not exist. Are you sending the session tokens in the request "
            "as cookies?",
            False,
        )
    refresh_token = get_refresh_token_from_cookie(request)
    if refresh_token is None:
        log_debug_message(
            "refreshSession: UNAUTHORISED because refresh token from cookies is undefined"
        )
        raise_unauthorised_exception(
            "Refresh token not found. Are you sending the refresh token in the "
            "request as a cookie?"
        )
    anti_csrf_token = get_anti_csrf_header(request)
    new_session = await session_functions.refresh_session(
        self, refresh_token, anti_csrf_token, get_rid_header(request) is not None
    )
    access_token = new_session["accessToken"]
    refresh_token = new_session["refreshToken"]
    id_refresh_token = new_session["idRefreshToken"]
    session = Session(
        self,
        access_token["token"],
        new_session["session"]["handle"],
        new_session["session"]["userId"],
        new_session["session"]["userDataInJWT"],
    )
    session.new_access_token_info = access_token
    session.new_refresh_token_info = refresh_token
    session.new_id_refresh_token_info = id_refresh_token
    if "antiCsrfToken" in new_session and new_session["antiCsrfToken"] is not None:
        session.new_anti_csrf_token = new_session["antiCsrfToken"]

    log_debug_message("refreshSession: Success!")
    request.set_session(session)
    return session
async def regenerate_access_token(self, access_token: str, new_access_token_payload: Union[Dict[str, Any], None], user_context: Dict[str, Any]) ‑> Union[RegenerateAccessTokenOkResult, None]
Expand source code
async def regenerate_access_token(
    self,
    access_token: str,
    new_access_token_payload: Union[Dict[str, Any], None],
    user_context: Dict[str, Any],
) -> Union[RegenerateAccessTokenOkResult, None]:
    if new_access_token_payload is None:
        new_access_token_payload = {}
    response: Dict[str, Any] = await self.querier.send_post_request(
        NormalisedURLPath("/recipe/session/regenerate"),
        {"accessToken": access_token, "userDataInJWT": new_access_token_payload},
    )
    if response["status"] == "UNAUTHORISED":
        return None
    access_token_obj: Union[None, AccessTokenObj] = None
    if "accessToken" in response:
        access_token_obj = AccessTokenObj(
            response["accessToken"]["token"],
            response["accessToken"]["expiry"],
            response["accessToken"]["createdTime"],
        )
    session = SessionObj(
        response["session"]["handle"],
        response["session"]["userId"],
        response["session"]["userDataInJWT"],
    )
    return RegenerateAccessTokenOkResult(session, access_token_obj)
async def remove_claim(self, session_handle: str, claim: SessionClaim[Any], user_context: Dict[str, Any]) ‑> bool
Expand source code
async def remove_claim(
    self,
    session_handle: str,
    claim: SessionClaim[Any],
    user_context: Dict[str, Any],
) -> bool:
    access_token_payload = claim.remove_from_payload_by_merge_({}, user_context)
    return await self.merge_into_access_token_payload(
        session_handle, access_token_payload, user_context
    )
async def revoke_all_sessions_for_user(self, user_id: str, user_context: Dict[str, Any]) ‑> List[str]
Expand source code
async def revoke_all_sessions_for_user(
    self, user_id: str, user_context: Dict[str, Any]
) -> List[str]:
    return await session_functions.revoke_all_sessions_for_user(self, user_id)
async def revoke_multiple_sessions(self, session_handles: List[str], user_context: Dict[str, Any]) ‑> List[str]
Expand source code
async def revoke_multiple_sessions(
    self, session_handles: List[str], user_context: Dict[str, Any]
) -> List[str]:
    return await session_functions.revoke_multiple_sessions(self, session_handles)
async def revoke_session(self, session_handle: str, user_context: Dict[str, Any]) ‑> bool
Expand source code
async def revoke_session(
    self, session_handle: str, user_context: Dict[str, Any]
) -> bool:
    return await session_functions.revoke_session(self, session_handle)
async def set_claim_value(self, session_handle: str, claim: SessionClaim[Any], value: Any, user_context: Dict[str, Any])
Expand source code
async def set_claim_value(
    self,
    session_handle: str,
    claim: SessionClaim[Any],
    value: Any,
    user_context: Dict[str, Any],
):
    access_token_payload_update = claim.add_to_payload_({}, value, user_context)
    return await self.merge_into_access_token_payload(
        session_handle, access_token_payload_update, user_context
    )
def update_jwt_signing_public_key_info(self, key_list: Union[List[Dict[str, Any]], None], public_key: str, expiry_time: int)
Expand source code
def update_jwt_signing_public_key_info(
    self,
    key_list: Union[List[Dict[str, Any]], None],
    public_key: str,
    expiry_time: int,
):
    if key_list is None:
        key_list = [
            {
                "publicKey": public_key,
                "expiryTime": expiry_time,
                "createdAt": get_timestamp_ms(),
            }
        ]

    if self.handshake_info is not None:
        self.handshake_info.set_jwt_signing_public_key_list(key_list)
async def update_session_data(self, session_handle: str, new_session_data: Dict[str, Any], user_context: Dict[str, Any]) ‑> bool
Expand source code
async def update_session_data(
    self,
    session_handle: str,
    new_session_data: Dict[str, Any],
    user_context: Dict[str, Any],
) -> bool:
    return await session_functions.update_session_data(
        self, session_handle, new_session_data
    )
async def validate_claims(self, user_id: str, access_token_payload: Dict[str, Any], claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any]) ‑> ClaimsValidationResult
Expand source code
async def validate_claims(
    self,
    user_id: str,
    access_token_payload: Dict[str, Any],
    claim_validators: List[SessionClaimValidator],
    user_context: Dict[str, Any],
) -> ClaimsValidationResult:
    access_token_payload_update = None
    original_access_token_payload = json.dumps(access_token_payload)

    for validator in claim_validators:
        log_debug_message(
            "update_claims_in_payload_if_needed checking should_refetch for %s",
            validator.id,
        )
        if validator.claim is not None and validator.should_refetch(
            access_token_payload, user_context
        ):
            log_debug_message(
                "update_claims_in_payload_if_needed refetching for %s", validator.id
            )
            value = await resolve(
                validator.claim.fetch_value(user_id, user_context)
            )
            log_debug_message(
                "update_claims_in_payload_if_needed %s refetch result %s",
                validator.id,
                json.dumps(value),
            )
            if value is not None:
                access_token_payload = validator.claim.add_to_payload_(
                    access_token_payload, value, user_context
                )

    if json.dumps(access_token_payload) != original_access_token_payload:
        access_token_payload_update = access_token_payload

    invalid_claims = await validate_claims_in_payload(
        claim_validators, access_token_payload, user_context
    )

    return ClaimsValidationResult(invalid_claims, access_token_payload_update)
async def validate_claims_in_jwt_payload(self, user_id: str, jwt_payload: JSONObject, claim_validators: List[SessionClaimValidator], user_context: Dict[str, Any]) ‑> ClaimsValidationResult
Expand source code
async def validate_claims_in_jwt_payload(
    self,
    user_id: str,
    jwt_payload: JSONObject,
    claim_validators: List[SessionClaimValidator],
    user_context: Dict[str, Any],
) -> ClaimsValidationResult:
    invalid_claims = await validate_claims_in_payload(
        claim_validators,
        jwt_payload,
        user_context,
    )

    return ClaimsValidationResult(invalid_claims)

Inherited members