from typing import Optional, Sequence

from overrides import overrides

from chromadb.api.types import GetResult, IncludeEnum, Metadata, QueryResult
from chromadb.config import System
from chromadb.execution.executor.abstract import Executor
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.segment import MetadataReader, VectorReader
from chromadb.segment.impl.manager.local import LocalSegmentManager
from chromadb.types import Collection, VectorQuery, VectorQueryResult


def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]:
    """Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map."""
    if not metadata:
        return None
    result = {}
    for k, v in metadata.items():
        if not k.startswith("chroma:"):
            result[k] = v
    if len(result) == 0:
        return None
    return result


def _doc(metadata: Optional[Metadata]) -> Optional[str]:
    """Retrieve the document (if any) from a Metadata map"""

    if metadata and "chroma:document" in metadata:
        return str(metadata["chroma:document"])
    return None


def _uri(metadata: Optional[Metadata]) -> Optional[str]:
    """Retrieve the uri (if any) from a Metadata map"""

    if metadata and "chroma:uri" in metadata:
        return str(metadata["chroma:uri"])
    return None


class LocalExecutor(Executor):
    _manager: LocalSegmentManager

    def __init__(self, system: System):
        super().__init__(system)
        self._manager = self.require(LocalSegmentManager)

    @overrides
    def count(self, plan: CountPlan) -> int:
        return self._metadata_segment(plan.scan.collection).count(plan.scan.version)

    @overrides
    def get(self, plan: GetPlan) -> GetResult:
        records = self._metadata_segment(plan.scan.collection).get_metadata(
            request_version_context=plan.scan.version,
            where=plan.filter.where,
            where_document=plan.filter.where_document,
            ids=plan.filter.user_ids,
            limit=plan.limit.fetch,
            offset=plan.limit.skip,
            include_metadata=True,
        )

        ids = [r["id"] for r in records]
        embeddings = None
        documents = None
        uris = None
        metadatas = None
        included = list()

        if plan.projection.embedding:
            if len(records) > 0:
                vectors = self._vector_segment(plan.scan.collection).get_vectors(
                    ids=ids, request_version_context=plan.scan.version
                )
                embeddings = [v["embedding"] for v in vectors]
            else:
                embeddings = list()
            included.append(IncludeEnum.embeddings)

        if plan.projection.document:
            documents = [_doc(r["metadata"]) for r in records]
            included.append(IncludeEnum.documents)

        if plan.projection.uri:
            uris = [_uri(r["metadata"]) for r in records]
            included.append(IncludeEnum.uris)

        if plan.projection.metadata:
            metadatas = [_clean_metadata(r["metadata"]) for r in records]
            included.append(IncludeEnum.metadatas)

        # TODO: Fix typing
        return GetResult(
            ids=ids,
            embeddings=embeddings,
            documents=documents,  # type: ignore[typeddict-item]
            uris=uris,  # type: ignore[typeddict-item]
            data=None,
            metadatas=metadatas,  # type: ignore[typeddict-item]
            included=included,
        )

    @overrides
    def knn(self, plan: KNNPlan) -> QueryResult:
        prefiltered_ids = None
        if plan.filter.user_ids or plan.filter.where or plan.filter.where_document:
            records = self._metadata_segment(plan.scan.collection).get_metadata(
                request_version_context=plan.scan.version,
                where=plan.filter.where,
                where_document=plan.filter.where_document,
                ids=plan.filter.user_ids,
                limit=None,
                offset=0,
                include_metadata=False,
            )
            prefiltered_ids = [r["id"] for r in records]

        knns: Sequence[Sequence[VectorQueryResult]] = [[]] * len(plan.knn.embeddings)

        # Query vectors only when the user did not specify a filter or when the filter
        # yields non-empty ids. Otherwise, the user specified a filter but it yields
        # no matching ids, in which case we can return an empty result.
        if prefiltered_ids is None or len(prefiltered_ids) > 0:
            query = VectorQuery(
                vectors=plan.knn.embeddings,
                k=plan.knn.fetch,
                allowed_ids=prefiltered_ids,
                include_embeddings=plan.projection.embedding,
                options=None,
                request_version_context=plan.scan.version,
            )
            knns = self._vector_segment(plan.scan.collection).query_vectors(query)

        ids = [[r["id"] for r in result] for result in knns]
        embeddings = None
        documents = None
        uris = None
        metadatas = None
        distances = None
        included = list()

        if plan.projection.embedding:
            embeddings = [[r["embedding"] for r in result] for result in knns]
            included.append(IncludeEnum.embeddings)

        if plan.projection.rank:
            distances = [[r["distance"] for r in result] for result in knns]
            included.append(IncludeEnum.distances)

        if plan.projection.document or plan.projection.metadata or plan.projection.uri:
            merged_ids = list(set([id for result in ids for id in result]))
            hydrated_records = self._metadata_segment(
                plan.scan.collection
            ).get_metadata(
                request_version_context=plan.scan.version,
                where=None,
                where_document=None,
                ids=merged_ids,
                limit=None,
                offset=0,
                include_metadata=True,
            )
            metadata_by_id = {r["id"]: r["metadata"] for r in hydrated_records}

            if plan.projection.document:
                documents = [
                    [_doc(metadata_by_id.get(id, None)) for id in result]
                    for result in ids
                ]
                included.append(IncludeEnum.documents)

            if plan.projection.uri:
                uris = [
                    [_uri(metadata_by_id.get(id, None)) for id in result]
                    for result in ids
                ]
                included.append(IncludeEnum.uris)

            if plan.projection.metadata:
                metadatas = [
                    [_clean_metadata(metadata_by_id.get(id, None)) for id in result]
                    for result in ids
                ]
                included.append(IncludeEnum.metadatas)

        # TODO: Fix typing
        return QueryResult(
            ids=ids,
            embeddings=embeddings,  # type: ignore[typeddict-item]
            documents=documents,  # type: ignore[typeddict-item]
            uris=uris,  # type: ignore[typeddict-item]
            data=None,
            metadatas=metadatas,  # type: ignore[typeddict-item]
            distances=distances,
            included=included,
        )

    def _metadata_segment(self, collection: Collection) -> MetadataReader:
        return self._manager.get_segment(collection.id, MetadataReader)

    def _vector_segment(self, collection: Collection) -> VectorReader:
        return self._manager.get_segment(collection.id, VectorReader)
