"""Extract a concrete inference example for the figure_maker.

Loads dino_kv_hsin_tsin (S/16+ sin/sin baseline, train_pix=9.32), picks a clean
training sample with n_valid=8 and visually interesting trajectory shape, runs forward,
and dumps everything needed to draw the volume mechanism figure.

Output: /data/cameron/para/paper/figs/data/volume_kv_example.npz with keys:
  rgb            : (3, 504, 504) float32  - the input image
  gt_pix_504     : (8, 2) float32         - GT EEF pixel per future timestep in 504-space
  gt_z_bin       : (8,) int64             - GT z-bin per timestep (in 32-bin space)
  valid          : (8,) bool              - True for all (we picked a full-8 sample)
  height_meters  : (32,) float32          - world-Z value at each bin center (from dataset)
  volume_logits  : (8, 32, 56, 56) float32 - raw output logits (T, Z, H_out, W_out)
  volume_softmax : (8, 32, 56, 56) float32 - softmax over (Z, H, W) per timestep
  argmax_voxel   : (8, 3) int64           - (z*, v*, u*) per timestep at 56-grid scale
  pixel_feats    : (48, 56, 56) float32   - F = the per-pixel key after L2-norm
  keys           : (8, 32, 48) float32    - key(t, z) = t_emb[t] + h_emb[z] after L2-norm
  dino_patches   : (28, 28, 384) float32  - raw DINO patch tokens (for PCA inset)
  dino_pca_rgb   : (28, 28, 3) float32 in [0,1] - precomputed 3-PCA RGB feature viz
  logit_scale    : float                  - exp(temperature) used in scoring
  sample_idx     : int                    - which sample index this is (for reproducibility)
  scale_504_to_grid : float               - = 56/504, multiply 504-coords by this to get grid
  meta           : dict                   - free-form notes including peak-voxel locations etc.
"""
import os, sys, json
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/keygrip/dinov3")
sys.path.insert(0, "/data/cameron/da3_repo/src")
import numpy as np
import torch
import torch.nn.functional as F
from data_da3_volume import Smith300DA3VolumeDataset, DA3_INPUT, N_WINDOW
from model_dino_volume_kv import DinoVolumeKV

CKPT = "/data/cameron/para/libero/checkpoints/da3_dino_kv_hsin_tsin/latest.pth"
OUT  = "/data/cameron/para/paper/figs/data/volume_kv_example.npz"
device = torch.device("cuda")

print("Loading dataset…")
ds = Smith300DA3VolumeDataset(depth_subdir="da3_depth_large")

print("Loading model…")
m = DinoVolumeKV(height_enc='sin', time_enc='sin').to(device).eval()
sd = torch.load(CKPT, map_location=device, weights_only=False)
m.load_state_dict(sd['model_state_dict'])
print(f"  ckpt epoch={sd.get('epoch','?')}, val_v={sd.get('val_v','?')}")

# Pick a sample with all 8 valid and a wide pixel span (interesting trajectory)
candidates = []
for i in range(len(ds)):
    s = ds[i]
    if int(s['gt_pix_valid'].sum()) == 8:
        span = (s['gt_pix_504'].max(0).values - s['gt_pix_504'].min(0).values).norm().item()
        candidates.append((span, i))
# Wide-span samples (top 10% spread) make the trajectory visually clear.
candidates.sort(reverse=True)
# Don't pick the absolute widest — pick something around the 80th percentile for "interesting but typical"
sid = candidates[len(candidates) // 5][1]
print(f"Picked sample {sid} (span={candidates[len(candidates)//5][0]:.1f}px)")

s = ds[sid]
rgb = s['rgb']                                       # (3, 504, 504)
gt_pix = s['gt_pix_504']                             # (8, 2)
gt_z = s['gt_z_bin']                                 # (8,)
valid = s['gt_pix_valid']                            # (8,)
print(f"  gt_pix: \n{gt_pix.numpy()}")
print(f"  gt_z:   {gt_z.numpy()}")

# Forward
with torch.no_grad():
    out = m(rgb.unsqueeze(0).to(device))
    vol = out['volume_logits'][0]                    # (T, Z, h, w) = (8, 32, 56, 56)
    pixel_feats = out['pixel_feats'][0]              # (48, 56, 56)
    dino_patches = out['dino_feats'][0]              # (1, N=784, 384) for S/16+@448 → 28x28

    keys = m._build_keys()                                # (T, Z, 48)
    keys_unit = keys / (keys.norm(dim=-1, keepdim=True) + 1e-6)
    import math
    logit_scale = m.logit_scale.clamp(max=math.log(100.0)).exp().item()
    print(f"  logit_scale (exp(temp)) = {logit_scale:.3f}")

    f_ln = m.pixel_norm(pixel_feats.permute(1, 2, 0)).permute(2, 0, 1)
    f_unit = f_ln / (f_ln.norm(dim=0, keepdim=True) + 1e-6)              # (48, 56, 56)

# Softmax volume (per timestep over Z*H*W)
T, Z, H_, W_ = vol.shape
vol_soft = F.softmax(vol.reshape(T, -1), dim=-1).reshape(T, Z, H_, W_)
argmax_flat = vol.reshape(T, -1).argmax(dim=-1)
arg_z = (argmax_flat // (H_ * W_))
arg_yx = (argmax_flat %  (H_ * W_))
arg_y = arg_yx // W_
arg_x = arg_yx %  W_
argmax_voxel = torch.stack([arg_z, arg_y, arg_x], dim=-1).cpu().numpy()
print(f"  argmax voxels: {argmax_voxel}")

# Per-bin world-Z values for labeling
bin_centers = ds.bin_centers.cpu().numpy()           # (32,) in metres

# DINO patches → PCA RGB inset (28x28)
patches_2d = dino_patches[0].cpu().float().numpy()   # (N=784, 384)
patches_c = patches_2d - patches_2d.mean(axis=0, keepdims=True)
u, sv, vt = np.linalg.svd(patches_c, full_matrices=False)
pcs = patches_c @ vt[:3].T                            # (784, 3)
pcs = (pcs - pcs.min(0)) / (pcs.max(0) - pcs.min(0) + 1e-8)
side = int(round(np.sqrt(pcs.shape[0])))               # 28
dino_pca_rgb = pcs.reshape(side, side, 3).astype(np.float32)

scale_504_to_grid = H_ / DA3_INPUT
gt_pix_grid = gt_pix.numpy() * scale_504_to_grid       # (8, 2) in 56-space
print(f"  gt_pix_grid (56-space): \n{gt_pix_grid}")

# Verify response volume actually peaks near GT for at least some timesteps
print("\nGT vs argmax check (in 56-grid coords):")
for t in range(T):
    gt_u_grid = int(gt_pix_grid[t, 0]); gt_v_grid = int(gt_pix_grid[t, 1])
    am_z, am_y, am_x = argmax_voxel[t]
    print(f"  t={t}: GT=(z={gt_z[t].item():2d},y={gt_v_grid:2d},x={gt_u_grid:2d}) "
          f"argmax=(z={am_z:2d},y={am_y:2d},x={am_x:2d})  Δpx={((am_x-gt_u_grid)**2+(am_y-gt_v_grid)**2)**0.5:.1f}")

meta = {
    "sample_idx":       int(sid),
    "ckpt_path":        CKPT,
    "ckpt_epoch":       int(sd.get('epoch', -1)),
    "n_window":         T,
    "n_height_bins":    Z,
    "pred_size":        H_,
    "dino_grid":        int(side),
    "dino_embed_dim":   patches_2d.shape[-1],
    "key_dim":          int(keys_unit.shape[-1]),
    "img_size":         DA3_INPUT,
    "scale_504_to_grid": float(scale_504_to_grid),
    "logit_scale_exp":  float(logit_scale),
    "min_height_m":     float(ds.min_height),
    "max_height_m":     float(ds.max_height),
}

os.makedirs(os.path.dirname(OUT), exist_ok=True)
np.savez_compressed(
    OUT,
    rgb=rgb.numpy().astype(np.float32),
    gt_pix_504=gt_pix.numpy().astype(np.float32),
    gt_pix_grid=gt_pix_grid.astype(np.float32),
    gt_z_bin=gt_z.numpy().astype(np.int64),
    valid=valid.numpy().astype(bool),
    height_meters=bin_centers.astype(np.float32),
    volume_logits=vol.cpu().float().numpy().astype(np.float32),
    volume_softmax=vol_soft.cpu().float().numpy().astype(np.float32),
    argmax_voxel=argmax_voxel.astype(np.int64),
    pixel_feats_raw=pixel_feats.cpu().float().numpy().astype(np.float32),
    pixel_feats_unit=f_unit.cpu().float().numpy().astype(np.float32),
    keys_unit=keys_unit.cpu().float().numpy().astype(np.float32),
    dino_patches=patches_2d.reshape(side, side, -1).astype(np.float32),
    dino_pca_rgb=dino_pca_rgb,
    meta=json.dumps(meta),
)
print(f"\nSaved: {OUT}")
print(f"File size: {os.path.getsize(OUT)/1e6:.1f} MB")
