Module supertokens_python.recipe.thirdparty.providers.custom
Expand source code
from typing import Any, Callable, Dict, List, Optional, Union
from urllib.parse import parse_qs, urlencode, urlparse
from httpx import AsyncClient
from jwt import decode # type: ignore
from jwt.algorithms import RSAAlgorithm
import pkce
from supertokens_python.recipe.thirdparty.exceptions import ClientTypeNotFoundError
from supertokens_python.recipe.thirdparty.providers.utils import (
DEV_OAUTH_AUTHORIZATION_URL,
DEV_OAUTH_REDIRECT_URL,
do_get_request,
do_post_request,
get_actual_client_id_from_development_client_id,
is_using_oauth_development_client_id,
DEV_KEY_IDENTIFIER,
DEV_OAUTH_CLIENT_IDS,
)
from ..types import RawUserInfoFromProvider, UserInfo, UserInfoEmail
from ..provider import (
AuthorisationRedirect,
Provider,
ProviderClientConfig,
ProviderConfig,
ProviderConfigForClient,
ProviderInput,
RedirectUriInfo,
UserFields,
UserInfoMap,
)
def get_provider_config_for_client(
config: ProviderConfig, client_config: ProviderClientConfig
) -> ProviderConfigForClient:
return ProviderConfigForClient(
# ProviderClientConfig
client_id=client_config.client_id,
client_secret=client_config.client_secret,
client_type=client_config.client_type,
scope=client_config.scope,
force_pkce=client_config.force_pkce,
additional_config=client_config.additional_config,
# CommonProviderConfig
third_party_id=config.third_party_id,
name=config.name,
authorization_endpoint=config.authorization_endpoint,
authorization_endpoint_query_params=config.authorization_endpoint_query_params,
token_endpoint=config.token_endpoint,
token_endpoint_body_params=config.token_endpoint_body_params,
user_info_endpoint=config.user_info_endpoint,
user_info_endpoint_query_params=config.user_info_endpoint_query_params,
user_info_endpoint_headers=config.user_info_endpoint_headers,
jwks_uri=config.jwks_uri,
oidc_discovery_endpoint=config.oidc_discovery_endpoint,
user_info_map=config.user_info_map,
require_email=config.require_email,
validate_id_token_payload=config.validate_id_token_payload,
generate_fake_email=config.generate_fake_email,
validate_access_token=config.validate_access_token,
)
def access_field(obj: Any, key: str) -> Any:
key_parts = key.split(".")
for part in key_parts:
if isinstance(obj, dict):
obj = obj.get(part) # type: ignore
else:
return None
return obj
def get_supertokens_user_info_result_from_raw_user_info(
config: ProviderConfigForClient,
raw_user_info_from_provider: RawUserInfoFromProvider,
) -> UserInfo:
third_party_user_id = ""
if config.user_info_map is None:
raise Exception("user info map is missing")
if config.user_info_map.from_user_info_api is None:
config.user_info_map.from_user_info_api = UserFields()
if config.user_info_map.from_id_token_payload is None:
config.user_info_map.from_id_token_payload = UserFields()
if config.user_info_map.from_user_info_api.user_id is not None:
user_id = access_field(
raw_user_info_from_provider.from_user_info_api,
config.user_info_map.from_user_info_api.user_id,
)
if user_id is not None:
third_party_user_id = str(user_id)
if config.user_info_map.from_id_token_payload.user_id is not None:
user_id = access_field(
raw_user_info_from_provider.from_id_token_payload,
config.user_info_map.from_id_token_payload.user_id,
)
if user_id is not None:
third_party_user_id = str(user_id)
if third_party_user_id == "":
raise Exception("third party user id is missing")
result = UserInfo(
third_party_user_id=third_party_user_id,
)
email = ""
if config.user_info_map.from_user_info_api.email is not None:
email_val = access_field(
raw_user_info_from_provider.from_user_info_api,
config.user_info_map.from_user_info_api.email,
)
if email_val is not None:
email = email_val
if config.user_info_map.from_id_token_payload.email is not None:
email_val = access_field(
raw_user_info_from_provider.from_id_token_payload,
config.user_info_map.from_id_token_payload.email,
)
if email_val is not None:
email = email_val
if email != "":
result.email = UserInfoEmail(email, False)
if config.user_info_map.from_user_info_api.email_verified is not None:
email_verified = access_field(
raw_user_info_from_provider.from_user_info_api,
config.user_info_map.from_user_info_api.email_verified,
)
if email_verified is not None:
result.email.is_verified = str(email_verified).lower() == "true"
if config.user_info_map.from_id_token_payload.email_verified is not None:
email_verified = access_field(
raw_user_info_from_provider.from_id_token_payload,
config.user_info_map.from_id_token_payload.email_verified,
)
if email_verified is not None:
result.email.is_verified = str(email_verified).lower() == "true"
return result
async def verify_id_token_from_jwks_endpoint_and_get_payload(
id_token: str, jwks_uri: str, audience: str
):
public_keys: List[RSAAlgorithm] = []
async with AsyncClient(timeout=30.0) as client:
response = await client.get(jwks_uri) # type:ignore
key_payload = response.json()
for key in key_payload["keys"]:
public_keys.append(RSAAlgorithm.from_jwk(key)) # type: ignore
err = Exception("id token verification failed")
for key in public_keys:
try:
return decode(jwt=id_token, key=key, audience=[audience], algorithms=["RS256"]) # type: ignore
except Exception as e:
err = e
raise err
def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]:
res = dest.copy()
for k, v in src.items():
if v is None:
if k in res:
del res[k]
else:
res[k] = v
return res
def is_using_development_client_id(client_id: str) -> bool:
return client_id.startswith(DEV_KEY_IDENTIFIER) or client_id in DEV_OAUTH_CLIENT_IDS
class GenericProvider(Provider):
def __init__(self, provider_config: ProviderConfig):
self.input_config = input_config = self._normalize_input(provider_config)
provider_config_for_client = ProviderConfigForClient(
# Will automatically get replaced with correct value
# in get_provider_config_for_client
# when fetch_and_set_config function runs
client_id="temp",
client_secret=None,
client_type=None,
scope=None,
force_pkce=False,
additional_config=None,
name=input_config.name,
authorization_endpoint=input_config.authorization_endpoint,
authorization_endpoint_query_params=input_config.authorization_endpoint_query_params,
token_endpoint=input_config.token_endpoint,
token_endpoint_body_params=input_config.token_endpoint_body_params,
user_info_endpoint=input_config.user_info_endpoint,
user_info_endpoint_query_params=input_config.user_info_endpoint_query_params,
user_info_endpoint_headers=input_config.user_info_endpoint_headers,
jwks_uri=input_config.jwks_uri,
oidc_discovery_endpoint=input_config.oidc_discovery_endpoint,
user_info_map=input_config.user_info_map,
require_email=input_config.require_email,
validate_id_token_payload=input_config.validate_id_token_payload,
generate_fake_email=input_config.generate_fake_email,
)
super().__init__(input_config.third_party_id, provider_config_for_client)
def _normalize_input( # pylint: disable=no-self-use
self, input_config: ProviderConfig
) -> ProviderConfig:
if input_config.user_info_map is None:
input_config.user_info_map = UserInfoMap(
from_id_token_payload=UserFields(),
from_user_info_api=UserFields(),
)
if input_config.user_info_map.from_user_info_api is None:
input_config.user_info_map.from_user_info_api = UserFields()
if input_config.user_info_map.from_id_token_payload is None:
input_config.user_info_map.from_id_token_payload = UserFields()
# These are safe defaults common to most providers. Each provider
# implementations override these as necessary
if input_config.user_info_map.from_id_token_payload.user_id is None:
input_config.user_info_map.from_id_token_payload.user_id = "sub"
if input_config.user_info_map.from_id_token_payload.email is None:
input_config.user_info_map.from_id_token_payload.email = "email"
if input_config.user_info_map.from_id_token_payload.email_verified is None:
input_config.user_info_map.from_id_token_payload.email_verified = (
"email_verified"
)
if input_config.user_info_map.from_user_info_api.user_id is None:
input_config.user_info_map.from_user_info_api.user_id = "sub"
if input_config.user_info_map.from_user_info_api.email is None:
input_config.user_info_map.from_user_info_api.email = "email"
if input_config.user_info_map.from_user_info_api.email_verified is None:
input_config.user_info_map.from_user_info_api.email_verified = (
"email_verified"
)
if input_config.generate_fake_email is None:
async def default_generate_fake_email(
_tenant_id: str, third_party_user_id: str, _: Dict[str, Any]
) -> str:
return f"{third_party_user_id}.{input_config.third_party_id}@stfakeemail.supertokens.com"
input_config.generate_fake_email = default_generate_fake_email
return input_config
async def get_config_for_client_type(
self, client_type: Optional[str], user_context: Dict[str, Any]
) -> ProviderConfigForClient:
if client_type is None:
if self.input_config.clients is None or len(self.input_config.clients) != 1:
raise ClientTypeNotFoundError(
"please provide exactly one client config or pass clientType or tenantId"
)
return get_provider_config_for_client(
self.input_config, self.input_config.clients[0]
)
if self.input_config.clients is not None:
for client in self.input_config.clients:
if client.client_type == client_type:
return get_provider_config_for_client(self.input_config, client)
raise ClientTypeNotFoundError(
f"Could not find client config for clientType: {client_type}"
)
async def get_authorisation_redirect_url(
self,
redirect_uri_on_provider_dashboard: str,
user_context: Dict[str, Any],
) -> AuthorisationRedirect:
query_params: Dict[str, str] = {
"client_id": self.config.client_id,
"redirect_uri": redirect_uri_on_provider_dashboard,
"response_type": "code",
}
if self.config.scope is not None:
query_params["scope"] = " ".join(self.config.scope)
pkce_code_verifier: Union[str, None] = None
if self.config.client_secret is None or self.config.force_pkce:
code_verifier, code_challenge = pkce.generate_pkce_pair(64)
query_params["code_challenge"] = code_challenge
query_params["code_challenge_method"] = "S256"
pkce_code_verifier = code_verifier
if self.config.authorization_endpoint_query_params is not None:
for k, v in self.config.authorization_endpoint_query_params.items():
if v is None:
del query_params[k]
else:
query_params[k] = v
if self.config.authorization_endpoint is None:
raise Exception(
"ThirdParty provider's authorizationEndpoint is not configured."
)
url: str = self.config.authorization_endpoint
# Transformation needed for dev keys BEGIN
if is_using_oauth_development_client_id(self.config.client_id):
query_params["client_id"] = get_actual_client_id_from_development_client_id(
self.config.client_id
)
query_params["actual_redirect_uri"] = url
url = DEV_OAUTH_AUTHORIZATION_URL
# Transformation needed for dev keys END
url_obj = urlparse(url)
qparams = parse_qs(url_obj.query)
for k, v in query_params.items():
qparams[k] = [v]
url = url_obj._replace(query=urlencode(qparams, doseq=True)).geturl()
return AuthorisationRedirect(url, pkce_code_verifier)
async def exchange_auth_code_for_oauth_tokens(
self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]
) -> Dict[str, Any]:
if self.config.token_endpoint is None:
raise Exception("ThirdParty provider's tokenEndpoint is not configured.")
token_api_url = self.config.token_endpoint
access_token_params: Dict[str, str] = {
"client_id": self.config.client_id,
"redirect_uri": redirect_uri_info.redirect_uri_on_provider_dashboard,
"code": redirect_uri_info.redirect_uri_query_params["code"],
"grant_type": "authorization_code",
}
if self.config.client_secret is not None:
access_token_params["client_secret"] = self.config.client_secret
if redirect_uri_info.pkce_code_verifier is not None:
access_token_params["code_verifier"] = redirect_uri_info.pkce_code_verifier
if self.config.token_endpoint_body_params is not None:
access_token_params = merge_into_dict(
self.config.token_endpoint_body_params, access_token_params
)
# Transformation needed for dev keys BEGIN
if is_using_oauth_development_client_id(self.config.client_id):
access_token_params[
"client_id"
] = get_actual_client_id_from_development_client_id(self.config.client_id)
access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL
# Transformation needed for dev keys END
_, body = await do_post_request(token_api_url, access_token_params)
return body
async def get_user_info(
self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]
) -> UserInfo:
access_token: Union[str, None] = oauth_tokens.get("access_token")
id_token: Union[str, None] = oauth_tokens.get("id_token")
raw_user_info_from_provider = RawUserInfoFromProvider({}, {})
if id_token is not None and self.config.jwks_uri is not None:
raw_user_info_from_provider.from_id_token_payload = (
await verify_id_token_from_jwks_endpoint_and_get_payload(
id_token,
self.config.jwks_uri,
get_actual_client_id_from_development_client_id(
self.config.client_id
),
)
)
if self.config.validate_id_token_payload is not None:
await self.config.validate_id_token_payload(
raw_user_info_from_provider.from_id_token_payload,
self.config,
user_context,
)
if self.config.validate_access_token is not None and access_token is not None:
await self.config.validate_access_token(
access_token, self.config, user_context
)
if access_token is not None and self.config.user_info_endpoint is not None:
headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"}
query_params: Dict[str, str] = {}
if self.config.user_info_endpoint_headers is not None:
headers = merge_into_dict(
self.config.user_info_endpoint_headers, headers
)
if self.config.user_info_endpoint_query_params is not None:
query_params = merge_into_dict(
self.config.user_info_endpoint_query_params, query_params
)
raw_user_info_from_provider.from_user_info_api = await do_get_request(
self.config.user_info_endpoint, query_params, headers
)
user_info_result = get_supertokens_user_info_result_from_raw_user_info(
self.config, raw_user_info_from_provider
)
return UserInfo(
third_party_user_id=user_info_result.third_party_user_id,
email=user_info_result.email,
raw_user_info_from_provider=raw_user_info_from_provider,
)
def NewProvider(
input_: ProviderInput,
base_class: Callable[[ProviderConfig], Provider] = GenericProvider,
) -> Provider:
provider_instance = base_class(input_.config)
if input_.override is not None:
provider_instance = input_.override(provider_instance)
return provider_instance
Functions
def NewProvider(input_: ProviderInput, base_class: Callable[[ProviderConfig], Provider] = supertokens_python.recipe.thirdparty.providers.custom.GenericProvider) ‑> Provider
-
Expand source code
def NewProvider( input_: ProviderInput, base_class: Callable[[ProviderConfig], Provider] = GenericProvider, ) -> Provider: provider_instance = base_class(input_.config) if input_.override is not None: provider_instance = input_.override(provider_instance) return provider_instance
def access_field(obj: Any, key: str) ‑> Any
-
Expand source code
def access_field(obj: Any, key: str) -> Any: key_parts = key.split(".") for part in key_parts: if isinstance(obj, dict): obj = obj.get(part) # type: ignore else: return None return obj
def get_provider_config_for_client(config: ProviderConfig, client_config: ProviderClientConfig) ‑> ProviderConfigForClient
-
Expand source code
def get_provider_config_for_client( config: ProviderConfig, client_config: ProviderClientConfig ) -> ProviderConfigForClient: return ProviderConfigForClient( # ProviderClientConfig client_id=client_config.client_id, client_secret=client_config.client_secret, client_type=client_config.client_type, scope=client_config.scope, force_pkce=client_config.force_pkce, additional_config=client_config.additional_config, # CommonProviderConfig third_party_id=config.third_party_id, name=config.name, authorization_endpoint=config.authorization_endpoint, authorization_endpoint_query_params=config.authorization_endpoint_query_params, token_endpoint=config.token_endpoint, token_endpoint_body_params=config.token_endpoint_body_params, user_info_endpoint=config.user_info_endpoint, user_info_endpoint_query_params=config.user_info_endpoint_query_params, user_info_endpoint_headers=config.user_info_endpoint_headers, jwks_uri=config.jwks_uri, oidc_discovery_endpoint=config.oidc_discovery_endpoint, user_info_map=config.user_info_map, require_email=config.require_email, validate_id_token_payload=config.validate_id_token_payload, generate_fake_email=config.generate_fake_email, validate_access_token=config.validate_access_token, )
def get_supertokens_user_info_result_from_raw_user_info(config: ProviderConfigForClient, raw_user_info_from_provider: RawUserInfoFromProvider) ‑> UserInfo
-
Expand source code
def get_supertokens_user_info_result_from_raw_user_info( config: ProviderConfigForClient, raw_user_info_from_provider: RawUserInfoFromProvider, ) -> UserInfo: third_party_user_id = "" if config.user_info_map is None: raise Exception("user info map is missing") if config.user_info_map.from_user_info_api is None: config.user_info_map.from_user_info_api = UserFields() if config.user_info_map.from_id_token_payload is None: config.user_info_map.from_id_token_payload = UserFields() if config.user_info_map.from_user_info_api.user_id is not None: user_id = access_field( raw_user_info_from_provider.from_user_info_api, config.user_info_map.from_user_info_api.user_id, ) if user_id is not None: third_party_user_id = str(user_id) if config.user_info_map.from_id_token_payload.user_id is not None: user_id = access_field( raw_user_info_from_provider.from_id_token_payload, config.user_info_map.from_id_token_payload.user_id, ) if user_id is not None: third_party_user_id = str(user_id) if third_party_user_id == "": raise Exception("third party user id is missing") result = UserInfo( third_party_user_id=third_party_user_id, ) email = "" if config.user_info_map.from_user_info_api.email is not None: email_val = access_field( raw_user_info_from_provider.from_user_info_api, config.user_info_map.from_user_info_api.email, ) if email_val is not None: email = email_val if config.user_info_map.from_id_token_payload.email is not None: email_val = access_field( raw_user_info_from_provider.from_id_token_payload, config.user_info_map.from_id_token_payload.email, ) if email_val is not None: email = email_val if email != "": result.email = UserInfoEmail(email, False) if config.user_info_map.from_user_info_api.email_verified is not None: email_verified = access_field( raw_user_info_from_provider.from_user_info_api, config.user_info_map.from_user_info_api.email_verified, ) if email_verified is not None: result.email.is_verified = str(email_verified).lower() == "true" if config.user_info_map.from_id_token_payload.email_verified is not None: email_verified = access_field( raw_user_info_from_provider.from_id_token_payload, config.user_info_map.from_id_token_payload.email_verified, ) if email_verified is not None: result.email.is_verified = str(email_verified).lower() == "true" return result
def is_using_development_client_id(client_id: str) ‑> bool
-
Expand source code
def is_using_development_client_id(client_id: str) -> bool: return client_id.startswith(DEV_KEY_IDENTIFIER) or client_id in DEV_OAUTH_CLIENT_IDS
def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) ‑> Dict[str, Any]
-
Expand source code
def merge_into_dict(src: Dict[str, Any], dest: Dict[str, Any]) -> Dict[str, Any]: res = dest.copy() for k, v in src.items(): if v is None: if k in res: del res[k] else: res[k] = v return res
async def verify_id_token_from_jwks_endpoint_and_get_payload(id_token: str, jwks_uri: str, audience: str)
-
Expand source code
async def verify_id_token_from_jwks_endpoint_and_get_payload( id_token: str, jwks_uri: str, audience: str ): public_keys: List[RSAAlgorithm] = [] async with AsyncClient(timeout=30.0) as client: response = await client.get(jwks_uri) # type:ignore key_payload = response.json() for key in key_payload["keys"]: public_keys.append(RSAAlgorithm.from_jwk(key)) # type: ignore err = Exception("id token verification failed") for key in public_keys: try: return decode(jwt=id_token, key=key, audience=[audience], algorithms=["RS256"]) # type: ignore except Exception as e: err = e raise err
Classes
class GenericProvider (provider_config: ProviderConfig)
-
Expand source code
class GenericProvider(Provider): def __init__(self, provider_config: ProviderConfig): self.input_config = input_config = self._normalize_input(provider_config) provider_config_for_client = ProviderConfigForClient( # Will automatically get replaced with correct value # in get_provider_config_for_client # when fetch_and_set_config function runs client_id="temp", client_secret=None, client_type=None, scope=None, force_pkce=False, additional_config=None, name=input_config.name, authorization_endpoint=input_config.authorization_endpoint, authorization_endpoint_query_params=input_config.authorization_endpoint_query_params, token_endpoint=input_config.token_endpoint, token_endpoint_body_params=input_config.token_endpoint_body_params, user_info_endpoint=input_config.user_info_endpoint, user_info_endpoint_query_params=input_config.user_info_endpoint_query_params, user_info_endpoint_headers=input_config.user_info_endpoint_headers, jwks_uri=input_config.jwks_uri, oidc_discovery_endpoint=input_config.oidc_discovery_endpoint, user_info_map=input_config.user_info_map, require_email=input_config.require_email, validate_id_token_payload=input_config.validate_id_token_payload, generate_fake_email=input_config.generate_fake_email, ) super().__init__(input_config.third_party_id, provider_config_for_client) def _normalize_input( # pylint: disable=no-self-use self, input_config: ProviderConfig ) -> ProviderConfig: if input_config.user_info_map is None: input_config.user_info_map = UserInfoMap( from_id_token_payload=UserFields(), from_user_info_api=UserFields(), ) if input_config.user_info_map.from_user_info_api is None: input_config.user_info_map.from_user_info_api = UserFields() if input_config.user_info_map.from_id_token_payload is None: input_config.user_info_map.from_id_token_payload = UserFields() # These are safe defaults common to most providers. Each provider # implementations override these as necessary if input_config.user_info_map.from_id_token_payload.user_id is None: input_config.user_info_map.from_id_token_payload.user_id = "sub" if input_config.user_info_map.from_id_token_payload.email is None: input_config.user_info_map.from_id_token_payload.email = "email" if input_config.user_info_map.from_id_token_payload.email_verified is None: input_config.user_info_map.from_id_token_payload.email_verified = ( "email_verified" ) if input_config.user_info_map.from_user_info_api.user_id is None: input_config.user_info_map.from_user_info_api.user_id = "sub" if input_config.user_info_map.from_user_info_api.email is None: input_config.user_info_map.from_user_info_api.email = "email" if input_config.user_info_map.from_user_info_api.email_verified is None: input_config.user_info_map.from_user_info_api.email_verified = ( "email_verified" ) if input_config.generate_fake_email is None: async def default_generate_fake_email( _tenant_id: str, third_party_user_id: str, _: Dict[str, Any] ) -> str: return f"{third_party_user_id}.{input_config.third_party_id}@stfakeemail.supertokens.com" input_config.generate_fake_email = default_generate_fake_email return input_config async def get_config_for_client_type( self, client_type: Optional[str], user_context: Dict[str, Any] ) -> ProviderConfigForClient: if client_type is None: if self.input_config.clients is None or len(self.input_config.clients) != 1: raise ClientTypeNotFoundError( "please provide exactly one client config or pass clientType or tenantId" ) return get_provider_config_for_client( self.input_config, self.input_config.clients[0] ) if self.input_config.clients is not None: for client in self.input_config.clients: if client.client_type == client_type: return get_provider_config_for_client(self.input_config, client) raise ClientTypeNotFoundError( f"Could not find client config for clientType: {client_type}" ) async def get_authorisation_redirect_url( self, redirect_uri_on_provider_dashboard: str, user_context: Dict[str, Any], ) -> AuthorisationRedirect: query_params: Dict[str, str] = { "client_id": self.config.client_id, "redirect_uri": redirect_uri_on_provider_dashboard, "response_type": "code", } if self.config.scope is not None: query_params["scope"] = " ".join(self.config.scope) pkce_code_verifier: Union[str, None] = None if self.config.client_secret is None or self.config.force_pkce: code_verifier, code_challenge = pkce.generate_pkce_pair(64) query_params["code_challenge"] = code_challenge query_params["code_challenge_method"] = "S256" pkce_code_verifier = code_verifier if self.config.authorization_endpoint_query_params is not None: for k, v in self.config.authorization_endpoint_query_params.items(): if v is None: del query_params[k] else: query_params[k] = v if self.config.authorization_endpoint is None: raise Exception( "ThirdParty provider's authorizationEndpoint is not configured." ) url: str = self.config.authorization_endpoint # Transformation needed for dev keys BEGIN if is_using_oauth_development_client_id(self.config.client_id): query_params["client_id"] = get_actual_client_id_from_development_client_id( self.config.client_id ) query_params["actual_redirect_uri"] = url url = DEV_OAUTH_AUTHORIZATION_URL # Transformation needed for dev keys END url_obj = urlparse(url) qparams = parse_qs(url_obj.query) for k, v in query_params.items(): qparams[k] = [v] url = url_obj._replace(query=urlencode(qparams, doseq=True)).geturl() return AuthorisationRedirect(url, pkce_code_verifier) async def exchange_auth_code_for_oauth_tokens( self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any] ) -> Dict[str, Any]: if self.config.token_endpoint is None: raise Exception("ThirdParty provider's tokenEndpoint is not configured.") token_api_url = self.config.token_endpoint access_token_params: Dict[str, str] = { "client_id": self.config.client_id, "redirect_uri": redirect_uri_info.redirect_uri_on_provider_dashboard, "code": redirect_uri_info.redirect_uri_query_params["code"], "grant_type": "authorization_code", } if self.config.client_secret is not None: access_token_params["client_secret"] = self.config.client_secret if redirect_uri_info.pkce_code_verifier is not None: access_token_params["code_verifier"] = redirect_uri_info.pkce_code_verifier if self.config.token_endpoint_body_params is not None: access_token_params = merge_into_dict( self.config.token_endpoint_body_params, access_token_params ) # Transformation needed for dev keys BEGIN if is_using_oauth_development_client_id(self.config.client_id): access_token_params[ "client_id" ] = get_actual_client_id_from_development_client_id(self.config.client_id) access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL # Transformation needed for dev keys END _, body = await do_post_request(token_api_url, access_token_params) return body async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] ) -> UserInfo: access_token: Union[str, None] = oauth_tokens.get("access_token") id_token: Union[str, None] = oauth_tokens.get("id_token") raw_user_info_from_provider = RawUserInfoFromProvider({}, {}) if id_token is not None and self.config.jwks_uri is not None: raw_user_info_from_provider.from_id_token_payload = ( await verify_id_token_from_jwks_endpoint_and_get_payload( id_token, self.config.jwks_uri, get_actual_client_id_from_development_client_id( self.config.client_id ), ) ) if self.config.validate_id_token_payload is not None: await self.config.validate_id_token_payload( raw_user_info_from_provider.from_id_token_payload, self.config, user_context, ) if self.config.validate_access_token is not None and access_token is not None: await self.config.validate_access_token( access_token, self.config, user_context ) if access_token is not None and self.config.user_info_endpoint is not None: headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"} query_params: Dict[str, str] = {} if self.config.user_info_endpoint_headers is not None: headers = merge_into_dict( self.config.user_info_endpoint_headers, headers ) if self.config.user_info_endpoint_query_params is not None: query_params = merge_into_dict( self.config.user_info_endpoint_query_params, query_params ) raw_user_info_from_provider.from_user_info_api = await do_get_request( self.config.user_info_endpoint, query_params, headers ) user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) return UserInfo( third_party_user_id=user_info_result.third_party_user_id, email=user_info_result.email, raw_user_info_from_provider=raw_user_info_from_provider, )
Ancestors
Subclasses
- ActiveDirectoryImpl
- AppleImpl
- BitbucketImpl
- BoxySAMLImpl
- DiscordImpl
- FacebookImpl
- GithubImpl
- GitlabImpl
- GoogleImpl
- LinkedinImpl
- OktaImpl
- TwitterImpl
Methods
async def exchange_auth_code_for_oauth_tokens(self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]) ‑> Dict[str, Any]
-
Expand source code
async def exchange_auth_code_for_oauth_tokens( self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any] ) -> Dict[str, Any]: if self.config.token_endpoint is None: raise Exception("ThirdParty provider's tokenEndpoint is not configured.") token_api_url = self.config.token_endpoint access_token_params: Dict[str, str] = { "client_id": self.config.client_id, "redirect_uri": redirect_uri_info.redirect_uri_on_provider_dashboard, "code": redirect_uri_info.redirect_uri_query_params["code"], "grant_type": "authorization_code", } if self.config.client_secret is not None: access_token_params["client_secret"] = self.config.client_secret if redirect_uri_info.pkce_code_verifier is not None: access_token_params["code_verifier"] = redirect_uri_info.pkce_code_verifier if self.config.token_endpoint_body_params is not None: access_token_params = merge_into_dict( self.config.token_endpoint_body_params, access_token_params ) # Transformation needed for dev keys BEGIN if is_using_oauth_development_client_id(self.config.client_id): access_token_params[ "client_id" ] = get_actual_client_id_from_development_client_id(self.config.client_id) access_token_params["redirect_uri"] = DEV_OAUTH_REDIRECT_URL # Transformation needed for dev keys END _, body = await do_post_request(token_api_url, access_token_params) return body
-
Expand source code
async def get_authorisation_redirect_url( self, redirect_uri_on_provider_dashboard: str, user_context: Dict[str, Any], ) -> AuthorisationRedirect: query_params: Dict[str, str] = { "client_id": self.config.client_id, "redirect_uri": redirect_uri_on_provider_dashboard, "response_type": "code", } if self.config.scope is not None: query_params["scope"] = " ".join(self.config.scope) pkce_code_verifier: Union[str, None] = None if self.config.client_secret is None or self.config.force_pkce: code_verifier, code_challenge = pkce.generate_pkce_pair(64) query_params["code_challenge"] = code_challenge query_params["code_challenge_method"] = "S256" pkce_code_verifier = code_verifier if self.config.authorization_endpoint_query_params is not None: for k, v in self.config.authorization_endpoint_query_params.items(): if v is None: del query_params[k] else: query_params[k] = v if self.config.authorization_endpoint is None: raise Exception( "ThirdParty provider's authorizationEndpoint is not configured." ) url: str = self.config.authorization_endpoint # Transformation needed for dev keys BEGIN if is_using_oauth_development_client_id(self.config.client_id): query_params["client_id"] = get_actual_client_id_from_development_client_id( self.config.client_id ) query_params["actual_redirect_uri"] = url url = DEV_OAUTH_AUTHORIZATION_URL # Transformation needed for dev keys END url_obj = urlparse(url) qparams = parse_qs(url_obj.query) for k, v in query_params.items(): qparams[k] = [v] url = url_obj._replace(query=urlencode(qparams, doseq=True)).geturl() return AuthorisationRedirect(url, pkce_code_verifier)
async def get_config_for_client_type(self, client_type: Optional[str], user_context: Dict[str, Any]) ‑> ProviderConfigForClient
-
Expand source code
async def get_config_for_client_type( self, client_type: Optional[str], user_context: Dict[str, Any] ) -> ProviderConfigForClient: if client_type is None: if self.input_config.clients is None or len(self.input_config.clients) != 1: raise ClientTypeNotFoundError( "please provide exactly one client config or pass clientType or tenantId" ) return get_provider_config_for_client( self.input_config, self.input_config.clients[0] ) if self.input_config.clients is not None: for client in self.input_config.clients: if client.client_type == client_type: return get_provider_config_for_client(self.input_config, client) raise ClientTypeNotFoundError( f"Could not find client config for clientType: {client_type}" )
async def get_user_info(self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]) ‑> UserInfo
-
Expand source code
async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] ) -> UserInfo: access_token: Union[str, None] = oauth_tokens.get("access_token") id_token: Union[str, None] = oauth_tokens.get("id_token") raw_user_info_from_provider = RawUserInfoFromProvider({}, {}) if id_token is not None and self.config.jwks_uri is not None: raw_user_info_from_provider.from_id_token_payload = ( await verify_id_token_from_jwks_endpoint_and_get_payload( id_token, self.config.jwks_uri, get_actual_client_id_from_development_client_id( self.config.client_id ), ) ) if self.config.validate_id_token_payload is not None: await self.config.validate_id_token_payload( raw_user_info_from_provider.from_id_token_payload, self.config, user_context, ) if self.config.validate_access_token is not None and access_token is not None: await self.config.validate_access_token( access_token, self.config, user_context ) if access_token is not None and self.config.user_info_endpoint is not None: headers: Dict[str, str] = {"Authorization": f"Bearer {access_token}"} query_params: Dict[str, str] = {} if self.config.user_info_endpoint_headers is not None: headers = merge_into_dict( self.config.user_info_endpoint_headers, headers ) if self.config.user_info_endpoint_query_params is not None: query_params = merge_into_dict( self.config.user_info_endpoint_query_params, query_params ) raw_user_info_from_provider.from_user_info_api = await do_get_request( self.config.user_info_endpoint, query_params, headers ) user_info_result = get_supertokens_user_info_result_from_raw_user_info( self.config, raw_user_info_from_provider ) return UserInfo( third_party_user_id=user_info_result.third_party_user_id, email=user_info_result.email, raw_user_info_from_provider=raw_user_info_from_provider, )