"""Code generated by Speakeasy (https://speakeasyapi.dev). DO NOT EDIT."""

import json
from typing import Optional, Tuple, Union

import google.auth
import google.auth.credentials
import google.auth.transport
import google.auth.transport.requests
import httpx
from mistralai_gcp import models
from mistralai_gcp._hooks import BeforeRequestHook, SDKHooks
from mistralai_gcp.chat import Chat
from mistralai_gcp.fim import Fim
from mistralai_gcp.types import Nullable

from .basesdk import BaseSDK
from .httpclient import AsyncHttpClient, HttpClient
from .sdkconfiguration import SDKConfiguration
from .utils.logger import Logger, NoOpLogger
from .utils.retries import RetryConfig

LEGACY_MODEL_ID_FORMAT = {
    "codestral-2405": "codestral@2405",
    "mistral-large-2407": "mistral-large@2407",
    "mistral-nemo-2407": "mistral-nemo@2407",
}

def get_model_info(model: str) -> Tuple[str, str]:
    # if the model requiers the legacy fomat, use it, else do nothing.
    if model in LEGACY_MODEL_ID_FORMAT:
        return "-".join(model.split("-")[:-1]), LEGACY_MODEL_ID_FORMAT[model]
    else:
        return model, model



class MistralGoogleCloud(BaseSDK):
    r"""Mistral AI API: Our Chat Completion and Embeddings APIs specification. Create your account on [La Plateforme](https://console.mistral.ai) to get access and read the [docs](https://docs.mistral.ai) to learn how to use it."""

    chat: Chat
    fim: Fim
    r"""Chat Completion API"""

    def __init__(
        self,
        region: str = "europe-west4",
        project_id: Optional[str] = None,
        access_token: Optional[str] = None,
        client: Optional[HttpClient] = None,
        async_client: Optional[AsyncHttpClient] = None,
        retry_config: Optional[Nullable[RetryConfig]] = None,
        debug_logger: Optional[Logger] = None,
    ) -> None:
        r"""Instantiates the SDK configuring it with the provided parameters.

        :param region: The Google Cloud region to use for all methods
        :param project_id: The project ID to use for all methods
        :param client: The HTTP client to use for all synchronous methods
        :param async_client: The Async HTTP client to use for all asynchronous methods
        :param retry_config: The retry configuration to use for all supported methods
        """

        if not access_token:
            credentials, loaded_project_id = google.auth.default(
                scopes=["https://www.googleapis.com/auth/cloud-platform"],
            )
            credentials.refresh(google.auth.transport.requests.Request())

            if not isinstance(credentials, google.auth.credentials.Credentials):
                raise models.SDKError(
                    "credentials must be an instance of google.auth.credentials.Credentials"
                )

            project_id = project_id or loaded_project_id
        if project_id is None:
            raise models.SDKError("project_id must be provided")

        def auth_token() -> str:
            if access_token:
                return access_token
            credentials.refresh(google.auth.transport.requests.Request())
            token = credentials.token
            if not token:
                raise models.SDKError("Failed to get token from credentials")
            return token

        if client is None:
            client = httpx.Client()

        assert issubclass(
            type(client), HttpClient
        ), "The provided client must implement the HttpClient protocol."

        if async_client is None:
            async_client = httpx.AsyncClient()

        if debug_logger is None:
            debug_logger = NoOpLogger()

        assert issubclass(
            type(async_client), AsyncHttpClient
        ), "The provided async_client must implement the AsyncHttpClient protocol."

        security = None
        if callable(auth_token):
            security = lambda: models.Security(  # pylint: disable=unnecessary-lambda-assignment
                api_key=auth_token()
            )
        else:
            security = models.Security(api_key=auth_token)

        BaseSDK.__init__(
            self,
            SDKConfiguration(
                client=client,
                async_client=async_client,
                security=security,
                server_url=f"https://{region}-aiplatform.googleapis.com",
                server=None,
                retry_config=retry_config,
                debug_logger=debug_logger,
            ),
        )

        hooks = SDKHooks()

        hook = GoogleCloudBeforeRequestHook(region, project_id)
        hooks.register_before_request_hook(hook)

        current_server_url, *_ = self.sdk_configuration.get_server_details()
        server_url, self.sdk_configuration.client = hooks.sdk_init(
            current_server_url, self.sdk_configuration.client
        )
        if current_server_url != server_url:
            self.sdk_configuration.server_url = server_url

        # pylint: disable=protected-access
        self.sdk_configuration.__dict__["_hooks"] = hooks

        self._init_sdks()

    def _init_sdks(self):
        self.chat = Chat(self.sdk_configuration)
        self.fim = Fim(self.sdk_configuration)


class GoogleCloudBeforeRequestHook(BeforeRequestHook):

    def __init__(self, region: str, project_id: str):
        self.region = region
        self.project_id = project_id

    def before_request(
        self, hook_ctx, request: httpx.Request
    ) -> Union[httpx.Request, Exception]:
        # The goal of this function is to template in the region, project and model into the URL path
        # We do this here so that the API remains more user-friendly
        model_id = None
        new_content = None
        if request.content:
            parsed = json.loads(request.content.decode("utf-8"))
            model_raw = parsed.get("model")
            model_name, model_id = get_model_info(model_raw)
            parsed["model"] = model_name
            new_content = json.dumps(parsed).encode("utf-8")

        if model_id == "":
            raise models.SDKError("model must be provided")


        stream = "streamRawPredict" in request.url.path
        specifier = "streamRawPredict" if stream else "rawPredict"
        url = f"/v1/projects/{self.project_id}/locations/{self.region}/publishers/mistralai/models/{model_id}:{specifier}"

        headers = dict(request.headers)
        # Delete content-length header as it will need to be recalculated
        headers.pop("content-length", None)

        next_request = httpx.Request(
            method=request.method,
            url=request.url.copy_with(path=url),
            headers=headers,
            content=new_content,
            stream=None,
        )

        return next_request
