# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings

import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
from torch import nn

import onnxruntime

logger = logging.getLogger(__name__)


class SAM2ImageEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base) -> None:
        super().__init__()
        self.model = sam_model
        self.image_encoder = sam_model.image_encoder
        self.no_mem_embed = sam_model.no_mem_embed

    def forward(
        self,
        image: torch.Tensor,
        enable_nvtx_profile: bool = False,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Encodes images into features.

        Only supports H=W=1024. If you want to use different image sizes like 512x512,
        see https://github.com/facebookresearch/segment-anything-2/issues/138.

        Args:
            image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
            enable_nvtx_profile (bool): enable NVTX profiling.

        Returns:
            image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
            image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
            image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
        """
        nvtx_helper = None
        if enable_nvtx_profile:
            from nvtx_helper import NvtxHelper

            nvtx_helper = NvtxHelper(["image_encoder", "post_process"])

        if nvtx_helper is not None:
            nvtx_helper.start_profile("image_encoder")

        backbone_out = self.image_encoder(image)

        if nvtx_helper is not None:
            nvtx_helper.stop_profile("image_encoder")
            nvtx_helper.start_profile("post_process")

        # precompute projected level 0 and level 1 features in SAM decoder
        # to avoid running it again on every SAM click
        backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
        backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])

        # Prepare and flatten visual features.
        feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
        vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
        feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]

        # flatten NxCxHxW to HWxNxC
        # TODO: we should avoid this transpose since it will be transposed back to NCHW later.
        vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]

        vision_feats[-1] = vision_feats[-1] + self.no_mem_embed

        feats = [
            feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
            for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1])
        ][::-1]

        if nvtx_helper is not None:
            nvtx_helper.stop_profile("post_process")
            nvtx_helper.print_latency()

        return feats[0], feats[1], feats[2]


def export_image_encoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
    dynamic_batch_axes: bool = False,
    verbose: bool = False,
):
    image = random_sam2_input_image()

    sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
    image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
    logger.info("image.shape: %s", image.shape)
    logger.info("image_features_0.shape: %s", image_features_0.shape)
    logger.info("image_features_1.shape: %s", image_features_1.shape)
    logger.info("image_embeddings.shape: %s", image_embeddings.shape)

    dynamic_axes = None
    if dynamic_batch_axes:
        dynamic_axes = {
            "image": {0: "batch_size"},
            "image_features_0": {0: "batch_size"},
            "image_features_1": {0: "batch_size"},
            "image_embeddings": {0: "batch_size"},
        }

    with warnings.catch_warnings():
        if not verbose:
            warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
            warnings.filterwarnings("ignore", category=UserWarning)
        torch.onnx.export(
            sam2_encoder,
            image,
            onnx_model_path,
            export_params=True,
            opset_version=17,
            do_constant_folding=True,
            input_names=["image"],
            output_names=["image_features_0", "image_features_1", "image_embeddings"],
            dynamic_axes=dynamic_axes,
        )

    print("encoder onnx model saved to", onnx_model_path)


def test_image_encoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
    dynamic_batch_axes=False,
):
    ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())

    model_inputs = ort_session.get_inputs()
    input_names = [model_inputs[i].name for i in range(len(model_inputs))]
    logger.info("input_names: %s", input_names)

    model_outputs = ort_session.get_outputs()
    output_names = [model_outputs[i].name for i in range(len(model_outputs))]
    logger.info("output_names: %s", output_names)

    batch_sizes = [1, 2] if dynamic_batch_axes else [1]
    for batch_size in batch_sizes:
        image = random_sam2_input_image(batch_size)

        sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
        image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())

        logger.info("image.shape: %s", image.shape)
        logger.info("image_features_0.shape: %s", image_features_0.shape)
        logger.info("image_features_1.shape: %s", image_features_1.shape)
        logger.info("image_embeddings.shape: %s", image_embeddings.shape)

        outputs = ort_session.run(output_names, {"image": image.numpy()})
        for i, output_name in enumerate(output_names):
            logger.info("output %s shape %s", output_name, outputs[i].shape)
        ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs

        # ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
        if (
            compare_tensors_with_tolerance(
                "image_features_0",
                image_features_0,
                torch.tensor(ort_image_features_0),
                mismatch_percentage_tolerance=1,
            )
            and compare_tensors_with_tolerance(
                "image_features_1",
                image_features_1,
                torch.tensor(ort_image_features_1),
                mismatch_percentage_tolerance=1,
            )
            and compare_tensors_with_tolerance(
                "image_embeddings",
                image_embeddings,
                torch.tensor(ort_image_embeddings),
                mismatch_percentage_tolerance=1,
            )
        ):
            print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
        else:
            print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
