Module supertokens_python.recipe.oauth2provider.api.end_session

Expand source code
# Copyright (c) 2024, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import urllib.parse
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, Optional, Union

from supertokens_python.exceptions import raise_bad_input_exception
from supertokens_python.framework import BaseResponse
from supertokens_python.types.response import GeneralErrorResponse
from supertokens_python.utils import send_200_response, send_non_200_response

from .utils import get_session

if TYPE_CHECKING:
    from supertokens_python.recipe.session import SessionContainer

    from ..interfaces import (
        APIInterface,
        APIOptions,
        ErrorOAuth2Response,
        RedirectResponse,
    )

    EndSessionCallable = Callable[
        [Dict[str, str], APIOptions, Optional[SessionContainer], bool, Dict[str, Any]],
        Awaitable[Union[RedirectResponse, ErrorOAuth2Response, GeneralErrorResponse]],
    ]


async def end_session_get(
    _tenant_id: str,
    api_implementation: APIInterface,
    api_options: APIOptions,
    user_context: Dict[str, Any],
):
    if api_implementation.disable_end_session_get is True:
        return None

    orig_url = api_options.request.get_original_url()
    split_url = orig_url.split("?", 1)
    params = (
        dict(urllib.parse.parse_qsl(split_url[1], True)) if len(split_url) > 1 else {}
    )

    return await end_session_common(
        params, api_implementation.end_session_get, api_options, user_context
    )


async def end_session_post(
    _tenant_id: str,
    api_implementation: APIInterface,
    api_options: APIOptions,
    user_context: Dict[str, Any],
):
    if api_implementation.disable_end_session_post is True:
        return None

    params = await api_options.request.get_json_or_form_data()
    if params is None:
        raise_bad_input_exception("Please provide a JSON body or form data")

    return await end_session_common(
        params, api_implementation.end_session_post, api_options, user_context
    )


async def end_session_common(
    params: Dict[str, str],
    api_implementation: Optional[EndSessionCallable],
    options: APIOptions,
    user_context: Dict[str, Any],
) -> Optional[BaseResponse]:
    from supertokens_python.recipe.session.exceptions import TryRefreshTokenError

    from ..interfaces import ErrorOAuth2Response, RedirectResponse

    if api_implementation is None:
        return None

    session = None
    should_try_refresh = False
    try:
        session = await get_session(
            options.request,
            False,
            user_context=user_context,
        )
        should_try_refresh = False
    except Exception as error:
        # We can handle this as if the session is not present, because then we redirect to the frontend,
        # which should handle the validation error
        session = None
        should_try_refresh = isinstance(error, TryRefreshTokenError)

    response = await api_implementation(
        params,
        options,
        session,
        should_try_refresh,
        user_context,
    )

    if isinstance(response, RedirectResponse):
        return options.response.redirect(response.redirect_to)
    elif isinstance(response, ErrorOAuth2Response):
        return send_non_200_response(
            {
                "error": response.error,
                "error_description": response.error_description,
            },
            response.status_code or 400,
            options.response,
        )
    else:
        if isinstance(response, dict):
            return send_200_response(response, options.response)
        else:
            return send_200_response(response.to_json(), options.response)

Functions

async def end_session_common(params: Dict[str, str], api_implementation: Optional[EndSessionCallable], options: APIOptions, user_context: Dict[str, Any])
async def end_session_get(_tenant_id: str, api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any])
async def end_session_post(_tenant_id: str, api_implementation: APIInterface, api_options: APIOptions, user_context: Dict[str, Any])