import httpx
from typing import Optional, Sequence
from uuid import UUID
from overrides import override

from chromadb.api.models.Collection import CollectionName
from chromadb.auth import UserIdentity
from chromadb.auth.utils import maybe_set_tenant_and_database
from chromadb.api import AsyncAdminAPI, AsyncClientAPI, AsyncServerAPI
from chromadb.api.configuration import CollectionConfiguration
from chromadb.api.models.AsyncCollection import AsyncCollection
from chromadb.api.shared_system_client import SharedSystemClient
from chromadb.api.types import (
    CollectionMetadata,
    DataLoader,
    Documents,
    Embeddable,
    EmbeddingFunction,
    Embeddings,
    GetResult,
    IDs,
    Include,
    Loadable,
    Metadatas,
    QueryResult,
    URIs,
)
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Settings, System
from chromadb.errors import ChromaError
from chromadb.types import Database, Tenant, Where, WhereDocument
import chromadb.utils.embedding_functions as ef


class AsyncClient(SharedSystemClient, AsyncClientAPI):
    """A client for Chroma. This is the main entrypoint for interacting with Chroma.
    A client internally stores its tenant and database and proxies calls to a
    Server API instance of Chroma. It treats the Server API and corresponding System
    as a singleton, so multiple clients connecting to the same resource will share the
    same API instance.

    Client implementations should be implement their own API-caching strategies.
    """

    # An internal admin client for verifying that databases and tenants exist
    _admin_client: AsyncAdminAPI

    tenant: str = DEFAULT_TENANT
    database: str = DEFAULT_DATABASE

    _server: AsyncServerAPI

    @classmethod
    async def create(
        cls,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
        settings: Settings = Settings(),
    ) -> "AsyncClient":
        # Create an admin client for verifying that databases and tenants exist
        self = cls(settings=settings)
        SharedSystemClient._populate_data_from_system(self._system)

        self.tenant = tenant
        self.database = database

        # Get the root system component we want to interact with
        self._server = self._system.instance(AsyncServerAPI)

        user_identity = await self.get_user_identity()

        maybe_tenant, maybe_database = maybe_set_tenant_and_database(
            user_identity,
            overwrite_singleton_tenant_database_access_from_auth=settings.chroma_overwrite_singleton_tenant_database_access_from_auth,
            user_provided_tenant=tenant,
            user_provided_database=database,
        )
        if maybe_tenant:
            self.tenant = maybe_tenant
        if maybe_database:
            self.database = maybe_database

        self._admin_client = AsyncAdminClient.from_system(self._system)
        await self._validate_tenant_database(tenant=self.tenant, database=self.database)

        self._submit_client_start_event()

        return self

    @classmethod
    # (we can't override and use from_system() because it's synchronous)
    async def from_system_async(
        cls,
        system: System,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> "AsyncClient":
        """Create a client from an existing system. This is useful for testing and debugging."""
        return await AsyncClient.create(tenant, database, system.settings)

    @classmethod
    @override
    def from_system(
        cls,
        system: System,
    ) -> "SharedSystemClient":
        """AsyncClient cannot be created synchronously. Use .from_system_async() instead."""
        raise NotImplementedError(
            "AsyncClient cannot be created synchronously. Use .from_system_async() instead."
        )

    @override
    async def get_user_identity(self) -> UserIdentity:
        return await self._server.get_user_identity()

    @override
    async def set_tenant(self, tenant: str, database: str = DEFAULT_DATABASE) -> None:
        await self._validate_tenant_database(tenant=tenant, database=database)
        self.tenant = tenant
        self.database = database

    @override
    async def set_database(self, database: str) -> None:
        await self._validate_tenant_database(tenant=self.tenant, database=database)
        self.database = database

    async def _validate_tenant_database(self, tenant: str, database: str) -> None:
        try:
            await self._admin_client.get_tenant(name=tenant)
        except httpx.ConnectError:
            raise ValueError(
                "Could not connect to a Chroma server. Are you sure it is running?"
            )
        # Propagate ChromaErrors
        except ChromaError as e:
            raise e
        except Exception:
            raise ValueError(
                f"Could not connect to tenant {tenant}. Are you sure it exists?"
            )

        try:
            await self._admin_client.get_database(name=database, tenant=tenant)
        except httpx.ConnectError:
            raise ValueError(
                "Could not connect to a Chroma server. Are you sure it is running?"
            )

    # region BaseAPI Methods
    # Note - we could do this in less verbose ways, but they break type checking
    @override
    async def heartbeat(self) -> int:
        return await self._server.heartbeat()

    @override
    async def list_collections(
        self, limit: Optional[int] = None, offset: Optional[int] = None
    ) -> Sequence[CollectionName]:
        models = await self._server.list_collections(
            limit, offset, tenant=self.tenant, database=self.database
        )
        return [CollectionName(model.name) for model in models]

    @override
    async def count_collections(self) -> int:
        return await self._server.count_collections(
            tenant=self.tenant, database=self.database
        )

    @override
    async def create_collection(
        self,
        name: str,
        configuration: Optional[CollectionConfiguration] = None,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = ef.DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
        get_or_create: bool = False,
    ) -> AsyncCollection:
        model = await self._server.create_collection(
            name=name,
            configuration=configuration,
            metadata=metadata,
            tenant=self.tenant,
            database=self.database,
            get_or_create=get_or_create,
        )
        return AsyncCollection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    async def get_collection(
        self,
        name: str,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = ef.DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
    ) -> AsyncCollection:
        model = await self._server.get_collection(
            name=name,
            tenant=self.tenant,
            database=self.database,
        )
        return AsyncCollection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    async def get_or_create_collection(
        self,
        name: str,
        configuration: Optional[CollectionConfiguration] = None,
        metadata: Optional[CollectionMetadata] = None,
        embedding_function: Optional[
            EmbeddingFunction[Embeddable]
        ] = ef.DefaultEmbeddingFunction(),  # type: ignore
        data_loader: Optional[DataLoader[Loadable]] = None,
    ) -> AsyncCollection:
        model = await self._server.get_or_create_collection(
            name=name,
            configuration=configuration,
            metadata=metadata,
            tenant=self.tenant,
            database=self.database,
        )
        return AsyncCollection(
            client=self._server,
            model=model,
            embedding_function=embedding_function,
            data_loader=data_loader,
        )

    @override
    async def _modify(
        self,
        id: UUID,
        new_name: Optional[str] = None,
        new_metadata: Optional[CollectionMetadata] = None,
    ) -> None:
        return await self._server._modify(
            id=id,
            new_name=new_name,
            new_metadata=new_metadata,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def delete_collection(
        self,
        name: str,
    ) -> None:
        return await self._server.delete_collection(
            name=name,
            tenant=self.tenant,
            database=self.database,
        )

    #
    # ITEM METHODS
    #

    @override
    async def _add(
        self,
        ids: IDs,
        collection_id: UUID,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return await self._server._add(
            ids=ids,
            collection_id=collection_id,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def _update(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Optional[Embeddings] = None,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return await self._server._update(
            collection_id=collection_id,
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def _upsert(
        self,
        collection_id: UUID,
        ids: IDs,
        embeddings: Embeddings,
        metadatas: Optional[Metadatas] = None,
        documents: Optional[Documents] = None,
        uris: Optional[URIs] = None,
    ) -> bool:
        return await self._server._upsert(
            collection_id=collection_id,
            ids=ids,
            embeddings=embeddings,
            metadatas=metadatas,
            documents=documents,
            uris=uris,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def _count(self, collection_id: UUID) -> int:
        return await self._server._count(
            collection_id=collection_id,
        )

    @override
    async def _peek(self, collection_id: UUID, n: int = 10) -> GetResult:
        return await self._server._peek(
            collection_id=collection_id,
            n=n,
        )

    @override
    async def _get(
        self,
        collection_id: UUID,
        ids: Optional[IDs] = None,
        where: Optional[Where] = None,
        sort: Optional[str] = None,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        page: Optional[int] = None,
        page_size: Optional[int] = None,
        where_document: Optional[WhereDocument] = None,
        include: Include = ["embeddings", "metadatas", "documents"],  # type: ignore[list-item]
    ) -> GetResult:
        return await self._server._get(
            collection_id=collection_id,
            ids=ids,
            where=where,
            sort=sort,
            limit=limit,
            offset=offset,
            page=page,
            page_size=page_size,
            where_document=where_document,
            include=include,
            tenant=self.tenant,
            database=self.database,
        )

    async def _delete(
        self,
        collection_id: UUID,
        ids: Optional[IDs],
        where: Optional[Where] = None,
        where_document: Optional[WhereDocument] = None,
    ) -> None:
        await self._server._delete(
            collection_id=collection_id,
            ids=ids,
            where=where,
            where_document=where_document,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def _query(
        self,
        collection_id: UUID,
        query_embeddings: Embeddings,
        n_results: int = 10,
        where: Optional[Where] = None,
        where_document: Optional[WhereDocument] = None,
        include: Include = ["embeddings", "metadatas", "documents", "distances"],  # type: ignore[list-item]
    ) -> QueryResult:
        return await self._server._query(
            collection_id=collection_id,
            query_embeddings=query_embeddings,
            n_results=n_results,
            where=where,
            where_document=where_document,
            include=include,
            tenant=self.tenant,
            database=self.database,
        )

    @override
    async def reset(self) -> bool:
        return await self._server.reset()

    @override
    async def get_version(self) -> str:
        return await self._server.get_version()

    @override
    def get_settings(self) -> Settings:
        return self._server.get_settings()

    @override
    async def get_max_batch_size(self) -> int:
        return await self._server.get_max_batch_size()

    # endregion


class AsyncAdminClient(SharedSystemClient, AsyncAdminAPI):
    _server: AsyncServerAPI

    def __init__(self, settings: Settings = Settings()) -> None:
        super().__init__(settings)
        self._server = self._system.instance(AsyncServerAPI)

    @override
    async def create_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
        return await self._server.create_database(name=name, tenant=tenant)

    @override
    async def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
        return await self._server.get_database(name=name, tenant=tenant)

    @override
    async def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
        return await self._server.delete_database(name=name, tenant=tenant)

    @override
    async def list_databases(
        self,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        tenant: str = DEFAULT_TENANT,
    ) -> Sequence[Database]:
        return await self._server.list_databases(
            limit=limit, offset=offset, tenant=tenant
        )

    @override
    async def create_tenant(self, name: str) -> None:
        return await self._server.create_tenant(name=name)

    @override
    async def get_tenant(self, name: str) -> Tenant:
        return await self._server.get_tenant(name=name)

    @classmethod
    @override
    def from_system(
        cls,
        system: System,
    ) -> "AsyncAdminClient":
        SharedSystemClient._populate_data_from_system(system)
        instance = cls(settings=system.settings)
        return instance
