# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
from typing import Optional, Union

from fusion_attention import FusionAttention
from fusion_base import Fusion
from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel

logger = logging.getLogger(__name__)


class FusionRotaryAttention(FusionAttention):
    """
    Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
    ):
        super().__init__(
            model,
            hidden_size,
            num_heads,
            use_multi_head_attention=True,
            search_op_types=[
                "SimplifiedLayerNormalization",
                "SkipSimplifiedLayerNormalization",
                "LayerNormalization",
                "SkipLayerNormalization",
                "Add",
            ],
        )

    def create_mha_node(
        self,
        input: str,
        output: str,
        q_rotary: NodeProto,
        k_rotary: NodeProto,
        v_matmul: NodeProto,
        attn_mask: str = "",
        add_qk: str = "",
        past_k: str = "",
        past_v: str = "",
        present_k: str = "",
        present_v: str = "",
        scale: Optional[float] = None,
    ) -> Union[NodeProto, None]:
        assert self.num_heads > 0

        if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0:
            logger.debug(
                f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}"
            )
            return None

        mha_node_name = self.model.create_node_name("MultiHeadAttention")
        mha_inputs = [
            q_rotary.output[0],
            k_rotary.output[0],
            v_matmul.output[0],
            "",  # bias
            attn_mask,  # key_padding_mask
            add_qk,  # attention_bias
            past_k,
            past_v,
        ]

        mha_outputs = [output]
        if present_k and present_v:
            mha_outputs.extend([present_k, present_v])

        mha_node = helper.make_node(
            "MultiHeadAttention",
            inputs=mha_inputs,
            outputs=mha_outputs,
            name=mha_node_name,
        )

        mha_node.domain = "com.microsoft"
        mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
        if scale is not None:
            mha_node.attribute.extend([helper.make_attribute("scale", scale)])
        if self.mask_filter_value is not None:
            mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])

        self.increase_counter("MultiHeadAttention")
        return mha_node

    def check_runtime_shape_paths_for_function(
        self,
        reshape_qkv_2,  # Reshape after Transpose
        reshape_qkv_1,  # Reshape before Transpose
        reshape_q_2,  # Reshape after RotaryEmbedding
        reshape_k_2,  # Reshape after RotaryEmbedding
        reshape_v_2,  # Reshape after Transpose
        reshape_v_1,  # Reshape before Transpose
        add_qk,  # Add before Softmax
        root_input,  # Root input to attention subgraph
    ):
        # Check #1: check paths for qkv nodes
        concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1])
        concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1])
        if concat_qkv_2_path is None or concat_qkv_1_path is None:
            return False
        concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0]

        reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0])
        if (
            reshape_qkv_2_path_1 is None
            or reshape_qkv_2_path_2 is None
            or reshape_qkv_1_path_1 is None
            or reshape_qkv_1_path_2 is None
        ):
            return False

        _, gather_1, shape_1 = reshape_qkv_2_path_1
        _, gather_2, shape_2 = reshape_qkv_2_path_2

        # Check root_input --> Shape --> Gather connection
        if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
            return False

        # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2
        if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name:
            return False

        # Check #2: check paths for v nodes
        concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1])
        concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1])
        if concat_v_2_path is None or concat_v_1_path is None:
            return False
        concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0]

        reshape_v_2_path_1 = self.model.match_parent_path(
            concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
        )
        reshape_v_2_path_2 = self.model.match_parent_path(
            concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0]
        )
        reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if (
            reshape_v_2_path_1 is None
            or reshape_v_2_path_2 is None
            or reshape_v_1_path_1 is None
            or reshape_v_1_path_2 is None
        ):
            return False

        # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1
        # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2
        # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2
        if (
            reshape_v_2_path_1[2].name != gather_1.name
            or reshape_v_2_path_2[2].name != gather_2.name
            or reshape_v_1_path_1[1].name != gather_1.name
            or reshape_v_1_path_2[1].name != gather_2.name
        ):
            return False

        # Check #3: check paths for k nodes
        concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1])
        if concat_k_2_path is None:
            return False
        concat_k_2 = concat_k_2_path[0]

        reshape_k_2_path_1 = self.model.match_parent_path(
            concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
        )
        reshape_k_2_path_2 = self.model.match_parent_path(
            concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0]
        )
        if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None:
            return False

        # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1
        # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2
        if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name:
            return False

        # Check #4: check paths for q nodes
        concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1])
        if concat_q_2_path is None:
            return False
        concat_q_2 = concat_q_2_path[0]

        reshape_q_2_path_1 = self.model.match_parent_path(
            concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0]
        )
        reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None:
            return False

        # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1
        # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2
        if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name:
            return False

        # Check #5: check Mul nodes are the same for q, k, v
        mul_q = reshape_q_2_path_1[1]
        mul_k = reshape_k_2_path_1[1]
        mul_v = reshape_v_2_path_1[1]
        gather_1_out = gather_1.output[0]
        if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out:
            return False

        # Check #6: check paths for attention mask nodes
        attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0])
        attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0])
        if attn_mask_path_1 is not None:
            _, slice_qk_2, slice_qk_1 = attn_mask_path_1
        elif attn_mask_path_2 is not None:
            _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2
        else:
            return False
        # Check first input to Slice #1 is 3D attention mask of shape (B,S,T)
        if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}:
            return False

        slice_qk_2_path = self.model.match_parent_path(
            slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
        )
        slice_qk_1_path_1 = self.model.match_parent_path(
            slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0]
        )
        slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1])
        if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None:
            return False

        # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path
        # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1
        if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name:
            return False

        # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2
        # Check if first input to Add and Unsqueeze #1 is position ids
        if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]:
            return False

        return True

    def check_runtime_shape_paths_for_nodes(
        self,
        reshape_qkv,  # Final reshape before o_proj MatMul
        reshape_q,  # Reshape before q_proj MatMul
        reshape_k,  # Reshape before k_proj MatMul
        reshape_v,  # Reshape before v_proj MatMul
        root_input,  # Root input to attention subgraph
    ):
        # Check #1: check paths for qkv nodes
        concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1])
        if concat_qkv_path is None:
            return False
        concat_qkv = concat_qkv_path[0]

        reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None:
            return False

        _, gather_1, shape_1 = reshape_qkv_path_1
        _, gather_2, shape_2 = reshape_qkv_path_2

        # Check root_input --> Shape --> Gather connection
        if shape_1.input[0] != root_input or shape_2.input[0] != root_input:
            return False

        # Check #2: check paths for v nodes
        concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1])
        if concat_v_path is None:
            return False
        concat_v = concat_v_path[0]

        reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_v_path_1 is None or reshape_v_path_2 is None:
            return False

        # Check Gather --> Unsqueeze --> Concat --> Reshape connection
        if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name:
            return False

        # Check #3: check paths for k nodes
        concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1])
        if concat_k_path is None:
            return False
        concat_k = concat_k_path[0]

        reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_k_path_1 is None or reshape_k_path_2 is None:
            return False

        # Check Gather --> Unsqueeze --> Concat --> Reshape connection
        if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name:
            return False

        # Check #4: check paths for q nodes
        concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1])
        if concat_q_path is None:
            return False
        concat_q = concat_q_path[0]

        reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0])
        reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0])
        if reshape_q_path_1 is None or reshape_q_path_2 is None:
            return False

        # Check Gather --> Unsqueeze --> Concat --> Reshape connection
        if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name:
            return False

        return True

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        if normalize_node.op_type not in {"SkipSimplifiedLayerNormalization", "SkipLayerNormalization", "Add"}:
            return

        # qkv_nodes_1 is for LLaMA-2 Microsoft
        # qkv_nodes_2 is for LLaMA-2 Hugging Face
        # qkv_nodes_3 is for LLaMA-2 distribute Hugging Face model
        qkv_nodes = None
        qkv_nodes_1 = self.model.match_parent_path(
            normalize_node,
            ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 0, 0],
        )
        qkv_nodes_2 = self.model.match_parent_path(
            normalize_node,
            ["MatMul", "Reshape", "Transpose", "MatMul"],
            [1, 0, 0, 0],
        )
        qkv_nodes_3 = self.model.match_parent_path(
            normalize_node,
            ["AllReduce", "MatMul", "Reshape", "Transpose", "MatMul"],
            [1, 0, 0, 0, 0],
        )
        if qkv_nodes_1 is not None:
            _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1
            qkv_nodes = qkv_nodes_1
        elif qkv_nodes_2 is not None:
            _, reshape_qkv, _, matmul_qkv = qkv_nodes_2
            qkv_nodes = qkv_nodes_2
        elif qkv_nodes_3 is not None:
            _, _, reshape_qkv, _, matmul_qkv = qkv_nodes_3
            qkv_nodes = qkv_nodes_3
        else:
            logger.debug("fuse_rotary_attention: failed to match qkv nodes")
            return

        # v_nodes_1 is for LLaMA-2 Microsoft
        # v_nodes_3 is for LLaMA-2 Hugging Face
        # v_nodes_4 is for LLaMA-2 70B model
        # v_nodes_5 is for Phi-2 DirectML
        past_v, present_v, past_seq_len = "", "", ""
        v_nodes = None
        add_v = None
        v_nodes_1 = self.model.match_parent_path(
            matmul_qkv,
            ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 1, 0, 0],
        )
        v_nodes_2 = self.model.match_parent_path(
            matmul_qkv,
            ["Concat", "Transpose", "Reshape", "MatMul"],
            [1, 1, 0, 0],
        )
        v_nodes_3 = self.model.match_parent_path(
            matmul_qkv,
            ["Transpose", "Reshape", "MatMul"],
            [1, 0, 0],
        )
        _, v_nodes_4, _ = self.model.match_parent_paths_all(
            matmul_qkv,
            [
                (
                    ["Reshape", "Expand", "Unsqueeze", "Concat", "Transpose", "Reshape", "MatMul"],
                    [1, 0, 0, 0, 1, 0, 0],
                ),
                (
                    [
                        "Reshape",
                        "Expand",
                        "Where",
                        "Equal",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
                ),
                (
                    [
                        "Reshape",
                        "Expand",
                        "Where",
                        "Equal",
                        "Mul",
                        "ConstantOfShape",
                        "Shape",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0],
                ),
                (
                    [
                        "Reshape",
                        "Expand",
                        "Where",
                        "ConstantOfShape",
                        "Shape",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0],
                ),
                (
                    [
                        "Reshape",
                        "Expand",
                        "Where",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0],
                ),
                (
                    ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
                    [1, 1, 0, 0, 0, 0, 1, 0, 0],
                ),
                (
                    [
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Mul",
                        "Gather",
                        "Shape",
                        "Concat",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 1, 1, 0, 0, 0, 0, 1, 0, 0],
                ),
                (
                    ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
                    [1, 1, 2, 0, 0, 0, 1, 0, 0],
                ),
                (
                    ["Reshape", "Concat", "Unsqueeze", "Gather", "Shape", "Concat", "Transpose", "Reshape", "MatMul"],
                    [1, 1, 3, 0, 0, 0, 1, 0, 0],
                ),
            ],
            output_name_to_node=None,
        )
        v_nodes_5 = self.model.match_parent_path(
            matmul_qkv,
            ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 1, 0, 0, 1],
        )
        if v_nodes_1 is not None:
            reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1
            v_nodes = v_nodes_1

            concat_v_path = self.model.match_parent_path(
                concat_v,
                ["Slice", "Unsqueeze"],
                [0, 2],
            )
            if concat_v_path is None:
                logger.debug("fuse_rotary_attention: failed to match past/present concat in v path")
                return

            past_v = concat_v_path[0].input[0]
            past_seq_len = concat_v_path[-1].input[0]
            present_v = concat_v.output[0]
        elif v_nodes_2 is not None:
            concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2
            v_nodes = v_nodes_2
            past_v = concat_v.input[0]
            present_v = concat_v.output[0]
        elif v_nodes_3 is not None:
            transpose_v, reshape_v, matmul_v = v_nodes_3
            v_nodes = v_nodes_3
            present_v = transpose_v.output[0]
        elif v_nodes_4 is not None and len(v_nodes_4) == 9:
            concat_v, transpose_v, reshape_v, matmul_v = v_nodes_4[0][-4:]
            v_nodes = v_nodes_4
            past_v = concat_v.input[0]
            present_v = concat_v.output[0]
        elif v_nodes_5 is not None:
            concat_v, transpose_v, reshape_v, add_v, matmul_v = v_nodes_5
            matmul_v = add_v
            v_nodes = v_nodes_5
            past_v = concat_v.input[0]
            present_v = concat_v.output[0]
        else:
            logger.debug("fuse_rotary_attention: failed to match v path")
            return

        qk_nodes = self.model.match_parent_path(
            matmul_qkv,
            ["Softmax", "Add", "Div", "MatMul"],
            [0, 0, 0, 0],
        )
        add_qk, matmul_qk = None, None
        if qk_nodes is not None:
            _, add_qk, _, matmul_qk = qk_nodes
        else:
            logger.debug("fuse_rotary_attention: failed to match qk nodes")
            return

        # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask
        # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask
        # attn_mask_nodes_5, attn_mask_nodes_6 are for LLaMA-2 Microsoft's model for the DML EP
        # attn_mask_nodes_7 is for LLaMA-2 Hugging Face's changes to the attention mask
        attn_mask, add_qk_str = "", ""
        attn_mask_nodes_1 = self.model.match_parent_path(
            add_qk,
            ["Concat", "Slice", "Slice"],
            [1, 0, 0],
        )
        attn_mask_nodes_2 = self.model.match_parent_path(
            add_qk,
            ["Cast", "Concat", "Slice", "Slice"],
            [1, 0, 0, 0],
        )
        attn_mask_nodes_3 = self.model.match_parent_path(
            add_qk,
            ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
            [1, 0, 2, 1, 0, 0, 0],
        )
        attn_mask_nodes_4 = self.model.match_parent_path(
            add_qk,
            ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
            [1, 2, 1, 0, 0, 0],
        )
        attn_mask_nodes_5 = self.model.match_parent_path(
            add_qk,
            ["Expand", "Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
            [1, 0, 0, 2, 1, 0, 0, 0],
        )
        attn_mask_nodes_6 = self.model.match_parent_path(
            add_qk,
            ["Expand", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
            [1, 0, 2, 1, 0, 0, 0],
        )
        attn_mask_nodes_7 = self.model.match_parent_path(
            add_qk,
            ["Where", "Cast", "Where", "Cast", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"],
            [1, 0, 0, 0, 0, 1, 0, 0, 0],
        )
        if attn_mask_nodes_1 is not None:
            _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1
            attn_mask = slice_mask_1.output[0]
        elif attn_mask_nodes_2 is not None:
            _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2
            attn_mask = slice_mask_1.output[0]
        elif attn_mask_nodes_3 is not None:
            # Reshape from (B,1,S,T) to (B,N,S,T)
            add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0])
        elif attn_mask_nodes_4 is not None:
            # Reshape from (B,1,S,T) to (B,N,S,T)
            add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0])
        elif attn_mask_nodes_5 is not None:
            # The mask has already been reshaped to (B,N,S,T)
            add_qk_str = attn_mask_nodes_5[0].output[0]
        elif attn_mask_nodes_6 is not None:
            # The mask has already been reshaped to (B,N,S,T)
            add_qk_str = attn_mask_nodes_6[0].output[0]
        elif attn_mask_nodes_7 is not None:
            # Reshape from (B,1,S,T) to (B,N,S,T)
            add_qk_str = self.reshape_add_qk(attn_mask_nodes_7[0].output[0])
        else:
            logger.debug("fuse_rotary_attention: failed to match attention mask nodes")
            return

        # k_nodes_1 is for LLaMA-2 Microsoft
        # k_nodes_2 is for LLaMA-2 Hugging Face
        # k_nodes_4 is for LLaMA-2 70B Hugging Face
        past_k, present_k = "", ""
        k_nodes = None
        slice_k = None
        concat_k_half = None
        k_nodes_1 = self.model.match_parent_path(
            matmul_qk,
            ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"],
            [1, 0, 0, 1, 0, 0],
        )
        k_nodes_2 = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
            [1, 0, 0, 0, 0],
        )
        k_nodes_3 = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
            [1, 0, 1, 0, 0, 0],
        )
        _, k_nodes_4, _ = self.model.match_parent_paths_all(
            matmul_qk,
            [
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Expand",
                        "Unsqueeze",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Expand",
                        "Where",
                        "Equal",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Expand",
                        "Where",
                        "Equal",
                        "Mul",
                        "ConstantOfShape",
                        "Shape",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Expand",
                        "Where",
                        "ConstantOfShape",
                        "Shape",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 0, 1, 1, 0, 0, 0, 3, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Expand",
                        "Where",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 0, 1, 2, 0, 4, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Mul",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 2, 0, 0, 0, 1, 0, 0, 0],
                ),
                (
                    [
                        "Transpose",
                        "Reshape",
                        "Concat",
                        "Unsqueeze",
                        "Gather",
                        "Shape",
                        "Concat",
                        "RotaryEmbedding",
                        "Transpose",
                        "Reshape",
                        "MatMul",
                    ],
                    [1, 0, 1, 3, 0, 0, 0, 1, 0, 0, 0],
                ),
            ],
            output_name_to_node=None,
        )
        k_nodes_5 = self.model.match_parent_path(
            matmul_qk,
            ["Transpose", "Concat", "Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
            [1, 0, 1, 0, 0, 0, 0, 0, 1],
        )
        if k_nodes_1 is not None:
            reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1
            k_nodes = k_nodes_1

            concat_k_path = self.model.match_parent_path(
                concat_k,
                ["Slice", "Unsqueeze"],
                [0, 2],
            )
            if concat_k_path is None:
                logger.debug("fuse_rotary_attention: failed to match past/present concat in k path")
                return

            past_k = concat_k_path[0].input[0]
            shared_past_seq_len = concat_k_path[-1].input[0]
            present_k = concat_k.output[0]

            assert past_seq_len == shared_past_seq_len
        elif k_nodes_2 is not None:
            _, rotary_k, _, reshape_k, matmul_k = k_nodes_2
            k_nodes = k_nodes_2
            present_k = rotary_k.output[0]
        elif k_nodes_3 is not None:
            _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3
            k_nodes = k_nodes_3
            past_k = concat_k.input[0]
            present_k = concat_k.output[0]
        elif k_nodes_4 is not None and len(k_nodes_4) == 9:
            reshape_k, matmul_k = k_nodes_4[0][-2:]
            concat_k, rotary_k = k_nodes_4[0][-5:-3]
            k_nodes = k_nodes_4
            past_k = concat_k.input[0]
            present_k = concat_k.output[0]
        elif k_nodes_5 is not None:
            _, concat_k, concat_k_half, rotary_k, slice_k, _, reshape_k, _, matmul_k = k_nodes_5
            k_nodes = k_nodes_5
            past_k = concat_k.input[0]
            present_k = concat_k.output[0]
        else:
            logger.debug("fuse_rotary_attention: failed to match k nodes")
            return

        # q_nodes_1 is for LLaMA-2 Microsoft
        # q_nodes_2 is for LLaMA-2 Hugging Face
        # q_nodes_3 is for Phi-2 DirectML
        q_nodes = None
        slice_q = None
        concat_q_half = None
        q_nodes_1 = self.model.match_parent_path(
            matmul_qk,
            ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"],
            [0, 0, 0, 0],
        )
        q_nodes_2 = self.model.match_parent_path(
            matmul_qk,
            ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"],
            [0, 0, 0, 0],
        )
        q_nodes_3 = self.model.match_parent_path(
            matmul_qk,
            ["Concat", "RotaryEmbedding", "Slice", "Transpose", "Reshape", "Add", "MatMul"],
            [0, 0, 0, 0, 0, 0, 1],
        )
        if q_nodes_1 is not None:
            reshape_q_2, _, rotary_q, matmul_q = q_nodes_1
            q_nodes = q_nodes_1
        elif q_nodes_2 is not None:
            rotary_q, _, reshape_q, matmul_q = q_nodes_2
            q_nodes = q_nodes_2
        elif q_nodes_3 is not None:
            concat_q_half, rotary_q, slice_q, _, reshape_q, _, matmul_q = q_nodes_3
            q_nodes = q_nodes_3
        else:
            logger.debug("fuse_rotary_attention: failed to match q nodes")
            return

        if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]:
            logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths")
            return

        root_output = ""
        if qkv_nodes == qkv_nodes_1:
            if not self.check_runtime_shape_paths_for_function(
                reshape_qkv_2,
                reshape_qkv_1,
                reshape_q_2,
                reshape_k_2,
                reshape_v_2,
                reshape_v_1,
                add_qk,
                matmul_q.input[0],
            ):
                logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
                return
            root_output = reshape_qkv_2.output[0]

        elif qkv_nodes in (qkv_nodes_2, qkv_nodes_3):
            if not self.check_runtime_shape_paths_for_nodes(
                reshape_qkv,
                reshape_q,
                reshape_k,
                reshape_v,
                matmul_q.input[0],
            ):
                logger.debug("fuse_rotary_attention: failed to verify runtime shape paths")
                return
            root_output = reshape_qkv.output[0]

            # Rename inputs of rotary_q/k so it connects with output of matmul_q/k
            # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding
            # After: MatMul --> RotaryEmbedding
            rotary_q.input[0] = slice_q.output[0] if slice_q else matmul_q.output[0]
            rotary_k.input[0] = slice_k.output[0] if slice_k else matmul_k.output[0]

            # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key)
            if concat_q_half is None:
                rotary_k.output[0] = rotary_k.name + "_output_0"

            if qkv_nodes == qkv_nodes_3:
                qkv_nodes = qkv_nodes[1:]

        def create_hidden_size_concat_node(reshape_q):
            """Detect num_heads and hidden_size for ONNX model from phi-2
            Args:
                reshape_q (NodeProto): reshape node for q
            Returns:
                hidden_size_concat_node(NodeProto): Concat node to be used by reshape
            """
            concat = self.model.match_parent(reshape_q, "Concat", 1)

            if concat is None:
                logger.debug("fuse_rotary_attention: failed to trace the concat node from reshape_q")
                return None

            # The shape is a tensor like [?, ?, num_heads, head_size]
            num_head_constant_node = self.model.get_constant_value(concat.input[2])
            head_size_constant_node = self.model.get_constant_value(concat.input[3])

            if num_head_constant_node is None or head_size_constant_node is None:
                logger.debug("fuse_rotary_attention: failed to get constant nodes of num_heads or head_size")
                return None

            num_head_value = num_head_constant_node[0]
            head_size_value = head_size_constant_node[0]

            hidden_size = num_head_value * head_size_value

            hidden_size_initilizer = self.model.create_node_name("Initializer", name_prefix="hidden_size")
            if self.model.get_initializer(hidden_size_initilizer) is None:
                self.add_initializer(
                    name=hidden_size_initilizer,
                    data_type=TensorProto.INT64,
                    dims=[1],
                    vals=[hidden_size],
                    raw=False,
                )

            hidden_size_reshape_node_name = self.model.create_node_name("Concat", name_prefix="hidden_size_concat")

            hidden_size_concat_node = helper.make_node(
                "Concat",
                inputs=[
                    concat.input[0],
                    concat.input[1],
                    hidden_size_initilizer,
                ],
                outputs=[hidden_size_reshape_node_name + "output_0"],
                name=hidden_size_reshape_node_name,
            )
            hidden_size_concat_node.attribute.extend([helper.make_attribute("axis", 0)])

            return hidden_size_concat_node

        # Add Tranpose and Reshape nodes for patial rotary embedding applied in phi-2 before passing into MHA
        if concat_q_half and concat_k_half:
            # Transpose the key output of rotary Embedding
            k_transpose_node_name = self.model.create_node_name("Transpose")
            k_tranpose_output_name = k_transpose_node_name + "_output_0"
            k_transpose_node = helper.make_node(
                "Transpose",
                inputs=[concat_k_half.output[0]],
                outputs=[k_tranpose_output_name],
                name=k_transpose_node_name,
            )

            k_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])

            # Transpose the query output of rotary Embedding
            q_transpose_node_name = self.model.create_node_name("Transpose")
            q_tranpose_output_name = q_transpose_node_name + "_output_0"
            q_transpose_node = helper.make_node(
                "Transpose",
                inputs=[concat_q_half.output[0]],
                outputs=[q_tranpose_output_name],
                name=q_transpose_node_name,
            )

            q_transpose_node.attribute.extend([helper.make_attribute("perm", [0, 2, 1, 3])])

            hidden_size_concat_node = create_hidden_size_concat_node(reshape_k)
            if hidden_size_concat_node is None:
                logger.debug("fuse_rotary_attention: failed to create hidden_size_concat_node")
                return

            # Reshape the Rotary Embedding output for key for 4D to 3D
            concat_k_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_k_half")
            concat_k_reshape_node = helper.make_node(
                "Reshape",
                inputs=[k_transpose_node.output[0], hidden_size_concat_node.output[0]],
                outputs=[concat_k_reshape_node_name + "_output_0"],
                name=concat_k_reshape_node_name,
            )

            # Reshape the Rotary Embedding output for query from 4D to 3D
            concat_q_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="concat_q_half")
            concat_q_reshape_node = helper.make_node(
                "Reshape",
                inputs=[q_transpose_node.output[0], hidden_size_concat_node.output[0]],
                outputs=[concat_q_reshape_node_name + "_output_0"],
                name=concat_q_reshape_node_name,
            )

            rotary_k = concat_k_reshape_node
            rotary_q = concat_q_reshape_node

            self.nodes_to_add.append(hidden_size_concat_node)
            self.nodes_to_add.append(k_transpose_node)
            self.nodes_to_add.append(q_transpose_node)
            self.nodes_to_add.append(concat_k_reshape_node)
            self.nodes_to_add.append(concat_q_reshape_node)

            self.node_name_to_graph_name[hidden_size_concat_node.name] = self.this_graph_name
            self.node_name_to_graph_name[k_transpose_node.name] = self.this_graph_name
            self.node_name_to_graph_name[q_transpose_node.name] = self.this_graph_name
            self.node_name_to_graph_name[concat_k_reshape_node.name] = self.this_graph_name
            self.node_name_to_graph_name[concat_q_reshape_node.name] = self.this_graph_name

        new_node = self.create_mha_node(
            matmul_q.input[0],
            root_output,
            rotary_q,
            rotary_k,
            matmul_v,
            attn_mask,
            add_qk_str,
            past_k,
            past_v,
            present_k,
            present_v,
        )
        if new_node is None:
            logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings")
            return

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

        self.nodes_to_remove.extend(qkv_nodes[1:])

        if v_nodes != v_nodes_4:
            self.nodes_to_remove.extend(v_nodes[:-1] if add_v is None else v_nodes[:-2])
        else:
            nodes_to_keep = [v_nodes[0][-1]]
            for temp_path in v_nodes:
                self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)

        self.nodes_to_remove.extend(qk_nodes)

        if k_nodes == k_nodes_1:
            self.nodes_to_remove.extend(k_nodes[:-2])
        elif k_nodes == k_nodes_2:
            self.nodes_to_remove.append(k_nodes[0])
            self.nodes_to_remove.append(k_nodes[2])
            self.nodes_to_remove.append(k_nodes[3])
        elif k_nodes == k_nodes_3:
            self.nodes_to_remove.append(k_nodes[0])
            self.nodes_to_remove.append(k_nodes[1])
            self.nodes_to_remove.append(k_nodes[3])
            self.nodes_to_remove.append(k_nodes[4])
        elif k_nodes == k_nodes_5:
            self.nodes_to_remove.append(k_nodes[0])
            self.nodes_to_remove.append(k_nodes[1])
        elif k_nodes == k_nodes_4:
            nodes_to_keep = [k_nodes[0][-1], k_nodes[0][-4]]
            for temp_path in k_nodes:
                self.add_nodes_to_remove_with_nodes_to_keep(temp_path, nodes_to_keep)

        if q_nodes == q_nodes_1:
            self.nodes_to_remove.extend(q_nodes[:-2])
        elif q_nodes == q_nodes_2:
            self.nodes_to_remove.append(q_nodes[1])
            self.nodes_to_remove.append(q_nodes[2])
        self.prune_graph = True


class FusionRotaryEmbeddings(Fusion):
    def __init__(self, model: OnnxModel):
        self.base_name = "RotaryEmbedding"
        super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"])

    # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output.
    # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter.
    # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used.
    def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto):
        # Find extra outputs and Constant nodes attached to those outputs
        extra_constants, extra_outputs = [], []
        for fn_node in function.node:
            if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output:
                extra_constants.append(fn_node)
                output_index = list(function.output).index(fn_node.output[0])
                extra_outputs.append(rot_emb_node.output[output_index])

        # Set extra Constant node outputs as initializers
        extra_initializers = []
        for extra_constant in extra_constants:
            constant_tensorproto = extra_constant.attribute[0].t
            constant_tensorproto.name = self.model.create_node_name("Constant")
            self.model.add_initializer(constant_tensorproto)
            extra_initializers.append(constant_tensorproto.name)

        # Update references of Constant node outputs to initializer references
        for extra_output, extra_initializer in zip(extra_outputs, extra_initializers):
            nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node))
            for node_to_update in nodes_to_update:
                OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer)

        return extra_outputs

    def create_rotary_embeddings_from_function(self, node: NodeProto):
        rotary_emb_node_name = self.model.create_node_name(self.base_name)

        matmul_path = self.model.match_parent_path(
            node,
            ["Reshape", "MatMul"],
            [0, 0],
        )
        if matmul_path is not None:
            reshape_node, matmul_node = matmul_path
        else:
            logger.debug("fuse_rotary_embeddings: failed to match MatMul")
            return

        rotary_emb_inputs = [
            matmul_node.output[0],  # x is of shape (B,S,D) instead of (B,S,N,H)
            node.input[1],  # position_ids
        ]

        # Convert cos_cache and sin_cache from node attributes to model initializers
        cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node))
        sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node))
        cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"

        if (
            len(cos_cache_node) == 1
            and len(sin_cache_node) == 1
            and self.model.get_initializer(cos_cache_name) is None
            and self.model.get_initializer(sin_cache_name) is None
        ):
            cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
            sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()

            cos_cache_tensor = helper.make_tensor(
                name=cos_cache_name,
                data_type=TensorProto.FLOAT,
                dims=list(cos_cache.shape),
                vals=cos_cache.flatten().tolist(),
            )
            self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
            sin_cache_tensor = helper.make_tensor(
                name=sin_cache_name,
                data_type=TensorProto.FLOAT,
                dims=list(sin_cache.shape),
                vals=sin_cache.flatten().tolist(),
            )
            self.model.add_initializer(sin_cache_tensor, self.this_graph_name)

            self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])

        rotary_emb_inputs.extend([cos_cache_name, sin_cache_name])

        rotary_emb_outputs = node.output
        if len(rotary_emb_outputs) > 1:
            # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers
            func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions))
            assert len(func) == 1
            extra_outputs = self.reassign_extra_outputs(node, func[0])
            rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs))
            assert len(rotary_emb_outputs) == 1

        rotary_emb_node = helper.make_node(
            self.base_name,
            inputs=rotary_emb_inputs,
            outputs=rotary_emb_outputs,
            name=rotary_emb_node_name,
            interleaved=1,
        )
        rotary_emb_node.domain = "com.microsoft"

        self.nodes_to_remove.append(reshape_node)

        return rotary_emb_node

    def create_rotary_embeddings_from_nodes(
        self,
        root_input: str,
        position_ids: str,
        cos_slice: str,
        sin_slice: str,
        output: str,
    ):
        rotary_emb_node_name = self.model.create_node_name(self.base_name)

        # Convert cos_cache and sin_cache from node attributes to model initializers
        cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node))
        sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node))
        cos_cache_name, sin_cache_name = "cos_cache", "sin_cache"

        if (
            len(cos_cache_node) == 1
            and len(sin_cache_node) == 1
            and self.model.get_initializer(cos_cache_name) is None
            and self.model.get_initializer(sin_cache_name) is None
        ):
            cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze()
            sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze()

            # Reshape cos/sin cache from (M, H) to (M, H/2)
            head_size = cos_cache.shape[1]
            cos_cache = cos_cache[:, : (head_size // 2)]
            sin_cache = sin_cache[:, : (head_size // 2)]

            cos_cache_tensor = helper.make_tensor(
                name=cos_cache_name,
                data_type=TensorProto.FLOAT,
                dims=list(cos_cache.shape),
                vals=cos_cache.flatten().tolist(),
            )
            self.model.add_initializer(cos_cache_tensor, self.this_graph_name)
            sin_cache_tensor = helper.make_tensor(
                name=sin_cache_name,
                data_type=TensorProto.FLOAT,
                dims=list(sin_cache.shape),
                vals=sin_cache.flatten().tolist(),
            )
            self.model.add_initializer(sin_cache_tensor, self.this_graph_name)

            self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]])

        rotary_emb_node = helper.make_node(
            self.base_name,
            inputs=[root_input, position_ids, cos_cache_name, sin_cache_name],
            outputs=[output],
            name=rotary_emb_node_name,
            interleaved=0,
        )
        rotary_emb_node.domain = "com.microsoft"
        return rotary_emb_node

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        # Node is either RotaryEmbedding function or Add
        if self.base_name not in node.op_type and node.op_type != "Add":
            return

        # Check if node is "RotaryEmbedding nn.Module" exported as a function
        # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export)
        rotary_emb_node = None
        if node.op_type != "Add":
            # Verify that function has the correct inputs
            if len(node.input) not in {4, 5} or node.input[1] not in {
                "pos",
                "pos_id",
                "position_id",
                "pos_ids",
                "position_ids",
            }:
                logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function")
                return

            rotary_emb_node = self.create_rotary_embeddings_from_function(node)
            if rotary_emb_node is None:
                logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
                return

            # Remove RotaryEmbedding function
            self.nodes_to_remove.append(node)

            # Remove RotaryEmbedding function's shape inference stored in value_info
            # The new shape will be calculated during symbolic shape inference
            old_shape_infer = list(
                filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info)
            )
            assert len(old_shape_infer) == 1
            self.model.model.graph.value_info.remove(old_shape_infer[0])

        else:
            # Rotary embeddings are defined using the below functions:
            #
            # def rotate_half(x):
            #     """Rotates half the hidden dims of the input."""
            #     x1 = x[..., : x.shape[-1] // 2]
            #     x2 = x[..., x.shape[-1] // 2 :]
            #     return torch.cat((-x2, x1), dim=-1)
            #
            # def apply_rope(x, cos, sin, position_ids):
            #     cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]
            #     sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]
            #     cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
            #     sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
            #     x_embed = (x * cos) + (rotate_half(x) * sin)
            #     return x_embed

            # Check paths for rotate_half(x)
            rotate_half_x2_path_1_1 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Neg", "Slice", "Transpose"],
                [1, 0, 0, 0, 0],
            )

            rotate_half_x2_path_1_2 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Neg", "Slice", "Slice"],
                [1, 0, 0, 0, 0],
            )

            rotate_half_x2_path_1 = rotate_half_x2_path_1_1 or rotate_half_x2_path_1_2

            rotate_half_x2_path_2_1 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
                [1, 0, 0, 0, 1, 0, 0, 0, 0],
            )

            rotate_half_x2_path_2_2 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
                [1, 0, 0, 0, 1, 0, 0, 0, 0],
            )

            rotate_half_x2_path_2 = rotate_half_x2_path_2_1 or rotate_half_x2_path_2_2

            if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None:
                logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half")
                return

            rotate_half_x1_path_1_1 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Slice", "Transpose"],
                [1, 0, 1, 0],
            )

            rotate_half_x1_path_1_2 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Slice", "Slice"],
                [1, 0, 1, 0],
            )

            rotate_half_x1_path_1 = rotate_half_x1_path_1_1 or rotate_half_x1_path_1_2

            rotate_half_x1_path_2_1 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"],
                [1, 0, 1, 2, 0, 0, 0, 0],
            )

            rotate_half_x1_path_2_2 = self.model.match_parent_path(
                node,
                ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Slice"],
                [1, 0, 1, 2, 0, 0, 0, 0],
            )

            rotate_half_x1_path_2 = rotate_half_x1_path_2_1 or rotate_half_x1_path_2_2

            if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None:
                logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half")
                return

            if (
                rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name
                or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name
                or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name
                or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name
            ):
                logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half")
                return

            # Check path for x
            x_path_1 = self.model.match_parent_path(
                node,
                ["Mul", "Transpose"],
                [0, 0],
            )

            x_path_2 = self.model.match_parent_path(
                node,
                ["Mul", "Slice"],
                [0, 0],
            )

            x_path = x_path_1 or x_path_2

            if x_path is None:
                logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half")
                return

            # Check path for sin
            sin_path, sin_cache, position_ids = None, "", ""
            sin_path_1 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
                [1, 1, 0, 0, 0, 0, 2, 0, 0],
            )
            sin_path_2 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
                [1, 1, 0, 0, 0, 0, 2, 0],
            )
            sin_path_3 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
                [1, 1, 0, 0, 2, 0, 0],
            )
            sin_path_4 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
                [1, 1, 0, 0, 2, 0],
            )
            if sin_path_1 is not None:
                sin_path = sin_path_1
                sin_cache = sin_path[-4].input[0]
            elif sin_path_2 is not None:
                sin_path = sin_path_2
                sin_cache = sin_path[-3].input[0]
            elif sin_path_3 is not None:
                sin_path = sin_path_3
                sin_cache = sin_path[-4].input[0]
                position_ids = sin_path[2].input[1]
            elif sin_path_4 is not None:
                sin_path = sin_path_4
                sin_cache = sin_path[-3].input[0]
                position_ids = sin_path[2].input[1]
            else:
                logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
                return

            # Check path for cos
            cos_path, cos_cache = None, ""
            cos_path_1 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"],
                [0, 1, 0, 0, 0, 0, 2, 0, 0],
            )
            cos_path_2 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"],
                [0, 1, 0, 0, 0, 0, 2, 0],
            )
            cos_path_3 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"],
                [0, 1, 0, 0, 2, 0, 0],
            )
            cos_path_4 = self.model.match_parent_path(
                node,
                ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"],
                [0, 1, 0, 0, 2, 0],
            )
            if cos_path_1 is not None:
                cos_path = cos_path_1
                cos_cache = cos_path[-4].input[0]
            elif cos_path_2 is not None:
                cos_path = cos_path_2
                cos_cache = cos_path[-3].input[0]
            elif cos_path_3 is not None:
                cos_path = cos_path_3
                cos_cache = cos_path[-4].input[0]
                position_ids = cos_path[2].input[1]
            elif cos_path_4 is not None:
                cos_path = cos_path_4
                cos_cache = cos_path[-3].input[0]
                position_ids = cos_path[2].input[1]
            else:
                logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope")
                return

            # Check path for position ids
            if position_ids == "":
                position_ids_from_sin_path = self.model.match_parent_path(
                    sin_path[2],
                    ["Reshape"],
                    [1],
                )
                position_ids_from_cos_path = self.model.match_parent_path(
                    cos_path[2],
                    ["Reshape"],
                    [1],
                )
                if (
                    position_ids_from_sin_path is None
                    or position_ids_from_cos_path is None
                    or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name
                ):
                    logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope")
                    return
                position_ids = position_ids_from_cos_path[0].input[0]
            else:
                position_ids_from_sin_path = []
                position_ids_from_cos_path = []

            past_seq_len_path, curr_seq_len_path = None, None
            if (sin_path == sin_path_1 and cos_path == cos_path_1) or (
                sin_path == sin_path_3 and cos_path == cos_path_3
            ):
                if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name:
                    logger.debug(
                        "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache"
                    )
                    return
            elif (sin_path == sin_path_2 and cos_path == cos_path_2) or (
                sin_path == sin_path_4 and cos_path == cos_path_4
            ):
                if sin_path[-1].name != cos_path[-1].name:
                    logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache")
                    return
                # Match past sequence length path: past_key --> Shape --> Gather --> Add
                past_seq_len_path = self.model.match_parent_path(
                    sin_path[-1],
                    ["Gather", "Shape"],
                    [1, 0],
                )
                # Match current sequence length path: transpose_k --> Shape --> Gather --> Add
                curr_seq_len_path = self.model.match_parent_path(
                    sin_path[-1],
                    ["Gather", "Shape", "Transpose"],
                    [0, 0, 0],
                )
                if (
                    past_seq_len_path is None
                    or curr_seq_len_path is None
                    or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None
                    or curr_seq_len_path[-1].op_type != "Transpose"
                ):
                    logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths")
                    return
            else:
                logger.debug("fuse_rotary_embeddings: failed to match common cache paths")

            rotary_emb_node = self.create_rotary_embeddings_from_nodes(
                rotate_half_x1_path_1[-1].output[0],
                position_ids,
                cos_cache,
                sin_cache,
                node.output[0],
            )
            if rotary_emb_node is None:
                logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node")
                return

            # Remove rotary embedding nodes
            self.add_nodes_to_remove([node])
            self.add_nodes_to_remove(rotate_half_x1_path_1[:-1])
            self.add_nodes_to_remove(rotate_half_x1_path_2[:-1])
            self.add_nodes_to_remove(rotate_half_x2_path_1[:-1])
            self.add_nodes_to_remove(rotate_half_x2_path_2[:-1])
            self.add_nodes_to_remove(x_path[:-1])
            self.add_nodes_to_remove(sin_path)
            self.add_nodes_to_remove(cos_path)
            self.add_nodes_to_remove(position_ids_from_sin_path[:-1])
            self.add_nodes_to_remove(position_ids_from_cos_path[:-1])

            if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1:
                # In merged HF model, output of Gather in past_seq_len_path is used twice
                # for past_key_values.0.key and once for other past_key_values
                self.add_nodes_to_remove(past_seq_len_path)
            if curr_seq_len_path is not None:
                self.add_nodes_to_remove(curr_seq_len_path[:-1])

        self.increase_counter(self.base_name)
        self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name
        self.nodes_to_add.append(rotary_emb_node)
        self.prune_graph = True
