# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from logging import getLogger
from typing import Tuple, Union

import numpy as np
from fusion_base import Fusion
from fusion_utils import NumpyHelper
from onnx import NodeProto, helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class FusionMultiHeadAttentionSam2(Fusion):
    """
    Fuse MultiHeadAttention subgraph of Segment Anything v2 (SAM2).
    """

    def __init__(
        self,
        model: OnnxModel,
        hidden_size: int,
        num_heads: int,
    ):
        super().__init__(model, "MultiHeadAttention", ["LayerNormalization"])
        self.hidden_size = hidden_size
        self.num_heads = num_heads

        # Flags to show warning only once
        self.num_heads_warning = True
        self.hidden_size_warning = True

    def get_decoder_num_heads(self, reshape_q: NodeProto) -> int:
        """Detect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
        Returns:
            int: num_heads, or 0 if not found
        """
        num_heads = 0

        # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
        shape_value = self.model.get_constant_value(reshape_q.input[1])
        if shape_value is not None:
            if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [4]:
                num_heads = int(shape_value[2])

        if isinstance(num_heads, int) and num_heads > 0:
            return num_heads

        return 0

    def get_encoder_num_heads(self, reshape_in: NodeProto) -> int:
        """Detect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
        Returns:
            int: num_heads, or 0 if not found
        """
        num_heads = 0

        shape_value = self.model.get_constant_value(reshape_in.input[1])
        if shape_value is not None:
            if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [5]:
                num_heads = int(shape_value[3])
        else:
            concat_shape = self.model.match_parent(reshape_in, "Concat", 1)
            if concat_shape is not None and len(concat_shape.input) == 5:
                # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
                shape_value = self.model.get_constant_value(concat_shape.input[3])
                if shape_value is not None:
                    if isinstance(shape_value, np.ndarray) and list(shape_value.shape) == [1]:
                        num_heads = int(shape_value[0])

        if isinstance(num_heads, int) and num_heads > 0:
            return num_heads

        return 0

    def get_hidden_size(self, layernorm_node):
        """Detect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        """
        layernorm_bias = self.model.get_initializer(layernorm_node.input[2])
        if layernorm_bias:
            return NumpyHelper.to_array(layernorm_bias).shape[0]

        return 0

    def get_num_heads_and_hidden_size(
        self, reshape_q: NodeProto, layernorm_node: NodeProto, is_encoder: bool = False
    ) -> Tuple[int, int]:
        """Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        """
        if is_encoder:
            num_heads = self.get_encoder_num_heads(reshape_q)
        else:
            num_heads = self.get_decoder_num_heads(reshape_q)
        if num_heads <= 0:
            num_heads = self.num_heads  # Fall back to user specified value

        if self.num_heads > 0 and num_heads != self.num_heads:
            if self.num_heads_warning:
                logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
                self.num_heads_warning = False  # Do not show the warning more than once

        hidden_size = self.get_hidden_size(layernorm_node)
        if hidden_size <= 0:
            hidden_size = self.hidden_size  # Fall back to user specified value

        if self.hidden_size > 0 and hidden_size != self.hidden_size:
            if self.hidden_size_warning:
                logger.warning(
                    f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
                )
                self.hidden_size_warning = False  # Do not show the warning more than once

        return num_heads, hidden_size

    def create_attention_node(
        self,
        q_matmul: NodeProto,
        q_add: NodeProto,
        k_matmul: NodeProto,
        k_add: NodeProto,
        v_matmul: NodeProto,
        v_add: NodeProto,
        num_heads: int,
        hidden_size: int,
        output: str,
    ) -> Union[NodeProto, None]:
        """Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            q_add (NodeProto): Add bias node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            k_add (NodeProto): Add bias node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            v_add (NodeProto): Add bias node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        """
        if hidden_size > 0 and (hidden_size % num_heads) != 0:
            logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
            return None

        q_weight = self.model.get_initializer(q_matmul.input[1])
        k_weight = self.model.get_initializer(k_matmul.input[1])
        v_weight = self.model.get_initializer(v_matmul.input[1])
        if not (q_weight and k_weight and v_weight):
            return None

        qw = NumpyHelper.to_array(q_weight)
        kw = NumpyHelper.to_array(k_weight)
        vw = NumpyHelper.to_array(v_weight)
        logger.debug(f"qw={qw.shape} kw={kw.shape} vw={vw.shape} hidden_size={hidden_size}")

        attention_node_name = self.model.create_node_name("MultiHeadAttention")

        attention_inputs = [
            q_add.output[0],
            k_add.output[0],
            v_add.output[0],
        ]

        attention_node = helper.make_node(
            "MultiHeadAttention",
            inputs=attention_inputs,
            outputs=[output],
            name=attention_node_name,
        )
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

        counter_name = "MultiHeadAttention ({})".format("cross attention")
        self.increase_counter(counter_name)
        return attention_node

    def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
        if self.fuse_sam_encoder_pattern(normalize_node, input_name_to_nodes, output_name_to_node):
            return

        match_qkv = self.match_attention_subgraph(normalize_node)
        if match_qkv is None:
            if normalize_node.input[0] not in output_name_to_node:
                return

            skip_add = output_name_to_node[normalize_node.input[0]]
            if skip_add.op_type != "Add":
                return

            match_qkv = self.match_attention_subgraph(skip_add)

            if match_qkv is None:
                return

        reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v = match_qkv

        attention_last_node = reshape_qkv

        q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q, normalize_node, False)
        if q_num_heads <= 0:
            logger.debug("fuse_attention: failed to detect num_heads")
            return

        # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
        new_node = self.create_attention_node(
            matmul_q,
            add_q,
            matmul_k,
            add_k,
            matmul_v,
            add_v,
            q_num_heads,
            q_hidden_size,
            output=attention_last_node.output[0],
        )
        if new_node is None:
            return

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name

        self.nodes_to_remove.extend([attention_last_node, transpose_qkv])

        # Use prune graph to remove nodes since they are shared by all attention nodes.
        self.prune_graph = True

    def match_attention_subgraph(self, node_after_output_projection):
        """Match Q, K and V paths exported by PyTorch 2.*"""
        qkv_nodes = self.model.match_parent_path(
            node_after_output_projection,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [None, None, None, 0, 0],
        )

        if qkv_nodes is None:
            return None

        (_, _, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes

        v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None])
        if v_nodes is None:
            logger.debug("fuse_attention: failed to match v path")
            return None
        (_, _, add_v, matmul_v) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
        if qk_nodes is not None:
            (_softmax_qk, matmul_qk) = qk_nodes
        else:
            logger.debug("fuse_attention: failed to match qk path")
            return None

        q_nodes = self.model.match_parent_path(
            matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [0, None, 0, 0, None]
        )
        if q_nodes is None:
            logger.debug("fuse_attention: failed to match q path")
            return None
        (mul_q, _transpose_q, reshape_q, add_q, matmul_q) = q_nodes

        k_nodes = self.model.match_parent_path(
            matmul_qk, ["Mul", "Transpose", "Reshape", "Add", "MatMul"], [1, None, 0, 0, None]
        )
        if k_nodes is None:
            logger.debug("fuse_attention: failed to match k path")
            return None

        (_mul_k, _, _, add_k, matmul_k) = k_nodes

        # The scalar for Q and K is sqrt(1.0/sqrt(head_size)).
        mul_q_nodes = self.model.match_parent_path(
            mul_q,
            ["Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape", "Transpose", "Reshape"],
            [None, 0, 1, 0, 0, 0, 0, 0],
        )
        if mul_q_nodes is None or mul_q_nodes[-1] != reshape_q:
            logger.debug("fuse_attention: failed to match mul_q path")
            return None

        return reshape_qkv, transpose_qkv, reshape_q, matmul_q, add_q, matmul_k, add_k, matmul_v, add_v

    # --------------------------------------------------------
    # The following are for SAM encoder
    # --------------------------------------------------------
    def fuse_sam_encoder_pattern(self, normalize_node, input_name_to_nodes, output_name_to_node) -> bool:
        # SAM encoder attention layer pattern:
        #           Add -----------+
        #            |             |
        #        LayerNorm         |
        #            |             |
        #        Reshape           |
        #            |             |
        #        Transpose         |
        #            |             |
        #        MatMul            |
        #            |             |
        #           Add            |
        #            |             |
        #         Reshape          |
        #            |             |
        #          Split           |
        #            |             |
        #  Self Attention subgraph |
        #            |             |
        #        Reshape           |
        #            |             |
        #        Transpose         |
        #            |             |
        #        Reshape           |
        #            |             |
        #            Add ----------+
        #            |
        #         LayerNorm (starts from here)

        nodes = self.model.match_parent_path(
            normalize_node,
            ["Add", "Reshape", "Transpose", "Reshape"],
            [0, None, 0, 0],
        )
        if nodes is None:
            nodes = self.model.match_parent_path(
                normalize_node,
                ["Add", "Slice", "Slice", "Reshape", "Transpose", "Reshape"],
                [0, None, 0, 0, 0, 0],
            )
        if nodes is None:
            nodes = self.model.match_parent_path(
                normalize_node,
                ["Add"],
                [0],
            )
        if nodes is None:
            return False

        node_after_output_projection = nodes[-1]
        matched_sdpa = self.match_sam_encoder_attention_subgraph(
            node_after_output_projection, input_index=1 if len(nodes) == 1 else None
        )
        if matched_sdpa is None:
            return False

        reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v = matched_sdpa

        # B, S, N, H => B, N, S, H
        permutation_q = OnnxModel.get_node_attribute(transpose_q, "perm")
        if (not isinstance(permutation_q, list)) or permutation_q != [0, 2, 1, 3]:
            return False

        # B, S, N, H => B, N, H, S
        permutation_k = OnnxModel.get_node_attribute(transpose_k, "perm")
        if (not isinstance(permutation_k, list)) or permutation_k != [0, 2, 3, 1]:
            return False

        # B, S, N, H => B, N, S, H
        permutation_v = OnnxModel.get_node_attribute(transpose_v, "perm")
        if (not isinstance(permutation_v, list)) or permutation_v != [0, 2, 1, 3]:
            return False

        input_projection_nodes = self.model.match_parent_path(
            split_qkv,
            ["Reshape", "Add", "MatMul"],
            [0, 0, None],
        )
        if input_projection_nodes is None:
            return False
        reshape_in, add_in, matmul_in = input_projection_nodes
        q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_in, normalize_node, True)
        if q_num_heads <= 0:
            logger.debug("fuse_attention: failed to detect num_heads")
            return False

        # Add a shape to convert 4D BxSxNxH to 3D BxSxD, which is required by MHA operator.
        new_dims_name = "bsnh_to_bsd_reshape_dims"
        new_dims = self.model.get_initializer(new_dims_name)
        if new_dims is None:
            new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
            self.model.add_initializer(new_dims, self.this_graph_name)
        reshape_q_name = self.model.create_node_name("Reshape")
        reshape_q = helper.make_node(
            "Reshape",
            inputs=[transpose_q.input[0], new_dims_name],
            outputs=[transpose_q.input[0] + "_BSD"],
            name=reshape_q_name,
        )
        self.nodes_to_add.append(reshape_q)
        self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name

        # Reuse the transpose_q node to transpose K from BSNH to BNSH. Here we update the input and output of the node.
        transpose_k_bnsh = transpose_q
        transpose_k_bnsh.input[0] = transpose_k.input[0]
        transpose_k_bnsh.output[0] = transpose_k.input[0] + "_BNSH"

        logger.debug(f"Found MHA: {q_num_heads=} {q_hidden_size=}")

        # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads
        new_node = self.create_mha_node(
            reshape_q,
            transpose_k_bnsh,
            transpose_v,
            q_num_heads,
        )
        if new_node is None:
            return False

        # Update the input of the next node that consumes the output of the MHA.
        assert len(self.model.get_children(transpose_out, input_name_to_nodes)) == 1
        reshape_out.input[0] = new_node.output[0]

        self.nodes_to_add.append(new_node)
        self.node_name_to_graph_name[new_node.name] = self.this_graph_name
        self.nodes_to_remove.extend([transpose_out])

        # Use prune graph to remove nodes since they are shared by all attention nodes.
        self.prune_graph = True
        return True

    def match_sam_encoder_attention_subgraph(self, node_after_output_projection, input_index=None):
        """Match SDPA pattern in SAM2 enconder.*"""

        # nodes of output projection and the second MatMul in SDPA.
        out_nodes = self.model.match_parent_path(
            node_after_output_projection,
            ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
            [input_index, None, None, 0, 0],
        )

        if out_nodes is None:
            return None

        (_, _, reshape_out, transpose_out, matmul_qk_v) = out_nodes

        # Split and Reshape is for packed QKV
        v_nodes = self.model.match_parent_path(matmul_qk_v, ["Transpose", "Squeeze", "Split", "Reshape"], [1, 0, 0, 0])
        if v_nodes is None:
            logger.debug("failed to match v path")
            return None
        (transpose_v, _, split_qkv, reshape_qkv) = v_nodes

        qk_nodes = self.model.match_parent_path(matmul_qk_v, ["Softmax", "MatMul"], [0, 0])
        if qk_nodes is not None:
            (_softmax_qk, matmul_qk) = qk_nodes
        else:
            logger.debug("failed to match qk path")
            return None

        q_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [0, None, 0, 0])
        if q_nodes is None:
            q_nodes = self.model.match_parent_path(
                matmul_qk,
                ["Mul", "Transpose", "Reshape", "Transpose", "MaxPool", "Transpose", "Reshape", "Squeeze", "Split"],
                [0, None, 0, 0, 0, 0, 0, 0, 0],
            )
            if q_nodes is None:
                logger.debug("failed to match q path")
                return None

        if q_nodes[-1] != split_qkv:
            return None
        transpose_q = q_nodes[1]

        k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose", "Squeeze", "Split"], [1, None, 0, 0])
        if k_nodes is None:
            logger.debug("failed to match k path")
            return None

        if k_nodes[-1] != split_qkv:
            return None
        (mul_k, transpose_k, _squeeze_k, _) = k_nodes

        return reshape_out, transpose_out, split_qkv, transpose_q, transpose_k, transpose_v

    def create_mha_node(
        self,
        reshape_q: NodeProto,
        transpose_k: NodeProto,
        transpose_v: NodeProto,
        num_heads: int,
    ) -> NodeProto:
        """Create a MultiHeadAttention node for SAM2 encoder.

        Args:
            reshape_q (NodeProto): Reshape node for Q, output is 3D BxSxNH format
            transpose_k (NodeProto): Transpose node for K, output is BNSH format
            transpose_v (NodeProto): Transpose node for V, output is BNSH format
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.

        Returns:
            NodeProto: the MultiHeadAttention node created.
        """

        attention_node_name = self.model.create_node_name("MultiHeadAttention")

        inputs = [
            reshape_q.output[0],
            transpose_k.output[0],
            transpose_v.output[0],
        ]

        # Create a new output name since the shape is 3D, which is different from the original output shape (4D).
        output = attention_node_name + "_out"

        attention_node = helper.make_node(
            "MultiHeadAttention",
            inputs=inputs,
            outputs=[output],
            name=attention_node_name,
        )
        attention_node.domain = "com.microsoft"
        attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])

        counter_name = "MultiHeadAttention ({})".format("self attention")
        self.increase_counter(counter_name)
        return attention_node
