from typing import Literal, Optional, Union

import httpx

import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
    AsyncHTTPHandler,
    HTTPHandler,
    _get_httpx_client,
    get_async_httpx_client,
)
from litellm.llms.vertex_ai.vertex_ai_non_gemini import VertexAIError
from litellm.llms.vertex_ai.vertex_llm_base import VertexBase
from litellm.types.llms.vertex_ai import *
from litellm.types.utils import EmbeddingResponse

from .types import *


class VertexEmbedding(VertexBase):
    def __init__(self) -> None:
        super().__init__()

    def embedding(
        self,
        model: str,
        input: Union[list, str],
        print_verbose,
        model_response: EmbeddingResponse,
        optional_params: dict,
        logging_obj: LiteLLMLoggingObject,
        custom_llm_provider: Literal[
            "vertex_ai", "vertex_ai_beta", "gemini"
        ],  # if it's vertex_ai or gemini (google ai studio)
        timeout: Optional[Union[float, httpx.Timeout]],
        api_key: Optional[str] = None,
        encoding=None,
        aembedding=False,
        api_base: Optional[str] = None,
        client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
        vertex_project: Optional[str] = None,
        vertex_location: Optional[str] = None,
        vertex_credentials: Optional[str] = None,
        gemini_api_key: Optional[str] = None,
        extra_headers: Optional[dict] = None,
    ) -> EmbeddingResponse:
        if aembedding is True:
            return self.async_embedding(  # type: ignore
                model=model,
                input=input,
                logging_obj=logging_obj,
                model_response=model_response,
                optional_params=optional_params,
                encoding=encoding,
                custom_llm_provider=custom_llm_provider,
                timeout=timeout,
                api_base=api_base,
                vertex_project=vertex_project,
                vertex_location=vertex_location,
                vertex_credentials=vertex_credentials,
                gemini_api_key=gemini_api_key,
                extra_headers=extra_headers,
            )

        should_use_v1beta1_features = self.is_using_v1beta1_features(
            optional_params=optional_params
        )

        _auth_header, vertex_project = self._ensure_access_token(
            credentials=vertex_credentials,
            project_id=vertex_project,
            custom_llm_provider=custom_llm_provider,
        )
        auth_header, api_base = self._get_token_and_url(
            model=model,
            gemini_api_key=gemini_api_key,
            auth_header=_auth_header,
            vertex_project=vertex_project,
            vertex_location=vertex_location,
            vertex_credentials=vertex_credentials,
            stream=False,
            custom_llm_provider=custom_llm_provider,
            api_base=api_base,
            should_use_v1beta1_features=should_use_v1beta1_features,
            mode="embedding",
        )
        headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
        vertex_request: VertexEmbeddingRequest = (
            litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
                input=input, optional_params=optional_params, model=model
            )
        )

        _client_params = {}
        if timeout:
            _client_params["timeout"] = timeout
        if client is None or not isinstance(client, HTTPHandler):
            client = _get_httpx_client(params=_client_params)
        else:
            client = client  # type: ignore
        ## LOGGING
        logging_obj.pre_call(
            input=vertex_request,
            api_key="",
            additional_args={
                "complete_input_dict": vertex_request,
                "api_base": api_base,
                "headers": headers,
            },
        )

        try:
            response = client.post(api_base, headers=headers, json=vertex_request)  # type: ignore
            response.raise_for_status()
        except httpx.HTTPStatusError as err:
            error_code = err.response.status_code
            raise VertexAIError(status_code=error_code, message=err.response.text)
        except httpx.TimeoutException:
            raise VertexAIError(status_code=408, message="Timeout error occurred.")

        _json_response = response.json()
        ## LOGGING POST-CALL
        logging_obj.post_call(
            input=input, api_key=None, original_response=_json_response
        )

        model_response = (
            litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
                response=_json_response, model=model, model_response=model_response
            )
        )

        return model_response

    async def async_embedding(
        self,
        model: str,
        input: Union[list, str],
        model_response: litellm.EmbeddingResponse,
        logging_obj: LiteLLMLoggingObject,
        optional_params: dict,
        custom_llm_provider: Literal[
            "vertex_ai", "vertex_ai_beta", "gemini"
        ],  # if it's vertex_ai or gemini (google ai studio)
        timeout: Optional[Union[float, httpx.Timeout]],
        api_base: Optional[str] = None,
        client: Optional[AsyncHTTPHandler] = None,
        vertex_project: Optional[str] = None,
        vertex_location: Optional[str] = None,
        vertex_credentials: Optional[str] = None,
        gemini_api_key: Optional[str] = None,
        extra_headers: Optional[dict] = None,
        encoding=None,
    ) -> litellm.EmbeddingResponse:
        """
        Async embedding implementation
        """
        should_use_v1beta1_features = self.is_using_v1beta1_features(
            optional_params=optional_params
        )
        _auth_header, vertex_project = await self._ensure_access_token_async(
            credentials=vertex_credentials,
            project_id=vertex_project,
            custom_llm_provider=custom_llm_provider,
        )
        auth_header, api_base = self._get_token_and_url(
            model=model,
            gemini_api_key=gemini_api_key,
            auth_header=_auth_header,
            vertex_project=vertex_project,
            vertex_location=vertex_location,
            vertex_credentials=vertex_credentials,
            stream=False,
            custom_llm_provider=custom_llm_provider,
            api_base=api_base,
            should_use_v1beta1_features=should_use_v1beta1_features,
            mode="embedding",
        )
        headers = self.set_headers(auth_header=auth_header, extra_headers=extra_headers)
        vertex_request: VertexEmbeddingRequest = (
            litellm.vertexAITextEmbeddingConfig.transform_openai_request_to_vertex_embedding_request(
                input=input, optional_params=optional_params, model=model
            )
        )

        _async_client_params = {}
        if timeout:
            _async_client_params["timeout"] = timeout
        if client is None or not isinstance(client, AsyncHTTPHandler):
            client = get_async_httpx_client(
                params=_async_client_params, llm_provider=litellm.LlmProviders.VERTEX_AI
            )
        else:
            client = client  # type: ignore
        ## LOGGING
        logging_obj.pre_call(
            input=vertex_request,
            api_key="",
            additional_args={
                "complete_input_dict": vertex_request,
                "api_base": api_base,
                "headers": headers,
            },
        )

        try:
            response = await client.post(api_base, headers=headers, json=vertex_request)  # type: ignore
            response.raise_for_status()
        except httpx.HTTPStatusError as err:
            error_code = err.response.status_code
            raise VertexAIError(status_code=error_code, message=err.response.text)
        except httpx.TimeoutException:
            raise VertexAIError(status_code=408, message="Timeout error occurred.")

        _json_response = response.json()
        ## LOGGING POST-CALL
        logging_obj.post_call(
            input=input, api_key=None, original_response=_json_response
        )

        model_response = (
            litellm.vertexAITextEmbeddingConfig.transform_vertex_response_to_openai(
                response=_json_response, model=model, model_response=model_response
            )
        )

        return model_response
