Module supertokens_python.recipe.thirdparty.providers.config_utils

Expand source code
from typing import List, Dict, Optional, Any

from supertokens_python.normalised_url_domain import NormalisedURLDomain
from supertokens_python.normalised_url_path import NormalisedURLPath
from .active_directory import ActiveDirectory
from .apple import Apple
from .bitbucket import Bitbucket
from .boxy_saml import BoxySAML
from .discord import Discord
from .facebook import Facebook
from .github import Github
from .gitlab import Gitlab
from .google_workspaces import GoogleWorkspaces
from .google import Google
from .linkedin import Linkedin
from .twitter import Twitter
from .okta import Okta
from .custom import NewProvider
from .utils import do_get_request

from ..provider import (
    ProviderConfig,
    ProviderConfigForClient,
    ProviderInput,
    Provider,
    UserFields,
    UserInfoMap,
)


def merge_config(
    config_from_static: ProviderConfig, config_from_core: ProviderConfig
) -> ProviderConfig:
    result = ProviderConfig(
        third_party_id=config_from_static.third_party_id,
        name=(
            config_from_static.name
            if config_from_core.name is None
            else config_from_core.name
        ),
        authorization_endpoint=(
            config_from_static.authorization_endpoint
            if config_from_core.authorization_endpoint is None
            else config_from_core.authorization_endpoint
        ),
        authorization_endpoint_query_params=(
            config_from_static.authorization_endpoint_query_params
            if config_from_core.authorization_endpoint_query_params is None
            else config_from_core.authorization_endpoint_query_params
        ),
        token_endpoint=(
            config_from_static.token_endpoint
            if config_from_core.token_endpoint is None
            else config_from_core.token_endpoint
        ),
        token_endpoint_body_params=(
            config_from_static.token_endpoint_body_params
            if config_from_core.token_endpoint_body_params is None
            else config_from_core.token_endpoint_body_params
        ),
        user_info_endpoint=(
            config_from_static.user_info_endpoint
            if config_from_core.user_info_endpoint is None
            else config_from_core.user_info_endpoint
        ),
        user_info_endpoint_headers=(
            config_from_static.user_info_endpoint_headers
            if config_from_core.user_info_endpoint_headers is None
            else config_from_core.user_info_endpoint_headers
        ),
        user_info_endpoint_query_params=(
            config_from_static.user_info_endpoint_query_params
            if config_from_core.user_info_endpoint_query_params is None
            else config_from_core.user_info_endpoint_query_params
        ),
        jwks_uri=(
            config_from_static.jwks_uri
            if config_from_core.jwks_uri is None
            else config_from_core.jwks_uri
        ),
        oidc_discovery_endpoint=(
            config_from_static.oidc_discovery_endpoint
            if config_from_core.oidc_discovery_endpoint is None
            else config_from_core.oidc_discovery_endpoint
        ),
        require_email=config_from_static.require_email,
        user_info_map=config_from_static.user_info_map,
        generate_fake_email=config_from_static.generate_fake_email,
        validate_id_token_payload=config_from_static.validate_id_token_payload,
        validate_access_token=config_from_static.validate_access_token,
    )

    if result.user_info_map is None:
        result.user_info_map = UserInfoMap(UserFields(), UserFields())

    if result.user_info_map.from_user_info_api is None:
        result.user_info_map.from_user_info_api = UserFields()
    if result.user_info_map.from_id_token_payload is None:
        result.user_info_map.from_id_token_payload = UserFields()

    if config_from_core.user_info_map is not None:
        if config_from_core.user_info_map.from_user_info_api is None:
            config_from_core.user_info_map.from_user_info_api = UserFields()
        if config_from_core.user_info_map.from_id_token_payload is None:
            config_from_core.user_info_map.from_id_token_payload = UserFields()

        if config_from_core.user_info_map.from_id_token_payload.user_id is not None:
            result.user_info_map.from_id_token_payload.user_id = (
                config_from_core.user_info_map.from_id_token_payload.user_id
            )
        if config_from_core.user_info_map.from_id_token_payload.email is not None:
            result.user_info_map.from_id_token_payload.email = (
                config_from_core.user_info_map.from_id_token_payload.email
            )
        if (
            config_from_core.user_info_map.from_id_token_payload.email_verified
            is not None
        ):
            result.user_info_map.from_id_token_payload.email_verified = (
                config_from_core.user_info_map.from_id_token_payload.email_verified
            )

        if config_from_core.user_info_map.from_user_info_api.user_id is not None:
            result.user_info_map.from_user_info_api.user_id = (
                config_from_core.user_info_map.from_user_info_api.user_id
            )
        if config_from_core.user_info_map.from_user_info_api.email is not None:
            result.user_info_map.from_user_info_api.email = (
                config_from_core.user_info_map.from_user_info_api.email
            )
        if config_from_core.user_info_map.from_user_info_api.email_verified is not None:
            result.user_info_map.from_user_info_api.email_verified = (
                config_from_core.user_info_map.from_user_info_api.email_verified
            )

    merged_clients = (config_from_static.clients or [])[:]  # Make a copy
    core_config_clients = config_from_core.clients or []

    for core_client in core_config_clients:
        found = False
        for idx, static_client in enumerate(merged_clients):
            if static_client.client_type == core_client.client_type:
                merged_clients[idx] = core_client
                found = True
                break

        if not found:
            merged_clients.append(core_client)

    result.clients = merged_clients

    return result


def merge_providers_from_core_and_static(
    provider_configs_from_core: List[ProviderConfig],
    provider_inputs_from_static: List[ProviderInput],
    include_all_providers: bool,
) -> List[ProviderInput]:
    merged_providers: List[ProviderInput] = []

    if len(provider_configs_from_core) == 0:
        for config in filter(
            lambda provider: include_all_providers
            or provider.include_in_non_public_tenants_by_default,
            provider_inputs_from_static,
        ):
            merged_providers.append(config)
    else:
        for provider_config_from_core in provider_configs_from_core:
            merged_provider_input = ProviderInput(provider_config_from_core)

            for provider_input_from_static in provider_inputs_from_static:
                if (
                    provider_input_from_static.config.third_party_id
                    == provider_config_from_core.third_party_id
                ):
                    merged_provider_input.config = merge_config(
                        provider_input_from_static.config, provider_config_from_core
                    )
                    merged_provider_input.override = provider_input_from_static.override
                    break

            merged_providers.append(merged_provider_input)

    return merged_providers


def create_provider(provider_input: ProviderInput) -> Provider:
    if provider_input.config.third_party_id.startswith("active-directory"):
        return ActiveDirectory(provider_input)
    if provider_input.config.third_party_id.startswith("apple"):
        return Apple(provider_input)
    if provider_input.config.third_party_id.startswith("bitbucket"):
        return Bitbucket(provider_input)
    if provider_input.config.third_party_id.startswith("discord"):
        return Discord(provider_input)
    if provider_input.config.third_party_id.startswith("facebook"):
        return Facebook(provider_input)
    if provider_input.config.third_party_id.startswith("github"):
        return Github(provider_input)
    if provider_input.config.third_party_id.startswith("gitlab"):
        return Gitlab(provider_input)
    if provider_input.config.third_party_id.startswith("google-workspaces"):
        return GoogleWorkspaces(provider_input)
    if provider_input.config.third_party_id.startswith("google"):
        return Google(provider_input)
    if provider_input.config.third_party_id.startswith("okta"):
        return Okta(provider_input)
    if provider_input.config.third_party_id.startswith("linkedin"):
        return Linkedin(provider_input)
    if provider_input.config.third_party_id.startswith("twitter"):
        return Twitter(provider_input)
    if provider_input.config.third_party_id.startswith("boxy-saml"):
        return BoxySAML(provider_input)

    return NewProvider(provider_input)


OIDC_INFO_MAP: Dict[str, Any] = {}


async def get_oidc_discovery_info(issuer: str):
    if issuer in OIDC_INFO_MAP:
        return OIDC_INFO_MAP[issuer]

    ndomain = NormalisedURLDomain(issuer)
    npath = NormalisedURLPath(issuer)

    oidc_info = await do_get_request(
        ndomain.get_as_string_dangerous() + npath.get_as_string_dangerous()
    )
    OIDC_INFO_MAP[issuer] = oidc_info

    return oidc_info


async def discover_oidc_endpoints(
    config: ProviderConfigForClient,
) -> ProviderConfigForClient:
    if config.oidc_discovery_endpoint is None:
        return config

    oidc_info = await get_oidc_discovery_info(config.oidc_discovery_endpoint)
    if (
        oidc_info.get("authorization_endpoint") is not None
        and config.authorization_endpoint is None
    ):
        config.authorization_endpoint = oidc_info["authorization_endpoint"]

    if oidc_info.get("token_endpoint") is not None and config.token_endpoint is None:
        config.token_endpoint = oidc_info["token_endpoint"]

    if (
        oidc_info.get("userinfo_endpoint") is not None
        and config.user_info_endpoint is None
    ):
        config.user_info_endpoint = oidc_info["userinfo_endpoint"]

    if oidc_info.get("jwks_uri") is not None and config.jwks_uri is None:
        config.jwks_uri = oidc_info["jwks_uri"]

    return config


async def fetch_and_set_config(
    provider_instance: Provider,
    client_type: Optional[str],
    user_context: Dict[str, Any],
):
    config = await provider_instance.get_config_for_client_type(
        client_type, user_context
    )
    config = await discover_oidc_endpoints(config)
    provider_instance.config = config


async def find_and_create_provider_instance(
    providers: List[ProviderInput],
    third_party_id: str,
    client_type: Optional[str],
    user_context: Dict[str, Any],
) -> Optional[Provider]:
    for provider_input in providers:
        if provider_input.config.third_party_id == third_party_id:
            provider_instance = create_provider(provider_input)
            await fetch_and_set_config(provider_instance, client_type, user_context)
            return provider_instance

    return None

Functions

def create_provider(provider_input: ProviderInput) ‑> Provider
async def discover_oidc_endpoints(config: ProviderConfigForClient) ‑> ProviderConfigForClient
async def fetch_and_set_config(provider_instance: Provider, client_type: Optional[str], user_context: Dict[str, Any])
async def find_and_create_provider_instance(providers: List[ProviderInput], third_party_id: str, client_type: Optional[str], user_context: Dict[str, Any]) ‑> Optional[Provider]
async def get_oidc_discovery_info(issuer: str)
def merge_config(config_from_static: ProviderConfig, config_from_core: ProviderConfig) ‑> ProviderConfig
def merge_providers_from_core_and_static(provider_configs_from_core: List[ProviderConfig], provider_inputs_from_static: List[ProviderInput], include_all_providers: bool) ‑> List[ProviderInput]