from threading import Lock
from typing import Dict, Sequence
from uuid import UUID, uuid4

from overrides import override

from chromadb.config import System
from chromadb.db.system import SysDB
from chromadb.segment import (
    SegmentImplementation,
    SegmentManager,
    SegmentType,
)
from chromadb.segment.distributed import SegmentDirectory
from chromadb.segment.impl.vector.hnsw_params import PersistentHnswParams
from chromadb.telemetry.opentelemetry import (
    OpenTelemetryGranularity,
    trace_method,
)
from chromadb.types import (
    Collection,
    Operation,
    Segment,
    SegmentScope,
)


class DistributedSegmentManager(SegmentManager):
    _sysdb: SysDB
    _system: System
    _instances: Dict[UUID, SegmentImplementation]
    _segment_directory: SegmentDirectory
    _lock: Lock

    def __init__(self, system: System):
        super().__init__(system)
        self._sysdb = self.require(SysDB)
        self._segment_directory = self.require(SegmentDirectory)
        self._system = system
        self._instances = {}
        self._lock = Lock()

    @trace_method(
        "DistributedSegmentManager.prepare_segments_for_new_collection",
        OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
    )
    @override
    def prepare_segments_for_new_collection(
        self, collection: Collection
    ) -> Sequence[Segment]:
        vector_segment = Segment(
            id=uuid4(),
            type=SegmentType.HNSW_DISTRIBUTED.value,
            scope=SegmentScope.VECTOR,
            collection=collection.id,
            metadata=PersistentHnswParams.extract(collection.metadata)
            if collection.metadata
            else None,
            file_paths={},
        )
        metadata_segment = Segment(
            id=uuid4(),
            type=SegmentType.BLOCKFILE_METADATA.value,
            scope=SegmentScope.METADATA,
            collection=collection.id,
            metadata=None,
            file_paths={},
        )
        record_segment = Segment(
            id=uuid4(),
            type=SegmentType.BLOCKFILE_RECORD.value,
            scope=SegmentScope.RECORD,
            collection=collection.id,
            metadata=None,
            file_paths={},
        )
        return [vector_segment, record_segment, metadata_segment]

    @override
    def delete_segments(self, collection_id: UUID) -> Sequence[UUID]:
        # TODO: this should be a pass, delete_collection is expected to delete segments in
        # distributed
        segments = self._sysdb.get_segments(collection=collection_id)
        return [s["id"] for s in segments]

    @trace_method(
        "DistributedSegmentManager.get_endpoint",
        OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
    )
    def get_endpoint(self, segment: Segment) -> str:
        return self._segment_directory.get_segment_endpoint(segment)

    @trace_method(
        "DistributedSegmentManager.hint_use_collection",
        OpenTelemetryGranularity.OPERATION_AND_SEGMENT,
    )
    @override
    def hint_use_collection(self, collection_id: UUID, hint_type: Operation) -> None:
        pass
