# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, Optional

from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionGelu(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "Gelu", "Erf")

    def fuse(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
        if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node):
            return
        if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node):
            return
        self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)

    def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
        """
        This pattern is from PyTorch model
        Fuse Gelu with Erf into one node:
        Pattern 1:
                       +-------Mul(0.5)---------------------+
                       |                                    |
                       |                                    v
                    [root] --> Div -----> Erf  --> Add --> Mul -->
                              (B=1.4142...)       (1)

        Pattern 2:
                       +------------------------------------+
                       |                                    |
                       |                                    v
                    [root] --> Div -----> Erf  --> Add --> Mul -->Mul -->
                              (B=1.4142...)       (1)            (0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if erf_node.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_erf = children[0]

        if not self.model.has_constant_input(add_after_erf, 1):
            return

        if add_after_erf.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_after_erf = children[0]

        div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
        if div is None:
            return

        if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
            return

        subgraph_input = div.input[0]

        another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
        if subgraph_input == mul_after_erf.input[another]:  # pattern 2
            children = input_name_to_nodes[mul_after_erf.output[0]]
            if len(children) != 1 or children[0].op_type != "Mul":
                return
            mul_half = children[0]
            if not self.model.has_constant_input(mul_half, 0.5):
                return
            subgraph_output = mul_half.output[0]
        else:  # pattern 1
            mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
            if mul_half is None:
                return

            if not self.model.has_constant_input(mul_half, 0.5):
                return

            if subgraph_input not in mul_half.input:
                return

            subgraph_output = mul_after_erf.output[0]

        subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "Gelu", inputs=[subgraph_input], outputs=[subgraph_output], name=self.model.create_node_name("Gelu")
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        self.increase_counter("Gelu")
        return True

    def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
        """
        This pattern is from Keras model
        Fuse Gelu with Erf into one node:
                       +------------------------------------------+
                       |                                          |
                       |                                          v
                    [root] --> Div -----> Erf  --> Add --> Mul -->Mul
                              (B=1.4142...)       (A=1)   (A=0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if erf_node.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_erf = children[0]

        if not self.model.has_constant_input(add_after_erf, 1):
            return

        if add_after_erf.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_after_erf = children[0]

        if not self.model.has_constant_input(mul_after_erf, 0.5):
            return

        if mul_after_erf.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[mul_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul = children[0]

        div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
        if div is None:
            return

        sqrt_node = None
        if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
            sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node)
            if sqrt_node is None:
                return
            if not self.model.has_constant_input(sqrt_node, 2.0):
                return

        root_node = self.model.get_parent(div, 0, output_name_to_node)
        if root_node is None:
            return

        if root_node.output[0] not in mul.input:
            return

        subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
        if sqrt_node:
            subgraph_nodes.append(sqrt_node)

        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]], name=self.model.create_node_name("Gelu")
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        self.increase_counter("Gelu")
        return True

    def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
        """
        This pattern is from TensorFlow model
        Fuse Gelu with Erf into one node:
                       +----------------------------------------------+
                       |                                              |
                       |                                              v
                    [root] --> Mul -----> Erf    -->   Add --> Mul -->Mul
                               (A=0.7071067690849304)  (B=1)  (B=0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """

        if erf_node.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_erf = children[0]

        if not self.model.has_constant_input(add_after_erf, 1):
            return

        if add_after_erf.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_half = children[0]

        if not self.model.has_constant_input(mul_half, 0.5):
            return

        first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node)
        if first_mul is None:
            return

        i = self.model.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
        if i < 0:
            return

        root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
        if root_node is None:
            return

        if mul_half.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[mul_half.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        last_mul = children[0]

        if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
            return

        subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            [last_mul.output[0]],
            input_name_to_nodes,
            output_name_to_node,
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]], name=self.model.create_node_name("Gelu")
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        self.increase_counter("Gelu")
        return True
