Module supertokens_python.framework.fastapi.fastapi_middleware

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from typing import Union


def get_middleware():
    from supertokens_python import Supertokens
    from supertokens_python.utils import default_user_context
    from supertokens_python.exceptions import SuperTokensError
    from supertokens_python.framework import BaseResponse
    from supertokens_python.recipe.session import SessionContainer
    from supertokens_python.supertokens import manage_session_post_response

    from starlette.requests import Request
    from starlette.responses import Response
    from starlette.types import ASGIApp, Message, Receive, Scope, Send

    from supertokens_python.framework.fastapi.fastapi_request import (
        FastApiRequest,
    )
    from supertokens_python.framework.fastapi.fastapi_response import (
        FastApiResponse,
    )

    class ASGIMiddleware:
        def __init__(self, app: ASGIApp) -> None:
            self.app = app

        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
            if scope["type"] != "http":  # we pass through the non-http requests, if any
                await self.app(scope, receive, send)
                return

            st = Supertokens.get_instance()

            request = Request(scope, receive=receive)
            custom_request = FastApiRequest(request)
            user_context = default_user_context(custom_request)

            try:
                response = FastApiResponse(Response())
                result: Union[BaseResponse, None] = await st.middleware(
                    custom_request, response, user_context
                )
                if result is None:
                    # This means that the supertokens middleware did not handle the request,
                    # however, we may need to handle the header changes in the response,
                    # based on response mutators used by the session.
                    async def send_wrapper(message: Message):
                        if message["type"] == "http.response.start":
                            # Start message has the headers, so we update the headers here
                            # by using `manage_session_post_response` function, which will
                            # apply all the Response Mutators. In the end, we just replace
                            # the updated headers in the message.
                            if hasattr(request.state, "supertokens") and isinstance(
                                request.state.supertokens, SessionContainer
                            ):
                                fapi_response = Response()
                                fapi_response.raw_headers = message["headers"]
                                response = FastApiResponse(fapi_response)
                                manage_session_post_response(
                                    request.state.supertokens, response, user_context
                                )
                                message["headers"] = fapi_response.raw_headers

                        # For `http.response.start` message, we might have the headers updated,
                        # otherwise, we just send all the messages as is
                        await send(message)

                    await self.app(scope, receive, send_wrapper)
                    return

                # This means that the request was handled by the supertokens middleware
                # and hence we respond using the response object returned by the middleware.
                if hasattr(request.state, "supertokens") and isinstance(
                    request.state.supertokens, SessionContainer
                ):
                    manage_session_post_response(
                        request.state.supertokens, result, user_context
                    )

                if isinstance(result, FastApiResponse):
                    await result.response(scope, receive, send)
                    return

                return

            except SuperTokensError as e:
                response = FastApiResponse(Response())
                result: Union[BaseResponse, None] = await st.handle_supertokens_error(
                    FastApiRequest(request), e, response, user_context
                )
                if isinstance(result, FastApiResponse):
                    await result.response(scope, receive, send)
                    return

            raise Exception("Should never come here")

    return ASGIMiddleware

Functions

def get_middleware()