# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

import numpy as np
from fusion_attention import AttentionMask, FusionAttention
from onnx import TensorProto, helper
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionBartAttention(FusionAttention):
    """
    Fuse Bart Attention subgraph into one Attention node.
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
        attention_mask: AttentionMask,
    ):
        super().__init__(model, hidden_size, num_heads, attention_mask)

    def check_runtime_shape_path(
        self,
        reshape_qkv_2,
        reshape_qkv_1,
        reshape_q_2,
        reshape_k_2,
        reshape_v_2,
        root_input,
    ):
        concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
        if concat_qkv_2_path is None:
            return False
        concat_qkv_2 = concat_qkv_2_path[0]

        reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None:
            return False

        _, gather_1, shape_1 = reshape_qkv_2_path_1
        _, gather_2, shape_2 = reshape_qkv_2_path_2

        if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
            return False

        reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0])
        reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0])
        if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None:
            return False
        if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name:
            return False

        reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
        reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
        reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0])
        if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None:
            return False

        mul_q = reshape_q_2_path[-1]
        mul_k = reshape_k_2_path[-1]
        mul_v = reshape_v_2_path[-1]

        gather_1_out = gather_1.output[0]
        if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
            return False

        return True

    def check_runtime_shape_path_openai(
        self,
        reshape_qkv_2,
        matmul_qkv,
        add_qk,
        matmul_qk,
        add_q,
    ):
        reshape_qkv_2_path = self.model.match_parent_path(
            reshape_qkv_2, ["Concat", "Slice", "Gather", "Shape"], [1, 0, 0, 0]
        )
        if reshape_qkv_2_path is None:
            return False
        else:
            if reshape_qkv_2_path[-1].input[0] != matmul_qkv.output[0]:
                return False

        matmul_qk_path_1 = self.model.match_parent_path(
            matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [0, 1, 0, 0, 0, 0]
        )
        matmul_qk_path_2 = self.model.match_parent_path(
            matmul_qk, ["Mul", "Pow", "Cast", "Div", "Gather", "Shape"], [1, 1, 0, 0, 0, 0]
        )
        if matmul_qk_path_1 is None or matmul_qk_path_2 is None:
            return False

        mul_1 = matmul_qk_path_1[0]
        mul_2 = matmul_qk_path_2[0]
        if mul_1.input[1] != mul_2.input[1]:
            return False
        if matmul_qk_path_1[-1].input[0] != add_q.output[0] and matmul_qk_path_2[-1].input[0] != add_q.output[0]:
            return False

        # For decoder attentions only
        if add_qk is not None:
            add_qk_path = self.model.match_parent_path(add_qk, ["Slice"], [1])
            if add_qk_path is None:
                return False
            slice_q_path_1 = self.model.match_parent_path(
                add_qk_path[0], ["Slice", "Unsqueeze", "Gather", "Shape"], [0, 2, 0, 0]
            )
            slice_q_path_2 = self.model.match_parent_path(add_qk_path[0], ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
            if slice_q_path_1 is None and slice_q_path_2 is None:
                return False
            _, unsqueeze_1, _, _ = slice_q_path_1
            unsqueeze_2, _, _ = slice_q_path_2
            if unsqueeze_1.input[0] != unsqueeze_2.input[0]:
                return False
            if slice_q_path_1[-1].input[0] != add_q.output[0] and slice_q_path_2[-1].input[0] != add_q.output[0]:
                return False

        return True

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        # Track if fusion is occurring for OpenAI implementation of Whisper
        model_impl_openai = False

        # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
        qkv_nodes = self.model.match_parent_path(
            normalize_node,
            ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
            [1, 1, 0, 0, 0, 0],
        )
        qkv_nodes_openai = self.model.match_parent_path(
            normalize_node,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [1, 1, 0, 0, 0],
        )
        if qkv_nodes is not None:
            (
                add_out,
                matmul_out,
                reshape_qkv_2,
                transpose_qkv,
                reshape_qkv_1,
                matmul_qkv,
            ) = qkv_nodes
        elif qkv_nodes_openai is not None:
            qkv_nodes = qkv_nodes_openai
            (
                add_out,
                matmul_out,
                reshape_qkv_2,
                transpose_qkv,
                matmul_qkv,
            ) = qkv_nodes
            # Set model implementation to openai
            model_impl_openai = True
        else:
            return

        other_inputs = []
        for input in normalize_node.input:
            if input not in output_name_to_node:
                continue
            if input == qkv_nodes[0].output[0]:
                continue
            other_inputs.append(input)
        if len(other_inputs) != 1:
            return
        root_input = other_inputs[0]

        # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
        # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
        # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
        # children nodes for each of its output names.
        """
                                        root_input
                    +---------------------------------------------------+
                    |                                                   |
                    |                                                   |
        SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
        """
        skip_layernorm = output_name_to_node[root_input]
        # For some attention blocks, the end SkipLayerNormalization node may point to an Add node whose
        # child is the LayerNormalization node.
        if skip_layernorm.op_type == "Add":
            skip_layernorm = self.model.get_children(skip_layernorm)[0]
        for output in skip_layernorm.output:
            if not output:
                continue
            children = input_name_to_nodes[output]
            children_types = [child.op_type for child in children]
            if children_types.count("MatMul") >= 1:
                root_input = output
                break

        graph_input_names = set([node.name for node in self.model.graph().input])
        graph_output_names = set([node.name for node in self.model.graph().output])

        v_nodes = self.model.match_parent_path(
            matmul_qkv,
            ["Reshape", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 0, 0, None],
        )
        v_nodes_openai = self.model.match_parent_path(
            matmul_qkv,
            ["Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 0, None],
        )
        v_nodes_with_past_self_attn = self.model.match_parent_path(
            # Decoder attention with past value concatenated before MatMul
            matmul_qkv,
            ["Reshape", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 1, 0, 0, None],
        )
        v_nodes_with_past_cross_attn = self.model.match_parent_path(
            # Decoder attention with past value directly used in MatMul
            matmul_qkv,
            ["Reshape"],
            [1],
        )
        v_nodes_with_past_cross_attn_openai = self.model.match_parent_path(
            matmul_qkv,
            ["Transpose", "Reshape", "Reshape", "Transpose"],
            [1, 0, 0, 0],
        )
        past_v, present_v = "", ""
        reshape_v_2, add_v = None, None
        if v_nodes is not None:
            (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
            # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
            present_v = transpose_v.output[0]
        elif v_nodes_openai is not None:
            v_nodes = v_nodes_openai
            (transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes
            # For initial pass through encoder-decoder_with_past to get starting past values (beam search)

            # Find the child path to access the correct present_v values
            # Openai impl provides present/past v values in 3D format
            # whereas ort MultiHeadAttention expects v values in 4D, hence the
            # additional Reshape and Transpose nodes are added
            # For encoder attention types
            # Add -> Reshape -> Transpose -> Present_V
            reshape_path = self.model.match_child_path(
                add_v,
                ["Reshape", "Transpose"],
                exclude=[reshape_v_1],
            )
            # For decoder attention types
            # add_v_node                     Reshape <- Transpose <-Past_V
            #           \                  /
            #             \              /
            #               -> Concat <-
            #                    |
            #                    |--> Reshape -> Transpose -> Present_V
            concat_path = self.model.match_child_path(add_v, ["Concat", "Reshape", "Transpose"])
            if reshape_path is not None:
                (_, transpose_add_v) = reshape_path
                if transpose_add_v.output[0] in graph_output_names:
                    present_v = transpose_add_v.output[0]
            if concat_path is not None:
                (concat_v, _, transpose_concat_v) = concat_path
                if transpose_concat_v.output[0] in graph_output_names:
                    present_v = transpose_concat_v.output[0]
                concat_nodes = self.model.match_parent_path(concat_v, ["Reshape", "Transpose"], [0, 0])
                _, transpose_concat_v_in = concat_nodes
                past_v = transpose_concat_v_in.input[0]
        elif v_nodes_with_past_self_attn is not None:
            (reshape_v_2, concat_v, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes_with_past_self_attn
            v_nodes = v_nodes_with_past_self_attn
            past_v = concat_v.input[0]
            present_v = concat_v.output[0]
        elif (
            v_nodes_with_past_cross_attn is not None and v_nodes_with_past_cross_attn[-1].input[0] in graph_input_names
        ):
            v_nodes = v_nodes_with_past_cross_attn
            past_v = v_nodes[-1].input[0]
            present_v = v_nodes[-1].output[0]
            if present_v not in graph_output_names:
                identity_node_v = list(
                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
                )
                present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
        elif (
            v_nodes_with_past_cross_attn_openai is not None
            and v_nodes_with_past_cross_attn_openai[-1].input[0] in graph_input_names
        ):
            v_nodes = v_nodes_with_past_cross_attn_openai
            past_v = v_nodes[-1].input[0]
            present_v = v_nodes[-1].output[0]
            if present_v not in graph_output_names:
                identity_node_v = list(
                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_v])
                )
                present_v = identity_node_v[0].output[0] if len(identity_node_v) == 1 else ""
        else:
            logger.debug("fuse_attention: failed to match v path")
            return
        past_v = past_v if past_v in graph_input_names else ""
        present_v = present_v if present_v in graph_output_names else ""

        qk_nodes_1 = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
        qk_nodes_2 = self.model.match_parent_path(
            matmul_qkv, ["Softmax", "Reshape", "Add", "Reshape", "MatMul"], [0, 0, 0, 0, 0]
        )
        qk_nodes_2_openai = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
        add_qk = None
        if qk_nodes_1 is not None:
            _, matmul_qk = qk_nodes_1
            qk_nodes = qk_nodes_1
        elif qk_nodes_2 is not None:
            _, _, add_qk, _, matmul_qk = qk_nodes_2
            qk_nodes = qk_nodes_2
        elif qk_nodes_2_openai is not None:
            _, add_qk, matmul_qk = qk_nodes_2_openai
            qk_nodes = qk_nodes_2_openai
        else:
            return

        q_nodes = self.model.match_parent_path(
            matmul_qk,
            ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"],
            [0, 0, 0, 0, 0, 1],
        )
        q_nodes_openai = self.model.match_parent_path(
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
            [0, 0, 0, 0, 1],
        )
        reshape_q_2 = None
        if q_nodes is not None:
            reshape_q_2, transpose_q, reshape_q_1, mul_q, add_q, matmul_q = q_nodes
        elif q_nodes_openai is not None:
            q_nodes = q_nodes_openai
            mul_q, transpose_q, reshape_q_1, add_q, matmul_q = q_nodes
        else:
            return

        k_nodes_with_bias = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 0, 0, 0, 1],
        )
        k_nodes_with_bias_openai = self.model.match_parent_path(
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 0],
        )
        k_nodes_no_bias = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 0, 0],
        )
        k_nodes_no_bias_with_past_self_attn = self.model.match_parent_path(
            # Decoder attention with past key concatenated before MatMul
            matmul_qk,
            ["Transpose", "Reshape", "Concat", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 1, 0, 0],
        )
        k_nodes_no_bias_with_past_cross_attn = self.model.match_parent_path(
            # Decoder attention with past key directly used in MatMul
            matmul_qk,
            ["Transpose", "Reshape"],
            [1, 0],
        )
        k_nodes_no_bias_with_past_cross_attn_openai = self.model.match_parent_path(
            # Decoder attention with past key directly used in MatMul
            matmul_qk,
            ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
            [1, 0, 0, 0, 0],
        )
        past_k, present_k = "", ""
        reshape_k_2, reshape_k_1, matmul_k = None, None, None
        if k_nodes_with_bias is not None:
            _, reshape_k_2, transpose_k_1, reshape_k_1, add_k, matmul_k = k_nodes_with_bias
            k_nodes = k_nodes_with_bias
        elif k_nodes_with_bias_openai is not None:
            mul_k, transpose_k_1, reshape_k_1, matmul_k = k_nodes_with_bias_openai
            k_nodes = k_nodes_with_bias_openai
            present_k = matmul_k.output[0]

            # Find the child path to access the correct present_k values
            # Openai impl provides present/past k values in 3D format
            # whereas ort MultiHeadAttention expects k values in 4D, hence the
            # additional Reshape and Transpose nodes are added
            # For encoder attention types
            # Matmul -> Reshape -> Transpose -> Present_K
            reshape_path = self.model.match_child_path(
                matmul_k,
                ["Reshape", "Transpose"],
                exclude=[reshape_k_1],
            )
            # For decoder attention types
            # matmul_k_node                  Reshape <- Transpose <- Past_K
            #           \                  /
            #             \              /
            #               -> Concat <-
            #                    |
            #                    |--> Reshape -> Transpose -> Present_K
            concat_path = self.model.match_child_path(matmul_k, ["Concat", "Reshape", "Transpose"])
            if reshape_path is not None:
                (_, transpose_matmul_k) = reshape_path
                if transpose_matmul_k.output[0] in graph_output_names:
                    present_k = transpose_matmul_k.output[0]
            if concat_path is not None:
                (concat_k, _, transpose_concat_k) = concat_path
                if transpose_concat_k.output[0] in graph_output_names:
                    present_k = transpose_concat_k.output[0]
                concat_nodes = self.model.match_parent_path(concat_k, ["Reshape", "Transpose"], [0, 0])
                _, transpose_concat_k_in = concat_nodes
                past_k = transpose_concat_k_in.input[0]
        elif k_nodes_no_bias is not None:
            _, reshape_k_2, transpose_k_1, reshape_k_1, matmul_k = k_nodes_no_bias
            k_nodes = k_nodes_no_bias
            # For initial pass through encoder-decoder_with_past to get starting past values (beam search)
            present_k = transpose_k_1.output[0]
        elif k_nodes_no_bias_with_past_self_attn is not None:
            _, reshape_k_2, concat_k, _, reshape_k_1, matmul_k = k_nodes_no_bias_with_past_self_attn
            k_nodes = k_nodes_no_bias_with_past_self_attn
            past_k = concat_k.input[0]
            present_k = concat_k.output[0]
        elif (
            k_nodes_no_bias_with_past_cross_attn is not None
            and k_nodes_no_bias_with_past_cross_attn[-1].input[0] in graph_input_names
        ):
            k_nodes = k_nodes_no_bias_with_past_cross_attn
            past_k = k_nodes[-1].input[0]
            present_k = k_nodes[-1].output[0]
            if present_k not in graph_output_names:
                identity_node_k = list(
                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
                )
                present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
        elif (
            k_nodes_no_bias_with_past_cross_attn_openai is not None
            and k_nodes_no_bias_with_past_cross_attn_openai[-1].input[0] in graph_input_names
        ):
            k_nodes = k_nodes_no_bias_with_past_cross_attn_openai
            past_k = k_nodes[-1].input[0]
            present_k = k_nodes[-1].output[0]
            if present_k not in graph_output_names:
                identity_node_k = list(
                    filter(lambda node: node.op_type == "Identity", self.model.input_name_to_nodes()[past_k])
                )
                present_k = identity_node_k[0].output[0] if len(identity_node_k) == 1 else ""
        else:
            return
        past_k = past_k if past_k in graph_input_names else ""
        present_k = present_k if present_k in graph_output_names else ""

        if k_nodes in (k_nodes_with_bias_openai, k_nodes_no_bias, k_nodes_no_bias_with_past_self_attn):
            # Create empty Add node for attention graph
            bias_dim = self.model.get_initializer(add_v.input[0]).dims[0]
            empty_bias_name = "empty_bias"
            empty_tensor = self.model.get_initializer(empty_bias_name)
            if empty_tensor is None:
                self.add_initializer(
                    empty_bias_name,
                    TensorProto.FLOAT,
                    dims=[bias_dim],
                    vals=np.array([0.0] * bias_dim, dtype=np.float32),
                )

            add_name = self.model.create_node_name("Add")
            add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name)

        if (
            model_impl_openai
            and not past_k
            and not self.check_runtime_shape_path_openai(
                reshape_qkv_2,
                matmul_qkv,
                add_qk,
                matmul_qk,
                add_q,
            )
        ):
            return
        elif (
            not model_impl_openai
            and not past_k
            and not self.check_runtime_shape_path(
                reshape_qkv_2,
                reshape_qkv_1,
                reshape_q_2,
                reshape_k_2,
                reshape_v_2,
                root_input,
            )
        ):
            return

        three_root_inputs = past_k and past_v and matmul_k is None and "matmul_v" not in locals()
        one_root_input = (
            not three_root_inputs
            and matmul_k.input[0] == root_input
            and matmul_q.input[0] == root_input
            and matmul_v.input[0] == root_input
        )
        two_root_inputs = (
            not three_root_inputs
            and matmul_q.input[0] == root_input
            and matmul_k.input[0] == matmul_v.input[0]
            and matmul_k.input[0] != matmul_q.input[0]
        )

        # There are 5 types of attention:
        # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_1
        # 2) Decoder attention with one_root_input=True and qk_nodes=qk_nodes_2
        # 3) Decoder attention with past with one_root_input=True and qk_nodes=qk_nodes_1 and past_k=past_decoder_key and past_v=past_decoder_value
        # 4) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_1
        # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_1
        encoder_attention = one_root_input and qk_nodes == qk_nodes_1
        decoder_attention = one_root_input and qk_nodes in (qk_nodes_2, qk_nodes_2_openai)
        decoder_attention_with_past = (
            (encoder_attention if not model_impl_openai else decoder_attention) and past_k and past_v
        )
        decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_1
        decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_1

        # For decoder_attention, the attention mask needs to be included in the attention node
        mask_index = None
        if decoder_attention:
            mask_nodes_bart = self.model.match_parent_path(
                add_qk,
                ["Where"],
                [1],
            )
            mask_nodes_whisper = self.model.match_parent_path(
                add_qk,
                ["Expand", "Unsqueeze", "Unsqueeze", "Where"],
                [1, 0, 0, 0],
            )
            if mask_nodes_whisper is not None:
                mask_index = mask_nodes_whisper[0].output[-1]
            elif mask_nodes_bart is not None:
                mask_index = mask_nodes_bart[0].output[-1]

        if (
            encoder_attention
            or decoder_attention
            or decoder_attention_with_past
            or decoder_cross_attention
            or decoder_cross_attention_with_past
        ):
            attention_last_node = reshape_qkv_2
            num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q_1)

            if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
                logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
                return

            new_node = None
            if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
                # Note: Decoder attention with past key and past value is fused as multihead attention
                # rather than attention because multihead attention supports separate past key and past
                # value whereas attention supports concatenated past key and past value.
                new_node = (
                    self.create_multihead_attention_node(
                        matmul_q,
                        matmul_k if decoder_cross_attention or decoder_attention_with_past else past_k,
                        matmul_v if decoder_cross_attention or decoder_attention_with_past else past_v,
                        add_q,
                        add_k if decoder_cross_attention or decoder_attention_with_past else None,
                        add_v if decoder_cross_attention or decoder_attention_with_past else None,
                        num_heads,
                        hidden_size,
                        attention_last_node.output[0],
                        past_k=past_k if decoder_attention_with_past else "",
                        past_v=past_v if decoder_attention_with_past else "",
                        present_k=present_k,
                        present_v=present_v,
                        packed_qkv=decoder_attention_with_past,
                    )
                    if self.use_multi_head_attention
                    else None
                )
            else:
                # Temporarily set multihead attention flag to false
                use_multi_head_attention_ground_truth = self.use_multi_head_attention
                self.use_multi_head_attention = False
                new_node = self.create_attention_node(
                    None,
                    matmul_q,
                    matmul_k,
                    matmul_v,
                    add_q,
                    add_k,
                    add_v,
                    num_heads,
                    hidden_size,
                    root_input,
                    attention_last_node.output[0],
                    add_qk_str=mask_index if decoder_attention else None,
                    past_k=past_k,
                    past_v=past_v,
                    present_k=present_k,
                    present_v=present_v,
                )
                self.use_multi_head_attention = use_multi_head_attention_ground_truth
            if new_node is None:
                return

            self.nodes_to_add.append(new_node)
            self.node_name_to_graph_name[new_node.name] = self.this_graph_name

            self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
            self.nodes_to_remove.extend(qk_nodes)

            # When using multihead attention, keep MatMul nodes in original graph
            if decoder_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
                if q_nodes[-1].op_type == "MatMul":
                    q_nodes.pop()
                if k_nodes[-1].op_type == "MatMul":
                    k_nodes.pop()
                if v_nodes[-1].op_type == "MatMul":
                    v_nodes.pop()
                if self.disable_multi_head_attention_bias and (
                    decoder_cross_attention or decoder_cross_attention_with_past
                ):
                    if q_nodes[-1].op_type == "Add":
                        q_nodes.pop()
                    if k_nodes[-1].op_type == "Add":
                        k_nodes.pop()
                    if v_nodes[-1].op_type == "Add":
                        v_nodes.pop()

            self.nodes_to_remove.extend(q_nodes)
            self.nodes_to_remove.extend(k_nodes)
            self.nodes_to_remove.extend(v_nodes)

            # Use prune graph to remove mask nodes since they are shared by all attention nodes.
            self.prune_graph = True
