# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, List

from fusion_base import Fusion
from onnx import TensorProto, helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionLayerNormalization(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "LayerNormalization", "ReduceMean")

    def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
        """
        Fuse Layer Normalization subgraph into one node LayerNormalization:
              +----------------------+
              |                      |
              |                      v
          [Root] --> ReduceMean -->  Sub  --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
                     (axis=2 or -1)  |      (Y=2)   (axis=2 or -1)  (E-6 or E-12 or 0)    ^
                                     |                                               |
                                     +-----------------------------------------------+

         It also handles cases of duplicated sub nodes exported from older version of PyTorch:
              +----------------------+
              |                      v
              |           +-------> Sub-----------------------------------------------+
              |           |                                                           |
              |           |                                                           v
          [Root] --> ReduceMean -->  Sub  --> Pow --> ReduceMean --> Add --> Sqrt --> Div  --> Mul --> Add
              |                      ^
              |                      |
              +----------------------+
        """
        subgraph_nodes = []
        children = self.model.get_children(node, input_name_to_nodes)
        if len(children) == 0 or len(children) > 2:
            return

        root_input = node.input[0]

        if children[0].op_type != "Sub" or children[0].input[0] != root_input:
            return

        if len(children) == 2:
            if children[1].op_type != "Sub" or children[1].input[0] != root_input:
                return

        div_node = None
        for child in children:
            # Check if Sub --> Div exists
            div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)

            # Check if Sub --> Cast --> Div
            div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])

            if div_node_1 is not None:
                div_node = div_node_1
            elif div_node_2 is not None:
                div_node = div_node_2[-1]
        if div_node is None:
            return

        path_id, parent_nodes, _ = self.model.match_parent_paths(
            div_node,
            [
                (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
                (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
            ],
            output_name_to_node,
        )
        if path_id < 0:
            return

        sub_node = parent_nodes[-1]
        if sub_node not in children:
            return

        second_add_node = parent_nodes[1]
        i, add_weight = self.model.get_constant_input(second_add_node)
        if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
            logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}")
            return

        pow_node = parent_nodes[3]
        if self.model.find_constant_input(pow_node, 2.0) != 1:
            return

        temp_node = input_name_to_nodes[div_node.output[0]][0]
        if temp_node.op_type == "Cast":
            # Div --> Cast --> Mul
            subgraph_nodes.append(temp_node)  # add Cast node to list of subgraph nodes
            mul_node = input_name_to_nodes[temp_node.output[0]][0]
        else:
            # Div --> Mul
            mul_node = temp_node
        if mul_node.op_type != "Mul":
            return

        last_add_node = input_name_to_nodes[mul_node.output[0]][0]
        if last_add_node.op_type != "Add":
            return

        subgraph_nodes.append(node)
        subgraph_nodes.extend(children)
        subgraph_nodes.extend(parent_nodes[:-1])

        subgraph_nodes.extend([last_add_node, mul_node, div_node])
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            last_add_node.output,
            input_name_to_nodes,
            output_name_to_node,
        ):
            logger.debug("It is not safe to fuse LayerNormalization node. Skip")
            return

        node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
        weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
        if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
            return

        bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
        if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"):
            return

        self.nodes_to_remove.extend(subgraph_nodes)

        normalize_node = helper.make_node(
            "LayerNormalization",
            inputs=[node.input[0], weight_input, bias_input],
            outputs=[last_add_node.output[0]],
            name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
        )
        normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
        self.nodes_to_add.append(normalize_node)
        self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name


class FusionLayerNormalizationNCHW(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "LayerNormalization", "ReduceMean")

    def get_weight_or_bias(self, output_name, description):
        value = self.model.get_constant_value(output_name)
        if value is None:
            logger.debug(f"{description} {output_name} is not initializer.")
            return None

        if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1:
            logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}")
            return None

        return value.reshape([value.shape[0]])

    def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
        """Append a Transpose node after an input"""
        node_name = self.model.create_node_name("Transpose")

        if output_name is None:
            output_name = node_name + "_out" + "-" + input_name

        transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
        transpose_node.attribute.extend([helper.make_attribute("perm", perm)])

        return transpose_node

    def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
        """
        Fuse Layer Normalization subgraph into one node LayerNormalization:
              +----------------------+
              | NxCxHxW              |
              |                      v                                                     (Cx1x1)  (Cx1x1)
          [Root] --> ReduceMean -->  Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add -->
                     (axes=1)        |      (Y=2)     (axes=1)     (E-6)             ^
                                     |                                               |
                                     +-----------------------------------------------+

        Fused subgraph:
                       (0,2,3,1)                            (0,3,1,2)
            [Root] --> Transpose --> LayerNormalization --> Transpose -->
        """
        axes = OnnxModel.get_node_attribute(node, "axes")
        if (not isinstance(axes, list)) or axes != [1]:
            return

        subgraph_nodes = []
        children = self.model.get_children(node, input_name_to_nodes)
        if len(children) != 1:
            return

        root_input = node.input[0]

        if children[0].op_type != "Sub" or children[0].input[0] != root_input:
            return
        sub = children[0]

        div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False)
        if div_node is None:
            return

        parent_nodes = self.model.match_parent_path(
            div_node,
            ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"],
            [1, 0, 0, 0, 0],
            output_name_to_node,
        )
        if parent_nodes is None:
            return

        _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes
        if sub != sub_node:
            return

        i, add_weight = self.model.get_constant_input(second_add_node)
        if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
            logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}")
            return

        axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes")
        assert isinstance(axes, list)
        if axes != [1]:
            return

        if self.model.find_constant_input(pow_node, 2.0) != 1:
            return

        temp_node = input_name_to_nodes[div_node.output[0]][0]
        mul_node = temp_node
        if mul_node.op_type != "Mul":
            return

        last_add_node = input_name_to_nodes[mul_node.output[0]][0]
        if last_add_node.op_type != "Add":
            return

        subgraph_nodes.append(node)
        subgraph_nodes.extend(parent_nodes)
        subgraph_nodes.extend([last_add_node, mul_node, div_node])

        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            last_add_node.output,
            input_name_to_nodes,
            output_name_to_node,
        ):
            logger.debug("It is not safe to fuse LayerNormalization node. Skip")
            return

        node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
        weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
        weight = self.get_weight_or_bias(weight_input, "layernorm weight")
        if weight is None:
            return

        bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
        bias = self.get_weight_or_bias(bias_input, "layernorm bias")
        if bias is None:
            return

        weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)

        bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
        self.model.add_initializer(weight_nhwc, self.this_graph_name)
        self.model.add_initializer(bias_nhwc, self.this_graph_name)

        self.nodes_to_remove.extend(subgraph_nodes)

        transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1])

        layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")

        transpose_output = self.create_transpose_node(
            layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0]
        )

        normalize_node = helper.make_node(
            "LayerNormalization",
            inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"],
            outputs=[layernorm_node_name + "_out_nhwc"],
            name=layernorm_node_name,
        )
        normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])

        self.nodes_to_add.append(transpose_input)
        self.nodes_to_add.append(normalize_node)
        self.nodes_to_add.append(transpose_output)
        self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
        self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
        self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name

        counter_name = "LayerNormalization(NHWC)"
        self.increase_counter(counter_name)


class FusionLayerNormalizationTF(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "LayerNormalization", "Add", "TF")

    def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
        """
         Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
          +------------------------------------+
          |                                    |
          |                                    |
        (Cast_1)                               |
          |                                    |
          |                                    v                                           (B)                             (B)             (A)
         Add --> (Cast_1) --> ReduceMean -->  Sub  --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
          |                       |                                                                                         |       ^              ^
          |                       |                                                                                         |       |              |
          |                       +--------------------------------------------------(Cast_2)-------------------------------|-------+              |
          |                                                                                                                 v                      |
          +---------------------------------------------------------------------------------------------------------------> Mul--------------------+
        """
        return_indice = []
        _, parent_nodes, return_indice = self.model.match_parent_paths(
            node,
            [
                (
                    [
                        "Sub",
                        "Mul",
                        "Mul",
                        "Reciprocal",
                        "Sqrt",
                        "Add",
                        "ReduceMean",
                        "Mul",
                        "Sub",
                        "ReduceMean",
                    ],
                    [1, 1, None, 0, 0, 0, None, 0, 0, None],
                ),
                (
                    [
                        "Sub",
                        "Mul",
                        "Mul",
                        "Reciprocal",
                        "Sqrt",
                        "Add",
                        "Cast",
                        "ReduceMean",
                        "Mul",
                        "Sub",
                        "ReduceMean",
                    ],
                    [1, 1, None, 0, 0, 0, 0, None, 0, 0, None],
                ),
            ],
            output_name_to_node,
        )

        if parent_nodes is None:
            return

        assert len(return_indice) == 3
        if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
            logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
            return

        (
            sub_node_0,
            mul_node_0,
            mul_node_1,
            reciprocol_node,
            sqrt_node,
            add_node_0,
        ) = parent_nodes[:6]
        reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]

        cast_node_3 = None
        if len(parent_nodes) == 11:
            cast_node_3 = parent_nodes[6]
            assert cast_node_3.op_type == "Cast"

        mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node)
        if mul_node_3 is None:
            logger.debug("mul_node_3 not found")
            return

        node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
        root_node = (
            node_before_reduce
            if cast_node_3 is None
            else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
        )
        if root_node is None:
            logger.debug("root node is none")
            return

        i, epsilon = self.model.get_constant_input(add_node_0)
        if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None):
            logger.debug("epsilon is not matched")
            return

        if cast_node_3 is None and (
            reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
        ):
            logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
            return

        if cast_node_3 is not None and (
            node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
        ):
            logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
            return

        if mul_node_2.input[0] != mul_node_2.input[1]:
            logger.debug("mul_node_2 shall have two same inputs")
            return

        subgraph_nodes = [
            node,
            sub_node_0,
            mul_node_0,
            mul_node_1,
            reciprocol_node,
            sqrt_node,
            add_node_0,
            reduce_mean_node_0,
            mul_node_2,
            sub_node_1,
            reduce_mean_node_1,
            mul_node_3,
        ]

        if cast_node_3 is not None:
            cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node)
            if cast_node_2 is None:
                logger.debug("cast_node_2 not found")
                return
            subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])

        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            node.output,
            self.model.input_name_to_nodes(),
            self.model.output_name_to_node(),
        ):
            logger.debug("not safe to fuse layer normalization")
            return

        self.nodes_to_remove.extend(subgraph_nodes)

        weight_input = mul_node_1.input[1]
        bias_input = sub_node_0.input[0]

        # TODO: add epsilon attribute
        fused_node = helper.make_node(
            "LayerNormalization",
            inputs=[mul_node_3.input[0], weight_input, bias_input],
            outputs=[node.output[0]],
            name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
        )
        fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
