Module supertokens_python.recipe.thirdparty.providers.config_utils
Expand source code
from typing import List, Dict, Optional, Any
from supertokens_python.normalised_url_domain import NormalisedURLDomain
from supertokens_python.normalised_url_path import NormalisedURLPath
from .active_directory import ActiveDirectory
from .apple import Apple
from .bitbucket import Bitbucket
from .boxy_saml import BoxySAML
from .discord import Discord
from .facebook import Facebook
from .github import Github
from .gitlab import Gitlab
from .google_workspaces import GoogleWorkspaces
from .google import Google
from .linkedin import Linkedin
from .twitter import Twitter
from .okta import Okta
from .custom import NewProvider
from .utils import do_get_request
from ..provider import (
ProviderConfig,
ProviderConfigForClient,
ProviderInput,
Provider,
UserFields,
UserInfoMap,
)
def merge_config(
config_from_static: ProviderConfig, config_from_core: ProviderConfig
) -> ProviderConfig:
result = ProviderConfig(
third_party_id=config_from_static.third_party_id,
name=(
config_from_static.name
if config_from_core.name is None
else config_from_core.name
),
authorization_endpoint=(
config_from_static.authorization_endpoint
if config_from_core.authorization_endpoint is None
else config_from_core.authorization_endpoint
),
authorization_endpoint_query_params=(
config_from_static.authorization_endpoint_query_params
if config_from_core.authorization_endpoint_query_params is None
else config_from_core.authorization_endpoint_query_params
),
token_endpoint=(
config_from_static.token_endpoint
if config_from_core.token_endpoint is None
else config_from_core.token_endpoint
),
token_endpoint_body_params=(
config_from_static.token_endpoint_body_params
if config_from_core.token_endpoint_body_params is None
else config_from_core.token_endpoint_body_params
),
user_info_endpoint=(
config_from_static.user_info_endpoint
if config_from_core.user_info_endpoint is None
else config_from_core.user_info_endpoint
),
user_info_endpoint_headers=(
config_from_static.user_info_endpoint_headers
if config_from_core.user_info_endpoint_headers is None
else config_from_core.user_info_endpoint_headers
),
user_info_endpoint_query_params=(
config_from_static.user_info_endpoint_query_params
if config_from_core.user_info_endpoint_query_params is None
else config_from_core.user_info_endpoint_query_params
),
jwks_uri=(
config_from_static.jwks_uri
if config_from_core.jwks_uri is None
else config_from_core.jwks_uri
),
oidc_discovery_endpoint=(
config_from_static.oidc_discovery_endpoint
if config_from_core.oidc_discovery_endpoint is None
else config_from_core.oidc_discovery_endpoint
),
require_email=config_from_static.require_email,
user_info_map=config_from_static.user_info_map,
generate_fake_email=config_from_static.generate_fake_email,
validate_id_token_payload=config_from_static.validate_id_token_payload,
validate_access_token=config_from_static.validate_access_token,
)
if result.user_info_map is None:
result.user_info_map = UserInfoMap(UserFields(), UserFields())
if result.user_info_map.from_user_info_api is None:
result.user_info_map.from_user_info_api = UserFields()
if result.user_info_map.from_id_token_payload is None:
result.user_info_map.from_id_token_payload = UserFields()
if config_from_core.user_info_map is not None:
if config_from_core.user_info_map.from_user_info_api is None:
config_from_core.user_info_map.from_user_info_api = UserFields()
if config_from_core.user_info_map.from_id_token_payload is None:
config_from_core.user_info_map.from_id_token_payload = UserFields()
if config_from_core.user_info_map.from_id_token_payload.user_id is not None:
result.user_info_map.from_id_token_payload.user_id = (
config_from_core.user_info_map.from_id_token_payload.user_id
)
if config_from_core.user_info_map.from_id_token_payload.email is not None:
result.user_info_map.from_id_token_payload.email = (
config_from_core.user_info_map.from_id_token_payload.email
)
if (
config_from_core.user_info_map.from_id_token_payload.email_verified
is not None
):
result.user_info_map.from_id_token_payload.email_verified = (
config_from_core.user_info_map.from_id_token_payload.email_verified
)
if config_from_core.user_info_map.from_user_info_api.user_id is not None:
result.user_info_map.from_user_info_api.user_id = (
config_from_core.user_info_map.from_user_info_api.user_id
)
if config_from_core.user_info_map.from_user_info_api.email is not None:
result.user_info_map.from_user_info_api.email = (
config_from_core.user_info_map.from_user_info_api.email
)
if config_from_core.user_info_map.from_user_info_api.email_verified is not None:
result.user_info_map.from_user_info_api.email_verified = (
config_from_core.user_info_map.from_user_info_api.email_verified
)
merged_clients = (config_from_static.clients or [])[:] # Make a copy
core_config_clients = config_from_core.clients or []
for core_client in core_config_clients:
found = False
for idx, static_client in enumerate(merged_clients):
if static_client.client_type == core_client.client_type:
merged_clients[idx] = core_client
found = True
break
if not found:
merged_clients.append(core_client)
result.clients = merged_clients
return result
def merge_providers_from_core_and_static(
provider_configs_from_core: List[ProviderConfig],
provider_inputs_from_static: List[ProviderInput],
include_all_providers: bool,
) -> List[ProviderInput]:
merged_providers: List[ProviderInput] = []
if len(provider_configs_from_core) == 0:
for config in filter(
lambda provider: include_all_providers
or provider.include_in_non_public_tenants_by_default,
provider_inputs_from_static,
):
merged_providers.append(config)
else:
for provider_config_from_core in provider_configs_from_core:
merged_provider_input = ProviderInput(provider_config_from_core)
for provider_input_from_static in provider_inputs_from_static:
if (
provider_input_from_static.config.third_party_id
== provider_config_from_core.third_party_id
):
merged_provider_input.config = merge_config(
provider_input_from_static.config, provider_config_from_core
)
merged_provider_input.override = provider_input_from_static.override
break
merged_providers.append(merged_provider_input)
return merged_providers
def create_provider(provider_input: ProviderInput) -> Provider:
if provider_input.config.third_party_id.startswith("active-directory"):
return ActiveDirectory(provider_input)
if provider_input.config.third_party_id.startswith("apple"):
return Apple(provider_input)
if provider_input.config.third_party_id.startswith("bitbucket"):
return Bitbucket(provider_input)
if provider_input.config.third_party_id.startswith("discord"):
return Discord(provider_input)
if provider_input.config.third_party_id.startswith("facebook"):
return Facebook(provider_input)
if provider_input.config.third_party_id.startswith("github"):
return Github(provider_input)
if provider_input.config.third_party_id.startswith("gitlab"):
return Gitlab(provider_input)
if provider_input.config.third_party_id.startswith("google-workspaces"):
return GoogleWorkspaces(provider_input)
if provider_input.config.third_party_id.startswith("google"):
return Google(provider_input)
if provider_input.config.third_party_id.startswith("okta"):
return Okta(provider_input)
if provider_input.config.third_party_id.startswith("linkedin"):
return Linkedin(provider_input)
if provider_input.config.third_party_id.startswith("twitter"):
return Twitter(provider_input)
if provider_input.config.third_party_id.startswith("boxy-saml"):
return BoxySAML(provider_input)
return NewProvider(provider_input)
OIDC_INFO_MAP: Dict[str, Any] = {}
async def get_oidc_discovery_info(issuer: str):
if issuer in OIDC_INFO_MAP:
return OIDC_INFO_MAP[issuer]
ndomain = NormalisedURLDomain(issuer)
npath = NormalisedURLPath(issuer)
oidc_info = await do_get_request(
ndomain.get_as_string_dangerous() + npath.get_as_string_dangerous()
)
OIDC_INFO_MAP[issuer] = oidc_info
return oidc_info
async def discover_oidc_endpoints(
config: ProviderConfigForClient,
) -> ProviderConfigForClient:
if config.oidc_discovery_endpoint is None:
return config
oidc_info = await get_oidc_discovery_info(config.oidc_discovery_endpoint)
if (
oidc_info.get("authorization_endpoint") is not None
and config.authorization_endpoint is None
):
config.authorization_endpoint = oidc_info["authorization_endpoint"]
if oidc_info.get("token_endpoint") is not None and config.token_endpoint is None:
config.token_endpoint = oidc_info["token_endpoint"]
if (
oidc_info.get("userinfo_endpoint") is not None
and config.user_info_endpoint is None
):
config.user_info_endpoint = oidc_info["userinfo_endpoint"]
if oidc_info.get("jwks_uri") is not None and config.jwks_uri is None:
config.jwks_uri = oidc_info["jwks_uri"]
return config
async def fetch_and_set_config(
provider_instance: Provider,
client_type: Optional[str],
user_context: Dict[str, Any],
):
config = await provider_instance.get_config_for_client_type(
client_type, user_context
)
config = await discover_oidc_endpoints(config)
provider_instance.config = config
async def find_and_create_provider_instance(
providers: List[ProviderInput],
third_party_id: str,
client_type: Optional[str],
user_context: Dict[str, Any],
) -> Optional[Provider]:
for provider_input in providers:
if provider_input.config.third_party_id == third_party_id:
provider_instance = create_provider(provider_input)
await fetch_and_set_config(provider_instance, client_type, user_context)
return provider_instance
return None
Functions
def create_provider(provider_input: ProviderInput) ‑> Provider
-
Expand source code
def create_provider(provider_input: ProviderInput) -> Provider: if provider_input.config.third_party_id.startswith("active-directory"): return ActiveDirectory(provider_input) if provider_input.config.third_party_id.startswith("apple"): return Apple(provider_input) if provider_input.config.third_party_id.startswith("bitbucket"): return Bitbucket(provider_input) if provider_input.config.third_party_id.startswith("discord"): return Discord(provider_input) if provider_input.config.third_party_id.startswith("facebook"): return Facebook(provider_input) if provider_input.config.third_party_id.startswith("github"): return Github(provider_input) if provider_input.config.third_party_id.startswith("gitlab"): return Gitlab(provider_input) if provider_input.config.third_party_id.startswith("google-workspaces"): return GoogleWorkspaces(provider_input) if provider_input.config.third_party_id.startswith("google"): return Google(provider_input) if provider_input.config.third_party_id.startswith("okta"): return Okta(provider_input) if provider_input.config.third_party_id.startswith("linkedin"): return Linkedin(provider_input) if provider_input.config.third_party_id.startswith("twitter"): return Twitter(provider_input) if provider_input.config.third_party_id.startswith("boxy-saml"): return BoxySAML(provider_input) return NewProvider(provider_input)
async def discover_oidc_endpoints(config: ProviderConfigForClient) ‑> ProviderConfigForClient
-
Expand source code
async def discover_oidc_endpoints( config: ProviderConfigForClient, ) -> ProviderConfigForClient: if config.oidc_discovery_endpoint is None: return config oidc_info = await get_oidc_discovery_info(config.oidc_discovery_endpoint) if ( oidc_info.get("authorization_endpoint") is not None and config.authorization_endpoint is None ): config.authorization_endpoint = oidc_info["authorization_endpoint"] if oidc_info.get("token_endpoint") is not None and config.token_endpoint is None: config.token_endpoint = oidc_info["token_endpoint"] if ( oidc_info.get("userinfo_endpoint") is not None and config.user_info_endpoint is None ): config.user_info_endpoint = oidc_info["userinfo_endpoint"] if oidc_info.get("jwks_uri") is not None and config.jwks_uri is None: config.jwks_uri = oidc_info["jwks_uri"] return config
async def fetch_and_set_config(provider_instance: Provider, client_type: Optional[str], user_context: Dict[str, Any])
-
Expand source code
async def fetch_and_set_config( provider_instance: Provider, client_type: Optional[str], user_context: Dict[str, Any], ): config = await provider_instance.get_config_for_client_type( client_type, user_context ) config = await discover_oidc_endpoints(config) provider_instance.config = config
async def find_and_create_provider_instance(providers: List[ProviderInput], third_party_id: str, client_type: Optional[str], user_context: Dict[str, Any]) ‑> Optional[Provider]
-
Expand source code
async def find_and_create_provider_instance( providers: List[ProviderInput], third_party_id: str, client_type: Optional[str], user_context: Dict[str, Any], ) -> Optional[Provider]: for provider_input in providers: if provider_input.config.third_party_id == third_party_id: provider_instance = create_provider(provider_input) await fetch_and_set_config(provider_instance, client_type, user_context) return provider_instance return None
async def get_oidc_discovery_info(issuer: str)
-
Expand source code
async def get_oidc_discovery_info(issuer: str): if issuer in OIDC_INFO_MAP: return OIDC_INFO_MAP[issuer] ndomain = NormalisedURLDomain(issuer) npath = NormalisedURLPath(issuer) oidc_info = await do_get_request( ndomain.get_as_string_dangerous() + npath.get_as_string_dangerous() ) OIDC_INFO_MAP[issuer] = oidc_info return oidc_info
def merge_config(config_from_static: ProviderConfig, config_from_core: ProviderConfig) ‑> ProviderConfig
-
Expand source code
def merge_config( config_from_static: ProviderConfig, config_from_core: ProviderConfig ) -> ProviderConfig: result = ProviderConfig( third_party_id=config_from_static.third_party_id, name=( config_from_static.name if config_from_core.name is None else config_from_core.name ), authorization_endpoint=( config_from_static.authorization_endpoint if config_from_core.authorization_endpoint is None else config_from_core.authorization_endpoint ), authorization_endpoint_query_params=( config_from_static.authorization_endpoint_query_params if config_from_core.authorization_endpoint_query_params is None else config_from_core.authorization_endpoint_query_params ), token_endpoint=( config_from_static.token_endpoint if config_from_core.token_endpoint is None else config_from_core.token_endpoint ), token_endpoint_body_params=( config_from_static.token_endpoint_body_params if config_from_core.token_endpoint_body_params is None else config_from_core.token_endpoint_body_params ), user_info_endpoint=( config_from_static.user_info_endpoint if config_from_core.user_info_endpoint is None else config_from_core.user_info_endpoint ), user_info_endpoint_headers=( config_from_static.user_info_endpoint_headers if config_from_core.user_info_endpoint_headers is None else config_from_core.user_info_endpoint_headers ), user_info_endpoint_query_params=( config_from_static.user_info_endpoint_query_params if config_from_core.user_info_endpoint_query_params is None else config_from_core.user_info_endpoint_query_params ), jwks_uri=( config_from_static.jwks_uri if config_from_core.jwks_uri is None else config_from_core.jwks_uri ), oidc_discovery_endpoint=( config_from_static.oidc_discovery_endpoint if config_from_core.oidc_discovery_endpoint is None else config_from_core.oidc_discovery_endpoint ), require_email=config_from_static.require_email, user_info_map=config_from_static.user_info_map, generate_fake_email=config_from_static.generate_fake_email, validate_id_token_payload=config_from_static.validate_id_token_payload, validate_access_token=config_from_static.validate_access_token, ) if result.user_info_map is None: result.user_info_map = UserInfoMap(UserFields(), UserFields()) if result.user_info_map.from_user_info_api is None: result.user_info_map.from_user_info_api = UserFields() if result.user_info_map.from_id_token_payload is None: result.user_info_map.from_id_token_payload = UserFields() if config_from_core.user_info_map is not None: if config_from_core.user_info_map.from_user_info_api is None: config_from_core.user_info_map.from_user_info_api = UserFields() if config_from_core.user_info_map.from_id_token_payload is None: config_from_core.user_info_map.from_id_token_payload = UserFields() if config_from_core.user_info_map.from_id_token_payload.user_id is not None: result.user_info_map.from_id_token_payload.user_id = ( config_from_core.user_info_map.from_id_token_payload.user_id ) if config_from_core.user_info_map.from_id_token_payload.email is not None: result.user_info_map.from_id_token_payload.email = ( config_from_core.user_info_map.from_id_token_payload.email ) if ( config_from_core.user_info_map.from_id_token_payload.email_verified is not None ): result.user_info_map.from_id_token_payload.email_verified = ( config_from_core.user_info_map.from_id_token_payload.email_verified ) if config_from_core.user_info_map.from_user_info_api.user_id is not None: result.user_info_map.from_user_info_api.user_id = ( config_from_core.user_info_map.from_user_info_api.user_id ) if config_from_core.user_info_map.from_user_info_api.email is not None: result.user_info_map.from_user_info_api.email = ( config_from_core.user_info_map.from_user_info_api.email ) if config_from_core.user_info_map.from_user_info_api.email_verified is not None: result.user_info_map.from_user_info_api.email_verified = ( config_from_core.user_info_map.from_user_info_api.email_verified ) merged_clients = (config_from_static.clients or [])[:] # Make a copy core_config_clients = config_from_core.clients or [] for core_client in core_config_clients: found = False for idx, static_client in enumerate(merged_clients): if static_client.client_type == core_client.client_type: merged_clients[idx] = core_client found = True break if not found: merged_clients.append(core_client) result.clients = merged_clients return result
def merge_providers_from_core_and_static(provider_configs_from_core: List[ProviderConfig], provider_inputs_from_static: List[ProviderInput], include_all_providers: bool) ‑> List[ProviderInput]
-
Expand source code
def merge_providers_from_core_and_static( provider_configs_from_core: List[ProviderConfig], provider_inputs_from_static: List[ProviderInput], include_all_providers: bool, ) -> List[ProviderInput]: merged_providers: List[ProviderInput] = [] if len(provider_configs_from_core) == 0: for config in filter( lambda provider: include_all_providers or provider.include_in_non_public_tenants_by_default, provider_inputs_from_static, ): merged_providers.append(config) else: for provider_config_from_core in provider_configs_from_core: merged_provider_input = ProviderInput(provider_config_from_core) for provider_input_from_static in provider_inputs_from_static: if ( provider_input_from_static.config.third_party_id == provider_config_from_core.third_party_id ): merged_provider_input.config = merge_config( provider_input_from_static.config, provider_config_from_core ) merged_provider_input.override = provider_input_from_static.override break merged_providers.append(merged_provider_input) return merged_providers