# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.  See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------

import logging
import os
import tempfile
from pathlib import Path
from typing import List, Optional, Union

import numpy
import onnx
import torch
from io_binding_helper import TypeHelper
from models.t5.past_helper import PastKeyValuesHelper
from onnx_model import OnnxModel
from torch_onnx_export_helper import torch_onnx_export
from transformers import WhisperConfig, file_utils
from whisper_openai_helper import WhisperDecoderInitOpenai

from onnxruntime import InferenceSession

logger = logging.getLogger(__name__)


class WhisperDecoderInit(torch.nn.Module):
    """A Whisper decoder to create initial past key values.
    This model is only called once during starting decoding.
    """

    def __init__(
        self,
        decoder: torch.nn.Module,
        config: WhisperConfig,
        decoder_start_token_id: Optional[int] = None,
    ):
        super().__init__()
        self.decoder = decoder
        self.config = config
        self.decoder_start_token_id = (
            decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
        )

    def forward(
        self,
        decoder_input_ids: torch.Tensor,
        encoder_hidden_states: torch.FloatTensor,
    ):
        encoder_outputs = file_utils.ModelOutput()
        encoder_outputs["last_hidden_state"] = encoder_hidden_states
        encoder_outputs["hidden_states"] = None
        encoder_outputs["attentions"] = None

        out = self.decoder.model(
            None,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            past_key_values=None,
            use_cache=True,
            return_dict=True,
        )
        logits = self.decoder.proj_out(out[0])
        return logits, out.past_key_values, out.encoder_last_hidden_state


class WhisperDecoder(torch.nn.Module):
    """A Whisper decoder with past key values"""

    def __init__(self, decoder, config, model_impl: str = "hf", model: torch.nn.Module = None):
        super().__init__()
        self.decoder = decoder
        self.config = config
        self.model_impl = model_impl
        if model is not None:
            self.whisper_decoder_openai_init = WhisperDecoderInitOpenai(model, decoder)

    def forward(self, decoder_input_ids, *past):
        encoder_outputs = file_utils.ModelOutput()
        dummy_encoder_hidden_states = torch.randn((decoder_input_ids.shape[0], 3000, int(self.config.d_model)))
        encoder_outputs["last_hidden_state"] = dummy_encoder_hidden_states
        encoder_outputs["hidden_states"] = dummy_encoder_hidden_states
        encoder_outputs["attentions"] = None

        if self.model_impl == "openai":
            dummy_encoder_hidden_states.unsqueeze(0)
            dec_out, present = self.whisper_decoder_openai_init(
                decoder_input_ids, dummy_encoder_hidden_states, past=past
            )
            return dec_out, present

        if len(past) == 0:
            past_key_values = None
        else:
            past_key_values = PastKeyValuesHelper.back_group_by_layer(past)

        decoder_out = self.decoder(
            None,
            encoder_outputs=encoder_outputs,
            decoder_input_ids=decoder_input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            return_dict=True,
        )
        logits = decoder_out[0]
        present_self, _ = PastKeyValuesHelper.group_by_self_and_cross(decoder_out.past_key_values)
        return logits, present_self


class WhisperDecoderInputs:
    def __init__(
        self,
        decoder_input_ids,
        past_key_values=None,
    ):
        self.decoder_input_ids: torch.LongTensor = decoder_input_ids
        self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values

    @staticmethod
    def create_dummy(
        config: WhisperConfig,
        batch_size: int,
        encode_sequence_length: int,
        past_decode_sequence_length: int,
        device: torch.device,
        float16: bool = False,
        use_int32_inputs: bool = False,
        model_impl: str = "hf",
    ):  # -> WhisperDecoderInputs:
        """Create dummy inputs for WhisperDecoder.

        Args:
            decoder: decoder
            batch_size (int): batch size
            encode_sequence_length (int): sequence length of input_ids for encoder
            past_decode_sequence_length (int): past sequence length of input_ids for decoder
            device (torch.device): device of output tensors
            float16 (bool): whether the model uses float32 or float16 in input
            use_int32_inputs(bool): whether use int32 instead of int64 for some inputs

        Returns:
            WhisperDecoderInputs: dummy inputs for decoder
        """
        num_attention_heads: int = config.encoder_attention_heads
        num_layers: int = config.decoder_layers  # + config.encoder_layers
        vocab_size: int = config.vocab_size

        # Use head_size, use hidden_size / num_attention_heads here.
        # For example, whisper-large, d_model=1280 and num_heads=20
        head_size: int = config.d_model // config.encoder_attention_heads

        sequence_length: int = 1  # fixed for decoding
        decoder_input_ids = torch.randint(
            low=0,
            high=vocab_size - 1,
            size=(batch_size, sequence_length),
            dtype=(torch.int32 if use_int32_inputs else torch.int64),
            device=device,
        )

        float_type = torch.float16 if float16 else torch.float32

        if past_decode_sequence_length > 0:
            self_attention_past_shape = [
                batch_size,
                num_attention_heads,
                past_decode_sequence_length,
                head_size,
            ]
            cross_attention_past_shape = [
                batch_size,
                num_attention_heads,
                encode_sequence_length if model_impl == "hf" else past_decode_sequence_length,
                head_size,
            ]

            past = []
            for _ in range(2 * num_layers):
                past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))

            for _ in range(2 * num_layers):
                past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
        else:
            past = None

        return WhisperDecoderInputs(decoder_input_ids, past)

    def to_list(self) -> List:
        input_list = [self.decoder_input_ids]
        if self.past_key_values:
            input_list.extend(self.past_key_values)
        return input_list

    def to_fp32(self):
        past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
        return WhisperDecoderInputs(
            self.decoder_input_ids.clone(),
            past,
        )


class WhisperDecoderHelper:
    @staticmethod
    def export_onnx(
        decoder: WhisperDecoder,
        device: torch.device,
        onnx_model_path: str,
        verbose: bool = True,
        use_external_data_format: bool = False,
        use_int32_inputs: bool = False,
    ):
        """Export decoder to ONNX

        Args:
            decoder (Union[WhisperDecoder, WhisperDecoderNoPastState]): decoder object
            device (torch.device): device of decoder object
            onnx_model_path (str): onnx path
            verbose (bool, optional): print verbose information. Defaults to True.
            use_external_data_format (bool, optional): use external data format or not. Defaults to False.
            use_int32_inputs (bool, optional): use int32 inputs
        """
        assert isinstance(decoder, (WhisperDecoder, WhisperDecoderInit))

        inputs = WhisperDecoderInputs.create_dummy(
            decoder.config,
            batch_size=2,
            encode_sequence_length=3000,
            past_decode_sequence_length=6 if isinstance(decoder, WhisperDecoder) else 0,
            device=device,
            use_int32_inputs=use_int32_inputs,
            model_impl=decoder.model_impl,
        )
        input_list = inputs.to_list()

        # Fix past disappearing bug - duplicate first past entry
        # input_list.insert(2, input_list[2])

        past_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=False)
        present_names = PastKeyValuesHelper.get_past_names(decoder.config.decoder_layers, present=True)
        present_self_names = present_names[: 2 * decoder.config.decoder_layers]

        input_past_names = past_names if isinstance(decoder, WhisperDecoder) else []
        output_present_names = present_self_names if isinstance(decoder, WhisperDecoder) else present_names
        output_names = ["logits", *output_present_names]

        # Shape of input tensors (sequence_length==1):
        #    input_ids: (batch_size, sequence_length)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)

        # Shape of output tensors:
        #    logits: (batch_size, sequence_length, vocab_size)
        #    past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
        #    past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)

        input_names = ["input_ids"]
        input_names.extend(input_past_names)

        dynamic_axes = {
            "input_ids": {0: "batch_size"},
            "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length / 2"},
            "logits": {0: "batch_size", 1: "sequence_length"},
        }

        for name in input_past_names:
            dynamic_axes[name] = {
                0: "batch_size",
                2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
            }

        for name in output_present_names:
            if "cross" in name:
                dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
            else:  # self attention past state
                if isinstance(decoder, WhisperDecoder):
                    dynamic_axes[name] = {
                        0: "batch_size",
                        2: "past_decode_sequence_length + 1",
                    }
                else:
                    dynamic_axes[name] = {
                        0: "batch_size",
                        # 2: 'sequence_length'
                    }

        Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)

        with tempfile.TemporaryDirectory() as tmp_dir_name:
            temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
            Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
            torch_onnx_export(
                decoder,
                args=tuple(input_list),
                f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
                export_params=True,
                input_names=input_names,
                output_names=output_names,
                dynamic_axes=dynamic_axes,
                opset_version=17,
                do_constant_folding=True,
                use_external_data_format=use_external_data_format,
                verbose=verbose,
            )

            if use_external_data_format:
                model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
                OnnxModel.save(
                    model,
                    onnx_model_path,
                    save_as_external_data=True,
                    all_tensors_to_one_file=True,
                )

    @staticmethod
    def onnxruntime_inference(ort_session, inputs: WhisperDecoderInputs):
        """Run inference of ONNX model."""
        logger.debug("start onnxruntime_inference")

        ort_inputs = {
            "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
        }

        if inputs.past_key_values:
            assert len(inputs.past_key_values) % 4 == 0
            num_layers = int(len(inputs.past_key_values) / 4)
            past_names = PastKeyValuesHelper.get_past_names(num_layers)
            for i, past_tensor in enumerate(inputs.past_key_values):
                ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())

        ort_outputs = ort_session.run(None, ort_inputs)
        return ort_outputs

    @staticmethod
    def verify_onnx(
        model: Union[WhisperDecoder, WhisperDecoderInit],
        ort_session: InferenceSession,
        device: torch.device,
        use_int32_inputs: bool,
        max_cases: int = 4,
    ):
        """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
        float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"

        test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
        test_cases_max_diff = []
        for (
            batch_size,
            encode_sequence_length,
            past_decode_sequence_length,
        ) in test_cases[:max_cases]:
            if isinstance(model, WhisperDecoderInit):
                dec_seq_len = 0
            else:
                dec_seq_len = past_decode_sequence_length

            inputs = WhisperDecoderInputs.create_dummy(
                model.config,
                batch_size,
                encode_sequence_length,
                dec_seq_len,
                device=device,
                float16=float16,
                use_int32_inputs=use_int32_inputs,
            )

            # We use fp32 PyTroch model as baseline even when ONNX model is fp16
            input_list = inputs.to_fp32().to_list()

            # Run inference of PyTorch model
            with torch.no_grad():
                torch_outputs = model(*input_list)

            ort_outputs = WhisperDecoderHelper.onnxruntime_inference(ort_session, inputs)

            max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
            max_diff_all = max_diff
            logger.debug(f"logits max_diff={max_diff}")

            for i in range(2 * model.config.num_layers):
                max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
                logger.debug(f"self attention past state {i} max_diff={max_diff}")
                max_diff_all = max(max_diff_all, max_diff)

            if isinstance(model, WhisperDecoderInit):
                for i in range(2 * model.config.num_layers):
                    max_diff = numpy.amax(
                        numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])
                    )
                    logger.debug(f"cross attention past state {i} max_diff={max_diff}")
                    max_diff_all = max(max_diff_all, max_diff)

            test_cases_max_diff.append(max_diff_all)
            logger.info(
                "batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
                batch_size,
                encode_sequence_length,
                past_decode_sequence_length,
                max_diff_all,
            )

        return max_diff_all
