# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import dataclasses
import inspect
import types
from collections import ChainMap, UserDict, UserList
from collections.abc import (
    AsyncGenerator,
    Generator,
    ItemsView,
    Iterable,
    KeysView,
    ValuesView,
)
from io import StringIO
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Final,
    cast,
)

from streamlit import dataframe_util, type_util
from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.string_util import (
    is_mem_address_str,
    max_char_sequence,
)

if TYPE_CHECKING:
    from streamlit.delta_generator import DeltaGenerator

# Special methods:
HELP_TYPES: Final[tuple[type[Any], ...]] = (
    types.BuiltinFunctionType,
    types.BuiltinMethodType,
    types.FunctionType,
    types.MethodType,
    types.ModuleType,
)

_LOGGER: Final = get_logger(__name__)

_TEXT_CURSOR: Final = " ▏"


class StreamingOutput(list[Any]):
    pass


class WriteMixin:
    @gather_metrics("write_stream")
    def write_stream(
        self,
        stream: Callable[..., Any]
        | Generator[Any, Any, Any]
        | Iterable[Any]
        | AsyncGenerator[Any, Any],
    ) -> list[Any] | str:
        """Stream a generator, iterable, or stream-like sequence to the app.

        ``st.write_stream`` iterates through the given sequences and writes all
        chunks to the app. String chunks will be written using a typewriter effect.
        Other data types will be written using ``st.write``.

        Parameters
        ----------
        stream : Callable, Generator, Iterable, OpenAI Stream, or LangChain Stream
            The generator or iterable to stream.

            If you pass an async generator, Streamlit will internally convert
            it to a sync generator.

            .. note::
                To use additional LLM libraries, you can create a wrapper to
                manually define a generator function and include custom output
                parsing.

        Returns
        -------
        str or list
            The full response. If the streamed output only contains text, this
            is a string. Otherwise, this is a list of all the streamed objects.
            The return value is fully compatible as input for ``st.write``.

        Example
        -------
        You can pass an OpenAI stream as shown in our tutorial, `Build a \
        basic LLM chat app <https://docs.streamlit.io/develop/tutorials/llms\
        /build-conversational-apps#build-a-chatgpt-like-app>`_. Alternatively,
        you can pass a generic generator function as input:

        >>> import time
        >>> import numpy as np
        >>> import pandas as pd
        >>> import streamlit as st
        >>>
        >>> _LOREM_IPSUM = \"\"\"
        >>> Lorem ipsum dolor sit amet, **consectetur adipiscing** elit, sed do eiusmod tempor
        >>> incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis
        >>> nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat.
        >>> \"\"\"
        >>>
        >>>
        >>> def stream_data():
        >>>     for word in _LOREM_IPSUM.split(" "):
        >>>         yield word + " "
        >>>         time.sleep(0.02)
        >>>
        >>>     yield pd.DataFrame(
        >>>         np.random.randn(5, 10),
        >>>         columns=["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"],
        >>>     )
        >>>
        >>>     for word in _LOREM_IPSUM.split(" "):
        >>>         yield word + " "
        >>>         time.sleep(0.02)
        >>>
        >>>
        >>> if st.button("Stream data"):
        >>>     st.write_stream(stream_data)

        ..  output::
            https://doc-write-stream-data.streamlit.app/
            height: 550px

        """

        # Just apply some basic checks for common iterable types that should
        # not be passed in here.
        if isinstance(stream, str) or dataframe_util.is_dataframe_like(stream):
            raise StreamlitAPIException(
                "`st.write_stream` expects a generator or stream-like object as input "
                f"not {type(stream)}. Please use `st.write` instead for "
                "this data type."
            )

        stream_container: DeltaGenerator | None = None
        streamed_response: str = ""
        written_content: list[Any] = StreamingOutput()

        def flush_stream_response():
            """Write the full response to the app."""
            nonlocal streamed_response
            nonlocal stream_container

            if streamed_response and stream_container:
                # Replace the stream_container element the full response
                stream_container.markdown(streamed_response)
                written_content.append(streamed_response)
                stream_container = None
                streamed_response = ""

        # Make sure we have a generator and not just a generator function.
        if inspect.isgeneratorfunction(stream) or inspect.isasyncgenfunction(stream):
            stream = stream()

        # If the stream is an async generator, convert it to a sync generator:
        if inspect.isasyncgen(stream):
            stream = type_util.async_generator_to_sync(stream)

        try:
            iter(stream)  # type: ignore
        except TypeError as exc:
            raise StreamlitAPIException(
                f"The provided input (type: {type(stream)}) cannot be iterated. "
                "Please make sure that it is a generator, generator function or iterable."
            ) from exc

        # Iterate through the generator and write each chunk to the app
        # with a type writer effect.
        for chunk in stream:  # type: ignore
            if type_util.is_openai_chunk(chunk):
                # Try to convert OpenAI chat completion chunk to a string:
                try:
                    if len(chunk.choices) == 0 or chunk.choices[0].delta is None:
                        # The choices list can be empty. E.g. when using the
                        # AzureOpenAI client, the first chunk will always be empty.
                        chunk = ""
                    else:
                        chunk = chunk.choices[0].delta.content or ""
                except AttributeError as err:
                    raise StreamlitAPIException(
                        "Failed to parse the OpenAI ChatCompletionChunk. "
                        "The most likely cause is a change of the chunk object structure "
                        "due to a recent OpenAI update. You might be able to fix this "
                        "by downgrading the OpenAI library or upgrading Streamlit. Also, "
                        "please report this issue to: https://github.com/streamlit/streamlit/issues."
                    ) from err

            if type_util.is_type(chunk, "langchain_core.messages.ai.AIMessageChunk"):
                # Try to convert LangChain message chunk to a string:
                try:
                    chunk = chunk.content or ""
                except AttributeError as err:
                    raise StreamlitAPIException(
                        "Failed to parse the LangChain AIMessageChunk. "
                        "The most likely cause is a change of the chunk object structure "
                        "due to a recent LangChain update. You might be able to fix this "
                        "by downgrading the OpenAI library or upgrading Streamlit. Also, "
                        "please report this issue to: https://github.com/streamlit/streamlit/issues."
                    ) from err

            if isinstance(chunk, str):
                if not chunk:
                    # Empty strings can be ignored
                    continue

                first_text = False
                if not stream_container:
                    stream_container = self.dg.empty()
                    first_text = True
                streamed_response += chunk
                # Only add the streaming symbol on the second text chunk
                stream_container.markdown(
                    streamed_response + ("" if first_text else _TEXT_CURSOR),
                )
            elif callable(chunk):
                flush_stream_response()
                chunk()
            else:
                flush_stream_response()
                self.write(chunk)
                written_content.append(chunk)

        flush_stream_response()

        if not written_content:
            # If nothing was streamed, return an empty string.
            return ""
        elif len(written_content) == 1 and isinstance(written_content[0], str):
            # If the output only contains a single string, return it as a string
            return written_content[0]

        # Otherwise return it as a list of write-compatible objects
        return written_content

    @gather_metrics("write")
    def write(self, *args: Any, unsafe_allow_html: bool = False, **kwargs) -> None:
        """Displays arguments in the app.

        This is the Swiss Army knife of Streamlit commands: it does different
        things depending on what you throw at it. Unlike other Streamlit
        commands, ``st.write()`` has some unique properties:

        - You can pass in multiple arguments, all of which will be displayed.
        - Its behavior depends on the input type(s).

        Parameters
        ----------
        *args : any
            One or many objects to display in the app.

            .. list-table:: Each type of argument is handled as follows:
                :header-rows: 1

                * - Type
                  - Handling
                * - ``str``
                  - Uses ``st.markdown()``.
                * - dataframe-like, ``dict``, or ``list``
                  - Uses ``st.dataframe()``.
                * - ``Exception``
                  - Uses ``st.exception()``.
                * - function, module, or class
                  - Uses ``st.help()``.
                * - ``DeltaGenerator``
                  - Uses ``st.help()``.
                * - Altair chart
                  - Uses ``st.altair_chart()``.
                * - Bokeh figure
                  - Uses ``st.bokeh_chart()``.
                * - Graphviz graph
                  - Uses ``st.graphviz_chart()``.
                * - Keras model
                  - Converts model and uses ``st.graphviz_chart()``.
                * - Matplotlib figure
                  - Uses ``st.pyplot()``.
                * - Plotly figure
                  - Uses ``st.plotly_chart()``.
                * - ``PIL.Image``
                  - Uses ``st.image()``.
                * - generator or stream (like ``openai.Stream``)
                  - Uses ``st.write_stream()``.
                * - SymPy expression
                  - Uses ``st.latex()``.
                * - An object with ``._repr_html()``
                  - Uses ``st.html()``.
                * - Database cursor
                  - Displays DB API 2.0 cursor results in a table.
                * - Any
                  - Displays ``str(arg)`` as inline code.

        unsafe_allow_html : bool
            Whether to render HTML within ``*args``. This only applies to
            strings or objects falling back on ``_repr_html_()``. If this is
            ``False`` (default), any HTML tags found in ``body`` will be
            escaped and therefore treated as raw text. If this is ``True``, any
            HTML expressions within ``body`` will be rendered.

            Adding custom HTML to your app impacts safety, styling, and
            maintainability.

            .. note::
                If you only want to insert HTML or CSS without Markdown text,
                we recommend using ``st.html`` instead.

        **kwargs : any
            Keyword arguments. Not used.

        .. deprecated::
            ``**kwargs`` is deprecated and will be removed in a later version.
            Use other, more specific Streamlit commands to pass additional
            keyword arguments.

        Returns
        -------
        None

        Examples
        --------

        Its basic use case is to draw Markdown-formatted text, whenever the
        input is a string:

        >>> import streamlit as st
        >>>
        >>> st.write("Hello, *World!* :sunglasses:")

        ..  output::
            https://doc-write1.streamlit.app/
            height: 150px

        As mentioned earlier, ``st.write()`` also accepts other data formats, such as
        numbers, data frames, styled data frames, and assorted objects:

        >>> import streamlit as st
        >>> import pandas as pd
        >>>
        >>> st.write(1234)
        >>> st.write(
        ...     pd.DataFrame(
        ...         {
        ...             "first column": [1, 2, 3, 4],
        ...             "second column": [10, 20, 30, 40],
        ...         }
        ...     )
        ... )

        ..  output::
            https://doc-write2.streamlit.app/
            height: 350px

        Finally, you can pass in multiple arguments to do things like:

        >>> import streamlit as st
        >>>
        >>> st.write("1 + 1 = ", 2)
        >>> st.write("Below is a DataFrame:", data_frame, "Above is a dataframe.")

        ..  output::
            https://doc-write3.streamlit.app/
            height: 410px

        Oh, one more thing: ``st.write`` accepts chart objects too! For example:

        >>> import streamlit as st
        >>> import pandas as pd
        >>> import numpy as np
        >>> import altair as alt
        >>>
        >>> df = pd.DataFrame(np.random.randn(200, 3), columns=["a", "b", "c"])
        >>> c = (
        ...     alt.Chart(df)
        ...     .mark_circle()
        ...     .encode(x="a", y="b", size="c", color="c", tooltip=["a", "b", "c"])
        ... )
        >>>
        >>> st.write(c)

        ..  output::
            https://doc-vega-lite-chart.streamlit.app/
            height: 300px

        """
        if kwargs:
            _LOGGER.warning(
                'Invalid arguments were passed to "st.write" function. Support for '
                "passing such unknown keywords arguments will be dropped in future. "
                "Invalid arguments were: %s",
                kwargs,
            )

        string_buffer: list[str] = []

        # This bans some valid cases like: e = st.empty(); e.write("a", "b").
        # BUT: 1) such cases are rare, 2) this rule is easy to understand,
        # and 3) this rule should be removed once we have st.container()
        if not self.dg._is_top_level and len(args) > 1:
            raise StreamlitAPIException(
                "Cannot replace a single element with multiple elements.\n\n"
                "The `write()` method only supports multiple elements when "
                "inserting elements rather than replacing. That is, only "
                "when called as `st.write()` or `st.sidebar.write()`."
            )

        def flush_buffer():
            if string_buffer:
                text_content = " ".join(string_buffer)
                # The usage of empty here prevents
                # some grey out effects:
                text_container = self.dg.empty()
                text_container.markdown(
                    text_content,
                    unsafe_allow_html=unsafe_allow_html,
                )
                string_buffer[:] = []

        for arg in args:
            # Order matters!
            if isinstance(arg, str):
                string_buffer.append(arg)
            elif isinstance(arg, StreamingOutput):
                flush_buffer()
                for item in arg:
                    if callable(item):
                        flush_buffer()
                        item()
                    else:
                        self.write(item, unsafe_allow_html=unsafe_allow_html)
            elif isinstance(arg, Exception):
                flush_buffer()
                self.dg.exception(arg)
            elif type_util.is_delta_generator(arg):
                flush_buffer()
                self.dg.help(arg)
            elif dataframe_util.is_dataframe_like(arg):
                flush_buffer()
                self.dg.dataframe(arg)
            elif type_util.is_altair_chart(arg):
                flush_buffer()
                self.dg.altair_chart(arg)
            elif type_util.is_type(arg, "matplotlib.figure.Figure"):
                flush_buffer()
                self.dg.pyplot(arg)
            elif type_util.is_plotly_chart(arg):
                flush_buffer()
                self.dg.plotly_chart(arg)
            elif type_util.is_type(arg, "bokeh.plotting.figure.Figure"):
                flush_buffer()
                self.dg.bokeh_chart(arg)
            elif type_util.is_graphviz_chart(arg):
                flush_buffer()
                self.dg.graphviz_chart(arg)
            elif type_util.is_sympy_expression(arg):
                flush_buffer()
                self.dg.latex(arg)
            elif type_util.is_pillow_image(arg):
                flush_buffer()
                self.dg.image(arg)
            elif type_util.is_keras_model(arg):
                from tensorflow.python.keras.utils import vis_utils

                flush_buffer()
                dot = vis_utils.model_to_dot(arg)
                self.dg.graphviz_chart(dot.to_string())
            elif (
                isinstance(
                    arg,
                    (
                        dict,
                        list,
                        map,
                        enumerate,
                        types.MappingProxyType,
                        UserDict,
                        ChainMap,
                        UserList,
                        ItemsView,
                        KeysView,
                        ValuesView,
                    ),
                )
                or type_util.is_custom_dict(arg)
                or type_util.is_namedtuple(arg)
                or type_util.is_pydantic_model(arg)
            ):
                flush_buffer()
                self.dg.json(arg)
            elif type_util.is_pydeck(arg):
                flush_buffer()
                self.dg.pydeck_chart(arg)
            elif isinstance(arg, StringIO):
                flush_buffer()
                self.dg.markdown(arg.getvalue())
            elif (
                inspect.isgenerator(arg)
                or inspect.isgeneratorfunction(arg)
                or inspect.isasyncgenfunction(arg)
                or inspect.isasyncgen(arg)
                or type_util.is_type(arg, "openai.Stream")
            ):
                flush_buffer()
                self.write_stream(arg)
            elif isinstance(arg, HELP_TYPES):
                flush_buffer()
                self.dg.help(arg)
            elif dataclasses.is_dataclass(arg):
                flush_buffer()
                self.dg.help(arg)
            elif inspect.isclass(arg):
                flush_buffer()
                # We cast arg to type here to appease mypy, due to bug in mypy:
                # https://github.com/python/mypy/issues/12933
                self.dg.help(cast(type, arg))
            elif unsafe_allow_html and type_util.has_callable_attr(arg, "_repr_html_"):
                self.dg.html(arg._repr_html_())
            elif type_util.has_callable_attr(
                arg, "to_pandas"
            ) or type_util.has_callable_attr(arg, "__dataframe__"):
                # This object can very likely be converted to a DataFrame
                # using the to_pandas, to_arrow, or the dataframe interchange
                # protocol.
                flush_buffer()
                self.dg.dataframe(arg)
            else:
                stringified_arg = str(arg)

                if is_mem_address_str(stringified_arg):
                    flush_buffer()
                    self.dg.help(arg)

                elif "\n" in stringified_arg:
                    # With a multi-line string, use a preformatted block
                    # To fully escape backticks, we wrap with backticks larger than
                    # the largest sequence of backticks in the string.
                    backtick_count = max(3, max_char_sequence(stringified_arg, "`") + 1)
                    backtick_wrapper = "`" * backtick_count
                    string_buffer.append(
                        f"{backtick_wrapper}\n{stringified_arg}\n{backtick_wrapper}"
                    )
                else:
                    # With a single-line string, use a preformatted text
                    # To fully escape backticks, we wrap with backticks larger than
                    # the largest sequence of backticks in the string.
                    backtick_count = max_char_sequence(stringified_arg, "`") + 1
                    backtick_wrapper = "`" * backtick_count
                    string_buffer.append(
                        f"{backtick_wrapper}{stringified_arg}{backtick_wrapper}"
                    )

        flush_buffer()

    @property
    def dg(self) -> DeltaGenerator:
        """Get our DeltaGenerator."""
        return cast("DeltaGenerator", self)
