# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import onnx

from ..onnx_model import ONNXModel
from .fusion import Fusion


class FusionGelu(Fusion):
    def __init__(self, model: ONNXModel):
        super().__init__(model, "Gelu", "Erf")

    def fuse(
        self,
        erf_node: onnx.NodeProto,
        input_name_to_nodes: dict[str, list[onnx.NodeProto]],
        output_name_to_node: dict[str, onnx.NodeProto],
    ):
        """
        Interface function that tries to fuse a node sequence containing an Erf node into a single
        Gelu node.
        """
        if (
            self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
            or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
            or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
        ):
            self.model.set_opset_import("com.microsoft", 1)

    def fuse_1(
        self,
        erf_node: onnx.NodeProto,
        input_name_to_nodes: dict[str, list[onnx.NodeProto]],
        output_name_to_node: dict[str, onnx.NodeProto],
    ) -> bool:
        """
        This pattern is from PyTorch model
        Fuse Gelu with Erf into one node:
        Pattern 1:
                       +-------Mul(0.5)---------------------+
                       |                                    |
                       |                                    v
                    [root] --> Div -----> Erf  --> Add --> Mul -->
                              (B=1.4142...)       (1)

        Pattern 2:
                       +------------------------------------+
                       |                                    |
                       |                                    v
                    [root] --> Div -----> Erf  --> Add --> Mul -->Mul -->
                              (B=1.4142...)       (1)            (0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if erf_node.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return False
        add_after_erf = children[0]

        if not self.has_constant_input(add_after_erf, 1):
            return False

        if add_after_erf.output[0] not in input_name_to_nodes:
            return False

        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return False

        mul_after_erf = children[0]

        div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
        if div is None:
            return False

        if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
            return False

        subgraph_input = div.input[0]

        another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
        if subgraph_input == mul_after_erf.input[another]:  # pattern 2
            children = input_name_to_nodes[mul_after_erf.output[0]]
            if len(children) != 1 or children[0].op_type != "Mul":
                return False
            mul_half = children[0]
            if not self.has_constant_input(mul_half, 0.5):
                return False
            subgraph_output = mul_half.output[0]
        else:  # pattern 1
            mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
            if mul_half is None:
                return False

            if not self.has_constant_input(mul_half, 0.5):
                return False

            if subgraph_input not in mul_half.input:
                return False

            subgraph_output = mul_after_erf.output[0]

        subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
        if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
            return False

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = onnx.helper.make_node(
            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        return True

    def fuse_2(
        self,
        erf_node: onnx.NodeProto,
        input_name_to_nodes: dict[str, list[onnx.NodeProto]],
        output_name_to_node: dict[str, onnx.NodeProto],
    ) -> bool:
        """
        This pattern is from Keras model
        Fuse Gelu with Erf into one node:
                       +------------------------------------------+
                       |                                          |
                       |                                          v
                    [root] --> Div -----> Erf  --> Add --> Mul -->Mul
                              (B=1.4142...)       (A=1)   (A=0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if erf_node.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return False
        add_after_erf = children[0]

        if not self.has_constant_input(add_after_erf, 1):
            return False

        if add_after_erf.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return False
        mul_after_erf = children[0]

        if not self.has_constant_input(mul_after_erf, 0.5):
            return False

        if mul_after_erf.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[mul_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return False
        mul = children[0]

        div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
        if div is None:
            return False

        sqrt_node = None
        if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
            sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
            if sqrt_node is None:
                return False
            if not self.has_constant_input(sqrt_node, 2.0):
                return False

        subgraph_input = div.input[0]

        if subgraph_input not in mul.input:
            return False

        subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
        if sqrt_node:
            subgraph_nodes.append(sqrt_node)

        if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
            return False

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = onnx.helper.make_node(
            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        return True

    def fuse_3(
        self,
        erf_node: onnx.NodeProto,
        input_name_to_nodes: dict[str, list[onnx.NodeProto]],
        output_name_to_node: dict[str, onnx.NodeProto],
    ) -> bool:
        """
        This pattern is from TensorFlow model
        Fuse Gelu with Erf into one node:
                       +----------------------------------------------+
                       |                                              |
                       |                                              v
                    [root] --> Mul -----> Erf    -->   Add --> Mul -->Mul
                               (A=0.7071067690849304)  (B=1)  (B=0.5)

        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """

        if erf_node.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[erf_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return False
        add_after_erf = children[0]

        if not self.has_constant_input(add_after_erf, 1):
            return False

        if add_after_erf.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[add_after_erf.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return False
        mul_half = children[0]

        if not self.has_constant_input(mul_half, 0.5):
            return False

        first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
        if first_mul is None:
            return False

        i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
        if i < 0:
            return False

        root_input_index = 1 - i
        subgraph_input = first_mul.input[root_input_index]

        if mul_half.output[0] not in input_name_to_nodes:
            return False
        children = input_name_to_nodes[mul_half.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return False
        last_mul = children[0]

        if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
            return False

        subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
        if not self.is_safe_to_fuse_nodes(
            subgraph_nodes,
            [last_mul.output[0]],
            input_name_to_nodes,
            output_name_to_node,
        ):
            return False

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = onnx.helper.make_node(
            "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        return True
