# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from collections import defaultdict
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any, Callable, Optional

import numpy as np
import pandas as pd

from nsys_recipe.lib.typing_helpers import _PandasNamedTuple


def group_overlapping_ranges(range_df: pd.DataFrame) -> pd.Series:
    """Assign unique group identifiers to overlapping ranges.

    Parameters
    ----------
    range_df : dataframe
        DataFrame containing ranges with 'start' and 'end' columns.

    Returns
    -------
    result : series
        Series containing group identifiers for each range.
    """
    df = range_df.sort_values("start")
    cumulative_max_end = df["end"].cummax()
    groups = (df["start"] > cumulative_max_end.shift()).cumsum()
    return groups.reindex(range_df.index)


def consolidate_ranges(range_df: pd.DataFrame) -> pd.DataFrame:
    """Consolidate overlapping time ranges.

    For each set of overlapping ranges, only the earliest start time and latest
    end time will be retained.

    Parameters
    ----------
    range_df : dataframe
        DataFrame containing ranges with 'start' and 'end' columns.

    Returns
    -------
    result : dataframe
        DataFrame with overlapping ranges consolidated, where the 'start' column
        contains the minimum start value and the 'end' column contains the
        maximum end value for each group.
    """
    groups = group_overlapping_ranges(range_df)
    return range_df.groupby(groups).agg({"start": "min", "end": "max"})


def process_overlapping_ranges(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    process_func: Callable,
    fully_contained: bool = False,
) -> None:
    """Process overlapping ranges between two dataframes.

    Parameters
    ----------
    df1 : dataframe
        Dataframe containing ranges with 'start' and 'end' columns. If
        'fully_contained' is True, this is checked to see if its ranges
        fully contain the ranges from 'df2'.
    df2 : dataframe
        Dataframe containing ranges with 'start' and 'end' columns. If
        'fully_contained' is True, this is checked to see if its ranges
        are fully contained within the ranges from 'df1'.
    process_func : callable
        Function to process overlapping ranges. It takes two arguments:
        - df1_index: index of a row from 'df1' that overlaps with 'df2_row'.
        - df2_row: itertuples iterator representing a row from 'df2', containing
            the index, start, and end values.
    fully_contained : bool
        Whether to check if the ranges are fully contained within each other.
        Fully contained ranges must have their start and end values within the
        start and end values of the containing range, with the end being
        exclusive.
    """
    if df1 is None or df1.empty or df2 is None or df2.empty:
        return

    df2_time_df = pd.DataFrame(
        data={"start": df2["start"], "end": df2["end"]}
    ).sort_values("start")

    # Should be "set[Index]", but we can't consitently infer this on insertion
    df1_active_indices: set[Any] = set()
    df1_start_df = pd.DataFrame(data={"time": df1["start"]}).sort_values("time")
    df1_end_df = pd.DataFrame(data={"time": df1["end"]}).sort_values("time")

    # Our _PandasNamedTuple is a bit more lenient than the one in pandas-stubs
    df2_iter: Iterator[_PandasNamedTuple] = iter(df2_time_df.itertuples())
    df1_start_iter: Iterator[_PandasNamedTuple] = iter(df1_start_df.itertuples())
    df1_end_iter: Iterator[_PandasNamedTuple] = iter(df1_end_df.itertuples())

    df2_row = next(df2_iter)
    df1_start_row: Optional[_PandasNamedTuple] = next(df1_start_iter)
    df1_end_row = next(df1_end_iter)

    while True:
        if df1_start_row is not None:
            if fully_contained:
                should_include_range = df1_start_row.time <= df2_row.start
            else:
                should_include_range = df1_start_row.time < df2_row.end

            if should_include_range:
                df1_active_indices.add(df1_start_row.Index)

                try:
                    df1_start_row = next(df1_start_iter)
                except StopIteration:
                    df1_start_row = None
                continue

        if df1_end_row.time <= df2_row.start:  # type: ignore
            df1_active_indices.remove(df1_end_row.Index)

            try:
                df1_end_row = next(df1_end_iter)
            except StopIteration:
                break
        else:
            for index in df1_active_indices:
                # Check if the end of the range is contained, as only the start
                # was checked in the first condition. If the end is not
                # contained, skip the current 'df2' range.
                if fully_contained and df2_row.end > df1_end_df.loc[index, "time"]:
                    continue

                process_func(index, df2_row)

            try:
                df2_row = next(df2_iter)
            except StopIteration:
                break


def map_overlapping_ranges(
    df1: pd.DataFrame,
    df2: pd.DataFrame,
    key_df: str = "df2",
    fully_contained: bool = False,
) -> dict:
    """Map overlapping ranges between two dataframes.

    Parameters
    ----------
    df1 : dataframe
        Dataframe containing ranges with 'start' and 'end' columns. If
        'fully_contained' is True, this is checked to see if its ranges
        fully contain the ranges from 'df2'.
    df2 : dataframe
        Dataframe containing ranges with 'start' and 'end' columns. If
        'fully_contained' is True, this is checked to see if its ranges
        are fully contained within the ranges from 'df1'.
    key_df : str
        Whether indices of 'df1' or 'df2' should be used as the key of the
        resulting mapping. Must be either 'df1' or 'df2'.
    fully_contained : bool
        Whether to check if the ranges are fully contained within each other.
        Fully contained ranges must have their start and end values within the
        start and end values of the containing range, with the end being
        exclusive.

    Returns
    -------
    overlap_map : dict
        Dictionary that maps indices of the 'key_df' to the indices of the
        corresponding ranges in the other dataframe.
    """
    if key_df != "df1" and key_df != "df2":
        raise ValueError("key_df must be either 'df1' or 'df2'.")

    overlap_map = defaultdict(set)

    def process_func(df1_index, df2_row):
        if key_df == "df1":
            overlap_map[df1_index].add(df2_row.Index)
        else:
            overlap_map[df2_row.Index].add(df1_index)

    process_overlapping_ranges(df1, df2, process_func, fully_contained)
    return overlap_map


def calculate_overlapping_ranges(
    df1: pd.DataFrame, df2: Optional[pd.DataFrame] = None
) -> pd.DataFrame:
    """Calculate the overlapping ranges between two dataframes.

    Parameters
    ----------
    df1 : dataframe
        DataFrame containing ranges to calculate the overlap from, with 'start'
        and 'end' columns.
    df2 : dataframe, optional
        DataFrame containing ranges to calculate the overlap with, with 'start'
        and 'end' columns. If not provided, the function calculates the
        overlap within df1.

    Returns
    -------
    result : dataframe
        DataFrame containing overlapping ranges, with the following columns:
        - start: start position of the overlap.
        - end: end position of the overlap.
        - original_index: index of the original row in df1.
        These ranges may not exactly match the original ranges, as they could
        be created by combining start and end values from different ranges.
    """
    overlap_map = defaultdict(set)

    def process_func(df1_index, df2_row):
        overlap_map[df1_index].add((df2_row.Index, df2_row.start, df2_row.end))

    if df2 is None:
        process_overlapping_ranges(df1, df1, process_func)
    else:
        process_overlapping_ranges(df1, df2, process_func)

    results: list[tuple[int, int, int]] = []

    df1_row: Any  # Can't infer .Index type correctly
    for df1_row in df1.itertuples():
        if df1_row.Index not in overlap_map:
            continue

        indices, starts, ends = zip(*overlap_map[df1_row.Index])
        indices_array = np.array(indices)
        starts_array = np.array(starts)
        ends_array = np.array(ends)

        # We don't want to consider the overlap between the same range
        # instances.
        if df2 is None:
            non_self_mask = indices_array != df1_row.Index
            starts_array = starts_array[non_self_mask]
            ends_array = ends_array[non_self_mask]

        overlap_start = np.maximum(df1_row.start, starts_array)  # type: ignore[arg-type]
        overlap_end = np.minimum(df1_row.end, ends_array)  # type: ignore[arg-type]
        overlap_duration = overlap_end - overlap_start

        valid_overlap_mask = overlap_duration > 0
        overlap_start = overlap_start[valid_overlap_mask]
        overlap_end = overlap_end[valid_overlap_mask]

        results.extend(
            zip(overlap_start, overlap_end, [df1_row.Index] * len(overlap_start))
        )

    return pd.DataFrame(results, columns=["start", "end", "original_index"])


def calculate_overlap_sum(
    df1: pd.DataFrame, df2: Optional[pd.DataFrame] = None, consolidate: bool = True
) -> pd.Series:
    """Calculate the sum of overlapping durations between two dataframes.

    Parameters
    ----------
    df1 : dataframe
        DataFrame containing ranges to calculate the overlap from, with 'start'
        and 'end' columns.
    df2 : dataframe, optional
        DataFrame containing ranges to calculate the overlap with, with 'start'
        and 'end' columns. If not provided, the function calculates the
        overlap within df1.
    consolidate : bool, optional
        Whether to consolidate overlapping ranges. If True, only one overlap is
        considered for each range.

    Returns
    -------
    result : series
        Series containing the sum of overlapping durations for each row in df1.
        Non overlapping ranges will have a sum of 0.
    """
    overlap_df = calculate_overlapping_ranges(df1, df2)

    if consolidate:
        overlap_df = (
            overlap_df.assign(groups=group_overlapping_ranges(overlap_df))
            .groupby(["original_index", "groups"])
            .agg({"start": "min", "end": "max"})
        )

    overlap_df["duration"] = overlap_df["end"] - overlap_df["start"]
    total_duration = overlap_df.groupby("original_index")["duration"].sum()
    return total_duration.reindex(df1.index, fill_value=0.0).astype(float).round(1)
