"""overlay_first_frames.py — Overlay first frames from all demos to visualize object position variance.

Loads frame 000000.png from each demo of task 0 and produces:
  1. Mean image (static parts sharp, moving objects blurred)
  2. Variance heatmap overlaid on mean (highlights where objects differ across demos)

Usage:
    python ood_libero/overlay_first_frames.py
"""

import glob
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

PARSED_DIR = "/data/libero/parsed_libero/libero_spatial/task_0"


def main():
    out_dir = Path(__file__).resolve().parent / "out"
    out_dir.mkdir(parents=True, exist_ok=True)

    # Load all first frames
    pattern = f"{PARSED_DIR}/demo_*/frames/000000.png"
    paths = sorted(glob.glob(pattern))
    print(f"Found {len(paths)} demos")

    imgs = []
    for p in paths:
        bgr = cv2.imread(p)
        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        imgs.append(rgb.astype(np.float32))
    imgs = np.stack(imgs)  # (N, H, W, 3)

    mean_img = imgs.mean(axis=0)                         # (H, W, 3)
    var_img = imgs.var(axis=0).mean(axis=-1)             # (H, W) — mean variance across RGB
    var_norm = var_img / (var_img.max() + 1e-8)          # normalize to [0, 1]

    # Colormap the variance
    var_color = plt.cm.hot(var_norm)[:, :, :3]           # (H, W, 3) float [0,1]
    var_color = (var_color * 255).astype(np.uint8)

    # Blend: mean image with variance heatmap overlay
    alpha = 0.5
    mean_u8 = np.clip(mean_img, 0, 255).astype(np.uint8)
    overlay = cv2.addWeighted(mean_u8, 1 - alpha, var_color, alpha, 0)

    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    axes[0].imshow(mean_u8)
    axes[0].set_title(f"Mean of {len(paths)} first frames", fontsize=11)
    axes[0].axis("off")

    axes[1].imshow(var_norm, cmap="hot")
    axes[1].set_title("Per-pixel variance (object positions)", fontsize=11)
    axes[1].axis("off")

    axes[2].imshow(overlay)
    axes[2].set_title("Variance heatmap overlaid on mean", fontsize=11)
    axes[2].axis("off")

    plt.tight_layout()
    out_path = out_dir / "first_frame_overlay.png"
    plt.savefig(str(out_path), dpi=150, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

    # Also save individual outputs
    cv2.imwrite(str(out_dir / "first_frame_mean.png"), cv2.cvtColor(mean_u8, cv2.COLOR_RGB2BGR))
    cv2.imwrite(str(out_dir / "first_frame_overlay_only.png"), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
    print("Done.")


if __name__ == "__main__":
    main()
