# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import argparse
import copy
import json
import sys
from collections import OrderedDict
from pprint import pprint
from typing import Any, Dict, List

import onnx

TuningResults = Dict[str, Any]

_TUNING_RESULTS_KEY = "tuning_results"


def _find_tuning_results_in_props(metadata_props):
    for idx, prop in enumerate(metadata_props):
        if prop.key == _TUNING_RESULTS_KEY:
            return idx
    return -1


def extract(model: onnx.ModelProto):
    idx = _find_tuning_results_in_props(model.metadata_props)
    if idx < 0:
        return None

    tuning_results_prop = model.metadata_props[idx]
    return json.loads(tuning_results_prop.value)


def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False):
    idx = _find_tuning_results_in_props(model.metadata_props)
    assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"

    if idx >= 0:
        model.metadata_props.pop(idx)

    entry = model.metadata_props.add()
    entry.key = _TUNING_RESULTS_KEY
    entry.value = json.dumps(tuning_results)
    return model


class Merger:
    class EpAndValidators:
        def __init__(self, ep: str, validators: Dict[str, str]):
            self.ep = ep
            self.validators = copy.deepcopy(validators)
            self.key = (ep, tuple(sorted(validators.items())))

        def __hash__(self):
            return hash(self.key)

        def __eq__(self, other):
            return self.ep == other.ep and self.key == other.key

    def __init__(self):
        self.ev_to_results = OrderedDict()

    def merge(self, tuning_results: List[TuningResults]):
        for trs in tuning_results:
            self._merge_one(trs)

    def get_merged(self):
        tuning_results = []
        for ev, flat_results in self.ev_to_results.items():
            results = {}
            trs = {
                "ep": ev.ep,
                "validators": ev.validators,
                "results": results,
            }
            for (op_sig, params_sig), kernel_id in flat_results.items():
                kernel_map = results.setdefault(op_sig, {})
                kernel_map[params_sig] = kernel_id
            tuning_results.append(trs)
        return tuning_results

    def _merge_one(self, trs: TuningResults):
        ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
        flat_results = self.ev_to_results.setdefault(ev, {})
        for op_sig, kernel_map in trs["results"].items():
            for params_sig, kernel_id in kernel_map.items():
                if (op_sig, params_sig) not in flat_results:
                    flat_results[(op_sig, params_sig)] = kernel_id


def parse_args():
    parser = argparse.ArgumentParser()
    sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")

    extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
    extract_parser.add_argument("input_onnx")
    extract_parser.add_argument("output_json")

    embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
    embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
    embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
    embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
    embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")

    merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
    merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
    merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")

    pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
    pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")

    args = parser.parse_args()
    if len(vars(args)) == 0:
        parser.print_help()
        exit(-1)
    return args


def main():
    args = parse_args()
    if args.cmd == "extract":
        tuning_results = extract(onnx.load_model(args.input_onnx))
        if tuning_results is None:
            sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
            sys.exit(-1)
        json.dump(tuning_results, open(args.output_json, "w"))  # noqa: SIM115
    elif args.cmd == "embed":
        model = onnx.load_model(args.input_onnx)
        merger = Merger()
        for tuning_results in [json.load(open(f)) for f in args.input_json]:  # noqa: SIM115
            merger.merge(tuning_results)
        model = embed(model, merger.get_merged(), args.force)
        onnx.save_model(model, args.output_onnx)
    elif args.cmd == "merge":
        merger = Merger()
        for tuning_results in [json.load(open(f)) for f in args.input_json]:  # noqa: SIM115
            merger.merge(tuning_results)
        json.dump(merger.get_merged(), open(args.output_json, "w"))  # noqa: SIM115
    elif args.cmd == "pprint":
        tuning_results = None
        try:  # noqa: SIM105
            tuning_results = json.load(open(args.json_or_onnx))  # noqa: SIM115
        except Exception:
            # it might be an onnx file otherwise, try it latter
            pass

        if tuning_results is None:
            try:
                model = onnx.load_model(args.json_or_onnx)
                tuning_results = extract(model)
                if tuning_results is None:
                    sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
                    sys.exit(-1)
            except Exception:
                pass

        if tuning_results is None:
            sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
            sys.exit(-1)

        pprint(tuning_results)
    else:
        # invalid choice will be handled by the parser
        pass


if __name__ == "__main__":
    main()
