from typing import (
    Any,
    Awaitable,
    Callable,
    cast,
    Dict,
    Sequence,
    Optional,
    Type,
    TypeVar,
    Tuple,
)
import fastapi
import orjson
from anyio import (
    to_thread,
    CapacityLimiter,
)
from fastapi import FastAPI as _FastAPI, Response, Request
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
from fastapi.routing import APIRoute
from fastapi import HTTPException, status
from functools import wraps

from chromadb.api.configuration import CollectionConfigurationInternal
from pydantic import BaseModel
from chromadb.api.types import (
    Embedding,
    GetResult,
    QueryResult,
    Embeddings,
    convert_list_embeddings_to_np,
)
from chromadb.auth import UserIdentity
from chromadb.auth import (
    AuthzAction,
    AuthzResource,
    ServerAuthenticationProvider,
    ServerAuthorizationProvider,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.api import ServerAPI
from chromadb.errors import (
    ChromaError,
    InvalidDimensionException,
    InvalidHTTPVersion,
    RateLimitError,
    QuotaError,
)
from chromadb.quota import QuotaEnforcer
from chromadb.rate_limit import AsyncRateLimitEnforcer
from chromadb.server import Server
from chromadb.server.fastapi.types import (
    AddEmbedding,
    CreateDatabase,
    CreateTenant,
    DeleteEmbedding,
    GetEmbedding,
    QueryEmbedding,
    CreateCollection,
    UpdateCollection,
    UpdateEmbedding,
)
from starlette.datastructures import Headers
import logging
import importlib.metadata

from chromadb.telemetry.product.events import ServerStartEvent
from chromadb.utils.fastapi import fastapi_json_response, string_to_uuid as _uuid
from opentelemetry import trace

from chromadb.telemetry.opentelemetry.fastapi import instrument_fastapi
from chromadb.types import Database, Tenant
from chromadb.telemetry.product import ServerContext, ProductTelemetryClient
from chromadb.telemetry.opentelemetry import (
    OpenTelemetryClient,
    OpenTelemetryGranularity,
    add_attributes_to_current_span,
    trace_method,
)
from chromadb.types import Collection as CollectionModel

logger = logging.getLogger(__name__)


def rate_limit(func):
    @wraps(func)
    async def wrapper(*args: Any, **kwargs: Any) -> Any:
        self = args[0]
        return await self._async_rate_limit_enforcer.rate_limit(func)(*args, **kwargs)
    return wrapper


def use_route_names_as_operation_ids(app: _FastAPI) -> None:
    """
    Simplify operation IDs so that generated API clients have simpler function
    names.
    Should be called only after all routes have been added.
    """
    for route in app.routes:
        if isinstance(route, APIRoute):
            route.operation_id = route.name


async def add_trace_id_to_response_middleware(
    request: Request, call_next: Callable[[Request], Any]
) -> Response:
    trace_id = trace.get_current_span().get_span_context().trace_id
    response = await call_next(request)
    response.headers["Chroma-Trace-Id"] = format(trace_id, "x")
    return response


async def catch_exceptions_middleware(
    request: Request, call_next: Callable[[Request], Any]
) -> Response:
    try:
        return await call_next(request)
    except ChromaError as e:
        return fastapi_json_response(e)
    except ValueError as e:
        return ORJSONResponse(
            content={"error": "InvalidArgumentError", "message": str(e)},
            status_code=400,
        )
    except TypeError as e:
        return ORJSONResponse(
            content={"error": "InvalidArgumentError", "message": str(e)},
            status_code=400,
        )
    except Exception as e:
        logger.exception(e)
        return ORJSONResponse(content={"error": repr(e)}, status_code=500)


async def check_http_version_middleware(
    request: Request, call_next: Callable[[Request], Any]
) -> Response:
    http_version = request.scope.get("http_version")
    if http_version not in ["1.1", "2"]:
        raise InvalidHTTPVersion(f"HTTP version {http_version} is not supported")
    return await call_next(request)


D = TypeVar("D", bound=BaseModel, contravariant=True)


def validate_model(model: Type[D], data: Any) -> D:  # type: ignore
    """Used for backward compatibility with Pydantic 1.x"""
    try:
        return model.model_validate(data)  # pydantic 2.x
    except AttributeError:
        return model.parse_obj(data)  # pydantic 1.x


class ChromaAPIRouter(fastapi.APIRouter):  # type: ignore
    # A simple subclass of fastapi's APIRouter which treats URLs with a
    # trailing "/" the same as URLs without. Docs will only contain URLs
    # without trailing "/"s.
    def add_api_route(self, path: str, *args: Any, **kwargs: Any) -> None:
        # If kwargs["include_in_schema"] isn't passed OR is True, we should
        # only include the non-"/" path. If kwargs["include_in_schema"] is
        # False, include neither.
        exclude_from_schema = (
            "include_in_schema" in kwargs and not kwargs["include_in_schema"]
        )

        def include_in_schema(path: str) -> bool:
            nonlocal exclude_from_schema
            return not exclude_from_schema and not path.endswith("/")

        kwargs["include_in_schema"] = include_in_schema(path)
        super().add_api_route(path, *args, **kwargs)

        if path.endswith("/"):
            path = path[:-1]
        else:
            path = path + "/"

        kwargs["include_in_schema"] = include_in_schema(path)
        super().add_api_route(path, *args, **kwargs)


class FastAPI(Server):
    def __init__(self, settings: Settings):
        ProductTelemetryClient.SERVER_CONTEXT = ServerContext.FASTAPI
        # https://fastapi.tiangolo.com/advanced/custom-response/#use-orjsonresponse
        self._app = fastapi.FastAPI(debug=True, default_response_class=ORJSONResponse)
        self._system = System(settings)
        self._api: ServerAPI = self._system.instance(ServerAPI)

        self._extra_openapi_schemas: Dict[str, Any] = {}
        self._app.openapi = self.generate_openapi

        self._opentelemetry_client = self._api.require(OpenTelemetryClient)
        self._capacity_limiter = CapacityLimiter(
            settings.chroma_server_thread_pool_size
        )
        self._quota_enforcer = self._system.require(QuotaEnforcer)
        self._system.start()

        self._app.middleware("http")(check_http_version_middleware)
        self._app.middleware("http")(catch_exceptions_middleware)
        self._app.middleware("http")(add_trace_id_to_response_middleware)
        self._app.add_middleware(
            CORSMiddleware,
            allow_headers=["*"],
            allow_origins=settings.chroma_server_cors_allow_origins,
            allow_methods=["*"],
        )
        self._app.add_exception_handler(QuotaError, self.quota_exception_handler)
        self._app.add_exception_handler(
            RateLimitError, self.rate_limit_exception_handler
        )
        self._async_rate_limit_enforcer = self._system.require(AsyncRateLimitEnforcer)

        self._app.on_event("shutdown")(self.shutdown)

        self.authn_provider = None
        if settings.chroma_server_authn_provider:
            self.authn_provider = self._system.require(ServerAuthenticationProvider)

        self.authz_provider = None
        if settings.chroma_server_authz_provider:
            self.authz_provider = self._system.require(ServerAuthorizationProvider)

        self.router = ChromaAPIRouter()

        self.setup_v1_routes()
        self.setup_v2_routes()

        self._app.include_router(self.router)

        use_route_names_as_operation_ids(self._app)
        instrument_fastapi(self._app)
        telemetry_client = self._system.instance(ProductTelemetryClient)
        telemetry_client.capture(ServerStartEvent())

    def generate_openapi(self) -> Dict[str, Any]:
        """Used instead of the default openapi() generation handler to include manually-populated schemas."""
        schema: Dict[str, Any] = get_openapi(
            title="Chroma",
            routes=self._app.routes,
            version=importlib.metadata.version("chromadb"),
        )

        for key, value in self._extra_openapi_schemas.items():
            schema["components"]["schemas"][key] = value

        return schema

    def get_openapi_extras_for_body_model(
        self, request_model: Type[D]
    ) -> Dict[str, Any]:
        schema = request_model.model_json_schema(
            ref_template="#/components/schemas/{model}"
        )
        if "$defs" in schema:
            for key, value in schema["$defs"].items():
                self._extra_openapi_schemas[key] = value

        openapi_extra = {
            "requestBody": {
                "content": {"application/json": {"schema": schema}},
                "required": True,
            }
        }
        return openapi_extra

    def setup_v2_routes(self) -> None:
        self.router.add_api_route("/api/v2", self.root, methods=["GET"])
        self.router.add_api_route("/api/v2/reset", self.reset, methods=["POST"])
        self.router.add_api_route("/api/v2/version", self.version, methods=["GET"])
        self.router.add_api_route("/api/v2/heartbeat", self.heartbeat, methods=["GET"])
        self.router.add_api_route(
            "/api/v2/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
        )

        self.router.add_api_route(
            "/api/v2/auth/identity",
            self.get_user_identity,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases",
            self.create_database,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}",
            self.get_database,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}",
            self.delete_database,
            methods=["DELETE"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v2/tenants",
            self.create_tenant,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}",
            self.get_tenant,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases",
            self.list_databases,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections",
            self.list_collections,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections_count",
            self.count_collections,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections",
            self.create_collection,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
        )

        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/add",
            self.add,
            methods=["POST"],
            status_code=status.HTTP_201_CREATED,
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/update",
            self.update,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/upsert",
            self.upsert,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/get",
            self.get,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/delete",
            self.delete,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/count",
            self.count,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}/query",
            self.get_nearest_neighbors,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(
                request_model=QueryEmbedding
            ),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
            self.get_collection,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_id}",
            self.update_collection,
            methods=["PUT"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
        )
        self.router.add_api_route(
            "/api/v2/tenants/{tenant}/databases/{database_name}/collections/{collection_name}",
            self.delete_collection,
            methods=["DELETE"],
            response_model=None,
        )

    def shutdown(self) -> None:
        self._system.stop()

    def app(self) -> fastapi.FastAPI:
        return self._app

    async def rate_limit_exception_handler(
        self, request: Request, exc: RateLimitError
    ) -> ORJSONResponse:
        return ORJSONResponse(
            status_code=429,
            content={"message": "Rate limit exceeded."},
        )

    def root(self) -> Dict[str, int]:
        return {"nanosecond heartbeat": self._api.heartbeat()}

    async def quota_exception_handler(
        self, request: Request, exc: QuotaError
    ) -> ORJSONResponse:
        return ORJSONResponse(
            status_code=400,
            content={"message": exc.message()},
        )

    async def heartbeat(self) -> Dict[str, int]:
        return self.root()

    async def version(self) -> str:
        return self._api.get_version()

    def _set_request_context(self, request: Request) -> None:
        """
        Set context about the request on any components that might need it.
        """
        self._quota_enforcer.set_context(context={"request": request})

    @trace_method(
        "auth_request",
        OpenTelemetryGranularity.OPERATION,
    )
    @rate_limit
    async def auth_request(
        self,
        headers: Headers,
        action: AuthzAction,
        tenant: Optional[str],
        database: Optional[str],
        collection: Optional[str],
    ) -> None:
        return await to_thread.run_sync(self.sync_auth_request, *(headers, action, tenant, database, collection))

    def sync_auth_request(
        self,
        headers: Headers,
        action: AuthzAction,
        tenant: Optional[str],
        database: Optional[str],
        collection: Optional[str],
    ) -> None:
        """
        Authenticates and authorizes the request based on the given headers
        and other parameters. If the request cannot be authenticated or cannot
        be authorized (with the configured providers), raises an HTTP 401.
        """
        if not self.authn_provider:
            add_attributes_to_current_span(
                {
                    "tenant": tenant,
                    "database": database,
                    "collection": collection,
                }
            )
            return

        user_identity = self.authn_provider.authenticate_or_raise(dict(headers))

        if not self.authz_provider:
            return

        authz_resource = AuthzResource(
            tenant=tenant,
            database=database,
            collection=collection,
        )

        self.authz_provider.authorize_or_raise(user_identity, action, authz_resource)
        add_attributes_to_current_span(
            {
                "tenant": tenant,
                "database": database,
                "collection": collection,
            }
        )
        return

    @trace_method("FastAPI.get_user_identity", OpenTelemetryGranularity.OPERATION)
    async def get_user_identity(
        self,
        request: Request,
    ) -> UserIdentity:
        if not self.authn_provider:
            return UserIdentity(
                user_id="", tenant=DEFAULT_TENANT, databases=[DEFAULT_DATABASE]
            )

        return cast(
            UserIdentity,
            await to_thread.run_sync(
                lambda: cast(ServerAuthenticationProvider, self.authn_provider).authenticate_or_raise(dict(request.headers))  # type: ignore
            ),
        )

    @trace_method("FastAPI.create_database", OpenTelemetryGranularity.OPERATION)
    async def create_database(
        self,
        request: Request,
        tenant: str,
    ) -> None:
        def process_create_database(
            tenant: str, headers: Headers, raw_body: bytes
        ) -> None:
            db = validate_model(CreateDatabase, orjson.loads(raw_body))

            self.sync_auth_request(
                headers,
                AuthzAction.CREATE_DATABASE,
                tenant,
                db.name,
                None,
            )

            self._set_request_context(request=request)

            return self._api.create_database(db.name, tenant)

        await to_thread.run_sync(
            process_create_database,
            tenant,
            request.headers,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get_database", OpenTelemetryGranularity.OPERATION)
    async def get_database(
        self,
        request: Request,
        database_name: str,
        tenant: str,
    ) -> Database:
        await self.auth_request(
            request.headers,
            AuthzAction.GET_DATABASE,
            tenant,
            database_name,
            None,
        )

        return cast(
            Database,
            await to_thread.run_sync(
                self._api.get_database,
                database_name,
                tenant,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.delete_database", OpenTelemetryGranularity.OPERATION)
    async def delete_database(
        self,
        request: Request,
        database_name: str,
        tenant: str,
    ) -> None:
        self.auth_request(
            request.headers,
            AuthzAction.DELETE_DATABASE,
            tenant,
            database_name,
            None,
        )

        await to_thread.run_sync(
            self._api.delete_database,
            database_name,
            tenant,
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.create_tenant", OpenTelemetryGranularity.OPERATION)
    async def create_tenant(
        self,
        request: Request,
    ) -> None:
        def process_create_tenant(request: Request, raw_body: bytes) -> None:
            tenant = validate_model(CreateTenant, orjson.loads(raw_body))

            self.sync_auth_request(
                request.headers,
                AuthzAction.CREATE_TENANT,
                tenant.name,
                None,
                None,
            )


            return self._api.create_tenant(tenant.name)

        await to_thread.run_sync(
            process_create_tenant,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get_tenant", OpenTelemetryGranularity.OPERATION)
    async def get_tenant(
        self,
        request: Request,
        tenant: str,
    ) -> Tenant:
        await self.auth_request(
            request.headers,
            AuthzAction.GET_TENANT,
            tenant,
            None,
            None,
        )

        return cast(
            Tenant,
            await to_thread.run_sync(
                self._api.get_tenant,
                tenant,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.list_databases", OpenTelemetryGranularity.OPERATION)
    async def list_databases(
        self,
        request: Request,
        tenant: str,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
    ) -> Sequence[Database]:
        await self.auth_request(
            request.headers,
            AuthzAction.LIST_DATABASES,
            tenant,
            None,
            None,
        )

        return cast(
            Sequence[Database],
            await to_thread.run_sync(
                self._api.list_databases,
                limit,
                offset,
                tenant,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.list_collections", OpenTelemetryGranularity.OPERATION)
    async def list_collections(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
    ) -> Sequence[CollectionModel]:
        def process_list_collections(
            limit: Optional[int], offset: Optional[int], tenant: str, database_name: str
        ) -> Sequence[CollectionModel]:
            self.sync_auth_request(
                request.headers,
                AuthzAction.LIST_COLLECTIONS,
                tenant,
                database_name,
                None,
            )

            self._set_request_context(request=request)

            add_attributes_to_current_span({"tenant": tenant})
            return self._api.list_collections(
                tenant=tenant, database=database_name, limit=limit, offset=offset
            )

        api_collection_models = cast(
            Sequence[CollectionModel],
            await to_thread.run_sync(
                process_list_collections,
                limit,
                offset,
                tenant,
                database_name,
                limiter=self._capacity_limiter,
            ),
        )

        return api_collection_models

    @trace_method("FastAPI.count_collections", OpenTelemetryGranularity.OPERATION)
    async def count_collections(
        self,
        request: Request,
        tenant: str,
        database_name: str,
    ) -> int:
        await self.auth_request(
            request.headers,
            AuthzAction.COUNT_COLLECTIONS,
            tenant,
            database_name,
            None,
        )

        add_attributes_to_current_span({"tenant": tenant})

        return cast(
            int,
            await to_thread.run_sync(
                self._api.count_collections,
                tenant,
                database_name,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.create_collection", OpenTelemetryGranularity.OPERATION)
    async def create_collection(
        self,
        request: Request,
        tenant: str,
        database_name: str,
    ) -> CollectionModel:
        def process_create_collection(
            request: Request, tenant: str, database: str, raw_body: bytes
        ) -> CollectionModel:
            create = validate_model(CreateCollection, orjson.loads(raw_body))
            configuration = (
                CollectionConfigurationInternal()
                if not create.configuration
                else CollectionConfigurationInternal.from_json(create.configuration)
            )

            self.sync_auth_request(
                request.headers,
                AuthzAction.CREATE_COLLECTION,
                tenant,
                database,
                create.name,
            )

            self._set_request_context(request=request)

            add_attributes_to_current_span({"tenant": tenant})

            return self._api.create_collection(
                name=create.name,
                configuration=configuration,
                metadata=create.metadata,
                get_or_create=create.get_or_create,
                tenant=tenant,
                database=database,
            )

        api_collection_model = cast(
            CollectionModel,
            await to_thread.run_sync(
                process_create_collection,
                request,
                tenant,
                database_name,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )
        return api_collection_model

    @trace_method("FastAPI.get_collection", OpenTelemetryGranularity.OPERATION)
    async def get_collection(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        collection_name: str,
    ) -> CollectionModel:
        await self.auth_request(
            request.headers,
            AuthzAction.GET_COLLECTION,
            tenant,
            database_name,
            collection_name,
        )

        add_attributes_to_current_span({"tenant": tenant})

        api_collection_model = cast(
            CollectionModel,
            await to_thread.run_sync(
                self._api.get_collection,
                collection_name,
                tenant,
                database_name,
                limiter=self._capacity_limiter,
            ),
        )
        return api_collection_model

    @trace_method("FastAPI.update_collection", OpenTelemetryGranularity.OPERATION)
    async def update_collection(
        self,
        tenant: str,
        database_name: str,
        collection_id: str,
        request: Request,
    ) -> None:
        def process_update_collection(
            request: Request, collection_id: str, raw_body: bytes
        ) -> None:
            update = validate_model(UpdateCollection, orjson.loads(raw_body))
            self.sync_auth_request(
                request.headers,
                AuthzAction.UPDATE_COLLECTION,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})
            return self._api._modify(
                id=_uuid(collection_id),
                new_name=update.new_name,
                new_metadata=update.new_metadata,
                tenant=tenant,
                database=database_name,
            )

        await to_thread.run_sync(
            process_update_collection,
            request,
            collection_id,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.delete_collection", OpenTelemetryGranularity.OPERATION)
    async def delete_collection(
        self,
        request: Request,
        collection_name: str,
        tenant: str,
        database_name: str,
    ) -> None:
        await self.auth_request(
            request.headers,
            AuthzAction.DELETE_COLLECTION,
            tenant,
            database_name,
            collection_name,
        )
        add_attributes_to_current_span({"tenant": tenant})

        await to_thread.run_sync(
            self._api.delete_collection,
            collection_name,
            tenant,
            database_name,
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.add", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def add(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        collection_id: str,
    ) -> bool:
        try:

            def process_add(request: Request, raw_body: bytes) -> bool:
                add = validate_model(AddEmbedding, orjson.loads(raw_body))
                self.sync_auth_request(
                    request.headers,
                    AuthzAction.ADD,
                    tenant,
                    database_name,
                    collection_id,
                )
                self._set_request_context(request=request)
                add_attributes_to_current_span({"tenant": tenant})
                return self._api._add(
                    collection_id=_uuid(collection_id),
                    ids=add.ids,
                    embeddings=cast(
                        Embeddings,
                        convert_list_embeddings_to_np(add.embeddings)
                        if add.embeddings
                        else None,
                    ),
                    metadatas=add.metadatas,  # type: ignore
                    documents=add.documents,  # type: ignore
                    uris=add.uris,  # type: ignore
                    tenant=tenant,
                    database=database_name,
                )

            return cast(
                bool,
                await to_thread.run_sync(
                    process_add,
                    request,
                    await request.body(),
                    limiter=self._capacity_limiter,
                ),
            )
        except InvalidDimensionException as e:
            raise HTTPException(status_code=500, detail=str(e))

    @trace_method("FastAPI.update", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def update(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        collection_id: str,
    ) -> None:
        def process_update(request: Request, raw_body: bytes) -> bool:
            update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

            self.sync_auth_request(
                request.headers,
                AuthzAction.UPDATE,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})

            return self._api._update(
                collection_id=_uuid(collection_id),
                ids=update.ids,
                embeddings=convert_list_embeddings_to_np(update.embeddings)
                if update.embeddings
                else None,
                metadatas=update.metadatas,  # type: ignore
                documents=update.documents,  # type: ignore
                uris=update.uris,  # type: ignore
                tenant=tenant,
                database=database_name,
            )

        await to_thread.run_sync(
            process_update,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.upsert", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def upsert(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        collection_id: str,
    ) -> None:
        def process_upsert(request: Request, raw_body: bytes) -> bool:
            upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

            self.sync_auth_request(
                request.headers,
                AuthzAction.UPSERT,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})

            return self._api._upsert(
                collection_id=_uuid(collection_id),
                ids=upsert.ids,
                embeddings=cast(
                    Embeddings,
                    convert_list_embeddings_to_np(upsert.embeddings)
                    if upsert.embeddings
                    else None,
                ),
                metadatas=upsert.metadatas,  # type: ignore
                documents=upsert.documents,  # type: ignore
                uris=upsert.uris,  # type: ignore
                tenant=tenant,
                database=database_name,
            )

        await to_thread.run_sync(
            process_upsert,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get(
        self,
        collection_id: str,
        tenant: str,
        database_name: str,
        request: Request,
    ) -> GetResult:
        def process_get(request: Request, raw_body: bytes) -> GetResult:
            get = validate_model(GetEmbedding, orjson.loads(raw_body))
            self.sync_auth_request(
                request.headers,
                AuthzAction.GET,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})
            return self._api._get(
                collection_id=_uuid(collection_id),
                ids=get.ids,
                where=get.where,
                sort=get.sort,
                limit=get.limit,
                offset=get.offset,
                where_document=get.where_document,
                include=get.include,
                tenant=tenant,
                database=database_name,
            )

        get_result = cast(
            GetResult,
            await to_thread.run_sync(
                process_get,
                request,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )

        if get_result["embeddings"] is not None:
            get_result["embeddings"] = [
                cast(Embedding, embedding).tolist()
                for embedding in get_result["embeddings"]
            ]

        return get_result

    @trace_method("FastAPI.delete", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def delete(
        self,
        collection_id: str,
        tenant: str,
        database_name: str,
        request: Request,
    ) -> None:
        def process_delete(request: Request, raw_body: bytes) -> None:
            delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
            self.sync_auth_request(
                request.headers,
                AuthzAction.DELETE,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})
            return self._api._delete(
                collection_id=_uuid(collection_id),
                ids=delete.ids,
                where=delete.where,
                where_document=delete.where_document,
                tenant=tenant,
                database=database_name,
            )

        await to_thread.run_sync(
            process_delete,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.count", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def count(
        self,
        request: Request,
        tenant: str,
        database_name: str,
        collection_id: str,
    ) -> int:
        await self.auth_request(
            request.headers,
            AuthzAction.COUNT,
            tenant,
            database_name,
            collection_id,
        )
        add_attributes_to_current_span({"tenant": tenant})

        return cast(
            int,
            await to_thread.run_sync(
                self._api._count,
                _uuid(collection_id),
                tenant,
                database_name,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.reset", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def reset(
        self,
        request: Request,
    ) -> bool:
        await self.auth_request(
            request.headers,
            AuthzAction.RESET,
            None,
            None,
            None,
        )

        return cast(
            bool,
            await to_thread.run_sync(
                self._api.reset,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.get_nearest_neighbors", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get_nearest_neighbors(
        self,
        tenant: str,
        database_name: str,
        collection_id: str,
        request: Request,
    ) -> QueryResult:
        def process_query(request: Request, raw_body: bytes) -> QueryResult:
            query = validate_model(QueryEmbedding, orjson.loads(raw_body))

            self.sync_auth_request(
                request.headers,
                AuthzAction.QUERY,
                tenant,
                database_name,
                collection_id,
            )
            self._set_request_context(request=request)
            add_attributes_to_current_span({"tenant": tenant})

            return self._api._query(
                collection_id=_uuid(collection_id),
                query_embeddings=cast(
                    Embeddings,
                    convert_list_embeddings_to_np(query.query_embeddings)
                    if query.query_embeddings
                    else None,
                ),
                n_results=query.n_results,
                where=query.where,
                where_document=query.where_document,
                include=query.include,
                tenant=tenant,
                database=database_name,
            )

        nnresult = cast(
            QueryResult,
            await to_thread.run_sync(
                process_query,
                request,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )

        if nnresult["embeddings"] is not None:
            nnresult["embeddings"] = [
                [cast(Embedding, embedding).tolist() for embedding in result]
                for result in nnresult["embeddings"]
            ]

        return nnresult

    async def pre_flight_checks(self) -> Dict[str, Any]:
        def process_pre_flight_checks() -> Dict[str, Any]:
            return {
                "max_batch_size": self._api.get_max_batch_size(),
            }

        return cast(
            Dict[str, Any],
            await to_thread.run_sync(
                process_pre_flight_checks,
                limiter=self._capacity_limiter,
            ),
        )

    # =========================================================================
    # OLD ROUTES FOR BACKWARDS COMPATIBILITY — WILL BE REMOVED
    # =========================================================================
    def setup_v1_routes(self) -> None:
        self.router.add_api_route("/api/v1", self.root, methods=["GET"])
        self.router.add_api_route("/api/v1/reset", self.reset, methods=["POST"])
        self.router.add_api_route("/api/v1/version", self.version, methods=["GET"])
        self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"])
        self.router.add_api_route(
            "/api/v1/pre-flight-checks", self.pre_flight_checks, methods=["GET"]
        )

        self.router.add_api_route(
            "/api/v1/databases",
            self.create_database_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateDatabase),
        )

        self.router.add_api_route(
            "/api/v1/databases/{database}",
            self.get_database_v1,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v1/tenants",
            self.create_tenant_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateTenant),
        )

        self.router.add_api_route(
            "/api/v1/tenants/{tenant}",
            self.get_tenant_v1,
            methods=["GET"],
            response_model=None,
        )

        self.router.add_api_route(
            "/api/v1/collections",
            self.list_collections_v1,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v1/count_collections",
            self.count_collections_v1,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v1/collections",
            self.create_collection_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(CreateCollection),
        )

        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/add",
            self.add_v1,
            methods=["POST"],
            status_code=status.HTTP_201_CREATED,
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/update",
            self.update_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(UpdateEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/upsert",
            self.upsert_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(AddEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/get",
            self.get_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(GetEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/delete",
            self.delete_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(DeleteEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/count",
            self.count_v1,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}/query",
            self.get_nearest_neighbors_v1,
            methods=["POST"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(QueryEmbedding),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_name}",
            self.get_collection_v1,
            methods=["GET"],
            response_model=None,
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_id}",
            self.update_collection_v1,
            methods=["PUT"],
            response_model=None,
            openapi_extra=self.get_openapi_extras_for_body_model(UpdateCollection),
        )
        self.router.add_api_route(
            "/api/v1/collections/{collection_name}",
            self.delete_collection_v1,
            methods=["DELETE"],
            response_model=None,
        )

    @trace_method(
        "auth_and_get_tenant_and_database_for_request_v1",
        OpenTelemetryGranularity.OPERATION,
    )
    @rate_limit
    async def auth_and_get_tenant_and_database_for_request(
        self,
        headers: Headers,
        action: AuthzAction,
        tenant: Optional[str],
        database: Optional[str],
        collection: Optional[str],
    ) -> Tuple[Optional[str], Optional[str]]:
        """
        Authenticates and authorizes the request based on the given headers
        and other parameters. If the request cannot be authenticated or cannot
        be authorized (with the configured providers), raises an HTTP 401.

        If the request is authenticated and authorized, returns the tenant and
        database to be used for the request. These will differ from the passed
        tenant and database if and only if:
        - The request is authenticated
        - chroma_overwrite_singleton_tenant_database_access_from_auth = True
        - The passed tenant or database are None or default_{tenant, database}
            (can be overwritten separately)
        - The user has access to a single tenant and/or single database.
        """
        return await to_thread.run_sync(self.auth_and_get_tenant_and_database_for_request, headers, action, tenant, database, collection)

    def sync_auth_and_get_tenant_and_database_for_request(
        self,
        headers: Headers,
        action: AuthzAction,
        tenant: Optional[str],
        database: Optional[str],
        collection: Optional[str],
    ) -> Tuple[Optional[str], Optional[str]]:

        if not self.authn_provider:
            add_attributes_to_current_span(
                {
                    "tenant": tenant,
                    "database": database,
                    "collection": collection,
                }
            )
            return (tenant, database)

        user_identity = self.authn_provider.authenticate_or_raise(dict(headers))

        (
            new_tenant,
            new_database,
        ) = self.authn_provider.singleton_tenant_database_if_applicable(user_identity)

        if (not tenant or tenant == DEFAULT_TENANT) and new_tenant:
            tenant = new_tenant
        if (not database or database == DEFAULT_DATABASE) and new_database:
            database = new_database

        if not self.authz_provider:
            return (tenant, database)

        authz_resource = AuthzResource(
            tenant=tenant,
            database=database,
            collection=collection,
        )

        self.authz_provider.authorize_or_raise(user_identity, action, authz_resource)
        add_attributes_to_current_span(
            {
                "tenant": tenant,
                "database": database,
                "collection": collection,
            }
        )
        return (tenant, database)

    @trace_method("FastAPI.create_database_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def create_database_v1(
        self,
        request: Request,
        tenant: str = DEFAULT_TENANT,
    ) -> None:
        def process_create_database(
            tenant: str, headers: Headers, raw_body: bytes
        ) -> None:
            db = validate_model(CreateDatabase, orjson.loads(raw_body))

            (
                maybe_tenant,
                maybe_database,
            ) = self.sync_auth_and_get_tenant_and_database_for_request(
                headers,
                AuthzAction.CREATE_DATABASE,
                tenant,
                db.name,
                None,
            )
            if maybe_tenant:
                tenant = maybe_tenant
            if maybe_database:
                db.name = maybe_database

            return self._api.create_database(db.name, tenant)

        await to_thread.run_sync(
            process_create_database,
            tenant,
            request.headers,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get_database_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get_database_v1(
        self,
        request: Request,
        database: str,
        tenant: str = DEFAULT_TENANT,
    ) -> Database:
        (
            maybe_tenant,
            maybe_database,
        ) = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.GET_DATABASE,
            tenant,
            database,
            None,
        )
        if maybe_tenant:
            tenant = maybe_tenant
        if maybe_database:
            database = maybe_database

        return cast(
            Database,
            await to_thread.run_sync(
                self._api.get_database,
                database,
                tenant,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.create_tenant_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def create_tenant_v1(
        self,
        request: Request,
    ) -> None:
        def process_create_tenant(request: Request, raw_body: bytes) -> None:
            tenant = validate_model(CreateTenant, orjson.loads(raw_body))

            maybe_tenant, _ = self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.CREATE_TENANT,
                tenant.name,
                None,
                None,
            )
            if maybe_tenant:
                tenant.name = maybe_tenant

            return self._api.create_tenant(tenant.name)

        await to_thread.run_sync(
            process_create_tenant,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get_tenant_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get_tenant_v1(
        self,
        request: Request,
        tenant: str,
    ) -> Tenant:
        maybe_tenant, _ = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.GET_TENANT,
            tenant,
            None,
            None,
        )
        if maybe_tenant:
            tenant = maybe_tenant

        return cast(
            Tenant,
            await to_thread.run_sync(
                self._api.get_tenant,
                tenant,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.list_collections_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def list_collections_v1(
        self,
        request: Request,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> Sequence[CollectionModel]:
        (
            maybe_tenant,
            maybe_database,
        ) = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.LIST_COLLECTIONS,
            tenant,
            database,
            None,
        )
        if maybe_tenant:
            tenant = maybe_tenant
        if maybe_database:
            database = maybe_database

        api_collection_models = cast(
            Sequence[CollectionModel],
            await to_thread.run_sync(
                self._api.list_collections,
                limit,
                offset,
                tenant,
                database,
                limiter=self._capacity_limiter,
            ),
        )

        return api_collection_models

    @trace_method("FastAPI.count_collections_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def count_collections_v1(
        self,
        request: Request,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> int:
        (
            maybe_tenant,
            maybe_database,
        ) = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.COUNT_COLLECTIONS,
            tenant,
            database,
            None,
        )
        if maybe_tenant:
            tenant = maybe_tenant
        if maybe_database:
            database = maybe_database

        return cast(
            int,
            await to_thread.run_sync(
                self._api.count_collections,
                tenant,
                database,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.create_collection_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def create_collection_v1(
        self,
        request: Request,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> CollectionModel:
        def process_create_collection(
            request: Request, tenant: str, database: str, raw_body: bytes
        ) -> CollectionModel:
            create = validate_model(CreateCollection, orjson.loads(raw_body))
            configuration = (
                CollectionConfigurationInternal()
                if not create.configuration
                else CollectionConfigurationInternal.from_json(create.configuration)
            )

            (
                maybe_tenant,
                maybe_database,
            ) = self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.CREATE_COLLECTION,
                tenant,
                database,
                create.name,
            )
            if maybe_tenant:
                tenant = maybe_tenant
            if maybe_database:
                database = maybe_database

            return self._api.create_collection(
                name=create.name,
                configuration=configuration,
                metadata=create.metadata,
                get_or_create=create.get_or_create,
                tenant=tenant,
                database=database,
            )

        api_collection_model = cast(
            CollectionModel,
            await to_thread.run_sync(
                process_create_collection,
                request,
                tenant,
                database,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )
        return api_collection_model

    @trace_method("FastAPI.get_collection_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get_collection_v1(
        self,
        request: Request,
        collection_name: str,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> CollectionModel:
        (
            maybe_tenant,
            maybe_database,
        ) = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.GET_COLLECTION,
            tenant,
            database,
            collection_name,
        )
        if maybe_tenant:
            tenant = maybe_tenant
        if maybe_database:
            database = maybe_database

        async def inner():
            api_collection_model = cast(
                CollectionModel,
                await to_thread.run_sync(
                    self._api.get_collection,
                    collection_name,
                    tenant,
                    database,
                    limiter=self._capacity_limiter,
                ),
            )
            return api_collection_model

        return await inner()

    @trace_method("FastAPI.update_collection_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def update_collection_v1(
        self,
        collection_id: str,
        request: Request,
    ) -> None:
        def process_update_collection(
            request: Request, collection_id: str, raw_body: bytes
        ) -> None:
            update = validate_model(UpdateCollection, orjson.loads(raw_body))
            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.UPDATE_COLLECTION,
                None,
                None,
                collection_id,
            )
            return self._api._modify(
                id=_uuid(collection_id),
                new_name=update.new_name,
                new_metadata=update.new_metadata,
            )

        await to_thread.run_sync(
            process_update_collection,
            request,
            collection_id,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.delete_collection_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def delete_collection_v1(
        self,
        request: Request,
        collection_name: str,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> None:
        (
            maybe_tenant,
            maybe_database,
        ) = await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.DELETE_COLLECTION,
            tenant,
            database,
            collection_name,
        )
        if maybe_tenant:
            tenant = maybe_tenant
        if maybe_database:
            database = maybe_database

        await to_thread.run_sync(
            self._api.delete_collection,
            collection_name,
            tenant,
            database,
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.add_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def add_v1(
        self,
        request: Request,
        collection_id: str,
    ) -> bool:
        try:

            def process_add(request: Request, raw_body: bytes) -> bool:
                add = validate_model(AddEmbedding, orjson.loads(raw_body))
                self.sync_auth_and_get_tenant_and_database_for_request(
                    request.headers,
                    AuthzAction.ADD,
                    None,
                    None,
                    collection_id,
                )
                return self._api._add(
                    collection_id=_uuid(collection_id),
                    ids=add.ids,
                    embeddings=cast(
                        Embeddings,
                        convert_list_embeddings_to_np(add.embeddings)
                        if add.embeddings
                        else None,
                    ),
                    metadatas=add.metadatas,  # type: ignore
                    documents=add.documents,  # type: ignore
                    uris=add.uris,  # type: ignore
                )

            return cast(
                bool,
                await to_thread.run_sync(
                    process_add,
                    request,
                    await request.body(),
                    limiter=self._capacity_limiter,
                ),
            )
        except InvalidDimensionException as e:
            raise HTTPException(status_code=500, detail=str(e))

    @trace_method("FastAPI.update_v1", OpenTelemetryGranularity.OPERATION)
    async def update_v1(
        self,
        request: Request,
        collection_id: str,
    ) -> None:
        def process_update(request: Request, raw_body: bytes) -> bool:
            update = validate_model(UpdateEmbedding, orjson.loads(raw_body))

            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.UPDATE,
                None,
                None,
                collection_id,
            )

            return self._api._update(
                collection_id=_uuid(collection_id),
                ids=update.ids,
                embeddings=convert_list_embeddings_to_np(update.embeddings)
                if update.embeddings
                else None,
                metadatas=update.metadatas,  # type: ignore
                documents=update.documents,  # type: ignore
                uris=update.uris,  # type: ignore
            )

        await to_thread.run_sync(
            process_update,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.upsert_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def upsert_v1(
        self,
        request: Request,
        collection_id: str,
    ) -> None:
        def process_upsert(request: Request, raw_body: bytes) -> bool:
            upsert = validate_model(AddEmbedding, orjson.loads(raw_body))

            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.UPSERT,
                None,
                None,
                collection_id,
            )

            return self._api._upsert(
                collection_id=_uuid(collection_id),
                ids=upsert.ids,
                embeddings=cast(
                    Embeddings,
                    convert_list_embeddings_to_np(upsert.embeddings)
                    if upsert.embeddings
                    else None,
                ),
                metadatas=upsert.metadatas,  # type: ignore
                documents=upsert.documents,  # type: ignore
                uris=upsert.uris,  # type: ignore
            )

        await to_thread.run_sync(
            process_upsert,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.get_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def get_v1(
        self,
        collection_id: str,
        request: Request,
    ) -> GetResult:
        def process_get(request: Request, raw_body: bytes) -> GetResult:
            get = validate_model(GetEmbedding, orjson.loads(raw_body))
            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.GET,
                None,
                None,
                collection_id,
            )
            return self._api._get(
                collection_id=_uuid(collection_id),
                ids=get.ids,
                where=get.where,
                sort=get.sort,
                limit=get.limit,
                offset=get.offset,
                where_document=get.where_document,
                include=get.include,
            )

        get_result = cast(
            GetResult,
            await to_thread.run_sync(
                process_get,
                request,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )

        if get_result["embeddings"] is not None:
            get_result["embeddings"] = [
                cast(Embedding, embedding).tolist()
                for embedding in get_result["embeddings"]
            ]

        return get_result

    @trace_method("FastAPI.delete_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def delete_v1(
        self,
        collection_id: str,
        request: Request,
    ) -> None:
        def process_delete(request: Request, raw_body: bytes) -> None:
            delete = validate_model(DeleteEmbedding, orjson.loads(raw_body))
            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.DELETE,
                None,
                None,
                collection_id,
            )
            return self._api._delete(
                collection_id=_uuid(collection_id),
                ids=delete.ids,
                where=delete.where,
                where_document=delete.where_document,
            )

        await to_thread.run_sync(
            process_delete,
            request,
            await request.body(),
            limiter=self._capacity_limiter,
        )

    @trace_method("FastAPI.count_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def count_v1(
        self,
        request: Request,
        collection_id: str,
    ) -> int:
        await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.COUNT,
            None,
            None,
            collection_id,
        )

        return cast(
            int,
            await to_thread.run_sync(
                self._api._count,
                _uuid(collection_id),
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method("FastAPI.reset_v1", OpenTelemetryGranularity.OPERATION)
    @rate_limit
    async def reset_v1(
        self,
        request: Request,
    ) -> bool:
        await self.auth_and_get_tenant_and_database_for_request(
            request.headers,
            AuthzAction.RESET,
            None,
            None,
            None,
        )

        return cast(
            bool,
            await to_thread.run_sync(
                self._api.reset,
                limiter=self._capacity_limiter,
            ),
        )

    @trace_method(
        "FastAPI.get_nearest_neighbors_v1", OpenTelemetryGranularity.OPERATION
    )
    @rate_limit
    async def get_nearest_neighbors_v1(
        self,
        collection_id: str,
        request: Request,
    ) -> QueryResult:
        def process_query(request: Request, raw_body: bytes) -> QueryResult:
            query = validate_model(QueryEmbedding, orjson.loads(raw_body))

            self.sync_auth_and_get_tenant_and_database_for_request(
                request.headers,
                AuthzAction.QUERY,
                None,
                None,
                collection_id,
            )

            return self._api._query(
                collection_id=_uuid(collection_id),
                query_embeddings=cast(
                    Embeddings,
                    convert_list_embeddings_to_np(query.query_embeddings)
                    if query.query_embeddings
                    else None,
                ),
                n_results=query.n_results,
                where=query.where,
                where_document=query.where_document,
                include=query.include,
            )

        nnresult = cast(
            QueryResult,
            await to_thread.run_sync(
                process_query,
                request,
                await request.body(),
                limiter=self._capacity_limiter,
            ),
        )

        if nnresult["embeddings"] is not None:
            nnresult["embeddings"] = [
                [cast(Embedding, embedding).tolist() for embedding in result]
                for result in nnresult["embeddings"]
            ]

        return nnresult

    # =========================================================================
