# What is this?
## Log success + failure events to Braintrust

import copy
import os
from datetime import datetime
from typing import Optional, Dict

import httpx
from pydantic import BaseModel

import litellm
from litellm import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
    HTTPHandler,
    get_async_httpx_client,
    httpxSpecialProvider,
)
from litellm.utils import print_verbose

global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"


def get_utc_datetime():
    import datetime as dt
    from datetime import datetime

    if hasattr(dt, "UTC"):
        return datetime.now(dt.UTC)  # type: ignore
    else:
        return datetime.utcnow()  # type: ignore


class BraintrustLogger(CustomLogger):
    def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
        super().__init__()
        self.validate_environment(api_key=api_key)
        self.api_base = api_base or API_BASE
        self.default_project_id = None
        self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY")  # type: ignore
        self.headers = {
            "Authorization": "Bearer " + self.api_key,
            "Content-Type": "application/json",
        }
        self._project_id_cache: Dict[str, str] = {}  # Cache mapping project names to IDs

    def validate_environment(self, api_key: Optional[str]):
        """
        Expects
        BRAINTRUST_API_KEY

        in the environment
        """
        missing_keys = []
        if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None:
            missing_keys.append("BRAINTRUST_API_KEY")

        if len(missing_keys) > 0:
            raise Exception("Missing keys={} in environment.".format(missing_keys))

    def get_project_id_sync(self, project_name: str) -> str:
        """
        Get project ID from name, using cache if available.
        If project doesn't exist, creates it.
        """
        if project_name in self._project_id_cache:
            return self._project_id_cache[project_name]

        try:
            response = global_braintrust_sync_http_handler.post(
                f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
            )
            project_dict = response.json()
            project_id = project_dict["id"]
            self._project_id_cache[project_name] = project_id
            return project_id
        except httpx.HTTPStatusError as e:
            raise Exception(f"Failed to register project: {e.response.text}")

    async def get_project_id_async(self, project_name: str) -> str:
        """
        Async version of get_project_id_sync
        """
        if project_name in self._project_id_cache:
            return self._project_id_cache[project_name]

        try:
            response = await global_braintrust_http_handler.post(
                f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
            )
            project_dict = response.json()
            project_id = project_dict["id"]
            self._project_id_cache[project_name] = project_id
            return project_id
        except httpx.HTTPStatusError as e:
            raise Exception(f"Failed to register project: {e.response.text}")

    @staticmethod
    def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict:
        """
        Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_"
        and overwrites litellm_params.metadata if already included.

        For example if you want to append your trace to an existing `trace_id` via header, send
        `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request.
        """
        if litellm_params is None:
            return metadata

        if litellm_params.get("proxy_server_request") is None:
            return metadata

        if metadata is None:
            metadata = {}

        proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}

        for metadata_param_key in proxy_headers:
            if metadata_param_key.startswith("braintrust"):
                trace_param_key = metadata_param_key.replace("braintrust", "", 1)
                if trace_param_key in metadata:
                    verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
                else:
                    verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
                metadata[trace_param_key] = proxy_headers.get(metadata_param_key)

        return metadata

    async def create_default_project_and_experiment(self):
        project = await global_braintrust_http_handler.post(
            f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
        )

        project_dict = project.json()

        self.default_project_id = project_dict["id"]

    def create_sync_default_project_and_experiment(self):
        project = global_braintrust_sync_http_handler.post(
            f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
        )

        project_dict = project.json()

        self.default_project_id = project_dict["id"]

    def log_success_event(  # noqa: PLR0915
        self, kwargs, response_obj, start_time, end_time
    ):
        verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
        try:
            litellm_call_id = kwargs.get("litellm_call_id")
            prompt = {"messages": kwargs.get("messages")}
            output = None
            choices = []
            if response_obj is not None and (
                kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
            ):
                output = None
            elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
                output = response_obj["choices"][0]["message"].json()
                choices = response_obj["choices"]
            elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
                output = response_obj.choices[0].text
                choices = response_obj.choices
            elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
                output = response_obj["data"]

            litellm_params = kwargs.get("litellm_params", {})
            metadata = litellm_params.get("metadata", {}) or {}  # if litellm_params['metadata'] == None
            metadata = self.add_metadata_from_header(litellm_params, metadata)
            clean_metadata = {}
            try:
                metadata = copy.deepcopy(metadata)  # Avoid modifying the original metadata
            except Exception:
                new_metadata = {}
                for key, value in metadata.items():
                    if (
                        isinstance(value, list)
                        or isinstance(value, dict)
                        or isinstance(value, str)
                        or isinstance(value, int)
                        or isinstance(value, float)
                    ):
                        new_metadata[key] = copy.deepcopy(value)
                metadata = new_metadata

            # Get project_id from metadata or create default if needed
            project_id = metadata.get("project_id")
            if project_id is None:
                project_name = metadata.get("project_name")
                project_id = self.get_project_id_sync(project_name) if project_name else None

            if project_id is None:
                if self.default_project_id is None:
                    self.create_sync_default_project_and_experiment()
                project_id = self.default_project_id

            tags = []
            if isinstance(metadata, dict):
                for key, value in metadata.items():
                    # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
                    if (
                        litellm.langfuse_default_tags is not None
                        and isinstance(litellm.langfuse_default_tags, list)
                        and key in litellm.langfuse_default_tags
                    ):
                        tags.append(f"{key}:{value}")

                    # clean litellm metadata before logging
                    if key in [
                        "headers",
                        "endpoint",
                        "caching_groups",
                        "previous_models",
                    ]:
                        continue
                    else:
                        clean_metadata[key] = value

            cost = kwargs.get("response_cost", None)
            if cost is not None:
                clean_metadata["litellm_response_cost"] = cost

            metrics: Optional[dict] = None
            usage_obj = getattr(response_obj, "usage", None)
            if usage_obj and isinstance(usage_obj, litellm.Usage):
                litellm.utils.get_logging_id(start_time, response_obj)
                metrics = {
                    "prompt_tokens": usage_obj.prompt_tokens,
                    "completion_tokens": usage_obj.completion_tokens,
                    "total_tokens": usage_obj.total_tokens,
                    "total_cost": cost,
                    "time_to_first_token": end_time.timestamp() - start_time.timestamp(),
                    "start": start_time.timestamp(),
                    "end": end_time.timestamp(),
                }

            request_data = {
                "id": litellm_call_id,
                "input": prompt["messages"],
                "metadata": clean_metadata,
                "tags": tags,
                "span_attributes": {"name": "Chat Completion", "type": "llm"},
            }
            if choices is not None:
                request_data["output"] = [choice.dict() for choice in choices]
            else:
                request_data["output"] = output

            if metrics is not None:
                request_data["metrics"] = metrics

            try:
                print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
                global_braintrust_sync_http_handler.post(
                    url=f"{self.api_base}/project_logs/{project_id}/insert",
                    json={"events": [request_data]},
                    headers=self.headers,
                )
            except httpx.HTTPStatusError as e:
                raise Exception(e.response.text)
        except Exception as e:
            raise e  # don't use verbose_logger.exception, if exception is raised

    async def async_log_success_event(  # noqa: PLR0915
        self, kwargs, response_obj, start_time, end_time
    ):
        verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
        try:
            litellm_call_id = kwargs.get("litellm_call_id")
            prompt = {"messages": kwargs.get("messages")}
            output = None
            choices = []
            if response_obj is not None and (
                kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
            ):
                output = None
            elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
                output = response_obj["choices"][0]["message"].json()
                choices = response_obj["choices"]
            elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
                output = response_obj.choices[0].text
                choices = response_obj.choices
            elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
                output = response_obj["data"]

            litellm_params = kwargs.get("litellm_params", {})
            metadata = litellm_params.get("metadata", {}) or {}  # if litellm_params['metadata'] == None
            metadata = self.add_metadata_from_header(litellm_params, metadata)
            clean_metadata = {}
            new_metadata = {}
            for key, value in metadata.items():
                if (
                    isinstance(value, list)
                    or isinstance(value, str)
                    or isinstance(value, int)
                    or isinstance(value, float)
                ):
                    new_metadata[key] = value
                elif isinstance(value, BaseModel):
                    new_metadata[key] = value.model_dump_json()
                elif isinstance(value, dict):
                    for k, v in value.items():
                        if isinstance(v, datetime):
                            value[k] = v.isoformat()
                    new_metadata[key] = value

            # Get project_id from metadata or create default if needed
            project_id = metadata.get("project_id")
            if project_id is None:
                project_name = metadata.get("project_name")
                project_id = await self.get_project_id_async(project_name) if project_name else None

            if project_id is None:
                if self.default_project_id is None:
                    await self.create_default_project_and_experiment()
                project_id = self.default_project_id

            tags = []
            if isinstance(metadata, dict):
                for key, value in metadata.items():
                    # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
                    if (
                        litellm.langfuse_default_tags is not None
                        and isinstance(litellm.langfuse_default_tags, list)
                        and key in litellm.langfuse_default_tags
                    ):
                        tags.append(f"{key}:{value}")

                    # clean litellm metadata before logging
                    if key in [
                        "headers",
                        "endpoint",
                        "caching_groups",
                        "previous_models",
                    ]:
                        continue
                    else:
                        clean_metadata[key] = value

            cost = kwargs.get("response_cost", None)
            if cost is not None:
                clean_metadata["litellm_response_cost"] = cost

            metrics: Optional[dict] = None
            usage_obj = getattr(response_obj, "usage", None)
            if usage_obj and isinstance(usage_obj, litellm.Usage):
                litellm.utils.get_logging_id(start_time, response_obj)
                metrics = {
                    "prompt_tokens": usage_obj.prompt_tokens,
                    "completion_tokens": usage_obj.completion_tokens,
                    "total_tokens": usage_obj.total_tokens,
                    "total_cost": cost,
                    "start": start_time.timestamp(),
                    "end": end_time.timestamp(),
                }

                api_call_start_time = kwargs.get("api_call_start_time")
                completion_start_time = kwargs.get("completion_start_time")

                if api_call_start_time is not None and completion_start_time is not None:
                    metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()

            request_data = {
                "id": litellm_call_id,
                "input": prompt["messages"],
                "output": output,
                "metadata": clean_metadata,
                "tags": tags,
                "span_attributes": {"name": "Chat Completion", "type": "llm"},
            }
            if choices is not None:
                request_data["output"] = [choice.dict() for choice in choices]
            else:
                request_data["output"] = output

            if metrics is not None:
                request_data["metrics"] = metrics

            if metrics is not None:
                request_data["metrics"] = metrics

            try:
                await global_braintrust_http_handler.post(
                    url=f"{self.api_base}/project_logs/{project_id}/insert",
                    json={"events": [request_data]},
                    headers=self.headers,
                )
            except httpx.HTTPStatusError as e:
                raise Exception(e.response.text)
        except Exception as e:
            raise e  # don't use verbose_logger.exception, if exception is raised

    def log_failure_event(self, kwargs, response_obj, start_time, end_time):
        return super().log_failure_event(kwargs, response_obj, start_time, end_time)
