Module supertokens_python.recipe.oauth2provider.api.utils

Expand source code
# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.


from __future__ import annotations

import time
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
from urllib.parse import parse_qs, urlparse

from supertokens_python.recipe.multitenancy.constants import DEFAULT_TENANT_ID
from supertokens_python.recipe.session.interfaces import SessionClaimValidator
from supertokens_python.recipe.session.recipe import SessionRecipe
from supertokens_python.recipe.session.session_request_functions import (
    get_session_from_request,
)
from supertokens_python.types import MaybeAwaitable

from ..constants import AUTH_PATH, END_SESSION_PATH, LOGIN_PATH

if TYPE_CHECKING:
    from supertokens_python.recipe.session.interfaces import SessionContainer
    from supertokens_python.supertokens import AppInfo

    from ..interfaces import (
        ErrorOAuth2Response,
        RecipeInterface,
        RedirectResponse,
    )


async def login_get(
    recipe_implementation: RecipeInterface,
    login_challenge: str,
    session: Optional[SessionContainer],
    should_try_refresh: bool,
    cookies: Optional[List[str]],
    is_direct_call: bool,
    user_context: Dict[str, Any],
) -> Union[RedirectResponse, ErrorOAuth2Response]:
    from ..interfaces import (
        ErrorOAuth2Response,
        FrontendRedirectionURLTypeLogin,
        FrontendRedirectionURLTypeTryRefresh,
        RedirectResponse,
    )

    login_request = await recipe_implementation.get_login_request(
        challenge=login_challenge,
        user_context=user_context,
    )

    if isinstance(login_request, ErrorOAuth2Response):
        return login_request

    session_info = (
        await SessionRecipe.get_instance().recipe_implementation.get_session_information(
            session.get_handle(), user_context
        )
        if session
        else None
    )
    if not session_info:
        session = None

    incoming_auth_url_query_params = parse_qs(urlparse(login_request.request_url).query)
    prompt_param = (
        incoming_auth_url_query_params.get("prompt", [None])[0]
        or incoming_auth_url_query_params.get("st_prompt", [None])[0]
    )
    max_age_param = incoming_auth_url_query_params.get("max_age", [None])[0]

    if max_age_param is not None:
        try:
            max_age_parsed = int(max_age_param)

            if max_age_parsed < 0:
                reject = await recipe_implementation.reject_login_request(
                    challenge=login_challenge,
                    error=ErrorOAuth2Response(
                        error="invalid_request",
                        error_description="max_age cannot be negative",
                    ),
                    user_context=user_context,
                )
                return RedirectResponse(
                    redirect_to=reject.redirect_to,
                    cookies=cookies,
                )

        except ValueError:
            reject = await recipe_implementation.reject_login_request(
                challenge=login_challenge,
                error=ErrorOAuth2Response(
                    error="invalid_request",
                    error_description="max_age must be an integer",
                ),
                user_context=user_context,
            )
            return RedirectResponse(
                redirect_to=reject.redirect_to,
                cookies=cookies,
            )

    tenant_id_param = incoming_auth_url_query_params.get("tenant_id", [None])[0]

    if (
        session
        and session_info
        and (
            not login_request.subject or session.get_user_id() == login_request.subject
        )
        and (not tenant_id_param or session.get_tenant_id() == tenant_id_param)
        and (prompt_param != "login" or is_direct_call)
        and (
            max_age_param is None
            or (max_age_param == "0" and is_direct_call)
            or int(max_age_param) * 1000
            > time.time() * 1000 - session_info.time_created
        )
    ):
        accept = await recipe_implementation.accept_login_request(
            challenge=login_challenge,
            acr=None,
            amr=None,
            context=None,
            extend_session_lifespan=None,
            subject=session.get_user_id(),
            identity_provider_session_id=session.get_handle(),
            user_context=user_context,
        )
        return RedirectResponse(
            redirect_to=accept.redirect_to,
            cookies=cookies,
        )

    if should_try_refresh and prompt_param != "login":
        return RedirectResponse(
            redirect_to=await recipe_implementation.get_frontend_redirection_url(
                params=FrontendRedirectionURLTypeTryRefresh(
                    login_challenge=login_challenge,
                ),
                user_context=user_context,
            ),
            cookies=cookies,
        )

    if prompt_param == "none":
        reject = await recipe_implementation.reject_login_request(
            challenge=login_challenge,
            error=ErrorOAuth2Response(
                error="login_required",
                error_description="The Authorization Server requires End-User authentication. Prompt 'none' was requested, but no existing or expired login session was found.",
            ),
            user_context=user_context,
        )
        return RedirectResponse(
            redirect_to=reject.redirect_to,
            cookies=cookies,
        )

    return RedirectResponse(
        redirect_to=await recipe_implementation.get_frontend_redirection_url(
            params=FrontendRedirectionURLTypeLogin(
                login_challenge=login_challenge,
                force_fresh_auth=session is not None or prompt_param == "login",
                tenant_id=tenant_id_param or DEFAULT_TENANT_ID,
                hint=(
                    login_request.oidc_context.get("login_hint")
                    if login_request.oidc_context
                    else None
                ),
            ),
            user_context=user_context,
        ),
        cookies=cookies,
    )


def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) -> str:
    if not new_cookies:
        return orig_cookies

    cookie_map: Dict[str, str] = {}
    for cookie in orig_cookies.split(";"):
        if "=" in cookie:
            name, value = cookie.split("=", 1)
            cookie_map[name.strip()] = value

    # Note: This is a simplified version. In production code you'd want to use a proper
    # cookie parsing library to handle all cookie attributes correctly
    if new_cookies:
        for cookie_str in new_cookies:
            cookie = cookie_str.split(";")[0].strip()
            if "=" in cookie:
                name, value = cookie.split("=", 1)
                cookie_map[name.strip()] = value

    return ";".join(f"{key}={value}" for key, value in cookie_map.items())


def merge_set_cookie_headers(
    set_cookie1: Optional[List[str]] = None, set_cookie2: Optional[List[str]] = None
) -> List[str]:
    if not set_cookie1:
        return set_cookie2 or []
    if not set_cookie2 or set(set_cookie1) == set(set_cookie2):
        return set_cookie1
    return set_cookie1 + set_cookie2


def is_login_internal_redirect(app_info: AppInfo, redirect_to: str) -> bool:
    api_domain = app_info.api_domain.get_as_string_dangerous()
    api_base_path = app_info.api_base_path.get_as_string_dangerous()
    base_path = f"{api_domain}{api_base_path}"

    return any(
        redirect_to.startswith(f"{base_path}{path}") for path in [LOGIN_PATH, AUTH_PATH]
    )


def is_logout_internal_redirect(app_info: AppInfo, redirect_to: str) -> bool:
    api_domain = app_info.api_domain.get_as_string_dangerous()
    api_base_path = app_info.api_base_path.get_as_string_dangerous()
    base_path = f"{api_domain}{api_base_path}"
    return redirect_to.startswith(f"{base_path}{END_SESSION_PATH}")


async def handle_login_internal_redirects(
    app_info: AppInfo,
    response: RedirectResponse,
    recipe_implementation: RecipeInterface,
    session: Optional[SessionContainer],
    should_try_refresh: bool,
    cookie: str,
    user_context: Dict[str, Any],
) -> Union[RedirectResponse, ErrorOAuth2Response]:
    from ..interfaces import ErrorOAuth2Response, RedirectResponse

    if not is_login_internal_redirect(app_info, response.redirect_to):
        return response

    max_redirects = 10
    redirect_count = 0

    while redirect_count < max_redirects and is_login_internal_redirect(
        app_info, response.redirect_to
    ):
        cookie = get_merged_cookies(cookie, response.cookies)

        query_string = (
            response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else ""
        )
        params = parse_qs(query_string)

        if LOGIN_PATH in response.redirect_to:
            login_challenge = (
                params.get("login_challenge", [None])[0]
                or params.get("loginChallenge", [None])[0]
            )
            if not login_challenge:
                raise Exception(f"Expected loginChallenge in {response.redirect_to}")

            login_res = await login_get(
                recipe_implementation=recipe_implementation,
                login_challenge=login_challenge,
                session=session,
                should_try_refresh=should_try_refresh,
                cookies=response.cookies,
                is_direct_call=False,
                user_context=user_context,
            )

            if isinstance(login_res, ErrorOAuth2Response):
                return login_res

            response = RedirectResponse(
                redirect_to=login_res.redirect_to,
                cookies=merge_set_cookie_headers(login_res.cookies, response.cookies),
            )

        elif AUTH_PATH in response.redirect_to:
            auth_res = await recipe_implementation.authorization(
                params={k: v[0] for k, v in params.items()},
                cookies=cookie,
                session=session,
                user_context=user_context,
            )

            if isinstance(auth_res, ErrorOAuth2Response):
                return auth_res

            response = RedirectResponse(
                redirect_to=auth_res.redirect_to,
                cookies=merge_set_cookie_headers(auth_res.cookies, response.cookies),
            )

        else:
            raise Exception(f"Unexpected internal redirect {response.redirect_to}")

        redirect_count += 1

    return response


async def handle_logout_internal_redirects(
    app_info: AppInfo,
    response: RedirectResponse,
    recipe_implementation: RecipeInterface,
    session: Optional[SessionContainer],
    user_context: Dict[str, Any],
) -> Union[RedirectResponse, ErrorOAuth2Response]:
    if not is_logout_internal_redirect(app_info, response.redirect_to):
        return response

    max_redirects = 10
    redirect_count = 0

    while redirect_count < max_redirects and is_logout_internal_redirect(
        app_info, response.redirect_to
    ):
        query_string = (
            response.redirect_to.split("?", 1)[1] if "?" in response.redirect_to else ""
        )
        params = parse_qs(query_string)

        if END_SESSION_PATH in response.redirect_to:
            end_session_res = await recipe_implementation.end_session(
                params={k: v[0] for k, v in params.items()},
                session=session,
                should_try_refresh=False,
                user_context=user_context,
            )
            if isinstance(end_session_res, ErrorOAuth2Response):
                return end_session_res
            response = end_session_res
        else:
            raise Exception(f"Unexpected internal redirect {response.redirect_to}")

        redirect_count += 1

    return response


async def get_session(
    request: Any,
    session_required: Optional[bool] = None,
    anti_csrf_check: Optional[bool] = None,
    check_database: Optional[bool] = None,
    override_global_claim_validators: Optional[
        Callable[
            [List[SessionClaimValidator], SessionContainer, Dict[str, Any]],
            MaybeAwaitable[List[SessionClaimValidator]],
        ]
    ] = None,
    user_context: Union[None, Dict[str, Any]] = None,
) -> Union[SessionContainer, None]:
    if user_context is None:
        user_context = {}

    if session_required is None:
        session_required = True

    recipe_instance = SessionRecipe.get_instance()
    recipe_interface_impl = recipe_instance.recipe_implementation
    config = recipe_instance.config

    return await get_session_from_request(
        request,
        config,
        recipe_interface_impl,
        session_required=session_required,
        anti_csrf_check=anti_csrf_check,
        check_database=check_database,
        override_global_claim_validators=override_global_claim_validators,
        user_context=user_context,
    )

Functions

def get_merged_cookies(orig_cookies: str, new_cookies: Optional[List[str]]) ‑> str
async def get_session(request: Any, session_required: Optional[bool] = None, anti_csrf_check: Optional[bool] = None, check_database: Optional[bool] = None, override_global_claim_validators: Optional[Callable[[List[SessionClaimValidator], SessionContainer, Dict[str, Any]], MaybeAwaitable[List[SessionClaimValidator]]]] = None, user_context: Union[None, Dict[str, Any]] = None)
async def handle_login_internal_redirects(app_info: AppInfo, response: RedirectResponse, recipe_implementation: RecipeInterface, session: Optional[SessionContainer], should_try_refresh: bool, cookie: str, user_context: Dict[str, Any])
async def handle_logout_internal_redirects(app_info: AppInfo, response: RedirectResponse, recipe_implementation: RecipeInterface, session: Optional[SessionContainer], user_context: Dict[str, Any])
def is_login_internal_redirect(app_info: AppInfo, redirect_to: str)
def is_logout_internal_redirect(app_info: AppInfo, redirect_to: str)
async def login_get(recipe_implementation: RecipeInterface, login_challenge: str, session: Optional[SessionContainer], should_try_refresh: bool, cookies: Optional[List[str]], is_direct_call: bool, user_context: Dict[str, Any])