# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import copy
import json
from collections.abc import Mapping
from enum import Enum
from typing import TYPE_CHECKING, Final, Literal, Union

from typing_extensions import TypeAlias

from streamlit.dataframe_util import DataFormat
from streamlit.elements.lib.column_types import ColumnConfig, ColumnType
from streamlit.elements.lib.dicttools import remove_none_values
from streamlit.errors import StreamlitAPIException

if TYPE_CHECKING:
    import pyarrow as pa
    from pandas import DataFrame, Index, Series

    from streamlit.proto.Arrow_pb2 import Arrow as ArrowProto


# The index identifier can be used to apply configuration options
IndexIdentifierType = Literal["_index"]
INDEX_IDENTIFIER: IndexIdentifierType = "_index"

# This is used as prefix for columns that are configured via the numerical position.
# The integer value is converted into a string key with this prefix.
# This needs to match with the prefix configured in the frontend.
_NUMERICAL_POSITION_PREFIX = "_pos:"


# The column data kind is used to describe the type of the data within the column.
class ColumnDataKind(str, Enum):
    INTEGER = "integer"
    FLOAT = "float"
    DATE = "date"
    TIME = "time"
    DATETIME = "datetime"
    BOOLEAN = "boolean"
    STRING = "string"
    TIMEDELTA = "timedelta"
    PERIOD = "period"
    INTERVAL = "interval"
    BYTES = "bytes"
    DECIMAL = "decimal"
    COMPLEX = "complex"
    LIST = "list"
    DICT = "dict"
    EMPTY = "empty"
    UNKNOWN = "unknown"


# The dataframe schema is a mapping from the name of the column
# in the underlying dataframe to the column data kind.
# The index column uses `_index` as name.
DataframeSchema: TypeAlias = dict[str, ColumnDataKind]

# This mapping contains all editable column types mapped to the data kinds
# that the column type is compatible for editing.
_EDITING_COMPATIBILITY_MAPPING: Final[dict[ColumnType, list[ColumnDataKind]]] = {
    "text": [ColumnDataKind.STRING, ColumnDataKind.EMPTY],
    "number": [
        ColumnDataKind.INTEGER,
        ColumnDataKind.FLOAT,
        ColumnDataKind.DECIMAL,
        ColumnDataKind.STRING,
        ColumnDataKind.TIMEDELTA,
        ColumnDataKind.EMPTY,
    ],
    "checkbox": [
        ColumnDataKind.BOOLEAN,
        ColumnDataKind.STRING,
        ColumnDataKind.INTEGER,
        ColumnDataKind.EMPTY,
    ],
    "selectbox": [
        ColumnDataKind.STRING,
        ColumnDataKind.BOOLEAN,
        ColumnDataKind.INTEGER,
        ColumnDataKind.FLOAT,
        ColumnDataKind.EMPTY,
    ],
    "date": [ColumnDataKind.DATE, ColumnDataKind.DATETIME, ColumnDataKind.EMPTY],
    "time": [ColumnDataKind.TIME, ColumnDataKind.DATETIME, ColumnDataKind.EMPTY],
    "datetime": [
        ColumnDataKind.DATETIME,
        ColumnDataKind.DATE,
        ColumnDataKind.TIME,
        ColumnDataKind.EMPTY,
    ],
    "link": [ColumnDataKind.STRING, ColumnDataKind.EMPTY],
}


def is_type_compatible(column_type: ColumnType, data_kind: ColumnDataKind) -> bool:
    """Check if the column type is compatible with the underlying data kind.

    This check only applies to editable column types (e.g. number or text).
    Non-editable column types (e.g. bar_chart or image) can be configured for
    all data kinds (this might change in the future).

    Parameters
    ----------
    column_type : ColumnType
        The column type to check.

    data_kind : ColumnDataKind
        The data kind to check.

    Returns
    -------
    bool
        True if the column type is compatible with the data kind, False otherwise.
    """

    if column_type not in _EDITING_COMPATIBILITY_MAPPING:
        return True

    return data_kind in _EDITING_COMPATIBILITY_MAPPING[column_type]


def _determine_data_kind_via_arrow(field: pa.Field) -> ColumnDataKind:
    """Determine the data kind via the arrow type information.

    The column data kind refers to the shared data type of the values
    in the column (e.g. int, float, str, bool).

    Parameters
    ----------

    field : pa.Field
        The arrow field from the arrow table schema.

    Returns
    -------
    ColumnDataKind
        The data kind of the field.
    """
    import pyarrow as pa

    field_type = field.type
    if pa.types.is_integer(field_type):
        return ColumnDataKind.INTEGER

    if pa.types.is_floating(field_type):
        return ColumnDataKind.FLOAT

    if pa.types.is_boolean(field_type):
        return ColumnDataKind.BOOLEAN

    if pa.types.is_string(field_type):
        return ColumnDataKind.STRING

    if pa.types.is_date(field_type):
        return ColumnDataKind.DATE

    if pa.types.is_time(field_type):
        return ColumnDataKind.TIME

    if pa.types.is_timestamp(field_type):
        return ColumnDataKind.DATETIME

    if pa.types.is_duration(field_type):
        return ColumnDataKind.TIMEDELTA

    if pa.types.is_list(field_type):
        return ColumnDataKind.LIST

    if pa.types.is_decimal(field_type):
        return ColumnDataKind.DECIMAL

    if pa.types.is_null(field_type):
        return ColumnDataKind.EMPTY

    # Interval does not seem to work correctly:
    # if pa.types.is_interval(field_type):
    #     return ColumnDataKind.INTERVAL

    if pa.types.is_binary(field_type):
        return ColumnDataKind.BYTES

    if pa.types.is_struct(field_type):
        return ColumnDataKind.DICT

    return ColumnDataKind.UNKNOWN


def _determine_data_kind_via_pandas_dtype(
    column: Series | Index,
) -> ColumnDataKind:
    """Determine the data kind by using the pandas dtype.

    The column data kind refers to the shared data type of the values
    in the column (e.g. int, float, str, bool).

    Parameters
    ----------
    column : pd.Series, pd.Index
        The column for which the data kind should be determined.

    Returns
    -------
    ColumnDataKind
        The data kind of the column.
    """
    import pandas as pd

    column_dtype = column.dtype
    if pd.api.types.is_bool_dtype(column_dtype):
        return ColumnDataKind.BOOLEAN

    if pd.api.types.is_integer_dtype(column_dtype):
        return ColumnDataKind.INTEGER

    if pd.api.types.is_float_dtype(column_dtype):
        return ColumnDataKind.FLOAT

    if pd.api.types.is_datetime64_any_dtype(column_dtype):
        return ColumnDataKind.DATETIME

    if pd.api.types.is_timedelta64_dtype(column_dtype):
        return ColumnDataKind.TIMEDELTA

    if isinstance(column_dtype, pd.PeriodDtype):
        return ColumnDataKind.PERIOD

    if isinstance(column_dtype, pd.IntervalDtype):
        return ColumnDataKind.INTERVAL

    if pd.api.types.is_complex_dtype(column_dtype):
        return ColumnDataKind.COMPLEX

    if pd.api.types.is_object_dtype(
        column_dtype
    ) is False and pd.api.types.is_string_dtype(column_dtype):
        # The is_string_dtype
        return ColumnDataKind.STRING

    return ColumnDataKind.UNKNOWN


def _determine_data_kind_via_inferred_type(
    column: Series | Index,
) -> ColumnDataKind:
    """Determine the data kind by inferring it from the underlying data.

    The column data kind refers to the shared data type of the values
    in the column (e.g. int, float, str, bool).

    Parameters
    ----------
    column : pd.Series, pd.Index
        The column to determine the data kind for.

    Returns
    -------
    ColumnDataKind
        The data kind of the column.
    """
    from pandas.api.types import infer_dtype

    inferred_type = infer_dtype(column)

    if inferred_type == "string":
        return ColumnDataKind.STRING

    if inferred_type == "bytes":
        return ColumnDataKind.BYTES

    if inferred_type in ["floating", "mixed-integer-float"]:
        return ColumnDataKind.FLOAT

    if inferred_type == "integer":
        return ColumnDataKind.INTEGER

    if inferred_type == "decimal":
        return ColumnDataKind.DECIMAL

    if inferred_type == "complex":
        return ColumnDataKind.COMPLEX

    if inferred_type == "boolean":
        return ColumnDataKind.BOOLEAN

    if inferred_type in ["datetime64", "datetime"]:
        return ColumnDataKind.DATETIME

    if inferred_type == "date":
        return ColumnDataKind.DATE

    if inferred_type in ["timedelta64", "timedelta"]:
        return ColumnDataKind.TIMEDELTA

    if inferred_type == "time":
        return ColumnDataKind.TIME

    if inferred_type == "period":
        return ColumnDataKind.PERIOD

    if inferred_type == "interval":
        return ColumnDataKind.INTERVAL

    if inferred_type == "empty":
        return ColumnDataKind.EMPTY

    # Unused types: mixed, unknown-array, categorical, mixed-integer

    return ColumnDataKind.UNKNOWN


def _determine_data_kind(
    column: Series | Index, field: pa.Field | None = None
) -> ColumnDataKind:
    """Determine the data kind of a column.

    The column data kind refers to the shared data type of the values
    in the column (e.g. int, float, str, bool).

    Parameters
    ----------
    column : pd.Series, pd.Index
        The column to determine the data kind for.
    field : pa.Field, optional
        The arrow field from the arrow table schema.

    Returns
    -------
    ColumnDataKind
        The data kind of the column.
    """
    import pandas as pd

    if isinstance(column.dtype, pd.CategoricalDtype):
        # Categorical columns can have different underlying data kinds
        # depending on the categories.
        return _determine_data_kind_via_inferred_type(column.dtype.categories)

    if field is not None:
        data_kind = _determine_data_kind_via_arrow(field)
        if data_kind != ColumnDataKind.UNKNOWN:
            return data_kind

    if column.dtype.name == "object":
        # If dtype is object, we need to infer the type from the column
        return _determine_data_kind_via_inferred_type(column)
    return _determine_data_kind_via_pandas_dtype(column)


def determine_dataframe_schema(
    data_df: DataFrame, arrow_schema: pa.Schema
) -> DataframeSchema:
    """Determine the schema of a dataframe.

    Parameters
    ----------
    data_df : pd.DataFrame
        The dataframe to determine the schema of.
    arrow_schema : pa.Schema
        The Arrow schema of the dataframe.

    Returns
    -------

    DataframeSchema
        A mapping that contains the detected data type for the index and columns.
        The key is the column name in the underlying dataframe or ``_index`` for index columns.
    """

    dataframe_schema: DataframeSchema = {}

    # Add type of index:
    # TODO(lukasmasuch): We need to apply changes here to support multiindex.
    dataframe_schema[INDEX_IDENTIFIER] = _determine_data_kind(data_df.index)

    # Add types for all columns:
    for i, column in enumerate(data_df.items()):
        column_name, column_data = column
        dataframe_schema[column_name] = _determine_data_kind(
            column_data, arrow_schema.field(i)
        )
    return dataframe_schema


# A mapping of column names/IDs to column configs.
ColumnConfigMapping: TypeAlias = dict[Union[IndexIdentifierType, str], ColumnConfig]
ColumnConfigMappingInput: TypeAlias = Mapping[
    Union[IndexIdentifierType, str],
    Union[ColumnConfig, None, str],
]


def process_config_mapping(
    column_config: ColumnConfigMappingInput | None = None,
) -> ColumnConfigMapping:
    """Transforms a user-provided column config mapping into a valid column config mapping
    that can be used by the frontend.

    Parameters
    ----------
    column_config: dict or None
        The user-provided column config mapping.

    Returns
    -------
    dict
        The transformed column config mapping.
    """
    if column_config is None:
        return {}

    transformed_column_config: ColumnConfigMapping = {}
    for column, config in column_config.items():
        if config is None:
            transformed_column_config[column] = ColumnConfig(hidden=True)
        elif isinstance(config, str):
            transformed_column_config[column] = ColumnConfig(label=config)
        elif isinstance(config, dict):
            # Ensure that the column config objects are cloned
            # since we will apply in-place changes to it.
            transformed_column_config[column] = copy.deepcopy(config)
        else:
            raise StreamlitAPIException(
                f"Invalid column config for column `{column}`. "
                f"Expected `None`, `str` or `dict`, but got `{type(config)}`."
            )
    return transformed_column_config


def update_column_config(
    column_config_mapping: ColumnConfigMapping, column: str, column_config: ColumnConfig
) -> None:
    """Updates the column config value for a single column within the mapping.

    Parameters
    ----------

    column_config_mapping : ColumnConfigMapping
        The column config mapping to update.

    column : str
        The column to update the config value for.

    column_config : ColumnConfig
        The column config to update.
    """

    if column not in column_config_mapping:
        column_config_mapping[column] = {}

    column_config_mapping[column].update(column_config)


def apply_data_specific_configs(
    columns_config: ColumnConfigMapping,
    data_format: DataFormat,
) -> None:
    """Apply data specific configurations to the provided dataframe.

    This will apply inplace changes to the dataframe and the column configurations
    depending on the data format.

    Parameters
    ----------
    columns_config : ColumnConfigMapping
        A mapping of column names/ids to column configurations.

    data_format : DataFormat
        The format of the data.
    """

    # Pandas adds a range index as default to all datastructures
    # but for most of the non-pandas data objects it is unnecessary
    # to show this index to the user. Therefore, we will hide it as default.
    if data_format in [
        DataFormat.SET_OF_VALUES,
        DataFormat.TUPLE_OF_VALUES,
        DataFormat.LIST_OF_VALUES,
        DataFormat.NUMPY_LIST,
        DataFormat.NUMPY_MATRIX,
        DataFormat.LIST_OF_RECORDS,
        DataFormat.LIST_OF_ROWS,
        DataFormat.COLUMN_VALUE_MAPPING,
        # Dataframe-like objects that don't have an index:
        DataFormat.PANDAS_ARRAY,
        DataFormat.PANDAS_INDEX,
        DataFormat.POLARS_DATAFRAME,
        DataFormat.POLARS_SERIES,
        DataFormat.POLARS_LAZYFRAME,
        DataFormat.PYARROW_ARRAY,
        DataFormat.RAY_DATASET,
    ]:
        update_column_config(columns_config, INDEX_IDENTIFIER, {"hidden": True})


def _convert_column_config_to_json(column_config_mapping: ColumnConfigMapping) -> str:
    try:
        # Ignore all None values and prefix columns specified by numerical index:
        return json.dumps(
            {
                (
                    f"{_NUMERICAL_POSITION_PREFIX}{str(k)}" if isinstance(k, int) else k
                ): v
                for (k, v) in remove_none_values(column_config_mapping).items()
            },
            allow_nan=False,
        )
    except ValueError as ex:
        raise StreamlitAPIException(
            f"The provided column config cannot be serialized into JSON: {ex}"
        ) from ex


def marshall_column_config(
    proto: ArrowProto, column_config_mapping: ColumnConfigMapping
) -> None:
    """Marshall the column config into the Arrow proto.

    Parameters
    ----------
    proto : ArrowProto
        The proto to marshall into.

    column_config_mapping : ColumnConfigMapping
        The column config to marshall.
    """

    proto.columns = _convert_column_config_to_json(column_config_mapping)
