# -------------------------------------------------------------------------
# Copyright (R) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
import logging

import torch
from sam2.modeling.sam2_base import SAM2Base
from sam2_utils import compare_tensors_with_tolerance
from torch import nn

logger = logging.getLogger(__name__)


class SAM2PromptEncoder(nn.Module):
    def __init__(self, sam_model: SAM2Base):
        super().__init__()
        self.prompt_encoder = sam_model.sam_prompt_encoder
        self.model = sam_model

    @torch.no_grad()
    def forward(
        self,
        point_coords: torch.Tensor,
        point_labels: torch.Tensor,
        input_masks: torch.Tensor,
        has_input_masks: torch.Tensor,
    ):
        """Encode prompts.

           Args:
            point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
                                         coordinate in (x, y) format of the P input points in image of size 1024x1024.
            point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
                                         positive (foreground), 0 means negative (background), -1 means padding,
                                         2 (box left upper corner), 3 (box right bottom corner).
            input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
                                        Typically coming from a previous iteration.
            has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
        Returns:
            sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
            dense_embeddings (torch.Tensor):  [L, 256, 64, 64]. embedding for input masks.
            image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
        """
        sparse_embeddings = self._embed_points(point_coords, point_labels)
        dense_embeddings = self._embed_masks(input_masks, has_input_masks)
        image_pe = self.prompt_encoder.get_dense_pe()

        return sparse_embeddings, dense_embeddings, image_pe

    def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
        point_coords = point_coords + 0.5

        padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
        padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
        point_coords = torch.cat([point_coords, padding_point], dim=1)
        point_labels = torch.cat([point_labels, padding_label], dim=1)

        # Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
        point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
        point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size

        point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
        point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)

        point_embedding = point_embedding * (point_labels != -1)
        point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)

        for i in range(self.prompt_encoder.num_point_embeddings):
            point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)

        return point_embedding

    def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
        mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
        no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
        logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
        mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
        logger.info("mask_embedding.shape: %s", mask_embedding.shape)
        return mask_embedding


def export_prompt_encoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
):
    sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()

    num_labels = 2
    num_points = 3
    point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
    point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
    input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
    has_input_masks = torch.ones(1, dtype=torch.float)

    sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
        point_coords, point_labels, input_masks, has_input_masks
    )

    logger.info("point_coords.shape: %s", point_coords.shape)
    logger.info("point_labels.shape: %s", point_labels.shape)
    logger.info("input_masks.shape: %s", input_masks.shape)
    logger.info("has_input_masks.shape: %s", has_input_masks.shape)

    logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
    logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
    logger.info("image_pe.shape: %s", image_pe.shape)

    torch.onnx.export(
        sam2_prompt_encoder,
        (point_coords, point_labels, input_masks, has_input_masks),
        onnx_model_path,
        export_params=True,
        opset_version=18,
        do_constant_folding=True,
        input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
        output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
        dynamic_axes={
            "point_coords": {0: "num_labels", 1: "num_points"},
            "point_labels": {0: "num_labels", 1: "num_points"},
            "input_masks": {0: "num_labels"},
            "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
            "dense_embeddings": {0: "num_labels"},
        },
    )

    print("prompt encoder onnx model saved to ", onnx_model_path)


def test_prompt_encoder_onnx(
    sam2_model: SAM2Base,
    onnx_model_path: str,
):
    sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()

    num_labels = 1
    num_points = 5
    point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
    point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
    input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
    has_input_masks = torch.ones(1, dtype=torch.float)

    sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
        point_coords, point_labels, input_masks, has_input_masks
    )

    import onnxruntime

    ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=onnxruntime.get_available_providers())

    model_inputs = ort_session.get_inputs()
    input_names = [model_inputs[i].name for i in range(len(model_inputs))]
    logger.info("input_names: %s", input_names)

    model_outputs = ort_session.get_outputs()
    output_names = [model_outputs[i].name for i in range(len(model_outputs))]
    logger.info("output_names: %s", output_names)

    outputs = ort_session.run(
        output_names,
        {
            "point_coords": point_coords.numpy(),
            "point_labels": point_labels.numpy(),
            "input_masks": input_masks.numpy(),
            "has_input_masks": has_input_masks.numpy(),
        },
    )

    for i, output_name in enumerate(output_names):
        logger.info("output %s shape: %s", output_name, outputs[i].shape)

    ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
    if (
        compare_tensors_with_tolerance(
            "sparse_embeddings",
            sparse_embeddings,
            torch.tensor(ort_sparse_embeddings),
            mismatch_percentage_tolerance=0.2,
        )
        and compare_tensors_with_tolerance(
            "dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
        )
        and compare_tensors_with_tolerance(
            "image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
        )
    ):
        print(f"onnx model has been verified: {onnx_model_path}")
    else:
        print(f"onnx model verification failed: {onnx_model_path}")
