# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""
Benchmark performance of SAM2 encoder with ORT or PyTorch. See benchmark_sam2.sh for usage.
"""

import argparse
import csv
import statistics
import time
from datetime import datetime
from typing import List, Mapping, Optional

import torch
from image_decoder import SAM2ImageDecoder
from image_encoder import SAM2ImageEncoder
from sam2_utils import decoder_shape_dict, encoder_shape_dict, load_sam2_model

from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from onnxruntime.transformers.io_binding_helper import CudaSession


class TestConfig:
    def __init__(
        self,
        model_type: str,
        onnx_path: str,
        sam2_dir: str,
        device: torch.device,
        component: str = "image_encoder",
        provider="CPUExecutionProvider",
        torch_compile_mode="max-autotune",
        batch_size: int = 1,
        height: int = 1024,
        width: int = 1024,
        num_labels: int = 1,
        num_points: int = 1,
        num_masks: int = 1,
        multi_mask_output: bool = False,
        use_tf32: bool = True,
        enable_cuda_graph: bool = False,
        dtype=torch.float32,
        prefer_nhwc: bool = False,
        warm_up: int = 5,
        enable_nvtx_profile: bool = False,
        enable_torch_profile: bool = False,
        repeats: int = 1000,
        verbose: bool = False,
    ):
        assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
        assert height >= 160 and height <= 4096
        assert width >= 160 and width <= 4096

        self.model_type = model_type
        self.onnx_path = onnx_path
        self.sam2_dir = sam2_dir
        self.component = component
        self.provider = provider
        self.torch_compile_mode = torch_compile_mode
        self.batch_size = batch_size
        self.height = height
        self.width = width
        self.num_labels = num_labels
        self.num_points = num_points
        self.num_masks = num_masks
        self.multi_mask_output = multi_mask_output
        self.device = device
        self.use_tf32 = use_tf32
        self.enable_cuda_graph = enable_cuda_graph
        self.dtype = dtype
        self.prefer_nhwc = prefer_nhwc
        self.warm_up = warm_up
        self.enable_nvtx_profile = enable_nvtx_profile
        self.enable_torch_profile = enable_torch_profile
        self.repeats = repeats
        self.verbose = verbose

        if self.component == "image_encoder":
            assert self.height == 1024 and self.width == 1024, "Only image size 1024x1024 is allowed for image encoder."

    def __repr__(self):
        return f"{vars(self)}"

    def shape_dict(self) -> Mapping[str, List[int]]:
        if self.component == "image_encoder":
            return encoder_shape_dict(self.batch_size, self.height, self.width)
        else:
            return decoder_shape_dict(self.height, self.width, self.num_labels, self.num_points, self.num_masks)

    def random_inputs(self) -> Mapping[str, torch.Tensor]:
        dtype = self.dtype
        if self.component == "image_encoder":
            return {"image": torch.randn(self.batch_size, 3, self.height, self.width, dtype=dtype, device=self.device)}
        else:
            return {
                "image_features_0": torch.rand(1, 32, 256, 256, dtype=dtype, device=self.device),
                "image_features_1": torch.rand(1, 64, 128, 128, dtype=dtype, device=self.device),
                "image_embeddings": torch.rand(1, 256, 64, 64, dtype=dtype, device=self.device),
                "point_coords": torch.randint(
                    0, 1024, (self.num_labels, self.num_points, 2), dtype=dtype, device=self.device
                ),
                "point_labels": torch.randint(
                    0, 1, (self.num_labels, self.num_points), dtype=torch.int32, device=self.device
                ),
                "input_masks": torch.zeros(self.num_labels, 1, 256, 256, dtype=dtype, device=self.device),
                "has_input_masks": torch.ones(self.num_labels, dtype=dtype, device=self.device),
                "original_image_size": torch.tensor([self.height, self.width], dtype=torch.int32, device=self.device),
            }


def create_ort_session(config: TestConfig, session_options=None) -> InferenceSession:
    if config.verbose:
        print(f"create session for {vars(config)}")

    if config.provider == "CUDAExecutionProvider":
        device_id = torch.cuda.current_device() if isinstance(config.device, str) else config.device.index
        provider_options = CudaSession.get_cuda_provider_options(device_id, config.enable_cuda_graph)
        provider_options["use_tf32"] = int(config.use_tf32)
        if config.prefer_nhwc:
            provider_options["prefer_nhwc"] = 1
        providers = [(config.provider, provider_options), "CPUExecutionProvider"]
    else:
        providers = ["CPUExecutionProvider"]

    ort_session = InferenceSession(config.onnx_path, session_options, providers=providers)
    return ort_session


def create_session(config: TestConfig, session_options=None) -> CudaSession:
    ort_session = create_ort_session(config, session_options)
    cuda_session = CudaSession(ort_session, config.device, config.enable_cuda_graph)
    cuda_session.allocate_buffers(config.shape_dict())
    return cuda_session


class OrtTestSession:
    """A wrapper of ORT session to test relevance and performance."""

    def __init__(self, config: TestConfig, session_options=None):
        self.ort_session = create_session(config, session_options)
        self.feed_dict = config.random_inputs()

    def infer(self):
        return self.ort_session.infer(self.feed_dict)


def measure_latency(cuda_session: CudaSession, input_dict):
    start = time.time()
    _ = cuda_session.infer(input_dict)
    end = time.time()
    return end - start


def run_torch(config: TestConfig):
    device_type = config.device.type
    is_cuda = device_type == "cuda"

    # Turn on TF32 for Ampere GPUs which could help when data type is float32.
    if is_cuda and torch.cuda.get_device_properties(0).major >= 8 and config.use_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    enabled_auto_cast = is_cuda and config.dtype != torch.float32
    ort_inputs = config.random_inputs()

    with torch.inference_mode(), torch.autocast(device_type=device_type, dtype=config.dtype, enabled=enabled_auto_cast):
        sam2_model = load_sam2_model(config.sam2_dir, config.model_type, device=config.device)
        if config.component == "image_encoder":
            if is_cuda and config.torch_compile_mode != "none":
                sam2_model.image_encoder.forward = torch.compile(
                    sam2_model.image_encoder.forward,
                    mode=config.torch_compile_mode,  # "reduce-overhead" if you want to reduce latency of first run.
                    fullgraph=True,
                    dynamic=False,
                )

            image_shape = config.shape_dict()["image"]
            img = torch.randn(image_shape).to(device=config.device, dtype=config.dtype)
            sam2_encoder = SAM2ImageEncoder(sam2_model)

            if is_cuda and config.torch_compile_mode != "none":
                print(f"Running warm up. It will take a while since torch compile mode is {config.torch_compile_mode}.")

            for _ in range(config.warm_up):
                _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)

            if is_cuda and config.enable_nvtx_profile:
                import nvtx
                from cuda import cudart

                cudart.cudaProfilerStart()
                print("Start nvtx profiling on encoder ...")
                with nvtx.annotate("one_run"):
                    sam2_encoder(img, enable_nvtx_profile=True)
                cudart.cudaProfilerStop()

            if is_cuda and config.enable_torch_profile:
                with torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
                    record_shapes=True,
                ) as prof:
                    print("Start torch profiling on encoder ...")
                    with torch.profiler.record_function("encoder"):
                        sam2_encoder(img)
                print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
                prof.export_chrome_trace("torch_image_encoder.json")

            if config.repeats == 0:
                return

            print(f"Start {config.repeats} runs of performance tests...")
            start = time.time()
            for _ in range(config.repeats):
                _image_features_0, _image_features_1, _image_embeddings = sam2_encoder(img)
                if is_cuda:
                    torch.cuda.synchronize()
        else:
            torch_inputs = (
                ort_inputs["image_features_0"],
                ort_inputs["image_features_1"],
                ort_inputs["image_embeddings"],
                ort_inputs["point_coords"],
                ort_inputs["point_labels"],
                ort_inputs["input_masks"],
                ort_inputs["has_input_masks"],
                ort_inputs["original_image_size"],
            )

            sam2_decoder = SAM2ImageDecoder(
                sam2_model,
                multimask_output=config.multi_mask_output,
            )

            if is_cuda and config.torch_compile_mode != "none":
                sam2_decoder.forward = torch.compile(
                    sam2_decoder.forward,
                    mode=config.torch_compile_mode,
                    fullgraph=True,
                    dynamic=False,
                )

            # warm up
            for _ in range(config.warm_up):
                _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)

            if is_cuda and config.enable_nvtx_profile:
                import nvtx
                from cuda import cudart

                cudart.cudaProfilerStart()
                print("Start nvtx profiling on decoder...")
                with nvtx.annotate("one_run"):
                    sam2_decoder(*torch_inputs, enable_nvtx_profile=True)
                cudart.cudaProfilerStop()

            if is_cuda and config.enable_torch_profile:
                with torch.profiler.profile(
                    activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
                    record_shapes=True,
                ) as prof:
                    print("Start torch profiling on decoder ...")
                    with torch.profiler.record_function("decoder"):
                        sam2_decoder(*torch_inputs)
                print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
                prof.export_chrome_trace("torch_image_decoder.json")

            if config.repeats == 0:
                return

            print(f"Start {config.repeats} runs of performance tests...")
            start = time.time()
            for _ in range(config.repeats):
                _masks, _iou_predictions, _low_res_masks = sam2_decoder(*torch_inputs)
                if is_cuda:
                    torch.cuda.synchronize()

        end = time.time()
        return (end - start) / config.repeats


def run_test(
    args: argparse.Namespace,
    csv_writer: Optional[csv.DictWriter] = None,
):
    use_gpu: bool = args.use_gpu
    enable_cuda_graph: bool = args.use_cuda_graph
    repeats: int = args.repeats

    if use_gpu:
        device_id = torch.cuda.current_device()
        device = torch.device("cuda", device_id)
        provider = "CUDAExecutionProvider"
    else:
        device_id = 0
        device = torch.device("cpu")
        enable_cuda_graph = False
        provider = "CPUExecutionProvider"

    dtypes = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
    config = TestConfig(
        model_type=args.model_type,
        onnx_path=args.onnx_path,
        sam2_dir=args.sam2_dir,
        component=args.component,
        provider=provider,
        batch_size=args.batch_size,
        height=args.height,
        width=args.width,
        device=device,
        use_tf32=True,
        enable_cuda_graph=enable_cuda_graph,
        dtype=dtypes[args.dtype],
        prefer_nhwc=args.prefer_nhwc,
        repeats=args.repeats,
        warm_up=args.warm_up,
        enable_nvtx_profile=args.enable_nvtx_profile,
        enable_torch_profile=args.enable_torch_profile,
        torch_compile_mode=args.torch_compile_mode,
        verbose=False,
    )

    if args.engine == "ort":
        sess_options = SessionOptions()
        sess_options.intra_op_num_threads = args.intra_op_num_threads
        if config.enable_nvtx_profile:
            sess_options.enable_profiling = True
            sess_options.log_severity_level = 4
            sess_options.log_verbosity_level = 0

        session = create_session(config, sess_options)
        input_dict = config.random_inputs()

        # warm up session
        try:
            for _ in range(config.warm_up):
                _ = measure_latency(session, input_dict)
        except Exception as e:
            print(f"Failed to run {config=}. Exception: {e}")
            return

        if config.enable_nvtx_profile:
            import nvtx
            from cuda import cudart

            cudart.cudaProfilerStart()
            with nvtx.annotate("one_run"):
                _ = session.infer(input_dict)
            cudart.cudaProfilerStop()
            session.ort_session.end_profiling()

        if repeats == 0:
            return

        latency_list = []
        for _ in range(repeats):
            latency = measure_latency(session, input_dict)
            latency_list.append(latency)
        average_latency = statistics.mean(latency_list)

        del session
    else:  # torch
        with torch.no_grad():
            try:
                average_latency = run_torch(config)
            except Exception as e:
                print(f"Failed to run {config=}. Exception: {e}")
                return

        if repeats == 0:
            return

    engine = args.engine + ":" + ("cuda" if use_gpu else "cpu")
    row = {
        "model_type": args.model_type,
        "component": args.component,
        "dtype": args.dtype,
        "use_gpu": use_gpu,
        "enable_cuda_graph": enable_cuda_graph,
        "prefer_nhwc": config.prefer_nhwc,
        "use_tf32": config.use_tf32,
        "batch_size": args.batch_size,
        "height": args.height,
        "width": args.width,
        "multi_mask_output": args.multimask_output,
        "num_labels": config.num_labels,
        "num_points": config.num_points,
        "num_masks": config.num_masks,
        "intra_op_num_threads": args.intra_op_num_threads,
        "warm_up": config.warm_up,
        "repeats": repeats,
        "enable_nvtx_profile": args.enable_nvtx_profile,
        "torch_compile_mode": args.torch_compile_mode,
        "engine": engine,
        "average_latency": average_latency,
    }

    if csv_writer is not None:
        csv_writer.writerow(row)

    print(f"{vars(config)}")
    print(f"{row}")


def run_perf_test(args):
    features = "gpu" if args.use_gpu else "cpu"
    csv_filename = "benchmark_sam_{}_{}_{}.csv".format(
        features,
        args.engine,
        datetime.now().strftime("%Y%m%d-%H%M%S"),
    )
    with open(csv_filename, mode="a", newline="") as csv_file:
        column_names = [
            "model_type",
            "component",
            "dtype",
            "use_gpu",
            "enable_cuda_graph",
            "prefer_nhwc",
            "use_tf32",
            "batch_size",
            "height",
            "width",
            "multi_mask_output",
            "num_labels",
            "num_points",
            "num_masks",
            "intra_op_num_threads",
            "warm_up",
            "repeats",
            "enable_nvtx_profile",
            "torch_compile_mode",
            "engine",
            "average_latency",
        ]
        csv_writer = csv.DictWriter(csv_file, fieldnames=column_names)
        csv_writer.writeheader()

        run_test(args, csv_writer)


def _parse_arguments():
    parser = argparse.ArgumentParser(description="Benchmark SMA2 for ONNX Runtime and PyTorch.")

    parser.add_argument(
        "--component",
        required=False,
        choices=["image_encoder", "image_decoder"],
        default="image_encoder",
        help="component to benchmark. Choices are image_encoder and image_decoder.",
    )

    parser.add_argument(
        "--dtype", required=False, choices=["fp32", "fp16", "bf16"], default="fp32", help="Data type for inference."
    )

    parser.add_argument(
        "--use_gpu",
        required=False,
        action="store_true",
        help="Use GPU for inference.",
    )
    parser.set_defaults(use_gpu=False)

    parser.add_argument(
        "--use_cuda_graph",
        required=False,
        action="store_true",
        help="Use cuda graph in onnxruntime.",
    )
    parser.set_defaults(use_cuda_graph=False)

    parser.add_argument(
        "--intra_op_num_threads",
        required=False,
        type=int,
        choices=[0, 1, 2, 4, 8, 16],
        default=0,
        help="intra_op_num_threads for onnxruntime. ",
    )

    parser.add_argument(
        "--batch_size",
        required=False,
        type=int,
        default=1,
        help="batch size",
    )

    parser.add_argument(
        "--height",
        required=False,
        type=int,
        default=1024,
        help="image height",
    )

    parser.add_argument(
        "--width",
        required=False,
        type=int,
        default=1024,
        help="image width",
    )

    parser.add_argument(
        "--repeats",
        required=False,
        type=int,
        default=1000,
        help="number of repeats for performance test. Default is 1000.",
    )

    parser.add_argument(
        "--warm_up",
        required=False,
        type=int,
        default=5,
        help="number of runs for warm up. Default is 5.",
    )

    parser.add_argument(
        "--engine",
        required=False,
        type=str,
        default="ort",
        choices=["ort", "torch"],
        help="engine for inference",
    )

    parser.add_argument(
        "--multimask_output",
        required=False,
        default=False,
        action="store_true",
        help="Export mask_decoder or image_decoder with multimask_output",
    )

    parser.add_argument(
        "--prefer_nhwc",
        required=False,
        default=False,
        action="store_true",
        help="Use prefer_nhwc=1 provider option for CUDAExecutionProvider",
    )

    parser.add_argument(
        "--enable_nvtx_profile",
        required=False,
        default=False,
        action="store_true",
        help="Enable nvtx profiling. It will add an extra run for profiling before performance test.",
    )

    parser.add_argument(
        "--enable_torch_profile",
        required=False,
        default=False,
        action="store_true",
        help="Enable PyTorch profiling. It will add an extra run for profiling before performance test.",
    )

    parser.add_argument(
        "--model_type",
        required=False,
        type=str,
        default="sam2_hiera_large",
        choices=["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"],
        help="sam2 model name",
    )

    parser.add_argument(
        "--sam2_dir",
        required=False,
        type=str,
        default="./segment-anything-2",
        help="The directory of segment-anything-2 git root directory",
    )

    parser.add_argument(
        "--onnx_path",
        required=False,
        type=str,
        default="./sam2_onnx_models/sam2_hiera_large_image_encoder.onnx",
        help="path of onnx model",
    )

    parser.add_argument(
        "--torch_compile_mode",
        required=False,
        type=str,
        default=None,
        choices=["reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs", "none"],
        help="torch compile mode. none will disable torch compile.",
    )

    args = parser.parse_args()

    return args


if __name__ == "__main__":
    args = _parse_arguments()
    print(f"arguments:{args}")

    if args.torch_compile_mode is None:
        # image decoder will fail with compile modes other than "none".
        args.torch_compile_mode = "max-autotune" if args.component == "image_encoder" else "none"

    if args.use_gpu:
        assert torch.cuda.is_available()
        if args.engine == "ort":
            assert "CUDAExecutionProvider" in get_available_providers()
            args.enable_torch_profile = False
    else:
        # Only support cuda profiling for now.
        assert not args.enable_nvtx_profile
        assert not args.enable_torch_profile

    if args.enable_nvtx_profile or args.enable_torch_profile:
        run_test(args)
    else:
        run_perf_test(args)
