"""Multi-panel inference-time visualization for the libero 2view checkpoint.

Layout (3 rows × 2 cols):
  Row 1: BEV RGB + rainbow pred/GT trajectory  | wrist RGB + rainbow pred/GT trajectory
  Row 2: F_bev 3-PCA                            | F_wrist 3-PCA   (joint PCA basis)
  Row 3: gripper piano-roll dist               | rotation piano-roll dist

Pred keypoints projected:
  - BEV: argmax of volume_logits → (z, y_bev, x_bev) → pixel directly
  - wrist: argmax voxel → recover world XYZ via BEV camera (using bev_xyz_table) → project
    through current wrist extrinsic + K
"""
import os, sys, argparse
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from data_libero_2view import CachedTrajectory2ViewDataset
from model_dino_volume_query_2view import (DinoVolumeQuery2View, PRED_SIZE,
                                            build_bev_world_xyz_table)


def piano_roll(logits, gt_bins, cell_h=6, cell_w=6):
    T, n = logits.shape
    p = torch.softmax(logits, dim=-1).cpu().numpy()
    p_norm = p / (p.max(axis=1, keepdims=True) + 1e-8)
    img8 = (p_norm * 255).astype(np.uint8)
    img = cv2.applyColorMap(img8, cv2.COLORMAP_VIRIDIS)
    for t in range(T):
        gt = int(gt_bins[t])
        img[t, gt] = (0, 0, 255)
    img = cv2.resize(img, (n * cell_w, T * cell_h), interpolation=cv2.INTER_NEAREST)
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)


def project_world_to_pixel(xyz, K_pixel, extrinsic):
    """xyz: (3,) world coord. extrinsic: (4,4) cam→world. K_pixel: (3,3).
       Returns (u, v) pixel coord or None if behind camera."""
    world_to_cam = np.linalg.inv(extrinsic)
    pt_cam = world_to_cam @ np.array([xyz[0], xyz[1], xyz[2], 1.0])
    if pt_cam[2] <= 1e-3:
        return None
    pix_h = K_pixel @ (pt_cam[:3] / pt_cam[2])
    return (float(pix_h[0]), float(pix_h[1]))


def _denorm_rgb(rgb_t):
    mean = np.array([0.485, 0.456, 0.406])[:, None, None]
    std  = np.array([0.229, 0.224, 0.225])[:, None, None]
    return (rgb_t.cpu().numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str,
                   default="/data/cameron/para/libero/checkpoints/libero_2view_libero_spatial_t0_v0/latest.pth")
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero_2view")
    p.add_argument("--task_id",    type=int, default=0)
    p.add_argument("--sample_idx", type=int, default=-1)
    p.add_argument("--out", type=str,
                   default="/data/cameron/para/paper/figs/generated/libero_2view_full_sample.png")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    sd = torch.load(args.checkpoint, map_location=device, weights_only=False)
    n_window   = int(sd.get("n_window", 8))
    image_size = int(sd.get("image_size", 448))
    min_h, max_h = float(sd["min_height"]), float(sd["max_height"])
    min_g, max_g = float(sd["min_grip"]),   float(sd["max_grip"])
    bev_K_norm    = np.asarray(sd["bev_K_norm"],    dtype=np.float32)
    bev_extrinsic = np.asarray(sd["bev_extrinsic"], dtype=np.float32)
    rot_pca_mean = np.asarray(sd["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(sd["rot_pca_axis"], dtype=np.float64)
    rot_pca_min, rot_pca_max = float(sd["rot_pca_min"]), float(sd["rot_pca_max"])

    ds = CachedTrajectory2ViewDataset(
        cache_root=args.cache_root, task_ids=[args.task_id],
        image_size=image_size, n_window=n_window, frame_stride=3, max_demos=0,
    )
    sample_idx = int(torch.randint(0, len(ds), (1,)).item()) if args.sample_idx < 0 else args.sample_idx
    s = ds[sample_idx]
    print(f"Sample {sample_idx} from demo_{int(s['demo_idx'])}, start_t={int(s['start_t'])}")

    m = DinoVolumeQuery2View(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=int(sd["n_rot_bins"]),
        image_size=image_size, pred_size=PRED_SIZE, rotation_mode='1d_pca',
    ).to(device).eval()
    m.load_state_dict(sd["model_state_dict"], strict=False)

    bev_xyz_table = build_bev_world_xyz_table(
        torch.tensor(bev_K_norm,    dtype=torch.float32, device=device),
        torch.tensor(bev_extrinsic, dtype=torch.float32, device=device),
        32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size, device,
    )

    with torch.no_grad():
        rgb_bev_t   = s["rgb_bev"  ].unsqueeze(0).to(device)
        rgb_wrist_t = s["rgb_wrist"].unsqueeze(0).to(device)
        start_pix = s["trajectory_2d_bev"][0:1].to(device)
        wrist_K   = s["wrist_K_norm"].unsqueeze(0).to(device)
        wrist_ext = s["wrist_extrinsic"].unsqueeze(0).to(device)
        out = m(rgb_bev_t, rgb_wrist_t, start_pix, bev_xyz_table, wrist_K, wrist_ext)
        F_bev   = out["pixel_feats"      ][0]
        F_wrist = out["pixel_feats_wrist"][0]
        vol = out["volume_logits"][0]                          # (T, Z, H, W)
        grip = out["gripper_logits"][0]
        rot  = out["rotation_logits"][0]

    T, Z, Hg, Wg = vol.shape
    flat = vol.reshape(T, -1).argmax(dim=-1).cpu().numpy()
    z_bins  = flat // (Hg * Wg)
    yx      = flat %  (Hg * Wg)
    py_grid = yx // Wg
    px_grid = yx %  Wg

    # Pred pixel in BEV (image-space)
    scale = image_size / Hg
    pred_pix_bev = np.stack([(px_grid + 0.5) * scale, (py_grid + 0.5) * scale], axis=-1)
    gt_pix_bev   = s["trajectory_2d_bev"].numpy()              # (T, 2)

    # Pred world XYZ per t via bev_xyz_table (Z, H, W, 3)
    xyz_table_np = bev_xyz_table.cpu().numpy()
    pred_xyz = np.array([xyz_table_np[z_bins[t], py_grid[t], px_grid[t]] for t in range(T)])
    gt_xyz   = s["trajectory_3d"].numpy()

    # Project pred + GT through wrist camera
    wrist_K_pix = s["wrist_K_norm"].numpy().astype(np.float64).copy()
    wrist_K_pix[0] *= image_size; wrist_K_pix[1] *= image_size
    wrist_ext_np = s["wrist_extrinsic"].numpy().astype(np.float64)
    pred_pix_wrist = np.full((T, 2), -100.0)
    gt_pix_wrist   = np.full((T, 2), -100.0)
    for t in range(T):
        pred_proj = project_world_to_pixel(pred_xyz[t], wrist_K_pix, wrist_ext_np)
        gt_proj   = project_world_to_pixel(gt_xyz[t],   wrist_K_pix, wrist_ext_np)
        if pred_proj is not None: pred_pix_wrist[t] = pred_proj
        if gt_proj   is not None: gt_pix_wrist[t]   = gt_proj

    rgb_b = (_denorm_rgb(rgb_bev_t  [0]) * 255).astype(np.uint8).copy()
    rgb_w = (_denorm_rgb(rgb_wrist_t[0]) * 255).astype(np.uint8).copy()
    for t in range(T):
        hue = int(t / max(T - 1, 1) * 170)
        col = cv2.cvtColor(np.uint8([[[hue, 255, 255]]]), cv2.COLOR_HSV2RGB)[0, 0].tolist()
        cv2.circle(rgb_b, (int(pred_pix_bev[t, 0]), int(pred_pix_bev[t, 1])), 5, col, -1)
        cv2.circle(rgb_b, (int(gt_pix_bev  [t, 0]), int(gt_pix_bev  [t, 1])), 5, (255, 255, 255), 2)
        if 0 <= pred_pix_wrist[t, 0] < image_size and 0 <= pred_pix_wrist[t, 1] < image_size:
            cv2.circle(rgb_w, (int(pred_pix_wrist[t, 0]), int(pred_pix_wrist[t, 1])), 5, col, -1)
        if 0 <= gt_pix_wrist[t, 0] < image_size and 0 <= gt_pix_wrist[t, 1] < image_size:
            cv2.circle(rgb_w, (int(gt_pix_wrist[t, 0]), int(gt_pix_wrist[t, 1])), 5, (255, 255, 255), 2)

    # Joint PCA for the F maps
    fb = F_bev.cpu().numpy().transpose(1, 2, 0).reshape(-1, F_bev.shape[0])
    fw = F_wrist.cpu().numpy().transpose(1, 2, 0).reshape(-1, F_wrist.shape[0])
    joint = np.concatenate([fb, fw], axis=0)
    centred = joint - joint.mean(0, keepdims=True)
    u, sv, vt = np.linalg.svd(centred, full_matrices=False)
    V = vt[:3].T
    ev = (sv ** 2) / (sv ** 2).sum()
    pcs = centred @ V
    lo, hi = np.percentile(pcs, [2, 98], axis=0)
    pcs_n = np.clip((pcs - lo) / (hi - lo + 1e-8), 0, 1)
    Hb = F_bev.shape[1]; Wb = F_bev.shape[2]
    pca_b = pcs_n[:Hb * Wb].reshape(Hb, Wb, 3)
    pca_w = pcs_n[Hb * Wb:].reshape(Hb, Wb, 3)

    # GT bins for piano rolls
    gz = ((s["trajectory_3d"][:, 2] - min_h) / max(max_h - min_h, 1e-8)).clamp(0, 1)
    gz_bin = (gz * (32 - 1)).long().numpy()
    gg = ((s["trajectory_gripper"] - min_g) / max(max_g - min_g, 1e-8)).clamp(0, 1)
    gg_bin = (gg * (32 - 1)).long().numpy()
    eul = np.stack([__import__('scipy.spatial.transform', fromlist=['Rotation']).Rotation.from_quat(q).as_euler('xyz')
                     for q in s["trajectory_quat"].numpy()])
    proj = (eul - rot_pca_mean) @ rot_pca_axis
    gr_norm = (proj - rot_pca_min) / max(rot_pca_max - rot_pca_min, 1e-8)
    gr_bin = (np.clip(gr_norm, 0, 1) * (rot.shape[-1] - 1)).astype(np.int64)

    pr_grip = piano_roll(grip.cpu(), gg_bin)
    pr_rot  = piano_roll(rot.cpu(),  gr_bin)

    fig, axes = plt.subplots(3, 2, figsize=(11.5, 13.5),
                              gridspec_kw={'hspace': 0.18, 'wspace': 0.05,
                                            'height_ratios': [1.0, 1.0, 0.55]})
    axes[0, 0].imshow(rgb_b); axes[0, 0].set_title("agentview (BEV) — rainbow pred, white GT", fontsize=10)
    axes[0, 1].imshow(rgb_w); axes[0, 1].set_title("robot0_eye_in_hand — projected pred + GT", fontsize=10)
    axes[1, 0].imshow(pca_b, interpolation='nearest'); axes[1, 0].set_title(f"F_bev 3-PCA (joint, EV={ev[:3].sum():.0%})", fontsize=10)
    axes[1, 1].imshow(pca_w, interpolation='nearest'); axes[1, 1].set_title("F_wrist 3-PCA (same basis)", fontsize=10)
    axes[2, 0].imshow(pr_grip, interpolation='nearest'); axes[2, 0].set_title("gripper bin (T rows × 32 cols, GT red)", fontsize=10)
    axes[2, 1].imshow(pr_rot,  interpolation='nearest'); axes[2, 1].set_title(f"rotation 1D-PCA bin (T × {rot.shape[-1]}, GT red)", fontsize=10)
    for ax in axes.flat: ax.set_xticks([]); ax.set_yticks([])
    fig.suptitle(f"libero_2view ckpt (ep {sd['epoch']}) — sample {sample_idx} from demo_{int(s['demo_idx'])} t={int(s['start_t'])}", fontsize=11)
    fig.tight_layout()
    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(args.out, dpi=150, bbox_inches='tight', facecolor='white')
    print(f"✓ Saved {args.out}")


if __name__ == "__main__":
    main()
