# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from logging import getLogger
from typing import List, Optional

import numpy as np
from dynamo_onnx_helper import DynamoOnnxHelper
from fusion_base import Fusion
from fusion_options import AttentionOpType, FusionOptions
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
from fusion_utils import NumpyHelper
from onnx import ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx_model import OnnxModel

logger = getLogger(__name__)


class ProcessGemmWFunc:
    def __call__(self, x):
        return np.transpose(x, (1, 0))


class ProcessMatMulQFunc:
    def __call__(self, x):
        return np.transpose(np.split(x, 3, 0)[0], (1, 0))


class ProcessMatMulKFunc:
    def __call__(self, x):
        return np.transpose(np.split(x, 3, 0)[1], (1, 0))


class ProcessMatMulVFunc:
    def __call__(self, x):
        return np.transpose(np.split(x, 3, 0)[2], (1, 0))


class ProcessBiasQFunc:
    def __call__(self, x):
        x = np.split(x, 3, -1)[0]
        return x


class ProcessBiasKFunc:
    def __call__(self, x):
        x = np.split(x, 3, -1)[1]
        return x


class ProcessBiasVFunc:
    def __call__(self, x):
        x = np.split(x, 3, -1)[2]
        return x


class ProcessRotCacheFunc:
    def __call__(self, x):
        # half rotary embedding
        assert len(x.shape) == 2
        if x.shape[1] == 32:
            return x[:, 0:16]
        return x


# TODO: move to a separate file
class Fission(Fusion):
    def __init__(
        self,
        model: OnnxModel,
        nodes_to_find: List[str],
    ):
        super().__init__(model, "DONOTUSE", nodes_to_find)

    def set_attention_op_type(self, attn_op_type: AttentionOpType):
        self.attn_op_type = attn_op_type

    def get_uname(self, layer_id, name):
        return name + "_" + str(layer_id)

    def get_edge_by_name(self, edges, name):
        for edge in edges:
            if edge == name or edge.endswith(name) or edge.startswith(name):
                return edge
        raise ValueError(f"Edge {name} not found")

    def get_input_by_name(self, node, name):
        return self.get_edge_by_name(node.input, name)

    def get_output_by_name(self, node, name):
        return self.get_edge_by_name(node.output, name)

    def process_initializer(self, initializer_name, functor, custom_name=None):
        i = self.model.get_initializer(initializer_name)
        i_np_array = NumpyHelper.to_array(i)
        processed_i_np_array = functor(i_np_array)
        new_tensor = helper.make_tensor(
            initializer_name + "_processed" if custom_name is None else custom_name,
            data_type=TensorProto.FLOAT,
            dims=processed_i_np_array.shape,
            vals=processed_i_np_array.flatten().tobytes(),
            raw=True,
        )
        self.model.add_initializer(new_tensor, self.this_graph_name)
        return new_tensor.name

    def add_fp32_value_info(self, name):
        new_value_info = self.model.graph().value_info.add()
        new_value_info.name = name
        new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT

    def add_int64_value_info(self, name):
        new_value_info = self.model.graph().value_info.add()
        new_value_info.name = name
        new_value_info.type.tensor_type.elem_type = TensorProto.INT64

    def replace_fp32_value_info(self, name, shape):
        for value_info in self.model.graph().value_info:
            if value_info.name == name:
                self.model.graph().value_info.remove(value_info)
                break
        new_value_info = helper.make_tensor_value_info(
            name,
            elem_type=TensorProto.FLOAT,
            shape=shape,
        )
        self.model.graph().value_info.extend([new_value_info])

    def set_unique_name_and_add_nodes(
        self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]
    ):
        for new_node in subgraph_nodes:
            for i, name in enumerate(new_node.input):
                if name == "":
                    continue
                elif name not in layer_known_edges_names:
                    new_node.input[i] = self.get_uname(layer_id, name)
                    self.add_fp32_value_info(new_node.input[i])
            for i, name in enumerate(new_node.output):
                if name == "":
                    continue
                elif name not in layer_known_edges_names:
                    new_node.output[i] = self.get_uname(layer_id, name)
                    self.add_fp32_value_info(new_node.output[i])
            new_node.name = self.get_uname(layer_id, new_node.name)
            self.nodes_to_add.append(new_node)
            self.node_name_to_graph_name[new_node.name] = self.this_graph_name

    def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
        assert len(inputs) == 3
        assert len(outputs) == 1
        node = helper.make_node(
            "LayerNormalization",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "_LayerNormalization",
            epsilon=9.999999747378752e-06,
        )
        return [node]

    def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""):
        assert len(inputs) == 3
        assert len(outputs) == 1
        matmul = helper.make_node(
            "MatMul",
            inputs=[inputs[0], inputs[1]],
            outputs=[prefix + "matmul_out"],
            name=prefix + "MatMul",
        )
        add = helper.make_node(
            "Add",
            inputs=[prefix + "matmul_out", inputs[2]],
            outputs=outputs,
            name=prefix + "Bias",
        )
        return [matmul, add]

    def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32):
        assert len(inputs) == 4
        assert len(outputs) == 1
        node = helper.make_node(
            "RotaryEmbedding",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "RotaryEmbedding",
            domain="com.microsoft",
            rotary_embedding_dim=rot_dim,
            num_heads=num_heads,
        )
        return [node]

    def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""):
        assert len(inputs) == 1
        assert len(outputs) == 1
        node = helper.make_node(
            "FastGelu",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "FastGelu",
            domain="com.microsoft",
        )
        return [node]

    def add(self, inputs: List[str], outputs: List[str], prefix: str = ""):
        assert len(inputs) == 2
        assert len(outputs) == 1
        node = helper.make_node(
            "Add",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "Add",
        )
        return [node]

    def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
        assert len(inputs) == 8
        assert len(outputs) == 3
        node = helper.make_node(
            "MultiHeadAttention",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "MultiHeadAttention",
            domain="com.microsoft",
            num_heads=num_heads,
            unidirectional=1,
        )
        return [node]

    def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
        assert len(inputs) == 7
        assert len(outputs) == 3
        node = helper.make_node(
            "GroupQueryAttention",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "GroupQueryAttention",
            domain="com.microsoft",
            num_heads=num_heads,
            kv_num_heads=num_heads,
        )
        return [node]

    def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
        assert len(inputs) == 5
        assert len(outputs) == 2
        node = helper.make_node(
            "Attention",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "Attention",
            domain="com.microsoft",
            num_heads=num_heads,
            unidirectional=1,
            do_rotary=1,
            rotary_embedding_dim=32,
        )
        return [node]

    def paged_attn(
        self,
        inputs: List[str],
        outputs: List[str],
        prefix: str = "",
        num_heads=32,
        head_size=80,
        scale=0.11180339753627777,
    ):
        assert len(inputs) == 6
        assert len(outputs) == 1
        node = helper.make_node(
            "PagedAttention",
            inputs=inputs,
            outputs=outputs,
            name=prefix + "PagedAttention",
            domain="vllm.ort.ext",
            num_heads=num_heads,
            num_kv_heads=num_heads,
            head_size=head_size,
            scale=scale,
        )
        return [node]


class Phi2PreProcessor(DynamoOnnxHelper):
    def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
        super().__init__(model)
        self.num_hidden_layers = 32
        self.num_attention_heads = num_heads
        self.hidden_size = hidden_size

        self.func_name = "modeling_phi_PhiModel_model_1"

    def get_phi2_edge_dict(self) -> dict:
        edge_dict = {}
        edge_dict["lm_head_1"] = "logits"
        edge_dict["l_input_ids_"] = "input_ids"
        edge_dict["key_states"] = "past_key_0"
        edge_dict["value_states"] = "past_value_0"
        for i in range(1, self.num_hidden_layers, 1):
            edge_dict[f"key_states_{i}"] = f"past_key_{i}"
            edge_dict[f"value_states_{i}"] = f"past_value_{i}"
            edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
            edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"

        outputs = [o.name for o in self.model.graph.output]
        if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
            edge_dict["model_layers_0_1_1"] = "present_key_0"
            edge_dict["model_layers_0_1_2"] = "present_value_0"
        else:
            assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
            edge_dict["model_layers_0_1"] = "present_key_0"
            edge_dict["model_layers_0_1_1"] = "present_value_0"
        return edge_dict

    def simplify_phi2_op_type(self):
        phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
        for node in self.model.graph.node:
            index = node.op_type.find(phi2_transformer_layer_name)
            if index != -1:
                node.op_type = node.op_type[index:]

    def process_graph_io(self, attn_op_type: AttentionOpType):
        self.use_attn = attn_op_type == AttentionOpType.Attention
        self.use_vllm = attn_op_type == AttentionOpType.PagedAttention
        graph = self.model.graph
        new_inputs = []
        for vi in graph.input:
            if "input_ids" in vi.name:
                vi_iid = helper.make_tensor_value_info(
                    vi.name,
                    elem_type=TensorProto.INT32 if not self.use_vllm else TensorProto.INT64,
                    shape=["batch_size", "seq_len"],
                )
                vi_step = helper.make_tensor_value_info(
                    "step",
                    elem_type=TensorProto.INT64,
                    shape=[1],
                )
                vi_pid = helper.make_tensor_value_info(
                    "position_ids",
                    elem_type=TensorProto.INT64,
                    shape=["batch_size", "seq_len"],
                )
                vi_mask = helper.make_tensor_value_info(
                    "attention_mask",
                    elem_type=TensorProto.INT32,
                    shape=["batch_size", "seq_len"],
                )
                vi_meta = helper.make_tensor_value_info(
                    "input_metadata",
                    elem_type=TensorProto.INT64,
                    shape=[1],
                )
                (
                    new_inputs.extend([vi_iid, vi_step, vi_mask])
                    if not self.use_vllm
                    else new_inputs.extend([vi_iid, vi_pid, vi_meta])
                )
            if self.use_attn:
                if "past_key" in vi.name:
                    vi_cache = helper.make_tensor_value_info(
                        vi.name.replace("past_key", "past"),
                        elem_type=vi.type.tensor_type.elem_type,
                        shape=[
                            2,
                            "batch_size",
                            self.num_attention_heads,
                            "past_seq_len",
                            self.hidden_size // self.num_attention_heads,
                        ],
                    )
                    new_inputs.extend([vi_cache])
            elif self.use_vllm:
                if "past_key" in vi.name:
                    vi_cache = helper.make_tensor_value_info(
                        vi.name,
                        elem_type=vi.type.tensor_type.elem_type,
                        shape=["num_blocks", "num_heads", "head_size_x", "block_size", "block_x"],
                    )
                    new_inputs.extend([vi_cache])
                if "past_value" in vi.name:
                    vi_cache = helper.make_tensor_value_info(
                        vi.name,
                        elem_type=vi.type.tensor_type.elem_type,
                        shape=[
                            "num_blocks",
                            "num_heads",
                            "head_size",
                            "block_size",
                        ],
                    )
                    new_inputs.extend([vi_cache])
            else:
                if "past_key" in vi.name or "past_value" in vi.name:
                    vi_cache = helper.make_tensor_value_info(
                        vi.name,
                        elem_type=vi.type.tensor_type.elem_type,
                        shape=[
                            "batch_size",
                            self.num_attention_heads,
                            "past_seq_len",
                            self.hidden_size // self.num_attention_heads,
                        ],
                    )
                    new_inputs.extend([vi_cache])

        graph.ClearField("input")
        graph.input.extend(new_inputs)

        new_outputs = []
        for i, vi in enumerate(graph.output):
            if i == 0:
                new_outputs.extend([vi])
            else:
                if self.use_attn:
                    if "present_key" in vi.name:
                        vi_cache = helper.make_tensor_value_info(
                            vi.name.replace("present_key", "present"),
                            elem_type=vi.type.tensor_type.elem_type,
                            shape=[
                                2,
                                "batch_size",
                                self.num_attention_heads,
                                "total_seq_len",
                                self.hidden_size // self.num_attention_heads,
                            ],
                        )
                        new_outputs.extend([vi_cache])
                elif self.use_vllm:
                    pass
                else:
                    vi_cache = helper.make_tensor_value_info(
                        vi.name,
                        elem_type=vi.type.tensor_type.elem_type,
                        shape=[
                            "batch_size",
                            self.num_attention_heads,
                            "total_seq_len",
                            self.hidden_size // self.num_attention_heads,
                        ],
                    )
                    new_outputs.extend([vi_cache])

        graph.ClearField("output")
        graph.output.extend(new_outputs)

    def preprocess_onnx(self, attn_op_type: AttentionOpType):
        function_name = None
        for func in self.model.functions:
            if func.name.endswith(self.func_name):
                function_name = func.name
                break
        assert function_name is not None
        self.unroll_function(function_name)
        self.update_edges(self.get_phi2_edge_dict())
        self.simplify_phi2_op_type()
        self.remove_dropout_layer()
        if attn_op_type == AttentionOpType.PagedAttention:
            self.remove_lm_head_layer()
        self.process_graph_io(attn_op_type)


class FissionTransformerEmbeddingPhi(Fission):
    def __init__(
        self,
        model: OnnxModel,
    ):
        super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"])

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        logger.info("Optimizing %s...", node.name)

        assert len(node.input) == 2
        assert len(node.output) == 1

        input = node.input[0]
        output = node.output[0]

        embedding = self.get_input_by_name(node, "embed_tokens.weight")

        layer_known_edges_names = [input, output, embedding]

        subgraph_nodes = [
            helper.make_node(
                "Gather",
                inputs=[embedding, input],
                outputs=[output],
                name="Embedding_Gather",
            ),
        ]

        self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names)
        self.nodes_to_remove.append(node)
        self.prune_graph = True


class FissionTransformerLayerNormPhi(Fission):
    def __init__(
        self,
        model: OnnxModel,
    ):
        super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"])

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        logger.info("Optimizing %s...", node.name)

        assert len(node.input) == 3
        assert len(node.output) == 1

        input = node.input[0]
        output = node.output[0]

        ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
        ln_bias = self.get_input_by_name(node, "final_layernorm.bias")

        layer_known_edges_names = [input, output, ln_weight, ln_bias]

        subgraph_nodes = []
        subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final"))

        self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)

        self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
        self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"])

        self.nodes_to_remove.append(node)
        self.prune_graph = True


class FissionTransformerCausalLMHeadPhi(Fission):
    def __init__(
        self,
        model: OnnxModel,
    ):
        super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"])

    def fuse(self, node, input_name_to_nodes, output_name_to_node):
        logger.info("Optimizing %s...", node.name)

        assert len(node.input) == 5
        assert len(node.output) == 1

        input = node.input[2]
        output = node.output[0]

        fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
        fc_bias = self.get_input_by_name(node, "lm_head.bias")

        layer_known_edges_names = [input, output, fc_weight, fc_bias]

        subgraph_nodes = []
        subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_"))

        self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names)

        self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"])
        self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200])

        self.nodes_to_remove.append(node)
        self.prune_graph = True


class FissionTransformerBlockPhi(Fission):
    def __init__(
        self,
        model: OnnxModel,
        num_heads: int,
    ):
        self.num_heads = num_heads
        max_num_layers = 32
        self.func_to_layer_id = {}
        nodes_to_find = []
        for layer in range(max_num_layers):
            func_name = f"modeling_phi_PhiDecoderLayer_model_layers_{layer}_1"
            nodes_to_find.append(func_name)
            self.func_to_layer_id[func_name] = layer

        super().__init__(model, nodes_to_find)

    def get_layer_id(self, node):
        return self.func_to_layer_id[node.op_type]

    def get_gqa_aux_nodes(self):
        gqa_aux_nodes = [
            helper.make_node(
                "Cast",
                inputs=["attention_mask"],
                outputs=["mask_int64"],
                name="Cast_gqa_aux_0",
                to=TensorProto.INT64,
            ),
            helper.make_node(
                "ReduceSum",
                inputs=["mask_int64", "one"],
                outputs=["mask_row_sums"],
                name="ReduceSum_gqa_aux",
            ),
            helper.make_node(
                "Sub",
                inputs=["mask_row_sums", "one"],
                outputs=["seqlens_k_int64"],
                name="Sub_gqa_aux",
            ),
            helper.make_node(
                "Cast",
                inputs=["seqlens_k_int64"],
                outputs=["seqlens_k"],
                name="Cast_gqa_aux_1",
                to=TensorProto.INT32,
            ),
            helper.make_node("Shape", inputs=["mask_int64"], outputs=["mask_shape"], name="Shape_gqa_aux_0"),
            helper.make_node(
                "Gather",
                inputs=["mask_shape", "one"],
                outputs=["total_seq_len_int64"],
                name="Gather_gqa_aux_0",
                axis=0,
            ),
            helper.make_node(
                "Cast",
                inputs=["total_seq_len_int64"],
                outputs=["total_sequence_length"],
                name="Cast_gqa_aux_2",
                to=TensorProto.INT32,
            ),
        ]
        return gqa_aux_nodes

    def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
        q_weight = self.model.get_initializer(q_w)
        k_weight = self.model.get_initializer(k_w)
        v_weight = self.model.get_initializer(v_w)
        qw = np.transpose(NumpyHelper.to_array(q_weight), (1, 0))
        kw = np.transpose(NumpyHelper.to_array(k_weight), (1, 0))
        vw = np.transpose(NumpyHelper.to_array(v_weight), (1, 0))
        qkv_weight = np.stack((qw, kw, vw), axis=1)

        q_bias = self.model.get_initializer(q_b)
        k_bias = self.model.get_initializer(k_b)
        v_bias = self.model.get_initializer(v_b)
        qb = NumpyHelper.to_array(q_bias)
        kb = NumpyHelper.to_array(k_bias)
        vb = NumpyHelper.to_array(v_bias)
        qkv_bias = np.stack((qb, kb, vb), axis=0)

        hidden_size = qkv_weight.shape[0]

        weight = helper.make_tensor(
            weight_name,
            data_type=TensorProto.FLOAT,
            dims=[hidden_size, hidden_size * 3],
            vals=qkv_weight.flatten().tobytes(),
            raw=True,
        )
        self.model.add_initializer(weight, self.this_graph_name)

        bias = helper.make_tensor(
            bias_name,
            data_type=TensorProto.FLOAT,
            dims=[hidden_size * 3],
            vals=qkv_bias.flatten().tobytes(),
            raw=True,
        )
        self.model.add_initializer(bias, self.this_graph_name)

        self.add_fp32_value_info(weight.name)
        self.add_fp32_value_info(bias.name)

        return weight_name, bias_name

    def fuse(
        self,
        node,
        input_name_to_nodes,
        output_name_to_node,
    ):
        logger.info("Optimizing %s...", node.name)

        logger.info(f"AttentionOpType: {self.attn_op_type}")

        layer_id = self.get_layer_id(node)

        i_hidden_states = node.input[0]
        i_key_cache = self.get_input_by_name(node, "past_key")
        i_value_cache = self.get_input_by_name(node, "past_value")

        o_hidden_states = node.output[-1]
        o_key_cache = self.get_output_by_name(node, "present_key")
        o_value_cache = self.get_output_by_name(node, "present_value")

        ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
        ln_bias = self.get_input_by_name(node, "input_layernorm.bias")

        attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
            None,
            None,
            None,
            None,
            None,
            None,
        )
        attn_qkv_weight, attn_qkv_bias = None, None
        cos_cache, sin_cache = None, None

        if self.attn_op_type != AttentionOpType.Attention:
            attn_q_weight = self.process_initializer(
                self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
            )
            attn_k_weight = self.process_initializer(
                self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
            )
            attn_v_weight = self.process_initializer(
                self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
            )
            attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
            attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
            attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")

            cos_cache = self.process_initializer(
                self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
            )
            sin_cache = self.process_initializer(
                self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
            )
        else:
            attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
                self.get_input_by_name(node, "self_attn.q_proj.weight"),
                self.get_input_by_name(node, "self_attn.k_proj.weight"),
                self.get_input_by_name(node, "self_attn.v_proj.weight"),
                self.get_input_by_name(node, "self_attn.q_proj.bias"),
                self.get_input_by_name(node, "self_attn.k_proj.bias"),
                self.get_input_by_name(node, "self_attn.v_proj.bias"),
                self.get_uname(layer_id, "attn_qkv_weight"),
                self.get_uname(layer_id, "attn_qkv_bias"),
            )

        attn_out_weight = self.process_initializer(
            self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
        )
        attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")

        mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
        mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
        mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
        mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")

        layer_known_edges_names = []
        layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
        layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache])
        layer_known_edges_names.extend([ln_weight, ln_bias])
        if self.attn_op_type != AttentionOpType.Attention:
            layer_known_edges_names.extend(
                [
                    attn_q_weight,
                    attn_q_bias,
                    attn_k_weight,
                    attn_k_bias,
                    attn_v_weight,
                    attn_v_bias,
                    cos_cache,
                    sin_cache,
                ]
            )
        else:
            layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias])
        layer_known_edges_names.extend(
            [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]
        )
        layer_known_edges_names.extend(
            ["attention_mask", "step", "seqlens_k", "total_sequence_length", "input_metadata", "position_ids"]
        )

        subgraph_nodes = []
        subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"]))
        subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_"))
        subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_"))
        subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"]))
        subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_"))
        subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1"))
        subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2"))
        if self.attn_op_type != AttentionOpType.Attention:
            subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
            subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
            subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
            # vllm engine requires full position ids as the input
            pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
            subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
            subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
            if self.attn_op_type == AttentionOpType.MultiHeadAttention:
                subgraph_nodes.extend(
                    self.mha(
                        ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache],
                        ["attn_out", o_key_cache, o_value_cache],
                    )
                )
            elif self.attn_op_type == AttentionOpType.GroupQueryAttention:
                subgraph_nodes.extend(
                    self.gqa(
                        [
                            "query_rot",
                            "key_rot",
                            "value",
                            i_key_cache,
                            i_value_cache,
                            "seqlens_k",
                            "total_sequence_length",
                        ],
                        ["attn_out", o_key_cache, o_value_cache],
                    )
                )
                if layer_id == 0:
                    gqa_aux_nodes = self.get_gqa_aux_nodes()
                    for new_node in gqa_aux_nodes:
                        self.nodes_to_add.append(new_node)
                        self.node_name_to_graph_name[new_node.name] = self.this_graph_name
                    self.model.add_initializer(
                        numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
                    )
            elif self.attn_op_type == AttentionOpType.PagedAttention:
                subgraph_nodes.extend(
                    self.paged_attn(
                        ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "input_metadata"],
                        ["attn_out"],
                    )
                )
        else:
            past_name = f"past_{layer_id}"
            present_name = f"present_{layer_id}"
            layer_known_edges_names.extend([past_name, present_name])
            subgraph_nodes.extend(
                self.attention(
                    ["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
                )
            )

        self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)

        self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
        self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])

        self.nodes_to_remove.append(node)
        self.prune_graph = True


class PhiOnnxModel(OnnxModel):
    def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
        super().__init__(model)
        self.phi2_preprocessor = Phi2PreProcessor(self.model, num_heads, hidden_size)
        self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads)
        self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self)
        self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self)
        self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self)

    def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
        assert options is not None
        attn_op_type = options.attention_op_type

        self.fission_transformer_block.set_attention_op_type(attn_op_type)

        self.phi2_preprocessor.preprocess_onnx(attn_op_type)

        self.fission_transformer_block.apply()
        self.fission_transformer_layernorm.apply()
        self.fission_causal_lm_head.apply()
        self.fission_transformer_embedding.apply()

        super().prune_graph()

        # SLN ctor is placed here intentionally to delay the symbolic shape inference
        self.fuse_sln = FusionSkipLayerNormalization(self)
        self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self)
        self.fuse_sln.apply()
        self.fuse_bias_sln.apply()

    def get_fused_operator_statistics(self):
        """
        Returns node count of fused operators.
        """
        op_count = {}
        ops = [
            "Attention",
            "MultiHeadAttention",
            "GroupQueryAttention",
            "PagedAttention",
            "Gelu",
            "BiasGelu",
            "FastGelu",
            "LayerNormalization",
            "SkipLayerNormalization",
        ]
        for op in ops:
            nodes = self.get_nodes_by_op_type(op)
            op_count[op] = len(nodes)

        logger.info(f"Optimized operators: {op_count}")
        return op_count

    def is_fully_optimized(self, fused_op_count=None):
        """
        Returns True when the model is fully optimized.
        """
        if fused_op_count is None:
            fused_op_count = self.get_fused_operator_statistics()

        def op_count(op_name: str):
            return fused_op_count.get(op_name) or 0

        attention = (
            op_count("Attention")
            + op_count("MultiHeadAttention")
            + op_count("GroupQueryAttention")
            + op_count("PagedAttention")
        )
        gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
        layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")

        is_perfect = (attention > 0) and (attention == gelu) and (layer_norm >= attention)

        if layer_norm == 0:
            logger.debug("Layer Normalization not fused")

        if gelu == 0:
            logger.debug("Gelu (or FastGelu) not fused")

        if attention == 0:
            logger.warning("Attention (or MultiHeadAttention) not fused")

        return is_perfect
