# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple, Union

import numpy as np
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import NodeProto, TensorProto, helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionAttentionUnet(Fusion):
    """
    Fuse Attention subgraph of UNet into one Attention node.
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
        is_cross_attention: bool,
        enable_packed_qkv: bool,
        enable_packed_kv: bool,
    ):
        super().__init__(
            model,
            "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention",
            ["LayerNormalization"],
        )
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.is_cross_attention = is_cross_attention

        # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA.
        # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization,
        # and CUDA operator pre-packs those tensors to preferred format based on available kernels.
        # In this way, we can support LoRA and get optimal performance at same time.
        self.enable_packed_qkv = enable_packed_qkv
        self.enable_packed_kv = enable_packed_kv

        # Flags to show warning only once
        self.num_heads_warning = True
        self.hidden_size_warning = True

    def get_num_heads(self, reshape_q: NodeProto, is_torch2: bool = False) -> int:
        """Detect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
        Returns:
            int: num_heads, or 0 if not found
        """
        num_heads = 0
        if is_torch2:
            # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
            reshape_parent = self.model.get_parent(reshape_q, 1)
            if reshape_parent and reshape_parent.op_type == "Concat" and len(reshape_parent.input) == 4:
                num_heads = self.model.get_constant_value(reshape_parent.input[2])
                if isinstance(num_heads, np.ndarray) and list(num_heads.shape) == [1]:
                    num_heads = int(num_heads)
        else:
            # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
            q_shape_value = self.model.get_constant_value(reshape_q.input[1])
            if isinstance(q_shape_value, np.ndarray) and list(q_shape_value.shape) == [4]:
                num_heads = int(q_shape_value[2])

        if isinstance(num_heads, int) and num_heads > 0:
            return num_heads

        return 0

    def get_hidden_size(self, layernorm_node):
        """Detect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        """
        layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
        if layernorm_bias:
            return NumpyHelper.to_array(layernorm_bias).shape[0]

        return 0

    def get_num_heads_and_hidden_size(
        self, reshape_q: NodeProto, layernorm_node: NodeProto, is_torch2: bool = False
    ) -> Tuple[int, int]:
        """Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        """
        num_heads = self.get_num_heads(reshape_q, is_torch2)
        if num_heads <= 0:
            num_heads = self.num_heads  # Fall back to user specified value

        if self.num_heads > 0 and num_heads != self.num_heads:
            if self.num_heads_warning:
                logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
                self.num_heads_warning = False  # Do not show the warning more than once

        hidden_size = self.get_hidden_size(layernorm_node)
        if hidden_size <= 0:
            hidden_size = self.hidden_size  # Fall back to user specified value

        if self.hidden_size > 0 and hidden_size != self.hidden_size:
            if self.hidden_size_warning:
                logger.warning(
                    f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
                )
                self.hidden_size_warning = False  # Do not show the warning more than once

        return num_heads, hidden_size

    def create_attention_node(
        self,
        q_matmul: NodeProto,
        k_matmul: NodeProto,
        v_matmul: NodeProto,
        num_heads: int,
        hidden_size: int,
        input: str,
        output: str,
    ) -> Union[NodeProto, None]:
        """Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        """
        is_self_attention = not self.is_cross_attention

        if is_self_attention:
            if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
                logger.debug(
                    "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
                    q_matmul.input[0],
                    k_matmul.input[0],
                    v_matmul.input[0],
                )
                return None
        else:
            if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
                logger.debug(
                    "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
                    q_matmul.input[0],
                    k_matmul.input[0],
                    v_matmul.input[0],
                )
                return None

        if hidden_size > 0 and (hidden_size % num_heads) != 0:
            logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
            return None

        q_weight = self.model.get_initializer(q_matmul.input[1])
        k_weight = self.model.get_initializer(k_matmul.input[1])
        v_weight = self.model.get_initializer(v_matmul.input[1])
        if not (q_weight and k_weight and v_weight):
            return None

        # Sometimes weights are stored in fp16
        float_type = q_weight.data_type

        qw = NumpyHelper.to_array(q_weight)
        kw = NumpyHelper.to_array(k_weight)
        vw = NumpyHelper.to_array(v_weight)
        logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")

        # assert q and k have same shape as expected
        if is_self_attention:
            if qw.shape != kw.shape or qw.shape != vw.shape:
                return None

            qw_in_size = qw.shape[0]

            if hidden_size > 0 and hidden_size != qw_in_size:
                raise ValueError(
                    f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
                    "Please provide a correct input hidden size or pass in 0"
                )

            # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
            # For 2d weights, the shapes would be [in_size, out_size].
            # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
            qw_out_size = int(np.prod(qw.shape[1:]))

            if self.enable_packed_qkv:
                attention_node_name = self.model.create_node_name("MultiHeadAttention")

                c = qw_in_size
                n = num_heads
                h = qw_out_size // num_heads

                # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
                qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
                    c, n * 3 * h
                )

                matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
                self.add_initializer(
                    name=matmul_node_name + "_weight",
                    data_type=float_type,
                    dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
                    vals=qkv_weight,
                )

                matmul_node = helper.make_node(
                    "MatMul",
                    inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
                    outputs=[matmul_node_name + "_out"],
                    name=matmul_node_name,
                )
                self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name

                self.add_initializer(
                    name=matmul_node_name + "_reshape_shape",
                    data_type=TensorProto.INT64,
                    dims=[5],
                    vals=[0, 0, n, 3, h],
                    raw=False,
                )

                reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[
                        matmul_node_name + "_out",
                        matmul_node_name + "_reshape_shape",
                    ],
                    outputs=[attention_node_name + "_qkv_input"],
                    name=matmul_node_name + "_reshape",
                )
                self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
                self.nodes_to_add.extend([matmul_node, reshape_node])
                self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul])

            else:
                qkv_weight = np.stack((qw, kw, vw), axis=1)
                qkv_weight_dim = 3 * qw_out_size

                attention_node_name = self.model.create_node_name("Attention")

                self.add_initializer(
                    name=attention_node_name + "_qkv_weight",
                    data_type=float_type,
                    dims=[qw_in_size, qkv_weight_dim],
                    vals=qkv_weight,
                )
        else:  # cross attention
            attention_node_name = self.model.create_node_name("MultiHeadAttention")
            if self.enable_packed_kv:
                if kw.shape != vw.shape:
                    return None

                kw_in_size = kw.shape[0]
                vw_in_size = vw.shape[0]
                assert kw_in_size == vw_in_size

                qw_out_size = qw.shape[1]
                kw_out_size = kw.shape[1]
                vw_out_size = vw.shape[1]
                assert qw_out_size == vw_out_size and kw_out_size == vw_out_size

                c = kw_in_size
                n = num_heads
                h = kw_out_size // num_heads

                # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
                kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)

                matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
                self.add_initializer(
                    name=matmul_node_name + "_weight",
                    data_type=float_type,
                    dims=[kv_weight.shape[0], kv_weight.shape[1]],
                    vals=kv_weight,
                )

                matmul_node = helper.make_node(
                    "MatMul",
                    inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
                    outputs=[matmul_node_name + "_out"],
                    name=matmul_node_name,
                )
                self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name

                self.add_initializer(
                    name=matmul_node_name + "_reshape_shape",
                    data_type=TensorProto.INT64,
                    dims=[5],
                    vals=[0, 0, n, 2, h],
                    raw=False,
                )

                reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[
                        matmul_node_name + "_out",
                        matmul_node_name + "_reshape_shape",
                    ],
                    outputs=[attention_node_name + "_kv_input"],
                    name=matmul_node_name + "_reshape",
                )
                self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
                self.nodes_to_add.extend([matmul_node, reshape_node])
                self.nodes_to_remove.extend([k_matmul, v_matmul])

        # No bias, use zeros
        qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
        qkv_bias_dim = 3 * hidden_size

        self.add_initializer(
            name=attention_node_name + "_qkv_bias",
            data_type=float_type,
            dims=[qkv_bias_dim],
            vals=qkv_bias,
        )

        if is_self_attention:
            if not self.enable_packed_qkv:
                attention_inputs = [
                    input,
                    attention_node_name + "_qkv_weight",
                    attention_node_name + "_qkv_bias",
                ]
            else:
                attention_inputs = [attention_node_name + "_qkv_input"]
        else:
            if not self.enable_packed_kv:
                attention_inputs = [
                    q_matmul.output[0],
                    k_matmul.output[0],
                    v_matmul.output[0],
                    attention_node_name + "_qkv_bias",
                ]
            else:
                attention_inputs = [
                    q_matmul.output[0],
                    attention_node_name + "_kv_input",
                ]

        attention_node = helper.make_node(
            "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
            inputs=attention_inputs,
            outputs=[output],
            name=attention_node_name,
        )
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

        counter_name = (
            "Attention (self attention)"
            if is_self_attention and not self.enable_packed_qkv
            else "MultiHeadAttention ({})".format(
                "self attention with packed qkv"
                if self.enable_packed_qkv
                else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
            )
        )
        self.increase_counter(counter_name)
        return attention_node

    def create_attention_node_lora(
        self,
        q_matmul_add: NodeProto,
        k_matmul_add: NodeProto,
        v_matmul_add: NodeProto,
        num_heads: int,
        hidden_size: int,
        input: str,
        output: str,
    ) -> Union[NodeProto, None]:
        """Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        """
        is_self_attention = not self.is_cross_attention

        q_matmul = self.model.match_parent(q_matmul_add, "MatMul", 0)
        k_matmul = self.model.match_parent(k_matmul_add, "MatMul", 0)
        v_matmul = self.model.match_parent(v_matmul_add, "MatMul", 0)

        q_lora_nodes = self.match_lora_path(q_matmul_add)
        if q_lora_nodes is None:
            return None
        (q_lora_last_node, q_lora_matmul_1) = q_lora_nodes

        k_lora_nodes = self.match_lora_path(k_matmul_add)
        if k_lora_nodes is None:
            return None
        (k_lora_last_node, k_lora_matmul_1) = k_lora_nodes

        v_lora_nodes = self.match_lora_path(v_matmul_add)
        if v_lora_nodes is None:
            return None
        (v_lora_last_node, v_lora_matmul_1) = v_lora_nodes

        if is_self_attention:
            if q_matmul.input[0] != input or k_matmul.input[0] != input or v_matmul.input[0] != input:
                logger.debug(
                    "For self attention, input hidden state for q and k/v shall be same. Got %s, %s, %s",
                    q_matmul.input[0],
                    k_matmul.input[0],
                    v_matmul.input[0],
                )
                return None

            if (
                q_lora_matmul_1.input[0] != input
                or k_lora_matmul_1.input[0] != input
                or v_lora_matmul_1.input[0] != input
            ):
                logger.debug(
                    "For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %s",
                    q_lora_matmul_1.input[0],
                    k_lora_matmul_1.input[0],
                    v_lora_matmul_1.input[0],
                )
                return None
        else:
            if q_matmul.input[0] != input or (k_matmul.input[0] != v_matmul.input[0]) or (k_matmul.input[0] == input):
                logger.debug(
                    "For cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %s",
                    q_matmul.input[0],
                    k_matmul.input[0],
                    v_matmul.input[0],
                )
                return None

            if (
                q_lora_matmul_1.input[0] != input
                or (k_lora_matmul_1.input[0] != v_lora_matmul_1.input[0])
                or (k_matmul.input[0] == input)
            ):
                logger.debug(
                    (
                        "For cross attention, input hidden state for LoRA q and k/v weights shall be different. "
                        "Got %s, %s, %s"
                    ),
                    q_lora_matmul_1.input[0],
                    k_lora_matmul_1.input[0],
                    v_lora_matmul_1.input[0],
                )
                return None

        if hidden_size > 0 and (hidden_size % num_heads) != 0:
            logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
            return None

        q_weight = self.model.get_initializer(q_matmul.input[1])
        k_weight = self.model.get_initializer(k_matmul.input[1])
        v_weight = self.model.get_initializer(v_matmul.input[1])
        if not (q_weight and k_weight and v_weight):
            return None

        # Sometimes weights are stored in fp16
        if q_weight.data_type == 10:
            logger.debug("weights are in fp16. Please run fp16 conversion after optimization")
            return None

        qw = NumpyHelper.to_array(q_weight)
        kw = NumpyHelper.to_array(k_weight)
        vw = NumpyHelper.to_array(v_weight)
        logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")

        # assert q and k have same shape as expected
        if is_self_attention:
            if qw.shape != kw.shape or qw.shape != vw.shape:
                return None

            qw_in_size = qw.shape[0]

            if hidden_size > 0 and hidden_size != qw_in_size:
                raise ValueError(
                    f"Input hidden size ({hidden_size}) is not same as weight dimension of q,k,v ({qw_in_size}). "
                    "Please provide a correct input hidden size or pass in 0"
                )

            # All the matrices can have the same shape or q, k matrics can have the same shape with v being different
            # For 2d weights, the shapes would be [in_size, out_size].
            # For 3d weights, shape would be [in_size, a, b] where a*b = out_size
            qw_out_size = int(np.prod(qw.shape[1:]))

            if self.enable_packed_qkv:
                attention_node_name = self.model.create_node_name("MultiHeadAttention")

                c = qw_in_size
                n = num_heads
                h = qw_out_size // num_heads

                # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 3, H] shape
                qkv_weight = np.dstack([qw.reshape(c, n, h), kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(
                    c, n * 3 * h
                )

                matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV")
                self.add_initializer(
                    name=matmul_node_name + "_weight",
                    data_type=TensorProto.FLOAT,
                    dims=[qkv_weight.shape[0], qkv_weight.shape[1]],
                    vals=qkv_weight,
                )

                matmul_node = helper.make_node(
                    "MatMul",
                    inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
                    outputs=[matmul_node_name + "_out"],
                    name=matmul_node_name,
                )
                self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name

                # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
                # the Q/K/V weights to be changed without having to re-run the optimizer.
                lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"

                self.add_initializer(
                    name=lora_weight_shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[4],
                    vals=[0, 0, n, h],
                    raw=False,
                )

                # Reshape the LoRA Q weights
                q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q")
                q_lora_reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[q_lora_last_node.output[0], lora_weight_shape_tensor_name],
                    outputs=[q_lora_reshape_node_name + "_out"],
                    name=q_lora_reshape_node_name,
                )
                self.node_name_to_graph_name[q_lora_reshape_node.name] = self.this_graph_name

                # Reshape the LoRA K weights
                k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
                k_lora_reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[k_lora_last_node.output[0], lora_weight_shape_tensor_name],
                    outputs=[k_lora_reshape_node_name + "_out"],
                    name=k_lora_reshape_node_name,
                )
                self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name

                # Reshape the LoRA V weights
                v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
                v_lora_reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[v_lora_last_node.output[0], lora_weight_shape_tensor_name],
                    outputs=[v_lora_reshape_node_name + "_out"],
                    name=v_lora_reshape_node_name,
                )
                self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name

                # Concat the reshaped LoRA Q/K/V weights together on the third axis
                qkv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_QKV")
                qkv_lora_concat_node = helper.make_node(
                    "Concat",
                    inputs=[
                        q_lora_reshape_node.output[0],
                        k_lora_reshape_node.output[0],
                        v_lora_reshape_node.output[0],
                    ],
                    outputs=[qkv_lora_concat_node_name + "_out"],
                    name=qkv_lora_concat_node_name,
                )
                qkv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
                self.node_name_to_graph_name[qkv_lora_concat_node.name] = self.this_graph_name

                # Reshape the LoRA concatenated weights to [..., n * 3 * h]
                reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape"
                self.add_initializer(
                    name=reshaped_lora_weights_shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[3],
                    vals=[0, 0, n * 3 * h],
                    raw=False,
                )

                qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV")
                qkv_lora_reshaped_node = helper.make_node(
                    "Reshape",
                    inputs=[qkv_lora_concat_node.output[0], reshaped_lora_weights_shape_tensor_name],
                    outputs=[qkv_lora_reshaped_node_name + "_out"],
                    name=qkv_lora_reshaped_node_name,
                )
                self.node_name_to_graph_name[qkv_lora_reshaped_node.name] = self.this_graph_name

                # Add the LoRA Q/K/V weights to the base Q/K/V weights
                add_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_QKV")
                add_weights_node = helper.make_node(
                    "Add",
                    inputs=[qkv_lora_reshaped_node.output[0], matmul_node.output[0]],
                    outputs=[add_weights_node_name + "_out"],
                    name=add_weights_node_name,
                )
                self.node_name_to_graph_name[add_weights_node.name] = self.this_graph_name

                # Finally, reshape the concatenated Q/K/V result to 5D
                shape_tensor_name = add_weights_node_name + "_reshape_shape"
                self.add_initializer(
                    name=shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[5],
                    vals=[0, 0, n, 3, h],
                    raw=False,
                )

                reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[add_weights_node.output[0], shape_tensor_name],
                    outputs=[attention_node_name + "_qkv_input"],
                    name=add_weights_node_name + "_reshape",
                )
                self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name

                self.nodes_to_add.extend(
                    [
                        matmul_node,
                        q_lora_reshape_node,
                        k_lora_reshape_node,
                        v_lora_reshape_node,
                        qkv_lora_concat_node,
                        qkv_lora_reshaped_node,
                        add_weights_node,
                        reshape_node,
                    ]
                )
                self.nodes_to_remove.extend([q_matmul, k_matmul, v_matmul, q_matmul_add, k_matmul_add, v_matmul_add])
            else:
                # TODO: Support non-packed QKV
                return None
        else:  # cross attention
            attention_node_name = self.model.create_node_name("MultiHeadAttention")
            if self.enable_packed_kv:
                if kw.shape != vw.shape:
                    return None

                kw_in_size = kw.shape[0]
                vw_in_size = vw.shape[0]
                assert kw_in_size == vw_in_size

                qw_out_size = qw.shape[1]
                kw_out_size = kw.shape[1]
                vw_out_size = vw.shape[1]
                assert qw_out_size == vw_out_size and kw_out_size == vw_out_size

                c = kw_in_size
                n = num_heads
                h = kw_out_size // num_heads

                # Concat and interleave weights so that the output of fused KV GEMM has [B, S_kv, N, 2, H] shape
                kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h)

                matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV")
                self.add_initializer(
                    name=matmul_node_name + "_weight",
                    data_type=TensorProto.FLOAT,
                    dims=[kv_weight.shape[0], kv_weight.shape[1]],
                    vals=kv_weight,
                )

                matmul_node = helper.make_node(
                    "MatMul",
                    inputs=[k_matmul.input[0], matmul_node_name + "_weight"],
                    outputs=[matmul_node_name + "_out"],
                    name=matmul_node_name,
                )
                self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name

                # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow
                # the Q/K/V weights to be changed without having to re-run the optimizer.
                kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape"
                self.add_initializer(
                    name=kv_lora_weight_shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[4],
                    vals=[0, 0, n, h],
                    raw=False,
                )

                # Reshape the LoRA K weights
                k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K")
                k_lora_reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[k_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
                    outputs=[k_lora_reshape_node_name + "_out"],
                    name=k_lora_reshape_node_name,
                )
                self.node_name_to_graph_name[k_lora_reshape_node.name] = self.this_graph_name

                # Reshape the LoRA V weights
                v_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_V")
                v_lora_reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[v_lora_last_node.output[0], kv_lora_weight_shape_tensor_name],
                    outputs=[v_lora_reshape_node_name + "_out"],
                    name=v_lora_reshape_node_name,
                )
                self.node_name_to_graph_name[v_lora_reshape_node.name] = self.this_graph_name

                # Concat the reshaped LoRA K/V weights together on the third axis
                kv_lora_concat_node_name = self.model.create_node_name("Concat", name_prefix="Concat_LoRA_KV")
                kv_lora_concat_node = helper.make_node(
                    "Concat",
                    inputs=[k_lora_reshape_node.output[0], v_lora_reshape_node.output[0]],
                    outputs=[kv_lora_concat_node_name + "_out"],
                    name=kv_lora_concat_node_name,
                )
                kv_lora_concat_node.attribute.extend([helper.make_attribute("axis", 3)])
                self.node_name_to_graph_name[kv_lora_concat_node.name] = self.this_graph_name

                # Reshape the LoRA concatenated weights to [..., n * 2 * h]
                reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape"
                self.add_initializer(
                    name=reshaped_kv_lora_weights_shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[3],
                    vals=[0, 0, n * 2 * h],
                    raw=False,
                )

                kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV")
                kv_lora_reshaped_node = helper.make_node(
                    "Reshape",
                    inputs=[kv_lora_concat_node.output[0], reshaped_kv_lora_weights_shape_tensor_name],
                    outputs=[kv_lora_reshaped_node_name + "_out"],
                    name=kv_lora_reshaped_node_name,
                )
                self.node_name_to_graph_name[kv_lora_reshaped_node.name] = self.this_graph_name

                # Add the LoRA K/V weights to the base K/V weights
                add_kv_weights_node_name = self.model.create_node_name("Add", name_prefix="Add_Weights_KV")
                add_kv_weights_node = helper.make_node(
                    "Add",
                    inputs=[kv_lora_reshaped_node.output[0], matmul_node.output[0]],
                    outputs=[add_kv_weights_node_name + "_out"],
                    name=add_kv_weights_node_name,
                )
                self.node_name_to_graph_name[add_kv_weights_node.name] = self.this_graph_name

                # Finally, reshape the concatenated K/V result to 5D
                shape_tensor_name = add_kv_weights_node_name + "_reshape_shape"
                self.add_initializer(
                    name=shape_tensor_name,
                    data_type=TensorProto.INT64,
                    dims=[5],
                    vals=[0, 0, n, 2, h],
                    raw=False,
                )

                reshape_node = helper.make_node(
                    "Reshape",
                    inputs=[add_kv_weights_node.output[0], shape_tensor_name],
                    outputs=[attention_node_name + "_kv_input"],
                    name=add_kv_weights_node_name + "_reshape",
                )
                self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name
                self.nodes_to_add.extend(
                    [
                        matmul_node,
                        k_lora_reshape_node,
                        v_lora_reshape_node,
                        kv_lora_concat_node,
                        kv_lora_reshaped_node,
                        add_kv_weights_node,
                        reshape_node,
                    ]
                )
                self.nodes_to_remove.extend([k_matmul, v_matmul, k_matmul_add, v_matmul_add])
            else:
                # TODO: Support non-packed KV
                return None

        # No bias, use zeros
        qkv_bias = np.zeros([3, hidden_size], dtype=np.float32)
        qkv_bias_dim = 3 * hidden_size
        self.add_initializer(
            name=attention_node_name + "_qkv_bias",
            data_type=TensorProto.FLOAT,
            dims=[qkv_bias_dim],
            vals=qkv_bias,
        )

        if is_self_attention:
            if not self.enable_packed_qkv:
                # TODO: Support non-packed QKV
                return None
            else:
                attention_inputs = [attention_node_name + "_qkv_input"]
        else:
            if not self.enable_packed_kv:
                # TODO: Support non-packed QKV
                return None
            else:
                attention_inputs = [
                    q_matmul_add.output[0],
                    attention_node_name + "_kv_input",
                ]

        attention_node = helper.make_node(
            "Attention" if (is_self_attention and not self.enable_packed_qkv) else "MultiHeadAttention",
            inputs=attention_inputs,
            outputs=[output],
            name=attention_node_name,
        )
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

        counter_name = (
            "Attention (self attention)"
            if is_self_attention and not self.enable_packed_qkv
            else "MultiHeadAttention ({})".format(
                "self attention with packed qkv"
                if self.enable_packed_qkv
                else "cross attention with packed kv" if self.enable_packed_kv else "cross attention"
            )
        )
        self.increase_counter(counter_name)
        return attention_node

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node):
            return

        node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0)

        # In SD 1.5, for self attention, LayerNorm has parent Reshape
        if node_before_layernorm is None and not self.is_cross_attention:
            node_before_layernorm = self.model.match_parent(normalize_node, "Reshape", 0)

        if node_before_layernorm is None:
            return

        root_input = node_before_layernorm.output[0]

        children_nodes = input_name_to_nodes[root_input]
        skip_add = None
        for node in children_nodes:
            if node.op_type == "Add":  # SkipLayerNormalization fusion is not applied yet
                skip_add = node
                break
        if skip_add is None:
            return

        match_qkv = self.match_qkv_torch1(root_input, skip_add) or self.match_qkv_torch2(root_input, skip_add)
        if match_qkv is not None:
            is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v = match_qkv

            attention_last_node = reshape_qkv

            q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
            if q_num_heads <= 0:
                logger.debug("fuse_attention: failed to detect num_heads")
                return

            # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
            new_node = self.create_attention_node(
                matmul_q,
                matmul_k,
                matmul_v,
                q_num_heads,
                q_hidden_size,
                input=normalize_node.output[0],
                output=attention_last_node.output[0],
            )
            if new_node is None:
                return
        else:
            # Check if we have a LoRA pattern
            match_qkv = self.match_qkv_torch1_lora(root_input, skip_add) or self.match_qkv_torch2_lora(
                root_input, skip_add
            )
            if match_qkv is None:
                return

            is_torch2, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v = match_qkv

            attention_last_node = reshape_qkv

            q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
            if q_num_heads <= 0:
                logger.debug("fuse_attention: failed to detect num_heads")
                return

            # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
            new_node = self.create_attention_node_lora(
                matmul_add_q,
                matmul_add_k,
                matmul_add_v,
                q_num_heads,
                q_hidden_size,
                input=normalize_node.output[0],
                output=attention_last_node.output[0],
            )
            if new_node is None:
                return

            q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, is_torch2)
            if q_num_heads <= 0:
                logger.debug("fuse_attention: failed to detect num_heads")
                return

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

        self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

        # Use prune graph to remove nodes since they are shared by all attention nodes.
        self.prune_graph = True

    def match_qkv_torch1(self, root_input, skip_add):
        """Match Q, K and V paths exported by PyTorch 1.*"""
        another_input = 1 if skip_add.input[0] == root_input else 0
        qkv_nodes = self.model.match_parent_path(
            skip_add,
            ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
            [another_input, None, None, 0, 0, 0],
        )

        if qkv_nodes is None:
            return None

        (_, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes

        # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
        v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return None
        (_, _, _, matmul_v) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
        if qk_nodes is not None:
            (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
        else:
            qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
            if qk_nodes is not None:
                (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
            else:
                logger.debug("fuse_attention: failed to match qk path")
                return None

        q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return None
        (_, _transpose_q, reshape_q, matmul_q) = q_nodes

        k_nodes = self.model.match_parent_path(
            matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0, 0]
        )
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return None

        (_, _, _, _, matmul_k) = k_nodes

        return False, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v

    def match_qkv_torch2(self, root_input, skip_add):
        """Match Q, K and V paths exported by PyTorch 2.*"""
        another_input = 1 if skip_add.input[0] == root_input else 0
        qkv_nodes = self.model.match_parent_path(
            skip_add,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [another_input, None, None, 0, 0],
        )

        if qkv_nodes is None:
            return None

        (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes

        v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "MatMul"], [1, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return None
        (_, _, matmul_v) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
        if qk_nodes is not None:
            (_softmax_qk, matmul_qk) = qk_nodes
        else:
            logger.debug("fuse_attention: failed to match qk path")
            return None

        q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [0, None, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return None
        (mul_q, _transpose_q, reshape_q, matmul_q) = q_nodes

        k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "MatMul"], [1, None, 0, 0])
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return None

        (_mul_k, _, _, matmul_k) = k_nodes

        # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
        mul_q_nodes = self.model.match_parent_path(
            mul_q,
            ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
            [None, 0, 1, 0, 0, 0, 0, 0],
        )
        if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
            logger.debug("fuse_attention: failed to match mul_q path")
            return None

        return True, reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v

    def match_qkv_torch1_lora(self, root_input, skip_add):
        """Match Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*"""
        another_input = 1 if skip_add.input[0] == root_input else 0
        qkv_nodes = self.model.match_parent_path(
            skip_add,
            ["Add", "Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"],
            [another_input, 0, None, None, 0, 0, 0],
        )
        if qkv_nodes is None:
            return None

        (_, _, _, reshape_qkv, transpose_qkv, _, matmul_qkv) = qkv_nodes

        # No bias. For cross-attention, the input of the MatMul is encoder_hidden_states graph input.
        v_nodes = self.model.match_parent_path(matmul_qkv, ["Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA v path")
            return None
        (_, _, _, matmul_add_v) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Mul", "MatMul"], [0, 0, 0])
        if qk_nodes is not None:
            (_softmax_qk, _mul_qk, matmul_qk) = qk_nodes
        else:
            qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
            if qk_nodes is not None:
                (_softmax_qk, _add_zero, _mul_qk, matmul_qk) = qk_nodes
            else:
                logger.debug("fuse_attention: failed to match LoRA qk path")
                return None

        q_nodes = self.model.match_parent_path(matmul_qk, ["Reshape", "Transpose", "Reshape", "Add"], [0, 0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA q path")
            return None
        (_, _transpose_q, reshape_q, matmul_add_q) = q_nodes

        k_nodes = self.model.match_parent_path(
            matmul_qk, ["Transpose", "Reshape", "Transpose", "Reshape", "Add"], [1, 0, 0, 0, 0]
        )
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA k path")
            return None

        (_, _, _, _, matmul_add_k) = k_nodes

        return False, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v

    def match_qkv_torch2_lora(self, root_input, skip_add):
        """Match Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*"""
        another_input = 1 if skip_add.input[0] == root_input else 0
        qkv_nodes = self.model.match_parent_path(
            skip_add,
            ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [another_input, 0, None, None, 0, 0],
        )
        if qkv_nodes is None:
            return None

        (_, _, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes

        v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add"], [1, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA v path")
            return None
        (_, _, matmul_add_v) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
        if qk_nodes is not None:
            (_softmax_qk, matmul_qk) = qk_nodes
        else:
            logger.debug("fuse_attention: failed to match LoRA qk path")
            return None

        q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [0, None, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA q path")
            return None
        (mul_q, _transpose_q, reshape_q, matmul_add_q) = q_nodes

        k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Reshape", "Add"], [1, None, 0, 0])
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match LoRA k path")
            return None

        (_mul_k, _, _, matmul_add_k) = k_nodes

        # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
        mul_q_nodes = self.model.match_parent_path(
            mul_q,
            ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
            [None, 0, 1, 0, 0, 0, 0, 0],
        )
        if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
            logger.debug("fuse_attention: failed to match LoRA mul_q path")
            return None

        return True, reshape_qkv, transpose_qkv, reshape_q, matmul_add_q, matmul_add_k, matmul_add_v

    def match_lora_path(
        self,
        add_node: NodeProto,
    ):
        # Lora paths can look like one of the following options:
        # MatMul -> MatMul -> Add
        # MatMul -> MatMul -> Mul -> Add
        # MatMul -> MatMul -> Mul -> Mul -> Add

        # Try matching MatMul -> MatMul -> Add
        lora_nodes = self.model.match_parent_path(
            add_node,
            ["MatMul", "MatMul"],
            [1, 0],
        )

        if lora_nodes is not None:
            (lora_matmul_2_node, lora_matmul_1_node) = lora_nodes
            return (lora_matmul_2_node, lora_matmul_1_node)

        # Try matching MatMul -> MatMul -> Mul -> Add
        lora_nodes = self.model.match_parent_path(
            add_node,
            ["Mul", "MatMul", "MatMul"],
            [1, 0, 0],
        )

        if lora_nodes is not None:
            (lora_mul_node, _, lora_matmul_1_node) = lora_nodes
            return (lora_mul_node, lora_matmul_1_node)

        # Try matching MatMul -> MatMul -> Mul -> Mul -> Add
        lora_nodes = self.model.match_parent_path(
            add_node,
            ["Mul", "Mul", "MatMul", "MatMul"],
            [1, 0, 0, 0],
        )

        if lora_nodes is not None:
            (lora_mul_node, _, _, lora_matmul_1_node) = lora_nodes
            return (lora_mul_node, lora_matmul_1_node)

        return None

    def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node):
        """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension"""
        entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0])
        if entry_path is None:
            entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0])
            if entry_path is None:
                return False
        _cast, node_before_layernorm = entry_path

        root_input = node_before_layernorm.output[0]

        children_nodes = input_name_to_nodes[root_input]
        skip_add = None
        for node in children_nodes:
            if node.op_type == "Add":  # SkipLayerNormalization fusion is not applied yet
                skip_add = node
                break
        if skip_add is None:
            return False

        match_qkv = self.match_qkv_a1111(root_input, skip_add)
        if match_qkv is None:
            return False

        (
            reshape_qkv,
            transpose_qkv,
            reshape_q,
            matmul_q,
            matmul_k,
            matmul_v,
        ) = match_qkv

        cast_q = self.model.match_parent(matmul_q, "Cast", 0)
        cast_k = self.model.match_parent(matmul_k, "Cast", 0)
        cast_v = self.model.match_parent(matmul_v, "Cast", 0)
        if not (
            cast_q is not None
            and cast_k is not None
            and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k)
            and cast_k == cast_v
        ):
            return False

        if cast_q.input[0] != normalize_node.output[0]:
            return False

        attention_last_node = reshape_qkv

        q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False)
        if q_num_heads <= 0:
            logger.debug("fuse_attention: failed to detect num_heads")
            return False

        q_hidden_size = self.get_hidden_size(normalize_node)

        # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
        new_node = self.create_attention_node(
            matmul_q,
            matmul_k,
            matmul_v,
            q_num_heads,
            q_hidden_size,
            input=matmul_q.input[0],
            output=attention_last_node.output[0],
        )
        if new_node is None:
            return False

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

        self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

        # Use prune graph to remove nodes since they are shared by all attention nodes.
        self.prune_graph = True
        return True

    def match_qkv_a1111(self, root_input, skip_add):
        """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension"""
        another_input = 1 if skip_add.input[0] == root_input else 0
        qkv_nodes = self.model.match_parent_path(
            skip_add,
            ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"],
            [another_input, None, None, 0, 0, 0],
        )

        if qkv_nodes is None:
            return None

        (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes

        v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return None
        (_, _, _, matmul_v) = v_nodes

        qk_nodes = self.model.match_parent_path(
            einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None]
        )
        if qk_nodes is not None:
            (_, _, _softmax_qk, _, einsum_qk) = qk_nodes
        else:
            logger.debug("fuse_attention: failed to match qk path")
            return None

        q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0])
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return None
        (_, _transpose_q, reshape_q, matmul_q) = q_nodes

        k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0])
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return None

        (_, _, _, matmul_k) = k_nodes

        return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v
