Module supertokens_python.recipe.thirdparty.providers.saml
Expand source code
# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import Any, Dict, Union
from ..provider import (
AuthorisationRedirect,
Provider,
ProviderInput,
RedirectUriInfo,
)
from ..types import RawUserInfoFromProvider, UserInfo, UserInfoEmail
from .custom import GenericProvider, NewProvider
class SAMLProviderImpl(GenericProvider):
def __init__(self, provider_config: Any):
super().__init__(provider_config)
self.provider_type = "saml"
async def get_authorisation_redirect_url(
self,
redirect_uri_on_provider_dashboard: str,
user_context: Dict[str, Any],
) -> AuthorisationRedirect:
from urllib.parse import urlencode
from supertokens_python.supertokens import Supertokens
st_instance = Supertokens.get_instance()
app_info = st_instance.app_info
tenant_id = getattr(self, "tenant_id", "public")
# Build URL to the SAML recipe's login endpoint
# Include tenant_id in the path so the SAML handler receives it
saml_login_url = (
app_info.api_domain.get_as_string_dangerous()
+ app_info.api_base_path.get_as_string_dangerous()
+ f"/{tenant_id}/saml/login"
)
query_params: Dict[str, str] = {
"client_id": self.config.client_id,
"redirect_uri": redirect_uri_on_provider_dashboard,
}
url = f"{saml_login_url}?{urlencode(query_params)}"
return AuthorisationRedirect(url_with_query_params=url, pkce_code_verifier=None)
async def exchange_auth_code_for_oauth_tokens(
self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]
) -> Dict[str, Any]:
raise Exception(
"SAML providers do not support exchangeAuthCodeForOAuthTokens. "
"The thirdparty sign-in-up flow handles SAML token extraction directly."
)
async def get_user_info(
self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]
) -> UserInfo:
from supertokens_python.recipe.saml.recipe import SAMLRecipe
access_token = oauth_tokens.get("access_token", "")
client_id = self.config.client_id
# Use the SAML recipe to get user info
saml_recipe = SAMLRecipe.get_instance()
# tenant_id is passed through oauth_tokens by the sign_in_up_post handler
tenant_id = oauth_tokens.get("_tenant_id", "public")
result = await saml_recipe.recipe_implementation.get_user_info(
tenant_id=tenant_id,
access_token=access_token,
client_id=client_id,
user_context=user_context,
)
from supertokens_python.recipe.saml.types import GetUserInfoOkResult
if isinstance(result, GetUserInfoOkResult):
email_info: Union[UserInfoEmail, None] = None
if result.email:
email_info = UserInfoEmail(email=result.email, is_verified=True)
raw_user_info = RawUserInfoFromProvider(
from_id_token_payload=None,
from_user_info_api=result.claims,
)
return UserInfo(
third_party_user_id=result.sub,
email=email_info,
raw_user_info_from_provider=raw_user_info,
)
raise Exception("Failed to get user info from SAML provider")
def SAML(input: ProviderInput) -> Provider: # pylint: disable=redefined-builtin
if not input.config.name:
input.config.name = "SAML"
return NewProvider(input, SAMLProviderImpl)
Functions
def SAML(input: ProviderInput) ‑> Provider
Classes
class SAMLProviderImpl (provider_config: Any)-
Expand source code
class SAMLProviderImpl(GenericProvider): def __init__(self, provider_config: Any): super().__init__(provider_config) self.provider_type = "saml" async def get_authorisation_redirect_url( self, redirect_uri_on_provider_dashboard: str, user_context: Dict[str, Any], ) -> AuthorisationRedirect: from urllib.parse import urlencode from supertokens_python.supertokens import Supertokens st_instance = Supertokens.get_instance() app_info = st_instance.app_info tenant_id = getattr(self, "tenant_id", "public") # Build URL to the SAML recipe's login endpoint # Include tenant_id in the path so the SAML handler receives it saml_login_url = ( app_info.api_domain.get_as_string_dangerous() + app_info.api_base_path.get_as_string_dangerous() + f"/{tenant_id}/saml/login" ) query_params: Dict[str, str] = { "client_id": self.config.client_id, "redirect_uri": redirect_uri_on_provider_dashboard, } url = f"{saml_login_url}?{urlencode(query_params)}" return AuthorisationRedirect(url_with_query_params=url, pkce_code_verifier=None) async def exchange_auth_code_for_oauth_tokens( self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any] ) -> Dict[str, Any]: raise Exception( "SAML providers do not support exchangeAuthCodeForOAuthTokens. " "The thirdparty sign-in-up flow handles SAML token extraction directly." ) async def get_user_info( self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any] ) -> UserInfo: from supertokens_python.recipe.saml.recipe import SAMLRecipe access_token = oauth_tokens.get("access_token", "") client_id = self.config.client_id # Use the SAML recipe to get user info saml_recipe = SAMLRecipe.get_instance() # tenant_id is passed through oauth_tokens by the sign_in_up_post handler tenant_id = oauth_tokens.get("_tenant_id", "public") result = await saml_recipe.recipe_implementation.get_user_info( tenant_id=tenant_id, access_token=access_token, client_id=client_id, user_context=user_context, ) from supertokens_python.recipe.saml.types import GetUserInfoOkResult if isinstance(result, GetUserInfoOkResult): email_info: Union[UserInfoEmail, None] = None if result.email: email_info = UserInfoEmail(email=result.email, is_verified=True) raw_user_info = RawUserInfoFromProvider( from_id_token_payload=None, from_user_info_api=result.claims, ) return UserInfo( third_party_user_id=result.sub, email=email_info, raw_user_info_from_provider=raw_user_info, ) raise Exception("Failed to get user info from SAML provider")Ancestors
Methods
async def exchange_auth_code_for_oauth_tokens(self, redirect_uri_info: RedirectUriInfo, user_context: Dict[str, Any]) ‑> Dict[str, Any]async def get_user_info(self, oauth_tokens: Dict[str, Any], user_context: Dict[str, Any]) ‑> UserInfo