Module supertokens_python.recipe.dashboard.api.users_get

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Dict
from typing_extensions import Literal

from supertokens_python.supertokens import Supertokens

from ...usermetadata import UserMetadataRecipe
from ...usermetadata.asyncio import get_user_metadata
from ..interfaces import DashboardUsersGetResponse
from ..utils import UserWithMetadata

if TYPE_CHECKING:
    from supertokens_python.recipe.dashboard.interfaces import (
        APIOptions,
        APIInterface,
    )
    from supertokens_python.types import APIResponse

from supertokens_python.exceptions import GeneralError, raise_bad_input_exception


async def handle_users_get_api(
    api_implementation: APIInterface,
    tenant_id: str,
    api_options: APIOptions,
    user_context: Dict[str, Any],
) -> APIResponse:
    _ = api_implementation

    limit = api_options.request.get_query_param("limit")
    if limit is None:
        raise_bad_input_exception("Missing required parameter 'limit'")

    time_joined_order: Literal["ASC", "DESC"] = api_options.request.get_query_param(  # type: ignore
        "timeJoinedOrder", "DESC"
    )
    if time_joined_order not in ["ASC", "DESC"]:
        raise_bad_input_exception("Invalid value recieved for 'timeJoinedOrder'")

    pagination_token = api_options.request.get_query_param("paginationToken")

    users_response = await Supertokens.get_instance().get_users(
        tenant_id,
        time_joined_order=time_joined_order,
        limit=int(limit),
        pagination_token=pagination_token,
        include_recipe_ids=None,
        query=api_options.request.get_query_params(),
    )

    # user metadata bulk fetch with batches:

    try:
        UserMetadataRecipe.get_instance()
    except GeneralError:
        return DashboardUsersGetResponse(
            users_response.users, users_response.next_pagination_token
        )

    users_with_metadata: List[UserWithMetadata] = [
        UserWithMetadata().from_user(user) for user in users_response.users
    ]
    metadata_fetch_awaitables: List[Awaitable[Any]] = []

    async def get_user_metadata_and_update_user(user_idx: int) -> None:
        user = users_response.users[user_idx]
        user_metadata = await get_user_metadata(user.user_id, user_context)
        first_name = user_metadata.metadata.get("first_name")
        last_name = user_metadata.metadata.get("last_name")

        # None becomes null which is acceptable for the dashboard.
        users_with_metadata[user_idx].first_name = first_name
        users_with_metadata[user_idx].last_name = last_name

    # Batch calls to get user metadata:
    for i, _ in enumerate(users_response.users):
        metadata_fetch_awaitables.append(get_user_metadata_and_update_user(i))

    promise_arr_start_position = 0
    batch_size = 5

    while promise_arr_start_position < len(metadata_fetch_awaitables):
        # We want to query only 5 in parallel at a time
        promises_to_call = [
            metadata_fetch_awaitables[i]
            for i in range(
                promise_arr_start_position,
                min(
                    promise_arr_start_position + batch_size,
                    len(metadata_fetch_awaitables),
                ),
            )
        ]
        await asyncio.gather(*promises_to_call)

        promise_arr_start_position += batch_size

    return DashboardUsersGetResponse(
        users_with_metadata,
        users_response.next_pagination_token,
    )

Functions

async def handle_users_get_api(api_implementation: APIInterface, tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any]) ‑> APIResponse
Expand source code
async def handle_users_get_api(
    api_implementation: APIInterface,
    tenant_id: str,
    api_options: APIOptions,
    user_context: Dict[str, Any],
) -> APIResponse:
    _ = api_implementation

    limit = api_options.request.get_query_param("limit")
    if limit is None:
        raise_bad_input_exception("Missing required parameter 'limit'")

    time_joined_order: Literal["ASC", "DESC"] = api_options.request.get_query_param(  # type: ignore
        "timeJoinedOrder", "DESC"
    )
    if time_joined_order not in ["ASC", "DESC"]:
        raise_bad_input_exception("Invalid value recieved for 'timeJoinedOrder'")

    pagination_token = api_options.request.get_query_param("paginationToken")

    users_response = await Supertokens.get_instance().get_users(
        tenant_id,
        time_joined_order=time_joined_order,
        limit=int(limit),
        pagination_token=pagination_token,
        include_recipe_ids=None,
        query=api_options.request.get_query_params(),
    )

    # user metadata bulk fetch with batches:

    try:
        UserMetadataRecipe.get_instance()
    except GeneralError:
        return DashboardUsersGetResponse(
            users_response.users, users_response.next_pagination_token
        )

    users_with_metadata: List[UserWithMetadata] = [
        UserWithMetadata().from_user(user) for user in users_response.users
    ]
    metadata_fetch_awaitables: List[Awaitable[Any]] = []

    async def get_user_metadata_and_update_user(user_idx: int) -> None:
        user = users_response.users[user_idx]
        user_metadata = await get_user_metadata(user.user_id, user_context)
        first_name = user_metadata.metadata.get("first_name")
        last_name = user_metadata.metadata.get("last_name")

        # None becomes null which is acceptable for the dashboard.
        users_with_metadata[user_idx].first_name = first_name
        users_with_metadata[user_idx].last_name = last_name

    # Batch calls to get user metadata:
    for i, _ in enumerate(users_response.users):
        metadata_fetch_awaitables.append(get_user_metadata_and_update_user(i))

    promise_arr_start_position = 0
    batch_size = 5

    while promise_arr_start_position < len(metadata_fetch_awaitables):
        # We want to query only 5 in parallel at a time
        promises_to_call = [
            metadata_fetch_awaitables[i]
            for i in range(
                promise_arr_start_position,
                min(
                    promise_arr_start_position + batch_size,
                    len(metadata_fetch_awaitables),
                ),
            )
        ]
        await asyncio.gather(*promises_to_call)

        promise_arr_start_position += batch_size

    return DashboardUsersGetResponse(
        users_with_metadata,
        users_response.next_pagination_token,
    )