"""Visualize RTX Video GIFs: show loaded frames and FPS stats to decide extraction FPS.

Uses the same data root and find_gifs as train_rtx.py. For each GIF we compute:
- n_frames, total_duration_ms, effective_fps = n_frames / (total_duration_ms/1000).
Then we show a grid of sample frames and a summary of FPS distribution.
"""

import argparse
import random
from pathlib import Path

import numpy as np
from PIL import Image

try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt
    HAS_MPL = True
except ImportError:
    HAS_MPL = False

DEFAULT_DATA_ROOT = Path("/data/RTX/RTX_Video")
NUM_FRAMES_TRAIN = 4  # match train_rtx (first N frames loaded)


def find_gifs(root: Path):
    """Return sorted list of root/*/*.gif (same as train_rtx.py)."""
    root = Path(root).resolve()
    out = []
    for subdir in sorted(root.iterdir()):
        if not subdir.is_dir():
            continue
        for f in sorted(subdir.glob("*.gif")):
            out.append(f)
    return out


def get_gif_info(path: Path):
    """Get n_frames, total_duration_ms, and per-frame durations for a GIF."""
    with Image.open(path) as im:
        n_frames = getattr(im, "n_frames", 1)
        durations_ms = []
        for i in range(n_frames):
            im.seek(i)
            d = im.info.get("duration", 0)
            if d is None:
                d = 0
            durations_ms.append(int(d))
    total_ms = sum(durations_ms)
    return n_frames, total_ms, durations_ms


def main():
    p = argparse.ArgumentParser(
        description="Visualize RTX Video GIFs: frames + FPS stats (for choosing extraction FPS)"
    )
    p.add_argument(
        "--data-root",
        type=str,
        default=str(DEFAULT_DATA_ROOT),
        help="Root with subdirs of .gif files",
    )
    p.add_argument(
        "--max-samples",
        type=int,
        default=200,
        help="Max number of GIFs to scan for stats (default 200). Use 0 for all.",
    )
    p.add_argument(
        "--num-episodes",
        type=int,
        default=4,
        help="Number of episodes (GIFs) to show, one row per episode (default 4).",
    )
    p.add_argument(
        "--show-frames",
        type=int,
        default=30,
        help="Number of frames to sample per episode, uniformly along the GIF (default 30).",
    )
    p.add_argument(
        "--num-frames",
        type=int,
        default=NUM_FRAMES_TRAIN,
        help=f"Unused when showing frames-per-episode; kept for compatibility (default {NUM_FRAMES_TRAIN}).",
    )
    p.add_argument(
        "--out",
        type=str,
        default="rtx_gif_vis",
        help="Output path prefix for saved figures (default: rtx_gif_vis).",
    )
    p.add_argument(
        "--seed",
        type=int,
        default=None,
        help="Random seed for shuffling GIF order (default: none).",
    )
    args = p.parse_args()

    root = Path(args.data_root).resolve()
    if not root.is_dir():
        raise FileNotFoundError(f"Data root not found: {root}")

    gifs = find_gifs(root)
    if not gifs:
        raise FileNotFoundError(f"No *.gif under {root}. Expected structure: root/*/*.gif")

    if args.max_samples and args.max_samples > 0:
        gifs = gifs[: args.max_samples]
    if args.seed is not None:
        random.seed(args.seed)
    random.shuffle(gifs)
    print(f"Scanning {len(gifs)} GIFs under {root} (shuffled) ...")

    # Collect stats
    n_frames_list = []
    total_duration_ms_list = []
    fps_list = []
    for path in gifs:
        n_frames, total_ms, _ = get_gif_info(path)
        n_frames_list.append(n_frames)
        total_duration_ms_list.append(total_ms)
        if total_ms > 0:
            fps_list.append(n_frames / (total_ms / 1000.0))
        else:
            fps_list.append(np.nan)

    n_valid_fps = sum(1 for f in fps_list if not np.isnan(f))
    fps_arr = np.array([f for f in fps_list if not np.isnan(f)])

    # Print summary
    print("\n--- RTX GIF summary ---")
    print(f"  Total GIFs scanned: {len(gifs)}")
    print(f"  Frames per GIF: min={min(n_frames_list)}, max={max(n_frames_list)}, mean={np.mean(n_frames_list):.1f}")
    print(f"  Total duration (ms) per GIF: min={min(total_duration_ms_list)}, max={max(total_duration_ms_list)}, mean={np.mean(total_duration_ms_list):.0f}")
    print(f"  GIFs with duration info: {n_valid_fps}/{len(gifs)}")
    if n_valid_fps:
        print(f"  Effective FPS: min={np.nanmin(fps_arr):.2f}, max={np.nanmax(fps_arr):.2f}, mean={np.nanmean(fps_arr):.2f}, median={np.nanmedian(fps_arr):.2f}")
        print("\n  Suggested: use an extraction FPS near the median (e.g. --sample_fps %.1f) so clips match GIF timing." % np.nanmedian(fps_arr))
    else:
        print("  (No duration metadata in GIFs; cannot compute FPS. Extraction FPS is your choice.)")

    if not HAS_MPL:
        print("\n(matplotlib not available; skipping figure export)")
        return

    out_prefix = Path(args.out)
    out_prefix.parent.mkdir(parents=True, exist_ok=True)

    # 1) FPS histogram
    if n_valid_fps:
        fig, ax = plt.subplots(1, 1, figsize=(8, 4))
        ax.hist(fps_arr, bins=min(50, max(10, n_valid_fps // 5)), color="steelblue", edgecolor="white")
        ax.axvline(np.nanmedian(fps_arr), color="red", linestyle="--", label=f"median = {np.nanmedian(fps_arr):.2f}")
        ax.set_xlabel("Effective FPS (frames / total_duration)")
        ax.set_ylabel("Count")
        ax.set_title("RTX GIF effective FPS distribution")
        ax.legend()
        fig.tight_layout()
        fig.savefig(str(out_prefix) + "_fps_hist.png", dpi=120)
        plt.close(fig)
        print(f"  Saved {out_prefix}_fps_hist.png")

    # 2) Frame count histogram
    fig, ax = plt.subplots(1, 1, figsize=(8, 4))
    ax.hist(n_frames_list, bins=min(50, max(2, max(n_frames_list) - min(n_frames_list) + 1)), color="seagreen", edgecolor="white")
    ax.set_xlabel("Number of frames in GIF")
    ax.set_ylabel("Count")
    ax.set_title("RTX GIF frame count distribution")
    fig.tight_layout()
    fig.savefig(str(out_prefix) + "_frame_count_hist.png", dpi=120)
    plt.close(fig)
    print(f"  Saved {out_prefix}_frame_count_hist.png")

    # 3) Frame grid: one row per episode, 30 (or show_frames) frames sampled uniformly per episode
    n_episodes = min(args.num_episodes, len(gifs))
    n_frames_per_ep = args.show_frames
    fig, axes = plt.subplots(n_episodes, n_frames_per_ep, figsize=(0.5 * n_frames_per_ep, 1.2 * n_episodes))
    if n_episodes == 1:
        axes = axes.reshape(1, -1)
    for row, path in enumerate(gifs[:n_episodes]):
        with Image.open(path) as im:
            n_available = getattr(im, "n_frames", 1)
            indices = np.linspace(0, n_available - 1, num=n_frames_per_ep).astype(int)
            for col, frame_idx in enumerate(indices):
                im.seek(int(frame_idx))
                fr = np.array(im.convert("RGB"))
                ax = axes[row, col]
                ax.imshow(fr)
                ax.set_axis_off()
                if col == 0:
                    ax.set_ylabel(path.name[:20] + "..." if len(path.name) > 20 else path.name, fontsize=6)
    fig.suptitle("%d frames sampled per episode (uniform in time), %d episodes" % (n_frames_per_ep, n_episodes), y=1.01)
    fig.tight_layout()
    fig.savefig(str(out_prefix) + "_frames.png", dpi=100)
    plt.close(fig)
    print(f"  Saved {out_prefix}_frames.png")
    print("Done.")


if __name__ == "__main__":
    main()
