# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging
import warnings

import torch
from image_encoder import SAM2ImageEncoder, random_sam2_input_image
from prompt_encoder import SAM2PromptEncoder
from sam2.modeling.sam2_base import SAM2Base
from torch import nn

logger = logging.getLogger(__name__)


class SAM2MaskDecoder(nn.Module):
    def __init__(
        self,
        sam_model: SAM2Base,
        multimask_output: bool,
        dynamic_multimask_via_stability: bool = True,
    ) -> None:
        super().__init__()
        self.mask_decoder = sam_model.sam_mask_decoder
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model
        self.multimask_output = multimask_output
        self.dynamic_multimask_via_stability = dynamic_multimask_via_stability

    @torch.no_grad()
    def forward(
        self,
        image_features_0: torch.Tensor,
        image_features_1: torch.Tensor,
        image_embeddings: torch.Tensor,
        image_pe: torch.Tensor,
        sparse_embeddings: torch.Tensor,
        dense_embeddings: torch.Tensor,
    ):
        """
        Decode masks from image and prompt embeddings. Only support H=W=1024.

        Args:
            image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
            image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
            image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
            image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
            sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
            dense_embeddings (torch.Tensor):  [L, 256, H/16, W/16]. embedding for input masks.

        Returns:
            low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
            iou_predictions (torch.Tensor): [1, M]. scores for M masks.
        """
        low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
            image_embeddings=image_embeddings,
            image_pe=image_pe,
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            repeat_image=sparse_embeddings.shape[0] > 1,  # batch mode
            high_res_features=[image_features_0, image_features_1],
        )

        if self.multimask_output:
            low_res_masks = low_res_masks[:, 1:, :, :]
            iou_predictions = iou_predictions[:, 1:]
        elif self.dynamic_multimask_via_stability:
            # When outputting a single mask, if the stability score from the current single-mask
            # output (based on output token 0) falls below a threshold, we instead select from
            # multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
            low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
                low_res_masks, iou_predictions
            )
        else:
            low_res_masks = low_res_masks[:, 0:1, :, :]
            iou_predictions = iou_predictions[:, 0:1]

        return low_res_masks, iou_predictions


def export_mask_decoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
    multimask_output: bool,
    dynamic_multimask_via_stability: bool = True,
    verbose=False,
):
    sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()

    image = random_sam2_input_image()
    sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
    image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
    logger.info("image_features_0.shape: %s", image_features_0.shape)
    logger.info("image_features_1.shape: %s", image_features_1.shape)
    logger.info("image_embeddings.shape: %s", image_embeddings.shape)

    # encode an random prompt
    num_labels = 2
    num_points = 3
    point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
    point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
    input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
    has_input_masks = torch.ones(1, dtype=torch.float)

    sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
        point_coords, point_labels, input_masks, has_input_masks
    )

    logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
    logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
    logger.info("image_pe.shape: %s", image_pe.shape)

    sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
    inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
    low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
    logger.info("low_res_masks.shape: %s", low_res_masks.shape)
    logger.info("iou_predictions.shape: %s", iou_predictions.shape)

    with warnings.catch_warnings():
        if not verbose:
            warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
            warnings.filterwarnings("ignore", category=UserWarning)
        torch.onnx.export(
            sam2_mask_decoder,
            inputs,
            onnx_model_path,
            export_params=True,
            opset_version=18,
            do_constant_folding=True,
            input_names=[
                "image_features_0",
                "image_features_1",
                "image_embeddings",
                "image_pe",
                "sparse_embeddings",
                "dense_embeddings",
            ],
            output_names=["low_res_masks", "iou_predictions"],
            dynamic_axes={
                "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
                "dense_embeddings": {0: "num_labels"},
                "low_res_masks": {0: "num_labels"},
                "iou_predictions": {0: "num_labels"},
            },
        )

    print("mask decoder onnx model saved to", onnx_model_path)


def test_mask_decoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
    multimask_output: bool,
    dynamic_multimask_via_stability: bool,
):
    sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()

    image = random_sam2_input_image()
    sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
    image_features_0, image_features_1, image_embeddings = sam2_encoder(image)

    num_labels = 1
    num_points = 5
    point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
    point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
    input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
    has_input_masks = torch.ones(1, dtype=torch.float)

    sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
        point_coords, point_labels, input_masks, has_input_masks
    )

    sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
    inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
    low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)

    import onnxruntime

    ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())

    model_inputs = ort_session.get_inputs()
    input_names = [model_inputs[i].name for i in range(len(model_inputs))]
    logger.info("input_names: %s", input_names)

    model_outputs = ort_session.get_outputs()
    output_names = [model_outputs[i].name for i in range(len(model_outputs))]
    logger.info("output_names: %s", output_names)

    outputs = ort_session.run(
        output_names,
        {
            "image_features_0": image_features_0.numpy(),
            "image_features_1": image_features_1.numpy(),
            "image_embeddings": image_embeddings.numpy(),
            "image_pe": image_pe.numpy(),
            "sparse_embeddings": sparse_embeddings.numpy(),
            "dense_embeddings": dense_embeddings.numpy(),
        },
    )

    for i, output_name in enumerate(output_names):
        logger.info("output %s shape: %s", output_name, outputs[i].shape)

    ort_low_res_masks, ort_iou_predictions = outputs
    torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
    torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
    print(f"onnx model has been verified: {onnx_model_path}")
