#!/usr/bin/env python3
"""gen_rollout_grid.py — Combine eval videos into 5x5 grid video + individual GIFs per cell."""
import cv2, numpy as np, os, json
from pathlib import Path

EXPERIMENT_NAME = "act_baseline"
n_views = 8
grid_size = 5  # subsample 8x8 → 5x5
thumb = 140
max_frames = 32

results_dir = Path(f"results/{EXPERIMENT_NAME}")
results = json.load(open(results_dir / "grid_results.json"))

thetas_deg = np.linspace(0, 25, n_views)
phis_deg = np.linspace(0, 360*(1-1/n_views), n_views)

# Sample 5 theta and 5 phi indices
theta_sample = np.round(np.linspace(0, 7, grid_size)).astype(int)
phi_sample = np.round(np.linspace(0, 7, grid_size)).astype(int)
selected_vis = [ti * n_views + pi for ti in theta_sample for pi in phi_sample]

def extract_frames(video_path, max_f=max_frames, skip_first=1):
    cap = cv2.VideoCapture(str(video_path))
    frames = []
    idx = 0
    while len(frames) < max_f:
        ret, frame = cap.read()
        if not ret: break
        if idx >= skip_first:
            frames.append(frame)
        idx += 1
    cap.release()
    while len(frames) < max_f:
        frames.append(frames[-1] if frames else np.zeros((448,448,3), dtype=np.uint8))
    return frames

# Load eval videos for selected viewpoints
all_seqs = []
vid_labels = []
for vi in selected_vis:
    vid_dir = results_dir / f"vp_{vi}" / "videos" / "task_0"
    vids = sorted(vid_dir.glob("*.mp4")) if vid_dir.exists() else []
    if vids:
        all_seqs.append(extract_frames(vids[0]))
    else:
        all_seqs.append([np.zeros((448,448,3), dtype=np.uint8)] * max_frames)

    # Get rate for label
    rate = results.get(str(vi), {}).get("rate", 0)
    ti, pi = vi // n_views, vi % n_views
    vid_labels.append((thetas_deg[ti], phis_deg[pi], rate))

# Colors
GREEN = (80, 220, 80)
RED = (80, 80, 220)
YELLOW = (50, 190, 250)
WHITE = (255, 255, 255)

def rate_color(r):
    if r > 0.5: return GREEN
    if r > 0.2: return YELLOW
    return RED

# Build grid video with labels
label_h = 20
col_header_h = 20
row_header_w = 50
cell_h = thumb + label_h
cell_w = thumb

video_frames = []
for t in range(max_frames):
    # Canvas
    canvas_h = col_header_h + grid_size * cell_h
    canvas_w = row_header_w + grid_size * cell_w
    canvas = np.zeros((canvas_h, canvas_w, 3), dtype=np.uint8)

    # Column headers (phi)
    for c in range(grid_size):
        pi = phi_sample[c]
        x = row_header_w + c * cell_w + cell_w // 2 - 15
        cv2.putText(canvas, f"phi={phis_deg[pi]:.0f}", (x, col_header_h - 5),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.35, (180,180,180), 1)

    # Row headers (theta) and cells
    for r in range(grid_size):
        ti = theta_sample[r]
        y_base = col_header_h + r * cell_h
        cv2.putText(canvas, f"th={thetas_deg[ti]:.0f}", (2, y_base + cell_h // 2),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.33, (180,180,180), 1)

        for c in range(grid_size):
            idx = r * grid_size + c
            img = cv2.resize(all_seqs[idx][t], (thumb, thumb))
            theta_d, phi_d, rate = vid_labels[idx]
            color = rate_color(rate)

            # Draw border
            img[:2, :] = color; img[-2:, :] = color
            img[:, :2] = color; img[:, -2:] = color

            x = row_header_w + c * cell_w
            y = y_base
            canvas[y:y+thumb, x:x+thumb] = img

            # Label below
            lbl = f"{rate*100:.0f}%"
            cv2.putText(canvas, lbl, (x + thumb//2 - 12, y + thumb + label_h - 5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, color, 1)

    video_frames.append(canvas)

# Save as H.264 video
h, w = video_frames[0].shape[:2]
tmp = f"/tmp/grid_{EXPERIMENT_NAME}.mp4"
out = str(results_dir / "rollout_grid.mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(tmp, fourcc, 5, (w, h))
for f in video_frames:
    writer.write(f)
writer.release()
os.system(f"ffmpeg -y -i {tmp} -c:v libx264 -preset ultrafast -crf 23 -movflags +faststart {out} 2>/dev/null && rm {tmp}")
print(f"Saved: {out}")

# Also save a static thumbnail (first frame) as PNG for the website
cv2.imwrite(str(results_dir / "rollout_grid_thumb.png"), video_frames[0])
print(f"Saved thumbnail: {results_dir / 'rollout_grid_thumb.png'}")

# Save individual cell GIFs as fallback for browsers
print("Done!")
