# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from logging import getLogger

from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionBiasGelu(Fusion):
    def __init__(self, model: OnnxModel, is_fastgelu):
        if is_fastgelu:
            super().__init__(model, "FastGelu", "FastGelu", "add bias")
        else:
            super().__init__(model, "BiasGelu", "Gelu")

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        gelu_op_type = node.op_type
        fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"

        if len(node.input) != 1:
            return

        nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
        if nodes is None:
            return
        (add, matmul) = nodes

        bias_weight = None
        # bias should be one dimension
        bias_index = -1
        for i, input in enumerate(add.input):
            initializer = self.model.get_initializer(input)
            if initializer is None:
                continue
            bias_index = i
            bias_weight = NumpyHelper.to_array(initializer)
            break
        if bias_weight is None:
            return
        if len(bias_weight.shape) != 1:
            return

        subgraph_nodes = [node, add]
        if not self.model.is_safe_to_fuse_nodes(
            subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
        ):
            return

        self.nodes_to_remove.extend(subgraph_nodes)

        fused_node = helper.make_node(
            fuse_op_type,
            inputs=[matmul.output[0], add.input[bias_index]],
            outputs=node.output,
            name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
        )
        fused_node.domain = "com.microsoft"
        self.nodes_to_add.append(fused_node)
        self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
