"""Sanity check the UMI Izzy Towel calibration / projection.

Loads the dataset, then for a sampling of frames per episode renders the RGB image
with:
  - red bullseye at the projected EEF pixel for the current frame
  - rainbow line through next N_TRAJ projected EEF pixels (color = time)

Also prints numerical stats: 3D bbox of EEF positions, pix bbox, and the correlation
between Δ3D and Δpix per axis (to verify the projection actually has the right sign).
"""
import os, sys, json
sys.path.insert(0, "/data/cameron/para/libero")
import numpy as np
import torch
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT

OUT = Path("/data/cameron/para/paper/figs/generated/umi_izzy_towel_projection_check.png")
OUT.parent.mkdir(parents=True, exist_ok=True)

ds = Smith300DA3VolumeDataset(
    root_dir="/data/cameron/mac_robot_datasets",
    sessions_whitelist=["umi_collect_izzy_towel"],
    n_window=20, frame_stride=1,
)
print(f"Loaded: {len(ds)} samples, {len(ds.episodes)} episodes")

# Numerical stats (from dataset's preloaded arrays — already filtered to in-bounds)
eef_xy_pix = ds.pix_t.numpy()                                # (N, 2) projected eef pixel in 504-space
eef_z      = ds.eef_z_t.numpy()                              # (N,) world height
print(f"\n3D / pix bounding boxes (in-bounds frames only):")
print(f"  pix x: [{eef_xy_pix[:, 0].min():.1f}, {eef_xy_pix[:, 0].max():.1f}]  pix y: [{eef_xy_pix[:, 1].min():.1f}, {eef_xy_pix[:, 1].max():.1f}]  (504×504 image)")
print(f"  z   : [{eef_z.min():.3f}, {eef_z.max():.3f}] m  (std={eef_z.std():.3f})")

# Frame-to-frame correlation check: Δworld vs Δpix per episode
# Load raw eef_pos from the dataset's frame mapping
print("\nΔ-correlation sanity check (per episode, between frame i and i+1):")
for ep_idx, ep in enumerate(ds.episodes[:5]):
    frames = ep['frames']
    if len(frames) < 5:
        continue
    g_idx = np.array(frames, dtype=np.int64)
    pix_seq = eef_xy_pix[g_idx]                              # (T_ep, 2)
    z_seq   = eef_z[g_idx]                                   # (T_ep,)
    d_pix = np.diff(pix_seq, axis=0)                         # (T_ep-1, 2)
    d_z   = np.diff(z_seq)                                   # (T_ep-1,)
    # Heuristic: if pix-Δ has consistent sign with z-Δ → projection is locked in 1 axis
    print(f"  ep {ep_idx}: T={len(frames)}  Δpix mean=({d_pix[:,0].mean():+.2f},{d_pix[:,1].mean():+.2f})  std=({d_pix[:,0].std():.2f},{d_pix[:,1].std():.2f})  Δz_mean={d_z.mean():+.4f}")

# Visual check: 4 episodes × 3 timestamps each
print("\nRendering panel grid...")
N_EP_SHOW = 4
N_TS_SHOW = 4   # cols per episode
sampled_eps = np.linspace(0, len(ds.episodes) - 1, N_EP_SHOW).astype(int)

fig, axes = plt.subplots(N_EP_SHOW, N_TS_SHOW, figsize=(3.5 * N_TS_SHOW, 3.5 * N_EP_SHOW))
for r, ep_idx in enumerate(sampled_eps):
    ep = ds.episodes[ep_idx]
    frames = ep['frames']
    # Pick N_TS_SHOW timestamps spread through the episode
    ts_idx = np.linspace(0, len(frames) - 12, N_TS_SHOW).astype(int)
    for c, t_local in enumerate(ts_idx):
        g_start = int(frames[t_local])
        rgb = ds.rgb_t[g_start].numpy().transpose(1, 2, 0).clip(0, 1)            # (504, 504, 3)
        img = (rgb * 255).astype(np.uint8).copy()
        # Overlay current pix + next 10 future pixels as rainbow
        n_future = min(10, len(frames) - t_local - 1)
        for k in range(n_future):
            g_k = int(frames[t_local + k])
            px = int(eef_xy_pix[g_k, 0])
            py = int(eef_xy_pix[g_k, 1])
            hue = int(k / max(n_future - 1, 1) * 170)
            col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
            cv2.circle(img, (px, py), 5, col, -1)
            if k == 0:
                cv2.circle(img, (px, py), 12, (255, 255, 255), 2)
        ax = axes[r, c]
        ax.imshow(img); ax.set_xticks([]); ax.set_yticks([])
        if c == 0:
            ax.set_ylabel(f"ep {ep_idx}", fontsize=10)
        ax.set_title(f"frame {g_start}  (z={eef_z[g_start]:.3f})", fontsize=8)

fig.suptitle(f"UMI Izzy Towel — projected EEF pixel + 10-step future (rainbow). "
             f"Workspace pix=[{eef_xy_pix[:,0].min():.0f},{eef_xy_pix[:,0].max():.0f}]×"
             f"[{eef_xy_pix[:,1].min():.0f},{eef_xy_pix[:,1].max():.0f}]  z=[{eef_z.min():.3f},{eef_z.max():.3f}]m",
             fontsize=11)
fig.tight_layout()
fig.savefig(OUT, dpi=140, bbox_inches='tight', facecolor='white')
print(f"\n✓ Saved {OUT}")
