# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from logging import getLogger

from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionSkipLayerNormalization(Fusion):
    """
    Fuse Add + LayerNormalization into one node: SkipLayerNormalization
    Note: This fusion does not check the input shape of Add and LayerNormalization.
    """

    def __init__(
        self,
        model: OnnxModel,
        fused_op_type: str = "SkipLayerNormalization",
        search_op_types: str = "LayerNormalization",
        shape_infer: bool = True,
    ):
        super().__init__(model, fused_op_type, search_op_types)
        if shape_infer:
            # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
            self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
            if self.shape_infer_helper is None:
                # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
                logger.warning("symbolic shape inference disabled or failed.")

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        add = self.model.get_parent(node, 0, output_name_to_node)

        # In some models there is input_ids->gather->add->LayerNorm and one of input of the
        # add node is initializer with fixed shape which should not be fused into SkipLayerNorm
        if add is None or add.op_type != "Add":
            return

        # The number of inputs of add should be 2
        if len(add.input) != 2:
            return

        for add_input in add.input:
            if self.model.get_initializer(add_input) is not None:
                return

        # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization
        if add in self.nodes_to_remove:
            return

        # Root Mean Square Layer Normalization
        simplified = node.op_type == "SimplifiedLayerNormalization"

        if hasattr(self, "shape_infer_helper"):
            if self.shape_infer_helper is not None:
                if (
                    self.shape_infer_helper.get_edge_shape(add.input[0])
                    and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
                ):
                    logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
                    return

                # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
                if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
                    logger.debug(
                        "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
                        add.input[0],
                        add.input[1],
                    )
                    return
            else:
                logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
                return

        gather_path = self.model.match_parent_path(add, ["Gather"], [None])
        if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
            if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
                return

        # This means that the residual Add before the LayerNormalization produces an output
        # that is consumed by some other nodes or graph output other than the LayerNormalization itself
        # We can still go ahead with the SkipLayerNormalization fusion but we need to
        # preserve the output of Add and that needs to be produced by SkipLayerNormalization.
        add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None
        residual_add_has_multiple_consumers = (
            add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1
        )

        outputs_to_keep = node.output

        if residual_add_has_multiple_consumers:
            outputs_to_keep.extend([add.output[0]])

        outputs = [node.output[0]]

        # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
        if residual_add_has_multiple_consumers:
            outputs.extend(["", "", add.output[0]])

        if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
            self.nodes_to_remove.extend([add, node])

            inputs = (
                [add.input[0], add.input[1], node.input[1], node.input[2]]
                if not simplified
                else [add.input[0], add.input[1], node.input[1]]
            )
            normalize_node = helper.make_node(
                self.fused_op_type,
                inputs=inputs,
                outputs=outputs,
                name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
            )
            normalize_node.domain = "com.microsoft"

            # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
            for att in node.attribute:
                if att.name == "epsilon":
                    normalize_node.attribute.extend([att])

            # Set default epsilon if no epsilon exists from layernorm
            if len(normalize_node.attribute) == 0:
                normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])

            self.nodes_to_add.append(normalize_node)
            self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name


class FusionBiasSkipLayerNormalization(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        if len(node.input) != 4:
            return

        return_indice = []
        nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice)
        if nodes is not None:
            (add, _matmul) = nodes
        else:
            # In case of fp16, we could have a Cast between the MatMul and the bias Add
            return_indice = []
            nodes = self.model.match_parent_path(
                node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice
            )
            if nodes is not None:
                (add, _cast, _matmul) = nodes
            else:
                return

        assert len(return_indice) == 2 or len(return_indice) == 3
        add_input_index = return_indice[0]
        if add_input_index >= 2:
            return
        sln_input = add.input[return_indice[1]]
        bias_input = add.input[1 - return_indice[1]]
        skip_input = node.input[1 - add_input_index]

        # bias should be one dimension
        initializer = self.model.get_initializer(bias_input)
        if initializer is None:
            return
        bias_weight = NumpyHelper.to_array(initializer)
        if bias_weight is None:
            logger.debug("Bias weight not found")
            return
        if len(bias_weight.shape) != 1:
            logger.debug("Bias weight is not 1D")
            return

        subgraph_nodes = [node, add]
        if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
            logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe")
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        inputs = [
            sln_input,
            skip_input,
            node.input[2],
            node.input[3],
            bias_input,
        ]
        new_node = helper.make_node(
            "SkipLayerNormalization",
            inputs=inputs,
            outputs=node.output,
            name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
        )
        new_node.domain = "com.microsoft"

        # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
        for att in node.attribute:
            if att.name == "epsilon":
                new_node.attribute.extend([att])

        # Set default epsilon if no epsilon exists from skiplayernorm
        if len(new_node.attribute) == 0:
            new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name
