"""Multi-frame viz for the libero 2view ckpt — N frames across one episode.

Layout (4 rows × N cols, where N = number of frames sampled):
  Row 1: BEV RGB + rainbow pred/GT trajectories
  Row 2: wrist RGB + rainbow pred/GT trajectories
  Row 3: F_bev 3-PCA (single joint PCA basis fit across ALL frames + both views)
  Row 4: F_wrist 3-PCA (same basis)

The joint PCA spans every frame + both views so the colors mean the same thing across
the entire grid — you can see how feature regions move + persist as the episode plays out.
"""
import os, sys, argparse
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
import numpy as np
import torch
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.spatial.transform import Rotation as ScipyR

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from data_libero_2view import CachedTrajectory2ViewDataset
from model_dino_volume_query_2view import (DinoVolumeQuery2View, PRED_SIZE,
                                            build_bev_world_xyz_table)


def project_world_to_pixel(xyz, K_pixel, extrinsic):
    world_to_cam = np.linalg.inv(extrinsic)
    pt_cam = world_to_cam @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
    if pt_cam[2] <= 1e-3:
        return None
    pix_h = K_pixel @ (pt_cam[:3] / pt_cam[2])
    return (float(pix_h[0]), float(pix_h[1]))


def _denorm_rgb(rgb_t):
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    return (rgb_t.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str,
                   default="/data/cameron/para/libero/checkpoints/libero_2view_libero_spatial_t0_v0/latest.pth")
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero_2view")
    p.add_argument("--task_id",    type=int, default=0)
    p.add_argument("--demo_idx",   type=int, default=-1, help="demo index in cache (random if -1)")
    p.add_argument("--n_frames",   type=int, default=5)
    p.add_argument("--out", type=str,
                   default="/data/cameron/para/paper/figs/generated/libero_2view_episode.png")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sd = torch.load(args.checkpoint, map_location=device, weights_only=False)
    n_window   = int(sd.get("n_window", 8))
    image_size = int(sd.get("image_size", 448))
    min_h, max_h = float(sd["min_height"]), float(sd["max_height"])
    bev_K_norm    = np.asarray(sd["bev_K_norm"],    dtype=np.float32)
    bev_extrinsic = np.asarray(sd["bev_extrinsic"], dtype=np.float32)

    ds = CachedTrajectory2ViewDataset(
        cache_root=args.cache_root, task_ids=[args.task_id],
        image_size=image_size, n_window=n_window, frame_stride=3, max_demos=0,
    )
    if args.demo_idx < 0:
        demo_idx = int(torch.randint(0, len(ds.demos), (1,)).item())
    else:
        demo_idx = args.demo_idx
    demo = ds.demos[demo_idx]
    T_demo = demo["T"]
    # Pick n_frames evenly spaced timestamps from inside the demo (skip last few — need future window)
    max_t = max(T_demo - n_window * 3 - 1, 1)
    ts = np.linspace(0, max_t, args.n_frames).astype(int)
    # Convert demo_idx + t → sample index in the flat list
    # The dataset enumerates samples = [(demo_idx, t) for demo, t in...]. Compute offset.
    offset = sum(d["T"] for d in ds.demos[:demo_idx])
    sample_idxs = [offset + int(t) for t in ts]
    print(f"Demo {demo_idx} (T={T_demo}), sampling frames {ts.tolist()}")

    m = DinoVolumeQuery2View(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=int(sd["n_rot_bins"]),
        image_size=image_size, pred_size=PRED_SIZE, rotation_mode='1d_pca',
    ).to(device).eval()
    m.load_state_dict(sd["model_state_dict"], strict=False)

    bev_xyz_table = build_bev_world_xyz_table(
        torch.tensor(bev_K_norm,    dtype=torch.float32, device=device),
        torch.tensor(bev_extrinsic, dtype=torch.float32, device=device),
        32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size, device,
    )
    xyz_table_np = bev_xyz_table.cpu().numpy()

    # Pass 1: forward each frame, collect features + pred/gt projections
    rows_data = []                                                            # list of dicts
    F_bev_all = []; F_wrist_all = []
    for sidx in sample_idxs:
        s = ds[sidx]
        with torch.no_grad():
            rgb_bev_t   = s["rgb_bev"  ].unsqueeze(0).to(device)
            rgb_wrist_t = s["rgb_wrist"].unsqueeze(0).to(device)
            start_pix = s["trajectory_2d_bev"][0:1].to(device)
            wrist_K   = s["wrist_K_norm"].unsqueeze(0).to(device)
            wrist_ext = s["wrist_extrinsic"].unsqueeze(0).to(device)
            out = m(rgb_bev_t, rgb_wrist_t, start_pix, bev_xyz_table, wrist_K, wrist_ext)
            F_bev   = out["pixel_feats"      ][0].cpu()
            F_wrist = out["pixel_feats_wrist"][0].cpu()
            vol = out["volume_logits"][0]

        T, Z, Hg, Wg = vol.shape
        flat = vol.reshape(T, -1).argmax(dim=-1).cpu().numpy()
        z_bins  = flat // (Hg * Wg)
        yx      = flat %  (Hg * Wg)
        py_grid = yx // Wg
        px_grid = yx %  Wg
        scale = image_size / Hg
        pred_pix_bev = np.stack([(px_grid + 0.5) * scale, (py_grid + 0.5) * scale], axis=-1)
        gt_pix_bev   = s["trajectory_2d_bev"].numpy()
        pred_xyz = np.array([xyz_table_np[z_bins[t], py_grid[t], px_grid[t]] for t in range(T)])
        gt_xyz   = s["trajectory_3d"].numpy()

        wrist_K_pix = s["wrist_K_norm"].numpy().astype(np.float64).copy()
        wrist_K_pix[0] *= image_size; wrist_K_pix[1] *= image_size
        wrist_ext_np = s["wrist_extrinsic"].numpy().astype(np.float64)
        pred_pix_wrist = np.full((T, 2), -100.0)
        gt_pix_wrist   = np.full((T, 2), -100.0)
        for t in range(T):
            pp = project_world_to_pixel(pred_xyz[t], wrist_K_pix, wrist_ext_np)
            gp = project_world_to_pixel(gt_xyz[t],   wrist_K_pix, wrist_ext_np)
            if pp is not None: pred_pix_wrist[t] = pp
            if gp is not None: gt_pix_wrist[t]   = gp

        rows_data.append({
            "rgb_bev_t":   s["rgb_bev"],
            "rgb_wrist_t": s["rgb_wrist"],
            "pred_pix_bev": pred_pix_bev,
            "gt_pix_bev":   gt_pix_bev,
            "pred_pix_wrist": pred_pix_wrist,
            "gt_pix_wrist":   gt_pix_wrist,
            "start_t":     int(s["start_t"]),
        })
        F_bev_all.append(F_bev); F_wrist_all.append(F_wrist)

    # Joint PCA across ALL frames + both views
    flat_feats = []
    for fb, fw in zip(F_bev_all, F_wrist_all):
        flat_feats.append(fb.numpy().transpose(1, 2, 0).reshape(-1, fb.shape[0]))
        flat_feats.append(fw.numpy().transpose(1, 2, 0).reshape(-1, fw.shape[0]))
    joint = np.concatenate(flat_feats, axis=0)
    centred = joint - joint.mean(0, keepdims=True)
    u, sv, vt = np.linalg.svd(centred, full_matrices=False)
    V = vt[:3].T
    ev = (sv ** 2) / (sv ** 2).sum()
    pcs = centred @ V
    lo, hi = np.percentile(pcs, [2, 98], axis=0)
    pcs_n = np.clip((pcs - lo) / (hi - lo + 1e-8), 0, 1)

    # Slice back into per-frame views
    n_pix_per_view = F_bev_all[0].shape[1] * F_bev_all[0].shape[2]
    pca_bev_per_frame = []
    pca_wrist_per_frame = []
    cursor = 0
    Hp = F_bev_all[0].shape[1]; Wp = F_bev_all[0].shape[2]
    for _ in rows_data:
        pca_bev_per_frame.append(pcs_n[cursor:cursor + n_pix_per_view].reshape(Hp, Wp, 3));   cursor += n_pix_per_view
        pca_wrist_per_frame.append(pcs_n[cursor:cursor + n_pix_per_view].reshape(Hp, Wp, 3)); cursor += n_pix_per_view

    # Render
    N = args.n_frames
    fig, axes = plt.subplots(4, N, figsize=(2.6 * N, 10.5),
                              gridspec_kw={'hspace': 0.10, 'wspace': 0.04})
    if N == 1: axes = axes.reshape(4, 1)
    for c, row in enumerate(rows_data):
        rgb_b = (_denorm_rgb(row["rgb_bev_t"])   * 255).astype(np.uint8).copy()
        rgb_w = (_denorm_rgb(row["rgb_wrist_t"]) * 255).astype(np.uint8).copy()
        T = row["pred_pix_bev"].shape[0]
        for t in range(T):
            hue = int(t / max(T - 1, 1) * 170)
            col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
            cv2.circle(rgb_b, (int(row["pred_pix_bev"][t, 0]), int(row["pred_pix_bev"][t, 1])), 4, col, -1)
            cv2.circle(rgb_b, (int(row["gt_pix_bev"  ][t, 0]), int(row["gt_pix_bev"  ][t, 1])), 4, (255, 255, 255), 2)
            ppw = row["pred_pix_wrist"][t]; gpw = row["gt_pix_wrist"][t]
            if 0 <= ppw[0] < image_size and 0 <= ppw[1] < image_size:
                cv2.circle(rgb_w, (int(ppw[0]), int(ppw[1])), 4, col, -1)
            if 0 <= gpw[0] < image_size and 0 <= gpw[1] < image_size:
                cv2.circle(rgb_w, (int(gpw[0]), int(gpw[1])), 4, (255, 255, 255), 2)
        axes[0, c].imshow(rgb_b); axes[0, c].set_title(f"frame {row['start_t']}", fontsize=9, pad=2)
        axes[1, c].imshow(rgb_w)
        axes[2, c].imshow(pca_bev_per_frame[c],   interpolation='nearest')
        axes[3, c].imshow(pca_wrist_per_frame[c], interpolation='nearest')
        for r in range(4):
            axes[r, c].set_xticks([]); axes[r, c].set_yticks([])
    axes[0, 0].set_ylabel("BEV +\ntraj",    fontsize=10)
    axes[1, 0].set_ylabel("wrist +\ntraj",  fontsize=10)
    axes[2, 0].set_ylabel("F_bev\nPCA",     fontsize=10)
    axes[3, 0].set_ylabel("F_wrist\nPCA",   fontsize=10)
    fig.suptitle(f"libero_2view ckpt (ep {sd['epoch']}) — demo {demo_idx}, {N} frames, "
                 f"joint PCA across all frames+views (EV={ev[:3].sum():.0%})", fontsize=11)
    fig.tight_layout()
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(args.out, dpi=140, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved {args.out}")


if __name__ == "__main__":
    main()
