"""Extract a concrete inference example for the query-MLP architecture figure.

Loads dino_query_izzy3_t50_pca1d_v0 (final ckpt, EEF+CLS query-MLP, 1D PCA rotation),
picks an interesting training sample, runs forward, and renders the 6 panel PNGs
the figure_maker needs.

Outputs in /data/cameron/para/paper/figs/data/query_arch/:
  example.npz          — raw tensors for layout-level scripting
  rgb.png              — input RGB (504×504)
  f_pca.png            — 2D PCA of F feature map (56×56 upsampled)
  f_pca_eef.png        — F PCA with EEF pixel marked in red
  feature_volume.png   — sparse 3D scatter of voxels colored by F PCA
  prob_volume.png      — same voxels colored by softmax probability per t
  arch_overview.png    — debug composite (all panels in one image)
"""
import os, sys, json, math
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap, ListedColormap
from pathlib import Path

os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT
from model_dino_volume_query import DinoVolumeQuery, IMG_SIZE, PRED_SIZE

CKPT = "/data/cameron/para/libero/checkpoints/dino_query_izzy3_t50_pca1d_v0/latest.pth"
PCA  = "/data/cameron/para/libero/rotation_pca_basis_izzy3.npz"
OUT_DIR = Path("/data/cameron/para/paper/figs/data/query_arch")
OUT_DIR.mkdir(parents=True, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# ──────────────────────────── 1. Dataset + model ────────────────────────────
print("Loading dataset (izzy3, n_window=50, 1D PCA rotation)...")
ds = Smith300DA3VolumeDataset(
    sessions_whitelist=['izzy_home_recording_3'],
    n_window=50, frame_stride=1,
    rot_pca_path=PCA,
)
print(f"  → {len(ds)} samples")

print("Loading query-MLP model (rotation_mode=1d_pca)...")
m = DinoVolumeQuery(
    n_window=50, n_height_bins=32, n_gripper_bins=32, n_rot_bins=32,
    image_size=IMG_SIZE, pred_size=PRED_SIZE,
    use_eef=True, rotation_mode='1d_pca',
).to(device).eval()
sd = torch.load(CKPT, map_location=device, weights_only=False)
missing, unexpected = m.load_state_dict(sd['model_state_dict'], strict=False)
print(f"  loaded ckpt epoch={sd['epoch']}; missing={len(missing)}, unexpected={len(unexpected)}")


# ──────────────────────────── 2. Pick a clean sample ────────────────────────────
# Want: visible trajectory span (not stationary), several real future steps (not all padded),
# clean scene composition.
print("Selecting sample with wide trajectory + many valid futures...")
candidates = []
for i in range(len(ds)):
    s = ds[i]
    n_valid = int(s['gt_pix_valid'].sum())
    if n_valid < 15:                      # need at least 15 real future steps
        continue
    span = (s['gt_pix_504'].max(0).values - s['gt_pix_504'].min(0).values).norm().item()
    candidates.append((span, n_valid, i))
candidates.sort(reverse=True)             # widest spans first
# Use the median of the top-20% — visually interesting but not extreme
sid = candidates[len(candidates) // 10][2]
chosen_span = candidates[len(candidates) // 10][0]
chosen_nvalid = candidates[len(candidates) // 10][1]
print(f"  → sample {sid}, span={chosen_span:.1f}px, n_valid={chosen_nvalid}")
s = ds[sid]

rgb       = s['rgb']                            # (3, 504, 504), ImageNet-normalised
gt_pix504 = s['gt_pix_504'].numpy()             # (50, 2)
gt_z      = s['gt_z_bin'].numpy()               # (50,)
gt_grip   = s['gt_grip_bin'].numpy()
gt_rot    = s['gt_rot_bin'].numpy()
valid     = s['gt_pix_valid'].numpy()
start_pix = s['start_pix_504']                  # (2,)


# ──────────────────────────── 3. Forward pass ────────────────────────────
with torch.no_grad():
    out = m(rgb.unsqueeze(0).to(device), start_pix=start_pix.unsqueeze(0).to(device))
    vol     = out['volume_logits'][0].cpu()                                # (T, Z, H, W)
    F_feat  = out['pixel_feats'][0].cpu()                                  # (32, H, W)

T, Z, H, W = vol.shape
print(f"  vol={tuple(vol.shape)}  F={tuple(F_feat.shape)}")

# Softmax-over-(Z, H, W) per-t for the probability volume
vol_flat = vol.reshape(T, -1)
prob_flat = torch.softmax(vol_flat, dim=-1)
prob = prob_flat.reshape(T, Z, H, W)                                       # (T, Z, H, W)

# Per-t argmax voxel
argmax_flat = vol_flat.argmax(dim=-1)                                      # (T,)
am_z = (argmax_flat // (H * W)).numpy()
am_y = ((argmax_flat % (H * W)) // W).numpy()
am_x = (argmax_flat % W).numpy()
print(f"  argmax voxels (z, y, x) per-t (first 5): {list(zip(am_z, am_y, am_x))[:5]}")


# ──────────────────────────── 4. F → 3-PCA RGB image ────────────────────────────
print("Computing F-PCA (32 → 3) for feature viz...")
F_flat = F_feat.numpy().reshape(F_feat.shape[0], -1).T                     # (H*W, 32)
F_centered = F_flat - F_flat.mean(0, keepdims=True)
u, sv, vt = np.linalg.svd(F_centered, full_matrices=False)
pcs = F_centered @ vt[:3].T                                                # (H*W, 3)
# Robust per-component normalisation to [0, 1]
lo, hi = np.percentile(pcs, [2, 98], axis=0)
F_rgb = np.clip((pcs - lo) / (hi - lo + 1e-8), 0, 1).reshape(H, W, 3)      # (56, 56, 3)
print(f"  F_PCA EV: PC1={sv[0]**2/(sv**2).sum():.3f} PC2={sv[1]**2/(sv**2).sum():.3f} PC3={sv[2]**2/(sv**2).sum():.3f}")


# ──────────────────────────── 5. Pick the timestep to show ────────────────────────────
# Use t around 1/3 through the trajectory — gives a nice mid-motion point.
# Pick a t with a clearly localised softmax (top-1 voxel takes a meaningful fraction).
peak_frac = prob_flat.max(dim=-1).values                                   # (T,)
T_STAR = int(peak_frac[:chosen_nvalid].argmax().item() // 3)               # bias toward early steps
T_STAR = max(2, min(T_STAR, chosen_nvalid - 1))
print(f"  T_STAR={T_STAR}  peak_frac={peak_frac[T_STAR].item():.4f}")


# ──────────────────────────── 6. De-normalise RGB ────────────────────────────
mean = np.array([0.485, 0.456, 0.406])[:, None, None]
std  = np.array([0.229, 0.224, 0.225])[:, None, None]
rgb_img = (rgb.numpy() * std + mean).clip(0, 1).transpose(1, 2, 0)
rgb_uint8 = (rgb_img * 255).astype(np.uint8)


# ──────────────────────────── 7. Render PNGs ────────────────────────────
def save_png(name, arr_or_fig, dpi=200, transparent=True):
    p = OUT_DIR / name
    if hasattr(arr_or_fig, 'savefig'):
        arr_or_fig.savefig(p, dpi=dpi, bbox_inches='tight',
                           transparent=transparent, pad_inches=0.0)
        plt.close(arr_or_fig)
    else:
        plt.imsave(p, arr_or_fig)
    print(f"  saved {p.name}")
    return p


print("Rendering panels...")

# 7.1 — RGB
plt.imsave(OUT_DIR / "rgb.png", rgb_uint8)
print(f"  saved rgb.png")

# 7.2 — F PCA
plt.imsave(OUT_DIR / "f_pca.png", (F_rgb * 255).astype(np.uint8))
print(f"  saved f_pca.png")

# 7.3 — F PCA with EEF pixel marker
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(F_rgb, interpolation='nearest')
ax.set_xticks([]); ax.set_yticks([])
sp_x_grid = float(start_pix[0]) * (W / DA3_INPUT)
sp_y_grid = float(start_pix[1]) * (H / DA3_INPUT)
ax.scatter([sp_x_grid], [sp_y_grid], s=140, edgecolor='white', facecolor='red',
            linewidth=2.0, zorder=10)
ax.scatter([sp_x_grid], [sp_y_grid], s=520, edgecolor='red', facecolor='none',
            linewidth=2.0, zorder=10)
for spine in ax.spines.values(): spine.set_visible(False)
fig.tight_layout(pad=0)
save_png("f_pca_eef.png", fig, dpi=200, transparent=True)


# 7.4 — Sparse 3D voxel scatter, colored by F PCA
def render_volume_3d(colors_zyx, alphas_zyx, out_name,
                     stride_xy=4, stride_z=4, marker_size=22,
                     elev=22, azim=-60):
    """colors_zyx: (Z, H, W, 3) in [0,1]; alphas_zyx: (Z, H, W) in [0,1]."""
    fig = plt.figure(figsize=(5.5, 5.5))
    ax = fig.add_subplot(111, projection='3d')
    # Subsample
    z_ids = list(range(0, Z, stride_z))
    y_ids = list(range(0, H, stride_xy))
    x_ids = list(range(0, W, stride_xy))
    pts_x, pts_y, pts_z, cols, alphs = [], [], [], [], []
    for zi in z_ids:
        for yi in y_ids:
            for xi in x_ids:
                pts_x.append(xi); pts_y.append(yi); pts_z.append(zi)
                cols.append(colors_zyx[zi, yi, xi])
                alphs.append(float(alphas_zyx[zi, yi, xi]))
    pts_x = np.array(pts_x); pts_y = np.array(pts_y); pts_z = np.array(pts_z)
    cols = np.array(cols)
    alphs = np.array(alphs)
    # Render in image-aligned axes — Y flipped so image top is at "back"
    rgba = np.concatenate([cols, alphs[:, None]], axis=-1)
    ax.scatter(pts_x, pts_y, pts_z, c=rgba, s=marker_size, marker='o', depthshade=False,
               edgecolors='none')
    ax.set_xlim(0, W); ax.set_ylim(H, 0); ax.set_zlim(0, Z)
    ax.set_xlabel('image x'); ax.set_ylabel('image y'); ax.set_zlabel('height z')
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.view_init(elev=elev, azim=azim)
    ax.grid(False)
    for axis in (ax.xaxis, ax.yaxis, ax.zaxis):
        axis.set_pane_color((1, 1, 1, 0))
    fig.tight_layout(pad=0.1)
    save_png(out_name, fig, dpi=200, transparent=True)


# Feature volume — each voxel inherits F PCA color from its (y, x), uniform alpha
feature_colors = np.broadcast_to(F_rgb[None, :, :, :], (Z, H, W, 3)).copy()  # (Z, H, W, 3)
feature_alphas = np.full((Z, H, W), 0.55, dtype=np.float32)
render_volume_3d(feature_colors, feature_alphas, "feature_volume.png",
                 stride_xy=5, stride_z=4, marker_size=26)

# Probability volume at T_STAR — show ONLY the top-K voxels by probability,
# coloured bright plasma. Rest invisible. This makes the heatmap concentration
# pop on a white/transparent background.
p_t = prob[T_STAR].numpy()                                                   # (Z, H, W)
import matplotlib.cm as cm
plasma = cm.get_cmap('plasma')
# Use the SUBSAMPLED voxel set as the candidate pool — top-K within that set
# means roughly K voxels actually appear in the render.
stride_xy_p = 4; stride_z_p = 3
candidates = []
for zi in range(0, Z, stride_z_p):
    for yi in range(0, H, stride_xy_p):
        for xi in range(0, W, stride_xy_p):
            candidates.append((zi, yi, xi, p_t[zi, yi, xi]))
candidates.sort(key=lambda c: -c[3])
TOPN = 25
topn = candidates[:TOPN]
# Build sparse color/alpha arrays over the full volume; only top-N entries non-zero alpha
prob_colors = np.zeros((Z, H, W, 3), dtype=np.float32)
prob_alphas = np.zeros((Z, H, W),    dtype=np.float32)
# Normalise the top-N probs to [0, 1] for colormap
p_vals = np.array([c[3] for c in topn])
p_min, p_max = float(p_vals.min()), float(p_vals.max())
for (zi, yi, xi, pv) in topn:
    norm = (pv - p_min) / max(p_max - p_min, 1e-12)
    prob_colors[zi, yi, xi] = plasma(norm)[:3]
    prob_alphas[zi, yi, xi] = float(0.65 + 0.35 * norm)
render_volume_3d(prob_colors, prob_alphas, "prob_volume.png",
                 stride_xy=stride_xy_p, stride_z=stride_z_p, marker_size=70)

# 7.5 — Argmax marker volume: same as probability volume but with the argmax voxel highlighted
fig = plt.figure(figsize=(5.5, 5.5))
ax = fig.add_subplot(111, projection='3d')
# Use the SAME strides as the top-K probability render so the hot voxels appear here too
z_ids = list(range(0, Z, stride_z_p))
y_ids = list(range(0, H, stride_xy_p))
x_ids = list(range(0, W, stride_xy_p))
pts_x, pts_y, pts_z, cols, alphs = [], [], [], [], []
for zi in z_ids:
    for yi in y_ids:
        for xi in x_ids:
            if prob_alphas[zi, yi, xi] <= 0.01:                       # skip empty voxels
                continue
            pts_x.append(xi); pts_y.append(yi); pts_z.append(zi)
            cols.append(prob_colors[zi, yi, xi])
            alphs.append(float(prob_alphas[zi, yi, xi]))
pts_x = np.array(pts_x); pts_y = np.array(pts_y); pts_z = np.array(pts_z)
cols = np.array(cols); alphs = np.array(alphs)
if len(pts_x) > 0:
    rgba = np.concatenate([cols, alphs[:, None]], axis=-1)
    ax.scatter(pts_x, pts_y, pts_z, c=rgba, s=70, marker='o', depthshade=False, edgecolors='none')
# Argmax voxel in bright green
ax.scatter([am_x[T_STAR]], [am_y[T_STAR]], [am_z[T_STAR]],
            s=300, marker='o', c=[(0.2, 1.0, 0.3, 1.0)], edgecolors='black', linewidths=2.0,
            depthshade=False, zorder=20)
ax.set_xlim(0, W); ax.set_ylim(H, 0); ax.set_zlim(0, Z)
ax.set_xlabel('image x'); ax.set_ylabel('image y'); ax.set_zlabel('height z')
ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
ax.view_init(elev=22, azim=-60)
ax.grid(False)
for axis in (ax.xaxis, ax.yaxis, ax.zaxis):
    axis.set_pane_color((1, 1, 1, 0))
fig.tight_layout(pad=0.1)
save_png("prob_volume_argmax.png", fig, dpi=200, transparent=True)


# 7.6 — Debug composite (all panels in one) for quick visual sanity
fig, axes = plt.subplots(2, 3, figsize=(13, 9))
axes[0, 0].imshow(rgb_uint8); axes[0, 0].set_title('rgb');             axes[0, 0].axis('off')
axes[0, 1].imshow(F_rgb);     axes[0, 1].set_title('F (DINO PCA)');     axes[0, 1].axis('off')
axes[0, 1].scatter([sp_x_grid], [sp_y_grid], s=140, edgecolor='white', facecolor='red')
axes[0, 1].scatter([sp_x_grid], [sp_y_grid], s=520, edgecolor='red', facecolor='none', linewidth=2)
axes[0, 2].imshow(F_rgb);     axes[0, 2].set_title(f'argmax pixel @ t={T_STAR}'); axes[0, 2].axis('off')
# Project argmax voxel into the 2D F coords
axes[0, 2].scatter([am_x[T_STAR]], [am_y[T_STAR]], s=240, edgecolor='black', facecolor='lime')
# Bottom row: 3D volumes
for idx, name in enumerate(['feature_volume.png', 'prob_volume.png', 'prob_volume_argmax.png']):
    img = plt.imread(str(OUT_DIR / name))
    axes[1, idx].imshow(img); axes[1, idx].set_title(name.replace('.png', '')); axes[1, idx].axis('off')
fig.suptitle(f'Query-MLP architecture figure intermediates (sample {sid}, t*={T_STAR})')
fig.tight_layout()
save_png("arch_overview.png", fig, dpi=150, transparent=False)


# ──────────────────────────── 8. NPZ dump ────────────────────────────
np.savez(OUT_DIR / "example.npz",
         rgb=rgb.numpy(),
         gt_pix_504=gt_pix504, gt_z=gt_z, gt_grip=gt_grip, gt_rot=gt_rot,
         valid=valid,
         start_pix_504=start_pix.numpy(),
         volume_logits=vol.numpy(),
         volume_softmax=prob.numpy(),
         argmax_z=am_z, argmax_y=am_y, argmax_x=am_x,
         pixel_feats=F_feat.numpy(),
         F_pca_rgb=F_rgb,
         T_star=T_STAR,
         sample_idx=sid,
         meta=json.dumps({
             "sample_idx": sid,
             "n_valid": int(chosen_nvalid),
             "span_px": float(chosen_span),
             "T_star": T_STAR,
             "Z": Z, "H": H, "W": W, "T": T,
             "ckpt": CKPT,
             "rotation_mode": "1d_pca",
             "peak_frac_at_T_star": float(peak_frac[T_STAR].item()),
         }))
print(f"Saved {OUT_DIR / 'example.npz'}")

print(f"\n✓ All intermediates in {OUT_DIR}")
