Module supertokens_python.recipe.session.claim_base_classes.primitive_array_claim

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from typing import Any, Callable, Dict, Optional, TypeVar, Union, Generic, List

from supertokens_python.types import MaybeAwaitable
from supertokens_python.utils import get_timestamp_ms

from ..interfaces import (
    JSONObject,
    JSONPrimitive,
    SessionClaim,
    SessionClaimValidator,
    ClaimValidationResult,
    JSONPrimitiveList,
)


Primitive = TypeVar("Primitive", bound=JSONPrimitive)
PrimitiveList = TypeVar("PrimitiveList", bound=JSONPrimitiveList)

_T = TypeVar("_T")


class SCVMixin(SessionClaimValidator, Generic[_T]):
    def __init__(
        self,
        id_: str,
        claim: SessionClaim[PrimitiveList],
        val: _T,
        max_age_in_sec: Optional[int] = None,
    ):
        super().__init__(id_)
        self.claim: SessionClaim[PrimitiveList] = claim  # TODO:PrimitiveArrayClaim
        self.val = val
        self.max_age_in_sec = max_age_in_sec

    def should_refetch(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ) -> bool:
        claim = self.claim

        return (claim.get_value_from_payload(payload, user_context) is None) or (
            self.max_age_in_sec is not None
            and (
                payload[claim.key]["t"]
                < get_timestamp_ms() - self.max_age_in_sec * 1000
            )
        )

    async def _validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
        is_include: bool,
    ):
        val = self.val
        max_age_in_sec = self.max_age_in_sec

        expected_key = "expectedToInclude" if is_include else "expectedToNotInclude"

        assert isinstance(self.claim, PrimitiveArrayClaim)
        claim_val = self.claim.get_value_from_payload(payload, user_context)

        if claim_val is None:
            return ClaimValidationResult(
                is_valid=False,
                reason={
                    "message": "value does not exist",
                    expected_key: val,
                    "actualValue": claim_val,
                },
            )

        last_refetch_time = self.claim.get_last_refetch_time(payload, user_context)
        assert last_refetch_time is not None
        age_in_sec = (get_timestamp_ms() - last_refetch_time) / 1000
        if max_age_in_sec is not None and age_in_sec > max_age_in_sec:
            return ClaimValidationResult(
                is_valid=False,
                reason={
                    "message": "expired",
                    "ageInSeconds": age_in_sec,
                    "maxAgeInSeconds": max_age_in_sec,
                },
            )

        # Doing this to ensure same code in the upcoming steps irrespective of
        # whether self.val is Primitive or PrimitiveList
        vals: List[JSONPrimitive] = (
            val if isinstance(val, list) else [val]
        )  # pyright: reportGeneralTypeIssues=false

        claim_val_set = set(claim_val)
        if is_include:
            for v in vals:
                if v not in claim_val_set:
                    return ClaimValidationResult(
                        is_valid=False,
                        reason={
                            "message": "wrong value",
                            expected_key: val,
                            # other SDKs return the item itself
                            "actualValue": claim_val,
                        },
                    )
        else:
            for v in vals:
                if v in claim_val_set:
                    return ClaimValidationResult(
                        is_valid=False,
                        reason={
                            "message": "wrong value",
                            expected_key: val,
                            # other SDKs return the item itself
                            "actualValue": claim_val,
                        },
                    )

        return ClaimValidationResult(is_valid=True)


class IncludesSCV(SCVMixin[Primitive]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=True)


class ExcludesSCV(SCVMixin[Primitive]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=False)


class IncludesAllSCV(SCVMixin[PrimitiveList]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=True)


class ExcludesAllSCV(SCVMixin[PrimitiveList]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=False)


class PrimitiveArrayClaimValidators(Generic[PrimitiveList]):
    def __init__(
        self,
        claim: SessionClaim[PrimitiveList],
        default_max_age_in_sec: Optional[int] = None,
    ) -> None:
        self.claim = claim
        self.default_max_age_in_sec = default_max_age_in_sec

    def includes(  # pyright: ignore[reportInvalidTypeVarUse]
        self,
        val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return IncludesSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def excludes(  # pyright: ignore[reportInvalidTypeVarUse]
        self,
        val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return ExcludesSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def includes_all(
        self,
        val: PrimitiveList,
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return IncludesAllSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def excludes_all(
        self,
        val: PrimitiveList,
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return ExcludesAllSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )


class PrimitiveArrayClaim(SessionClaim[PrimitiveList], Generic[PrimitiveList]):
    def __init__(
        self,
        key: str,
        fetch_value: Callable[
            [str, Dict[str, Any]],
            MaybeAwaitable[Optional[PrimitiveList]],
        ],
        default_max_age_in_sec: Optional[int] = None,
    ) -> None:
        super().__init__(key, fetch_value)

        claim = self
        self.validators = PrimitiveArrayClaimValidators(claim, default_max_age_in_sec)

    def add_to_payload_(
        self,
        payload: Dict[str, Any],
        value: PrimitiveList,
        user_context: Union[Dict[str, Any], None] = None,
    ) -> JSONObject:
        payload[self.key] = {"v": value, "t": get_timestamp_ms()}
        _ = user_context

        return payload

    def remove_from_payload_by_merge_(
        self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
    ) -> JSONObject:
        payload[self.key] = None
        return payload

    def remove_from_payload(
        self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
    ) -> JSONObject:
        del payload[self.key]
        return payload

    def get_value_from_payload(
        self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
    ) -> Union[PrimitiveList, None]:
        _ = user_context

        return payload.get(self.key, {}).get("v")

    def get_last_refetch_time(
        self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
    ) -> Union[int, None]:
        _ = user_context

        return payload.get(self.key, {}).get("t")

Classes

class ExcludesAllSCV (id_: str, claim: SessionClaim[~PrimitiveList], val: ~_T, max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class ExcludesAllSCV(SCVMixin[PrimitiveList]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=False)

Ancestors

Methods

async def validate(self, payload: Dict[str, Any], user_context: Dict[str, Any])
Expand source code
async def validate(
    self,
    payload: JSONObject,
    user_context: Dict[str, Any],
):
    return await self._validate(payload, user_context, is_include=False)
class ExcludesSCV (id_: str, claim: SessionClaim[~PrimitiveList], val: ~_T, max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class ExcludesSCV(SCVMixin[Primitive]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=False)

Ancestors

Methods

async def validate(self, payload: Dict[str, Any], user_context: Dict[str, Any])
Expand source code
async def validate(
    self,
    payload: JSONObject,
    user_context: Dict[str, Any],
):
    return await self._validate(payload, user_context, is_include=False)
class IncludesAllSCV (id_: str, claim: SessionClaim[~PrimitiveList], val: ~_T, max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class IncludesAllSCV(SCVMixin[PrimitiveList]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=True)

Ancestors

Methods

async def validate(self, payload: Dict[str, Any], user_context: Dict[str, Any])
Expand source code
async def validate(
    self,
    payload: JSONObject,
    user_context: Dict[str, Any],
):
    return await self._validate(payload, user_context, is_include=True)
class IncludesSCV (id_: str, claim: SessionClaim[~PrimitiveList], val: ~_T, max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class IncludesSCV(SCVMixin[Primitive]):
    async def validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ):
        return await self._validate(payload, user_context, is_include=True)

Ancestors

Methods

async def validate(self, payload: Dict[str, Any], user_context: Dict[str, Any])
Expand source code
async def validate(
    self,
    payload: JSONObject,
    user_context: Dict[str, Any],
):
    return await self._validate(payload, user_context, is_include=True)
class PrimitiveArrayClaim (key: str, fetch_value: Callable[[str, Dict[str, Any]], Union[Awaitable[Optional[~PrimitiveList]], ~PrimitiveList, ForwardRef(None)]], default_max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Args

key
The key to use when storing the claim in the payload.
fetch_value
a method that fetches the current value of this claim for the user. A None return value signifies that we don't want to update the claim payload and or the claim value is not present in the database. For example, this can happen with a second factor auth claim, where we don't want to add the claim to the session automatically
Expand source code
class PrimitiveArrayClaim(SessionClaim[PrimitiveList], Generic[PrimitiveList]):
    def __init__(
        self,
        key: str,
        fetch_value: Callable[
            [str, Dict[str, Any]],
            MaybeAwaitable[Optional[PrimitiveList]],
        ],
        default_max_age_in_sec: Optional[int] = None,
    ) -> None:
        super().__init__(key, fetch_value)

        claim = self
        self.validators = PrimitiveArrayClaimValidators(claim, default_max_age_in_sec)

    def add_to_payload_(
        self,
        payload: Dict[str, Any],
        value: PrimitiveList,
        user_context: Union[Dict[str, Any], None] = None,
    ) -> JSONObject:
        payload[self.key] = {"v": value, "t": get_timestamp_ms()}
        _ = user_context

        return payload

    def remove_from_payload_by_merge_(
        self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
    ) -> JSONObject:
        payload[self.key] = None
        return payload

    def remove_from_payload(
        self, payload: JSONObject, user_context: Optional[Dict[str, Any]] = None
    ) -> JSONObject:
        del payload[self.key]
        return payload

    def get_value_from_payload(
        self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
    ) -> Union[PrimitiveList, None]:
        _ = user_context

        return payload.get(self.key, {}).get("v")

    def get_last_refetch_time(
        self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
    ) -> Union[int, None]:
        _ = user_context

        return payload.get(self.key, {}).get("t")

Ancestors

Subclasses

Methods

def get_last_refetch_time(self, payload: Dict[str, Any], user_context: Optional[Dict[str, Any]] = None) ‑> Optional[None]
Expand source code
def get_last_refetch_time(
    self, payload: JSONObject, user_context: Union[Dict[str, Any], None] = None
) -> Union[int, None]:
    _ = user_context

    return payload.get(self.key, {}).get("t")

Inherited members

class PrimitiveArrayClaimValidators (claim: SessionClaim[~PrimitiveList], default_max_age_in_sec: Optional[None] = None)

Abstract base class for generic types.

A generic type is typically declared by inheriting from this class parameterized with one or more type variables. For example, a generic mapping type might be defined as::

class Mapping(Generic[KT, VT]): def getitem(self, key: KT) -> VT: … # Etc.

This class can then be used as follows::

def lookup_name(mapping: Mapping[KT, VT], key: KT, default: VT) -> VT: try: return mapping[key] except KeyError: return default

Expand source code
class PrimitiveArrayClaimValidators(Generic[PrimitiveList]):
    def __init__(
        self,
        claim: SessionClaim[PrimitiveList],
        default_max_age_in_sec: Optional[int] = None,
    ) -> None:
        self.claim = claim
        self.default_max_age_in_sec = default_max_age_in_sec

    def includes(  # pyright: ignore[reportInvalidTypeVarUse]
        self,
        val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return IncludesSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def excludes(  # pyright: ignore[reportInvalidTypeVarUse]
        self,
        val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return ExcludesSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def includes_all(
        self,
        val: PrimitiveList,
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return IncludesAllSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

    def excludes_all(
        self,
        val: PrimitiveList,
        max_age_in_seconds: Optional[int] = None,
        id_: Union[str, None] = None,
    ) -> SessionClaimValidator:
        max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
        return ExcludesAllSCV(
            (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
        )

Ancestors

  • typing.Generic

Methods

def excludes(self, val: ~Primitive, max_age_in_seconds: Optional[None] = None, id_: Optional[str] = None) ‑> SessionClaimValidator
Expand source code
def excludes(  # pyright: ignore[reportInvalidTypeVarUse]
    self,
    val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
    max_age_in_seconds: Optional[int] = None,
    id_: Union[str, None] = None,
) -> SessionClaimValidator:
    max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
    return ExcludesSCV(
        (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
    )
def excludes_all(self, val: ~PrimitiveList, max_age_in_seconds: Optional[None] = None, id_: Optional[str] = None) ‑> SessionClaimValidator
Expand source code
def excludes_all(
    self,
    val: PrimitiveList,
    max_age_in_seconds: Optional[int] = None,
    id_: Union[str, None] = None,
) -> SessionClaimValidator:
    max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
    return ExcludesAllSCV(
        (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
    )
def includes(self, val: ~Primitive, max_age_in_seconds: Optional[None] = None, id_: Optional[str] = None) ‑> SessionClaimValidator
Expand source code
def includes(  # pyright: ignore[reportInvalidTypeVarUse]
    self,
    val: Primitive,  # pyright: ignore[reportInvalidTypeVarUse]
    max_age_in_seconds: Optional[int] = None,
    id_: Union[str, None] = None,
) -> SessionClaimValidator:
    max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
    return IncludesSCV(
        (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
    )
def includes_all(self, val: ~PrimitiveList, max_age_in_seconds: Optional[None] = None, id_: Optional[str] = None) ‑> SessionClaimValidator
Expand source code
def includes_all(
    self,
    val: PrimitiveList,
    max_age_in_seconds: Optional[int] = None,
    id_: Union[str, None] = None,
) -> SessionClaimValidator:
    max_age_in_sec = max_age_in_seconds or self.default_max_age_in_sec
    return IncludesAllSCV(
        (id_ or self.claim.key), self.claim, val=val, max_age_in_sec=max_age_in_sec
    )
class SCVMixin (id_: str, claim: SessionClaim[~PrimitiveList], val: ~_T, max_age_in_sec: Optional[None] = None)

Helper class that provides a standard way to create an ABC using inheritance.

Expand source code
class SCVMixin(SessionClaimValidator, Generic[_T]):
    def __init__(
        self,
        id_: str,
        claim: SessionClaim[PrimitiveList],
        val: _T,
        max_age_in_sec: Optional[int] = None,
    ):
        super().__init__(id_)
        self.claim: SessionClaim[PrimitiveList] = claim  # TODO:PrimitiveArrayClaim
        self.val = val
        self.max_age_in_sec = max_age_in_sec

    def should_refetch(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
    ) -> bool:
        claim = self.claim

        return (claim.get_value_from_payload(payload, user_context) is None) or (
            self.max_age_in_sec is not None
            and (
                payload[claim.key]["t"]
                < get_timestamp_ms() - self.max_age_in_sec * 1000
            )
        )

    async def _validate(
        self,
        payload: JSONObject,
        user_context: Dict[str, Any],
        is_include: bool,
    ):
        val = self.val
        max_age_in_sec = self.max_age_in_sec

        expected_key = "expectedToInclude" if is_include else "expectedToNotInclude"

        assert isinstance(self.claim, PrimitiveArrayClaim)
        claim_val = self.claim.get_value_from_payload(payload, user_context)

        if claim_val is None:
            return ClaimValidationResult(
                is_valid=False,
                reason={
                    "message": "value does not exist",
                    expected_key: val,
                    "actualValue": claim_val,
                },
            )

        last_refetch_time = self.claim.get_last_refetch_time(payload, user_context)
        assert last_refetch_time is not None
        age_in_sec = (get_timestamp_ms() - last_refetch_time) / 1000
        if max_age_in_sec is not None and age_in_sec > max_age_in_sec:
            return ClaimValidationResult(
                is_valid=False,
                reason={
                    "message": "expired",
                    "ageInSeconds": age_in_sec,
                    "maxAgeInSeconds": max_age_in_sec,
                },
            )

        # Doing this to ensure same code in the upcoming steps irrespective of
        # whether self.val is Primitive or PrimitiveList
        vals: List[JSONPrimitive] = (
            val if isinstance(val, list) else [val]
        )  # pyright: reportGeneralTypeIssues=false

        claim_val_set = set(claim_val)
        if is_include:
            for v in vals:
                if v not in claim_val_set:
                    return ClaimValidationResult(
                        is_valid=False,
                        reason={
                            "message": "wrong value",
                            expected_key: val,
                            # other SDKs return the item itself
                            "actualValue": claim_val,
                        },
                    )
        else:
            for v in vals:
                if v in claim_val_set:
                    return ClaimValidationResult(
                        is_valid=False,
                        reason={
                            "message": "wrong value",
                            expected_key: val,
                            # other SDKs return the item itself
                            "actualValue": claim_val,
                        },
                    )

        return ClaimValidationResult(is_valid=True)

Ancestors

Subclasses

Methods

def should_refetch(self, payload: Dict[str, Any], user_context: Dict[str, Any]) ‑> bool
Expand source code
def should_refetch(
    self,
    payload: JSONObject,
    user_context: Dict[str, Any],
) -> bool:
    claim = self.claim

    return (claim.get_value_from_payload(payload, user_context) is None) or (
        self.max_age_in_sec is not None
        and (
            payload[claim.key]["t"]
            < get_timestamp_ms() - self.max_age_in_sec * 1000
        )
    )