# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Data marshalling utilities for ArrowTable protobufs, which are used by
CustomComponent for dataframe serialization.
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Any

from streamlit import dataframe_util
from streamlit.elements.lib import pandas_styler_utils

if TYPE_CHECKING:
    from pandas import DataFrame, Index, Series

    from streamlit.proto.Components_pb2 import ArrowTable as ArrowTableProto


def _maybe_tuple_to_list(item: Any) -> Any:
    """Convert a tuple to a list. Leave as is if it's not a tuple."""
    return list(item) if isinstance(item, tuple) else item


def marshall(
    proto: ArrowTableProto, data: Any, default_uuid: str | None = None
) -> None:
    """Marshall data into an ArrowTable proto.

    Parameters
    ----------
    proto : proto.ArrowTable
        Output. The protobuf for a Streamlit ArrowTable proto.

    data : pandas.DataFrame, pandas.Styler, numpy.ndarray, Iterable, dict, or None
        Something that is or can be converted to a dataframe.

    """
    if dataframe_util.is_pandas_styler(data):
        pandas_styler_utils.marshall_styler(proto, data, default_uuid)  # type: ignore

    df = dataframe_util.convert_anything_to_pandas_df(data)
    _marshall_index(proto, df.index)
    _marshall_columns(proto, df.columns)
    _marshall_data(proto, df)


def _marshall_index(proto: ArrowTableProto, index: Index) -> None:
    """Marshall pandas.DataFrame index into an ArrowTable proto.

    Parameters
    ----------
    proto : proto.ArrowTable
        Output. The protobuf for a Streamlit ArrowTable proto.

    index : pd.Index
        Index to use for resulting frame.
        Will default to RangeIndex (0, 1, 2, ..., n) if no index is provided.

    """
    import pandas as pd

    index = map(_maybe_tuple_to_list, index.values)
    index_df = pd.DataFrame(index)
    proto.index = dataframe_util.convert_pandas_df_to_arrow_bytes(index_df)


def _marshall_columns(proto: ArrowTableProto, columns: Series) -> None:
    """Marshall pandas.DataFrame columns into an ArrowTable proto.

    Parameters
    ----------
    proto : proto.ArrowTable
        Output. The protobuf for a Streamlit ArrowTable proto.

    columns : Series
        Column labels to use for resulting frame.
        Will default to RangeIndex (0, 1, 2, ..., n) if no column labels are provided.

    """
    import pandas as pd

    columns = map(_maybe_tuple_to_list, columns.values)
    columns_df = pd.DataFrame(columns)
    proto.columns = dataframe_util.convert_pandas_df_to_arrow_bytes(columns_df)


def _marshall_data(proto: ArrowTableProto, df: DataFrame) -> None:
    """Marshall pandas.DataFrame data into an ArrowTable proto.

    Parameters
    ----------
    proto : proto.ArrowTable
        Output. The protobuf for a Streamlit ArrowTable proto.

    df : pandas.DataFrame
        A dataframe to marshall.

    """
    proto.data = dataframe_util.convert_pandas_df_to_arrow_bytes(df)


def arrow_proto_to_dataframe(proto: ArrowTableProto) -> DataFrame:
    """Convert ArrowTable proto to pandas.DataFrame.

    Parameters
    ----------
    proto : proto.ArrowTable
        Output. pandas.DataFrame

    """

    if dataframe_util.is_pyarrow_version_less_than("14.0.1"):
        raise RuntimeError(
            "The installed pyarrow version is not compatible with this component. "
            "Please upgrade to 14.0.1 or higher: pip install -U pyarrow"
        )

    import pandas as pd

    data = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.data)
    index = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.index)
    columns = dataframe_util.convert_arrow_bytes_to_pandas_df(proto.columns)

    return pd.DataFrame(
        data.values, index=index.values.T.tolist(), columns=columns.values.T.tolist()
    )
