# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Dict, Optional

from fusion_base import Fusion
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionFastGelu(Fusion):
    def __init__(self, model: OnnxModel):
        super().__init__(model, "FastGelu", "Tanh")

    def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
        if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
            return

        if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
            return

        if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
            return

    def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
        """
        Fuse Gelu with tanh into one node:
              +---------------------------+
              |                           |
              |                           v
            [root] --> Pow --> Mul -----> Add  --> Mul --> Tanh --> Add --> Mul
              |       (Y=3)   (B=0.0447...)       (B=0.7978...)    (B=1)     ^
              |                                                              |
              +------> Mul(B=0.5)--------------------------------------------+
        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if tanh_node.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[tanh_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_tanh = children[0]

        if not self.model.has_constant_input(add_after_tanh, 1.0):
            return

        if add_after_tanh.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_tanh.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_after_tanh = children[0]

        mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
        if mul_half is None:
            return

        i = self.model.find_constant_input(mul_half, 0.5)
        if i < 0:
            return

        root_input = mul_half.input[0 if i == 1 else 1]

        # root_node could be None when root_input is graph input
        root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)

        mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
        if mul_before_tanh is None:
            return

        i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
        if i < 0:
            return

        add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
        if add_before_tanh is None:
            return

        mul_after_pow = self.model.match_parent(
            add_before_tanh,
            "Mul",
            None,
            output_name_to_node,
            exclude=[root_node] if root_node else [],
        )
        if mul_after_pow is None:
            return

        i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
        if i < 0:
            return

        pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
        if pow is None:
            return

        if not self.model.has_constant_input(pow, 3.0):
            return

        if pow.input[0] != root_input:
            return

        subgraph_nodes = [
            mul_after_tanh,
            mul_half,
            add_after_tanh,
            tanh_node,
            mul_before_tanh,
            add_before_tanh,
            mul_after_pow,
            pow,
        ]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            [mul_after_tanh.output[0]],
            input_name_to_nodes,
            output_name_to_node,
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "FastGelu",
            inputs=[root_input],
            outputs=mul_after_tanh.output,
            name=self.model.create_node_name("FastGelu"),
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        return True

    def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
        """
        This pattern is from Tensorflow model.
        Fuse Gelu with tanh into one node:
              +---------------------------+
              |                           |
              |                           v
            [root] --> Pow --> Mul -----> Add  --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
              |       (Y=3)   (B=0.0447...)       (B=0.7978...)    (B=1)                  ^
              |                                                                           |
              +---------------------------------------------------------------------------+
        Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
        """
        if tanh_node.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[tanh_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_tanh = children[0]

        if not self.model.has_constant_input(add_after_tanh, 1.0):
            return

        if add_after_tanh.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_tanh.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_half = children[0]

        i = self.model.find_constant_input(mul_half, 0.5)
        if i < 0:
            return

        if mul_half.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[mul_half.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_after_mul_half = children[0]

        root_node = self.model.get_parent(
            mul_after_mul_half,
            0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
            output_name_to_node,
        )
        if root_node is None:
            return

        mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
        if mul_before_tanh is None:
            return

        i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
        if i < 0:
            return

        add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
        if add_before_tanh is None:
            return

        mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
        if mul_after_pow is None:
            return

        i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
        if i < 0:
            return

        pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
        if pow is None:
            return

        if not self.model.has_constant_input(pow, 3.0):
            return

        if pow.input[0] != root_node.output[0]:
            return

        subgraph_nodes = [
            mul_after_mul_half,
            mul_half,
            add_after_tanh,
            tanh_node,
            mul_before_tanh,
            add_before_tanh,
            mul_after_pow,
            pow,
        ]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            [mul_after_mul_half.output[0]],
            input_name_to_nodes,
            output_name_to_node,
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "FastGelu",
            inputs=[root_node.output[0]],
            outputs=mul_after_mul_half.output,
            name=self.model.create_node_name("FastGelu"),
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        return True

    def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
        """
        OpenAI's gelu implementation, also used in Megatron:
           Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))

        Fuse subgraph into a FastGelu node:
            +------------ Mul (B=0.79788456) -------------------+
            |                                                   |
            +-------------------------------+                   |
            |                               |                   |
            |                               v                   v
          [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
            |                                                                                 ^
            |                                                                                 |
            +-----------> Mul (B=0.5) --------------------------------------------------------+
        """
        if tanh_node.output[0] not in input_name_to_nodes:
            return

        children = input_name_to_nodes[tanh_node.output[0]]
        if len(children) != 1 or children[0].op_type != "Add":
            return
        add_after_tanh = children[0]

        if not self.model.has_constant_input(add_after_tanh, 1.0):
            return

        if add_after_tanh.output[0] not in input_name_to_nodes:
            return
        children = input_name_to_nodes[add_after_tanh.output[0]]
        if len(children) != 1 or children[0].op_type != "Mul":
            return
        mul_last = children[0]

        mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
        if mul_half is None:
            return

        i = self.model.find_constant_input(mul_half, 0.5)
        if i < 0:
            return

        root_input = mul_half.input[0 if i == 1 else 1]

        mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
        if mul_before_tanh is None:
            return

        add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
        if add_1 is None:
            return
        j = self.model.find_constant_input(add_1, 1.0)
        if j < 0:
            return

        mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
        if mul_7978 is None:
            return
        k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
        if k < 0:
            return
        if mul_7978.input[0 if k == 1 else 1] != root_input:
            return

        mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
        if mul_before_add_1 is None:
            return

        if mul_before_add_1.input[0] == root_input:
            another = 1
        elif mul_before_add_1.input[1] == root_input:
            another = 0
        else:
            return

        mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
        if mul_0447 is None:
            return
        m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
        if m < 0:
            return

        if mul_0447.input[0 if m == 1 else 1] != root_input:
            return

        subgraph_nodes = [
            mul_0447,
            mul_before_add_1,
            add_1,
            mul_before_tanh,
            tanh_node,
            add_after_tanh,
            mul_7978,
            mul_half,
            mul_last,
        ]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes,
            [mul_last.output[0]],
            input_name_to_nodes,
            output_name_to_node,
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)
        fused_node = helper.make_node(
            "FastGelu",
            inputs=[root_input],
            outputs=mul_last.output,
            name=self.model.create_node_name("FastGelu"),
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
        return True
