import asyncio
import os
import shutil
import tempfile
from uuid import UUID
import pytest
from itertools import count
from typing import (
    Generator,
    List,
    Callable,
    Optional,
    Dict,
    Union,
    Iterator,
    Sequence,
    Tuple,
)
from chromadb.errors import BatchSizeExceededError
from chromadb.ingest import Producer, Consumer
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
    OperationRecord,
    Operation,
    LogRecord,
    ScalarEncoding,
)
from chromadb.config import System, Settings
from pytest import FixtureRequest, approx
from asyncio import Event, wait_for, TimeoutError
import numpy as np


def sqlite() -> Generator[Tuple[Producer, Consumer], None, None]:
    """Fixture generator for sqlite Producer + Consumer"""
    system = System(Settings(allow_reset=True))
    db = system.require(SqliteDB)
    system.start()
    yield db, db
    system.stop()


def sqlite_persistent() -> Generator[Tuple[Producer, Consumer], None, None]:
    """Fixture generator for sqlite_persistent Producer + Consumer"""
    save_path = tempfile.mkdtemp()
    system = System(
        Settings(allow_reset=True, is_persistent=True, persist_directory=save_path)
    )
    db = system.require(SqliteDB)
    system.start()
    yield db, db
    system.stop()
    if os.path.exists(save_path):
        shutil.rmtree(save_path)


def fixtures() -> List[Callable[[], Generator[Tuple[Producer, Consumer], None, None]]]:
    fixtures = [sqlite, sqlite_persistent]
    if "CHROMA_CLUSTER_TEST_ONLY" in os.environ:
        # TODO: We should add the new log service here
        fixtures = []

    return fixtures


@pytest.fixture(scope="module", params=fixtures())
def producer_consumer(
    request: FixtureRequest,
) -> Generator[Tuple[Producer, Consumer], None, None]:
    yield next(request.param())


@pytest.fixture(scope="module")
def sample_embeddings() -> Iterator[OperationRecord]:
    def create_record(i: int) -> OperationRecord:
        vector = np.array([i + i * 0.1, i + 1 + i * 0.1])
        metadata: Optional[Dict[str, Union[str, int, float]]]
        if i % 2 == 0:
            metadata = None
        else:
            metadata = {"str_key": f"value_{i}", "int_key": i, "float_key": i + i * 0.1}

        record = OperationRecord(
            id=f"embedding_{i}",
            embedding=vector,
            encoding=ScalarEncoding.FLOAT32,
            metadata=metadata,
            operation=Operation.ADD,
        )
        return record

    return (create_record(i) for i in count())


class CapturingConsumeFn:
    embeddings: List[LogRecord]
    waiters: List[Tuple[int, Event]]

    def __init__(self) -> None:
        """A function that captures embeddings and allows you to wait for a certain
        number of embeddings to be available. It must be constructed in the thread with
        the main event loop
        """
        self.embeddings = []
        self.waiters = []
        self._loop = asyncio.get_event_loop()

    def __call__(self, embeddings: Sequence[LogRecord]) -> None:
        self.embeddings.extend(embeddings)
        for n, event in self.waiters:
            if len(self.embeddings) >= n:
                # event.set() is not thread safe, so we need to call it in the main event loop
                self._loop.call_soon_threadsafe(event.set)

    async def get(self, n: int, timeout_secs: int = 10) -> Sequence[LogRecord]:
        "Wait until at least N embeddings are available, then return all embeddings"
        if len(self.embeddings) >= n:
            return self.embeddings[:n]
        else:
            event = Event()
            self.waiters.append((n, event))
            # timeout so we don't hang forever on failure
            await wait_for(event.wait(), timeout_secs)
            return self.embeddings[:n]


def assert_approx_equal(a: Sequence[float], b: Sequence[float]) -> None:
    for i, j in zip(a, b):
        assert approx(i) == approx(j)


def assert_records_match(
    inserted_records: Sequence[OperationRecord],
    consumed_records: Sequence[LogRecord],
) -> None:
    """Given a list of inserted and consumed records, make sure they match"""
    assert len(consumed_records) == len(inserted_records)
    for inserted, consumed in zip(inserted_records, consumed_records):
        assert inserted["id"] == consumed["record"]["id"]
        assert inserted["operation"] == consumed["record"]["operation"]
        assert inserted["encoding"] == consumed["record"]["encoding"]
        assert inserted["metadata"] == consumed["record"]["metadata"]

        if inserted["embedding"] is not None:
            assert consumed["record"]["embedding"] is not None
            assert_approx_equal(inserted["embedding"], consumed["record"]["embedding"])


@pytest.mark.asyncio
async def test_backfill(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection_id = UUID("00000000-0000-0000-0000-000000000000")
    embeddings = produce_fns(producer, collection_id, sample_embeddings, 3)[0]

    consume_fn = CapturingConsumeFn()
    consumer.subscribe(collection_id, consume_fn, start=consumer.min_seqid())

    recieved = await consume_fn.get(3)
    assert_records_match(embeddings, recieved)


@pytest.mark.asyncio
async def test_notifications(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection_id = UUID("00000000-0000-0000-0000-000000000000")

    embeddings: List[OperationRecord] = []

    consume_fn = CapturingConsumeFn()

    consumer.subscribe(collection_id, consume_fn, start=consumer.min_seqid())

    for i in range(10):
        e = next(sample_embeddings)
        embeddings.append(e)
        producer.submit_embedding(collection_id, e)
        received = await consume_fn.get(i + 1)
        assert_records_match(embeddings, received)


@pytest.mark.asyncio
async def test_multiple_collections(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection_1 = UUID("00000000-0000-0000-0000-000000000001")
    collection_2 = UUID("00000000-0000-0000-0000-000000000002")

    embeddings_1: List[OperationRecord] = []
    embeddings_2: List[OperationRecord] = []

    consume_fn_1 = CapturingConsumeFn()
    consume_fn_2 = CapturingConsumeFn()

    consumer.subscribe(collection_1, consume_fn_1, start=consumer.min_seqid())
    consumer.subscribe(collection_2, consume_fn_2, start=consumer.min_seqid())

    for i in range(10):
        e_1 = next(sample_embeddings)
        embeddings_1.append(e_1)
        producer.submit_embedding(collection_1, e_1)
        results_2 = await consume_fn_1.get(i + 1)
        assert_records_match(embeddings_1, results_2)

        e_2 = next(sample_embeddings)
        embeddings_2.append(e_2)
        producer.submit_embedding(collection_2, e_2)
        results_2 = await consume_fn_2.get(i + 1)
        assert_records_match(embeddings_2, results_2)


@pytest.mark.asyncio
async def test_start_seq_id(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection = UUID("00000000-0000-0000-0000-000000000000")

    consume_fn_1 = CapturingConsumeFn()
    consume_fn_2 = CapturingConsumeFn()

    consumer.subscribe(collection, consume_fn_1, start=consumer.min_seqid())

    embeddings = produce_fns(producer, collection, sample_embeddings, 5)[0]

    results_1 = await consume_fn_1.get(5)
    assert_records_match(embeddings, results_1)

    start = consume_fn_1.embeddings[-1]["log_offset"]
    consumer.subscribe(collection, consume_fn_2, start=start)
    second_embeddings = produce_fns(producer, collection, sample_embeddings, 5)[0]
    assert isinstance(embeddings, list)
    embeddings.extend(second_embeddings)
    results_2 = await consume_fn_2.get(5)
    assert_records_match(embeddings[-5:], results_2)


@pytest.mark.asyncio
async def test_end_seq_id(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection = UUID("00000000-0000-0000-0000-000000000000")

    consume_fn_1 = CapturingConsumeFn()
    consume_fn_2 = CapturingConsumeFn()

    consumer.subscribe(collection, consume_fn_1, start=consumer.min_seqid())

    embeddings = produce_fns(producer, collection, sample_embeddings, 10)[0]

    results_1 = await consume_fn_1.get(10)
    assert_records_match(embeddings, results_1)

    end = consume_fn_1.embeddings[-5]["log_offset"]
    consumer.subscribe(collection, consume_fn_2, start=consumer.min_seqid(), end=end)

    results_2 = await consume_fn_2.get(6)
    assert_records_match(embeddings[:6], results_2)

    # Should never produce a 7th
    with pytest.raises(TimeoutError):
        _ = await wait_for(consume_fn_2.get(7), timeout=1)


@pytest.mark.asyncio
async def test_submit_batch(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection = UUID("00000000-0000-0000-0000-000000000000")

    embeddings = [next(sample_embeddings) for _ in range(100)]

    producer.submit_embeddings(collection, embeddings=embeddings)

    consume_fn = CapturingConsumeFn()
    consumer.subscribe(collection, consume_fn, start=consumer.min_seqid())

    recieved = await consume_fn.get(100)
    assert_records_match(embeddings, recieved)


@pytest.mark.asyncio
async def test_multiple_collections_batch(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()

    N_TOPICS = 2
    consume_fns = [CapturingConsumeFn() for _ in range(N_TOPICS)]
    for i in range(N_TOPICS):
        consumer.subscribe(
            UUID(f"00000000-0000-0000-0000-00000000000{i}"),
            consume_fns[i],
            start=consumer.min_seqid(),
        )

    embeddings_n: List[List[OperationRecord]] = [[] for _ in range(N_TOPICS)]

    PRODUCE_BATCH_SIZE = 10
    N_TO_PRODUCE = 100
    total_produced = 0
    for i in range(N_TO_PRODUCE // PRODUCE_BATCH_SIZE):
        for n in range(N_TOPICS):
            embeddings_n[n].extend(
                produce_fns(
                    producer,
                    UUID(f"00000000-0000-0000-0000-00000000000{n}"),
                    sample_embeddings,
                    PRODUCE_BATCH_SIZE,
                )[0]
            )
            recieved = await consume_fns[n].get(total_produced + PRODUCE_BATCH_SIZE)
            assert_records_match(embeddings_n[n], recieved)
        total_produced += PRODUCE_BATCH_SIZE


@pytest.mark.asyncio
async def test_max_batch_size(
    producer_consumer: Tuple[Producer, Consumer],
    sample_embeddings: Iterator[OperationRecord],
) -> None:
    producer, consumer = producer_consumer
    producer.reset_state()
    consumer.reset_state()
    collection = UUID("00000000-0000-0000-0000-000000000000")
    max_batch_size = producer.max_batch_size
    assert max_batch_size > 0

    # Make sure that we can produce a batch of size max_batch_size
    embeddings = [next(sample_embeddings) for _ in range(max_batch_size)]
    consume_fn = CapturingConsumeFn()
    consumer.subscribe(collection, consume_fn, start=consumer.min_seqid())
    producer.submit_embeddings(collection, embeddings=embeddings)
    received = await consume_fn.get(max_batch_size, timeout_secs=120)
    assert_records_match(embeddings, received)

    embeddings = [next(sample_embeddings) for _ in range(max_batch_size + 1)]
    # Make sure that we can't produce a batch of size > max_batch_size
    with pytest.raises(BatchSizeExceededError) as e:
        producer.submit_embeddings(collection, embeddings=embeddings)
    assert "Cannot submit more than" in str(e.value)
