Module supertokens_python.recipe.session.recipe_implementation
Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from .session_class import Session
from supertokens_python.process_state import ProcessState, AllowedProcessStates
from supertokens_python.normalised_url_path import NormalisedURLPath
from typing import TYPE_CHECKING
from .interfaces import RecipeInterface
from .exceptions import raise_unauthorised_exception, raise_try_refresh_token_exception
from .cookie_and_header import get_id_refresh_token_from_cookie, get_access_token_from_cookie, get_anti_csrf_header, \
get_rid_header, get_refresh_token_from_cookie
from . import session_functions
from supertokens_python.utils import execute_in_background, FRAMEWORKS, frontend_has_interceptor, \
normalise_http_method, get_timestamp_ms
if TYPE_CHECKING:
from typing import Union, List
from .utils import SessionConfig
from supertokens_python.querier import Querier
class HandshakeInfo:
def __init__(self, info):
self.access_token_blacklisting_enabled = info['accessTokenBlacklistingEnabled']
self.raw_jwt_signing_public_key_list = None
self.anti_csrf = info['antiCsrf']
self.access_token_validity = info['accessTokenValidity']
self.refresh_token_validity = info['refreshTokenValidity']
def set_jwt_signing_public_key_list(self, updated_list: List):
self.raw_jwt_signing_public_key_list = updated_list
def get_jwt_signing_public_key_list(self) -> List:
time_now = get_timestamp_ms()
return [key for key in self.raw_jwt_signing_public_key_list if key['expiryTime'] > time_now]
class RecipeImplementation(RecipeInterface):
def __init__(self, querier: Querier, config: SessionConfig):
super().__init__()
self.querier = querier
self.config = config
self.handshake_info: Union[HandshakeInfo, None] = None
async def call_get_handshake_info():
try:
await self.get_handshake_info()
except Exception:
pass
try:
execute_in_background(config.mode, call_get_handshake_info)
except Exception:
pass
async def get_handshake_info(self, force_refetch=False) -> HandshakeInfo:
if self.handshake_info is None or len(
self.handshake_info.get_jwt_signing_public_key_list()) == 0 or force_refetch:
ProcessState.get_instance().add_state(
AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO)
response = await self.querier.send_post_request(NormalisedURLPath('/recipe/handshake'), {})
self.handshake_info = HandshakeInfo({
**response,
'antiCsrf': self.config.anti_csrf
})
self.update_jwt_signing_public_key_info(response['jwtSigningPublicKeyList'],
response['jwtSigningPublicKey'],
response['jwtSigningPublicKeyExpiryTime'])
return self.handshake_info
def update_jwt_signing_public_key_info(self, key_list: Union[List, None], public_key: str, expiry_time: int):
if key_list is None:
key_list = [{
'publicKey': public_key,
'expiryTime': expiry_time,
'createdAt': get_timestamp_ms()
}]
if self.handshake_info is not None:
self.handshake_info.set_jwt_signing_public_key_list(key_list)
async def create_new_session(self, request: any, user_id: str, access_token_payload: Union[dict, None] = None,
session_data: Union[dict, None] = None) -> Session:
if not hasattr(request, 'wrapper_used') or not request.wrapper_used:
request = FRAMEWORKS[self.config.framework].wrap_request(request)
session = await session_functions.create_new_session(self, user_id, access_token_payload, session_data)
access_token = session['accessToken']
refresh_token = session['refreshToken']
id_refresh_token = session['idRefreshToken']
new_session = Session(self, access_token['token'], session['session']['handle'],
session['session']['userId'], session['session']['userDataInJWT'])
new_session.new_access_token_info = access_token
new_session.new_refresh_token_info = refresh_token
new_session.new_id_refresh_token_info = id_refresh_token
if 'antiCsrfToken' in session and session['antiCsrfToken'] is not None:
new_session.new_anti_csrf_token = session['antiCsrfToken']
request.set_session(new_session)
return request.get_session()
async def get_session(self, request: any, anti_csrf_check: Union[bool, None] = None,
session_required: bool = True) -> Union[Session, None]:
if not hasattr(request, 'wrapper_used') or not request.wrapper_used:
request = FRAMEWORKS[self.config.framework].wrap_request(request)
id_refresh_token = get_id_refresh_token_from_cookie(request)
if id_refresh_token is None:
if not session_required:
return None
raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the '
'request as cookies?', False)
access_token = get_access_token_from_cookie(request)
if access_token is None:
if session_required is True or frontend_has_interceptor(request) or normalise_http_method(
request.method()) == 'get':
raise_try_refresh_token_exception(
'Access token has expired. Please call the refresh API')
return None
anti_csrf_token = get_anti_csrf_header(request)
if anti_csrf_check is None:
anti_csrf_check = normalise_http_method(request.method()) != 'get'
new_session = await session_functions.get_session(self, access_token, anti_csrf_token, anti_csrf_check,
get_rid_header(request) is not None)
if 'accessToken' in new_session:
access_token = new_session['accessToken']['token']
session = Session(self, access_token, new_session['session']['handle'],
new_session['session']['userId'], new_session['session']['userDataInJWT'])
if 'accessToken' in new_session:
session.new_access_token_info = new_session['accessToken']
request.set_session(session)
return request.get_session()
async def refresh_session(self, request: any) -> Session:
if not hasattr(request, 'wrapper_used') or not request.wrapper_used:
request = FRAMEWORKS[self.config.framework].wrap_request(request)
id_refresh_token = get_id_refresh_token_from_cookie(request)
if id_refresh_token is None:
raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the request '
'as cookies?', False)
refresh_token = get_refresh_token_from_cookie(request)
if refresh_token is None:
raise_unauthorised_exception('Refresh token not found. Are you sending the refresh token in the '
'request as a cookie?')
anti_csrf_token = get_anti_csrf_header(request)
new_session = await session_functions.refresh_session(self, refresh_token, anti_csrf_token,
get_rid_header(request) is not None)
access_token = new_session['accessToken']
refresh_token = new_session['refreshToken']
id_refresh_token = new_session['idRefreshToken']
session = Session(self, access_token['token'], new_session['session']['handle'],
new_session['session']['userId'], new_session['session']['userDataInJWT'])
session.new_access_token_info = access_token
session.new_refresh_token_info = refresh_token
session.new_id_refresh_token_info = id_refresh_token
if 'antiCsrfToken' in new_session and new_session['antiCsrfToken'] is not None:
session.new_anti_csrf_token = new_session['antiCsrfToken']
request.set_session(session)
return request.get_session()
async def revoke_session(self, session_handle: str) -> bool:
return await session_functions.revoke_session(self, session_handle)
async def revoke_all_sessions_for_user(self, user_id: str) -> List[str]:
return await session_functions.revoke_all_sessions_for_user(self, user_id)
async def get_all_session_handles_for_user(self, user_id: str) -> List[str]:
return await session_functions.get_all_session_handles_for_user(self, user_id)
async def revoke_multiple_sessions(self, session_handles: List[str]) -> List[str]:
return await session_functions.revoke_multiple_sessions(self, session_handles)
async def get_session_information(self, session_handle: str) -> dict:
return await session_functions.get_session_information(self, session_handle)
async def update_session_data(self, session_handle: str, new_session_data: dict) -> None:
await session_functions.update_session_data(self, session_handle, new_session_data)
async def update_access_token_payload(self, session_handle: str, new_access_token_payload: dict) -> None:
await session_functions.update_access_token_payload(self, session_handle, new_access_token_payload)
async def get_access_token_lifetime_ms(self) -> int:
return (await self.get_handshake_info()).access_token_validity
async def get_refresh_token_lifetime_ms(self) -> int:
return (await self.get_handshake_info()).refresh_token_validity
Classes
class HandshakeInfo (info)
-
Expand source code
class HandshakeInfo: def __init__(self, info): self.access_token_blacklisting_enabled = info['accessTokenBlacklistingEnabled'] self.raw_jwt_signing_public_key_list = None self.anti_csrf = info['antiCsrf'] self.access_token_validity = info['accessTokenValidity'] self.refresh_token_validity = info['refreshTokenValidity'] def set_jwt_signing_public_key_list(self, updated_list: List): self.raw_jwt_signing_public_key_list = updated_list def get_jwt_signing_public_key_list(self) -> List: time_now = get_timestamp_ms() return [key for key in self.raw_jwt_signing_public_key_list if key['expiryTime'] > time_now]
Methods
def get_jwt_signing_public_key_list(self) ‑> List
-
Expand source code
def get_jwt_signing_public_key_list(self) -> List: time_now = get_timestamp_ms() return [key for key in self.raw_jwt_signing_public_key_list if key['expiryTime'] > time_now]
def set_jwt_signing_public_key_list(self, updated_list: List)
-
Expand source code
def set_jwt_signing_public_key_list(self, updated_list: List): self.raw_jwt_signing_public_key_list = updated_list
class RecipeImplementation (querier: Querier, config: SessionConfig)
-
Helper class that provides a standard way to create an ABC using inheritance.
Expand source code
class RecipeImplementation(RecipeInterface): def __init__(self, querier: Querier, config: SessionConfig): super().__init__() self.querier = querier self.config = config self.handshake_info: Union[HandshakeInfo, None] = None async def call_get_handshake_info(): try: await self.get_handshake_info() except Exception: pass try: execute_in_background(config.mode, call_get_handshake_info) except Exception: pass async def get_handshake_info(self, force_refetch=False) -> HandshakeInfo: if self.handshake_info is None or len( self.handshake_info.get_jwt_signing_public_key_list()) == 0 or force_refetch: ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO) response = await self.querier.send_post_request(NormalisedURLPath('/recipe/handshake'), {}) self.handshake_info = HandshakeInfo({ **response, 'antiCsrf': self.config.anti_csrf }) self.update_jwt_signing_public_key_info(response['jwtSigningPublicKeyList'], response['jwtSigningPublicKey'], response['jwtSigningPublicKeyExpiryTime']) return self.handshake_info def update_jwt_signing_public_key_info(self, key_list: Union[List, None], public_key: str, expiry_time: int): if key_list is None: key_list = [{ 'publicKey': public_key, 'expiryTime': expiry_time, 'createdAt': get_timestamp_ms() }] if self.handshake_info is not None: self.handshake_info.set_jwt_signing_public_key_list(key_list) async def create_new_session(self, request: any, user_id: str, access_token_payload: Union[dict, None] = None, session_data: Union[dict, None] = None) -> Session: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) session = await session_functions.create_new_session(self, user_id, access_token_payload, session_data) access_token = session['accessToken'] refresh_token = session['refreshToken'] id_refresh_token = session['idRefreshToken'] new_session = Session(self, access_token['token'], session['session']['handle'], session['session']['userId'], session['session']['userDataInJWT']) new_session.new_access_token_info = access_token new_session.new_refresh_token_info = refresh_token new_session.new_id_refresh_token_info = id_refresh_token if 'antiCsrfToken' in session and session['antiCsrfToken'] is not None: new_session.new_anti_csrf_token = session['antiCsrfToken'] request.set_session(new_session) return request.get_session() async def get_session(self, request: any, anti_csrf_check: Union[bool, None] = None, session_required: bool = True) -> Union[Session, None]: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) id_refresh_token = get_id_refresh_token_from_cookie(request) if id_refresh_token is None: if not session_required: return None raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the ' 'request as cookies?', False) access_token = get_access_token_from_cookie(request) if access_token is None: if session_required is True or frontend_has_interceptor(request) or normalise_http_method( request.method()) == 'get': raise_try_refresh_token_exception( 'Access token has expired. Please call the refresh API') return None anti_csrf_token = get_anti_csrf_header(request) if anti_csrf_check is None: anti_csrf_check = normalise_http_method(request.method()) != 'get' new_session = await session_functions.get_session(self, access_token, anti_csrf_token, anti_csrf_check, get_rid_header(request) is not None) if 'accessToken' in new_session: access_token = new_session['accessToken']['token'] session = Session(self, access_token, new_session['session']['handle'], new_session['session']['userId'], new_session['session']['userDataInJWT']) if 'accessToken' in new_session: session.new_access_token_info = new_session['accessToken'] request.set_session(session) return request.get_session() async def refresh_session(self, request: any) -> Session: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) id_refresh_token = get_id_refresh_token_from_cookie(request) if id_refresh_token is None: raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the request ' 'as cookies?', False) refresh_token = get_refresh_token_from_cookie(request) if refresh_token is None: raise_unauthorised_exception('Refresh token not found. Are you sending the refresh token in the ' 'request as a cookie?') anti_csrf_token = get_anti_csrf_header(request) new_session = await session_functions.refresh_session(self, refresh_token, anti_csrf_token, get_rid_header(request) is not None) access_token = new_session['accessToken'] refresh_token = new_session['refreshToken'] id_refresh_token = new_session['idRefreshToken'] session = Session(self, access_token['token'], new_session['session']['handle'], new_session['session']['userId'], new_session['session']['userDataInJWT']) session.new_access_token_info = access_token session.new_refresh_token_info = refresh_token session.new_id_refresh_token_info = id_refresh_token if 'antiCsrfToken' in new_session and new_session['antiCsrfToken'] is not None: session.new_anti_csrf_token = new_session['antiCsrfToken'] request.set_session(session) return request.get_session() async def revoke_session(self, session_handle: str) -> bool: return await session_functions.revoke_session(self, session_handle) async def revoke_all_sessions_for_user(self, user_id: str) -> List[str]: return await session_functions.revoke_all_sessions_for_user(self, user_id) async def get_all_session_handles_for_user(self, user_id: str) -> List[str]: return await session_functions.get_all_session_handles_for_user(self, user_id) async def revoke_multiple_sessions(self, session_handles: List[str]) -> List[str]: return await session_functions.revoke_multiple_sessions(self, session_handles) async def get_session_information(self, session_handle: str) -> dict: return await session_functions.get_session_information(self, session_handle) async def update_session_data(self, session_handle: str, new_session_data: dict) -> None: await session_functions.update_session_data(self, session_handle, new_session_data) async def update_access_token_payload(self, session_handle: str, new_access_token_payload: dict) -> None: await session_functions.update_access_token_payload(self, session_handle, new_access_token_payload) async def get_access_token_lifetime_ms(self) -> int: return (await self.get_handshake_info()).access_token_validity async def get_refresh_token_lifetime_ms(self) -> int: return (await self.get_handshake_info()).refresh_token_validity
Ancestors
- RecipeInterface
- abc.ABC
Methods
async def create_new_session(self, request: any, user_id: str, access_token_payload: Union[dict, None] = None, session_data: Union[dict, None] = None) ‑> Session
-
Expand source code
async def create_new_session(self, request: any, user_id: str, access_token_payload: Union[dict, None] = None, session_data: Union[dict, None] = None) -> Session: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) session = await session_functions.create_new_session(self, user_id, access_token_payload, session_data) access_token = session['accessToken'] refresh_token = session['refreshToken'] id_refresh_token = session['idRefreshToken'] new_session = Session(self, access_token['token'], session['session']['handle'], session['session']['userId'], session['session']['userDataInJWT']) new_session.new_access_token_info = access_token new_session.new_refresh_token_info = refresh_token new_session.new_id_refresh_token_info = id_refresh_token if 'antiCsrfToken' in session and session['antiCsrfToken'] is not None: new_session.new_anti_csrf_token = session['antiCsrfToken'] request.set_session(new_session) return request.get_session()
async def get_access_token_lifetime_ms(self) ‑> int
-
Expand source code
async def get_access_token_lifetime_ms(self) -> int: return (await self.get_handshake_info()).access_token_validity
async def get_all_session_handles_for_user(self, user_id: str) ‑> List[str]
-
Expand source code
async def get_all_session_handles_for_user(self, user_id: str) -> List[str]: return await session_functions.get_all_session_handles_for_user(self, user_id)
async def get_handshake_info(self, force_refetch=False) ‑> HandshakeInfo
-
Expand source code
async def get_handshake_info(self, force_refetch=False) -> HandshakeInfo: if self.handshake_info is None or len( self.handshake_info.get_jwt_signing_public_key_list()) == 0 or force_refetch: ProcessState.get_instance().add_state( AllowedProcessStates.CALLING_SERVICE_IN_GET_HANDSHAKE_INFO) response = await self.querier.send_post_request(NormalisedURLPath('/recipe/handshake'), {}) self.handshake_info = HandshakeInfo({ **response, 'antiCsrf': self.config.anti_csrf }) self.update_jwt_signing_public_key_info(response['jwtSigningPublicKeyList'], response['jwtSigningPublicKey'], response['jwtSigningPublicKeyExpiryTime']) return self.handshake_info
async def get_refresh_token_lifetime_ms(self) ‑> int
-
Expand source code
async def get_refresh_token_lifetime_ms(self) -> int: return (await self.get_handshake_info()).refresh_token_validity
async def get_session(self, request: any, anti_csrf_check: Union[bool, None] = None, session_required: bool = True) ‑> Union[Session, None]
-
Expand source code
async def get_session(self, request: any, anti_csrf_check: Union[bool, None] = None, session_required: bool = True) -> Union[Session, None]: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) id_refresh_token = get_id_refresh_token_from_cookie(request) if id_refresh_token is None: if not session_required: return None raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the ' 'request as cookies?', False) access_token = get_access_token_from_cookie(request) if access_token is None: if session_required is True or frontend_has_interceptor(request) or normalise_http_method( request.method()) == 'get': raise_try_refresh_token_exception( 'Access token has expired. Please call the refresh API') return None anti_csrf_token = get_anti_csrf_header(request) if anti_csrf_check is None: anti_csrf_check = normalise_http_method(request.method()) != 'get' new_session = await session_functions.get_session(self, access_token, anti_csrf_token, anti_csrf_check, get_rid_header(request) is not None) if 'accessToken' in new_session: access_token = new_session['accessToken']['token'] session = Session(self, access_token, new_session['session']['handle'], new_session['session']['userId'], new_session['session']['userDataInJWT']) if 'accessToken' in new_session: session.new_access_token_info = new_session['accessToken'] request.set_session(session) return request.get_session()
async def get_session_information(self, session_handle: str) ‑> dict
-
Expand source code
async def get_session_information(self, session_handle: str) -> dict: return await session_functions.get_session_information(self, session_handle)
async def refresh_session(self, request: any) ‑> Session
-
Expand source code
async def refresh_session(self, request: any) -> Session: if not hasattr(request, 'wrapper_used') or not request.wrapper_used: request = FRAMEWORKS[self.config.framework].wrap_request(request) id_refresh_token = get_id_refresh_token_from_cookie(request) if id_refresh_token is None: raise_unauthorised_exception('Session does not exist. Are you sending the session tokens in the request ' 'as cookies?', False) refresh_token = get_refresh_token_from_cookie(request) if refresh_token is None: raise_unauthorised_exception('Refresh token not found. Are you sending the refresh token in the ' 'request as a cookie?') anti_csrf_token = get_anti_csrf_header(request) new_session = await session_functions.refresh_session(self, refresh_token, anti_csrf_token, get_rid_header(request) is not None) access_token = new_session['accessToken'] refresh_token = new_session['refreshToken'] id_refresh_token = new_session['idRefreshToken'] session = Session(self, access_token['token'], new_session['session']['handle'], new_session['session']['userId'], new_session['session']['userDataInJWT']) session.new_access_token_info = access_token session.new_refresh_token_info = refresh_token session.new_id_refresh_token_info = id_refresh_token if 'antiCsrfToken' in new_session and new_session['antiCsrfToken'] is not None: session.new_anti_csrf_token = new_session['antiCsrfToken'] request.set_session(session) return request.get_session()
async def revoke_all_sessions_for_user(self, user_id: str) ‑> List[str]
-
Expand source code
async def revoke_all_sessions_for_user(self, user_id: str) -> List[str]: return await session_functions.revoke_all_sessions_for_user(self, user_id)
async def revoke_multiple_sessions(self, session_handles: List[str]) ‑> List[str]
-
Expand source code
async def revoke_multiple_sessions(self, session_handles: List[str]) -> List[str]: return await session_functions.revoke_multiple_sessions(self, session_handles)
async def revoke_session(self, session_handle: str) ‑> bool
-
Expand source code
async def revoke_session(self, session_handle: str) -> bool: return await session_functions.revoke_session(self, session_handle)
async def update_access_token_payload(self, session_handle: str, new_access_token_payload: dict) ‑> None
-
Expand source code
async def update_access_token_payload(self, session_handle: str, new_access_token_payload: dict) -> None: await session_functions.update_access_token_payload(self, session_handle, new_access_token_payload)
def update_jwt_signing_public_key_info(self, key_list: Union[List, None], public_key: str, expiry_time: int)
-
Expand source code
def update_jwt_signing_public_key_info(self, key_list: Union[List, None], public_key: str, expiry_time: int): if key_list is None: key_list = [{ 'publicKey': public_key, 'expiryTime': expiry_time, 'createdAt': get_timestamp_ms() }] if self.handshake_info is not None: self.handshake_info.set_jwt_signing_public_key_list(key_list)
async def update_session_data(self, session_handle: str, new_session_data: dict) ‑> None
-
Expand source code
async def update_session_data(self, session_handle: str, new_session_data: dict) -> None: await session_functions.update_session_data(self, session_handle, new_session_data)