"""Quick scratch renders for the volume model.

Loads the pre-extracted example (`volume_kv_example.npz` — same sample used by figure_maker)
and produces 3 separate PNGs in /data/cameron/para/scratch_figures/:
  1. volume_blue.png      — 8 timesteps, 3D scatter of voxels, all blue, opacity=softmax prob
  2. volume_heatmap.png   — 8 timesteps, 3D scatter colored by softmax (inferno cmap) + opacity
  3. dino_pca.png         — DINO feature map's per-pixel PCA-to-RGB, single image

Voxels are thresholded to the top-K most likely per timestep so the cloud is readable.
"""
import os, json
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm

NPZ = "/data/cameron/para/paper/figs/data/volume_kv_example.npz"
OUT = "/data/cameron/para/scratch_figures"
os.makedirs(OUT, exist_ok=True)

d = np.load(NPZ, allow_pickle=True)
soft = d['volume_softmax']      # (T=8, Z=32, H=56, W=56) softmax per timestep
gt_grid = d['gt_pix_grid']       # (T, 2) in 56-space
gt_z = d['gt_z_bin']             # (T,)
argmax = d['argmax_voxel']       # (T, 3) (z, y, x)
dino_pca = d['dino_pca_rgb']     # (28, 28, 3) in [0, 1]
rgb = d['rgb']                   # (3, 504, 504)
T, Z, H, W = soft.shape
print(f"Volume shape: {soft.shape}")
print(f"Per-timestep peak softmax probs: {soft.reshape(T, -1).max(axis=1).round(3)}")


def all_voxels(p_per_t):
    """Return EVERY voxel (Z, H, W) with a normalised intensity per timestep.
    Uses LOGITS (not softmax) so the cloud has visible structure beyond the peak."""
    Z_, H_, W_ = p_per_t.shape[1:]
    z, y, x = np.meshgrid(np.arange(Z_), np.arange(H_), np.arange(W_), indexing='ij')
    z = z.flatten(); y = y.flatten(); x = x.flatten()
    out = []
    for t in range(p_per_t.shape[0]):
        vals = p_per_t[t].flatten()
        # Percentile clip + normalise to [0, 1]. Anything below the 50th percentile maps
        # to 0 (background); the top end maps to 1 (peak).
        lo, hi = np.percentile(vals, [50, 99.9])
        v = np.clip((vals - lo) / max(hi - lo, 1e-8), 0, 1)
        out.append((z, y, x, v))
    return out

# Use raw logits (not softmax) so the cloud has texture beyond the single peak
logits = d['volume_logits']
per_t = all_voxels(logits)
print(f"Avg voxels per panel: {np.mean([len(t[0]) for t in per_t]):.0f}")


def render_3d_grid(per_t, fig_title, out_path, use_cmap=False):
    fig = plt.figure(figsize=(20, 10))
    cmap = cm.inferno if use_cmap else None
    for t in range(T):
        ax = fig.add_subplot(2, 4, t + 1, projection='3d')
        z, y, x, a = per_t[t]   # `a` already in [0, 1] from percentile clip
        # Drop truly-zero voxels for render speed (after the clip, half are 0).
        mask = a > 0
        z, y, x, a = z[mask], y[mask], x[mask], a[mask]
        if use_cmap:
            colors = cmap(a)
        else:
            colors = np.tile(np.array([[0.255, 0.412, 0.882, 1.0]]), (len(z), 1))
        colors[:, 3] = a ** 0.6   # gentle gamma to make mid-range visible
        ax.scatter(x, y, z, c=colors, s=6, edgecolors='none', depthshade=False)
        ax.set_xlim(0, W); ax.set_ylim(0, H); ax.set_zlim(0, Z)
        ax.invert_yaxis()
        ax.set_xlabel('u'); ax.set_ylabel('v'); ax.set_zlabel('z')
        ax.set_title(f"t={t}", fontsize=11)
        ax.view_init(elev=25, azim=-60)
    fig.suptitle(fig_title, fontsize=14, y=0.98)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140, bbox_inches='tight')
    plt.close()
    print(f"Saved: {out_path}")


# 1. Blue volume
render_3d_grid(per_t, "Volume points per timestep — top 1% voxels, opacity = softmax prob",
                os.path.join(OUT, "volume_blue.png"), use_cmap=False)
# 2. Heatmap-colored volume
render_3d_grid(per_t, "Volume points per timestep — colored by softmax prob (inferno)",
                os.path.join(OUT, "volume_heatmap.png"), use_cmap=True)

# 3. DINO PCA feature map
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(np.transpose(rgb, (1, 2, 0)))
axes[0].set_title("RGB input"); axes[0].axis('off')
# Upsample PCA from 28x28 to 504x504 via nearest for visual scale; keep 28x28 for accuracy.
axes[1].imshow(dino_pca, interpolation='nearest')
axes[1].set_title("DINO patch features → 3-PCA → RGB (28×28)"); axes[1].axis('off')
plt.tight_layout()
out_pca = os.path.join(OUT, "dino_pca.png")
plt.savefig(out_pca, dpi=140, bbox_inches='tight')
plt.close()
print(f"Saved: {out_pca}")

# Also one combined render with everything: RGB + both volume strips + PCA
fig = plt.figure(figsize=(24, 14))
# Row 1: RGB + PCA side by side, smaller
ax_rgb = fig.add_subplot(3, 2, 1)
ax_rgb.imshow(np.transpose(rgb, (1, 2, 0))); ax_rgb.axis('off'); ax_rgb.set_title("RGB input")
ax_pca = fig.add_subplot(3, 2, 2)
ax_pca.imshow(dino_pca, interpolation='nearest'); ax_pca.axis('off')
ax_pca.set_title("DINO patch features (28×28) → 3-PCA → RGB")
# Row 2: 8 blue volume panels — span full width
for t in range(T):
    ax = fig.add_subplot(3, T, T + t + 1, projection='3d')
    z, y, x, p = per_t[t]; a = p / p.max(); alphas = a ** 0.4
    colors = np.tile(np.array([[0.255, 0.412, 0.882, 1.0]]), (len(z), 1))
    colors[:, 3] = alphas
    ax.scatter(x, y, z, c=colors, s=3, edgecolors='none', depthshade=False)
    ax.set_xlim(0, W); ax.set_ylim(0, H); ax.set_zlim(0, Z)
    ax.invert_yaxis(); ax.set_title(f"blue t={t}", fontsize=9)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
# Row 3: 8 heatmap-colored volume panels
for t in range(T):
    ax = fig.add_subplot(3, T, 2 * T + t + 1, projection='3d')
    z, y, x, p = per_t[t]; a = p / p.max()
    colors = cm.inferno(a); colors[:, 3] = a ** 0.4
    ax.scatter(x, y, z, c=colors, s=3, edgecolors='none', depthshade=False)
    ax.set_xlim(0, W); ax.set_ylim(0, H); ax.set_zlim(0, Z)
    ax.invert_yaxis(); ax.set_title(f"heatmap t={t}", fontsize=9)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
fig.suptitle("Volume mechanism per timestep — top 1% voxels (red × = GT, green ★ = argmax)",
              fontsize=14)
plt.tight_layout()
out_all = os.path.join(OUT, "volume_all.png")
plt.savefig(out_all, dpi=130, bbox_inches='tight')
plt.close()
print(f"Saved: {out_all}")
