"""
Generate a 2x2 grid video comparing ACT vs PARA rollouts + feature PCAs.

Layout:
  ┌──────────────┬──────────────┐
  │  ACT Rollout  │  ACT PCA     │
  ├──────────────┼──────────────┤
  │  PARA Rollout │  PARA PCA    │
  └──────────────┴──────────────┘

Both policies run on the same OOD condition (shifted object position).
DINO backbone features are extracted at each replan step and visualized via joint PCA.

Usage:
    export PYTHONPATH=/data/cameron/LIBERO:/data/cameron/para_normalized_losses:$PYTHONPATH
    export DINO_REPO_DIR=/data/cameron/keygrip/dinov3
    export DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth
    python ood_libero/generate_feature_comparison.py
"""

import os, sys, argparse
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from pathlib import Path
from sklearn.decomposition import PCA

# Use the para_normalized_losses model files (matches checkpoints)
sys.path.insert(0, "/data/cameron/para_normalized_losses/libero")

from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm_lib, get_libero_path
import h5py

IMAGE_SIZE = 448
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406])
IMAGENET_STD = np.array([0.229, 0.224, 0.225])


def preprocess_obs(rgb_obs, size=448):
    """Convert LIBERO obs to model input tensor."""
    img = np.flipud(rgb_obs).copy()
    img = cv2.resize(img, (size, size), interpolation=cv2.INTER_LINEAR)
    img = img.astype(np.float32) / 255.0
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float()


def extract_backbone_features(model, img_tensor, model_type):
    """Extract DINO backbone patch features from a model (works for both PARA and ACT).

    Both models use a DINO backbone accessed via model.dino. We run the backbone
    manually to get spatial patch features regardless of the model type.

    Returns:
        patch_features: (H_p, W_p, C) numpy array
    """
    dino = model.dino
    with torch.no_grad():
        x_tokens, (H_p, W_p) = dino.prepare_tokens_with_masks(img_tensor)
        for blk in dino.blocks:
            rope_sincos = dino.rope_embed(H=H_p, W=W_p) if dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)
        x_tokens = dino.norm(x_tokens)
        n_storage = dino.n_storage_tokens if hasattr(dino, 'n_storage_tokens') else 0
        patch_tokens = x_tokens[:, n_storage + 1:]  # skip CLS + storage
        feats = patch_tokens[0].reshape(H_p, W_p, -1).cpu().numpy()
    return feats


def get_clean_rgb(obs, camera="agentview"):
    """Get clean RGB frame from obs (flipped, resized)."""
    rgb = np.flipud(obs[f"{camera}_image"]).copy()
    rgb = cv2.resize(rgb, (IMAGE_SIZE, IMAGE_SIZE), interpolation=cv2.INTER_LINEAR)
    return rgb


def compute_joint_pca(all_features, target_size):
    """Compute PCA on all features jointly, return per-frame PCA images.

    Args:
        all_features: list of (H_p, W_p, C) arrays
        target_size: (H, W) to upsample to
    Returns:
        list of (H, W, 3) uint8 PCA images
    """
    H, W = target_size
    H_p, W_p, C = all_features[0].shape

    stacked = np.concatenate([f.reshape(-1, C) for f in all_features], axis=0)
    pca = PCA(n_components=3)
    transformed = pca.fit_transform(stacked)

    # Normalize globally
    vmin, vmax = transformed.min(axis=0), transformed.max(axis=0)
    rng = vmax - vmin
    rng[rng == 0] = 1.0
    transformed = (transformed - vmin) / rng

    n_patches = H_p * W_p
    pca_images = []
    for i in range(len(all_features)):
        pca_frame = transformed[i * n_patches:(i + 1) * n_patches]
        pca_rgb = pca_frame.reshape(H_p, W_p, 3)
        pca_up = cv2.resize(pca_rgb, (W, H), interpolation=cv2.INTER_LINEAR)
        pca_images.append((pca_up * 255).clip(0, 255).astype(np.uint8))

    return pca_images


def add_label(frame, text, position="top", font_scale=0.55, thickness=2):
    """Add text label."""
    frame = frame.copy()
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    (tw, th), bl = cv2.getTextSize(text, font, font_scale, thickness)
    if position == "top":
        x, y = (w - tw) // 2, th + 10
    else:
        x, y = (w - tw) // 2, h - 12
    cv2.rectangle(frame, (x - 4, y - th - 4), (x + tw + 4, y + bl + 4), (0, 0, 0), -1)
    cv2.putText(frame, text, (x, y), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)
    return frame


def add_badge(frame, text, success=True):
    """Add success/failure badge top-right."""
    frame = frame.copy()
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    fs, th = 0.45, 1
    (tw, tht), bl = cv2.getTextSize(text, font, fs, th)
    color = (0, 160, 0) if success else (0, 0, 180)
    x, y = w - tw - 12, tht + 10
    cv2.rectangle(frame, (x - 4, y - tht - 4), (x + tw + 4, y + bl + 4), color, -1)
    cv2.putText(frame, text, (x, y), font, fs, (255, 255, 255), th, cv2.LINE_AA)
    return frame


def run_policy_rollout(env, model, model_type, init_state, device,
                       shift_dx=0, shift_dy=0, cam_theta=0, cam_phi=0,
                       max_steps=300, camera="agentview"):
    """Run a single episode using teleport+zero_rotation servo execution.

    For each replan step:
      - Predict a window of actions and 3D target positions
      - For each target in the window, servo to the 3D position (max 25 steps, 5mm threshold)
        with zero rotation, then apply gripper in a separate step

    Returns:
        rgb_frames: list of (H, W, 3) uint8
        features: list of (H_p, W_p, C) numpy arrays
        success: bool
    """
    import model as model_module
    from eval import (get_camera_params, eef_to_start_kp, decode_window_actions,
                      decode_act_actions, _reposition_camera)

    env.reset()

    # Shift object position
    state = init_state.copy()
    pick_obj_indices = [10]   # bowl
    place_obj_indices = [38]  # plate
    for si in pick_obj_indices + place_obj_indices:
        state[si] += shift_dx
        state[si + 1] += shift_dy

    obs = env.set_init_state(state)

    # Clean scene — hide furniture and distractors
    sim = env.env.sim
    for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(name)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass
    distractor_names = ["akita_black_bowl_2_main", "cookies_1_main",
                        "glazed_rim_porcelain_ramekin_1_main"]
    distractor_bids = set()
    for name in distractor_names:
        try:
            distractor_bids.add(sim.model.body_name2id(name))
        except Exception:
            pass
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in distractor_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
    sim.forward()

    # Reposition camera if needed
    if cam_theta != 0 or cam_phi != 0:
        _reposition_camera(env.sim, camera, cam_theta, cam_phi)

    # Settle
    for _ in range(5):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, camera, IMAGE_SIZE)

    # Load CLIP embedding for ACT
    clip_embedding = None
    if model_type == "act":
        clip_path = f"/data/libero/parsed_libero/libero_spatial/task_0_clip.pt"
        if os.path.exists(clip_path):
            clip_embedding = torch.load(clip_path, map_location=device).unsqueeze(0)

    rgb_frames = []
    feat_maps = []
    step_idx = 0
    done = False
    success = False
    current_gripper_cmd = -1.0  # track gripper state across entire episode

    while step_idx < max_steps and not done:
        current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
        rgb_obs = obs[f"{camera}_image"]

        # Get clean RGB
        rgb_clean = get_clean_rgb(obs, camera)

        # Prepare input
        start_kp = eef_to_start_kp(current_eef_pos, world_to_camera, IMAGE_SIZE).to(device)
        img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device)

        # Extract backbone features (before action head)
        feats = extract_backbone_features(model, img_tensor, model_type)

        # Save frame and features
        rgb_frames.append(rgb_clean)
        feat_maps.append(feats)

        # Run policy to get actions and 3D targets
        if model_type == "act":
            min_pos = np.array(model_module.MIN_POS, dtype=np.float64)
            max_pos = np.array(model_module.MAX_POS, dtype=np.float64)
            eef_norm = torch.tensor(
                (current_eef_pos - min_pos) / (max_pos - min_pos + 1e-8),
                dtype=torch.float32, device=device
            ).clamp(0, 1).unsqueeze(0)
            grip_state = float(obs.get("robot0_gripper_qpos", [0, 0])[0])
            grip_norm = torch.tensor(
                [(grip_state - model_module.MIN_GRIPPER) / (model_module.MAX_GRIPPER - model_module.MIN_GRIPPER + 1e-8)],
                dtype=torch.float32, device=device
            ).clamp(0, 1).unsqueeze(0)
            extra = {}
            if clip_embedding is not None:
                extra['clip_embedding'] = clip_embedding
            with torch.no_grad():
                pos_pred, rot_pred, gripper_pred = model(
                    img_tensor, start_kp,
                    current_eef_pos=eef_norm,
                    current_gripper=grip_norm,
                    **extra,
                )
            window_actions = decode_act_actions(
                pos_pred, rot_pred, gripper_pred,
                current_eef_pos, current_eef_quat,
            )
            # Extract absolute 3D targets for teleport mode
            pos_denorm = pos_pred[0].cpu().numpy().astype(np.float64) * (max_pos - min_pos) + min_pos
            pred_3d_targets = [pos_denorm[t] for t in range(pos_denorm.shape[0])]
        else:
            with torch.no_grad():
                volume_logits, _, _, model_feats = model(img_tensor, start_kp)
            window_actions, pred_3d_targets = decode_window_actions(
                volume_logits, model, model_feats,
                camera_pose, cam_K, current_eef_pos, current_eef_quat,
                image_size=IMAGE_SIZE,
            )

        # Execute window using teleport + zero_rotation servo
        for t, action in enumerate(window_actions):
            target_pos = pred_3d_targets[t].astype(np.float64)
            new_gripper = action[6]
            servo_steps = 0
            max_servo = 25
            threshold = 0.005  # 5mm

            # Phase 1: servo to target position with zero rotation, holding current gripper
            while servo_steps < max_servo:
                cur_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                delta = target_pos - cur_pos
                dist = np.linalg.norm(delta)
                if dist < threshold:
                    break
                delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
                servo_action = np.zeros(7, dtype=np.float32)
                servo_action[:3] = delta_clipped
                # Zero rotation — don't apply rotation deltas
                servo_action[6] = current_gripper_cmd
                obs, _, done, _ = env.step(servo_action)
                step_idx += 1
                servo_steps += 1
                if done:
                    success = True
                    break
                if step_idx >= max_steps:
                    break

            # Phase 2: at target position, apply new gripper
            if not done and step_idx < max_steps:
                grip_action = np.zeros(7, dtype=np.float32)
                grip_action[6] = new_gripper
                obs, _, done, _ = env.step(grip_action)
                step_idx += 1
                current_gripper_cmd = new_gripper

            if done:
                success = True
                break
            if step_idx >= max_steps:
                break

    return rgb_frames, feat_maps, success


def extract_cls_attention_map(dino, img_tensor):
    """Extract CLS-to-patch self-attention weights from the last block of a DINO backbone.

    Since DINOv3 uses scaled_dot_product_attention (which doesn't store weights),
    we manually compute attention weights by hooking into the last block's attention
    module and intercepting the QKV projection output.

    Args:
        dino: DINO backbone (DinoVisionTransformer)
        img_tensor: (1, 3, H, W) preprocessed image tensor

    Returns:
        attn_map: (num_heads, H_p, W_p) numpy array — CLS token attention to each patch
    """
    captured = {}

    def hook_fn(module, input, output):
        # input[0] is x after norm, shape (B, N, C)
        x = input[0]
        captured['x'] = x

    # Hook the attention module in the last block
    last_block = dino.blocks[-1]
    handle = last_block.attn.register_forward_hook(hook_fn)

    with torch.no_grad():
        x_tokens, (H_p, W_p) = dino.prepare_tokens_with_masks(img_tensor)
        for blk in dino.blocks:
            rope_sincos = dino.rope_embed(H=H_p, W=W_p) if dino.rope_embed else None
            x_tokens = blk(x_tokens, rope_sincos)

    handle.remove()

    # Now manually compute attention weights from the captured input
    x = captured['x']
    attn_module = last_block.attn
    B, N, C = x.shape
    num_heads = attn_module.num_heads
    head_dim = C // num_heads

    # Compute QKV
    with torch.no_grad():
        # The hook captured input to forward, but we need to apply norm1 first
        # Actually the hook on attn gets input AFTER norm1 is applied by the block
        # Block does: x_attn = x + ls1(attn(norm1(x), rope=rope))
        # So hook input[0] is norm1(x) — that's correct for the attention module
        # But wait, attn.forward takes (x, attn_bias, rope), so input[0] is the normed x
        normed_x = x  # already normed since hook is on attn which receives norm1(x)
        qkv = attn_module.qkv(normed_x)
        qkv = qkv.reshape(B, N, 3, num_heads, head_dim)
        q, k, v = torch.unbind(qkv, 2)
        q, k = [t.transpose(1, 2) for t in [q, k]]  # (B, heads, N, head_dim)

        # Apply RoPE if available (same as during forward)
        if dino.rope_embed is not None:
            rope_sincos = dino.rope_embed(H=H_p, W=W_p)
            q, k = attn_module.apply_rope(q, k, rope_sincos)

        # Compute attention weights: softmax(Q @ K^T / sqrt(d))
        scale = head_dim ** -0.5
        attn_weights = torch.matmul(q * scale, k.transpose(-2, -1))  # (B, heads, N, N)
        attn_weights = torch.softmax(attn_weights.float(), dim=-1)

    # Extract CLS token (index 0) attention to patch tokens
    n_prefix = 1 + (dino.n_storage_tokens if hasattr(dino, 'n_storage_tokens') else 0)
    cls_attn = attn_weights[0, :, 0, n_prefix:]  # (heads, num_patches)
    cls_attn = cls_attn.reshape(num_heads, H_p, W_p).cpu().numpy()

    return cls_attn


def generate_attention_map_grid(act_model, para_model, first_frame_rgb, device, output_path):
    """Generate a grid PNG comparing self-attention maps from ACT and PARA DINO backbones.

    Layout: RGB | ACT head 1..6 | PARA head 1..6

    Args:
        act_model: loaded ACT model
        para_model: loaded PARA model
        first_frame_rgb: (H, W, 3) uint8 RGB image (already resized to IMAGE_SIZE)
        device: torch device
        output_path: path to save the grid PNG
    """
    # Preprocess the RGB frame for DINO input
    img_f = first_frame_rgb.astype(np.float32) / 255.0
    img_f = (img_f - IMAGENET_MEAN) / IMAGENET_STD
    img_tensor = torch.from_numpy(img_f).permute(2, 0, 1).unsqueeze(0).float().to(device)

    # Extract attention maps from both backbones
    act_attn = extract_cls_attention_map(act_model.dino, img_tensor)   # (num_heads, H_p, W_p)
    para_attn = extract_cls_attention_map(para_model.dino, img_tensor)  # (num_heads, H_p, W_p)

    num_heads = act_attn.shape[0]  # should be 6 for ViT-S

    # Build the grid: RGB | ACT heads | PARA heads
    # Total columns: 1 + num_heads + num_heads = 13
    cell_h, cell_w = IMAGE_SIZE, IMAGE_SIZE
    n_cols = 1 + num_heads + num_heads
    grid_w = cell_w * n_cols
    grid_h = cell_h

    grid = np.zeros((grid_h, grid_w, 3), dtype=np.uint8)

    # Column 0: RGB image
    rgb_bgr = cv2.cvtColor(first_frame_rgb, cv2.COLOR_RGB2BGR)
    rgb_bgr = add_label(rgb_bgr, "RGB", "top")
    grid[:, :cell_w] = rgb_bgr

    def attn_to_heatmap(attn_head, size):
        """Convert a single attention head (H_p, W_p) to a colored heatmap (size, size, 3)."""
        # Normalize to [0, 1]
        a = attn_head.copy()
        a = (a - a.min()) / (a.max() - a.min() + 1e-8)
        # Upsample to image size
        a_up = cv2.resize(a, (size, size), interpolation=cv2.INTER_LINEAR)
        # Apply colormap
        heatmap = cv2.applyColorMap((a_up * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)
        return heatmap

    # Columns 1..num_heads: ACT attention heads
    for h in range(num_heads):
        col = 1 + h
        heatmap = attn_to_heatmap(act_attn[h], cell_w)
        heatmap = add_label(heatmap, f"ACT head {h+1}", "top")
        grid[:, col * cell_w:(col + 1) * cell_w] = heatmap

    # Columns (1+num_heads)..(1+2*num_heads): PARA attention heads
    for h in range(num_heads):
        col = 1 + num_heads + h
        heatmap = attn_to_heatmap(para_attn[h], cell_w)
        heatmap = add_label(heatmap, f"PARA head {h+1}", "top")
        grid[:, col * cell_w:(col + 1) * cell_w] = heatmap

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    cv2.imwrite(output_path, grid)
    print(f"Saved attention map grid: {output_path}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output_path", type=str,
                        default="/data/cameron/para/.agents/reports/project_site/media/feature_pca_comparison.mp4")
    parser.add_argument("--shift_dy", type=float, default=0.18,
                        help="Object position shift (OOD)")
    parser.add_argument("--fps", type=int, default=5,
                        help="Output FPS (each frame is one replan step)")
    parser.add_argument("--max_steps", type=int, default=300)
    parser.add_argument("--episode_idx", type=int, default=1,
                        help="Which init state to use (1 gave PARA success before)")
    args = parser.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # ── Load both models ──
    act_ckpt = "/data/cameron/para_normalized_losses/libero/checkpoints/act_v2_exp4_n64/best.pth"
    act_sd = torch.load(act_ckpt, map_location=device)
    act_n_window = act_sd["model_state_dict"]["pos_mlp.5.weight"].shape[0] // 3
    print(f"Loading ACT model (N_WINDOW={act_n_window})...")
    import model as model_module
    from model_act import ACTPredictor
    act_model = ACTPredictor(n_window=act_n_window)
    act_model.load_state_dict(act_sd["model_state_dict"])
    act_model = act_model.to(device).eval()
    # Restore normalization stats
    import model as model_module
    for key in ["MIN_POS", "MAX_POS", "MIN_HEIGHT", "MAX_HEIGHT", "MIN_GRIPPER", "MAX_GRIPPER",
                "MIN_ROT", "MAX_ROT", "REF_ROT"]:
        if key.lower() in act_sd:
            setattr(model_module, key, act_sd[key.lower()])

    para_ckpt = "/data/cameron/para_normalized_losses/libero/checkpoints/para_v2_exp4_n64/best.pth"
    para_sd = torch.load(para_ckpt, map_location=device)
    para_n_window = para_sd["model_state_dict"]["volume_head.weight"].shape[0] // 32
    print(f"Loading PARA model (N_WINDOW={para_n_window})...")
    model_module.N_WINDOW = para_n_window
    from model import TrajectoryHeatmapPredictor
    para_model = TrajectoryHeatmapPredictor(n_window=para_n_window)
    para_model.load_state_dict(para_sd["model_state_dict"])
    para_model = para_model.to(device).eval()
    for key in ["MIN_POS", "MAX_POS", "MIN_HEIGHT", "MAX_HEIGHT", "MIN_GRIPPER", "MAX_GRIPPER",
                "MIN_ROT", "MAX_ROT", "REF_ROT"]:
        if key.lower() in para_sd:
            setattr(model_module, key, para_sd[key.lower()])

    # ── Setup environment ──
    print("Setting up LIBERO environment...")
    benchmark = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = benchmark.get_task(0)
    bddl_file = os.path.join(get_libero_path("bddl_files"),
                              task.problem_folder, task.bddl_file)

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=["agentview"],
    )
    env.seed(0)
    env.reset()

    demo_path = os.path.join(get_libero_path("datasets"),
                              benchmark.get_task_demonstration(0))
    with h5py.File(demo_path, "r") as f:
        init_state = np.array(f["data/demo_0/states"][args.episode_idx])

    # ── Run ACT rollout ──
    print(f"\nRunning ACT rollout (shift_dy={args.shift_dy})...")
    # Restore ACT normalization stats
    for key in ["MIN_POS", "MAX_POS", "MIN_HEIGHT", "MAX_HEIGHT", "MIN_GRIPPER", "MAX_GRIPPER",
                "MIN_ROT", "MAX_ROT", "REF_ROT"]:
        if key.lower() in act_sd:
            setattr(model_module, key, act_sd[key.lower()])

    act_rgbs, act_feats, act_success = run_policy_rollout(
        env, act_model, "act", init_state, device,
        shift_dx=-0.08, shift_dy=args.shift_dy, max_steps=args.max_steps,
    )
    print(f"  ACT: {len(act_rgbs)} replan steps, {'SUCCESS' if act_success else 'FAILURE'}")

    # ── Run PARA rollout ──
    print(f"Running PARA rollout (shift_dy={args.shift_dy})...")
    for key in ["MIN_POS", "MAX_POS", "MIN_HEIGHT", "MAX_HEIGHT", "MIN_GRIPPER", "MAX_GRIPPER",
                "MIN_ROT", "MAX_ROT", "REF_ROT"]:
        if key.lower() in para_sd:
            setattr(model_module, key, para_sd[key.lower()])

    para_rgbs, para_feats, para_success = run_policy_rollout(
        env, para_model, "para", init_state, device,
        shift_dx=-0.08, shift_dy=args.shift_dy, max_steps=args.max_steps,
    )
    print(f"  PARA: {len(para_rgbs)} replan steps, {'SUCCESS' if para_success else 'FAILURE'}")

    # ── Generate attention map grid from the first OOD frame ──
    print("Generating attention map grid...")
    first_frame_rgb = para_rgbs[0] if para_rgbs else act_rgbs[0]
    attn_grid_path = "/data/cameron/para/.agents/reports/project_site/media/attention_map_grid.png"
    generate_attention_map_grid(act_model, para_model, first_frame_rgb, device, attn_grid_path)

    env.close()

    # ── Compute joint PCA across ALL features from BOTH runs ──
    print("Computing joint PCA...")
    all_feats = act_feats + para_feats
    all_pca_images = compute_joint_pca(all_feats, (IMAGE_SIZE, IMAGE_SIZE))

    act_pca_images = all_pca_images[:len(act_feats)]
    para_pca_images = all_pca_images[len(act_feats):]

    # ── Sync frame counts (pad shorter run with last frame) ──
    max_frames = max(len(act_rgbs), len(para_rgbs))
    while len(act_rgbs) < max_frames:
        act_rgbs.append(act_rgbs[-1])
        act_pca_images.append(act_pca_images[-1])
    while len(para_rgbs) < max_frames:
        para_rgbs.append(para_rgbs[-1])
        para_pca_images.append(para_pca_images[-1])

    # ── Compose 2x2 grid video ──
    print(f"Composing {max_frames}-frame 2x2 grid video...")
    cell_size = IMAGE_SIZE
    grid_w = cell_size * 2
    grid_h = cell_size * 2

    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(args.output_path, fourcc, args.fps, (grid_w, grid_h))

    for i in range(max_frames):
        # Convert RGB to BGR for OpenCV
        act_rgb_bgr = cv2.cvtColor(act_rgbs[i], cv2.COLOR_RGB2BGR)
        act_pca_bgr = cv2.cvtColor(act_pca_images[i], cv2.COLOR_RGB2BGR)
        para_rgb_bgr = cv2.cvtColor(para_rgbs[i], cv2.COLOR_RGB2BGR)
        para_pca_bgr = cv2.cvtColor(para_pca_images[i], cv2.COLOR_RGB2BGR)

        # Add labels
        act_rgb_bgr = add_label(act_rgb_bgr, "ACT Rollout", "top")
        act_rgb_bgr = add_badge(act_rgb_bgr, "FAILURE", success=False)

        act_pca_bgr = add_label(act_pca_bgr, "ACT Features (DINO PCA)", "top")

        para_rgb_bgr = add_label(para_rgb_bgr, "PARA Rollout", "top")
        para_rgb_bgr = add_badge(para_rgb_bgr, "SUCCESS", success=True)

        para_pca_bgr = add_label(para_pca_bgr, "PARA Features (DINO PCA)", "top")

        # Compose grid
        top_row = np.hstack([act_rgb_bgr, act_pca_bgr])
        bottom_row = np.hstack([para_rgb_bgr, para_pca_bgr])
        grid = np.vstack([top_row, bottom_row])

        # Add bottom text
        font = cv2.FONT_HERSHEY_SIMPLEX
        text = "Same backbone. Same features. Different action head."
        (tw, th), bl = cv2.getTextSize(text, font, 0.6, 2)
        x = (grid_w - tw) // 2
        y = grid_h - 15
        cv2.rectangle(grid, (x - 6, y - th - 6), (x + tw + 6, y + bl + 6), (0, 0, 0), -1)
        cv2.putText(grid, text, (x, y), font, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

        writer.write(grid)

    writer.release()

    # Re-encode H.264
    h264_path = args.output_path.replace(".mp4", "_h264.mp4")
    ret = os.system(
        f'ffmpeg -y -i "{args.output_path}" -c:v libx264 -preset ultrafast -crf 23 '
        f'-movflags +faststart "{h264_path}" 2>/dev/null'
    )
    if ret == 0:
        os.replace(h264_path, args.output_path)

    print(f"Saved: {args.output_path}")
    print(f"Duration: {max_frames / args.fps:.1f}s at {args.fps} fps")


if __name__ == "__main__":
    main()
