from typing import Dict, Optional, Sequence, Tuple, TypedDict, Union, cast
from uuid import UUID

import numpy as np
from numpy.typing import NDArray

import chromadb.proto.chroma_pb2 as chroma_pb
import chromadb.proto.query_executor_pb2 as query_pb
from chromadb.api.configuration import CollectionConfigurationInternal
from chromadb.api.types import Embedding, Where, WhereDocument
from chromadb.execution.expression.operator import (
    KNN,
    Filter,
    Limit,
    Projection,
    Scan,
)
from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan
from chromadb.types import (
    Collection,
    LogRecord,
    Metadata,
    Operation,
    OperationRecord,
    RequestVersionContext,
    ScalarEncoding,
    Segment,
    SegmentScope,
    SeqId,
    UpdateMetadata,
    Vector,
    VectorEmbeddingRecord,
    VectorQueryResult,
)


class ProjectionRecord(TypedDict):
    id: str
    document: Optional[str]
    embedding: Optional[Vector]
    metadata: Optional[Metadata]


class KNNProjectionRecord(TypedDict):
    record: ProjectionRecord
    distance: Optional[float]


# TODO: Unit tests for this file, handling optional states etc
def to_proto_vector(vector: Vector, encoding: ScalarEncoding) -> chroma_pb.Vector:
    if encoding == ScalarEncoding.FLOAT32:
        as_bytes = np.array(vector, dtype=np.float32).tobytes()
        proto_encoding = chroma_pb.ScalarEncoding.FLOAT32
    elif encoding == ScalarEncoding.INT32:
        as_bytes = np.array(vector, dtype=np.int32).tobytes()
        proto_encoding = chroma_pb.ScalarEncoding.INT32
    else:
        raise ValueError(
            f"Unknown encoding {encoding}, expected one of {ScalarEncoding.FLOAT32} \
            or {ScalarEncoding.INT32}"
        )

    return chroma_pb.Vector(dimension=vector.size, vector=as_bytes, encoding=proto_encoding)


def from_proto_vector(vector: chroma_pb.Vector) -> Tuple[Embedding, ScalarEncoding]:
    encoding = vector.encoding
    as_array: Union[NDArray[np.int32], NDArray[np.float32]]
    if encoding == chroma_pb.ScalarEncoding.FLOAT32:
        as_array = np.frombuffer(vector.vector, dtype=np.float32)
        out_encoding = ScalarEncoding.FLOAT32
    elif encoding == chroma_pb.ScalarEncoding.INT32:
        as_array = np.frombuffer(vector.vector, dtype=np.int32)
        out_encoding = ScalarEncoding.INT32
    else:
        raise ValueError(
            f"Unknown encoding {encoding}, expected one of \
            {chroma_pb.ScalarEncoding.FLOAT32} or {chroma_pb.ScalarEncoding.INT32}"
        )

    return (as_array, out_encoding)


def from_proto_operation(operation: chroma_pb.Operation) -> Operation:
    if operation == chroma_pb.Operation.ADD:
        return Operation.ADD
    elif operation == chroma_pb.Operation.UPDATE:
        return Operation.UPDATE
    elif operation == chroma_pb.Operation.UPSERT:
        return Operation.UPSERT
    elif operation == chroma_pb.Operation.DELETE:
        return Operation.DELETE
    else:
        # TODO: full error
        raise RuntimeError(f"Unknown operation {operation}")


def from_proto_metadata(metadata: chroma_pb.UpdateMetadata) -> Optional[Metadata]:
    return cast(Optional[Metadata], _from_proto_metadata_handle_none(metadata, False))


def from_proto_update_metadata(
    metadata: chroma_pb.UpdateMetadata,
) -> Optional[UpdateMetadata]:
    return cast(
        Optional[UpdateMetadata], _from_proto_metadata_handle_none(metadata, True)
    )


def _from_proto_metadata_handle_none(
    metadata: chroma_pb.UpdateMetadata, is_update: bool
) -> Optional[Union[UpdateMetadata, Metadata]]:
    if not metadata.metadata:
        return None
    out_metadata: Dict[str, Union[str, int, float, bool, None]] = {}
    for key, value in metadata.metadata.items():
        if value.HasField("bool_value"):
            out_metadata[key] = value.bool_value
        elif value.HasField("string_value"):
            out_metadata[key] = value.string_value
        elif value.HasField("int_value"):
            out_metadata[key] = value.int_value
        elif value.HasField("float_value"):
            out_metadata[key] = value.float_value
        elif is_update:
            out_metadata[key] = None
        else:
            raise ValueError(f"Metadata key {key} value cannot be None")
    return out_metadata


def to_proto_update_metadata(metadata: UpdateMetadata) -> chroma_pb.UpdateMetadata:
    return chroma_pb.UpdateMetadata(
        metadata={k: to_proto_metadata_update_value(v) for k, v in metadata.items()}
    )


def from_proto_submit(
    operation_record: chroma_pb.OperationRecord, seq_id: SeqId
) -> LogRecord:
    embedding, encoding = from_proto_vector(operation_record.vector)
    record = LogRecord(
        log_offset=seq_id,
        record=OperationRecord(
            id=operation_record.id,
            embedding=embedding,
            encoding=encoding,
            metadata=from_proto_update_metadata(operation_record.metadata),
            operation=from_proto_operation(operation_record.operation),
        ),
    )
    return record


def from_proto_segment(segment: chroma_pb.Segment) -> Segment:
    return Segment(
        id=UUID(hex=segment.id),
        type=segment.type,
        scope=from_proto_segment_scope(segment.scope),
        collection=UUID(hex=segment.collection),
        metadata=from_proto_metadata(segment.metadata)
        if segment.HasField("metadata")
        else None,
        file_paths={name: [path for path in paths.paths] for name, paths in segment.file_paths.items()}
    )


def to_proto_segment(segment: Segment) -> chroma_pb.Segment:
    return chroma_pb.Segment(
        id=segment["id"].hex,
        type=segment["type"],
        scope=to_proto_segment_scope(segment["scope"]),
        collection=segment["collection"].hex,
        metadata=None
        if segment["metadata"] is None
        else to_proto_update_metadata(segment["metadata"]),
        file_paths={name: chroma_pb.FilePaths(paths=paths) for name, paths in segment["file_paths"].items()}
    )


def from_proto_segment_scope(segment_scope: chroma_pb.SegmentScope) -> SegmentScope:
    if segment_scope == chroma_pb.SegmentScope.VECTOR:
        return SegmentScope.VECTOR
    elif segment_scope == chroma_pb.SegmentScope.METADATA:
        return SegmentScope.METADATA
    elif segment_scope == chroma_pb.SegmentScope.RECORD:
        return SegmentScope.RECORD
    else:
        raise RuntimeError(f"Unknown segment scope {segment_scope}")


def to_proto_segment_scope(segment_scope: SegmentScope) -> chroma_pb.SegmentScope:
    if segment_scope == SegmentScope.VECTOR:
        return chroma_pb.SegmentScope.VECTOR
    elif segment_scope == SegmentScope.METADATA:
        return chroma_pb.SegmentScope.METADATA
    elif segment_scope == SegmentScope.RECORD:
        return chroma_pb.SegmentScope.RECORD
    else:
        raise RuntimeError(f"Unknown segment scope {segment_scope}")


def to_proto_metadata_update_value(
    value: Union[str, int, float, bool, None]
) -> chroma_pb.UpdateMetadataValue:
    # Be careful with the order here. Since bools are a subtype of int in python,
    # isinstance(value, bool) and isinstance(value, int) both return true
    # for a value of bool type.
    if isinstance(value, bool):
        return chroma_pb.UpdateMetadataValue(bool_value=value)
    elif isinstance(value, str):
        return chroma_pb.UpdateMetadataValue(string_value=value)
    elif isinstance(value, int):
        return chroma_pb.UpdateMetadataValue(int_value=value)
    elif isinstance(value, float):
        return chroma_pb.UpdateMetadataValue(float_value=value)
    # None is used to delete the metadata key.
    elif value is None:
        return chroma_pb.UpdateMetadataValue()
    else:
        raise ValueError(
            f"Unknown metadata value type {type(value)}, expected one of str, int, \
            float, or None"
        )


def from_proto_collection(collection: chroma_pb.Collection) -> Collection:
    return Collection(
        id=UUID(hex=collection.id),
        name=collection.name,
        configuration=CollectionConfigurationInternal.from_json_str(
            collection.configuration_json_str
        ),
        metadata=from_proto_metadata(collection.metadata)
        if collection.HasField("metadata")
        else None,
        dimension=collection.dimension
        if collection.HasField("dimension") and collection.dimension
        else None,
        database=collection.database,
        tenant=collection.tenant,
        version=collection.version,
        log_position=collection.log_position,
    )


def to_proto_collection(collection: Collection) -> chroma_pb.Collection:
    return chroma_pb.Collection(
        id=collection["id"].hex,
        name=collection["name"],
        configuration_json_str=collection.get_configuration().to_json_str(),
        metadata=None
        if collection["metadata"] is None
        else to_proto_update_metadata(collection["metadata"]),
        dimension=collection["dimension"],
        tenant=collection["tenant"],
        database=collection["database"],
        log_position=collection["log_position"],
        version=collection["version"],
    )


def to_proto_operation(operation: Operation) -> chroma_pb.Operation:
    if operation == Operation.ADD:
        return chroma_pb.Operation.ADD
    elif operation == Operation.UPDATE:
        return chroma_pb.Operation.UPDATE
    elif operation == Operation.UPSERT:
        return chroma_pb.Operation.UPSERT
    elif operation == Operation.DELETE:
        return chroma_pb.Operation.DELETE
    else:
        raise ValueError(
            f"Unknown operation {operation}, expected one of {Operation.ADD}, \
            {Operation.UPDATE}, {Operation.UPDATE}, or {Operation.DELETE}"
        )


def to_proto_submit(
    submit_record: OperationRecord,
) -> chroma_pb.OperationRecord:
    vector = None
    if submit_record["embedding"] is not None and submit_record["encoding"] is not None:
        vector = to_proto_vector(submit_record["embedding"], submit_record["encoding"])

    metadata = None
    if submit_record["metadata"] is not None:
        metadata = to_proto_update_metadata(submit_record["metadata"])

    return chroma_pb.OperationRecord(
        id=submit_record["id"],
        vector=vector,
        metadata=metadata,
        operation=to_proto_operation(submit_record["operation"]),
    )


def from_proto_vector_embedding_record(
    embedding_record: chroma_pb.VectorEmbeddingRecord,
) -> VectorEmbeddingRecord:
    return VectorEmbeddingRecord(
        id=embedding_record.id,
        embedding=from_proto_vector(embedding_record.vector)[0],
    )


def to_proto_vector_embedding_record(
    embedding_record: VectorEmbeddingRecord,
    encoding: ScalarEncoding,
) -> chroma_pb.VectorEmbeddingRecord:
    return chroma_pb.VectorEmbeddingRecord(
        id=embedding_record["id"],
        vector=to_proto_vector(embedding_record["embedding"], encoding),
    )


def from_proto_vector_query_result(
    vector_query_result: chroma_pb.VectorQueryResult,
) -> VectorQueryResult:
    return VectorQueryResult(
        id=vector_query_result.id,
        distance=vector_query_result.distance,
        embedding=from_proto_vector(vector_query_result.vector)[0],
    )


def from_proto_request_version_context(
    request_version_context: chroma_pb.RequestVersionContext,
) -> RequestVersionContext:
    return RequestVersionContext(
        collection_version=request_version_context.collection_version,
        log_position=request_version_context.log_position,
    )


def to_proto_request_version_context(
    request_version_context: RequestVersionContext,
) -> chroma_pb.RequestVersionContext:
    return chroma_pb.RequestVersionContext(
        collection_version=request_version_context["collection_version"],
        log_position=request_version_context["log_position"],
    )


def to_proto_where(where: Where) -> chroma_pb.Where:
    response = chroma_pb.Where()
    if len(where) != 1:
        raise ValueError(f"Expected where to have exactly one operator, got {where}")

    for key, value in where.items():
        if not isinstance(key, str):
            raise ValueError(f"Expected where key to be a str, got {key}")

        if key == "$and" or key == "$or":
            if not isinstance(value, list):
                raise ValueError(
                    f"Expected where value for $and or $or to be a list of where expressions, got {value}"
                )
            children: chroma_pb.WhereChildren = chroma_pb.WhereChildren(
                children=[to_proto_where(w) for w in value]
            )
            if key == "$and":
                children.operator = chroma_pb.BooleanOperator.AND
            else:
                children.operator = chroma_pb.BooleanOperator.OR

            response.children.CopyFrom(children)
            return response

        # At this point we know we're at a direct comparison. It can either
        # be of the form {"key": "value"} or {"key": {"$operator": "value"}}.

        dc = chroma_pb.DirectComparison()
        dc.key = key

        if not isinstance(value, dict):
            # {'key': 'value'} case
            if type(value) is str:
                ssc = chroma_pb.SingleStringComparison()
                ssc.value = value
                ssc.comparator = chroma_pb.GenericComparator.EQ
                dc.single_string_operand.CopyFrom(ssc)
            elif type(value) is bool:
                sbc = chroma_pb.SingleBoolComparison()
                sbc.value = value
                sbc.comparator = chroma_pb.GenericComparator.EQ
                dc.single_bool_operand.CopyFrom(sbc)
            elif type(value) is int:
                sic = chroma_pb.SingleIntComparison()
                sic.value = value
                sic.generic_comparator = chroma_pb.GenericComparator.EQ
                dc.single_int_operand.CopyFrom(sic)
            elif type(value) is float:
                sdc = chroma_pb.SingleDoubleComparison()
                sdc.value = value
                sdc.generic_comparator = chroma_pb.GenericComparator.EQ
                dc.single_double_operand.CopyFrom(sdc)
            else:
                raise ValueError(
                    f"Expected where value to be a string, int, or float, got {value}"
                )
        else:
            for operator, operand in value.items():
                if operator in ["$in", "$nin"]:
                    if not isinstance(operand, list):
                        raise ValueError(
                            f"Expected where value for $in or $nin to be a list of values, got {value}"
                        )
                    if len(operand) == 0 or not all(
                        isinstance(x, type(operand[0])) for x in operand
                    ):
                        raise ValueError(
                            f"Expected where operand value to be a non-empty list, and all values to be of the same type "
                            f"got {operand}"
                        )
                    list_operator = None
                    if operator == "$in":
                        list_operator = chroma_pb.ListOperator.IN
                    else:
                        list_operator = chroma_pb.ListOperator.NIN
                    if type(operand[0]) is str:
                        slo = chroma_pb.StringListComparison()
                        for x in operand:
                            slo.values.extend([x])  # type: ignore
                        slo.list_operator = list_operator
                        dc.string_list_operand.CopyFrom(slo)
                    elif type(operand[0]) is bool:
                        blo = chroma_pb.BoolListComparison()
                        for x in operand:
                            blo.values.extend([x])  # type: ignore
                        blo.list_operator = list_operator
                        dc.bool_list_operand.CopyFrom(blo)
                    elif type(operand[0]) is int:
                        ilo = chroma_pb.IntListComparison()
                        for x in operand:
                            ilo.values.extend([x])  # type: ignore
                        ilo.list_operator = list_operator
                        dc.int_list_operand.CopyFrom(ilo)
                    elif type(operand[0]) is float:
                        dlo = chroma_pb.DoubleListComparison()
                        for x in operand:
                            dlo.values.extend([x])  # type: ignore
                        dlo.list_operator = list_operator
                        dc.double_list_operand.CopyFrom(dlo)
                    else:
                        raise ValueError(
                            f"Expected where operand value to be a list of strings, ints, or floats, got {operand}"
                        )
                elif operator in ["$eq", "$ne", "$gt", "$lt", "$gte", "$lte"]:
                    # Direct comparison to a single value.
                    if type(operand) is str:
                        ssc = chroma_pb.SingleStringComparison()
                        ssc.value = operand
                        if operator == "$eq":
                            ssc.comparator = chroma_pb.GenericComparator.EQ
                        elif operator == "$ne":
                            ssc.comparator = chroma_pb.GenericComparator.NE
                        else:
                            raise ValueError(
                                f"Expected where operator to be $eq or $ne, got {operator}"
                            )
                        dc.single_string_operand.CopyFrom(ssc)
                    elif type(operand) is bool:
                        sbc = chroma_pb.SingleBoolComparison()
                        sbc.value = operand
                        if operator == "$eq":
                            sbc.comparator = chroma_pb.GenericComparator.EQ
                        elif operator == "$ne":
                            sbc.comparator = chroma_pb.GenericComparator.NE
                        else:
                            raise ValueError(
                                f"Expected where operator to be $eq or $ne, got {operator}"
                            )
                        dc.single_bool_operand.CopyFrom(sbc)
                    elif type(operand) is int:
                        sic = chroma_pb.SingleIntComparison()
                        sic.value = operand
                        if operator == "$eq":
                            sic.generic_comparator = chroma_pb.GenericComparator.EQ
                        elif operator == "$ne":
                            sic.generic_comparator = chroma_pb.GenericComparator.NE
                        elif operator == "$gt":
                            sic.number_comparator = chroma_pb.NumberComparator.GT
                        elif operator == "$lt":
                            sic.number_comparator = chroma_pb.NumberComparator.LT
                        elif operator == "$gte":
                            sic.number_comparator = chroma_pb.NumberComparator.GTE
                        elif operator == "$lte":
                            sic.number_comparator = chroma_pb.NumberComparator.LTE
                        else:
                            raise ValueError(
                                f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
                            )
                        dc.single_int_operand.CopyFrom(sic)
                    elif type(operand) is float:
                        sfc = chroma_pb.SingleDoubleComparison()
                        sfc.value = operand
                        if operator == "$eq":
                            sfc.generic_comparator = chroma_pb.GenericComparator.EQ
                        elif operator == "$ne":
                            sfc.generic_comparator = chroma_pb.GenericComparator.NE
                        elif operator == "$gt":
                            sfc.number_comparator = chroma_pb.NumberComparator.GT
                        elif operator == "$lt":
                            sfc.number_comparator = chroma_pb.NumberComparator.LT
                        elif operator == "$gte":
                            sfc.number_comparator = chroma_pb.NumberComparator.GTE
                        elif operator == "$lte":
                            sfc.number_comparator = chroma_pb.NumberComparator.LTE
                        else:
                            raise ValueError(
                                f"Expected where operator to be one of $eq, $ne, $gt, $lt, $gte, $lte, got {operator}"
                            )
                        dc.single_double_operand.CopyFrom(sfc)
                    else:
                        raise ValueError(
                            f"Expected where operand value to be a string, int, or float, got {operand}"
                        )
                else:
                    # This case should never happen, as we've already
                    # handled the case for direct comparisons.
                    pass

        response.direct_comparison.CopyFrom(dc)
    return response


def to_proto_where_document(where_document: WhereDocument) -> chroma_pb.WhereDocument:
    response = chroma_pb.WhereDocument()
    if len(where_document) != 1:
        raise ValueError(
            f"Expected where_document to have exactly one operator, got {where_document}"
        )

    for operator, operand in where_document.items():
        if operator == "$and" or operator == "$or":
            # Nested "$and" or "$or" expression.
            if not isinstance(operand, list):
                raise ValueError(
                    f"Expected where_document value for $and or $or to be a list of where_document expressions, got {operand}"
                )
            children: chroma_pb.WhereDocumentChildren = chroma_pb.WhereDocumentChildren(
                children=[to_proto_where_document(w) for w in operand]
            )
            if operator == "$and":
                children.operator = chroma_pb.BooleanOperator.AND
            else:
                children.operator = chroma_pb.BooleanOperator.OR

            response.children.CopyFrom(children)
        else:
            # Direct "$contains" or "$not_contains" comparison to a single
            # value.
            if not isinstance(operand, str):
                raise ValueError(
                    f"Expected where_document operand to be a string, got {operand}"
                )
            dwd = chroma_pb.DirectWhereDocument()
            dwd.document = operand
            if operator == "$contains":
                dwd.operator = chroma_pb.WhereDocumentOperator.CONTAINS
            elif operator == "$not_contains":
                dwd.operator = chroma_pb.WhereDocumentOperator.NOT_CONTAINS
            else:
                raise ValueError(
                    f"Expected where_document operator to be one of $contains, $not_contains, got {operator}"
                )
            response.direct.CopyFrom(dwd)

    return response


def to_proto_scan(scan: Scan) -> query_pb.ScanOperator:
    return query_pb.ScanOperator(
        collection=to_proto_collection(scan.collection),
        knn=to_proto_segment(scan.knn),
        metadata=to_proto_segment(scan.metadata),
        record=to_proto_segment(scan.record),
    )


def to_proto_filter(filter: Filter) -> query_pb.FilterOperator:
    return query_pb.FilterOperator(
        ids=chroma_pb.UserIds(ids=filter.user_ids) if filter.user_ids is not None else None,
        where=to_proto_where(filter.where) if filter.where else None,
        where_document=to_proto_where_document(filter.where_document)
        if filter.where_document
        else None,
    )


def to_proto_knn(knn: KNN) -> query_pb.KNNOperator:
    return query_pb.KNNOperator(
        embeddings=[
            to_proto_vector(vector=embedding, encoding=ScalarEncoding.FLOAT32)
            for embedding in knn.embeddings
        ],
        fetch=knn.fetch,
    )


def to_proto_limit(limit: Limit) -> query_pb.LimitOperator:
    return query_pb.LimitOperator(skip=limit.skip, fetch=limit.fetch)


def to_proto_projection(projection: Projection) -> query_pb.ProjectionOperator:
    return query_pb.ProjectionOperator(
        document=projection.document,
        embedding=projection.embedding,
        metadata=projection.metadata,
    )


def to_proto_knn_projection(projection: Projection) -> query_pb.KNNProjectionOperator:
    return query_pb.KNNProjectionOperator(
        projection=to_proto_projection(projection), distance=projection.rank
    )


def to_proto_count_plan(count: CountPlan) -> query_pb.CountPlan:
    return query_pb.CountPlan(scan=to_proto_scan(count.scan))


def from_proto_count_result(result: query_pb.CountResult) -> int:
    return result.count


def to_proto_get_plan(get: GetPlan) -> query_pb.GetPlan:
    return query_pb.GetPlan(
        scan=to_proto_scan(get.scan),
        filter=to_proto_filter(get.filter),
        limit=to_proto_limit(get.limit),
        projection=to_proto_projection(get.projection),
    )


def from_proto_projection_record(record: query_pb.ProjectionRecord) -> ProjectionRecord:
    return ProjectionRecord(
        id=record.id,
        document=record.document if record.document else None,
        embedding=from_proto_vector(record.embedding)[0]
        if record.embedding is not None
        else None,
        metadata=from_proto_metadata(record.metadata),
    )


def from_proto_get_result(result: query_pb.GetResult) -> Sequence[ProjectionRecord]:
    return [from_proto_projection_record(record) for record in result.records]


def to_proto_knn_plan(knn: KNNPlan) -> query_pb.KNNPlan:
    return query_pb.KNNPlan(
        scan=to_proto_scan(knn.scan),
        filter=to_proto_filter(knn.filter),
        knn=to_proto_knn(knn.knn),
        projection=to_proto_knn_projection(knn.projection),
    )


def from_proto_knn_projection_record(
    record: query_pb.KNNProjectionRecord,
) -> KNNProjectionRecord:
    return KNNProjectionRecord(
        record=from_proto_projection_record(record.record), distance=record.distance
    )


def from_proto_knn_batch_result(
    results: query_pb.KNNBatchResult,
) -> Sequence[Sequence[KNNProjectionRecord]]:
    return [
        [from_proto_knn_projection_record(record) for record in result.records]
        for result in results.results
    ]
