# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Timer: helps measure CPU and CUDA times easily and reliably.
"""

import time
from contextlib import ContextDecorator
from functools import wraps
from typing import Callable

import torch

from cosmos_policy._src.imaginaire.utils import log


def _autoformat_time_us(time_us: float) -> str:
    """
    Automatically format time in nanoseconds.
    """
    if time_us >= 1e6:
        time_s = time_us * 1e-6
        return f"{time_s:.2f} s"

    if time_us >= 1e3:
        time_ms = time_us * 1e-3
        return f"{time_ms:.2f} ms"

    return f"{time_us:.2f} us"


def format_time_str(time_us: float, unit: str | None = None) -> str:
    """
    Automatically format time in nanoseconds either automatically or based on
    desired unit.
    """
    if unit is None:
        return _autoformat_time_us(time_us)

    if unit == "us":
        return f"{time_us:.2f} us"

    if unit == "ms":
        return f"{time_us * 1e-3:.2f} ms"

    if unit == "s":
        return f"{time_us * 1e-6:.2f} s"

    raise NotImplementedError(f"Time unit {unit} is not supported.")


def format_time(time_us: float, unit: str) -> float:
    """
    Format time in nanoseconds based on desired unit.
    """

    if unit == "us":
        return time_us

    if unit == "ms":
        return time_us * 1e-3

    if unit == "s":
        return time_us * 1e-6

    raise NotImplementedError(f"Time unit {unit} is not supported.")


class Timer(ContextDecorator):
    """
    Reliable CPU and CUDA Timer.

    Args:
        tag (str | None): Optional tag used in logs/prints.

        measure_cpu (bool): Whether to measure CPU time (using `time`). Default: `True`.

        measure_cuda (bool): Whether to measure CUDA time (using CUDA events). Default: `True`.

        unit (str | None): Optional time unit. Must be either "s" (seconds), "ms" (microseconds),
            "us" (nanoseconds), or None (format automatically based on value).

        debug (bool): Whether to log results in debug mode instead of info. Default is False.

    Examples:
        ```python
        with Timer(measure_cpu=True, measure_cuda=True, unit="ms"):
            model(x)
        ```

        ```python
        @Timer(measure_cpu=True, measure_cuda=True, unit="ms")
        def func(x):
            return model(x)
        ```
    """

    def __init__(
        self,
        tag: str | None = None,
        measure_cpu: bool = True,
        measure_cuda: bool = True,
        unit: str | None = None,
        debug: bool = False,
    ):
        self.measure_cpu = measure_cpu
        self.measure_cuda = measure_cuda

        self.measured = False
        self.cpu_time_us = 0
        self.cuda_time_us = 0

        self.busy = False
        self.cpu_time_start = None
        self.cuda_start_event = None
        self.cuda_end_event = None
        self.cuda_stream = None

        self.tag = "unknown" if tag is None else tag
        self.unit = unit
        if self.unit is not None and self.unit not in ["s", "ms", "us"]:
            raise NotImplementedError(f"Time unit {self.unit} is not supported.")

        self.debug = debug

    def _log(self, msg: str):
        if self.debug:
            log.debug(msg)
        else:
            log.info(msg)

    def __enter__(self):
        self.start()

    def __exit__(self, exc_type, exc_value, traceback):
        self.end()
        self.report()

    def __call__(self, func: Callable) -> Callable:
        @wraps(func)
        def wrapper(*args, **kwargs):  # noqa: ANN202
            self.start()
            result = func(*args, **kwargs)
            self.end()
            self.report()
            return result

        return wrapper  # type: ignore

    def report(self):
        """
        Reports measurements.
        """
        if self.measure_cpu and self.measure_cuda:
            self._log(f"Time spent on {self.tag}: CPU: {self.get_cpu_time_str()}, CUDA: {self.get_cuda_time_str()}")
        elif self.measure_cpu:
            self._log(f"Time spent on {self.tag}: {self.get_cpu_time_str()}")
        elif self.measure_cuda:
            self._log(f"CUDA time spent on {self.tag}: {self.get_cuda_time_str()}")
        else:
            raise NotImplementedError()

    def get_cpu_time(self) -> float:
        """
        Returns CPU time measurement.
        """
        if not self.measure_cpu:
            raise RuntimeError(f"CPU timer is disabled ({self.measure_cpu=}).")

        if not self.measured:
            raise RuntimeError("No measurements were made yet!")

        if self.unit is None:
            raise RuntimeError("No unit was specified. Please use get_cpu_time_str() instead.")

        assert self.unit is not None
        return format_time(self.cpu_time_us, unit=self.unit)

    def get_cuda_time(self) -> float:
        """
        Returns CUDA time measurement.
        """
        if not self.measure_cuda:
            raise RuntimeError(f"CUDA timer is disabled ({self.measure_cuda=}).")

        if not self.measured:
            raise RuntimeError("No measurements were made yet!")

        if self.unit is None:
            raise RuntimeError("No unit was specified. Please use get_cuda_time_str() instead.")

        assert self.unit is not None
        return format_time(self.cuda_time_us, unit=self.unit)

    def get_cpu_time_str(self) -> str:
        """
        Returns CPU time measurement in string format.
        """
        if not self.measure_cpu:
            raise RuntimeError(f"CPU timer is disabled ({self.measure_cpu=}).")

        if not self.measured:
            raise RuntimeError("No measurements were made yet!")

        return format_time_str(self.cpu_time_us, unit=self.unit)

    def get_cuda_time_str(self) -> str:
        """
        Returns CUDA time measurement in string format.
        """
        if not self.measure_cuda:
            raise RuntimeError(f"CUDA timer is disabled ({self.measure_cuda=}).")

        if not self.measured:
            raise RuntimeError("No measurements were made yet!")

        return format_time_str(self.cuda_time_us, unit=self.unit)

    def reset(self):
        """
        Resets recorded measurements
        """
        self.measured = False
        self.cpu_time_us = 0
        self.cuda_time_us = 0

    def start(self, cuda_device: torch.device | None = None, cuda_stream: torch.cuda.Stream | None = None):
        """
        Start time measurements.

        Args:
            cuda_device (torch.device | None): CUDA device. Will use default CUDA device if not indicated.

            cuda_stream (torch.cuda.Stream | None): CUDA stream to use for CUDA time measurement.
                Will use default stream for current CUDA device if not indicated.
        """
        if self.busy:
            raise RuntimeError("Already called Timer.start() once!")

        self.busy = True

        if self.measure_cuda:
            self.cuda_stream = cuda_stream if cuda_stream is not None else torch.cuda.current_stream(cuda_device)
            self.cuda_stream.synchronize()

        if self.measure_cpu:
            self.cpu_time_start = time.time()

        if self.measure_cuda:
            self.cuda_start_event = torch.cuda.Event(enable_timing=True)
            self.cuda_end_event = torch.cuda.Event(enable_timing=True)
            self.cuda_stream.record_event(self.cuda_start_event)

    def end(self):
        """
        Ends time measurements.

        NOTE: must be done on the same CUDA device and stream as start().
        """
        if not self.busy:
            raise RuntimeError("Timer.start() must be called exactly once before end()!")

        if self.measure_cuda:
            self.cuda_stream.record_event(self.cuda_end_event)
            self.cuda_end_event.synchronize()

        if self.measure_cpu:
            self.cpu_time_end = time.time()
            self.cpu_time_us = (self.cpu_time_end - self.cpu_time_start) * 1e6

        if self.measure_cuda:
            self.cuda_time_us = self.cuda_start_event.elapsed_time(self.cuda_end_event) * 1e3

        self.busy = False
        self.measured = True
