# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation.  All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

# It is used to dump machine information for Notebooks

import argparse
import json
import logging
import platform
from os import environ
from typing import Dict, List

import cpuinfo
import psutil
from py3nvml.py3nvml import (
    NVMLError,
    nvmlDeviceGetCount,
    nvmlDeviceGetHandleByIndex,
    nvmlDeviceGetMemoryInfo,
    nvmlDeviceGetName,
    nvmlInit,
    nvmlShutdown,
    nvmlSystemGetDriverVersion,
)


class MachineInfo:
    """Class encapsulating Machine Info logic."""

    def __init__(self, silent=False, logger=None):
        self.silent = silent

        if logger is None:
            logging.basicConfig(
                format="%(asctime)s - %(name)s - %(levelname)s: %(message)s",
                level=logging.INFO,
            )
            self.logger = logging.getLogger(__name__)
        else:
            self.logger = logger

        self.machine_info = None
        try:
            self.machine_info = self.get_machine_info()
        except Exception:
            self.logger.exception("Exception in getting machine info.")
            self.machine_info = None

    def get_machine_info(self):
        """Get machine info in metric format"""
        gpu_info = self.get_gpu_info_by_nvml()
        cpu_info = cpuinfo.get_cpu_info()

        machine_info = {
            "gpu": gpu_info,
            "cpu": self.get_cpu_info(),
            "memory": self.get_memory_info(),
            "os": platform.platform(),
            "python": self._try_get(cpu_info, ["python_version"]),
            "packages": self.get_related_packages(),
            "onnxruntime": self.get_onnxruntime_info(),
            "pytorch": self.get_pytorch_info(),
            "tensorflow": self.get_tensorflow_info(),
        }
        return machine_info

    def get_memory_info(self) -> Dict:
        """Get memory info"""
        mem = psutil.virtual_memory()
        return {"total": mem.total, "available": mem.available}

    def _try_get(self, cpu_info: Dict, names: List) -> str:
        for name in names:
            if name in cpu_info:
                value = cpu_info[name]
                if isinstance(value, (list, tuple)):
                    return ",".join([str(i) for i in value])
                return value
        return ""

    def get_cpu_info(self) -> Dict:
        """Get CPU info"""
        cpu_info = cpuinfo.get_cpu_info()

        return {
            "brand": self._try_get(cpu_info, ["brand", "brand_raw"]),
            "cores": psutil.cpu_count(logical=False),
            "logical_cores": psutil.cpu_count(logical=True),
            "hz": self._try_get(cpu_info, ["hz_actual"]),
            "l2_cache": self._try_get(cpu_info, ["l2_cache_size"]),
            "flags": self._try_get(cpu_info, ["flags"]),
            "processor": platform.uname().processor,
        }

    def get_gpu_info_by_nvml(self) -> Dict:
        """Get GPU info using nvml"""
        gpu_info_list = []
        driver_version = None
        try:
            nvmlInit()
            driver_version = nvmlSystemGetDriverVersion()
            deviceCount = nvmlDeviceGetCount()  # noqa: N806
            for i in range(deviceCount):
                handle = nvmlDeviceGetHandleByIndex(i)
                info = nvmlDeviceGetMemoryInfo(handle)
                gpu_info = {}
                gpu_info["memory_total"] = info.total
                gpu_info["memory_available"] = info.free
                gpu_info["name"] = nvmlDeviceGetName(handle)
                gpu_info_list.append(gpu_info)
            nvmlShutdown()
        except NVMLError as error:
            if not self.silent:
                self.logger.error("Error fetching GPU information using nvml: %s", error)
            return None

        result = {"driver_version": driver_version, "devices": gpu_info_list}

        if "CUDA_VISIBLE_DEVICES" in environ:
            result["cuda_visible"] = environ["CUDA_VISIBLE_DEVICES"]
        return result

    def get_related_packages(self) -> List[str]:
        import pkg_resources

        installed_packages = pkg_resources.working_set
        related_packages = [
            "onnxruntime-gpu",
            "onnxruntime",
            "onnx",
            "transformers",
            "protobuf",
            "sympy",
            "torch",
            "tensorflow",
            "flatbuffers",
            "numpy",
            "onnxconverter-common",
        ]
        related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages}
        return related_packages_list

    def get_onnxruntime_info(self) -> Dict:
        try:
            import onnxruntime

            return {
                "version": onnxruntime.__version__,
                "support_gpu": "CUDAExecutionProvider" in onnxruntime.get_available_providers(),
            }
        except ImportError as error:
            if not self.silent:
                self.logger.exception(error)
            return None
        except Exception as exception:
            if not self.silent:
                self.logger.exception(exception, False)
            return None

    def get_pytorch_info(self) -> Dict:
        try:
            import torch

            return {
                "version": torch.__version__,
                "support_gpu": torch.cuda.is_available(),
                "cuda": torch.version.cuda,
            }
        except ImportError as error:
            if not self.silent:
                self.logger.exception(error)
            return None
        except Exception as exception:
            if not self.silent:
                self.logger.exception(exception, False)
            return None

    def get_tensorflow_info(self) -> Dict:
        try:
            import tensorflow as tf

            return {
                "version": tf.version.VERSION,
                "git_version": tf.version.GIT_VERSION,
                "support_gpu": tf.test.is_built_with_cuda(),
            }
        except ImportError as error:
            if not self.silent:
                self.logger.exception(error)
            return None
        except ModuleNotFoundError as error:
            if not self.silent:
                self.logger.exception(error)
            return None


def parse_arguments():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--silent",
        required=False,
        action="store_true",
        help="Do not print error message",
    )
    parser.set_defaults(silent=False)

    args = parser.parse_args()
    return args


def get_machine_info(silent=True) -> str:
    machine = MachineInfo(silent)
    return json.dumps(machine.machine_info, indent=2)


if __name__ == "__main__":
    args = parse_arguments()
    print(get_machine_info(args.silent))
