"""Render izzy3 episode-start frames + DINO PCA, jointly fit so colors are consistent.

For each episode's first frame in izzy_home_recording_3:
  - Top row: RGB
  - Bottom row: DINO feature PCA (3-PCA of F, the refined 32-dim feature map)

The PCA basis is fit JOINTLY across all episodes' features so the colors mean the
same thing across columns. Per-pixel normalisation is also joint (single min/max
per PC across all data).
"""
import os, sys, math
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
import numpy as np
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT
from model_dino_volume_query import DinoVolumeQuery, IMG_SIZE, PRED_SIZE

CKPT = "/data/cameron/para/libero/checkpoints/dino_query_izzy3_t50_pca1d_v0/latest.pth"
OUT_DIR = Path("/data/cameron/para/paper/figs/generated")
OUT_DIR.mkdir(parents=True, exist_ok=True)
OUT_PNG = OUT_DIR / "izzy3_start_frames_dino_pca.png"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ────────── Dataset (n_window=1 — only need rgb and start-frame indices) ──────────
print("Loading izzy3 dataset...")
ds = Smith300DA3VolumeDataset(
    sessions_whitelist=['izzy_home_recording_3'],
    n_window=1, frame_stride=1,
)
# ds.episodes is a list of dicts; the first frame of each episode is the one we want.
# The `frames` list per episode holds global frame indices.
ep_start_global = [int(ep['frames'][0]) for ep in ds.episodes]
print(f"  → {len(ds.episodes)} episodes, first-frame global indices: {ep_start_global}")

# Pull the RGB tensor for each start-frame directly from the dataset's preloaded buffer
rgb_starts = [ds.rgb_t[g] for g in ep_start_global]                      # list of (3, 504, 504)
rgbs = torch.stack(rgb_starts, dim=0)                                    # (N_ep, 3, 504, 504)
N = rgbs.shape[0]
print(f"  rgbs: {tuple(rgbs.shape)}")

# ────────── Model (trunk only — we use refine output as F) ──────────
print("Loading model trunk + refine...")
m = DinoVolumeQuery(
    n_window=50, n_height_bins=32, n_gripper_bins=32, n_rot_bins=32,
    image_size=IMG_SIZE, pred_size=PRED_SIZE,
    use_eef=True, rotation_mode='1d_pca',
).to(device).eval()
sd = torch.load(CKPT, map_location=device, weights_only=False)
m.load_state_dict(sd['model_state_dict'], strict=False)

# Extract F for each start frame (one at a time to keep memory low)
all_F = []
with torch.no_grad():
    for i in range(N):
        rgb = rgbs[i:i+1].to(device)                                     # (1, 3, 504, 504)
        patch, _ = m._extract_dino_features(rgb)                          # (1, embed, H_p, W_p)
        import torch.nn.functional as F
        feat_up = F.interpolate(patch, size=(PRED_SIZE, PRED_SIZE),
                                 mode='bilinear', align_corners=False)
        F_feat = m.refine(feat_up)[0].cpu()                              # (32, 56, 56)
        all_F.append(F_feat)
F_all = torch.stack(all_F, dim=0)                                        # (N, 32, 56, 56)
print(f"  F_all: {tuple(F_all.shape)}")

# ────────── Joint PCA ──────────
print("Fitting joint 3-PCA over all features...")
# Flatten: (N * 56 * 56, 32)
F_flat = F_all.numpy().transpose(0, 2, 3, 1).reshape(-1, F_all.shape[1])
F_centered = F_flat - F_flat.mean(0, keepdims=True)
u, sv, vt = np.linalg.svd(F_centered, full_matrices=False)
V_proj = vt[:3].T                                                         # (32, 3)
ev = (sv ** 2) / (sv ** 2).sum()
print(f"  EV: PC1={ev[0]:.3f}  PC2={ev[1]:.3f}  PC3={ev[2]:.3f}")

# Project all features to 3 components (still joint, no per-image renormalisation yet)
F_pcs_flat = F_centered @ V_proj                                         # (N*56*56, 3)
# Joint per-component min/max for colour stability across images
lo, hi = np.percentile(F_pcs_flat, [2, 98], axis=0)                      # (3,) each
print(f"  PC ranges: {list(zip(lo, hi))}")
F_pcs_norm = np.clip((F_pcs_flat - lo) / (hi - lo + 1e-8), 0, 1)         # joint normalisation
# Back to (N, 56, 56, 3)
H = W = PRED_SIZE
F_rgb_all = F_pcs_norm.reshape(N, H, W, 3)

# ────────── RGB for display (smith300 stores in [0,1] directly, no ImageNet norm) ──────────
rgb_imgs = [rgbs[i].numpy().clip(0, 1).transpose(1, 2, 0) for i in range(N)]

# ────────── Render grid: 2 rows × N columns ──────────
print(f"Rendering grid: 2 × {N}...")
fig, axes = plt.subplots(2, N, figsize=(1.8 * N, 4.0), gridspec_kw={'hspace': 0.06, 'wspace': 0.04})
if N == 1:
    axes = axes.reshape(2, 1)
for i in range(N):
    axes[0, i].imshow(rgb_imgs[i])
    axes[0, i].set_xticks([]); axes[0, i].set_yticks([])
    for s in axes[0, i].spines.values(): s.set_linewidth(0.5)
    axes[0, i].set_title(f"ep {i}", fontsize=9, pad=3)
    axes[1, i].imshow(F_rgb_all[i], interpolation='nearest')
    axes[1, i].set_xticks([]); axes[1, i].set_yticks([])
    for s in axes[1, i].spines.values(): s.set_linewidth(0.5)
axes[0, 0].set_ylabel('RGB', fontsize=10)
axes[1, 0].set_ylabel('DINO 3-PCA\n(joint)', fontsize=10)
fig.suptitle(f"izzy3 — episode start frames + jointly-fit DINO PCA "
             f"(PC1+PC2+PC3 = {ev[0]+ev[1]+ev[2]:.1%} EV)", fontsize=11)
fig.tight_layout()
fig.savefig(OUT_PNG, dpi=180, bbox_inches='tight', facecolor='white')
plt.close(fig)
print(f"✓ Saved {OUT_PNG}")
