from concurrent import futures
from typing import Any, Dict, List, cast
from uuid import UUID
from overrides import overrides
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.config import DEFAULT_DATABASE, DEFAULT_TENANT, Component, System
from chromadb.proto.convert import (
    from_proto_metadata,
    from_proto_update_metadata,
    from_proto_segment,
    from_proto_segment_scope,
    to_proto_collection,
    to_proto_segment,
)
import chromadb.proto.chroma_pb2 as proto
from chromadb.proto.coordinator_pb2 import (
    CreateCollectionRequest,
    CreateCollectionResponse,
    CreateDatabaseRequest,
    CreateDatabaseResponse,
    CreateSegmentRequest,
    CreateSegmentResponse,
    CreateTenantRequest,
    CreateTenantResponse,
    DeleteCollectionRequest,
    DeleteCollectionResponse,
    DeleteSegmentRequest,
    DeleteSegmentResponse,
    GetCollectionsRequest,
    GetCollectionsResponse,
    GetCollectionWithSegmentsRequest,
    GetCollectionWithSegmentsResponse,
    GetDatabaseRequest,
    GetDatabaseResponse,
    GetSegmentsRequest,
    GetSegmentsResponse,
    GetTenantRequest,
    GetTenantResponse,
    ResetStateResponse,
    UpdateCollectionRequest,
    UpdateCollectionResponse,
    UpdateSegmentRequest,
    UpdateSegmentResponse,
)
from chromadb.proto.coordinator_pb2_grpc import (
    SysDBServicer,
    add_SysDBServicer_to_server,
)
import grpc
from google.protobuf.empty_pb2 import Empty
from chromadb.types import Collection, Metadata, Segment, SegmentScope


class GrpcMockSysDB(SysDBServicer, Component):
    """A mock sysdb implementation that can be used for testing the grpc client. It stores
    state in simple python data structures instead of a database."""

    _server: grpc.Server
    _server_port: int
    _segments: Dict[str, Segment] = {}
    _collection_to_segments: Dict[str, List[str]] = {}
    _tenants_to_databases_to_collections: Dict[
        str, Dict[str, Dict[str, Collection]]
    ] = {}
    _tenants_to_database_to_id: Dict[str, Dict[str, UUID]] = {}

    def __init__(self, system: System):
        self._server_port = system.settings.require("chroma_server_grpc_port")
        return super().__init__(system)

    @overrides
    def start(self) -> None:
        self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
        add_SysDBServicer_to_server(self, self._server)  # type: ignore
        self._server.add_insecure_port(f"[::]:{self._server_port}")
        self._server.start()
        return super().start()

    @overrides
    def stop(self) -> None:
        self._server.stop(None)
        return super().stop()

    @overrides
    def reset_state(self) -> None:
        self._segments = {}
        self._tenants_to_databases_to_collections = {}
        # Create defaults
        self._tenants_to_databases_to_collections[DEFAULT_TENANT] = {}
        self._tenants_to_databases_to_collections[DEFAULT_TENANT][DEFAULT_DATABASE] = {}
        self._tenants_to_database_to_id[DEFAULT_TENANT] = {}
        self._tenants_to_database_to_id[DEFAULT_TENANT][DEFAULT_DATABASE] = UUID(int=0)
        return super().reset_state()

    @overrides(check_signature=False)
    def CreateDatabase(
        self, request: CreateDatabaseRequest, context: grpc.ServicerContext
    ) -> CreateDatabaseResponse:
        tenant = request.tenant
        database = request.name
        if tenant not in self._tenants_to_databases_to_collections:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
        if database in self._tenants_to_databases_to_collections[tenant]:
            context.abort(
                grpc.StatusCode.ALREADY_EXISTS, f"Database {database} already exists"
            )
        self._tenants_to_databases_to_collections[tenant][database] = {}
        self._tenants_to_database_to_id[tenant][database] = UUID(hex=request.id)
        return CreateDatabaseResponse()

    @overrides(check_signature=False)
    def GetDatabase(
        self, request: GetDatabaseRequest, context: grpc.ServicerContext
    ) -> GetDatabaseResponse:
        tenant = request.tenant
        database = request.name
        if tenant not in self._tenants_to_databases_to_collections:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
        if database not in self._tenants_to_databases_to_collections[tenant]:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
        id = self._tenants_to_database_to_id[tenant][database]
        return GetDatabaseResponse(
            database=proto.Database(id=id.hex, name=database, tenant=tenant),
        )

    @overrides(check_signature=False)
    def CreateTenant(
        self, request: CreateTenantRequest, context: grpc.ServicerContext
    ) -> CreateTenantResponse:
        tenant = request.name
        if tenant in self._tenants_to_databases_to_collections:
            context.abort(
                grpc.StatusCode.ALREADY_EXISTS, f"Tenant {tenant} already exists"
            )
        self._tenants_to_databases_to_collections[tenant] = {}
        self._tenants_to_database_to_id[tenant] = {}
        return CreateTenantResponse()

    @overrides(check_signature=False)
    def GetTenant(
        self, request: GetTenantRequest, context: grpc.ServicerContext
    ) -> GetTenantResponse:
        tenant = request.name
        if tenant not in self._tenants_to_databases_to_collections:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
        return GetTenantResponse(
            tenant=proto.Tenant(name=tenant),
        )

    # We are forced to use check_signature=False because the generated proto code
    # does not have type annotations for the request and response objects.
    # TODO: investigate generating types for the request and response objects
    @overrides(check_signature=False)
    def CreateSegment(
        self, request: CreateSegmentRequest, context: grpc.ServicerContext
    ) -> CreateSegmentResponse:
        segment = from_proto_segment(request.segment)
        return self.CreateSegmentHelper(segment, context)

    def CreateSegmentHelper(self, segment: Segment, context: grpc.ServicerContext) -> CreateSegmentResponse:
        if segment["id"].hex in self._segments:
            context.abort(
                grpc.StatusCode.ALREADY_EXISTS,
                f"Segment {segment['id']} already exists",
            )
        self._segments[segment["id"].hex] = segment
        return CreateSegmentResponse()

    @overrides(check_signature=False)
    def DeleteSegment(
        self, request: DeleteSegmentRequest, context: grpc.ServicerContext
    ) -> DeleteSegmentResponse:
        id_to_delete = request.id
        if id_to_delete in self._segments:
            del self._segments[id_to_delete]
            return DeleteSegmentResponse()
        else:
            context.abort(
                grpc.StatusCode.NOT_FOUND, f"Segment {id_to_delete} not found"
            )

    @overrides(check_signature=False)
    def GetSegments(
        self, request: GetSegmentsRequest, context: grpc.ServicerContext
    ) -> GetSegmentsResponse:
        target_id = UUID(hex=request.id) if request.HasField("id") else None
        target_type = request.type if request.HasField("type") else None
        target_scope = (
            from_proto_segment_scope(request.scope)
            if request.HasField("scope")
            else None
        )
        target_collection = UUID(hex=request.collection)

        found_segments = []
        for segment in self._segments.values():
            if target_id and segment["id"] != target_id:
                continue
            if target_type and segment["type"] != target_type:
                continue
            if target_scope and segment["scope"] != target_scope:
                continue
            if target_collection and segment["collection"] != target_collection:
                continue
            found_segments.append(segment)
        return GetSegmentsResponse(
            segments=[to_proto_segment(segment) for segment in found_segments]
        )

    @overrides(check_signature=False)
    def UpdateSegment(
        self, request: UpdateSegmentRequest, context: grpc.ServicerContext
    ) -> UpdateSegmentResponse:
        id_to_update = UUID(request.id)
        if id_to_update.hex not in self._segments:
            context.abort(
                grpc.StatusCode.NOT_FOUND, f"Segment {id_to_update} not found"
            )
        else:
            segment = self._segments[id_to_update.hex]
            if request.HasField("metadata"):
                target = cast(Dict[str, Any], segment["metadata"])
                if segment["metadata"] is None:
                    segment["metadata"] = {}
                self._merge_metadata(target, request.metadata)
            if request.HasField("reset_metadata") and request.reset_metadata:
                segment["metadata"] = {}
            return UpdateSegmentResponse()

    @overrides(check_signature=False)
    def CreateCollection(
        self, request: CreateCollectionRequest, context: grpc.ServicerContext
    ) -> CreateCollectionResponse:
        collection_name = request.name
        tenant = request.tenant
        database = request.database
        if tenant not in self._tenants_to_databases_to_collections:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
        if database not in self._tenants_to_databases_to_collections[tenant]:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")

        # Check if the collection already exists globally by id
        for (
            search_tenant,
            databases,
        ) in self._tenants_to_databases_to_collections.items():
            for search_database, search_collections in databases.items():
                if request.id in search_collections:
                    if (
                        search_tenant != request.tenant
                        or search_database != request.database
                    ):
                        context.abort(
                            grpc.StatusCode.ALREADY_EXISTS,
                            f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
                        )
                    elif not request.get_or_create:
                        # If the id exists for this tenant and database, and we are not doing a get_or_create, then
                        # we should return an already exists error
                        context.abort(
                            grpc.StatusCode.ALREADY_EXISTS,
                            f"Collection {request.id} already exists in tenant {search_tenant} database {search_database}",
                        )

        # Check if the collection already exists in this database by name
        collections = self._tenants_to_databases_to_collections[tenant][database]
        matches = [c for c in collections.values() if c["name"] == collection_name]
        assert len(matches) <= 1
        if len(matches) > 0:
            if request.get_or_create:
                existing_collection = matches[0]
                return CreateCollectionResponse(
                    collection=to_proto_collection(existing_collection),
                    created=False,
                )
            context.abort(
                grpc.StatusCode.ALREADY_EXISTS,
                f"Collection {collection_name} already exists",
            )

        configuration = CollectionConfigurationInternal.from_json_str(
            request.configuration_json_str
        )

        id = UUID(hex=request.id)
        new_collection = Collection(
            id=id,
            name=request.name,
            configuration=configuration,
            metadata=from_proto_metadata(request.metadata),
            dimension=request.dimension,
            database=database,
            tenant=tenant,
            version=0,
        )
        
        # Check that segments are unique and do not already exist
        # Keep a track of the segments that are being added
        segments_added = []
        # Create segments for the collection
        for segment_proto in request.segments:
            segment = from_proto_segment(segment_proto)
            if segment["id"].hex in self._segments:
                # Remove the already added segment since we need to roll back
                for s in segments_added:
                    self.DeleteSegment(DeleteSegmentRequest(id=s), context)
                context.abort(grpc.StatusCode.ALREADY_EXISTS, f"Segment {segment['id']} already exists")
            self.CreateSegmentHelper(segment, context)
            segments_added.append(segment["id"].hex)

        collections[request.id] = new_collection
        collection_unique_key = f"{tenant}:{database}:{request.id}"
        self._collection_to_segments[collection_unique_key] = segments_added
        return CreateCollectionResponse(
            collection=to_proto_collection(new_collection),
            created=True,
        )

    @overrides(check_signature=False)
    def DeleteCollection(
        self, request: DeleteCollectionRequest, context: grpc.ServicerContext
    ) -> DeleteCollectionResponse:
        collection_id = request.id
        tenant = request.tenant
        database = request.database
        if tenant not in self._tenants_to_databases_to_collections:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Tenant {tenant} not found")
        if database not in self._tenants_to_databases_to_collections[tenant]:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Database {database} not found")
        collections = self._tenants_to_databases_to_collections[tenant][database]
        if collection_id in collections:
            del collections[collection_id]
            collection_unique_key = f"{tenant}:{database}:{collection_id}"
            segment_ids = self._collection_to_segments[collection_unique_key]
            if segment_ids: # Delete segments if provided.
                for segment_id in segment_ids:
                    del self._segments[segment_id]
            return DeleteCollectionResponse()
        else:
            context.abort(
                grpc.StatusCode.NOT_FOUND, f"Collection {collection_id} not found"
            )

    @overrides(check_signature=False)
    def GetCollections(
        self, request: GetCollectionsRequest, context: grpc.ServicerContext
    ) -> GetCollectionsResponse:
        target_id = UUID(hex=request.id) if request.HasField("id") else None
        target_name = request.name if request.HasField("name") else None

        allCollections = {}
        for tenant, databases in self._tenants_to_databases_to_collections.items():
            for database, collections in databases.items():
                if request.tenant != "" and tenant != request.tenant:
                    continue
                if request.database != "" and database != request.database:
                    continue
                allCollections.update(collections)
                print(
                    f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
                )
        found_collections = []
        for collection in allCollections.values():
            if target_id and collection["id"] != target_id:
                continue
            if target_name and collection["name"] != target_name:
                continue
            found_collections.append(collection)
        return GetCollectionsResponse(
            collections=[
                to_proto_collection(collection) for collection in found_collections
            ]
        )

    @overrides(check_signature=False)
    def GetCollectionWithSegments(
        self, request: GetCollectionWithSegmentsRequest, context: grpc.ServicerContext
    ) -> GetCollectionWithSegmentsResponse:
        allCollections = {}
        for tenant, databases in self._tenants_to_databases_to_collections.items():
            for database, collections in databases.items():
                allCollections.update(collections)
                print(
                    f"Tenant: {tenant}, Database: {database}, Collections: {collections}"
                )
        collection = allCollections.get(request.id, None)
        if collection is None:
            context.abort(grpc.StatusCode.NOT_FOUND, f"Collection with id {request.id} not found")
        collection_unique_key = f"{collection.tenant}:{collection.database}:{request.id}"
        segments = [self._segments[id] for id in self._collection_to_segments[collection_unique_key]]
        if {segment["scope"] for segment in segments} != {SegmentScope.METADATA, SegmentScope.RECORD, SegmentScope.VECTOR}:
            context.abort(grpc.StatusCode.INTERNAL, f"Incomplete segments for collection {collection}: {segments}")
            
        return GetCollectionWithSegmentsResponse(
            collection=to_proto_collection(collection),
            segments=[to_proto_segment(segment) for segment in segments]
        )

    @overrides(check_signature=False)
    def UpdateCollection(
        self, request: UpdateCollectionRequest, context: grpc.ServicerContext
    ) -> UpdateCollectionResponse:
        id_to_update = UUID(request.id)
        # Find the collection with this id
        collections = {}
        for tenant, databases in self._tenants_to_databases_to_collections.items():
            for database, maybe_collections in databases.items():
                if id_to_update.hex in maybe_collections:
                    collections = maybe_collections

        if id_to_update.hex not in collections:
            context.abort(
                grpc.StatusCode.NOT_FOUND, f"Collection {id_to_update} not found"
            )
        else:
            collection = collections[id_to_update.hex]
            if request.HasField("name"):
                collection["name"] = request.name
            if request.HasField("dimension"):
                collection["dimension"] = request.dimension
            if request.HasField("metadata"):
                # TODO: IN SysDB SQlite we have technical debt where we
                # replace the entire metadata dict with the new one. We should
                # fix that by merging it. For now we just do the same thing here

                update_metadata = from_proto_update_metadata(request.metadata)
                cleaned_metadata = None
                if update_metadata is not None:
                    cleaned_metadata = {}
                    for key, value in update_metadata.items():
                        if value is not None:
                            cleaned_metadata[key] = value

                collection["metadata"] = cleaned_metadata
            elif request.HasField("reset_metadata"):
                if request.reset_metadata:
                    collection["metadata"] = {}

            return UpdateCollectionResponse()

    @overrides(check_signature=False)
    def ResetState(
        self, request: Empty, context: grpc.ServicerContext
    ) -> ResetStateResponse:
        self.reset_state()
        return ResetStateResponse()

    def _merge_metadata(self, target: Metadata, source: proto.UpdateMetadata) -> None:
        target_metadata = cast(Dict[str, Any], target)
        source_metadata = cast(Dict[str, Any], from_proto_update_metadata(source))
        target_metadata.update(source_metadata)
        # If a key has a None value, remove it from the metadata
        for key, value in source_metadata.items():
            if value is None and key in target:
                del target_metadata[key]
