"""Render UMI Izzy Towel episode-start frames + DINO PCA across 9 components.

For ~10 episodes' start frames:
  Row 1: RGB
  Row 2: PCs 1-3 (as RGB)
  Row 3: PCs 4-6 (as RGB)
  Row 4: PCs 7-9 (as RGB)

PCA basis fit jointly over all frames so colors are consistent across columns. Each
3-PC triplet is normalised on its own joint min/max so each row is its own RGB space,
but within a row colors are comparable across episodes.
"""
import os, sys
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from data_da3_volume import Smith300DA3VolumeDataset
from model_dino_volume_query import DinoVolumeQuery, IMG_SIZE, PRED_SIZE

# Latest UMI Izzy Towel ckpt (still training — use what's there)
CKPT = "/data/cameron/para/libero/checkpoints/dino_query_umi_izzy_towel_t50_pca1d_v0/latest.pth"
# Fallback if not yet saved: use the desk_collect_1 ckpt (same DINO trunk)
FALLBACK = "/data/cameron/para/libero/checkpoints/dino_query_desk_collect_1_t50_pca1d_v0/latest.pth"

OUT = Path("/data/cameron/para/paper/figs/generated/umi_izzy_towel_dino_pca9.png")
OUT.parent.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Loading umi_collect_izzy_towel dataset...")
ds = Smith300DA3VolumeDataset(
    root_dir="/data/cameron/mac_robot_datasets",
    sessions_whitelist=["umi_collect_izzy_towel"],
    n_window=1, frame_stride=1,
)
ep_start_global = [int(ep['frames'][0]) for ep in ds.episodes]
print(f"  → {len(ds.episodes)} episodes, picking first 10")
N_EPS = min(10, len(ds.episodes))
ep_start_global = ep_start_global[:N_EPS]

rgb_starts = [ds.rgb_t[g] for g in ep_start_global]
rgbs = torch.stack(rgb_starts, dim=0)                            # (N, 3, 504, 504), in [0,1] (smith300 convention)
print(f"  rgbs: {tuple(rgbs.shape)}")

print("Loading model trunk + refine...")
m = DinoVolumeQuery(
    n_window=50, n_height_bins=32, n_gripper_bins=32, n_rot_bins=32,
    image_size=IMG_SIZE, pred_size=PRED_SIZE,
    use_eef=True, rotation_mode='1d_pca',
).to(device).eval()
ckpt_path = CKPT if Path(CKPT).exists() else FALLBACK
sd = torch.load(ckpt_path, map_location=device, weights_only=False)
m.load_state_dict(sd['model_state_dict'], strict=False)
print(f"  loaded {ckpt_path}")

all_F = []
with torch.no_grad():
    for i in range(N_EPS):
        rgb = rgbs[i:i+1].to(device)
        patch, _ = m._extract_dino_features(rgb)
        feat_up = F.interpolate(patch, size=(PRED_SIZE, PRED_SIZE),
                                 mode='bilinear', align_corners=False)
        F_feat = m.refine(feat_up)[0].cpu()
        all_F.append(F_feat)
F_all = torch.stack(all_F, dim=0)                                # (N, 32, 56, 56)
print(f"  F_all: {tuple(F_all.shape)}")

print("Fitting joint 9-PCA over all features...")
F_flat = F_all.numpy().transpose(0, 2, 3, 1).reshape(-1, F_all.shape[1])
mu = F_flat.mean(0, keepdims=True)
F_c = F_flat - mu
u, sv, vt = np.linalg.svd(F_c, full_matrices=False)
V_proj = vt[:9].T                                                # (32, 9)
ev = (sv ** 2) / (sv ** 2).sum()
print(f"  EV PC1..9: {[f'{e:.3f}' for e in ev[:9]]}")
print(f"  cumulative EV at 3/6/9: {ev[:3].sum():.3f} / {ev[:6].sum():.3f} / {ev[:9].sum():.3f}")

# Project all features to 9 components
F_pcs_flat = F_c @ V_proj                                        # (N*56*56, 9)
H = W = PRED_SIZE
F_pcs = F_pcs_flat.reshape(N_EPS, H, W, 9)                       # (N, H, W, 9)

# Per-triplet joint normalisation (2nd / 98th percentile across all images for that triplet)
triplet_imgs = []
for triplet_idx, (lo_i, hi_i) in enumerate([(0, 3), (3, 6), (6, 9)]):
    flat3 = F_pcs[..., lo_i:hi_i].reshape(-1, 3)
    lo, hi = np.percentile(flat3, [2, 98], axis=0)
    normed = np.clip((F_pcs[..., lo_i:hi_i] - lo) / (hi - lo + 1e-8), 0, 1)
    triplet_imgs.append(normed)                                  # (N, H, W, 3)
    print(f"  triplet PC{lo_i+1}-{hi_i}: EV={ev[lo_i:hi_i].sum():.3f}, ranges={list(zip(lo, hi))}")

# RGB
rgb_imgs = [rgbs[i].numpy().clip(0, 1).transpose(1, 2, 0) for i in range(N_EPS)]

print(f"Rendering 4 × {N_EPS} grid...")
fig, axes = plt.subplots(4, N_EPS, figsize=(1.8 * N_EPS, 7.5),
                          gridspec_kw={'hspace': 0.06, 'wspace': 0.04})
if N_EPS == 1:
    axes = axes.reshape(4, 1)

for c in range(N_EPS):
    axes[0, c].imshow(rgb_imgs[c]); axes[0, c].set_xticks([]); axes[0, c].set_yticks([])
    axes[0, c].set_title(f"ep {c}", fontsize=9, pad=3)
    for r_triplet, normed_imgs in enumerate(triplet_imgs):
        axes[r_triplet + 1, c].imshow(normed_imgs[c], interpolation='nearest')
        axes[r_triplet + 1, c].set_xticks([]); axes[r_triplet + 1, c].set_yticks([])

ev_str = f"{ev[:3].sum():.0%} / {ev[3:6].sum():.0%} / {ev[6:9].sum():.0%}  cum={ev[:9].sum():.0%}"
axes[0, 0].set_ylabel("RGB",       fontsize=10)
axes[1, 0].set_ylabel(f"PC1-3\n{ev[:3].sum():.0%} EV", fontsize=10)
axes[2, 0].set_ylabel(f"PC4-6\n{ev[3:6].sum():.0%} EV", fontsize=10)
axes[3, 0].set_ylabel(f"PC7-9\n{ev[6:9].sum():.0%} EV", fontsize=10)
fig.suptitle(f"UMI Izzy Towel — episode start frames + jointly-fit DINO 9-PCA "
             f"(per-triplet EV: {ev_str})", fontsize=11)
fig.tight_layout()
fig.savefig(OUT, dpi=180, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(f"✓ Saved {OUT}")
