from abc import abstractmethod
import json
from overrides import override
from typing import (
    Any,
    ClassVar,
    Dict,
    List,
    Optional,
    Protocol,
    Union,
    TypeVar,
    cast,
)
from typing_extensions import Self
from multiprocessing import cpu_count

from chromadb.serde import JSONSerializable

# TODO: move out of API


class StaticParameterError(Exception):
    """Represents an error that occurs when a static parameter is set."""

    pass


class InvalidConfigurationError(ValueError):
    """Represents an error that occurs when a configuration is invalid."""

    pass


ParameterValue = Union[str, int, float, bool, "ConfigurationInternal"]


class ParameterValidator(Protocol):
    """Represents an abstract parameter validator."""

    @abstractmethod
    def __call__(self, value: ParameterValue) -> bool:
        """Returns whether the given value is valid."""
        raise NotImplementedError()


class ConfigurationDefinition:
    """Represents the definition of a configuration."""

    name: str
    validator: ParameterValidator
    is_static: bool
    default_value: ParameterValue

    def __init__(
        self,
        name: str,
        validator: ParameterValidator,
        is_static: bool,
        default_value: ParameterValue,
    ):
        self.name = name
        self.validator = validator
        self.is_static = is_static
        self.default_value = default_value


class ConfigurationParameter:
    """Represents a parameter of a configuration."""

    name: str
    value: ParameterValue

    def __init__(self, name: str, value: ParameterValue):
        self.name = name
        self.value = value

    def __repr__(self) -> str:
        return f"ConfigurationParameter({self.name}, {self.value})"

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, ConfigurationParameter):
            return NotImplemented
        return self.name == __value.name and self.value == __value.value


T = TypeVar("T", bound="ConfigurationInternal")


class ConfigurationInternal(JSONSerializable["ConfigurationInternal"]):
    """Represents an abstract configuration, used internally by Chroma."""

    # The internal data structure used to store the parameters
    # All expected parameters must be present with defaults or None values at initialization
    parameter_map: Dict[str, ConfigurationParameter]
    definitions: ClassVar[Dict[str, ConfigurationDefinition]]

    def __init__(self, parameters: Optional[List[ConfigurationParameter]] = None):
        """Initializes a new instance of the Configuration class. Respecting defaults and
        validators."""
        self.parameter_map = {}
        if parameters is not None:
            for parameter in parameters:
                if parameter.name not in self.definitions:
                    raise ValueError(f"Invalid parameter name: {parameter.name}")

                definition = self.definitions[parameter.name]
                # Handle the case where we have a recursive configuration definition
                if isinstance(parameter.value, dict):
                    child_type = globals().get(parameter.value.get("_type", None))
                    if child_type is None:
                        raise ValueError(
                            f"Invalid configuration type: {parameter.value}"
                        )
                    parameter.value = child_type.from_json(parameter.value)
                if not isinstance(parameter.value, type(definition.default_value)):
                    raise ValueError(f"Invalid parameter value: {parameter.value}")

                parameter_validator = definition.validator
                if not parameter_validator(parameter.value):
                    raise ValueError(f"Invalid parameter value: {parameter.value}")
                self.parameter_map[parameter.name] = parameter
        # Apply the defaults for any missing parameters
        for name, definition in self.definitions.items():
            if name not in self.parameter_map:
                self.parameter_map[name] = ConfigurationParameter(
                    name=name, value=definition.default_value
                )

        self.configuration_validator()

    def __repr__(self) -> str:
        return f"Configuration({self.parameter_map.values()})"

    def __eq__(self, __value: object) -> bool:
        if not isinstance(__value, ConfigurationInternal):
            return NotImplemented
        return self.parameter_map == __value.parameter_map

    @abstractmethod
    def configuration_validator(self) -> None:
        """Perform custom validation when parameters are dependent on each other.

        Raises an InvalidConfigurationError if the configuration is invalid.
        """
        pass

    def get_parameters(self) -> List[ConfigurationParameter]:
        """Returns the parameters of the configuration."""
        return list(self.parameter_map.values())

    def get_parameter(self, name: str) -> ConfigurationParameter:
        """Returns the parameter with the given name, or except if it doesn't exist."""
        if name not in self.parameter_map:
            raise ValueError(
                f"Invalid parameter name: {name} for configuration {self.__class__.__name__}"
            )
        param_value = cast(ConfigurationParameter, self.parameter_map.get(name))
        return param_value

    def set_parameter(self, name: str, value: Union[str, int, float, bool]) -> None:
        """Sets the parameter with the given name to the given value."""
        if name not in self.definitions:
            raise ValueError(f"Invalid parameter name: {name}")
        definition = self.definitions[name]
        parameter = self.parameter_map[name]
        if definition.is_static:
            raise StaticParameterError(f"Cannot set static parameter: {name}")
        if not definition.validator(value):
            raise ValueError(f"Invalid value for parameter {name}: {value}")
        parameter.value = value

    @override
    def to_json_str(self) -> str:
        """Returns the JSON representation of the configuration."""
        return json.dumps(self.to_json())

    @classmethod
    @override
    def from_json_str(cls, json_str: str) -> Self:
        """Returns a configuration from the given JSON string."""
        try:
            config_json = json.loads(json_str)
        except json.JSONDecodeError:
            raise ValueError(
                f"Unable to decode configuration from JSON string: {json_str}"
            )
        return cls.from_json(config_json)

    @override
    def to_json(self) -> Dict[str, Any]:
        """Returns the JSON compatible dictionary representation of the configuration."""
        json_dict = {
            name: parameter.value.to_json()
            if isinstance(parameter.value, ConfigurationInternal)
            else parameter.value
            for name, parameter in self.parameter_map.items()
        }
        # What kind of configuration is this?
        json_dict["_type"] = self.__class__.__name__
        return json_dict

    @classmethod
    @override
    def from_json(cls, json_map: Dict[str, Any]) -> Self:
        """Returns a configuration from the given JSON string."""
        if cls.__name__ != json_map.get("_type", None):
            raise ValueError(
                f"Trying to instantiate configuration of type {cls.__name__} from JSON with type {json_map['_type']}"
            )
        parameters = []
        for name, value in json_map.items():
            # Type value is only for storage
            if name == "_type":
                continue
            parameters.append(ConfigurationParameter(name=name, value=value))
        return cls(parameters=parameters)


class HNSWConfigurationInternal(ConfigurationInternal):
    """Internal representation of the HNSW configuration.
    Used for validation, defaults, serialization and deserialization."""

    definitions = {
        "space": ConfigurationDefinition(
            name="space",
            validator=lambda value: isinstance(value, str)
            and value in ["l2", "ip", "cosine"],
            is_static=True,
            default_value="l2",
        ),
        "ef_construction": ConfigurationDefinition(
            name="ef_construction",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=True,
            default_value=100,
        ),
        "ef_search": ConfigurationDefinition(
            name="ef_search",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=False,
            default_value=100,
        ),
        "num_threads": ConfigurationDefinition(
            name="num_threads",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=False,
            default_value=cpu_count(),  # By default use all cores available
        ),
        "M": ConfigurationDefinition(
            name="M",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=True,
            default_value=16,
        ),
        "resize_factor": ConfigurationDefinition(
            name="resize_factor",
            validator=lambda value: isinstance(value, float) and value >= 1,
            is_static=True,
            default_value=1.2,
        ),
        "batch_size": ConfigurationDefinition(
            name="batch_size",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=True,
            default_value=100,
        ),
        "sync_threshold": ConfigurationDefinition(
            name="sync_threshold",
            validator=lambda value: isinstance(value, int) and value >= 1,
            is_static=True,
            default_value=1000,
        ),
    }

    @override
    def configuration_validator(self) -> None:
        batch_size = self.parameter_map.get("batch_size")
        sync_threshold = self.parameter_map.get("sync_threshold")

        if (
            batch_size
            and sync_threshold
            and cast(int, batch_size.value) > cast(int, sync_threshold.value)
        ):
            raise InvalidConfigurationError(
                "batch_size must be less than or equal to sync_threshold"
            )

    @classmethod
    def from_legacy_params(cls, params: Dict[str, Any]) -> Self:
        """Returns an HNSWConfiguration from a metadata dict containing legacy HNSW parameters. Used for migration."""

        # We maintain this map to avoid a circular import with HnswParams, and
        # because then names won't change since we intend to deprecate HNSWParams
        # in favor of this type of configuration.
        old_to_new = {
            "hnsw:space": "space",
            "hnsw:construction_ef": "ef_construction",
            "hnsw:search_ef": "ef_search",
            "hnsw:M": "M",
            "hnsw:num_threads": "num_threads",
            "hnsw:resize_factor": "resize_factor",
            "hnsw:batch_size": "batch_size",
            "hnsw:sync_threshold": "sync_threshold",
        }

        parameters = []
        for name, value in params.items():
            if name not in old_to_new:
                raise ValueError(f"Invalid legacy HNSW parameter name: {name}")
            parameters.append(
                ConfigurationParameter(name=old_to_new[name], value=value)
            )
        return cls(parameters)


# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class HNSWConfigurationInterface(HNSWConfigurationInternal):
    """HNSW index configuration parameters.
    See https://docs.trychroma.com/guides#changing-the-distance-function for more information.
    """

    def __init__(
        self,
        space: str = "l2",
        ef_construction: int = 100,
        ef_search: int = 100,
        num_threads: int = cpu_count(),
        M: int = 16,
        resize_factor: float = 1.2,
        batch_size: int = 100,
        sync_threshold: int = 1000,
    ):
        parameters = [
            ConfigurationParameter(name="space", value=space),
            ConfigurationParameter(name="ef_construction", value=ef_construction),
            ConfigurationParameter(name="ef_search", value=ef_search),
            ConfigurationParameter(name="num_threads", value=num_threads),
            ConfigurationParameter(name="M", value=M),
            ConfigurationParameter(name="resize_factor", value=resize_factor),
            ConfigurationParameter(name="batch_size", value=batch_size),
            ConfigurationParameter(name="sync_threshold", value=sync_threshold),
        ]

        super().__init__(parameters=parameters)


# Alias for user convenience - the user doesn't need to know this is an 'Interface'
HNSWConfiguration = HNSWConfigurationInterface


class CollectionConfigurationInternal(ConfigurationInternal):
    """Internal representation of the collection configuration.
    Used for validation, defaults, and serialization / deserialization."""

    definitions = {
        "hnsw_configuration": ConfigurationDefinition(
            name="hnsw_configuration",
            validator=lambda value: isinstance(value, HNSWConfigurationInternal),
            is_static=True,
            default_value=HNSWConfigurationInternal(),
        ),
    }

    @override
    def configuration_validator(self) -> None:
        pass


# This is the user-facing interface for HNSW index configuration parameters.
# Internally, we pass around HNSWConfigurationInternal objects, which perform
# validation, serialization and deserialization. Users don't need to know
# about that and instead get a clean constructor with default arguments.
class CollectionConfigurationInterface(CollectionConfigurationInternal):
    """Configuration parameters for creating a collection."""

    def __init__(self, hnsw_configuration: Optional[HNSWConfigurationInternal]):
        """Initializes a new instance of the CollectionConfiguration class.
        Args:
            hnsw_configuration: The HNSW configuration to use for the collection.
        """
        if hnsw_configuration is None:
            hnsw_configuration = HNSWConfigurationInternal()
        parameters = [
            ConfigurationParameter(name="hnsw_configuration", value=hnsw_configuration)
        ]
        super().__init__(parameters=parameters)


# Alias for user convenience - the user doesn't need to know this is an 'Interface'.
CollectionConfiguration = CollectionConfigurationInterface


class EmbeddingsQueueConfigurationInternal(ConfigurationInternal):
    definitions = {
        "automatically_purge": ConfigurationDefinition(
            name="automatically_purge",
            validator=lambda value: isinstance(value, bool),
            is_static=False,
            default_value=True,
        ),
    }

    @override
    def configuration_validator(self) -> None:
        pass
