Module supertokens_python.recipe.dashboard.api.users_get

Expand source code
# Copyright (c) 2021, VRAI Labs and/or its affiliates. All rights reserved.
#
# This software is licensed under the Apache License, Version 2.0 (the
# "License") as published by the Apache Software Foundation.
#
# You may not use this file except in compliance with the License. You may
# obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING, Any, Awaitable, List, Dict

from ...usermetadata import UserMetadataRecipe
from ...usermetadata.asyncio import get_user_metadata
from ..interfaces import DashboardUsersGetResponse
from ..utils import UserWithMetadata

if TYPE_CHECKING:
    from supertokens_python.recipe.dashboard.interfaces import (
        APIOptions,
        APIInterface,
    )

from supertokens_python.exceptions import GeneralError, raise_bad_input_exception
from supertokens_python.asyncio import get_users_newest_first, get_users_oldest_first


async def handle_users_get_api(
    api_implementation: APIInterface,
    tenant_id: str,
    api_options: APIOptions,
    user_context: Dict[str, Any],
) -> DashboardUsersGetResponse:
    _ = api_implementation

    limit = api_options.request.get_query_param("limit")
    if limit is None:
        raise_bad_input_exception("Missing required parameter 'limit'")

    time_joined_order = api_options.request.get_query_param("timeJoinedOrder", "DESC")
    if time_joined_order not in ["ASC", "DESC"]:
        raise_bad_input_exception("Invalid value received for 'timeJoinedOrder'")

    pagination_token = api_options.request.get_query_param("paginationToken")
    query = get_search_params_from_url(api_options.request.get_original_url())

    users_response = await (
        get_users_newest_first
        if time_joined_order == "DESC"
        else get_users_oldest_first
    )(
        tenant_id,
        limit=int(limit),
        pagination_token=pagination_token,
        query=query,
        user_context=user_context,
    )

    try:
        UserMetadataRecipe.get_instance()
    except GeneralError:
        users_with_metadata: List[UserWithMetadata] = [
            UserWithMetadata().from_user(user) for user in users_response.users
        ]
        return DashboardUsersGetResponse(
            users_with_metadata, users_response.next_pagination_token
        )

    users_with_metadata: List[UserWithMetadata] = [
        UserWithMetadata().from_user(user) for user in users_response.users
    ]
    metadata_fetch_awaitables: List[Awaitable[Any]] = []

    async def get_user_metadata_and_update_user(user_idx: int) -> None:
        user = users_response.users[user_idx]
        user_metadata = await get_user_metadata(user.id)
        first_name = user_metadata.metadata.get("first_name")
        last_name = user_metadata.metadata.get("last_name")

        users_with_metadata[user_idx].first_name = first_name
        users_with_metadata[user_idx].last_name = last_name

    for i, _ in enumerate(users_response.users):
        metadata_fetch_awaitables.append(get_user_metadata_and_update_user(i))

    promise_arr_start_position = 0
    batch_size = 5

    while promise_arr_start_position < len(metadata_fetch_awaitables):
        # We want to query only 5 in parallel at a time
        promises_to_call = [
            metadata_fetch_awaitables[i]
            for i in range(
                promise_arr_start_position,
                min(
                    promise_arr_start_position + batch_size,
                    len(metadata_fetch_awaitables),
                ),
            )
        ]
        await asyncio.gather(*promises_to_call)
        promise_arr_start_position += batch_size

    return DashboardUsersGetResponse(
        users_with_metadata,
        users_response.next_pagination_token,
    )


def get_search_params_from_url(path: str) -> Dict[str, str]:
    from urllib.parse import urlparse, parse_qs

    url_object = urlparse("https://example.com" + path)
    params = parse_qs(url_object.query)
    search_query = {
        key: value[0]
        for key, value in params.items()
        if key not in ["limit", "timeJoinedOrder", "paginationToken"]
    }
    return search_query

Functions

def get_search_params_from_url(path: str) ‑> Dict[str, str]
async def handle_users_get_api(api_implementation: APIInterface, tenant_id: str, api_options: APIOptions, user_context: Dict[str, Any])