"""Visualize RTX GIFs as "first n frames at X fps" sampling.

Shows what you get when you sample the first n frames at a given FPS:
uses each GIF's per-frame durations to pick the frame index at times
0, 1/fps, 2/fps, ..., (n-1)/fps seconds. Default n=4, fps=20.
"""

import argparse
import random
from pathlib import Path

import numpy as np
from PIL import Image

try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")


def find_gifs(root: Path):
    """Return sorted list of root/*/*.gif (same as train_rtx.py)."""
    root = Path(root).resolve()
    out = []
    for subdir in sorted(root.iterdir()):
        if not subdir.is_dir():
            continue
        for f in sorted(subdir.glob("*.gif")):
            out.append(f)
    return out


def get_gif_info(path: Path):
    """Get n_frames, total_duration_ms, and per-frame durations for a GIF."""
    with Image.open(path) as im:
        n_frames = getattr(im, "n_frames", 1)
        durations_ms = []
        for i in range(n_frames):
            im.seek(i)
            d = im.info.get("duration", 0)
            if d is None:
                d = 0
            durations_ms.append(int(d))
    total_ms = sum(durations_ms)
    return n_frames, total_ms, durations_ms


def sample_frame_indices_at_fps(durations_ms: list, num_frames: int, fps: float):
    """
    Return frame indices corresponding to sampling at `fps` for the first `num_frames` samples.
    Sample times: 0, 1/fps, 2/fps, ..., (num_frames-1)/fps seconds.
    """
    if not durations_ms or fps <= 0:
        return list(range(min(num_frames, len(durations_ms) or 1)))
    cumsum_ms = np.cumsum([0] + list(durations_ms))
    total_ms = cumsum_ms[-1]
    indices = []
    for i in range(num_frames):
        t_sec = i / fps
        t_ms = t_sec * 1000
        if t_ms >= total_ms:
            indices.append(len(durations_ms) - 1)
            continue
        # find frame k such that cumsum_ms[k] <= t_ms < cumsum_ms[k+1]
        k = np.searchsorted(cumsum_ms, t_ms, side="right") - 1
        k = max(0, min(k, len(durations_ms) - 1))
        indices.append(k)
    return indices


def main():
    p = argparse.ArgumentParser(
        description="Visualize RTX GIFs: first n frames sampled at X fps (default n=4, fps=20)"
    )
    p.add_argument(
        "--data-root",
        type=str,
        default=str(DEFAULT_DATA_ROOT),
        help="Root with subdirs of .gif files",
    )
    p.add_argument(
        "--num-frames",
        type=int,
        default=4,
        help="Number of frames to show per episode (default 4).",
    )
    p.add_argument(
        "--fps",
        type=float,
        default=20.0,
        help="Sampling FPS: frame at t=0, 1/fps, 2/fps, ... seconds (default 20).",
    )
    p.add_argument(
        "--num-episodes",
        type=int,
        default=6,
        help="Number of episodes (GIFs) to show, one row per episode (default 6).",
    )
    p.add_argument(
        "--max-samples",
        type=int,
        default=200,
        help="Max number of GIFs to consider; use 0 for all.",
    )
    p.add_argument(
        "--out",
        type=str,
        default="rtx_fps_vis",
        help="Output path prefix for saved figure (default: rtx_fps_vis).",
    )
    p.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed for shuffling GIF order (default: none).",
    )
    args = p.parse_args()

    root = Path(args.data_root).resolve()
    if not root.is_dir():
        raise FileNotFoundError(f"Data root not found: {root}")

    gifs = find_gifs(root)
    if not gifs:
        raise FileNotFoundError(f"No *.gif under {root}. Expected structure: root/*/*.gif")

    if args.max_samples and args.max_samples > 0:
        gifs = gifs[: args.max_samples]
    if args.seed is not None:
        random.seed(args.seed)
    random.shuffle(gifs)

    n_episodes = min(args.num_episodes, len(gifs))
    n_f = args.num_frames
    fps = args.fps

    if not HAS_MPL:
        print("matplotlib not available; cannot save figure.")
        return

    out_prefix = Path(args.out)
    out_prefix.parent.mkdir(parents=True, exist_ok=True)

    fig, axes = plt.subplots(n_episodes, n_f, figsize=(2 * n_f, 2 * n_episodes))
    if n_episodes == 1:
        axes = axes.reshape(1, -1)

    for row, path in enumerate(gifs[:n_episodes]):
        _, total_ms, durations_ms = get_gif_info(path)
        indices = sample_frame_indices_at_fps(durations_ms, n_f, fps)
        with Image.open(path) as im:
            for col, frame_idx in enumerate(indices):
                im.seek(int(frame_idx))
                fr = np.array(im.convert("RGB"))
                ax = axes[row, col]
                ax.imshow(fr)
                ax.set_axis_off()
                t_sec = col / fps
                if col == 0:
                    ax.set_ylabel(path.name[:20] + "..." if len(path.name) > 20 else path.name, fontsize=6)
                ax.set_title(f"t={t_sec:.2f}s", fontsize=6)

    fig.suptitle("First %d frames @ %.1f fps (one row per episode)" % (n_f, fps), y=1.02)
    fig.tight_layout()
    fig.savefig(str(out_prefix) + "_frames.png", dpi=100)
    plt.close(fig)
    print("Saved %s_frames.png (%d episodes × %d frames @ %.1f fps)" % (out_prefix, n_episodes, n_f, fps))


if __name__ == "__main__":
    main()
