Module supertokens_python.recipe.session.recipe

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from os import environ
from typing import TYPE_CHECKING, Any, Dict, List, Union

from supertokens_python.framework.response import BaseResponse
from typing_extensions import Literal

from .api import handle_refresh_api, handle_signout_api
from .cookie_and_header import get_cors_allowed_headers
from .exceptions import (SuperTokensSessionError, TokenTheftError,
                         UnauthorisedError)

if TYPE_CHECKING:
    from supertokens_python.framework import BaseRequest
    from supertokens_python.supertokens import AppInfo

from supertokens_python.exceptions import (SuperTokensError,
                                           raise_general_exception)
from supertokens_python.logger import log_debug_message
from supertokens_python.normalised_url_path import NormalisedURLPath
from supertokens_python.querier import Querier
from supertokens_python.recipe.openid.recipe import OpenIdRecipe
from supertokens_python.recipe.session.with_jwt import \
    get_recipe_implementation_with_jwt
from supertokens_python.recipe_module import APIHandled, RecipeModule

from .api.implementation import APIImplementation
from .constants import SESSION_REFRESH, SIGNOUT
from .interfaces import APIInterface, APIOptions, RecipeInterface
from .recipe_implementation import RecipeImplementation
from .utils import (InputErrorHandlers, InputOverrideConfig, JWTConfig,
                    validate_and_normalise_user_input)


class SessionRecipe(RecipeModule):
    recipe_id = 'session'
    __instance = None

    def __init__(self, recipe_id: str, app_info: AppInfo,
                 cookie_domain: Union[str, None] = None,
                 cookie_secure: Union[bool, None] = None,
                 cookie_same_site: Union[Literal["lax",
                                                 "none", "strict"], None] = None,
                 session_expired_status_code: Union[int, None] = None,
                 anti_csrf: Union[Literal["VIA_TOKEN",
                                          "VIA_CUSTOM_HEADER", "NONE"], None] = None,
                 error_handlers: Union[InputErrorHandlers, None] = None,
                 override: Union[InputOverrideConfig, None] = None,
                 jwt: Union[JWTConfig, None] = None):
        super().__init__(recipe_id, app_info)
        self.openid_recipe: Union[None, OpenIdRecipe] = None
        self.config = validate_and_normalise_user_input(app_info, cookie_domain,
                                                        cookie_secure,
                                                        cookie_same_site,
                                                        session_expired_status_code,
                                                        anti_csrf,
                                                        error_handlers,
                                                        override,
                                                        jwt)
        log_debug_message("session init: anti_csrf: %s", self.config.anti_csrf)
        if self.config.cookie_domain is not None:
            log_debug_message("session init: cookie_domain: %s", self.config.cookie_domain)
        else:
            log_debug_message("session init: cookie_domain: None")
        log_debug_message("session init: cookie_same_site: %s", self.config.cookie_same_site)
        log_debug_message("session init: cookie_secure: %s", str(self.config.cookie_secure))
        log_debug_message("session init: refresh_token_path: %s ", self.config.refresh_token_path.get_as_string_dangerous())
        log_debug_message("session init: session_expired_status_code: %s", str(self.config.session_expired_status_code))

        if self.config.jwt.enable:
            openid_feature_override = None
            if override is not None:
                openid_feature_override = override.openid_feature
            self.openid_recipe = OpenIdRecipe(recipe_id, app_info, None, self.config.jwt.issuer,
                                              openid_feature_override)
            recipe_implementation = RecipeImplementation(
                Querier.get_instance(recipe_id), self.config)
            recipe_implementation = get_recipe_implementation_with_jwt(recipe_implementation, self.config, self.openid_recipe.recipe_implementation)
        else:
            recipe_implementation = RecipeImplementation(
                Querier.get_instance(recipe_id), self.config)
        self.recipe_implementation: RecipeInterface = recipe_implementation if self.config.override.functions is None else self.config.override.functions(
            recipe_implementation)
        api_implementation = APIImplementation()
        self.api_implementation: APIInterface = api_implementation if self.config.override.apis is None else self.config.override.apis(
            api_implementation)

    def is_error_from_this_recipe_based_on_instance(
            self, err: Exception) -> bool:
        return isinstance(err, SuperTokensError) and (
            isinstance(err, SuperTokensSessionError)
            or
            (self.openid_recipe is not None and self.openid_recipe.is_error_from_this_recipe_based_on_instance(err))
        )

    def get_apis_handled(self) -> List[APIHandled]:
        apis_handled = [
            APIHandled(NormalisedURLPath(SESSION_REFRESH), 'post', SESSION_REFRESH,
                       self.api_implementation.disable_refresh_post),
            APIHandled(NormalisedURLPath(SIGNOUT), 'post', SIGNOUT,
                       self.api_implementation.disable_signout_post)
        ]
        if self.openid_recipe is not None:
            apis_handled = apis_handled + self.openid_recipe.get_apis_handled()

        return apis_handled

    async def handle_api_request(self, request_id: str, request: BaseRequest, path: NormalisedURLPath, method: str, response: BaseResponse) -> Union[BaseResponse, None]:
        if request_id == SESSION_REFRESH:
            return await handle_refresh_api(self.api_implementation,
                                            APIOptions(
                                                request,
                                                response,
                                                self.recipe_id,
                                                self.config,
                                                self.recipe_implementation
                                            ))
        if request_id == SIGNOUT:
            return await handle_signout_api(self.api_implementation,
                                            APIOptions(
                                                request,
                                                response,
                                                self.recipe_id,
                                                self.config,
                                                self.recipe_implementation
                                            ))
        if self.openid_recipe is not None:
            return await self.openid_recipe.handle_api_request(request_id, request, path, method, response)
        return None

    async def handle_error(self, request: BaseRequest, err: SuperTokensError, response: BaseResponse) -> BaseResponse:
        if isinstance(err, UnauthorisedError):
            log_debug_message("errorHandler: returning UNAUTHORISED")
            return await self.config.error_handlers.on_unauthorised(self, err.clear_cookies, request, str(err), response)
        if isinstance(err, TokenTheftError):
            log_debug_message("errorHandler: returning TOKEN_THEFT_DETECTED")
            return await self.config.error_handlers.on_token_theft_detected(self, request, err.session_handle, err.user_id, response)
        log_debug_message("errorHandler: returning TRY_REFRESH_TOKEN")
        return await self.config.error_handlers.on_try_refresh_token(request, str(err), response)

    def get_all_cors_headers(self) -> List[str]:
        cors_headers = get_cors_allowed_headers()
        if self.openid_recipe is not None:
            cors_headers = cors_headers + self.openid_recipe.get_all_cors_headers()

        return cors_headers

    @staticmethod
    def init(cookie_domain: Union[str, None] = None,
             cookie_secure: Union[bool, None] = None,
             cookie_same_site: Union[Literal["lax",
                                             "none", "strict"], None] = None,
             session_expired_status_code: Union[int, None] = None,
             anti_csrf: Union[Literal["VIA_TOKEN",
                                      "VIA_CUSTOM_HEADER", "NONE"], None] = None,
             error_handlers: Union[InputErrorHandlers, None] = None,
             override: Union[InputOverrideConfig, None] = None,
             jwt: Union[JWTConfig, None] = None):
        def func(app_info: AppInfo):
            if SessionRecipe.__instance is None:
                SessionRecipe.__instance = SessionRecipe(
                    SessionRecipe.recipe_id, app_info, cookie_domain,
                    cookie_secure,
                    cookie_same_site,
                    session_expired_status_code,
                    anti_csrf,
                    error_handlers,
                    override,
                    jwt
                )
                return SessionRecipe.__instance
            raise_general_exception(
                'Session recipe has already been initialised. Please check your code for bugs.')

        return func

    @staticmethod
    def get_instance() -> SessionRecipe:
        if SessionRecipe.__instance is not None:
            return SessionRecipe.__instance
        raise_general_exception(
            'Initialisation not done. Did you forget to call the SuperTokens.init function?')

    @staticmethod
    def reset():
        if ('SUPERTOKENS_ENV' not in environ) or (
                environ['SUPERTOKENS_ENV'] != 'testing'):
            raise_general_exception(
                'calling testing function in non testing env')
        SessionRecipe.__instance = None

    async def verify_session(self, request: BaseRequest,
                             anti_csrf_check: Union[bool, None],
                             session_required: bool, user_context: Dict[str, Any]):
        return await self.api_implementation.verify_session(
            APIOptions(
                request,
                None,
                self.recipe_id,
                self.config,
                self.recipe_implementation
            ), anti_csrf_check,
            session_required, user_context)

Classes

class SessionRecipe (recipe_id: str, app_info: AppInfo, cookie_domain: Union[str, None] = None, cookie_secure: Union[bool, None] = None, cookie_same_site: "Union[Literal['lax', 'none', 'strict'], None]" = None, session_expired_status_code: Union[int, None] = None, anti_csrf: "Union[Literal['VIA_TOKEN', 'VIA_CUSTOM_HEADER', 'NONE'], None]" = None, error_handlers: Union[InputErrorHandlers, None] = None, override: Union[InputOverrideConfig, None] = None, jwt: Union[JWTConfig, None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class SessionRecipe(RecipeModule):
    recipe_id = 'session'
    __instance = None

    def __init__(self, recipe_id: str, app_info: AppInfo,
                 cookie_domain: Union[str, None] = None,
                 cookie_secure: Union[bool, None] = None,
                 cookie_same_site: Union[Literal["lax",
                                                 "none", "strict"], None] = None,
                 session_expired_status_code: Union[int, None] = None,
                 anti_csrf: Union[Literal["VIA_TOKEN",
                                          "VIA_CUSTOM_HEADER", "NONE"], None] = None,
                 error_handlers: Union[InputErrorHandlers, None] = None,
                 override: Union[InputOverrideConfig, None] = None,
                 jwt: Union[JWTConfig, None] = None):
        super().__init__(recipe_id, app_info)
        self.openid_recipe: Union[None, OpenIdRecipe] = None
        self.config = validate_and_normalise_user_input(app_info, cookie_domain,
                                                        cookie_secure,
                                                        cookie_same_site,
                                                        session_expired_status_code,
                                                        anti_csrf,
                                                        error_handlers,
                                                        override,
                                                        jwt)
        log_debug_message("session init: anti_csrf: %s", self.config.anti_csrf)
        if self.config.cookie_domain is not None:
            log_debug_message("session init: cookie_domain: %s", self.config.cookie_domain)
        else:
            log_debug_message("session init: cookie_domain: None")
        log_debug_message("session init: cookie_same_site: %s", self.config.cookie_same_site)
        log_debug_message("session init: cookie_secure: %s", str(self.config.cookie_secure))
        log_debug_message("session init: refresh_token_path: %s ", self.config.refresh_token_path.get_as_string_dangerous())
        log_debug_message("session init: session_expired_status_code: %s", str(self.config.session_expired_status_code))

        if self.config.jwt.enable:
            openid_feature_override = None
            if override is not None:
                openid_feature_override = override.openid_feature
            self.openid_recipe = OpenIdRecipe(recipe_id, app_info, None, self.config.jwt.issuer,
                                              openid_feature_override)
            recipe_implementation = RecipeImplementation(
                Querier.get_instance(recipe_id), self.config)
            recipe_implementation = get_recipe_implementation_with_jwt(recipe_implementation, self.config, self.openid_recipe.recipe_implementation)
        else:
            recipe_implementation = RecipeImplementation(
                Querier.get_instance(recipe_id), self.config)
        self.recipe_implementation: RecipeInterface = recipe_implementation if self.config.override.functions is None else self.config.override.functions(
            recipe_implementation)
        api_implementation = APIImplementation()
        self.api_implementation: APIInterface = api_implementation if self.config.override.apis is None else self.config.override.apis(
            api_implementation)

    def is_error_from_this_recipe_based_on_instance(
            self, err: Exception) -> bool:
        return isinstance(err, SuperTokensError) and (
            isinstance(err, SuperTokensSessionError)
            or
            (self.openid_recipe is not None and self.openid_recipe.is_error_from_this_recipe_based_on_instance(err))
        )

    def get_apis_handled(self) -> List[APIHandled]:
        apis_handled = [
            APIHandled(NormalisedURLPath(SESSION_REFRESH), 'post', SESSION_REFRESH,
                       self.api_implementation.disable_refresh_post),
            APIHandled(NormalisedURLPath(SIGNOUT), 'post', SIGNOUT,
                       self.api_implementation.disable_signout_post)
        ]
        if self.openid_recipe is not None:
            apis_handled = apis_handled + self.openid_recipe.get_apis_handled()

        return apis_handled

    async def handle_api_request(self, request_id: str, request: BaseRequest, path: NormalisedURLPath, method: str, response: BaseResponse) -> Union[BaseResponse, None]:
        if request_id == SESSION_REFRESH:
            return await handle_refresh_api(self.api_implementation,
                                            APIOptions(
                                                request,
                                                response,
                                                self.recipe_id,
                                                self.config,
                                                self.recipe_implementation
                                            ))
        if request_id == SIGNOUT:
            return await handle_signout_api(self.api_implementation,
                                            APIOptions(
                                                request,
                                                response,
                                                self.recipe_id,
                                                self.config,
                                                self.recipe_implementation
                                            ))
        if self.openid_recipe is not None:
            return await self.openid_recipe.handle_api_request(request_id, request, path, method, response)
        return None

    async def handle_error(self, request: BaseRequest, err: SuperTokensError, response: BaseResponse) -> BaseResponse:
        if isinstance(err, UnauthorisedError):
            log_debug_message("errorHandler: returning UNAUTHORISED")
            return await self.config.error_handlers.on_unauthorised(self, err.clear_cookies, request, str(err), response)
        if isinstance(err, TokenTheftError):
            log_debug_message("errorHandler: returning TOKEN_THEFT_DETECTED")
            return await self.config.error_handlers.on_token_theft_detected(self, request, err.session_handle, err.user_id, response)
        log_debug_message("errorHandler: returning TRY_REFRESH_TOKEN")
        return await self.config.error_handlers.on_try_refresh_token(request, str(err), response)

    def get_all_cors_headers(self) -> List[str]:
        cors_headers = get_cors_allowed_headers()
        if self.openid_recipe is not None:
            cors_headers = cors_headers + self.openid_recipe.get_all_cors_headers()

        return cors_headers

    @staticmethod
    def init(cookie_domain: Union[str, None] = None,
             cookie_secure: Union[bool, None] = None,
             cookie_same_site: Union[Literal["lax",
                                             "none", "strict"], None] = None,
             session_expired_status_code: Union[int, None] = None,
             anti_csrf: Union[Literal["VIA_TOKEN",
                                      "VIA_CUSTOM_HEADER", "NONE"], None] = None,
             error_handlers: Union[InputErrorHandlers, None] = None,
             override: Union[InputOverrideConfig, None] = None,
             jwt: Union[JWTConfig, None] = None):
        def func(app_info: AppInfo):
            if SessionRecipe.__instance is None:
                SessionRecipe.__instance = SessionRecipe(
                    SessionRecipe.recipe_id, app_info, cookie_domain,
                    cookie_secure,
                    cookie_same_site,
                    session_expired_status_code,
                    anti_csrf,
                    error_handlers,
                    override,
                    jwt
                )
                return SessionRecipe.__instance
            raise_general_exception(
                'Session recipe has already been initialised. Please check your code for bugs.')

        return func

    @staticmethod
    def get_instance() -> SessionRecipe:
        if SessionRecipe.__instance is not None:
            return SessionRecipe.__instance
        raise_general_exception(
            'Initialisation not done. Did you forget to call the SuperTokens.init function?')

    @staticmethod
    def reset():
        if ('SUPERTOKENS_ENV' not in environ) or (
                environ['SUPERTOKENS_ENV'] != 'testing'):
            raise_general_exception(
                'calling testing function in non testing env')
        SessionRecipe.__instance = None

    async def verify_session(self, request: BaseRequest,
                             anti_csrf_check: Union[bool, None],
                             session_required: bool, user_context: Dict[str, Any]):
        return await self.api_implementation.verify_session(
            APIOptions(
                request,
                None,
                self.recipe_id,
                self.config,
                self.recipe_implementation
            ), anti_csrf_check,
            session_required, user_context)

Ancestors

Class variables

var recipe_id

Static methods

def get_instance() ‑> SessionRecipe
Expand source code
@staticmethod
def get_instance() -> SessionRecipe:
    if SessionRecipe.__instance is not None:
        return SessionRecipe.__instance
    raise_general_exception(
        'Initialisation not done. Did you forget to call the SuperTokens.init function?')
def init(cookie_domain: Union[str, None] = None, cookie_secure: Union[bool, None] = None, cookie_same_site: "Union[Literal['lax', 'none', 'strict'], None]" = None, session_expired_status_code: Union[int, None] = None, anti_csrf: "Union[Literal['VIA_TOKEN', 'VIA_CUSTOM_HEADER', 'NONE'], None]" = None, error_handlers: Union[InputErrorHandlers, None] = None, override: Union[InputOverrideConfig, None] = None, jwt: Union[JWTConfig, None] = None)
Expand source code
@staticmethod
def init(cookie_domain: Union[str, None] = None,
         cookie_secure: Union[bool, None] = None,
         cookie_same_site: Union[Literal["lax",
                                         "none", "strict"], None] = None,
         session_expired_status_code: Union[int, None] = None,
         anti_csrf: Union[Literal["VIA_TOKEN",
                                  "VIA_CUSTOM_HEADER", "NONE"], None] = None,
         error_handlers: Union[InputErrorHandlers, None] = None,
         override: Union[InputOverrideConfig, None] = None,
         jwt: Union[JWTConfig, None] = None):
    def func(app_info: AppInfo):
        if SessionRecipe.__instance is None:
            SessionRecipe.__instance = SessionRecipe(
                SessionRecipe.recipe_id, app_info, cookie_domain,
                cookie_secure,
                cookie_same_site,
                session_expired_status_code,
                anti_csrf,
                error_handlers,
                override,
                jwt
            )
            return SessionRecipe.__instance
        raise_general_exception(
            'Session recipe has already been initialised. Please check your code for bugs.')

    return func
def reset()
Expand source code
@staticmethod
def reset():
    if ('SUPERTOKENS_ENV' not in environ) or (
            environ['SUPERTOKENS_ENV'] != 'testing'):
        raise_general_exception(
            'calling testing function in non testing env')
    SessionRecipe.__instance = None

Methods

def get_all_cors_headers(self) ‑> List[str]
Expand source code
def get_all_cors_headers(self) -> List[str]:
    cors_headers = get_cors_allowed_headers()
    if self.openid_recipe is not None:
        cors_headers = cors_headers + self.openid_recipe.get_all_cors_headers()

    return cors_headers
def get_apis_handled(self) ‑> List[APIHandled]
Expand source code
def get_apis_handled(self) -> List[APIHandled]:
    apis_handled = [
        APIHandled(NormalisedURLPath(SESSION_REFRESH), 'post', SESSION_REFRESH,
                   self.api_implementation.disable_refresh_post),
        APIHandled(NormalisedURLPath(SIGNOUT), 'post', SIGNOUT,
                   self.api_implementation.disable_signout_post)
    ]
    if self.openid_recipe is not None:
        apis_handled = apis_handled + self.openid_recipe.get_apis_handled()

    return apis_handled
async def handle_api_request(self, request_id: str, request: BaseRequest, path: NormalisedURLPath, method: str, response: BaseResponse) ‑> Union[BaseResponse, None]
Expand source code
async def handle_api_request(self, request_id: str, request: BaseRequest, path: NormalisedURLPath, method: str, response: BaseResponse) -> Union[BaseResponse, None]:
    if request_id == SESSION_REFRESH:
        return await handle_refresh_api(self.api_implementation,
                                        APIOptions(
                                            request,
                                            response,
                                            self.recipe_id,
                                            self.config,
                                            self.recipe_implementation
                                        ))
    if request_id == SIGNOUT:
        return await handle_signout_api(self.api_implementation,
                                        APIOptions(
                                            request,
                                            response,
                                            self.recipe_id,
                                            self.config,
                                            self.recipe_implementation
                                        ))
    if self.openid_recipe is not None:
        return await self.openid_recipe.handle_api_request(request_id, request, path, method, response)
    return None
async def handle_error(self, request: BaseRequest, err: SuperTokensError, response: BaseResponse) ‑> BaseResponse
Expand source code
async def handle_error(self, request: BaseRequest, err: SuperTokensError, response: BaseResponse) -> BaseResponse:
    if isinstance(err, UnauthorisedError):
        log_debug_message("errorHandler: returning UNAUTHORISED")
        return await self.config.error_handlers.on_unauthorised(self, err.clear_cookies, request, str(err), response)
    if isinstance(err, TokenTheftError):
        log_debug_message("errorHandler: returning TOKEN_THEFT_DETECTED")
        return await self.config.error_handlers.on_token_theft_detected(self, request, err.session_handle, err.user_id, response)
    log_debug_message("errorHandler: returning TRY_REFRESH_TOKEN")
    return await self.config.error_handlers.on_try_refresh_token(request, str(err), response)
def is_error_from_this_recipe_based_on_instance(self, err: Exception) ‑> bool
Expand source code
def is_error_from_this_recipe_based_on_instance(
        self, err: Exception) -> bool:
    return isinstance(err, SuperTokensError) and (
        isinstance(err, SuperTokensSessionError)
        or
        (self.openid_recipe is not None and self.openid_recipe.is_error_from_this_recipe_based_on_instance(err))
    )
async def verify_session(self, request: BaseRequest, anti_csrf_check: Union[bool, None], session_required: bool, user_context: Dict[str, Any])
Expand source code
async def verify_session(self, request: BaseRequest,
                         anti_csrf_check: Union[bool, None],
                         session_required: bool, user_context: Dict[str, Any]):
    return await self.api_implementation.verify_session(
        APIOptions(
            request,
            None,
            self.recipe_id,
            self.config,
            self.recipe_implementation
        ), anti_csrf_check,
        session_required, user_context)