# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import contextlib
import inspect
import os
import sys
import threading
import time
import uuid
from collections.abc import Sized
from functools import wraps
from typing import Any, Callable, Final, TypeVar, cast, overload

from streamlit import config, util
from streamlit.logger import get_logger
from streamlit.proto.ForwardMsg_pb2 import ForwardMsg
from streamlit.proto.PageProfile_pb2 import Argument, Command
from streamlit.runtime.scriptrunner_utils.exceptions import RerunException
from streamlit.runtime.scriptrunner_utils.script_run_context import get_script_run_ctx

_LOGGER: Final = get_logger(__name__)

# Limit the number of commands to keep the page profile message small
_MAX_TRACKED_COMMANDS: Final = 200
# Only track a maximum of 25 uses per unique command since some apps use
# commands excessively (e.g. calling add_rows thousands of times in one rerun)
# making the page profile useless.
_MAX_TRACKED_PER_COMMAND: Final = 25

# A mapping to convert from the actual name to preferred/shorter representations
_OBJECT_NAME_MAPPING: Final = {
    "streamlit.delta_generator.DeltaGenerator": "DG",
    "pandas.core.frame.DataFrame": "DataFrame",
    "plotly.graph_objs._figure.Figure": "PlotlyFigure",
    "bokeh.plotting.figure.Figure": "BokehFigure",
    "matplotlib.figure.Figure": "MatplotlibFigure",
    "pandas.io.formats.style.Styler": "PandasStyler",
    "pandas.core.indexes.base.Index": "PandasIndex",
    "pandas.core.series.Series": "PandasSeries",
    "streamlit.connections.snowpark_connection.SnowparkConnection": "SnowparkConnection",
    "streamlit.connections.sql_connection.SQLConnection": "SQLConnection",
}

# A list of dependencies to check for attribution
_ATTRIBUTIONS_TO_CHECK: Final = [
    # DB Clients:
    "pymysql",
    "MySQLdb",
    "mysql",
    "pymongo",
    "ibis",
    "boto3",
    "psycopg2",
    "psycopg3",
    "sqlalchemy",
    "elasticsearch",
    "pyodbc",
    "pymssql",
    "cassandra",
    "azure",
    "redis",
    "sqlite3",
    "neo4j",
    "duckdb",
    "opensearchpy",
    "supabase",
    # Dataframe Libraries:
    "polars",
    "dask",
    "vaex",
    "modin",
    "pyspark",
    "cudf",
    "xarray",
    "ray",
    "geopandas",
    "mars",
    "tables",
    "zarr",
    "datasets",
    # ML & LLM Tools:
    "mistralai",
    "openai",
    "langchain",
    "llama_index",
    "llama_cpp",
    "anthropic",
    "pyllamacpp",
    "cohere",
    "transformers",
    "nomic",
    "diffusers",
    "semantic_kernel",
    "replicate",
    "huggingface_hub",
    "wandb",
    "torch",
    "tensorflow",
    "trubrics",
    "comet_ml",
    "clarifai",
    "reka",
    "hegel",
    "fastchat",
    "assemblyai",
    "openllm",
    "embedchain",
    "haystack",
    "vllm",
    "alpa",
    "jinaai",
    "guidance",
    "litellm",
    "comet_llm",
    "instructor",
    "xgboost",
    "lightgbm",
    "catboost",
    "sklearn",
    # Workflow Tools:
    "prefect",
    "luigi",
    "airflow",
    "dagster",
    # Vector Stores:
    "pgvector",
    "faiss",
    "annoy",
    "pinecone",
    "chromadb",
    "weaviate",
    "qdrant_client",
    "pymilvus",
    "lancedb",
    # Others:
    "snowflake",
    "streamlit_extras",
    "streamlit_pydantic",
    "pydantic",
    "plost",
    "authlib",
]

_ETC_MACHINE_ID_PATH = "/etc/machine-id"
_DBUS_MACHINE_ID_PATH = "/var/lib/dbus/machine-id"


def _get_machine_id_v3() -> str:
    """Get the machine ID

    This is a unique identifier for a user for tracking metrics,
    that is broken in different ways in some Linux distros and Docker images.
    - at times just a hash of '', which means many machines map to the same ID
    - at times a hash of the same string, when running in a Docker container
    """

    machine_id = str(uuid.getnode())
    if os.path.isfile(_ETC_MACHINE_ID_PATH):
        with open(_ETC_MACHINE_ID_PATH) as f:
            machine_id = f.read()

    elif os.path.isfile(_DBUS_MACHINE_ID_PATH):
        with open(_DBUS_MACHINE_ID_PATH) as f:
            machine_id = f.read()

    return machine_id


class Installation:
    _instance_lock = threading.Lock()
    _instance: Installation | None = None

    @classmethod
    def instance(cls) -> Installation:
        """Returns the singleton Installation"""
        # We use a double-checked locking optimization to avoid the overhead
        # of acquiring the lock in the common case:
        # https://en.wikipedia.org/wiki/Double-checked_locking
        if cls._instance is None:
            with cls._instance_lock:
                if cls._instance is None:
                    cls._instance = Installation()
        return cls._instance

    def __init__(self):
        self.installation_id_v3 = str(
            uuid.uuid5(uuid.NAMESPACE_DNS, _get_machine_id_v3())
        )

    def __repr__(self) -> str:
        return util.repr_(self)

    @property
    def installation_id(self):
        return self.installation_id_v3


def _get_type_name(obj: object) -> str:
    """Get a simplified name for the type of the given object."""
    with contextlib.suppress(Exception):
        obj_type = obj if inspect.isclass(obj) else type(obj)
        type_name = "unknown"
        if hasattr(obj_type, "__qualname__"):
            type_name = obj_type.__qualname__
        elif hasattr(obj_type, "__name__"):
            type_name = obj_type.__name__

        if obj_type.__module__ != "builtins":
            # Add the full module path
            type_name = f"{obj_type.__module__}.{type_name}"

        if type_name in _OBJECT_NAME_MAPPING:
            type_name = _OBJECT_NAME_MAPPING[type_name]
        return type_name
    return "failed"


def _get_top_level_module(func: Callable[..., Any]) -> str:
    """Get the top level module for the given function."""
    module = inspect.getmodule(func)
    if module is None or not module.__name__:
        return "unknown"
    return module.__name__.split(".")[0]


def _get_arg_metadata(arg: object) -> str | None:
    """Get metadata information related to the value of the given object."""
    with contextlib.suppress(Exception):
        if isinstance(arg, (bool)):
            return f"val:{arg}"

        if isinstance(arg, Sized):
            return f"len:{len(arg)}"

    return None


def _get_command_telemetry(
    _command_func: Callable[..., Any], _command_name: str, *args, **kwargs
) -> Command:
    """Get telemetry information for the given callable and its arguments."""
    arg_keywords = inspect.getfullargspec(_command_func).args
    self_arg: Any | None = None
    arguments: list[Argument] = []
    is_method = inspect.ismethod(_command_func)
    name = _command_name

    for i, arg in enumerate(args):
        pos = i
        if is_method:
            # If func is a method, ignore the first argument (self)
            i = i + 1

        keyword = arg_keywords[i] if len(arg_keywords) > i else f"{i}"
        if keyword == "self":
            self_arg = arg
            continue
        argument = Argument(k=keyword, t=_get_type_name(arg), p=pos)

        arg_metadata = _get_arg_metadata(arg)
        if arg_metadata:
            argument.m = arg_metadata
        arguments.append(argument)
    for kwarg, kwarg_value in kwargs.items():
        argument = Argument(k=kwarg, t=_get_type_name(kwarg_value))

        arg_metadata = _get_arg_metadata(kwarg_value)
        if arg_metadata:
            argument.m = arg_metadata
        arguments.append(argument)

    top_level_module = _get_top_level_module(_command_func)
    if top_level_module != "streamlit":
        # If the gather_metrics decorator is used outside of streamlit library
        # we enforce a prefix to be added to the tracked command:
        name = f"external:{top_level_module}:{name}"

    if (
        name == "create_instance"
        and self_arg
        and hasattr(self_arg, "name")
        and self_arg.name
    ):
        name = f"component:{self_arg.name}"

    return Command(name=name, args=arguments)


def to_microseconds(seconds: float) -> int:
    """Convert seconds into microseconds."""
    return int(seconds * 1_000_000)


F = TypeVar("F", bound=Callable[..., Any])


@overload
def gather_metrics(
    name: str,
    func: F,
) -> F: ...


@overload
def gather_metrics(
    name: str,
    func: None = None,
) -> Callable[[F], F]: ...


def gather_metrics(name: str, func: F | None = None) -> Callable[[F], F] | F:
    """Function decorator to add telemetry tracking to commands.

    Parameters
    ----------
    func : callable
    The function to track for telemetry.

    name : str or None
    Overwrite the function name with a custom name that is used for telemetry tracking.

    Example
    -------
    >>> @st.gather_metrics
    ... def my_command(url):
    ...     return url

    >>> @st.gather_metrics(name="custom_name")
    ... def my_command(url):
    ...     return url
    """

    if not name:
        _LOGGER.warning("gather_metrics: name is empty")
        name = "undefined"

    if func is None:
        # Support passing the params via function decorator
        def wrapper(f: F) -> F:
            return gather_metrics(
                name=name,
                func=f,
            )

        return wrapper
    else:
        # To make mypy type narrow F | None -> F
        non_optional_func = func

    @wraps(non_optional_func)
    def wrapped_func(*args, **kwargs):
        from timeit import default_timer as timer

        exec_start = timer()
        ctx = get_script_run_ctx(suppress_warning=True)

        tracking_activated = (
            ctx is not None
            and ctx.gather_usage_stats
            and not ctx.command_tracking_deactivated
            and len(ctx.tracked_commands)
            < _MAX_TRACKED_COMMANDS  # Prevent too much memory usage
        )

        command_telemetry: Command | None = None
        # This flag is needed to make sure that only the command (the outermost command)
        # that deactivated tracking (via ctx.command_tracking_deactivated) is able to reset it
        # again. This is important to prevent nested commands from reactivating tracking.
        # At this point, we don't know yet if the command will deactivated tracking.
        has_set_command_tracking_deactivated = False

        if ctx and tracking_activated:
            try:
                command_telemetry = _get_command_telemetry(
                    non_optional_func, name, *args, **kwargs
                )

                if (
                    command_telemetry.name not in ctx.tracked_commands_counter
                    or ctx.tracked_commands_counter[command_telemetry.name]
                    < _MAX_TRACKED_PER_COMMAND
                ):
                    ctx.tracked_commands.append(command_telemetry)
                ctx.tracked_commands_counter.update([command_telemetry.name])
                # Deactivate tracking to prevent calls inside already tracked commands
                ctx.command_tracking_deactivated = True
                # The ctx.command_tracking_deactivated flag was set to True,
                # we also need to set has_set_command_tracking_deactivated to True
                # to make sure that this command is able to reset it again.
                has_set_command_tracking_deactivated = True
            except Exception as ex:
                # Always capture all exceptions since we want to make sure that
                # the telemetry never causes any issues.
                _LOGGER.debug("Failed to collect command telemetry", exc_info=ex)
        try:
            result = non_optional_func(*args, **kwargs)
        except RerunException as ex:
            # Duplicated from below, because static analysis tools get confused
            # by deferring the rethrow.
            if tracking_activated and command_telemetry:
                command_telemetry.time = to_microseconds(timer() - exec_start)
            raise ex
        finally:
            # Activate tracking again if command executes without any exceptions
            # we only want to do that if this command has set the
            # flag to deactivate tracking.
            if ctx and has_set_command_tracking_deactivated:
                ctx.command_tracking_deactivated = False

        if tracking_activated and command_telemetry:
            # Set the execution time to the measured value
            command_telemetry.time = to_microseconds(timer() - exec_start)

        return result

    with contextlib.suppress(AttributeError):
        # Make this a well-behaved decorator by preserving important function
        # attributes.
        wrapped_func.__dict__.update(non_optional_func.__dict__)
        wrapped_func.__signature__ = inspect.signature(non_optional_func)  # type: ignore
    return cast(F, wrapped_func)


def create_page_profile_message(
    commands: list[Command],
    exec_time: int,
    prep_time: int,
    uncaught_exception: str | None = None,
) -> ForwardMsg:
    """Create and return the full PageProfile ForwardMsg."""
    msg = ForwardMsg()
    page_profile = msg.page_profile

    page_profile.commands.extend(commands)
    page_profile.exec_time = exec_time
    page_profile.prep_time = prep_time

    page_profile.headless = config.get_option("server.headless")

    # Collect all config options that have been manually set
    config_options: set[str] = set()
    if config._config_options:
        for option_name in config._config_options.keys():
            if not config.is_manually_set(option_name):
                # We only care about manually defined options
                continue

            config_option = config._config_options[option_name]
            if config_option.is_default:
                option_name = f"{option_name}:default"
            config_options.add(option_name)

    page_profile.config.extend(config_options)

    # Check the predefined set of modules for attribution
    attributions: set[str] = {
        attribution
        for attribution in _ATTRIBUTIONS_TO_CHECK
        if attribution in sys.modules
    }

    page_profile.os = str(sys.platform)
    page_profile.timezone = str(time.tzname)
    page_profile.attributions.extend(attributions)

    if uncaught_exception:
        page_profile.uncaught_exception = uncaught_exception

    if ctx := get_script_run_ctx():
        page_profile.is_fragment_run = bool(ctx.fragment_ids_this_run)

    return msg
