import uuid
from random import randint
from typing import cast, List, Any, Dict
import hypothesis
import numpy as np
import pytest
import hypothesis.strategies as st
from hypothesis import given, settings
from chromadb.api import ClientAPI
from chromadb.api.types import Embeddings, Metadatas
from chromadb.test.conftest import (
    reset,
    NOT_CLUSTER_ONLY,
    override_hypothesis_profile,
)
import chromadb.test.property.strategies as strategies
import chromadb.test.property.invariants as invariants
from chromadb.test.utils.wait_for_version_increase import wait_for_version_increase
from chromadb.utils.batch_utils import create_batches


collection_st = st.shared(strategies.collections(with_hnsw_params=True), key="coll")


# Hypothesis tends to generate smaller values so we explicitly segregate the
# the tests into tiers, Small, Medium. Hypothesis struggles to generate large
# record sets so we explicitly create a large record set without using Hypothesis
@given(
    collection=collection_st,
    record_set=strategies.recordsets(collection_st, min_size=1, max_size=500),
    should_compact=st.booleans(),
)
@settings(
    deadline=None,
    parent=override_hypothesis_profile(
        normal=hypothesis.settings(max_examples=500),
        fast=hypothesis.settings(max_examples=200),
    ),
)
def test_add_small(
    client: ClientAPI,
    collection: strategies.Collection,
    record_set: strategies.RecordSet,
    should_compact: bool,
) -> None:
    _test_add(client, collection, record_set, should_compact)


@given(
    collection=collection_st,
    record_set=strategies.recordsets(
        collection_st,
        min_size=250,
        max_size=500,
        num_unique_metadata=5,
        min_metadata_size=1,
        max_metadata_size=5,
    ),
    should_compact=st.booleans(),
)
@settings(
    deadline=None,
    parent=override_hypothesis_profile(
        normal=hypothesis.settings(max_examples=10),
        fast=hypothesis.settings(max_examples=5),
    ),
    suppress_health_check=[
        hypothesis.HealthCheck.too_slow,
        hypothesis.HealthCheck.data_too_large,
        hypothesis.HealthCheck.large_base_example,
        hypothesis.HealthCheck.function_scoped_fixture,
    ],
)
def test_add_medium(
    client: ClientAPI,
    collection: strategies.Collection,
    record_set: strategies.RecordSet,
    should_compact: bool,
) -> None:
    # Cluster tests transmit their results over grpc, which has a payload limit
    # This breaks the ann_accuracy invariant by default, since
    # the vector reader returns a payload of dataset size. So we need to batch
    # the queries in the ann_accuracy invariant
    _test_add(client, collection, record_set, should_compact, batch_ann_accuracy=True)


def _test_add(
    client: ClientAPI,
    collection: strategies.Collection,
    record_set: strategies.RecordSet,
    should_compact: bool,
    batch_ann_accuracy: bool = False,
) -> None:
    reset(client)

    # TODO: Generative embedding functions
    coll = client.create_collection(
        name=collection.name,
        metadata=collection.metadata,  # type: ignore
        embedding_function=collection.embedding_function,
    )
    initial_version = cast(int, coll.get_model()["version"])

    normalized_record_set = invariants.wrap_all(record_set)

    # TODO: The type of add() is incorrect as it does not allow for metadatas
    # like [{"a": 1}, None, {"a": 3}]
    for batch in create_batches(
        api=client,
        ids=cast(List[str], record_set["ids"]),
        embeddings=cast(Embeddings, record_set["embeddings"]),
        metadatas=cast(Metadatas, record_set["metadatas"]),
        documents=cast(List[str], record_set["documents"]),
    ):
        coll.add(*batch)
    # Only wait for compaction if the size of the collection is
    # some minimal size
    if (
        not NOT_CLUSTER_ONLY
        and should_compact
        and len(normalized_record_set["ids"]) > 10
    ):
        # Wait for the model to be updated
        wait_for_version_increase(client, collection.name, initial_version)

    invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))
    n_results = max(1, (len(normalized_record_set["ids"]) // 10))

    if batch_ann_accuracy:
        batch_size = 10
        for i in range(0, len(normalized_record_set["ids"]), batch_size):
            invariants.ann_accuracy(
                coll,
                cast(strategies.RecordSet, normalized_record_set),
                n_results=n_results,
                embedding_function=collection.embedding_function,
                query_indices=list(
                    range(i, min(i + batch_size, len(normalized_record_set["ids"])))
                ),
            )
    else:
        invariants.ann_accuracy(
            coll,
            cast(strategies.RecordSet, normalized_record_set),
            n_results=n_results,
            embedding_function=collection.embedding_function,
        )


# Hypothesis struggles to generate large record sets so we explicitly create
# a large record set
def create_large_recordset(
    min_size: int = 45000,
    max_size: int = 50000,
) -> strategies.RecordSet:
    size = randint(min_size, max_size)

    ids = [str(uuid.uuid4()) for _ in range(size)]
    metadatas = [{"some_key": f"{i}"} for i in range(size)]
    documents = [f"Document {i}" for i in range(size)]
    embeddings = [[1, 2, 3] for _ in range(size)]
    record_set: Dict[str, List[Any]] = {
        "ids": ids,
        "embeddings": cast(Embeddings, embeddings),
        "metadatas": metadatas,
        "documents": documents,
    }
    return cast(strategies.RecordSet, record_set)


@given(collection=collection_st, should_compact=st.booleans())
@settings(deadline=None, max_examples=5)
def test_add_large(
    client: ClientAPI, collection: strategies.Collection, should_compact: bool
) -> None:
    reset(client)

    record_set = create_large_recordset(
        min_size=10000,
        max_size=50000,
    )
    coll = client.create_collection(
        name=collection.name,
        metadata=collection.metadata,  # type: ignore
        embedding_function=collection.embedding_function,
    )
    normalized_record_set = invariants.wrap_all(record_set)
    initial_version = cast(int, coll.get_model()["version"])

    for batch in create_batches(
        api=client,
        ids=cast(List[str], record_set["ids"]),
        embeddings=cast(Embeddings, record_set["embeddings"]),
        metadatas=cast(Metadatas, record_set["metadatas"]),
        documents=cast(List[str], record_set["documents"]),
    ):
        coll.add(*batch)

    if (
        not NOT_CLUSTER_ONLY
        and should_compact
        and len(normalized_record_set["ids"]) > 10
    ):
        # Wait for the model to be updated, since the record set is larger, add some additional time
        wait_for_version_increase(
            client, collection.name, initial_version, additional_time=240
        )

    invariants.count(coll, cast(strategies.RecordSet, normalized_record_set))


@given(collection=collection_st)
@settings(deadline=None, max_examples=1)
def test_add_large_exceeding(
    client: ClientAPI, collection: strategies.Collection
) -> None:
    reset(client)

    record_set = create_large_recordset(
        min_size=client.get_max_batch_size(),
        max_size=client.get_max_batch_size()
        + 100,  # Exceed the max batch size by 100 records
    )
    coll = client.create_collection(
        name=collection.name,
        metadata=collection.metadata,  # type: ignore
        embedding_function=collection.embedding_function,
    )

    with pytest.raises(Exception) as e:
        coll.add(**record_set)  # type: ignore[arg-type]
    assert "exceeds maximum batch size" in str(e.value)


# TODO: This test fails right now because the ids are not sorted by the input order
@pytest.mark.xfail(
    reason="This is expected to fail right now. We should change the API to sort the \
    ids by input order."
)
def test_out_of_order_ids(client: ClientAPI) -> None:
    reset(client)
    ooo_ids = [
        "40",
        "05",
        "8",
        "6",
        "10",
        "01",
        "00",
        "3",
        "04",
        "20",
        "02",
        "9",
        "30",
        "11",
        "13",
        "2",
        "0",
        "7",
        "06",
        "5",
        "50",
        "12",
        "03",
        "4",
        "1",
    ]

    coll = client.create_collection(
        "test", embedding_function=lambda input: [[1, 2, 3] for _ in input]  # type: ignore
    )
    embeddings: Embeddings = [np.array([1, 2, 3]) for _ in ooo_ids]
    coll.add(ids=ooo_ids, embeddings=embeddings)
    get_ids = coll.get(ids=ooo_ids)["ids"]
    assert get_ids == ooo_ids


def test_add_partial(client: ClientAPI) -> None:
    """Tests adding a record set with some of the fields set to None."""
    reset(client)

    coll = client.create_collection("test")
    # TODO: We need to clean up the api types to support this typing
    coll.add(
        ids=["1", "2", "3"],
        # All embeddings must be provided, or else None - no partial lists allowed
        embeddings=[[1, 2, 3], [1, 2, 3], [1, 2, 3]],  # type: ignore
        # Metadatas can always be partial
        metadatas=[{"a": 1}, None, {"a": 3}],  # type: ignore
        # Documents are optional if embeddings are provided
        documents=["a", "b", None],  # type: ignore
    )

    results = coll.get()
    assert results["ids"] == ["1", "2", "3"]
    assert results["metadatas"] == [{"a": 1}, None, {"a": 3}]
    assert results["documents"] == ["a", "b", None]
