# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

import argparse
import os
from datetime import datetime
from pathlib import Path

import pandas as pd

import nsys_recipe
from nsys_recipe import log
from nsys_recipe.data_service import DataService
from nsys_recipe.lib import arm_metrics as am
from nsys_recipe.lib import cpu_perf, helpers, nvtx, recipe
from nsys_recipe.lib.args import Option
from nsys_recipe.lib.table_config import CompositeTable
from nsys_recipe.log import logger
from nsys_recipe.recipes.nvtx_cpu_topdown import utils


class NvtxCpuTopDownMetrics(recipe.Recipe):
    def __init__(self, parsed_args):
        super().__init__(parsed_args)
        self._warnings = []

    @staticmethod
    def _mapper_func(report_path, parsed_args):
        service = DataService(report_path, parsed_args)

        service.queue_custom_table(CompositeTable.NVTX)
        service.queue_custom_table(CompositeTable.PERF_EVENTS)
        service.queue_table("SCHED_EVENTS")
        service.queue_table("TARGET_INFO_SYSTEM_ENV", ["name", "value"])

        df_dict = service.read_queued_tables()
        if df_dict is None:
            return None

        nvtx_df = df_dict[CompositeTable.NVTX]
        err_msg = service.filter_and_adjust_time(nvtx_df)
        if err_msg is not None:
            logger.error(f"{report_path}: {err_msg}")
            return None

        multithread_nvtx = len(nvtx_df["tid"].unique()) > 1

        no_nvtx_extra_info_msg = ""

        if multithread_nvtx and not parsed_args.agg_parallel_nvtx_ranges:
            # If there are multiple threads, filter out the NVTX ranges
            # that are not associated with the main thread.
            # The main thread is the one that has the same TID and PID.
            nvtx_df = nvtx_df[nvtx_df["tid"] == nvtx_df["pid"]]

            no_nvtx_extra_info_msg = " for the main thread"

        if nvtx_df.empty:
            logger.info(
                f"{report_path}: Report was successfully processed, "
                f"but no NVTX ranges were found{no_nvtx_extra_info_msg}."
            )
            return None

        # Drop the columns that are not needed.
        nvtx_df = nvtx_df.drop(
            columns=["endGlobalTid", "eventType", "textId"], errors="ignore"
        )

        perf_df = df_dict[CompositeTable.PERF_EVENTS]
        cpu_flag = perf_df["cpu"].notna()
        core_perf_df = perf_df[cpu_flag]
        if core_perf_df.empty:
            logger.info(
                f"{report_path}: Report was successfully processed, "
                "but no PMU core performance events were found."
            )
            return None

        thread_sched_df = df_dict["SCHED_EVENTS"]
        if thread_sched_df.empty:
            logger.info(
                f"{report_path}: Report was successfully processed, "
                "but no thread scheduling events were found."
            )
            return None

        target_info_df = df_dict["TARGET_INFO_SYSTEM_ENV"]
        if target_info_df.empty:
            logger.info(
                f"{report_path}: Report was successfully processed, "
                "but no information on system environment was found."
            )
            return None

        # TODO: Update the logic here once DTSP-18884 is implemented.
        cpu_arch = utils.get_cpu_arch(target_info_df)

        if cpu_arch == cpu_perf.Architecture.AARCH64_TEGRA:
            logger.warning("Tegra-based device is not supported by this recipe.")
            return None

        if cpu_arch != cpu_perf.Architecture.AARCH64_SBSA:
            logger.warning(
                f"The {cpu_arch} CPU architecture is not supported by this recipe."
            )
            return None

        filename = Path(report_path).stem
        cpu_core_event_names = core_perf_df["name"].unique()

        try:
            nvtx_df = utils.get_nvtx_w_cpu_events_n_stack(
                nvtx_df,
                core_perf_df,
                thread_sched_df,
                not parsed_args.distribute_across_all_threads,
                parsed_args.agg_parallel_nvtx_ranges,
            )
        except Exception as e:
            logger.error(
                f"{report_path}: "
                "The data from this report is skipped from "
                f"processing due to an error: {e}"
            )
            return None

        return nvtx_df, filename, cpu_core_event_names, cpu_arch, multithread_nvtx

    @log.time("Mapper")
    def mapper_func(self, context):
        return context.wait(
            context.map(
                self._mapper_func,
                self._parsed_args.input,
                parsed_args=self._parsed_args,
            )
        )

    @log.time("Reducer")
    def reducer_func(self, mapper_res):
        filtered_res = helpers.filter_none(mapper_res)

        # Sort the results by the number of NVTX ranges
        filtered_res = sorted(
            filtered_res, key=lambda x: len(x[0]) if x[0] is not None else 0
        )
        nvtx_dfs, filenames, cpu_core_event_names, cpu_archs, multithread_nvtxs = zip(
            *filtered_res
        )

        details_log_path = (
            self.add_output_file("details.log")
            if self._parsed_args.log_details
            else None
        )

        td_group_strategy = nvtx.TopDownGroupingStrategy()
        td_key_col = td_group_strategy.key_column
        par_td_key_col = td_group_strategy.par_key_column
        td_key_cols = [td_key_col, par_td_key_col]

        nvtx_groupers = [None] * len(nvtx_dfs)
        mean_nvtx_id_df = None
        selected_nvtx_dfs = [None] * len(nvtx_dfs)
        too_short_nvtx_ids = set()

        for idx, nvtx_df in enumerate(nvtx_dfs):
            if nvtx_df is None:
                continue

            nvtx_grouper = utils.create_nvtx_grouper(nvtx_df, td_group_strategy)
            nvtx_groupers[idx] = nvtx_grouper

            inst_idx_col = "instIdx"
            inst_idx_df = utils.number_nvtx_instances(nvtx_grouper, inst_idx_col)

            # Select the mean NVTX ranges by the instance duration from the first report
            # and the corresponding ranges on the index from the rest reports
            index_cols = [inst_idx_col, td_key_col]
            if mean_nvtx_id_df is None:
                selected_nvtx_df = nvtx_grouper.mid_ranges()
                selected_nvtx_df = selected_nvtx_df.merge(inst_idx_df, on="rangeId")
                mean_nvtx_id_df = selected_nvtx_df[index_cols]
            else:
                upd_nvtx_df = nvtx_grouper.df.merge(inst_idx_df, on="rangeId")
                selected_nvtx_df = upd_nvtx_df.merge(mean_nvtx_id_df, on=index_cols)

                if selected_nvtx_df.empty:
                    logger.warning(
                        "No common NVTX ranges were found in the reports: "
                        f"{filenames[0]} and {filenames[idx]}."
                    )
                    return False

            selected_nvtx_dfs[idx] = selected_nvtx_df

            # Get too short NVTX ranges from each report for futher filtering
            too_short_nvtx_df = selected_nvtx_df[selected_nvtx_df["tooShort"] == 1]
            too_short_nvtx_ids.update(too_short_nvtx_df[td_key_col].values)

        report_summary_items = []
        nvtx_summary_dfs = []
        cpu_metrics_dfs = []
        cpu_metrics_merge_cols = ["text", inst_idx_col] + td_key_cols
        for idx, nvtx_df in enumerate(nvtx_dfs):
            if nvtx_df is None:
                continue

            filename = filenames[idx]
            nvtx_grouper = nvtx_groupers[idx]

            # Aggregate NVTX ranges to compute NVTX summary for each report
            nvtx_summary_df = utils.aggregate_nvtx_ranges(nvtx_grouper, td_key_cols)
            nvtx_summary_df["report"] = filenames[idx]
            nvtx_summary_df["tooShort"] = nvtx_summary_df[td_key_col].isin(
                too_short_nvtx_ids
            )
            nvtx_summary_dfs.append(nvtx_summary_df)

            # Compute CPU metrics for the selected NVTX ranges for each report
            selected_nvtx_df = selected_nvtx_dfs[idx]
            selected_nvtx_df = selected_nvtx_df[
                ~selected_nvtx_df[td_key_col].isin(too_short_nvtx_ids)
            ]
            cpu_metrics_df = cpu_perf.compute_perf_metrics(
                selected_nvtx_df, "instDuration", cpu_archs[idx]
            )
            cpu_metric_ids = cpu_metrics_df.columns.tolist()

            cpu_metrics_df = selected_nvtx_df[cpu_metrics_merge_cols].join(
                cpu_metrics_df
            )
            cpu_metrics_dfs.append(cpu_metrics_df)

            # Fill the report summary for each report
            arm_metrics = am.get_arm_metrics()
            cpu_metric_names = [
                arm_metrics[am.PerfMetricType.from_name(id)].name
                for id in cpu_metric_ids
            ]

            report_summary_items.append(
                {
                    "Report": filename + ".nsys-rep",
                    "PMU core events": "<br/>".join(cpu_core_event_names[idx]),
                    "CPU core metrics": "<br/>".join(cpu_metric_names),
                }
            )

            if self._parsed_args.log_details:
                utils.log_details(
                    details_log_path,
                    filename,
                    nvtx_df,
                    selected_nvtx_df,
                    nvtx_summary_dfs[idx],
                    cpu_metrics_dfs[idx],
                )

        # Compute total report summary and save it
        report_summary_df = pd.DataFrame(report_summary_items).reset_index(drop=True)
        report_summary_df.to_parquet(self.add_output_file("report_summary.parquet"))

        # Compute total NVTX summary and save it
        nvtx_summary_df = utils.create_nvtx_summary(nvtx_summary_dfs, td_key_cols)
        nvtx_summary_df.to_parquet(self.add_output_file("nvtx_summary.parquet"))

        # Compute total CPU metrics and save it
        cpu_metrics_df = helpers.merge(cpu_metrics_dfs, cpu_metrics_merge_cols)
        cpu_metrics_df = utils.add_callstack_to_duplicated_names(
            cpu_metrics_df, td_key_cols
        )
        cpu_metrics_df.drop(columns=td_key_cols, inplace=True)
        cpu_metrics_df.rename(columns={"text": "NVTX Range"}, inplace=True)
        cpu_metrics_df.to_parquet(self.add_output_file("cpu_metrics.parquet"))

        # Collect warnings on data processing and print them
        if any(multithread_nvtxs) and not self._parsed_args.agg_parallel_nvtx_ranges:
            self._warnings.append(utils.get_only_main_thread_nvtx_warning())

        if len(too_short_nvtx_ids) >= nvtx_summary_dfs[0].shape[0]:
            self._warnings.append(utils.get_all_nvtx_too_short_warning())

        for warning in self._warnings:
            logger.warning(warning)

        warnings_df = pd.DataFrame(self._warnings, columns=["Message"])
        warnings_df.to_parquet(self.add_output_file("warnings.parquet"))

        if self._parsed_args.csv:
            report_summary_df.to_csv(self.add_output_file("report_summary.csv"))
            nvtx_summary_df.to_csv(self.add_output_file("nvtx_summary.csv"))
            cpu_metrics_df.to_csv(self.add_output_file("cpu_metrics.csv"))
            warnings_df.to_csv(self.add_output_file("warnings.csv"))

        if self._parsed_args.log_details:
            utils.log_details(
                details_log_path,
                "Summary from reports",
                nvtx_summary_df=nvtx_summary_df,
                cpu_metrics_df=cpu_metrics_df,
            )

        return True

    def save_notebook(self):
        self.create_notebook("nvtx_cpu_topdown.ipynb")
        self.add_notebook_helper_file("nsys_display.py")
        self.add_notebook_helper_file("arm_metrics.py")
        self.add_notebook_helper_file("post_process.py", dir=os.path.dirname(__file__))

    def save_analysis_file(self):
        self._analysis_dict.update(
            {
                "EndTime": str(datetime.now()),
                "Outputs": self._output_files,
            }
        )
        self.create_analysis_file()

    def run(self, context):
        super().run(context)

        mapper_res = self.mapper_func(context)
        if self.reducer_func(mapper_res):
            self.save_notebook()
            self.save_analysis_file()
        else:
            raise nsys_recipe.NoDataError()

    @classmethod
    def get_argument_parser(cls):
        parser = super().get_argument_parser()

        parser.add_recipe_argument(Option.INPUT, required=True)
        parser.add_recipe_argument(Option.CSV)
        parser.add_recipe_argument(
            "--distribute-across-all-threads",
            dest="distribute_across_all_threads",
            action="store_true",
            help=argparse.SUPPRESS,
            default=False,
        )
        parser.add_recipe_argument(
            "--aggregate-parallel-nvtx-ranges",
            dest="agg_parallel_nvtx_ranges",
            action="store_true",
            help=argparse.SUPPRESS,
            default=False,
        )

        parser.add_recipe_argument(
            "--log-details",
            dest="log_details",
            action="store_true",
            help=argparse.SUPPRESS,
            default=False,
        )

        return parser
