"""Build a grid video from individual eval rollout videos.

Each cell shows a side-by-side (sim + gen) rollout with a colored border
to visually group the sim and gen panels as one rollout.

Usage:
    python build_grid_video.py --input_dir eval_grid_4x2 --rows 4 --cols 2 --out grid.mp4
"""

import argparse
import cv2
import glob
import numpy as np
import os
import subprocess


def extract_frames(video_path):
    """Extract all frames from a video as RGB numpy arrays."""
    cap = cv2.VideoCapture(str(video_path))
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
    cap.release()
    return frames


def build_grid(args):
    # Find all episode videos
    vid_dir = os.path.join(args.input_dir, "videos", f"task_{args.task_id}")
    vids = sorted(glob.glob(os.path.join(vid_dir, "ep*.mp4")))
    n_cells = args.rows * args.cols

    if len(vids) < n_cells:
        print(f"WARNING: only {len(vids)} videos found, need {n_cells}")
        while len(vids) < n_cells:
            vids.append(vids[-1])  # repeat last

    print(f"Building {args.rows}x{args.cols} grid from {len(vids)} videos")

    # Load all videos
    all_seqs = []
    for i in range(n_cells):
        frames = extract_frames(vids[i])
        print(f"  Cell {i}: {os.path.basename(vids[i])} — {len(frames)} frames, {frames[0].shape[1]}x{frames[0].shape[0]}")
        all_seqs.append(frames)

    # Pad to same length
    max_len = max(len(s) for s in all_seqs)
    for s in all_seqs:
        while len(s) < max_len:
            s.append(s[-1])

    # Subsample to keep video reasonable
    target_duration = 30  # seconds
    target_frames = args.fps * target_duration
    step = max(1, max_len // target_frames)
    frame_indices = list(range(0, max_len, step))
    print(f"  {max_len} max frames, subsampling every {step} → {len(frame_indices)} grid frames")

    # Calculate cell size
    cell_h, cell_w = all_seqs[0][0].shape[:2]
    # Scale down each cell to fit
    scale = args.cell_width / cell_w
    thumb_w = args.cell_width
    thumb_h = int(cell_h * scale)

    border = args.border
    # Grid dimensions
    grid_w = args.cols * (thumb_w + 2 * border)
    grid_h = args.rows * (thumb_h + 2 * border)
    print(f"  Cell: {thumb_w}x{thumb_h}, Border: {border}px, Grid: {grid_w}x{grid_h}")

    # Border colors for stages
    stage_colors = {
        "miss":  (180, 60, 60),    # red
        "grasp": (180, 180, 60),   # yellow
        "place": (60, 180, 60),    # green
    }

    # Detect stage from filename
    cell_colors = []
    for i in range(n_cells):
        name = os.path.basename(vids[i])
        color = (100, 100, 100)  # default gray
        for stage, c in stage_colors.items():
            if stage in name:
                color = c
                break
        cell_colors.append(color)

    # Build grid frames
    video_frames = []
    for t in frame_indices:
        grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)
        grid[:] = 40  # dark background

        for idx in range(n_cells):
            r = idx // args.cols
            c = idx % args.cols

            # Cell position
            x0 = c * (thumb_w + 2 * border)
            y0 = r * (thumb_h + 2 * border)

            # Fill border area with stage color
            grid[y0:y0 + thumb_h + 2 * border,
                 x0:x0 + thumb_w + 2 * border] = cell_colors[idx]

            # Place thumbnail inside border
            cell_frame = all_seqs[idx][t]
            thumb = cv2.resize(cell_frame, (thumb_w, thumb_h))
            grid[y0 + border:y0 + border + thumb_h,
                 x0 + border:x0 + border + thumb_w] = thumb

        video_frames.append(grid)

    # Save as H.264
    tmp = args.out.replace('.mp4', '_tmp.mp4')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(tmp, fourcc, args.fps, (grid_w, grid_h))
    for f in video_frames:
        writer.write(cv2.cvtColor(f, cv2.COLOR_RGB2BGR))
    writer.release()
    subprocess.run(["ffmpeg", "-y", "-i", tmp, "-c:v", "libx264",
                    "-preset", "ultrafast", "-crf", "23",
                    "-movflags", "+faststart", args.out],
                   capture_output=True)
    os.remove(tmp)
    print(f"  Saved: {args.out} ({len(video_frames)} frames at {args.fps} fps)")


if __name__ == "__main__":
    p = argparse.ArgumentParser()
    p.add_argument("--input_dir", type=str, required=True)
    p.add_argument("--rows", type=int, default=4)
    p.add_argument("--cols", type=int, default=2)
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--cell_width", type=int, default=400,
                   help="Width of each cell in the grid (pixels)")
    p.add_argument("--border", type=int, default=3,
                   help="Border width around each cell")
    p.add_argument("--fps", type=int, default=15)
    p.add_argument("--out", type=str, default="rollout_grid.mp4")
    args = p.parse_args()
    build_grid(args)
