# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

from logging import getLogger
from typing import Optional

from fusion_attention_vae import FusionAttentionVae
from fusion_options import FusionOptions
from onnx import ModelProto
from onnx_model_unet import UnetOnnxModel

logger = getLogger(__name__)


class VaeOnnxModel(UnetOnnxModel):
    def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
        assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
        super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)

    def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None):
        # Self Attention
        self_attention_fusion = FusionAttentionVae(self, self.hidden_size, self.num_heads)
        self_attention_fusion.apply()

    def get_fused_operator_statistics(self):
        """
        Returns node count of fused operators.
        """
        op_count = {}
        ops = [
            "Attention",
            "GroupNorm",
            "SkipGroupNorm",
            "NhwcConv",
        ]
        for op in ops:
            nodes = self.get_nodes_by_op_type(op)
            op_count[op] = len(nodes)

        logger.info(f"Optimized operators:{op_count}")
        return op_count
