import os
import shutil
import tempfile
import pytest
from typing import (
    Generator,
    List,
    Callable,
    Iterator,
    Dict,
    Optional,
    Union,
    Sequence,
)

from chromadb.api.types import validate_metadata
from chromadb.config import System, Settings
from chromadb.db.base import ParameterValue, get_sql
from chromadb.db.impl.sqlite import SqliteDB
from chromadb.test.conftest import ProducerFn
from chromadb.types import (
    OperationRecord,
    MetadataEmbeddingRecord,
    Operation,
    RequestVersionContext,
    ScalarEncoding,
    Segment,
    SegmentScope,
    SeqId,
)
from pypika import Table
from chromadb.ingest import Producer
from chromadb.segment import MetadataReader
import uuid
import time

from chromadb.segment.impl.metadata.sqlite import SqliteMetadataSegment

from pytest import FixtureRequest
from itertools import count


def sqlite() -> Generator[System, None, None]:
    """Fixture generator for sqlite DB"""
    settings = Settings(allow_reset=True, is_persistent=False)
    system = System(settings)
    system.start()
    yield system
    system.stop()


def sqlite_persistent() -> Generator[System, None, None]:
    """Fixture generator for sqlite DB"""
    save_path = tempfile.mkdtemp()
    settings = Settings(
        allow_reset=True, is_persistent=True, persist_directory=save_path
    )
    system = System(settings)
    system.start()
    yield system
    system.stop()
    if os.path.exists(save_path):
        shutil.rmtree(save_path)


def system_fixtures() -> List[Callable[[], Generator[System, None, None]]]:
    return [sqlite, sqlite_persistent]


@pytest.fixture(scope="module", params=system_fixtures())
def system(request: FixtureRequest) -> Generator[System, None, None]:
    yield next(request.param())


@pytest.fixture(scope="function")
def sample_embeddings() -> Iterator[OperationRecord]:
    def create_record(i: int) -> OperationRecord:
        vector = [i + i * 0.1, i + 1 + i * 0.1]
        metadata: Optional[Dict[str, Union[str, int, float, bool]]]
        if i == 0:
            metadata = None
        else:
            metadata = {
                "str_key": f"value_{i}",
                "int_key": i,
                "float_key": i + i * 0.1,
                "bool_key": True,
            }
            if i % 3 == 0:
                metadata["div_by_three"] = "true"
            if i % 2 == 0:
                metadata["bool_key"] = False
            metadata["chroma:document"] = _build_document(i)

        record = OperationRecord(
            id=f"embedding_{i}",
            embedding=vector,  # type: ignore[typeddict-item]
            encoding=ScalarEncoding.FLOAT32,
            metadata=metadata,
            operation=Operation.ADD,
        )
        return record

    return (create_record(i) for i in count())


_digit_map = {
    "0": "zero",
    "1": "one",
    "2": "two",
    "3": "three",
    "4": "four",
    "5": "five",
    "6": "six",
    "7": "seven",
    "8": "eight",
    "9": "nine",
}


def _build_document(i: int) -> str:
    digits = list(str(i))
    return " ".join(_digit_map[d] for d in digits)


segment_definition = Segment(
    id=uuid.uuid4(),
    type="test_type",
    scope=SegmentScope.METADATA,
    collection=uuid.UUID(int=0),
    metadata=None,
    file_paths={},
)

segment_definition2 = Segment(
    id=uuid.uuid4(),
    type="test_type",
    scope=SegmentScope.METADATA,
    collection=uuid.UUID(int=1),
    metadata=None,
    file_paths={},
)


def sync(segment: MetadataReader, seq_id: SeqId) -> None:
    # Try for up to 5 seconds, then throw a TimeoutError
    start = time.time()
    while time.time() - start < 5:
        if segment.max_seqid() >= seq_id:
            return
        time.sleep(0.25)
    raise TimeoutError(f"Timed out waiting for seq_id {seq_id}")


def test_insert_and_count(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()

    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    max_id = produce_fns(producer, collection_id, sample_embeddings, 3)[1][-1]

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    sync(segment, max_id)

    assert (
        segment.count(
            request_version_context=RequestVersionContext(
                collection_version=0, log_position=0
            )
        )
        == 3
    )

    for i in range(3):
        max_id = producer.submit_embedding(collection_id, next(sample_embeddings))

    sync(segment, max_id)

    assert (
        segment.count(
            request_version_context=RequestVersionContext(
                collection_version=0, log_position=0
            )
        )
        == 6
    )


def assert_equiv_records(
    expected: Sequence[OperationRecord], actual: Sequence[MetadataEmbeddingRecord]
) -> None:
    assert len(expected) == len(actual)
    sorted_expected = sorted(expected, key=lambda r: r["id"])
    sorted_actual = sorted(actual, key=lambda r: r["id"])
    for e, a in zip(sorted_expected, sorted_actual):
        assert e["id"] == a["id"]
        assert e["metadata"] == a["metadata"]


def test_get(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    sync(segment, seq_ids[-1])
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    # get with bool key
    result = segment.get_metadata(
        where={"bool_key": True}, request_version_context=request_version_context
    )
    assert len(result) == 5

    result = segment.get_metadata(
        where={"bool_key": False}, request_version_context=request_version_context
    )
    assert len(result) == 4

    # Get all records
    results = segment.get_metadata(request_version_context=request_version_context)
    assert_equiv_records(embeddings, results)

    # get by ID
    result = segment.get_metadata(
        ids=[e["id"] for e in embeddings[0:5]],
        request_version_context=request_version_context,
    )
    assert_equiv_records(embeddings[0:5], result)

    # Get with limit and offset
    # Cannot rely on order(yet), but can rely on retrieving exactly the
    # whole set eventually
    ret: List[MetadataEmbeddingRecord] = []
    ret.extend(
        segment.get_metadata(limit=3, request_version_context=request_version_context)
    )
    assert len(ret) == 3
    ret.extend(
        segment.get_metadata(
            limit=3, offset=3, request_version_context=request_version_context
        )
    )
    assert len(ret) == 6
    ret.extend(
        segment.get_metadata(
            limit=3, offset=6, request_version_context=request_version_context
        )
    )
    assert len(ret) == 9
    ret.extend(
        segment.get_metadata(
            limit=3, offset=9, request_version_context=request_version_context
        )
    )
    assert len(ret) == 10
    assert_equiv_records(embeddings, ret)

    # Get with simple where
    result = segment.get_metadata(
        where={"div_by_three": "true"}, request_version_context=request_version_context
    )
    assert len(result) == 3

    # Get with gt/gte/lt/lte on int keys
    result = segment.get_metadata(
        where={"int_key": {"$gt": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4
    result = segment.get_metadata(
        where={"int_key": {"$gte": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 5
    result = segment.get_metadata(
        where={"int_key": {"$lt": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4
    result = segment.get_metadata(
        where={"int_key": {"$lte": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 5

    # Get with gt/lt on float keys with float values
    result = segment.get_metadata(
        where={"float_key": {"$gt": 5.01}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 5
    result = segment.get_metadata(
        where={"float_key": {"$lt": 4.99}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4

    # Get with gt/lt on float keys with int values
    result = segment.get_metadata(
        where={"float_key": {"$gt": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 5
    result = segment.get_metadata(
        where={"float_key": {"$lt": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4

    # Get with gt/lt on int keys with float values
    result = segment.get_metadata(
        where={"int_key": {"$gt": 5.01}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4
    result = segment.get_metadata(
        where={"int_key": {"$lt": 4.99}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 4

    # Get with $ne
    # Returns metadata that has an int_key but not equal to 5, or without an int_key
    result = segment.get_metadata(
        where={"int_key": {"$ne": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 9

    # get with multiple heterogenous conditions
    result = segment.get_metadata(
        where={"div_by_three": "true", "int_key": {"$gt": 5}},  # type:ignore[dict-item]
        request_version_context=request_version_context,
    )
    assert len(result) == 2

    # get with OR conditions
    result = segment.get_metadata(
        where={"$or": [{"int_key": 1}, {"int_key": 2}]},
        request_version_context=request_version_context,
    )
    assert len(result) == 2

    # get with AND conditions
    result = segment.get_metadata(
        where={
            "$and": [
                {"int_key": 3},
                {"float_key": {"$gt": 5}},  # type:ignore[dict-item]
            ]
        },
        request_version_context=request_version_context,
    )
    assert len(result) == 0
    result = segment.get_metadata(
        where={
            "$and": [
                {"int_key": 3},
                {"float_key": {"$lt": 5}},  # type:ignore[dict-item]
            ]
        },
        request_version_context=request_version_context,
    )
    assert len(result) == 1


def test_fulltext(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    max_id = produce_fns(producer, collection_id, sample_embeddings, 100)[1][-1]

    sync(segment, max_id)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    result = segment.get_metadata(
        where={"chroma:document": "four two"},
        request_version_context=request_version_context,
    )
    result2 = segment.get_metadata(
        ids=["embedding_42"], request_version_context=request_version_context
    )
    assert result == result2

    # Test single result
    result = segment.get_metadata(
        where_document={"$contains": "four two"},
        request_version_context=request_version_context,
    )
    assert len(result) == 1

    # Test not_contains
    # Returns records without documents or with documents not containing the searched text.
    result = segment.get_metadata(
        where_document={"$not_contains": "four two"},
        request_version_context=request_version_context,
    )
    assert (
        len(result)
        == len([i for i in range(1, 100) if "four two" not in _build_document(i)]) + 1
    )  # The first record does not have a document, which should be included in the result

    # Test many results
    result = segment.get_metadata(
        where_document={"$contains": "zero"},
        request_version_context=request_version_context,
    )
    assert len(result) == 9

    # Test not_contains
    result = segment.get_metadata(
        where_document={"$not_contains": "zero"},
        request_version_context=request_version_context,
    )
    assert (
        len(result)
        == len([i for i in range(1, 100) if "zero" not in _build_document(i)]) + 1
    )  # The first record does not have a document, which should be included in the result

    # test $and
    result = segment.get_metadata(
        where_document={"$and": [{"$contains": "four"}, {"$contains": "two"}]},
        request_version_context=request_version_context,
    )
    assert len(result) == 2
    assert set([r["id"] for r in result]) == {"embedding_42", "embedding_24"}

    result = segment.get_metadata(
        where_document={"$and": [{"$not_contains": "four"}, {"$not_contains": "two"}]},
        request_version_context=request_version_context,
    )
    assert (
        len(result)
        == len(
            [
                i
                for i in range(1, 100)
                if "four" not in _build_document(i) and "two" not in _build_document(i)
            ]
        )
        + 1
    )  # The first record does not have a document, which should be included in the result

    # test $or
    result = segment.get_metadata(
        where_document={"$or": [{"$contains": "zero"}, {"$contains": "one"}]},
        request_version_context=request_version_context,
    )
    ones = [i for i in range(1, 100) if "one" in _build_document(i)]
    zeros = [i for i in range(1, 100) if "zero" in _build_document(i)]
    expected = set([f"embedding_{i}" for i in set(ones + zeros)])
    assert set([r["id"] for r in result]) == expected

    result = segment.get_metadata(
        where_document={"$or": [{"$not_contains": "zero"}, {"$not_contains": "one"}]},
        request_version_context=request_version_context,
    )
    assert (
        len(result)
        == len(
            [
                i
                for i in range(1, 100)
                if "zero" not in _build_document(i) or "one" not in _build_document(i)
            ]
        )
        + 1
    )  # The first record does not have a document, which should be included in the result

    # test combo with where clause (negative case)
    result = segment.get_metadata(
        where={"int_key": {"$eq": 42}},  # type:ignore[dict-item]
        where_document={"$contains": "zero"},
        request_version_context=request_version_context,
    )
    assert len(result) == 0

    # test combo with where clause (positive case)
    result = segment.get_metadata(
        where={"int_key": {"$eq": 42}},  # type:ignore[dict-item]
        where_document={"$contains": "four"},
        request_version_context=request_version_context,
    )
    assert len(result) == 1

    # test partial words
    result = segment.get_metadata(
        where_document={"$contains": "zer"},
        request_version_context=request_version_context,
    )
    assert len(result) == 9


def test_delete(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
    max_id = seq_ids[-1]

    sync(segment, max_id)

    version_context = RequestVersionContext(collection_version=0, log_position=0)
    assert segment.count(request_version_context=version_context) == 10
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=version_context
    )
    assert_equiv_records(embeddings[:1], results)

    # Delete by ID
    delete_embedding = OperationRecord(
        id="embedding_0",
        embedding=None,
        encoding=None,
        metadata=None,
        operation=Operation.DELETE,
    )
    max_id = produce_fns(
        producer, collection_id, (delete_embedding for _ in range(1)), 1
    )[1][-1]

    sync(segment, max_id)

    version_context = RequestVersionContext(collection_version=0, log_position=0)
    assert segment.count(request_version_context=version_context) == 9
    assert (
        segment.get_metadata(
            ids=["embedding_0"], request_version_context=version_context
        )
        == []
    )

    # Delete is idempotent
    max_id = produce_fns(
        producer, collection_id, (delete_embedding for _ in range(1)), 1
    )[1][-1]

    sync(segment, max_id)
    assert segment.count(request_version_context=version_context) == 9
    assert (
        segment.get_metadata(
            ids=["embedding_0"], request_version_context=version_context
        )
        == []
    )

    # re-add
    max_id = producer.submit_embedding(collection_id, embeddings[0])
    sync(segment, max_id)
    assert segment.count(request_version_context=version_context) == 10
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=version_context
    )


def test_update(system: System, sample_embeddings: Iterator[OperationRecord]) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    _test_update(sample_embeddings, producer, segment, collection_id, Operation.UPDATE)

    # Update nonexisting ID
    update_record = OperationRecord(
        id="no_such_id",
        metadata={"foo": "bar"},
        embedding=None,
        encoding=None,
        operation=Operation.UPDATE,
    )
    max_id = producer.submit_embedding(collection_id, update_record)
    sync(segment, max_id)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    results = segment.get_metadata(
        ids=["no_such_id"], request_version_context=request_version_context
    )
    assert len(results) == 0
    assert segment.count(request_version_context=request_version_context) == 3


def test_upsert(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    _test_update(sample_embeddings, producer, segment, collection_id, Operation.UPSERT)

    # upsert previously nonexisting ID
    update_record = OperationRecord(
        id="no_such_id",
        metadata={"foo": "bar"},
        embedding=None,
        encoding=None,
        operation=Operation.UPSERT,
    )
    max_id = produce_fns(
        producer=producer,
        collection_id=collection_id,
        embeddings=(update_record for _ in range(1)),
        n=1,
    )[1][-1]
    sync(segment, max_id)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    results = segment.get_metadata(
        ids=["no_such_id"], request_version_context=request_version_context
    )
    assert results[0]["metadata"] == {"foo": "bar"}


def _test_update(
    sample_embeddings: Iterator[OperationRecord],
    producer: Producer,
    segment: MetadataReader,
    collection_id: uuid.UUID,
    op: Operation,
) -> None:
    """test code common between update and upsert paths"""

    embeddings = [next(sample_embeddings) for i in range(3)]

    max_id = 0
    for e in embeddings:
        max_id = producer.submit_embedding(collection_id, e)

    sync(segment, max_id)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert_equiv_records(embeddings[:1], results)

    # Update embedding with no metadata
    update_record = OperationRecord(
        id="embedding_0",
        metadata={"chroma:document": "foo bar"},
        embedding=None,
        encoding=None,
        operation=op,
    )
    max_id = producer.submit_embedding(collection_id, update_record)
    sync(segment, max_id)
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert results[0]["metadata"] == {"chroma:document": "foo bar"}
    results = segment.get_metadata(
        where_document={"$contains": "foo"},
        request_version_context=request_version_context,
    )
    assert results[0]["metadata"] == {"chroma:document": "foo bar"}

    # Update and overrwrite key
    update_record = OperationRecord(
        id="embedding_0",
        metadata={"chroma:document": "biz buz"},
        embedding=None,
        encoding=None,
        operation=op,
    )
    max_id = producer.submit_embedding(collection_id, update_record)
    sync(segment, max_id)
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert results[0]["metadata"] == {"chroma:document": "biz buz"}
    results = segment.get_metadata(
        where_document={"$contains": "biz"},
        request_version_context=request_version_context,
    )
    assert results[0]["metadata"] == {"chroma:document": "biz buz"}
    results = segment.get_metadata(
        where_document={"$contains": "foo"},
        request_version_context=request_version_context,
    )
    assert len(results) == 0

    # Update and add key
    update_record = OperationRecord(
        id="embedding_0",
        metadata={"baz": 42},
        embedding=None,
        encoding=None,
        operation=op,
    )
    max_id = producer.submit_embedding(collection_id, update_record)
    sync(segment, max_id)
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert results[0]["metadata"] == {"chroma:document": "biz buz", "baz": 42}

    # Update and delete key
    update_record = OperationRecord(
        id="embedding_0",
        metadata={"chroma:document": None},
        embedding=None,
        encoding=None,
        operation=op,
    )
    max_id = producer.submit_embedding(collection_id, update_record)
    sync(segment, max_id)
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert results[0]["metadata"] == {"baz": 42}
    results = segment.get_metadata(
        where_document={"$contains": "biz"},
        request_version_context=request_version_context,
    )
    assert len(results) == 0


def test_limit(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()

    collection_id = segment_definition["collection"]
    max_id = produce_fns(producer, collection_id, sample_embeddings, 3)[1][-1]

    collection_id_2 = segment_definition2["collection"]
    max_id2 = produce_fns(producer, collection_id_2, sample_embeddings, 3)[1][-1]

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    segment2 = SqliteMetadataSegment(system, segment_definition2)
    segment2.start()

    sync(segment, max_id)
    sync(segment2, max_id2)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )

    assert segment.count(request_version_context=request_version_context) == 3

    for i in range(3):
        max_id = producer.submit_embedding(collection_id, next(sample_embeddings))

    sync(segment, max_id)

    assert segment.count(request_version_context=request_version_context) == 6

    res = segment.get_metadata(limit=3, request_version_context=request_version_context)
    assert len(res) == 3

    # if limit is negative, throw error
    with pytest.raises(ValueError):
        segment.get_metadata(limit=-1, request_version_context=request_version_context)

    # if offset is more than number of results, return empty list
    res = segment.get_metadata(
        limit=3, offset=10, request_version_context=request_version_context
    )
    assert len(res) == 0


def test_delete_segment(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
    max_id = seq_ids[-1]

    sync(segment, max_id)

    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    assert segment.count(request_version_context=request_version_context) == 10
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert_equiv_records(embeddings[:1], results)
    _id = segment._id
    segment.delete()
    _db = system.instance(SqliteDB)
    t = Table("embeddings")
    q = (
        _db.querybuilder()
        .from_(t)
        .select(t.id)
        .where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
    )
    sql, params = get_sql(q)
    with _db.tx() as cur:
        res = cur.execute(sql, params)
        # assert that the segment is gone
        assert len(res.fetchall()) == 0

    fts_t = Table("embedding_fulltext_search")
    q_fts = (
        _db.querybuilder()
        .from_(fts_t)
        .select()
        .where(
            fts_t.rowid.isin(
                _db.querybuilder()
                .from_(t)
                .select(t.id)
                .where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
            )
        )
    )
    sql, params = get_sql(q_fts)
    with _db.tx() as cur:
        res = cur.execute(sql, params)
        # assert that all FTS rows are gone
        assert len(res.fetchall()) == 0


def test_delete_single_fts_record(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
    max_id = seq_ids[-1]

    sync(segment, max_id)

    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )
    assert segment.count(request_version_context=request_version_context) == 10
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert_equiv_records(embeddings[:1], results)
    _id = segment._id
    _db = system.instance(SqliteDB)
    # Delete by ID
    delete_embedding = OperationRecord(
        id="embedding_0",
        embedding=None,
        encoding=None,
        metadata=None,
        operation=Operation.DELETE,
    )
    max_id = produce_fns(
        producer, collection_id, (delete_embedding for _ in range(1)), 1
    )[1][-1]
    t = Table("embeddings")

    sync(segment, max_id)
    fts_t = Table("embedding_fulltext_search")
    q_fts = (
        _db.querybuilder()
        .from_(fts_t)
        .select()
        .where(
            fts_t.rowid.isin(
                _db.querybuilder()
                .from_(t)
                .select(t.id)
                .where(t.segment_id == ParameterValue(_db.uuid_to_db(_id)))
                .where(t.embedding_id == ParameterValue(delete_embedding["id"]))
            )
        )
    )
    sql, params = get_sql(q_fts)
    with _db.tx() as cur:
        res = cur.execute(sql, params)
        # assert that the ids that are deleted from the segment are also gone from the fts table
        assert len(res.fetchall()) == 0


def test_include_metadata(
    system: System,
    sample_embeddings: Iterator[OperationRecord],
    produce_fns: ProducerFn,
) -> None:
    producer = system.instance(Producer)
    system.reset_state()
    collection_id = segment_definition["collection"]
    # We know that the collection_id exists so we can cast
    collection_id = collection_id

    segment = SqliteMetadataSegment(system, segment_definition)
    segment.start()

    embeddings, seq_ids = produce_fns(producer, collection_id, sample_embeddings, 10)
    max_id = seq_ids[-1]

    sync(segment, max_id)
    request_version_context = RequestVersionContext(
        collection_version=0, log_position=0
    )

    assert segment.count(request_version_context=request_version_context) == 10
    results = segment.get_metadata(
        ids=["embedding_0"], request_version_context=request_version_context
    )
    assert_equiv_records(embeddings[:1], results)

    # Test include_metadata=False
    results = segment.get_metadata(
        ids=["embedding_0"],
        include_metadata=False,
        request_version_context=request_version_context,
    )
    assert len(results) == 1
    assert results[0]["metadata"] is None

    # Test include_metadata=True
    results = segment.get_metadata(
        ids=["embedding_0"],
        include_metadata=True,
        request_version_context=request_version_context,
    )
    assert len(results) == 1
    assert results[0]["metadata"] == embeddings[0]["metadata"]


def test_metadata_validation_forbidden_key() -> None:
    with pytest.raises(ValueError, match="chroma:document"):
        validate_metadata(
            {"chroma:document": "this is not the document you are looking for"}
        )
