"""Render three matched sparse 14×14×14 volumes (same camera, same point geometry):
    A. blue only — empty/structural view
    B. F-PCA colored — feature volume V
    C. softmax heatmap colored — probability volume at T*

The three PNGs are byte-for-byte interchangeable in any figure that wants to
swap "structure → features → response" along the same viewing axis.

Output: feature_volume_blue.png, feature_volume_feat.png, feature_volume_heat.png
"""
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

DATA_DIR = Path("/data/cameron/para/paper/figs/data/query_arch")
NPZ = DATA_DIR / "example.npz"

d = np.load(NPZ, allow_pickle=True)
T_STAR = int(d["T_star"])

# Source tensors (full resolution)
F_PCA_RGB = d["F_pca_rgb"]              # (56, 56, 3) F-PCA color per pixel
SOFTMAX   = d["volume_softmax"][T_STAR]  # (32, 56, 56)
ARGZ, ARGY, ARGX = (
    int(d["argmax_z"][T_STAR]),
    int(d["argmax_y"][T_STAR]),
    int(d["argmax_x"][T_STAR]),
)

# Sparse 14×14×14 grid (uniform stride 4) — same density as the existing
# feature_volume.png so blue/feat/heat versions read as the same volume.
S_XY = 4   # 56/14
S_Z  = 32 // 14   # ≈ 2 (so we'd get 16 heights); cap to 14 by clipping
N_XY = 56 // S_XY                          # 14
Z_INDICES = np.linspace(0, 31, 14).round().astype(int)
N_Z = len(Z_INDICES)

xs = np.arange(0, 56, S_XY)                # length 14
ys = np.arange(0, 56, S_XY)                # length 14

# Build coordinates: shape (14, 14, 14, 3)
zz, yy, xx = np.meshgrid(np.arange(N_Z), np.arange(N_XY), np.arange(N_XY), indexing="ij")
pix_y = ys[yy.reshape(-1)]
pix_x = xs[xx.reshape(-1)]
pix_z_idx = Z_INDICES[zz.reshape(-1)]
N = pix_y.size

print(f"sparse grid: {N_Z} × {N_XY} × {N_XY} = {N} voxels")
print(f"argmax (z, y, x) at full res: ({ARGZ}, {ARGY}, {ARGX})")
print(f"argmax mapped to sparse:  "
      f"z_bin {int(np.abs(Z_INDICES - ARGZ).argmin())}, "
      f"y_idx {ARGY // S_XY}, x_idx {ARGX // S_XY}")


def _style_axes(ax):
    """Matched 3D axis styling for all three figures."""
    ax.set_xlabel("image x", fontsize=11, labelpad=8)
    ax.set_ylabel("image y", fontsize=11, labelpad=8)
    ax.set_zlabel("height z", fontsize=11, labelpad=8)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.view_init(elev=18, azim=-58)
    ax.set_box_aspect((1.0, 1.0, 1.0))
    # Light grey panes; keep them visible to communicate "3D volume" framing
    for pane in (ax.xaxis.pane, ax.yaxis.pane, ax.zaxis.pane):
        pane.set_edgecolor((0, 0, 0, 0.0))
        pane.set_facecolor((1, 1, 1, 0.0))
    ax.grid(False)


def _save(fig, name):
    out = DATA_DIR / f"{name}.png"
    fig.savefig(out, dpi=200, transparent=True, bbox_inches="tight")
    plt.close(fig)
    print(f"  wrote {out}  ({out.stat().st_size/1024:.0f} KB)")
    return out


# Common scatter args
SCATTER_KW = dict(s=14, edgecolor="none")

# Use the sparse grid coords as the scatter positions. The matplotlib 3D
# axis maps the three numeric arrays straight to its (x, y, z) axes; we
# label them as image_x / image_y / height_z respectively.
SX = pix_x.astype(float)
SY = pix_y.astype(float)
SZ = pix_z_idx.astype(float)

# ─── A. Blue-only structural view ──────────────────────────────────
print("rendering blue structural volume...")
fig = plt.figure(figsize=(5.5, 5.0), dpi=200)
ax = fig.add_subplot(111, projection="3d")
ax.scatter(SX, SY, SZ, c="#3b82f6", alpha=0.78, **SCATTER_KW)
_style_axes(ax)
_save(fig, "feature_volume_blue")

# ─── B. F-PCA-colored feature volume ──────────────────────────────
print("rendering F-PCA-colored feature volume...")
# Color each (y, x) cell by its F-PCA RGB; replicate over heights
colors_yx = F_PCA_RGB[pix_y, pix_x]  # (N, 3) in [0, 1]
fig = plt.figure(figsize=(5.5, 5.0), dpi=200)
ax = fig.add_subplot(111, projection="3d")
ax.scatter(SX, SY, SZ, c=colors_yx, alpha=0.78, **SCATTER_KW)
_style_axes(ax)
_save(fig, "feature_volume_feat")

# ─── C. Softmax heatmap response ──────────────────────────────────
# Max-pool softmax over each sparse cell's neighborhood so we don't miss the
# peak. Each sparse cell maps to a 4×4 xy region; for z we take the max over
# the z-bins assigned to this sparse slice (defined by Z_INDICES neighborhood).
print("rendering softmax heatmap response (max-pooled per cell)...")

# Build z-bin assignment: each z-bin 0..31 belongs to the nearest Z_INDICES slot.
z_assign = np.argmin(
    np.abs(np.arange(32)[:, None] - Z_INDICES[None, :]),
    axis=1,
)  # (32,) → idx into Z_INDICES

# For each (z_idx, y_idx, x_idx) sparse cell compute MAX softmax over its
# (xy 4×4) × (matching z bins) neighborhood.
sparse_probs = np.zeros((N_Z, N_XY, N_XY), dtype=np.float32)
for zi in range(N_Z):
    zbins = np.where(z_assign == zi)[0]
    if len(zbins) == 0: continue
    block = SOFTMAX[zbins]  # (n_zbins, 56, 56)
    # Reduce 56 → 14 by 4-stride max-pool on xy
    block = block.reshape(len(zbins), N_XY, 4, N_XY, 4).max(axis=(2, 4))  # (n_zbins, 14, 14)
    sparse_probs[zi] = block.max(axis=0)

probs = sparse_probs.reshape(-1)  # (N,)
p_max = float(probs.max())
print(f"  sparse-cell max softmax (with max-pool): {p_max:.4f}  "
      f"(full softmax peak: {float(SOFTMAX.max()):.4f})")

cmap = plt.cm.plasma
norm = probs / (p_max + 1e-9)
face = cmap(norm)
# Alpha floor 0.18 so structure remains visible; ramp to 1.0 at peak.
face[..., 3] = np.clip(0.18 + 0.90 * norm, 0.18, 1.0)

# Identify the brightest sparse cell (= argmax-containing cell)
peak_flat = int(probs.argmax())
arg_mask = np.zeros(N, dtype=bool)
arg_mask[peak_flat] = True

fig = plt.figure(figsize=(5.5, 5.0), dpi=200)
ax = fig.add_subplot(111, projection="3d")
ax.scatter(SX[~arg_mask], SY[~arg_mask], SZ[~arg_mask],
           c=face[~arg_mask], **SCATTER_KW)
ax.scatter(SX[arg_mask], SY[arg_mask], SZ[arg_mask],
           s=160, c="#22c55e", edgecolors="black", linewidths=1.6, zorder=10)
_style_axes(ax)
_save(fig, "feature_volume_heat")

print("\ndone — three matched-camera sparse volume renders:")
for n in ("feature_volume_blue", "feature_volume_feat", "feature_volume_heat"):
    print(f"  /data/cameron/para/paper/figs/data/query_arch/{n}.png")
