Module supertokens_python.recipe.session.cookie_and_header

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from typing import TYPE_CHECKING

from typing_extensions import Literal

from urllib.parse import quote, unquote

from .constants import (ACCESS_CONTROL_EXPOSE_HEADERS, ACCESS_TOKEN_COOKIE_KEY,
                        ANTI_CSRF_HEADER_KEY, FRONT_TOKEN_HEADER_SET_KEY,
                        ID_REFRESH_TOKEN_COOKIE_KEY,
                        ID_REFRESH_TOKEN_HEADER_SET_KEY,
                        REFRESH_TOKEN_COOKIE_KEY, RID_HEADER_KEY)

if TYPE_CHECKING:
    from supertokens_python.framework.request import BaseRequest
    from supertokens_python.framework.response import BaseResponse
    from .recipe import SessionRecipe

from json import dumps
from typing import Any, Dict, Union

from supertokens_python.exceptions import raise_general_exception
from supertokens_python.utils import get_header, utf_base64encode


def set_front_token_in_headers(response: BaseResponse, user_id: str, expires_at: int, jwt_payload: Union[None, Dict[str, Any]] = None):
    if jwt_payload is None:
        jwt_payload = {}
    token_info = {
        'uid': user_id,
        'ate': expires_at,
        'up': jwt_payload
    }
    set_header(
        response,
        FRONT_TOKEN_HEADER_SET_KEY,
        utf_base64encode(
            dumps(
                token_info,
                separators=(
                    ',',
                    ':'),
                sort_keys=True)),
        False)
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        FRONT_TOKEN_HEADER_SET_KEY,
        True)


def get_cors_allowed_headers():
    return [ANTI_CSRF_HEADER_KEY, RID_HEADER_KEY]


def set_header(response: BaseResponse,
               key: str, value: str, allow_duplicate: bool):
    try:
        if allow_duplicate:
            old_value = response.get_header(key)
            if old_value is None:
                response.set_header(key, value)
            else:
                response.set_header(key, old_value + ',' + value)
        else:
            response.set_header(key, value)
    except Exception:
        raise_general_exception(
            'Error while setting header with key: ' +
            key +
            ' and value: ' +
            value)


def get_cookie(request: BaseRequest, key: str):
    cookie_val = request.get_cookie(key)
    if cookie_val is None:
        return None
    return unquote(cookie_val)


def set_cookie(recipe: SessionRecipe, response: BaseResponse, key: str, value: str,
               expires: int, path_type: Literal['refresh_token_path', 'access_token_path']):
    domain = recipe.config.cookie_domain
    secure = recipe.config.cookie_secure
    same_site = recipe.config.cookie_same_site
    path = ''
    if path_type == 'refresh_token_path':
        path = recipe.config.refresh_token_path.get_as_string_dangerous()
    elif path_type == 'access_token_path':
        path = '/'
    http_only = True
    response.set_cookie(key=key, value=quote(value, encoding='utf-8'), expires=expires,
                        path=path, domain=domain, secure=secure, httponly=http_only, samesite=same_site)


def attach_anti_csrf_header(response: BaseResponse, value: str):
    set_header(response, ANTI_CSRF_HEADER_KEY, value, False)
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        ANTI_CSRF_HEADER_KEY, True)


def get_anti_csrf_header(request: BaseRequest):
    return get_header(request, ANTI_CSRF_HEADER_KEY)


def get_rid_header(request: BaseRequest):
    return get_header(request, RID_HEADER_KEY)


def attach_access_token_to_cookie(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_cookie(
        recipe,
        response,
        ACCESS_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'access_token_path')


def attach_refresh_token_to_cookie(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_cookie(
        recipe,
        response,
        REFRESH_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'refresh_token_path')


def attach_id_refresh_token_to_cookie_and_header(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_header(
        response,
        ID_REFRESH_TOKEN_HEADER_SET_KEY,
        token +
        ';' +
        str(expires_at),
        False
    )
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        ID_REFRESH_TOKEN_HEADER_SET_KEY,
        True
    )
    set_cookie(
        recipe,
        response,
        ID_REFRESH_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'access_token_path')


def get_access_token_from_cookie(request: BaseRequest):
    return get_cookie(request, ACCESS_TOKEN_COOKIE_KEY)


def get_refresh_token_from_cookie(request: BaseRequest):
    return get_cookie(request, REFRESH_TOKEN_COOKIE_KEY)


def get_id_refresh_token_from_cookie(request: BaseRequest):
    return get_cookie(request, ID_REFRESH_TOKEN_COOKIE_KEY)


def clear_cookies(recipe: SessionRecipe, response: BaseResponse):
    if response is not None:
        set_cookie(
            recipe,
            response,
            ACCESS_TOKEN_COOKIE_KEY,
            '',
            0,
            'access_token_path')
        set_cookie(
            recipe,
            response,
            ID_REFRESH_TOKEN_COOKIE_KEY,
            '',
            0,
            'access_token_path')
        set_cookie(
            recipe,
            response,
            REFRESH_TOKEN_COOKIE_KEY,
            '',
            0,
            'refresh_token_path')
        set_header(
            response,
            ID_REFRESH_TOKEN_HEADER_SET_KEY,
            "remove",
            False)
        set_header(
            response,
            ACCESS_CONTROL_EXPOSE_HEADERS,
            ID_REFRESH_TOKEN_HEADER_SET_KEY,
            True)

Functions

Expand source code
def attach_access_token_to_cookie(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_cookie(
        recipe,
        response,
        ACCESS_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'access_token_path')
def attach_anti_csrf_header(response: BaseResponse, value: str)
Expand source code
def attach_anti_csrf_header(response: BaseResponse, value: str):
    set_header(response, ANTI_CSRF_HEADER_KEY, value, False)
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        ANTI_CSRF_HEADER_KEY, True)
Expand source code
def attach_id_refresh_token_to_cookie_and_header(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_header(
        response,
        ID_REFRESH_TOKEN_HEADER_SET_KEY,
        token +
        ';' +
        str(expires_at),
        False
    )
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        ID_REFRESH_TOKEN_HEADER_SET_KEY,
        True
    )
    set_cookie(
        recipe,
        response,
        ID_REFRESH_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'access_token_path')
Expand source code
def attach_refresh_token_to_cookie(
        recipe: SessionRecipe, response: BaseResponse, token: str, expires_at: int):
    set_cookie(
        recipe,
        response,
        REFRESH_TOKEN_COOKIE_KEY,
        token,
        expires_at,
        'refresh_token_path')
def clear_cookies(recipe: SessionRecipe, response: BaseResponse)
Expand source code
def clear_cookies(recipe: SessionRecipe, response: BaseResponse):
    if response is not None:
        set_cookie(
            recipe,
            response,
            ACCESS_TOKEN_COOKIE_KEY,
            '',
            0,
            'access_token_path')
        set_cookie(
            recipe,
            response,
            ID_REFRESH_TOKEN_COOKIE_KEY,
            '',
            0,
            'access_token_path')
        set_cookie(
            recipe,
            response,
            REFRESH_TOKEN_COOKIE_KEY,
            '',
            0,
            'refresh_token_path')
        set_header(
            response,
            ID_REFRESH_TOKEN_HEADER_SET_KEY,
            "remove",
            False)
        set_header(
            response,
            ACCESS_CONTROL_EXPOSE_HEADERS,
            ID_REFRESH_TOKEN_HEADER_SET_KEY,
            True)
Expand source code
def get_access_token_from_cookie(request: BaseRequest):
    return get_cookie(request, ACCESS_TOKEN_COOKIE_KEY)
def get_anti_csrf_header(request: BaseRequest)
Expand source code
def get_anti_csrf_header(request: BaseRequest):
    return get_header(request, ANTI_CSRF_HEADER_KEY)
Expand source code
def get_cookie(request: BaseRequest, key: str):
    cookie_val = request.get_cookie(key)
    if cookie_val is None:
        return None
    return unquote(cookie_val)
def get_cors_allowed_headers()
Expand source code
def get_cors_allowed_headers():
    return [ANTI_CSRF_HEADER_KEY, RID_HEADER_KEY]
Expand source code
def get_id_refresh_token_from_cookie(request: BaseRequest):
    return get_cookie(request, ID_REFRESH_TOKEN_COOKIE_KEY)
Expand source code
def get_refresh_token_from_cookie(request: BaseRequest):
    return get_cookie(request, REFRESH_TOKEN_COOKIE_KEY)
def get_rid_header(request: BaseRequest)
Expand source code
def get_rid_header(request: BaseRequest):
    return get_header(request, RID_HEADER_KEY)
Expand source code
def set_cookie(recipe: SessionRecipe, response: BaseResponse, key: str, value: str,
               expires: int, path_type: Literal['refresh_token_path', 'access_token_path']):
    domain = recipe.config.cookie_domain
    secure = recipe.config.cookie_secure
    same_site = recipe.config.cookie_same_site
    path = ''
    if path_type == 'refresh_token_path':
        path = recipe.config.refresh_token_path.get_as_string_dangerous()
    elif path_type == 'access_token_path':
        path = '/'
    http_only = True
    response.set_cookie(key=key, value=quote(value, encoding='utf-8'), expires=expires,
                        path=path, domain=domain, secure=secure, httponly=http_only, samesite=same_site)
def set_front_token_in_headers(response: BaseResponse, user_id: str, expires_at: int, jwt_payload: Union[None, Dict[str, Any]] = None)
Expand source code
def set_front_token_in_headers(response: BaseResponse, user_id: str, expires_at: int, jwt_payload: Union[None, Dict[str, Any]] = None):
    if jwt_payload is None:
        jwt_payload = {}
    token_info = {
        'uid': user_id,
        'ate': expires_at,
        'up': jwt_payload
    }
    set_header(
        response,
        FRONT_TOKEN_HEADER_SET_KEY,
        utf_base64encode(
            dumps(
                token_info,
                separators=(
                    ',',
                    ':'),
                sort_keys=True)),
        False)
    set_header(
        response,
        ACCESS_CONTROL_EXPOSE_HEADERS,
        FRONT_TOKEN_HEADER_SET_KEY,
        True)
def set_header(response: BaseResponse, key: str, value: str, allow_duplicate: bool)
Expand source code
def set_header(response: BaseResponse,
               key: str, value: str, allow_duplicate: bool):
    try:
        if allow_duplicate:
            old_value = response.get_header(key)
            if old_value is None:
                response.set_header(key, value)
            else:
                response.set_header(key, old_value + ',' + value)
        else:
            response.set_header(key, value)
    except Exception:
        raise_general_exception(
            'Error while setting header with key: ' +
            key +
            ' and value: ' +
            value)