try:
    import langchain  # noqa: F401
except ImportError:
    raise ModuleNotFoundError("Please install LangChain to use this feature: 'pip install langchain'")

import logging
import time
from dataclasses import dataclass
from typing import (
    Any,
    Dict,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)
from uuid import UUID

from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema.agent import AgentAction, AgentFinish
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, BaseMessage, FunctionMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from pydantic import BaseModel

from posthog import default_client
from posthog.ai.utils import get_model_params, with_privacy_mode
from posthog.client import Client

log = logging.getLogger("posthog")


@dataclass
class SpanMetadata:
    name: str
    """Name of the run: chain name, model name, etc."""
    start_time: float
    """Start time of the run."""
    end_time: Optional[float]
    """End time of the run."""
    input: Optional[Any]
    """Input of the run: messages, prompt variables, etc."""

    @property
    def latency(self) -> float:
        if not self.end_time:
            return 0
        return self.end_time - self.start_time


@dataclass
class GenerationMetadata(SpanMetadata):
    provider: Optional[str] = None
    """Provider of the run: OpenAI, Anthropic"""
    model: Optional[str] = None
    """Model used in the run"""
    model_params: Optional[Dict[str, Any]] = None
    """Model parameters of the run: temperature, max_tokens, etc."""
    base_url: Optional[str] = None
    """Base URL of the provider's API used in the run."""


RunMetadata = Union[SpanMetadata, GenerationMetadata]
RunMetadataStorage = Dict[UUID, RunMetadata]


class CallbackHandler(BaseCallbackHandler):
    """
    The PostHog LLM observability callback handler for LangChain.
    """

    _client: Client
    """PostHog client instance."""

    _distinct_id: Optional[Union[str, int, float, UUID]]
    """Distinct ID of the user to associate the trace with."""

    _trace_id: Optional[Union[str, int, float, UUID]]
    """Global trace ID to be sent with every event. Otherwise, the top-level run ID is used."""

    _trace_input: Optional[Any]
    """The input at the start of the trace. Any JSON object."""

    _trace_name: Optional[str]
    """Name of the trace, exposed in the UI."""

    _properties: Optional[Dict[str, Any]]
    """Global properties to be sent with every event."""

    _runs: RunMetadataStorage
    """Mapping of run IDs to run metadata as run metadata is only available on the start of generation."""

    _parent_tree: Dict[UUID, UUID]
    """
    A dictionary that maps chain run IDs to their parent chain run IDs (parent pointer tree),
    so the top level can be found from a bottom-level run ID.
    """

    def __init__(
        self,
        client: Optional[Client] = None,
        *,
        distinct_id: Optional[Union[str, int, float, UUID]] = None,
        trace_id: Optional[Union[str, int, float, UUID]] = None,
        properties: Optional[Dict[str, Any]] = None,
        privacy_mode: bool = False,
        groups: Optional[Dict[str, Any]] = None,
    ):
        """
        Args:
            client: PostHog client instance.
            distinct_id: Optional distinct ID of the user to associate the trace with.
            trace_id: Optional trace ID to use for the event.
            properties: Optional additional metadata to use for the trace.
            privacy_mode: Whether to redact the input and output of the trace.
            groups: Optional additional PostHog groups to use for the trace.
        """
        posthog_client = client or default_client
        if posthog_client is None:
            raise ValueError("PostHog client is required")
        self._client = posthog_client
        self._distinct_id = distinct_id
        self._trace_id = trace_id
        self._properties = properties or {}
        self._privacy_mode = privacy_mode
        self._groups = groups or {}
        self._runs = {}
        self._parent_tree = {}

    def on_chain_start(
        self,
        serialized: Dict[str, Any],
        inputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        self._log_debug_event("on_chain_start", run_id, parent_run_id, inputs=inputs)
        self._set_parent_of_run(run_id, parent_run_id)
        self._set_trace_or_span_metadata(serialized, inputs, run_id, parent_run_id, **kwargs)

    def on_chain_end(
        self,
        outputs: Dict[str, Any],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        self._log_debug_event("on_chain_end", run_id, parent_run_id, outputs=outputs)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, outputs)

    def on_chain_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        self._log_debug_event("on_chain_error", run_id, parent_run_id, error=error)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)

    def on_chat_model_start(
        self,
        serialized: Dict[str, Any],
        messages: List[List[BaseMessage]],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs,
    ):
        self._log_debug_event("on_chat_model_start", run_id, parent_run_id, messages=messages)
        self._set_parent_of_run(run_id, parent_run_id)
        input = [_convert_message_to_dict(message) for row in messages for message in row]
        self._set_llm_metadata(serialized, run_id, input, **kwargs)

    def on_llm_start(
        self,
        serialized: Dict[str, Any],
        prompts: List[str],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        self._log_debug_event("on_llm_start", run_id, parent_run_id, prompts=prompts)
        self._set_parent_of_run(run_id, parent_run_id)
        self._set_llm_metadata(serialized, run_id, prompts, **kwargs)

    def on_llm_new_token(
        self,
        token: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on new LLM token. Only available when streaming is enabled."""
        self._log_debug_event("on_llm_new_token", run_id, parent_run_id, token=token)

    def on_llm_end(
        self,
        response: LLMResult,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        """
        The callback works for both streaming and non-streaming runs. For streaming runs, the chain must set `stream_usage=True` in the LLM.
        """
        self._log_debug_event("on_llm_end", run_id, parent_run_id, response=response, kwargs=kwargs)
        self._pop_run_and_capture_generation(run_id, parent_run_id, response)

    def on_llm_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        self._log_debug_event("on_llm_error", run_id, parent_run_id, error=error)
        self._pop_run_and_capture_generation(run_id, parent_run_id, error)

    def on_tool_start(
        self,
        serialized: Optional[Dict[str, Any]],
        input_str: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._log_debug_event("on_tool_start", run_id, parent_run_id, input_str=input_str)
        self._set_trace_or_span_metadata(serialized, input_str, run_id, parent_run_id, **kwargs)

    def on_tool_end(
        self,
        output: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        self._log_debug_event("on_tool_end", run_id, parent_run_id, output=output)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, output)

    def on_tool_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> Any:
        self._log_debug_event("on_tool_error", run_id, parent_run_id, error=error)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)

    def on_retriever_start(
        self,
        serialized: Optional[Dict[str, Any]],
        query: str,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        metadata: Optional[Dict[str, Any]] = None,
        **kwargs: Any,
    ) -> Any:
        self._log_debug_event("on_retriever_start", run_id, parent_run_id, query=query)
        self._set_trace_or_span_metadata(serialized, query, run_id, parent_run_id, **kwargs)

    def on_retriever_end(
        self,
        documents: Sequence[Document],
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ):
        self._log_debug_event("on_retriever_end", run_id, parent_run_id, documents=documents)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, documents)

    def on_retriever_error(
        self,
        error: BaseException,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        tags: Optional[list[str]] = None,
        **kwargs: Any,
    ) -> Any:
        """Run when Retriever errors."""
        self._log_debug_event("on_retriever_error", run_id, parent_run_id, error=error)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, error)

    def on_agent_action(
        self,
        action: AgentAction,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        """Run on agent action."""
        self._log_debug_event("on_agent_action", run_id, parent_run_id, action=action)
        self._set_parent_of_run(run_id, parent_run_id)
        self._set_trace_or_span_metadata(None, action, run_id, parent_run_id, **kwargs)

    def on_agent_finish(
        self,
        finish: AgentFinish,
        *,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs: Any,
    ) -> Any:
        self._log_debug_event("on_agent_finish", run_id, parent_run_id, finish=finish)
        self._pop_run_and_capture_trace_or_span(run_id, parent_run_id, finish)

    def _set_parent_of_run(self, run_id: UUID, parent_run_id: Optional[UUID] = None):
        """
        Set the parent run ID for a chain run. If there is no parent, the run is the root.
        """
        if parent_run_id is not None:
            self._parent_tree[run_id] = parent_run_id

    def _pop_parent_of_run(self, run_id: UUID):
        """
        Remove the parent run ID for a chain run.
        """
        try:
            self._parent_tree.pop(run_id)
        except KeyError:
            pass

    def _find_root_run(self, run_id: UUID) -> UUID:
        """
        Finds the root ID of a chain run.
        """
        id: UUID = run_id
        while id in self._parent_tree:
            id = self._parent_tree[id]
        return id

    def _set_trace_or_span_metadata(
        self,
        serialized: Optional[Dict[str, Any]],
        input: Any,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs,
    ):
        default_name = "trace" if parent_run_id is None else "span"
        run_name = _get_langchain_run_name(serialized, **kwargs) or default_name
        self._runs[run_id] = SpanMetadata(name=run_name, input=input, start_time=time.time(), end_time=None)

    def _set_llm_metadata(
        self,
        serialized: Dict[str, Any],
        run_id: UUID,
        messages: Union[List[Dict[str, Any]], List[str]],
        metadata: Optional[Dict[str, Any]] = None,
        invocation_params: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        run_name = _get_langchain_run_name(serialized, **kwargs) or "generation"
        generation = GenerationMetadata(name=run_name, input=messages, start_time=time.time(), end_time=None)
        if isinstance(invocation_params, dict):
            generation.model_params = get_model_params(invocation_params)
        if isinstance(metadata, dict):
            if model := metadata.get("ls_model_name"):
                generation.model = model
            if provider := metadata.get("ls_provider"):
                generation.provider = provider
        try:
            base_url = serialized["kwargs"]["openai_api_base"]
            if base_url is not None:
                generation.base_url = base_url
        except KeyError:
            pass
        self._runs[run_id] = generation

    def _pop_run_metadata(self, run_id: UUID) -> Optional[RunMetadata]:
        end_time = time.time()
        try:
            run = self._runs.pop(run_id)
        except KeyError:
            log.warning(f"No run metadata found for run {run_id}")
            return None
        run.end_time = end_time
        return run

    def _get_trace_id(self, run_id: UUID):
        trace_id = self._trace_id or self._find_root_run(run_id)
        if not trace_id:
            return run_id
        return trace_id

    def _get_parent_run_id(self, trace_id: Any, run_id: UUID, parent_run_id: Optional[UUID]):
        """
        Replace the parent run ID with the trace ID for second level runs when a custom trace ID is set.
        """
        if parent_run_id is not None and parent_run_id not in self._parent_tree:
            return trace_id
        return parent_run_id

    def _pop_run_and_capture_trace_or_span(self, run_id: UUID, parent_run_id: Optional[UUID], outputs: Any):
        trace_id = self._get_trace_id(run_id)
        self._pop_parent_of_run(run_id)
        run = self._pop_run_metadata(run_id)
        if not run:
            return
        if isinstance(run, GenerationMetadata):
            log.warning(f"Run {run_id} is a generation, but attempted to be captured as a trace or span.")
            return
        self._capture_trace_or_span(
            trace_id, run_id, run, outputs, self._get_parent_run_id(trace_id, run_id, parent_run_id)
        )

    def _capture_trace_or_span(
        self,
        trace_id: Any,
        run_id: UUID,
        run: SpanMetadata,
        outputs: Any,
        parent_run_id: Optional[UUID],
    ):
        event_name = "$ai_trace" if parent_run_id is None else "$ai_span"
        event_properties = {
            "$ai_trace_id": trace_id,
            "$ai_input_state": with_privacy_mode(self._client, self._privacy_mode, run.input),
            "$ai_latency": run.latency,
            "$ai_span_name": run.name,
            "$ai_span_id": run_id,
        }
        if parent_run_id is not None:
            event_properties["$ai_parent_id"] = parent_run_id
        if self._properties:
            event_properties.update(self._properties)

        if isinstance(outputs, BaseException):
            event_properties["$ai_error"] = _stringify_exception(outputs)
            event_properties["$ai_is_error"] = True
        elif outputs is not None:
            event_properties["$ai_output_state"] = with_privacy_mode(self._client, self._privacy_mode, outputs)

        if self._distinct_id is None:
            event_properties["$process_person_profile"] = False

        self._client.capture(
            distinct_id=self._distinct_id or run_id,
            event=event_name,
            properties=event_properties,
            groups=self._groups,
        )

    def _pop_run_and_capture_generation(
        self, run_id: UUID, parent_run_id: Optional[UUID], response: Union[LLMResult, BaseException]
    ):
        trace_id = self._get_trace_id(run_id)
        self._pop_parent_of_run(run_id)
        run = self._pop_run_metadata(run_id)
        if not run:
            return
        if not isinstance(run, GenerationMetadata):
            log.warning(f"Run {run_id} is not a generation, but attempted to be captured as a generation.")
            return
        self._capture_generation(
            trace_id, run_id, run, response, self._get_parent_run_id(trace_id, run_id, parent_run_id)
        )

    def _capture_generation(
        self,
        trace_id: Any,
        run_id: UUID,
        run: GenerationMetadata,
        output: Union[LLMResult, BaseException],
        parent_run_id: Optional[UUID] = None,
    ):
        event_properties = {
            "$ai_trace_id": trace_id,
            "$ai_span_id": run_id,
            "$ai_span_name": run.name,
            "$ai_parent_id": parent_run_id,
            "$ai_provider": run.provider,
            "$ai_model": run.model,
            "$ai_model_parameters": run.model_params,
            "$ai_input": with_privacy_mode(self._client, self._privacy_mode, run.input),
            "$ai_http_status": 200,
            "$ai_latency": run.latency,
            "$ai_base_url": run.base_url,
        }

        if isinstance(output, BaseException):
            event_properties["$ai_http_status"] = _get_http_status(output)
            event_properties["$ai_error"] = _stringify_exception(output)
            event_properties["$ai_is_error"] = True
        else:
            # Add usage
            input_tokens, output_tokens = _parse_usage(output)
            event_properties["$ai_input_tokens"] = input_tokens
            event_properties["$ai_output_tokens"] = output_tokens

            # Generation results
            generation_result = output.generations[-1]
            if isinstance(generation_result[-1], ChatGeneration):
                completions = [
                    _convert_message_to_dict(cast(ChatGeneration, generation).message)
                    for generation in generation_result
                ]
            else:
                completions = [_extract_raw_esponse(generation) for generation in generation_result]
            event_properties["$ai_output_choices"] = with_privacy_mode(self._client, self._privacy_mode, completions)

        if self._properties:
            event_properties.update(self._properties)

        if self._distinct_id is None:
            event_properties["$process_person_profile"] = False

        self._client.capture(
            distinct_id=self._distinct_id or trace_id,
            event="$ai_generation",
            properties=event_properties,
            groups=self._groups,
        )

    def _log_debug_event(
        self,
        event_name: str,
        run_id: UUID,
        parent_run_id: Optional[UUID] = None,
        **kwargs,
    ):
        log.debug(
            f"Event: {event_name}, run_id: {str(run_id)[:5]}, parent_run_id: {str(parent_run_id)[:5]}, kwargs: {kwargs}"
        )


def _extract_raw_esponse(last_response):
    """Extract the response from the last response of the LLM call."""
    # We return the text of the response if not empty
    if last_response.text is not None and last_response.text.strip() != "":
        return last_response.text.strip()
    elif hasattr(last_response, "message"):
        # Additional kwargs contains the response in case of tool usage
        return last_response.message.additional_kwargs
    else:
        # Not tool usage, some LLM responses can be simply empty
        return ""


def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]:
    # assistant message
    if isinstance(message, HumanMessage):
        message_dict = {"role": "user", "content": message.content}
    elif isinstance(message, AIMessage):
        message_dict = {"role": "assistant", "content": message.content}
    elif isinstance(message, SystemMessage):
        message_dict = {"role": "system", "content": message.content}
    elif isinstance(message, ToolMessage):
        message_dict = {"role": "tool", "content": message.content}
    elif isinstance(message, FunctionMessage):
        message_dict = {"role": "function", "content": message.content}
    else:
        message_dict = {"role": message.type, "content": str(message.content)}

    if message.additional_kwargs:
        message_dict.update(message.additional_kwargs)

    return message_dict


def _parse_usage_model(
    usage: Union[BaseModel, Dict],
) -> Tuple[Union[int, None], Union[int, None]]:
    if isinstance(usage, BaseModel):
        usage = usage.__dict__

    conversion_list = [
        # https://pypi.org/project/langchain-anthropic/ (works also for Bedrock-Anthropic)
        ("input_tokens", "input"),
        ("output_tokens", "output"),
        # https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/get-token-count
        ("prompt_token_count", "input"),
        ("candidates_token_count", "output"),
        # Bedrock: https://docs.aws.amazon.com/bedrock/latest/userguide/monitoring-cw.html#runtime-cloudwatch-metrics
        ("inputTokenCount", "input"),
        ("outputTokenCount", "output"),
        # langchain-ibm https://pypi.org/project/langchain-ibm/
        ("input_token_count", "input"),
        ("generated_token_count", "output"),
    ]

    parsed_usage = {}
    for model_key, type_key in conversion_list:
        if model_key in usage:
            captured_count = usage[model_key]
            final_count = (
                sum(captured_count) if isinstance(captured_count, list) else captured_count
            )  # For Bedrock, the token count is a list when streamed

            parsed_usage[type_key] = final_count

    return parsed_usage.get("input"), parsed_usage.get("output")


def _parse_usage(response: LLMResult):
    # langchain-anthropic uses the usage field
    llm_usage_keys = ["token_usage", "usage"]
    llm_usage: Tuple[Union[int, None], Union[int, None]] = (None, None)
    if response.llm_output is not None:
        for key in llm_usage_keys:
            if response.llm_output.get(key):
                llm_usage = _parse_usage_model(response.llm_output[key])
                break

    if hasattr(response, "generations"):
        for generation in response.generations:
            for generation_chunk in generation:
                if generation_chunk.generation_info and ("usage_metadata" in generation_chunk.generation_info):
                    llm_usage = _parse_usage_model(generation_chunk.generation_info["usage_metadata"])
                    break

                message_chunk = getattr(generation_chunk, "message", {})
                response_metadata = getattr(message_chunk, "response_metadata", {})

                bedrock_anthropic_usage = (
                    response_metadata.get("usage", None)  # for Bedrock-Anthropic
                    if isinstance(response_metadata, dict)
                    else None
                )
                bedrock_titan_usage = (
                    response_metadata.get("amazon-bedrock-invocationMetrics", None)  # for Bedrock-Titan
                    if isinstance(response_metadata, dict)
                    else None
                )
                ollama_usage = getattr(message_chunk, "usage_metadata", None)  # for Ollama

                chunk_usage = bedrock_anthropic_usage or bedrock_titan_usage or ollama_usage
                if chunk_usage:
                    llm_usage = _parse_usage_model(chunk_usage)
                    break

    return llm_usage


def _get_http_status(error: BaseException) -> int:
    # OpenAI: https://github.com/openai/openai-python/blob/main/src/openai/_exceptions.py
    # Anthropic: https://github.com/anthropics/anthropic-sdk-python/blob/main/src/anthropic/_exceptions.py
    # Google: https://github.com/googleapis/python-api-core/blob/main/google/api_core/exceptions.py
    status_code = getattr(error, "status_code", getattr(error, "code", 0))
    return status_code


def _get_langchain_run_name(serialized: Optional[Dict[str, Any]], **kwargs: Any) -> Optional[str]:
    """Retrieve the name of a serialized LangChain runnable.

    The prioritization for the determination of the run name is as follows:
    - The value assigned to the "name" key in `kwargs`.
    - The value assigned to the "name" key in `serialized`.
    - The last entry of the value assigned to the "id" key in `serialized`.
    - "<unknown>".

    Args:
        serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
        **kwargs (Any): Additional keyword arguments, potentially including the 'name' override.

    Returns:
        str: The determined name of the Langchain runnable.
    """
    if "name" in kwargs and kwargs["name"] is not None:
        return kwargs["name"]
    if serialized is None:
        return None
    try:
        return serialized["name"]
    except (KeyError, TypeError):
        pass
    try:
        return serialized["id"][-1]
    except (KeyError, TypeError):
        pass
    return None


def _stringify_exception(exception: BaseException) -> str:
    description = str(exception)
    if description:
        return f"{exception.__class__.__name__}: {description}"
    return exception.__class__.__name__
