from typing import List, Optional, Sequence, Tuple, Union, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, System, logger
from chromadb.db.system import SysDB
from chromadb.errors import NotFoundError, UniqueConstraintError, InternalError
from chromadb.proto.convert import (
    from_proto_collection,
    from_proto_segment,
    to_proto_update_metadata,
    to_proto_segment,
    to_proto_segment_scope,
)
from chromadb.proto.coordinator_pb2 import (
    CreateCollectionRequest,
    CreateDatabaseRequest,
    CreateSegmentRequest,
    CreateTenantRequest,
    DeleteCollectionRequest,
    DeleteDatabaseRequest,
    DeleteSegmentRequest,
    GetCollectionsRequest,
    GetCollectionsResponse,
    GetCollectionWithSegmentsRequest,
    GetCollectionWithSegmentsResponse,
    GetDatabaseRequest,
    GetSegmentsRequest,
    GetTenantRequest,
    ListDatabasesRequest,
    UpdateCollectionRequest,
    UpdateSegmentRequest,
)
from chromadb.proto.coordinator_pb2_grpc import SysDBStub
from chromadb.proto.utils import RetryOnRpcErrorClientInterceptor
from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor
from chromadb.types import (
    Collection,
    CollectionAndSegments,
    Database,
    Metadata,
    OptionalArgument,
    Segment,
    SegmentScope,
    Tenant,
    Unspecified,
    UpdateMetadata,
)
from google.protobuf.empty_pb2 import Empty
import grpc


class GrpcSysDB(SysDB):
    """A gRPC implementation of the SysDB. In the distributed system, the SysDB is also
    called the 'Coordinator'. This implementation is used by Chroma frontend servers
    to call a remote SysDB (Coordinator) service."""

    _sys_db_stub: SysDBStub
    _channel: grpc.Channel
    _coordinator_url: str
    _coordinator_port: int
    _request_timeout_seconds: int

    def __init__(self, system: System):
        self._coordinator_url = system.settings.require("chroma_coordinator_host")
        # TODO: break out coordinator_port into a separate setting?
        self._coordinator_port = system.settings.require("chroma_server_grpc_port")
        self._request_timeout_seconds = system.settings.require(
            "chroma_sysdb_request_timeout_seconds"
        )
        return super().__init__(system)

    @overrides
    def start(self) -> None:
        self._channel = grpc.insecure_channel(
            f"{self._coordinator_url}:{self._coordinator_port}",
        )
        interceptors = [OtelInterceptor(), RetryOnRpcErrorClientInterceptor()]
        self._channel = grpc.intercept_channel(self._channel, *interceptors)
        self._sys_db_stub = SysDBStub(self._channel)  # type: ignore
        return super().start()

    @overrides
    def stop(self) -> None:
        self._channel.close()
        return super().stop()

    @overrides
    def reset_state(self) -> None:
        self._sys_db_stub.ResetState(Empty())
        return super().reset_state()

    @overrides
    def create_database(
        self, id: UUID, name: str, tenant: str = DEFAULT_TENANT
    ) -> None:
        try:
            request = CreateDatabaseRequest(id=id.hex, name=name, tenant=tenant)
            response = self._sys_db_stub.CreateDatabase(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(
                f"Failed to create database name {name} and database id {id} for tenant {tenant} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise UniqueConstraintError()
            raise InternalError()

    @overrides
    def get_database(self, name: str, tenant: str = DEFAULT_TENANT) -> Database:
        try:
            request = GetDatabaseRequest(name=name, tenant=tenant)
            response = self._sys_db_stub.GetDatabase(
                request, timeout=self._request_timeout_seconds
            )
            return Database(
                id=UUID(hex=response.database.id),
                name=response.database.name,
                tenant=response.database.tenant,
            )
        except grpc.RpcError as e:
            logger.info(
                f"Failed to get database {name} for tenant {tenant} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            raise InternalError()

    @overrides
    def delete_database(self, name: str, tenant: str = DEFAULT_TENANT) -> None:
        try:
            request = DeleteDatabaseRequest(name=name, tenant=tenant)
            self._sys_db_stub.DeleteDatabase(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(
                f"Failed to delete database {name} for tenant {tenant} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            raise InternalError

    @overrides
    def list_databases(
        self,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
        tenant: str = DEFAULT_TENANT,
    ) -> Sequence[Database]:
        try:
            request = ListDatabasesRequest(limit=limit, offset=offset, tenant=tenant)
            response = self._sys_db_stub.ListDatabases(
                request, timeout=self._request_timeout_seconds
            )
            results: List[Database] = []
            for proto_database in response.databases:
                results.append(
                    Database(
                        id=UUID(hex=proto_database.id),
                        name=proto_database.name,
                        tenant=proto_database.tenant,
                    )
                )
            return results
        except grpc.RpcError as e:
            logger.info(
                f"Failed to list databases for tenant {tenant} due to error: {e}"
            )
            raise InternalError()

    @overrides
    def create_tenant(self, name: str) -> None:
        try:
            request = CreateTenantRequest(name=name)
            response = self._sys_db_stub.CreateTenant(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(f"Failed to create tenant {name} due to error: {e}")
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise UniqueConstraintError()
            raise InternalError()

    @overrides
    def get_tenant(self, name: str) -> Tenant:
        try:
            request = GetTenantRequest(name=name)
            response = self._sys_db_stub.GetTenant(
                request, timeout=self._request_timeout_seconds
            )
            return Tenant(
                name=response.tenant.name,
            )
        except grpc.RpcError as e:
            logger.info(f"Failed to get tenant {name} due to error: {e}")
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            raise InternalError()

    @overrides
    def create_segment(self, segment: Segment) -> None:
        try:
            proto_segment = to_proto_segment(segment)
            request = CreateSegmentRequest(
                segment=proto_segment,
            )
            response = self._sys_db_stub.CreateSegment(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(f"Failed to create segment {segment}, error: {e}")
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise UniqueConstraintError()
            raise InternalError()

    @overrides
    def delete_segment(self, collection: UUID, id: UUID) -> None:
        try:
            request = DeleteSegmentRequest(
                id=id.hex,
                collection=collection.hex,
            )
            response = self._sys_db_stub.DeleteSegment(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(
                f"Failed to delete segment with id {id} for collection {collection} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            raise InternalError()

    @overrides
    def get_segments(
        self,
        collection: UUID,
        id: Optional[UUID] = None,
        type: Optional[str] = None,
        scope: Optional[SegmentScope] = None,
    ) -> Sequence[Segment]:
        try:
            request = GetSegmentsRequest(
                id=id.hex if id else None,
                type=type,
                scope=to_proto_segment_scope(scope) if scope else None,
                collection=collection.hex,
            )
            response = self._sys_db_stub.GetSegments(
                request, timeout=self._request_timeout_seconds
            )
            results: List[Segment] = []
            for proto_segment in response.segments:
                segment = from_proto_segment(proto_segment)
                results.append(segment)
            return results
        except grpc.RpcError as e:
            logger.info(
                f"Failed to get segment id {id}, type {type}, scope {scope} for collection {collection} due to error: {e}"
            )
            raise InternalError()

    @overrides
    def update_segment(
        self,
        collection: UUID,
        id: UUID,
        metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
    ) -> None:
        try:
            write_metadata = None
            if metadata != Unspecified():
                write_metadata = cast(Union[UpdateMetadata, None], metadata)

            request = UpdateSegmentRequest(
                id=id.hex,
                collection=collection.hex,
                metadata=to_proto_update_metadata(write_metadata)
                if write_metadata
                else None,
            )

            if metadata is None:
                request.ClearField("metadata")
                request.reset_metadata = True

            self._sys_db_stub.UpdateSegment(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.info(
                f"Failed to update segment with id {id} for collection {collection}, error: {e}"
            )
            raise InternalError()

    @overrides
    def create_collection(
        self,
        id: UUID,
        name: str,
        configuration: CollectionConfigurationInternal,
        segments: Sequence[Segment],
        metadata: Optional[Metadata] = None,
        dimension: Optional[int] = None,
        get_or_create: bool = False,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> Tuple[Collection, bool]:
        try:
            request = CreateCollectionRequest(
                id=id.hex,
                name=name,
                configuration_json_str=configuration.to_json_str(),
                metadata=to_proto_update_metadata(metadata) if metadata else None,
                dimension=dimension,
                get_or_create=get_or_create,
                tenant=tenant,
                database=database,
                segments=[to_proto_segment(segment) for segment in segments],
            )
            response = self._sys_db_stub.CreateCollection(
                request, timeout=self._request_timeout_seconds
            )
            collection = from_proto_collection(response.collection)
            return collection, response.created
        except grpc.RpcError as e:
            logger.error(
                f"Failed to create collection id {id}, name {name} for database {database} and tenant {tenant} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise UniqueConstraintError()
            raise InternalError()

    @overrides
    def delete_collection(
        self,
        id: UUID,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
    ) -> None:
        try:
            request = DeleteCollectionRequest(
                id=id.hex,
                tenant=tenant,
                database=database,
            )
            response = self._sys_db_stub.DeleteCollection(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            logger.error(
                f"Failed to delete collection id {id} for database {database} and tenant {tenant} due to error: {e}"
            )
            e = cast(grpc.Call, e)
            logger.error(
                f"Error code: {e.code()}, NotFoundError: {grpc.StatusCode.NOT_FOUND}"
            )
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            raise InternalError()

    @overrides
    def get_collections(
        self,
        id: Optional[UUID] = None,
        name: Optional[str] = None,
        tenant: str = DEFAULT_TENANT,
        database: str = DEFAULT_DATABASE,
        limit: Optional[int] = None,
        offset: Optional[int] = None,
    ) -> Sequence[Collection]:
        try:
            # TODO: implement limit and offset in the gRPC service
            request = None
            if id is not None:
                request = GetCollectionsRequest(
                    id=id.hex,
                    limit=limit,
                    offset=offset,
                )
            if name is not None:
                if tenant is None and database is None:
                    raise ValueError(
                        "If name is specified, tenant and database must also be specified in order to uniquely identify the collection"
                    )
                request = GetCollectionsRequest(
                    name=name,
                    tenant=tenant,
                    database=database,
                    limit=limit,
                    offset=offset,
                )
            if id is None and name is None:
                request = GetCollectionsRequest(
                    tenant=tenant,
                    database=database,
                    limit=limit,
                    offset=offset,
                )
            response: GetCollectionsResponse = self._sys_db_stub.GetCollections(
                request, timeout=self._request_timeout_seconds
            )
            results: List[Collection] = []
            for collection in response.collections:
                results.append(from_proto_collection(collection))
            return results
        except grpc.RpcError as e:
            logger.error(
                f"Failed to get collections with id {id}, name {name}, tenant {tenant}, database {database} due to error: {e}"
            )
            raise InternalError()

    @overrides
    def get_collection_with_segments(
        self, collection_id: UUID
    ) -> CollectionAndSegments:
        try:
            request = GetCollectionWithSegmentsRequest(id=collection_id.hex)
            response: GetCollectionWithSegmentsResponse = (
                self._sys_db_stub.GetCollectionWithSegments(request)
            )
            return CollectionAndSegments(
                collection=from_proto_collection(response.collection),
                segments=[from_proto_segment(segment) for segment in response.segments],
            )
        except grpc.RpcError as e:
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            logger.error(
                f"Failed to get collection {collection_id} and its segments due to error: {e}"
            )
            raise InternalError()

    @overrides
    def update_collection(
        self,
        id: UUID,
        name: OptionalArgument[str] = Unspecified(),
        dimension: OptionalArgument[Optional[int]] = Unspecified(),
        metadata: OptionalArgument[Optional[UpdateMetadata]] = Unspecified(),
    ) -> None:
        try:
            write_name = None
            if name != Unspecified():
                write_name = cast(str, name)

            write_dimension = None
            if dimension != Unspecified():
                write_dimension = cast(Union[int, None], dimension)

            write_metadata = None
            if metadata != Unspecified():
                write_metadata = cast(Union[UpdateMetadata, None], metadata)

            request = UpdateCollectionRequest(
                id=id.hex,
                name=write_name,
                dimension=write_dimension,
                metadata=to_proto_update_metadata(write_metadata)
                if write_metadata
                else None,
            )
            if metadata is None:
                request.ClearField("metadata")
                request.reset_metadata = True

            response = self._sys_db_stub.UpdateCollection(
                request, timeout=self._request_timeout_seconds
            )
        except grpc.RpcError as e:
            e = cast(grpc.Call, e)
            logger.error(
                f"Failed to update collection id {id}, name {name} due to error: {e}"
            )
            if e.code() == grpc.StatusCode.NOT_FOUND:
                raise NotFoundError()
            if e.code() == grpc.StatusCode.ALREADY_EXISTS:
                raise UniqueConstraintError()
            raise InternalError()

    def reset_and_wait_for_ready(self) -> None:
        self._sys_db_stub.ResetState(Empty(), wait_for_ready=True)
