from __future__ import annotations

from typing import TYPE_CHECKING, Any, overload

from altair import (
    Chart,
    ConcatChart,
    ConcatSpecGenericSpec,
    FacetChart,
    FacetedUnitSpec,
    FacetSpec,
    HConcatChart,
    HConcatSpecGenericSpec,
    LayerChart,
    LayerSpec,
    NonNormalizedSpec,
    TopLevelConcatSpec,
    TopLevelFacetSpec,
    TopLevelHConcatSpec,
    TopLevelLayerSpec,
    TopLevelUnitSpec,
    TopLevelVConcatSpec,
    UnitSpec,
    UnitSpecWithFrame,
    VConcatChart,
    VConcatSpecGenericSpec,
    data_transformers,
)
from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion
from altair.utils.schemapi import Undefined

if TYPE_CHECKING:
    import sys
    from collections.abc import Iterable

    if sys.version_info >= (3, 10):
        from typing import TypeAlias
    else:
        from typing_extensions import TypeAlias

    from altair.typing import ChartType
    from altair.utils.core import DataFrameLike

Scope: TypeAlias = tuple[int, ...]
FacetMapping: TypeAlias = dict[tuple[str, Scope], tuple[str, Scope]]


# For the transformed_data functionality, the chart classes in the values
# can be considered equivalent to the chart class in the key.
_chart_class_mapping = {
    Chart: (
        Chart,
        TopLevelUnitSpec,
        FacetedUnitSpec,
        UnitSpec,
        UnitSpecWithFrame,
        NonNormalizedSpec,
    ),
    LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec),
    ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec),
    HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec),
    VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec),
    FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec),
}


@overload
def transformed_data(
    chart: Chart | FacetChart,
    row_limit: int | None = None,
    exclude: Iterable[str] | None = None,
) -> DataFrameLike | None: ...


@overload
def transformed_data(
    chart: LayerChart | HConcatChart | VConcatChart | ConcatChart,
    row_limit: int | None = None,
    exclude: Iterable[str] | None = None,
) -> list[DataFrameLike]: ...


def transformed_data(chart, row_limit=None, exclude=None):
    """
    Evaluate a Chart's transforms.

    Evaluate the data transforms associated with a Chart and return the
    transformed data as one or more DataFrames

    Parameters
    ----------
    chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart
        Altair chart to evaluate transforms on
    row_limit : int (optional)
        Maximum number of rows to return for each DataFrame. None (default) for unlimited
    exclude : iterable of str
        Set of the names of charts to exclude

    Returns
    -------
    DataFrame or list of DataFrames or None
        If input chart is a Chart or Facet Chart, returns a DataFrame of the
        transformed data. Otherwise, returns a list of DataFrames of the
        transformed data
    """
    vf = import_vegafusion()
    # Add mark if none is specified to satisfy Vega-Lite
    if isinstance(chart, Chart) and chart.mark == Undefined:
        chart = chart.mark_point()

    # Deep copy chart so that we can rename marks without affecting caller
    chart = chart.copy(deep=True)

    # Ensure that all views are named so that we can look them up in the
    # resulting Vega specification
    chart_names = name_views(chart, 0, exclude=exclude)

    # Compile to Vega and extract inline DataFrames
    with data_transformers.enable("vegafusion"):
        vega_spec = chart.to_dict(format="vega", context={"pre_transform": False})
        inline_datasets = get_inline_tables(vega_spec)

    # Build mapping from mark names to vega datasets
    facet_mapping = get_facet_mapping(vega_spec)
    dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping)

    # Build a list of vega dataset names that corresponds to the order
    # of the chart components
    dataset_names = []
    for chart_name in chart_names:
        if chart_name in dataset_mapping:
            dataset_names.append(dataset_mapping[chart_name])
        else:
            msg = "Failed to locate all datasets"
            raise ValueError(msg)

    # Extract transformed datasets with VegaFusion
    datasets, _ = vf.runtime.pre_transform_datasets(
        vega_spec,
        dataset_names,
        row_limit=row_limit,
        inline_datasets=inline_datasets,
    )

    if isinstance(chart, (Chart, FacetChart)):
        # Return DataFrame (or None if it was excluded) if input was a simple Chart
        if not datasets:
            return None
        else:
            return datasets[0]
    else:
        # Otherwise return the list of DataFrames
        return datasets


# The equivalent classes from _chart_class_mapping should also be added
# to the type hints below for `chart` as the function would also work for them.
# However, this was not possible so far as mypy then complains about
# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]"
# This might be due to the complex type hierarchy of the chart classes.
# See also https://github.com/python/mypy/issues/5119
# and https://github.com/python/mypy/issues/4020 which show that mypy might not have
# a very consistent behavior for overloaded functions.
# The same error appeared when trying it with Protocols for the concat and layer charts.
# This function is only used internally and so we accept this inconsistency for now.
def name_views(
    chart: ChartType, i: int = 0, exclude: Iterable[str] | None = None
) -> list[str]:
    """
    Name unnamed chart views.

    Name unnamed charts views so that we can look them up later in
    the compiled Vega spec.

    Note: This function mutates the input chart by applying names to
    unnamed views.

    Parameters
    ----------
    chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart
        Altair chart to apply names to
    i : int (default 0)
        Starting chart index
    exclude : iterable of str
        Names of charts to exclude

    Returns
    -------
    list of str
        List of the names of the charts and subcharts
    """
    exclude = set(exclude) if exclude is not None else set()
    if isinstance(
        chart, (_chart_class_mapping[Chart], _chart_class_mapping[FacetChart])
    ):
        if chart.name not in exclude:
            if chart.name in {None, Undefined}:
                # Add name since none is specified
                chart.name = Chart._get_name()
            return [chart.name]
        else:
            return []
    else:
        subcharts: list[Any]
        if isinstance(chart, _chart_class_mapping[LayerChart]):
            subcharts = chart.layer
        elif isinstance(chart, _chart_class_mapping[HConcatChart]):
            subcharts = chart.hconcat
        elif isinstance(chart, _chart_class_mapping[VConcatChart]):
            subcharts = chart.vconcat
        elif isinstance(chart, _chart_class_mapping[ConcatChart]):
            subcharts = chart.concat
        else:
            msg = (
                "transformed_data accepts an instance of "
                "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n"
                f"Received value of type: {type(chart)}"
            )
            raise ValueError(msg)

        chart_names: list[str] = []
        for subchart in subcharts:
            for name in name_views(subchart, i=i + len(chart_names), exclude=exclude):
                chart_names.append(name)
        return chart_names


def get_group_mark_for_scope(
    vega_spec: dict[str, Any], scope: Scope
) -> dict[str, Any] | None:
    """
    Get the group mark at a particular scope.

    Parameters
    ----------
    vega_spec : dict
        Top-level Vega specification dictionary
    scope : tuple of int
        Scope tuple. If empty, the original Vega specification is returned.
        Otherwise, the nested group mark at the scope specified is returned.

    Returns
    -------
    dict or None
        Top-level Vega spec (if scope is empty)
        or group mark (if scope is non-empty)
        or None (if group mark at scope does not exist)

    Examples
    --------
    >>> spec = {
    ...     "marks": [
    ...         {"type": "group", "marks": [{"type": "symbol"}]},
    ...         {"type": "group", "marks": [{"type": "rect"}]},
    ...     ]
    ... }
    >>> get_group_mark_for_scope(spec, (1,))
    {'type': 'group', 'marks': [{'type': 'rect'}]}
    """
    group = vega_spec

    # Find group at scope
    for scope_value in scope:
        group_index = 0
        child_group = None
        for mark in group.get("marks", []):
            if mark.get("type") == "group":
                if group_index == scope_value:
                    child_group = mark
                    break
                group_index += 1
        if child_group is None:
            return None
        group = child_group

    return group


def get_datasets_for_scope(vega_spec: dict[str, Any], scope: Scope) -> list[str]:
    """
    Get the names of the datasets that are defined at a given scope.

    Parameters
    ----------
    vega_spec : dict
        Top-leve Vega specification
    scope : tuple of int
        Scope tuple. If empty, the names of top-level datasets are returned
        Otherwise, the names of the datasets defined in the nested group mark
        at the specified scope are returned.

    Returns
    -------
    list of str
        List of the names of the datasets defined at the specified scope

    Examples
    --------
    >>> spec = {
    ...     "data": [{"name": "data1"}],
    ...     "marks": [
    ...         {
    ...             "type": "group",
    ...             "data": [{"name": "data2"}],
    ...             "marks": [{"type": "symbol"}],
    ...         },
    ...         {
    ...             "type": "group",
    ...             "data": [
    ...                 {"name": "data3"},
    ...                 {"name": "data4"},
    ...             ],
    ...             "marks": [{"type": "rect"}],
    ...         },
    ...     ],
    ... }

    >>> get_datasets_for_scope(spec, ())
    ['data1']

    >>> get_datasets_for_scope(spec, (0,))
    ['data2']

    >>> get_datasets_for_scope(spec, (1,))
    ['data3', 'data4']

    Returns empty when no group mark exists at scope
    >>> get_datasets_for_scope(spec, (1, 3))
    []
    """
    group = get_group_mark_for_scope(vega_spec, scope) or {}

    # get datasets from group
    datasets = []
    for dataset in group.get("data", []):
        datasets.append(dataset["name"])

    # Add facet dataset
    facet_dataset = group.get("from", {}).get("facet", {}).get("name", None)
    if facet_dataset:
        datasets.append(facet_dataset)
    return datasets


def get_definition_scope_for_data_reference(
    vega_spec: dict[str, Any], data_name: str, usage_scope: Scope
) -> Scope | None:
    """
    Return the scope that a dataset is defined at, for a given usage scope.

    Parameters
    ----------
    vega_spec: dict
        Top-level Vega specification
    data_name: str
        The name of a dataset reference
    usage_scope: tuple of int
        The scope that the dataset is referenced in

    Returns
    -------
    tuple of int
        The scope where the referenced dataset is defined,
        or None if no such dataset is found

    Examples
    --------
    >>> spec = {
    ...     "data": [{"name": "data1"}],
    ...     "marks": [
    ...         {
    ...             "type": "group",
    ...             "data": [{"name": "data2"}],
    ...             "marks": [
    ...                 {
    ...                     "type": "symbol",
    ...                     "encode": {
    ...                         "update": {
    ...                             "x": {"field": "x", "data": "data1"},
    ...                             "y": {"field": "y", "data": "data2"},
    ...                         }
    ...                     },
    ...                 }
    ...             ],
    ...         }
    ...     ],
    ... }

    data1 is referenced at scope [0] and defined at scope []
    >>> get_definition_scope_for_data_reference(spec, "data1", (0,))
    ()

    data2 is referenced at scope [0] and defined at scope [0]
    >>> get_definition_scope_for_data_reference(spec, "data2", (0,))
    (0,)

    If data2 is not visible at scope [] (the top level),
    because it's defined in scope [0]
    >>> repr(get_definition_scope_for_data_reference(spec, "data2", ()))
    'None'
    """
    for i in reversed(range(len(usage_scope) + 1)):
        scope = usage_scope[:i]
        datasets = get_datasets_for_scope(vega_spec, scope)
        if data_name in datasets:
            return scope
    return None


def get_facet_mapping(group: dict[str, Any], scope: Scope = ()) -> FacetMapping:
    """
    Create mapping from facet definitions to source datasets.

    Parameters
    ----------
    group : dict
        Top-level Vega spec or nested group mark
    scope : tuple of int
        Scope of the group dictionary within a top-level Vega spec

    Returns
    -------
    dict
        Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope)

    Examples
    --------
    >>> spec = {
    ...     "data": [{"name": "data1"}],
    ...     "marks": [
    ...         {
    ...             "type": "group",
    ...             "from": {
    ...                 "facet": {
    ...                     "name": "facet1",
    ...                     "data": "data1",
    ...                     "groupby": ["colA"],
    ...                 }
    ...             },
    ...         }
    ...     ],
    ... }
    >>> get_facet_mapping(spec)
    {('facet1', (0,)): ('data1', ())}
    """
    facet_mapping = {}
    group_index = 0
    mark_group = get_group_mark_for_scope(group, scope) or {}
    for mark in mark_group.get("marks", []):
        if mark.get("type", None) == "group":
            # Get facet for this group
            group_scope = (*scope, group_index)
            facet = mark.get("from", {}).get("facet", None)
            if facet is not None:
                facet_name = facet.get("name", None)
                facet_data = facet.get("data", None)
                if facet_name is not None and facet_data is not None:
                    definition_scope = get_definition_scope_for_data_reference(
                        group, facet_data, scope
                    )
                    if definition_scope is not None:
                        facet_mapping[facet_name, group_scope] = (
                            facet_data,
                            definition_scope,
                        )

            # Handle children recursively
            child_mapping = get_facet_mapping(group, scope=group_scope)
            facet_mapping.update(child_mapping)
            group_index += 1

    return facet_mapping


def get_from_facet_mapping(
    scoped_dataset: tuple[str, Scope], facet_mapping: FacetMapping
) -> tuple[str, Scope]:
    """
    Apply facet mapping to a scoped dataset.

    Parameters
    ----------
    scoped_dataset : (str, tuple of int)
        A dataset name and scope tuple
    facet_mapping : dict from (str, tuple of int) to (str, tuple of int)
        The facet mapping produced by get_facet_mapping

    Returns
    -------
    (str, tuple of int)
        Dataset name and scope tuple that has been mapped as many times as possible

    Examples
    --------
    Facet mapping as produced by get_facet_mapping
    >>> facet_mapping = {
    ...     ("facet1", (0,)): ("data1", ()),
    ...     ("facet2", (0, 1)): ("facet1", (0,)),
    ... }
    >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping)
    ('data1', ())
    """
    while scoped_dataset in facet_mapping:
        scoped_dataset = facet_mapping[scoped_dataset]
    return scoped_dataset


def get_datasets_for_view_names(
    group: dict[str, Any],
    vl_chart_names: list[str],
    facet_mapping: FacetMapping,
    scope: Scope = (),
) -> dict[str, tuple[str, Scope]]:
    """
    Get the Vega datasets that correspond to the provided Altair view names.

    Parameters
    ----------
    group : dict
        Top-level Vega spec or nested group mark
    vl_chart_names : list of str
        List of the Vega-Lite
    facet_mapping : dict from (str, tuple of int) to (str, tuple of int)
        The facet mapping produced by get_facet_mapping
    scope : tuple of int
        Scope of the group dictionary within a top-level Vega spec

    Returns
    -------
    dict from str to (str, tuple of int)
        Dict from Altair view names to scoped datasets
    """
    datasets = {}
    group_index = 0
    mark_group = get_group_mark_for_scope(group, scope) or {}
    for mark in mark_group.get("marks", []):
        for vl_chart_name in vl_chart_names:
            if mark.get("name", "") == f"{vl_chart_name}_cell":
                data_name = mark.get("from", {}).get("facet", None).get("data", None)
                scoped_data_name = (data_name, scope)
                datasets[vl_chart_name] = get_from_facet_mapping(
                    scoped_data_name, facet_mapping
                )
                break

        name = mark.get("name", "")
        if mark.get("type", "") == "group":
            group_data_names = get_datasets_for_view_names(
                group, vl_chart_names, facet_mapping, scope=(*scope, group_index)
            )
            for k, v in group_data_names.items():
                datasets.setdefault(k, v)
            group_index += 1
        else:
            for vl_chart_name in vl_chart_names:
                if name.startswith(vl_chart_name) and name.endswith("_marks"):
                    data_name = mark.get("from", {}).get("data", None)
                    scoped_data = get_definition_scope_for_data_reference(
                        group, data_name, scope
                    )
                    if scoped_data is not None:
                        datasets[vl_chart_name] = get_from_facet_mapping(
                            (data_name, scoped_data), facet_mapping
                        )
                        break

    return datasets
