import hashlib
import hypothesis
import hypothesis.strategies as st
from typing import Any, Optional, List, Dict, Union, cast
from typing_extensions import TypedDict
import uuid
import numpy as np
import numpy.typing as npt
import chromadb.api.types as types
import re
from hypothesis.strategies._internal.strategies import SearchStrategy
from chromadb.test.conftest import NOT_CLUSTER_ONLY
from dataclasses import dataclass
from chromadb.api.types import (
    Documents,
    Embeddable,
    EmbeddingFunction,
    Embeddings,
    Metadata,
)
from chromadb.types import LiteralValue, WhereOperator, LogicalOperator

# Set the random seed for reproducibility
np.random.seed(0)  # unnecessary, hypothesis does this for us

# See Hypothesis documentation for creating strategies at
# https://hypothesis.readthedocs.io/en/latest/data.html

# NOTE: Because these strategies are used in state machines, we need to
# work around an issue with state machines, in which strategies that frequently
# are marked as invalid (i.e. through the use of `assume` or `.filter`) can cause the
# state machine tests to fail with an hypothesis.errors.Unsatisfiable.

# Ultimately this is because the entire state machine is run as a single Hypothesis
# example, which ends up drawing from the same strategies an enormous number of times.
# Whenever a strategy marks itself as invalid, Hypothesis tries to start the entire
# state machine run over. See https://github.com/HypothesisWorks/hypothesis/issues/3618

# Because strategy generation is all interrelated, seemingly small changes (especially
# ones called early in a test) can have an outside effect. Generating lists with
# unique=True, or dictionaries with a min size seems especially bad.

# Please make changes to these strategies incrementally, testing to make sure they don't
# start generating unsatisfiable examples.

test_hnsw_config = {
    "hnsw:construction_ef": 128,
    "hnsw:search_ef": 128,
    "hnsw:M": 128,
}


class RecordSet(TypedDict):
    """
    A generated set of embeddings, ids, metadatas, and documents that
    represent what a user would pass to the API.
    """

    ids: Union[types.ID, List[types.ID]]
    embeddings: Optional[Union[types.Embeddings, types.Embedding]]
    metadatas: Optional[Union[List[Optional[types.Metadata]], types.Metadata]]
    documents: Optional[Union[List[types.Document], types.Document]]


class NormalizedRecordSet(TypedDict):
    """
    A RecordSet, with all fields normalized to lists.
    """

    ids: List[types.ID]
    embeddings: Optional[types.Embeddings]
    metadatas: Optional[List[Optional[types.Metadata]]]
    documents: Optional[List[types.Document]]


class StateMachineRecordSet(TypedDict):
    """
    Represents the internal state of a state machine in hypothesis tests.
    """

    ids: List[types.ID]
    embeddings: types.Embeddings
    metadatas: List[Optional[types.Metadata]]
    documents: List[Optional[types.Document]]


class Record(TypedDict):
    """
    A single generated record.
    """

    id: types.ID
    embedding: Optional[types.Embedding]
    metadata: Optional[types.Metadata]
    document: Optional[types.Document]


# TODO: support arbitrary text everywhere so we don't SQL-inject ourselves.
# TODO: support empty strings everywhere
sql_alphabet = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_"
safe_text = st.text(alphabet=sql_alphabet, min_size=1)
sql_alphabet_minus_underscore = (
    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-"
)
safe_text_min_size_3 = st.text(alphabet=sql_alphabet_minus_underscore, min_size=3)
tenant_database_name = st.text(alphabet=sql_alphabet, min_size=3)

# Workaround for FastAPI json encoding peculiarities
# https://github.com/tiangolo/fastapi/blob/8ac8d70d52bb0dd9eb55ba4e22d3e383943da05c/fastapi/encoders.py#L104
safe_text = safe_text.filter(lambda s: not s.startswith("_sa"))
safe_text_min_size_3 = safe_text_min_size_3.filter(lambda s: not s.startswith("_sa"))
tenant_database_name = tenant_database_name.filter(lambda s: not s.startswith("_sa"))

safe_integers = st.integers(
    min_value=-(2**31), max_value=2**31 - 1
)  # TODO: handle longs
# In distributed chroma, floats are 32 bit hence we need to
# restrict the generation to generate only 32 bit floats.
safe_floats = st.floats(
    allow_infinity=False,
    allow_nan=False,
    allow_subnormal=False,
    width=32,
    min_value=-1e6,
    max_value=1e6,
)  # TODO: handle infinity and NAN

safe_values: List[SearchStrategy[Union[int, float, str, bool]]] = [
    safe_text,
    safe_integers,
    safe_floats,
    st.booleans(),
]


def one_or_both(
    strategy_a: st.SearchStrategy[Any], strategy_b: st.SearchStrategy[Any]
) -> st.SearchStrategy[Any]:
    return st.one_of(
        st.tuples(strategy_a, strategy_b),
        st.tuples(strategy_a, st.none()),
        st.tuples(st.none(), strategy_b),
    )


# Temporarily generate only these to avoid SQL formatting issues.
legal_id_characters = (
    "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_./+"
)

float_types = [np.float16, np.float32, np.float64]
int_types = [np.int16, np.int32, np.int64]  # TODO: handle int types


@st.composite
def collection_name(draw: st.DrawFn) -> str:
    _collection_name_re = re.compile(r"^[a-zA-Z][a-zA-Z0-9-]{1,60}[a-zA-Z0-9]$")
    _ipv4_address_re = re.compile(r"^([0-9]{1,3}\.){3}[0-9]{1,3}$")
    _two_periods_re = re.compile(r"\.\.")

    name: str = draw(st.from_regex(_collection_name_re))
    hypothesis.assume(not _ipv4_address_re.match(name))
    hypothesis.assume(not _two_periods_re.search(name))

    return name


collection_metadata = st.one_of(
    st.none(), st.dictionaries(safe_text, st.one_of(*safe_values))
)


# TODO: Use a hypothesis strategy while maintaining embedding uniqueness
#       Or handle duplicate embeddings within a known epsilon
def create_embeddings(
    dim: int,
    count: int,
    dtype: npt.DTypeLike,
) -> types.Embeddings:
    embeddings: types.Embeddings = (
        np.random.uniform(
            low=-1.0,
            high=1.0,
            size=(count, dim),
        )
        .astype(dtype)
        .tolist()
    )

    return embeddings


def create_embeddings_ndarray(
    dim: int,
    count: int,
    dtype: npt.DTypeLike,
) -> np.typing.NDArray[Any]:
    return np.random.uniform(
        low=-1.0,
        high=1.0,
        size=(count, dim),
    ).astype(dtype)


class hashing_embedding_function(types.EmbeddingFunction[Documents]):
    def __init__(self, dim: int, dtype: npt.DTypeLike) -> None:
        self.dim = dim
        self.dtype = dtype

    def __call__(self, input: types.Documents) -> types.Embeddings:
        # Hash the texts and convert to hex strings
        hashed_texts = [
            list(hashlib.sha256(text.encode("utf-8")).hexdigest()) for text in input
        ]
        # Pad with repetition, or truncate the hex strings to the desired dimension
        padded_texts = [
            text * (self.dim // len(text)) + text[: self.dim % len(text)]
            for text in hashed_texts
        ]

        # Convert the hex strings to dtype
        embeddings: types.Embeddings = [
            np.array([int(char, 16) / 15.0 for char in text], dtype=self.dtype)
            for text in padded_texts
        ]

        return embeddings

    def __repr__(self) -> str:
        return f"hashing_embedding_function(dim={self.dim}, dtype={self.dtype})"


class not_implemented_embedding_function(types.EmbeddingFunction[Documents]):
    def __call__(self, input: Documents) -> Embeddings:
        assert False, "This embedding function is not implemented"


def embedding_function_strategy(
    dim: int, dtype: npt.DTypeLike
) -> st.SearchStrategy[types.EmbeddingFunction[Embeddable]]:
    return st.just(
        cast(EmbeddingFunction[Embeddable], hashing_embedding_function(dim, dtype))
    )


@dataclass
class ExternalCollection:
    """
    An external view of a collection.

    This strategy only contains information about a collection that a client of Chroma
    sees -- that is, it contains none of Chroma's internal bookkeeping. It should
    be used to test the API and client code.
    """

    name: str
    metadata: Optional[types.Metadata]
    embedding_function: Optional[types.EmbeddingFunction[Embeddable]]


@dataclass
class Collection(ExternalCollection):
    """
    An internal view of a collection.

    This strategy contains all the information Chroma uses internally to manage a
    collection. It is a superset of ExternalCollection and should be used to test
    internal Chroma logic.
    """

    id: uuid.UUID
    dimension: int
    dtype: npt.DTypeLike
    known_metadata_keys: types.Metadata
    known_document_keywords: List[str]
    has_documents: bool = False
    has_embeddings: bool = False


@st.composite
def collections(
    draw: st.DrawFn,
    add_filterable_data: bool = False,
    with_hnsw_params: bool = False,
    has_embeddings: Optional[bool] = None,
    has_documents: Optional[bool] = None,
    with_persistent_hnsw_params: st.SearchStrategy[bool] = st.just(False),
    max_hnsw_batch_size: int = 2000,
    max_hnsw_sync_threshold: int = 2000,
) -> Collection:
    """Strategy to generate a Collection object. If add_filterable_data is True, then known_metadata_keys and known_document_keywords will be populated with consistent data."""

    assert not ((has_embeddings is False) and (has_documents is False))

    name = draw(collection_name())
    metadata = draw(collection_metadata)
    dimension = draw(st.integers(min_value=2, max_value=2048))
    dtype = draw(st.sampled_from(float_types))

    use_persistent_hnsw_params = draw(with_persistent_hnsw_params)

    if use_persistent_hnsw_params and not with_hnsw_params:
        raise ValueError(
            "with_persistent_hnsw_params requires with_hnsw_params to be true"
        )

    if with_hnsw_params:
        if metadata is None:
            metadata = {}
        metadata.update(test_hnsw_config)
        if use_persistent_hnsw_params:
            metadata["hnsw:sync_threshold"] = draw(
                st.integers(min_value=3, max_value=max_hnsw_sync_threshold)
            )
            metadata["hnsw:batch_size"] = draw(
                st.integers(
                    min_value=3,
                    max_value=min(
                        [metadata["hnsw:sync_threshold"], max_hnsw_batch_size]
                    ),
                )
            )
        # Sometimes, select a space at random
        if draw(st.booleans()):
            # TODO: pull the distance functions from a source of truth that lives not
            # in tests once https://github.com/chroma-core/issues/issues/61 lands
            metadata["hnsw:space"] = draw(st.sampled_from(["cosine", "l2", "ip"]))

    known_metadata_keys: Dict[str, Union[int, str, float]] = {}
    if add_filterable_data:
        while len(known_metadata_keys) < 5:
            key = draw(safe_text)
            known_metadata_keys[key] = draw(st.one_of(*safe_values))

    if has_documents is None:
        has_documents = draw(st.booleans())
    assert has_documents is not None
    # For cluster tests, we want to avoid generating documents and where_document
    # clauses of length < 3. We also don't want them to contain certan special
    # characters like _ and % that implicitly involve searching for a regex in sqlite.
    if not NOT_CLUSTER_ONLY:
        if has_documents and add_filterable_data:
            known_document_keywords = draw(
                st.lists(safe_text_min_size_3, min_size=5, max_size=5)
            )
        else:
            known_document_keywords = []
    else:
        if has_documents and add_filterable_data:
            known_document_keywords = draw(st.lists(safe_text, min_size=5, max_size=5))
        else:
            known_document_keywords = []

    if not has_documents:
        has_embeddings = True
    else:
        if has_embeddings is None:
            has_embeddings = draw(st.booleans())
    assert has_embeddings is not None

    embedding_function = draw(embedding_function_strategy(dimension, dtype))

    return Collection(
        id=uuid.uuid4(),
        name=name,
        metadata=metadata,
        dimension=dimension,
        dtype=dtype,
        known_metadata_keys=known_metadata_keys,
        has_documents=has_documents,
        known_document_keywords=known_document_keywords,
        has_embeddings=has_embeddings,
        embedding_function=embedding_function,
    )


@st.composite
def metadata(
    draw: st.DrawFn,
    collection: Collection,
    min_size: int = 0,
    max_size: Optional[int] = None,
) -> Optional[types.Metadata]:
    """Strategy for generating metadata that could be a part of the given collection"""
    # First draw a random dictionary.
    metadata: types.Metadata = draw(
        st.dictionaries(
            safe_text, st.one_of(*safe_values), min_size=min_size, max_size=max_size
        )
    )
    # Then, remove keys that overlap with the known keys for the coll
    # to avoid type errors when comparing.
    if collection.known_metadata_keys:
        for key in collection.known_metadata_keys.keys():
            if key in metadata:
                del metadata[key]  # type: ignore
        # Finally, add in some of the known keys for the collection
        sampling_dict: Dict[str, st.SearchStrategy[Union[str, int, float]]] = {
            k: st.just(v) for k, v in collection.known_metadata_keys.items()
        }
        metadata.update(draw(st.fixed_dictionaries({}, optional=sampling_dict)))  # type: ignore
    # We don't allow submitting empty metadata
    if metadata == {}:
        return None
    return metadata


@st.composite
def document(draw: st.DrawFn, collection: Collection) -> types.Document:
    """Strategy for generating documents that could be a part of the given collection"""
    # For cluster tests, we want to avoid generating documents of length < 3.
    # We also don't want them to contain certan special
    # characters like _ and % that implicitly involve searching for a regex in sqlite.
    if not NOT_CLUSTER_ONLY:
        # Blacklist certain unicode characters that affect sqlite processing.
        # For example, the null (/x00) character makes sqlite stop processing a string.
        # Also, blacklist _ and % for cluster tests.
        blacklist_categories = ("Cc", "Cs", "Pc", "Po")
        if collection.known_document_keywords:
            known_words_st = st.sampled_from(collection.known_document_keywords)
        else:
            known_words_st = st.text(
                min_size=3,
                alphabet=st.characters(blacklist_categories=blacklist_categories),  # type: ignore
            )

        random_words_st = st.text(
            min_size=3, alphabet=st.characters(blacklist_categories=blacklist_categories)  # type: ignore
        )
        words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
        return " ".join(words)

    # Blacklist certain unicode characters that affect sqlite processing.
    # For example, the null (/x00) character makes sqlite stop processing a string.
    blacklist_categories = ("Cc", "Cs")  # type: ignore
    if collection.known_document_keywords:
        known_words_st = st.sampled_from(collection.known_document_keywords)
    else:
        known_words_st = st.text(
            min_size=1,
            alphabet=st.characters(blacklist_categories=blacklist_categories),  # type: ignore
        )

    random_words_st = st.text(
        min_size=1, alphabet=st.characters(blacklist_categories=blacklist_categories)  # type: ignore
    )
    words = draw(st.lists(st.one_of(known_words_st, random_words_st), min_size=1))
    return " ".join(words)


@st.composite
def recordsets(
    draw: st.DrawFn,
    collection_strategy: SearchStrategy[Collection] = collections(),
    id_strategy: SearchStrategy[str] = safe_text,
    min_size: int = 1,
    max_size: int = 50,
    # If num_unique_metadata is not None, then the number of metadata generations
    # will be the size of the record set. If set, the number of metadata
    # generations will be the value of num_unique_metadata.
    num_unique_metadata: Optional[int] = None,
    min_metadata_size: int = 0,
    max_metadata_size: Optional[int] = None,
) -> RecordSet:
    collection = draw(collection_strategy)

    ids = list(
        draw(st.lists(id_strategy, min_size=min_size, max_size=max_size, unique=True))
    )

    embeddings: Optional[Embeddings] = None
    if collection.has_embeddings:
        embeddings = create_embeddings(collection.dimension, len(ids), collection.dtype)
    num_metadata = num_unique_metadata if num_unique_metadata is not None else len(ids)
    generated_metadatas = draw(
        st.lists(
            metadata(
                collection, min_size=min_metadata_size, max_size=max_metadata_size
            ),
            min_size=num_metadata,
            max_size=num_metadata,
        )
    )
    metadatas = []
    for i in range(len(ids)):
        metadatas.append(generated_metadatas[i % len(generated_metadatas)])

    documents: Optional[Documents] = None
    if collection.has_documents:
        documents = draw(
            st.lists(document(collection), min_size=len(ids), max_size=len(ids))
        )

    # in the case where we have a single record, sometimes exercise
    # the code that handles individual values rather than lists.
    # In this case, any field may be a list or a single value.
    if len(ids) == 1:
        single_id: Union[str, List[str]] = ids[0] if draw(st.booleans()) else ids
        single_embedding = (
            embeddings[0]
            if embeddings is not None and draw(st.booleans())
            else embeddings
        )
        single_metadata: Union[Optional[Metadata], List[Optional[Metadata]]] = (
            metadatas[0] if draw(st.booleans()) else metadatas
        )
        single_document = (
            documents[0] if documents is not None and draw(st.booleans()) else documents
        )
        return {
            "ids": single_id,
            "embeddings": single_embedding,
            "metadatas": single_metadata,
            "documents": single_document,
        }
    return {
        "ids": ids,
        "embeddings": embeddings,
        "metadatas": metadatas,
        "documents": documents,
    }


def opposite_value(value: LiteralValue) -> SearchStrategy[Any]:
    """
    Returns a strategy that will generate all valid values except the input value - testing of $nin
    """
    if isinstance(value, float):
        return safe_floats.filter(lambda x: x != value)
    elif isinstance(value, str):
        return safe_text.filter(lambda x: x != value)
    elif isinstance(value, bool):
        return st.booleans().filter(lambda x: x != value)
    elif isinstance(value, int):
        return st.integers(min_value=-(2**31), max_value=2**31 - 1).filter(
            lambda x: x != value
        )
    else:
        return st.from_type(type(value)).filter(lambda x: x != value)


@st.composite
def where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
    """Generate a filter that could be used in a query against the given collection"""

    known_keys = sorted(collection.known_metadata_keys.keys())

    key = draw(st.sampled_from(known_keys))
    value = collection.known_metadata_keys[key]

    # This is hacky, but the distributed system does not support $in or $in so we
    # need to avoid generating these operators for now in that case.
    # TODO: Remove this once the distributed system supports $in and $nin
    legal_ops: List[Optional[str]]
    legal_ops = [None, "$eq", "$ne", "$in", "$nin"]

    if not isinstance(value, str) and not isinstance(value, bool):
        legal_ops.extend(["$gt", "$lt", "$lte", "$gte"])
    if isinstance(value, float):
        # Add or subtract a small number to avoid floating point rounding errors
        value = value + draw(st.sampled_from([1e-6, -1e-6]))
        # Truncate to 32 bit
        value = float(np.float32(value))

    op: WhereOperator = draw(st.sampled_from(legal_ops))

    if op is None:
        return {key: value}
    elif op == "$in":  # type: ignore
        if isinstance(value, str) and not value:
            return {}
        return {key: {op: [value, *[draw(opposite_value(value)) for _ in range(3)]]}}
    elif op == "$nin":  # type: ignore
        if isinstance(value, str) and not value:
            return {}
        return {key: {op: [draw(opposite_value(value)) for _ in range(3)]}}
    else:
        return {key: {op: value}}  # type: ignore


@st.composite
def where_doc_clause(draw: st.DrawFn, collection: Collection) -> types.WhereDocument:
    """Generate a where_document filter that could be used against the given collection"""
    # For cluster tests, we want to avoid generating where_document
    # clauses of length < 3. We also don't want them to contain certan special
    # characters like _ and % that implicitly involve searching for a regex in sqlite.
    if not NOT_CLUSTER_ONLY:
        if collection.known_document_keywords:
            word = draw(st.sampled_from(collection.known_document_keywords))
        else:
            word = draw(safe_text_min_size_3)
    else:
        if collection.known_document_keywords:
            word = draw(st.sampled_from(collection.known_document_keywords))
        else:
            word = draw(safe_text)

    # This is hacky, but the distributed system does not support $not_contains
    # so we need to avoid generating these operators for now in that case.
    # TODO: Remove this once the distributed system supports $not_contains
    op = draw(st.sampled_from(["$contains", "$not_contains"]))

    if op == "$contains":
        return {"$contains": word}
    else:
        assert op == "$not_contains"
        return {"$not_contains": word}


def binary_operator_clause(
    base_st: SearchStrategy[types.Where],
) -> SearchStrategy[types.Where]:
    op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"])
    return st.dictionaries(
        keys=op,
        values=st.lists(base_st, max_size=2, min_size=2),
        min_size=1,
        max_size=1,
    )


def binary_document_operator_clause(
    base_st: SearchStrategy[types.WhereDocument],
) -> SearchStrategy[types.WhereDocument]:
    op: SearchStrategy[LogicalOperator] = st.sampled_from(["$and", "$or"])
    return st.dictionaries(
        keys=op,
        values=st.lists(base_st, max_size=2, min_size=2),
        min_size=1,
        max_size=1,
    )


@st.composite
def recursive_where_clause(draw: st.DrawFn, collection: Collection) -> types.Where:
    base_st = where_clause(collection)
    where: types.Where = draw(st.recursive(base_st, binary_operator_clause))
    return where


@st.composite
def recursive_where_doc_clause(
    draw: st.DrawFn, collection: Collection
) -> types.WhereDocument:
    base_st = where_doc_clause(collection)
    where: types.WhereDocument = draw(
        st.recursive(base_st, binary_document_operator_clause)
    )
    return where


class Filter(TypedDict):
    where: Optional[types.Where]
    ids: Optional[Union[str, List[str]]]
    where_document: Optional[types.WhereDocument]


@st.composite
def filters(
    draw: st.DrawFn,
    collection_st: st.SearchStrategy[Collection],
    recordset_st: st.SearchStrategy[RecordSet],
    include_all_ids: bool = False,
) -> Filter:
    collection = draw(collection_st)
    recordset = draw(recordset_st)

    where_clause = draw(st.one_of(st.none(), recursive_where_clause(collection)))
    where_document_clause = draw(
        st.one_of(st.none(), recursive_where_doc_clause(collection))
    )

    ids: Optional[Union[List[types.ID], types.ID]]
    # Record sets can be a value instead of a list of values if there is only one record
    if isinstance(recordset["ids"], str):
        ids = [recordset["ids"]]
    else:
        ids = recordset["ids"]

    if not include_all_ids:
        ids = draw(st.one_of(st.none(), st.lists(st.sampled_from(ids), min_size=1)))
        if ids is not None:
            # Remove duplicates since hypothesis samples with replacement
            ids = list(set(ids))

    # Test both the single value list and the unwrapped single value case
    if ids is not None and len(ids) == 1 and draw(st.booleans()):
        ids = ids[0]

    return {"where": where_clause, "where_document": where_document_clause, "ids": ids}
