# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.  See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from __future__ import annotations

import numpy as np
import torch
from transformers import AutoConfig, AutoTokenizer

from onnxruntime import InferenceSession, OrtValue


# Get position_ids from attention_mask
def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool):
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)
    if use_past_kv:
        # Shape: (batch_size, 1)
        position_ids = position_ids[:, -1].unsqueeze(-1)

    # Shape: (batch_size, sequence_length)
    return position_ids


# Inputs for first pass to get initial past_key_values
#   input_ids: (batch_size, sequence_length)
#   attention_mask: (batch_size, sequence_length)
#   position_ids: (batch_size, sequence_length)
def get_sample_inputs(
    config: AutoConfig,
    device: torch.device,
    batch_size: int,
    seq_len: int,
    engine: str = "pt",
    return_dict: bool = False,
):
    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
    attention_mask = torch.ones(batch_size, seq_len, dtype=torch.int64)
    position_ids = get_position_ids(attention_mask, use_past_kv=False)

    # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
    input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
    attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
    position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)

    if not return_dict:
        # For export
        return (input_ids, attention_mask, position_ids)

    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "position_ids": position_ids,
    }
    return inputs


# Inputs for subsequent passes with past_key_values
#   input_ids: (batch_size, 1)
#   attention_mask: (batch_size, past_sequence_length + 1)
#   position_ids: (batch_size, 1)
#   past_key: (batch_size, num_heads, past_sequence_length, head_size)
#   past_value: (batch_size, num_heads, past_sequence_length, head_size)
def get_sample_with_past_kv_inputs(
    config: AutoConfig,
    device: torch.device,
    batch_size: int,
    past_seq_len: int,
    use_fp16: bool = False,
    engine: str = "pt",
    return_dict: bool = False,
    world_size: int = 1,
):
    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, 1), dtype=torch.int64)
    attention_mask = torch.ones(batch_size, past_seq_len + 1, dtype=torch.int64)
    # position_ids is of shape (batch_size, 1)
    position_ids = get_position_ids(attention_mask, use_past_kv=True)
    past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)

    # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
    input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
    attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
    position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
    past_kv = (
        flatten_past_kv_inputs(past_kv)
        if engine == "ort"
        else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
    )

    if not return_dict:
        # For export
        assert isinstance(past_kv, list)
        return (input_ids, attention_mask, position_ids, past_kv)

    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "position_ids": position_ids,
    }
    if engine == "ort":
        assert isinstance(past_kv, dict)
        inputs.update(past_kv)
    else:
        assert isinstance(past_kv, list)
        inputs["past_key_values"] = past_kv

    return inputs


# Inputs for all passes with past_key_values
#   input_ids: (batch_size, sequence_length)
#   attention_mask: (batch_size, past_sequence_length + sequence_length)
#   position_ids: (batch_size, sequence_length)
#   past_key: (batch_size, num_heads, kv_sequence_length, head_size)
#      For models with GQA, kv_sequence_length = max_sequence_length
#      For models without GQA, kv_sequence_length = past_sequence_length
#   past_value: (batch_size, num_heads, kv_sequence_length, head_size)
#      For models with GQA, kv_sequence_length = max_sequence_length
#      For models without GQA, kv_sequence_length = past_sequence_length
def get_merged_sample_with_past_kv_inputs(
    config: AutoConfig,
    device: torch.device,
    batch_size: int,
    seq_len: int,
    past_seq_len: int,
    max_seq_len: int,
    use_fp16: bool = False,
    use_buffer_share: bool = False,
    engine: str = "pt",
    return_dict: bool = False,
    world_size: int = 1,
):
    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seq_len), dtype=torch.int64)
    attention_mask = torch.ones(batch_size, past_seq_len + seq_len, dtype=torch.int64)
    # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation
    position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0))
    past_kv = get_past_kv_inputs(config, batch_size, past_seq_len, use_fp16, world_size=world_size)

    # Convert inputs to NumPy (for ORT) or send to device (for PyTorch)
    input_ids = input_ids.numpy() if engine == "ort" else input_ids.to(device)
    attention_mask = attention_mask.numpy() if engine == "ort" else attention_mask.to(device)
    position_ids = position_ids.numpy() if engine == "ort" else position_ids.to(device)
    past_kv = (
        flatten_past_kv_inputs(past_kv)
        if engine == "ort"
        else list(map(lambda kv: (kv[0].to(device), kv[1].to(device)), past_kv))
    )

    if not return_dict:
        # For export
        assert isinstance(past_kv, list)
        return (input_ids, attention_mask, position_ids, past_kv)

    inputs = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "position_ids": position_ids,
    }
    if engine == "ort":
        assert isinstance(past_kv, dict)
        inputs.update(past_kv)

        if use_buffer_share:
            inputs = enable_past_present_share_buffer(inputs, past_seq_len, max_seq_len)

    else:
        assert isinstance(past_kv, list)
        inputs["past_key_values"] = past_kv

    return inputs


# Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx
def get_msft_sample_inputs(
    config: AutoConfig,
    batch_size: int,
    past_seq_len: int,
    seq_len: int,
    max_seq_len: int,
    use_fp16: bool,
    use_buffer_share: bool,
    split_kv: bool,
):
    np_dtype = np.float16 if use_fp16 else np.float32
    head_size = config.hidden_size // config.num_attention_heads

    if not split_kv:
        ort_inputs = {
            "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
            "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype),
            "k_cache": np.random.rand(
                batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
            ).astype(np_dtype),
            "v_cache": np.random.rand(
                batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size
            ).astype(np_dtype),
            "pos": np.array(past_seq_len, dtype=np.int64),
        }
    else:
        ort_inputs = {
            "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype),
            "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype(
                np.int32
            ),
            "pos": np.array(past_seq_len, dtype=np.int64),
        }
        for i in range(config.num_hidden_layers):
            ort_inputs.update(
                {
                    f"k_{i}_cache": np.random.rand(
                        batch_size, config.num_attention_heads, past_seq_len, head_size
                    ).astype(np_dtype),
                    f"v_{i}_cache": np.random.rand(
                        batch_size, config.num_attention_heads, past_seq_len, head_size
                    ).astype(np_dtype),
                }
            )

        if use_buffer_share:
            ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)

    return ort_inputs


# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
    num_heads = config.num_key_value_heads // world_size
    head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
    torch_dtype = torch.float16 if use_fp16 else torch.float32
    past_kv = [
        (
            torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
            torch.rand(batch_size, num_heads, past_seq_len, head_size, dtype=torch_dtype),
        )
        for _ in range(config.num_hidden_layers)
    ]
    return past_kv


# Convert list of past_key_values to dict of past_key and past_value
def flatten_past_kv_inputs(past_key_values: list[tuple[torch.Tensor, torch.Tensor]]):
    past_kv = {}
    for i, (past_k, past_v) in enumerate(past_key_values):
        past_kv[f"past_key_values.{i}.key"] = past_k.detach().cpu().numpy()
        past_kv[f"past_key_values.{i}.value"] = past_v.detach().cpu().numpy()
    return past_kv


# Format PyTorch inputs to ONNX Runtime inputs
def convert_inputs_for_ort(
    pt_inputs: dict,
    use_buffer_share: bool = False,
    past_seq_len: int = 0,
    max_seq_len: int = 2048,
):
    ort_inputs = {}
    for k, v in pt_inputs.items():
        if isinstance(v, np.ndarray):
            ort_inputs[k] = v
        elif k == "past_key_values":
            ort_inputs.update(flatten_past_kv_inputs(v))
        else:
            ort_inputs[k] = v.detach().cpu().numpy()

    # Reshape KV caches if using past-present-share-buffer
    if use_buffer_share:
        ort_inputs = enable_past_present_share_buffer(ort_inputs, past_seq_len, max_seq_len)

    return ort_inputs


# Re-allocate KV caches from (batch_size, num_heads, past_sequence_length, head_size) to
# (batch_size, num_heads, max_sequence_length, head_size) for past-present buffer sharing
def enable_past_present_share_buffer(ort_inputs: dict, past_seq_len: int, max_seq_len: int):
    for k, v in ort_inputs.items():
        # Allocate new buffers with max_sequence_length for GQA
        if "cache" in k or "past_key_values" in k:
            # Copy v (BxSxPxH) into new_v (BxSxMxH)
            batch_size, num_heads, _, head_size = v.shape
            new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype)
            new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v
            ort_inputs[k] = new_v
    return ort_inputs


# Verify ONNX Runtime inputs with model
def verify_ort_inputs(model: InferenceSession, ort_inputs: dict):
    # Check that all model inputs will be provided
    model_inputs = set(map(lambda model_input: model_input.name, model.get_inputs()))
    user_inputs = set(ort_inputs.keys())
    missing_inputs = model_inputs - user_inputs
    if len(missing_inputs):
        print(f"The following model inputs are missing: {missing_inputs}")
        raise Exception("There are missing inputs to the model. Please add them and try again.")

    # Remove unnecessary inputs from model inputs
    unnecessary_inputs = user_inputs - model_inputs
    if len(unnecessary_inputs):
        for unnecessary_input in unnecessary_inputs:
            del ort_inputs[unnecessary_input]

    return ort_inputs


# Add IO bindings for execution providers using OrtValue
# Use when you need to run inference once or twice to save memory
def add_io_bindings_as_ortvalues(
    model: InferenceSession,
    ort_inputs: dict,
    device: str,
    device_id: int,
    use_buffer_share: bool,
    kv_cache_ortvalues: dict,
):
    io_binding = model.io_binding()

    model_inputs = set(map(lambda i: i.name, model.get_inputs()))
    for k, v in ort_inputs.items():
        # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
        # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
        # but `position_ids` is used as a PyTorch model input
        if k not in model_inputs:
            continue

        # Bind OrtValue inputs to device
        if use_buffer_share and ("cache" in k or "past_key_values" in k):
            if k not in kv_cache_ortvalues:
                v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
                io_binding.bind_ortvalue_input(k, v_device)
                kv_cache_ortvalues[k] = v_device
            else:
                kv_cache_ortvalues[k].update_inplace(v)
                io_binding.bind_ortvalue_input(k, kv_cache_ortvalues[k])
        else:
            v_device = OrtValue.ortvalue_from_numpy(v, device_type=device, device_id=device_id)
            io_binding.bind_ortvalue_input(k, v_device)

    for output in model.get_outputs():
        name = output.name
        if use_buffer_share and ("out" in name or "present" in name):
            # Bind present KV cache outputs to past KV cache inputs in order to buffer share
            input_name = name.replace("out", "cache").replace("present", "past_key_values")
            io_binding.bind_ortvalue_output(name, kv_cache_ortvalues[input_name])
        else:
            io_binding.bind_output(name, device_type=device, device_id=device_id)

    return io_binding, kv_cache_ortvalues


# Add IO bindings for execution providers using PyTorch tensors
# Use when you need to run inference many times
def add_io_bindings_as_tensors(
    model: InferenceSession, inputs: dict, outputs: dict, use_fp16: bool, use_buffer_share: bool
):
    # Verify model inputs
    inputs = verify_ort_inputs(model, inputs)

    device = None
    pt_to_np = {
        "torch.int32": np.int32,
        "torch.int64": np.int64,
        "torch.float16": np.float16,
        "torch.float32": np.float32,
    }

    # Bind inputs/outputs to IO binding
    io_binding = model.io_binding()
    for k, v in inputs.items():
        io_binding.bind_input(
            name=k,
            device_type=v.device.type,
            device_id=0 if v.device.type == "cpu" else v.device.index,
            element_type=pt_to_np[repr(v.dtype)],
            shape=tuple(v.shape),
            buffer_ptr=v.data_ptr(),
        )
        device = v.device

    for output in model.get_outputs():
        name = output.name
        # Bind KV cache outputs to KV cache inputs
        v = (
            inputs[name.replace("present", "past_key_values")]
            if use_buffer_share and "present" in name
            else outputs[name]
        )
        io_binding.bind_output(
            name=name,
            device_type=device.type,
            device_id=0 if device.type == "cpu" else device.index,
            element_type=(np.float16 if use_fp16 else np.float32),
            shape=tuple(v.shape),
            buffer_ptr=v.data_ptr(),
        )

    return io_binding


# Get actual inputs when using real data (instead of sample data) and initialize outputs
def get_initial_inputs_and_outputs(
    config: AutoConfig,
    tokenizer: AutoTokenizer,
    requested_length: int,
    prompt: list[str],
    device: torch.device,
    use_fp16: bool,
    use_buffer_share: bool,
    engine: str,
):
    tokenizer.pad_token = tokenizer.eos_token
    encodings_dict = tokenizer.batch_encode_plus(prompt, padding=True)
    torch_dtype = torch.float16 if use_fp16 else torch.float32

    # input_ids:      pad token id is 0
    # attention_mask: pad token id is 0
    # position_ids:   pad token id is 1
    input_ids = torch.tensor(encodings_dict["input_ids"], device=device, dtype=torch.int64)
    attention_mask = torch.tensor(encodings_dict["attention_mask"], device=device, dtype=torch.int64)
    position_ids = get_position_ids(attention_mask, use_past_kv=False)

    # Check if tokenized prompt length matches the requested prompt length
    tokenized_length = input_ids.shape[-1]
    if tokenized_length > requested_length:
        # Shorten the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
        input_ids = input_ids[:, :requested_length]
        attention_mask = attention_mask[:, :requested_length]
        position_ids = get_position_ids(attention_mask, use_past_kv=False)
    elif tokenized_length < requested_length:
        # Lengthen the inputs from (batch_size, tokenized_length) to (batch_size, requested_length)
        input_ids_first_col = input_ids[:, 0].unsqueeze(0).T
        attention_mask_first_col = attention_mask[:, 0].unsqueeze(0).T
        for _ in range(requested_length - tokenized_length):
            input_ids = torch.hstack((input_ids_first_col, input_ids))
            attention_mask = torch.hstack((attention_mask_first_col, attention_mask))
        position_ids = get_position_ids(attention_mask, use_past_kv=False)

    tokenized_length = input_ids.shape[-1]
    assert tokenized_length == requested_length

    # Create inputs
    inputs = {
        "input_ids": input_ids.contiguous() if engine == "ort" else input_ids,
        "attention_mask": attention_mask.contiguous() if engine == "ort" else attention_mask,
        "position_ids": position_ids.contiguous() if engine == "ort" else position_ids,
    }
    if engine != "ort":
        inputs["past_key_values"] = []

    # Get shape of KV cache inputs
    batch_size, sequence_length = input_ids.shape
    max_sequence_length = config.max_position_embeddings
    num_heads = config.num_key_value_heads
    head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads

    # Create KV cache inputs
    for i in range(config.num_hidden_layers):
        past_key = torch.zeros(
            batch_size,
            num_heads,
            max_sequence_length if use_buffer_share else 0,
            head_size,
            device=device,
            dtype=torch_dtype,
        )
        past_value = torch.zeros(
            batch_size,
            num_heads,
            max_sequence_length if use_buffer_share else 0,
            head_size,
            device=device,
            dtype=torch_dtype,
        )
        if engine == "ort":
            inputs.update(
                {
                    f"past_key_values.{i}.key": past_key.contiguous(),
                    f"past_key_values.{i}.value": past_value.contiguous(),
                }
            )
        else:
            inputs["past_key_values"].append((past_key, past_value))

    outputs = None
    if engine == "ort":
        # Create outputs
        logits = torch.zeros(batch_size, sequence_length, config.vocab_size, device=device, dtype=torch_dtype)
        outputs = {"logits": logits.contiguous()}
        if not use_buffer_share:
            for i in range(config.num_hidden_layers):
                present_key = torch.zeros(
                    batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
                )
                present_value = torch.zeros(
                    batch_size, num_heads, sequence_length, head_size, device=device, dtype=torch_dtype
                )
                outputs.update(
                    {f"present.{i}.key": present_key.contiguous(), f"present.{i}.value": present_value.contiguous()}
                )

    return inputs, outputs
