# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import argparse
import logging
import os
from typing import List, Union

import coloredlogs
from constants import (
    AttentionInputIDs,
    AttentionOutputIDs,
    MultiHeadAttentionInputIDs,
    MultiHeadAttentionOutputIDs,
    Operators,
)
from onnx import helper, load_model
from onnx_model import NodeProto, OnnxModel
from shape_infer_helper import SymbolicShapeInferenceHelper

logger = logging.getLogger(__name__)


class PackingAttentionBase:
    def __init__(self, model: OnnxModel, attention_op_type: str):
        self.model: OnnxModel = model
        self.nodes_to_remove: List = []
        self.nodes_to_add: List = []
        self.prune_graph: bool = False
        self.node_name_to_graph_name: dict = {}
        self.this_graph_name: str = self.model.model.graph.name
        self.attention_op_type = attention_op_type
        self.attention_nodes = self.model.get_nodes_by_op_type(attention_op_type)

    def _try_getting_attention_mask(self) -> Union[str, None]:
        mask_index = (
            AttentionInputIDs.MASK_INDEX
            if self.attention_op_type == Operators.ATTENTION
            else MultiHeadAttentionInputIDs.KEY_PADDING_MASK
        )
        first_attention_node = self._try_getting_first_attention()
        # check if attention has mask
        if not first_attention_node or len(first_attention_node.input) <= mask_index:
            return None

        attention_mask = first_attention_node.input[mask_index]

        # check if all attention nodes have same mask
        for node in self.attention_nodes:
            if len(node.input) <= mask_index or node.input[mask_index] != attention_mask:
                return None

        return attention_mask

    def _try_getting_first_attention(self) -> Union[NodeProto, None]:
        if len(self.attention_nodes) <= 0:
            return None

        return self.attention_nodes[0]

    def _try_getting_last_layernorm(self) -> Union[NodeProto, None]:
        last_layernorm_node = None
        for node in self.model.nodes():
            if node.op_type == Operators.LAYERNORM or node.op_type == Operators.SKIPLAYERNORM:
                last_layernorm_node = node
        return last_layernorm_node

    def _are_attentions_supported(self) -> bool:
        raise NotImplementedError()

    def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
        new_node = helper.make_node(
            Operators.REMOVEPADDING,
            inputs=inputs,
            outputs=outputs,
            name=self.model.create_node_name(Operators.REMOVEPADDING),
        )

        new_node.domain = "com.microsoft"
        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

    def _insert_restorepadding_node(self, inputs: List[str], outputs: List[str]) -> None:
        new_node = helper.make_node(
            Operators.RESTOREPADDING,
            inputs=inputs,
            outputs=outputs,
            name=self.model.create_node_name(Operators.RESTOREPADDING),
        )

        new_node.domain = "com.microsoft"
        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

    def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
        raise NotImplementedError()

    def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
        if self.attention_op_type == Operators.ATTENTION:
            return first_attention_node.input[AttentionInputIDs.INPUT]
        return None

    def convert(self, use_symbolic_shape_infer: bool = True) -> None:
        logger.debug("start converting to packing model...")

        if not self._are_attentions_supported():
            return

        attention_mask = self._try_getting_attention_mask()
        if not attention_mask:
            return

        first_attention_node = self._try_getting_first_attention()
        last_layernorm_node = self._try_getting_last_layernorm()
        if not last_layernorm_node:
            return

        # insert RemovePadding
        input_to_remove_padding = self._get_input_to_remove_padding(first_attention_node)
        if not input_to_remove_padding:
            return

        output_without_padding = input_to_remove_padding + "_no_padding"
        token_offset = input_to_remove_padding + "_token_offset"
        cumulated_seq_len = input_to_remove_padding + "_cumulated_seq_len"
        max_seq_len = input_to_remove_padding + "_max_seq_len"
        self._insert_removepadding_node(
            [input_to_remove_padding, attention_mask],
            [output_without_padding, token_offset, cumulated_seq_len, max_seq_len],
        )
        self.model.replace_input_of_all_nodes(input_to_remove_padding, output_without_padding)
        logger.debug("inserted RemovePadding before Attention")

        # insert RestorePadding
        restorepadding_input = last_layernorm_node.output[0] + "_restore_input"
        self._insert_restorepadding_node([restorepadding_input, token_offset], [last_layernorm_node.output[0]])
        self.model.replace_output_of_all_nodes(last_layernorm_node.output[0], restorepadding_input)
        logger.debug(f"inserted RestorePadding after last {last_layernorm_node.op_type} layer")

        # insert PackedAttention
        self._replace_attention_with_packing_attention(token_offset, cumulated_seq_len)
        logger.debug(f"replaced {self.attention_op_type} with Packed{self.attention_op_type}")

        self.model.remove_nodes(self.nodes_to_remove)
        self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)

        if self.prune_graph:
            self.model.prune_graph()
        elif self.nodes_to_remove or self.nodes_to_add:
            self.model.update_graph()
        self.model.clean_shape_infer()
        if use_symbolic_shape_infer:
            # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
            # are not recognized by onnx shape inference.
            shape_infer_helper = SymbolicShapeInferenceHelper(self.model.model, verbose=0)
            inferred_model = shape_infer_helper.infer_shapes(self.model.model, auto_merge=True, guess_output_rank=False)
            if inferred_model:
                self.model.model = inferred_model


class PackingAttention(PackingAttentionBase):
    def __init__(self, model: OnnxModel):
        super().__init__(model, Operators.ATTENTION)

    def _are_attentions_supported(self) -> bool:
        for node in self.attention_nodes:
            if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None:
                return False
            if OnnxModel.get_node_attribute(node, "do_rotary") is not None:
                return False
            unidirection_attr = OnnxModel.get_node_attribute(node, "unidirectional")
            if unidirection_attr is not None and unidirection_attr != 0:
                return False
            if len(node.input) > AttentionInputIDs.PAST and not node.input[AttentionInputIDs.PAST]:
                return False
            if (
                len(node.input) > AttentionInputIDs.PAST_SEQUENCE_LENGTH
                and not node.input[AttentionInputIDs.PAST_SEQUENCE_LENGTH]
            ):
                return False
        return True

    def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
        for attention in self.attention_nodes:
            attention_bias = (
                attention.input[AttentionInputIDs.ATTENTION_BIAS]
                if len(attention.input) > AttentionInputIDs.ATTENTION_BIAS
                else ""
            )
            packed_attention = helper.make_node(
                Operators.PACKEDATTENTION,
                inputs=[
                    attention.input[AttentionInputIDs.INPUT],
                    attention.input[AttentionInputIDs.WEIGHTS],
                    attention.input[AttentionInputIDs.BIAS],
                    token_offset,
                    cumulative_sequence_length,
                    attention_bias,
                ],
                outputs=[attention.output[AttentionOutputIDs.OUTPUT]],
                name=self.model.create_node_name(Operators.PACKEDATTENTION),
            )

            attributes = []
            for attr in attention.attribute:
                if attr.name in ["num_heads", "qkv_hidden_sizes", "scale"]:
                    attributes.append(attr)

            packed_attention.attribute.extend(attributes)
            packed_attention.domain = "com.microsoft"
            self.nodes_to_add.append(packed_attention)
            self.nodes_to_remove.append(attention)
            self.node_name_to_graph_name[packed_attention.name] = self.this_graph_name

        logger.info("Converted %d Attention nodes to PackedAttention.", len(self.attention_nodes))


class PackingMultiHeadAttention(PackingAttentionBase):
    def __init__(self, model: OnnxModel):
        super().__init__(model, Operators.MULTI_HEAD_ATTENTION)

    def _check_empty_input(self, node, index: int, name: str):
        """Check a node does not have given input."""
        if len(node.input) > index:
            if len(node.input[index]) > 0:
                logger.error(f"node input {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
                return False
        return True

    def _check_empty_output(self, node, index: int, name: str):
        """Check a node does not have given input."""
        if len(node.output) > index:
            if len(node.output[index]) > 0:
                logger.error(f"node output {index} ({name}) is not supported in PackedMultiHeadAttention: {node}")
                return False
        return True

    def _are_attentions_supported(self) -> bool:
        for node in self.attention_nodes:
            for attr in node.attribute:
                if attr.name not in ["num_heads", "mask_filter_value", "scale"]:
                    logger.error(f"node attribute {attr.name} is not supported in PackedMultiHeadAttention: {node}")
                    return False

            if node.input[MultiHeadAttentionInputIDs.KEY] and not node.input[MultiHeadAttentionInputIDs.VALUE]:
                logger.error("packed kv format is not supported in PackedMultiHeadAttention")
                return False

            if not (
                self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_KEY, "past_key")
                and self._check_empty_input(node, MultiHeadAttentionInputIDs.PAST_VALUE, "past_key")
                and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_KEY, "present_key")
                and self._check_empty_output(node, MultiHeadAttentionOutputIDs.PRESENT_VALUE, "present_key")
            ):
                return False

        return True

    def _replace_attention_with_packing_attention(self, token_offset: str, cumulative_sequence_length: str) -> None:
        gated_relative_pos_bias_count = 0
        for mha in self.attention_nodes:
            attention_bias = (
                mha.input[MultiHeadAttentionInputIDs.ATTENTION_BIAS]
                if len(mha.input) > MultiHeadAttentionInputIDs.ATTENTION_BIAS
                else ""
            )
            packed_mha = helper.make_node(
                Operators.PACKED_MULTI_HEAD_ATTENTION,
                inputs=[
                    mha.input[MultiHeadAttentionInputIDs.QUERY],
                    mha.input[MultiHeadAttentionInputIDs.KEY],
                    mha.input[MultiHeadAttentionInputIDs.VALUE],
                    mha.input[MultiHeadAttentionInputIDs.BIAS],
                    token_offset,
                    cumulative_sequence_length,
                    attention_bias,
                ],
                outputs=[mha.output[MultiHeadAttentionOutputIDs.OUTPUT]],
                name=self.model.create_node_name(Operators.PACKED_MULTI_HEAD_ATTENTION),
            )

            attributes = []
            for attr in mha.attribute:
                if attr.name in ["num_heads", "mask_filter_value", "scale"]:
                    attributes.append(attr)

            packed_mha.attribute.extend(attributes)
            packed_mha.domain = "com.microsoft"
            self.nodes_to_add.append(packed_mha)
            self.nodes_to_remove.append(mha)
            self.node_name_to_graph_name[packed_mha.name] = self.this_graph_name

            # Append token_offset input to GatedRelativePositionBias
            if attention_bias:
                rel_pos_bias_node = self.model.get_parent(mha, MultiHeadAttentionInputIDs.ATTENTION_BIAS)
                if (
                    rel_pos_bias_node
                    and rel_pos_bias_node.op_type == "GatedRelativePositionBias"
                    and len(rel_pos_bias_node.input) == 6
                ):
                    rel_pos_bias_node.input.append(token_offset)
                    gated_relative_pos_bias_count += 1

        logger.info("Converted %d MultiHeadAttention nodes to PackedMultiHeadAttention.", len(self.attention_nodes))
        logger.info("Converted %d GatedRelativePositionBias nodes to packing mode.", gated_relative_pos_bias_count)

    def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None]:
        # When there are query, key and value inputs, we need to find the first input of the parent MatMul node.
        matmul = self.model.get_parent(first_attention_node, 0)
        if matmul and matmul.op_type == "MatMul":
            return matmul.input[0]
        return None


class PackingMode:
    def __init__(self, model: OnnxModel):
        self.model = model

    def convert(self, use_symbolic_shape_infer: bool = True) -> None:
        if self.model.get_nodes_by_op_type(Operators.ATTENTION):
            if self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
                logger.error("Packing mode does not support both Attention and MultiHeadAttention in same graph.")
                return None
            packing = PackingAttention(self.model)
            return packing.convert(use_symbolic_shape_infer)
        elif self.model.get_nodes_by_op_type(Operators.MULTI_HEAD_ATTENTION):
            packing = PackingMultiHeadAttention(self.model)
            return packing.convert(use_symbolic_shape_infer)
        else:
            logger.error("Packing mode requires either Attention or MultiHeadAttention node in onnx graph.")
            return None


def _parse_arguments():
    parser = argparse.ArgumentParser(
        description="Convert to packing mode tool for ONNX Runtime. It converts BERT like model to use packing mode."
    )
    parser.add_argument("--input", required=True, type=str, help="input onnx model path")

    parser.add_argument("--output", required=True, type=str, help="optimized onnx model path")

    parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.")
    parser.set_defaults(verbose=False)

    parser.add_argument(
        "--use_external_data_format",
        required=False,
        action="store_true",
        help="use external data format to store large model (>2GB)",
    )
    parser.set_defaults(use_external_data_format=False)

    args = parser.parse_args()

    return args


def _setup_logger(verbose):
    if verbose:
        coloredlogs.install(
            level="DEBUG",
            fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s",
        )
    else:
        coloredlogs.install(fmt="%(funcName)20s: %(message)s")


def main():
    args = _parse_arguments()

    _setup_logger(args.verbose)

    logger.debug(f"arguments:{args}")

    if os.path.realpath(args.input) == os.path.realpath(args.output):
        logger.warning("Specified the same input and output path. Note that this may overwrite the original model")

    model = load_model(args.input)
    packing_mode = PackingMode(OnnxModel(model))
    packing_mode.convert()
    packing_mode.model.save_model_to_file(args.output, use_external_data_format=args.use_external_data_format)


if __name__ == "__main__":
    main()
