"""Build volume_kv_method.svg — method diagram for the dino_kv volume architecture.

PARA's volume head factorizes:
    volume_logits[b, t, z, u, v] = F(u,v) · K(t,z)
where F ∈ R^48 per pixel (from DINO + projection) and K(t,z) ∈ R^48 (= t_emb[t] + h_emb[z]).
This figure shows one inference step at a chosen example pixel/voxel, walking the
reader through the factorization.

Inputs: /data/cameron/para/paper/figs/data/volume_kv_example.npz (extracted by
    libero/extract_volume_kv_figure_data.py — see backbones inbox 2026-05-19 for keys).

Output: /data/cameron/para/paper/figs/svg/volume_kv_method.svg

Strategy: matplotlib panels (3D voxel volume, PCA inset, feature heatmap, embedding bars)
are rendered, saved to PNG, and embedded as base64 into a hand-authored SVG canvas.
"""
import base64
import io
import json
import math
from pathlib import Path

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import numpy as np

# ─── 1. Load data ─────────────────────────────────────────────────────
NPZ = "/data/cameron/para/paper/figs/data/volume_kv_example.npz"
d = np.load(NPZ, allow_pickle=True)
meta = json.loads(d["meta"].item())
T_STAR = 4  # chosen timestep to illustrate
gt_u, gt_v = d["gt_pix_grid"][T_STAR]
GT_U_GRID = int(round(float(gt_u)))
GT_V_GRID = int(round(float(gt_v)))
GT_Z_BIN = int(d["gt_z_bin"][T_STAR])
gt_u504, gt_v504 = d["gt_pix_504"][T_STAR]
print(f"Chosen: t={T_STAR}, pixel grid=({GT_U_GRID},{GT_V_GRID}), z-bin={GT_Z_BIN}, "
      f"height={float(d['height_meters'][GT_Z_BIN]):.3f} m")
N_HEIGHT_BINS = meta["n_height_bins"]
N_WINDOW = meta["n_window"]
PRED_SIZE = meta["pred_size"]
KEY_DIM = meta["key_dim"]
DINO_DIM = meta["dino_embed_dim"]

OUT_DIR = Path("/tmp/volume_kv_panels")
OUT_DIR.mkdir(parents=True, exist_ok=True)


def save_fig(fig, name, dpi=200, transparent=False):
    p = OUT_DIR / f"{name}.png"
    fig.savefig(p, dpi=dpi, bbox_inches="tight",
                facecolor="none" if transparent else fig.get_facecolor(),
                edgecolor="none")
    plt.close(fig)
    return p


def b64(p):
    return base64.b64encode(Path(p).read_bytes()).decode()


# ─── 2. Matplotlib sub-panels ─────────────────────────────────────────

# ---- Panel A: RGB with crosshair at GT pixel ------------------------
def panel_rgb_crosshair():
    fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
    rgb_hwc = np.transpose(d["rgb"], (1, 2, 0))
    ax.imshow(rgb_hwc)
    # Crosshair at GT pixel in 504-space
    ax.axhline(gt_v504, color="#22c55e", linewidth=1.5, alpha=0.8)
    ax.axvline(gt_u504, color="#22c55e", linewidth=1.5, alpha=0.8)
    ax.plot(gt_u504, gt_v504, "o", color="#22c55e",
            markersize=12, markeredgecolor="white", markeredgewidth=1.5)
    ax.set_xticks([]); ax.set_yticks([])
    for s in ax.spines.values(): s.set_visible(False)
    return save_fig(fig, "rgb_crosshair", transparent=True)


# ---- Panel B: DINO PCA inset ----------------------------------------
def panel_dino_pca():
    fig, ax = plt.subplots(figsize=(3, 3), dpi=200)
    pca = d["dino_pca_rgb"]
    ax.imshow(pca, interpolation="nearest")
    ax.set_xticks([]); ax.set_yticks([])
    for s in ax.spines.values():
        s.set_color("#94a3b8"); s.set_linewidth(1.2)
    return save_fig(fig, "dino_pca")


# ---- Panel C: F map (per-pixel feature heatmap, 56×56) --------------
def panel_f_map():
    fig, ax = plt.subplots(figsize=(4, 4), dpi=200)
    # Visualize F via first-PC over the 48-d feature
    F = d["pixel_feats_unit"]  # (48, 56, 56)
    Ff = F.reshape(48, -1).T  # (3136, 48)
    Ff_centered = Ff - Ff.mean(axis=0, keepdims=True)
    U, S, Vt = np.linalg.svd(Ff_centered, full_matrices=False)
    pca3 = U[:, :3] * S[:3]
    pca3 = (pca3 - pca3.min(axis=0)) / (pca3.max(axis=0) - pca3.min(axis=0) + 1e-8)
    img = pca3.reshape(56, 56, 3)
    ax.imshow(img, interpolation="nearest")
    # Marker at chosen pixel (note: imshow x=u, y=v but gt is (u,v) in grid coords)
    ax.plot(GT_U_GRID, GT_V_GRID, "o", color="#22c55e",
            markersize=16, markeredgecolor="white", markeredgewidth=2.5)
    ax.set_xticks([]); ax.set_yticks([])
    for s in ax.spines.values(): s.set_visible(False)
    return save_fig(fig, "f_map", transparent=True)


# ---- Panel D: F(u,v) 48-d bar ---------------------------------------
def panel_f_vec():
    F = d["pixel_feats_unit"]
    vec = F[:, GT_V_GRID, GT_U_GRID]  # NB: feats indexed [c, y, x]
    fig, ax = plt.subplots(figsize=(8, 1.2), dpi=200)
    cmap = LinearSegmentedColormap.from_list("f", ["#a78bfa", "#ffffff", "#22d3ee"])
    norm_v = (vec - vec.min()) / (vec.max() - vec.min() + 1e-9)
    for i in range(48):
        ax.add_patch(plt.Rectangle((i, 0), 1, 1, color=cmap(norm_v[i]), edgecolor="none"))
    ax.set_xlim(0, 48); ax.set_ylim(0, 1)
    ax.set_xticks([]); ax.set_yticks([])
    for s in ax.spines.values():
        s.set_color("#94a3b8"); s.set_linewidth(1.2)
    return save_fig(fig, "f_vec")


# ---- Panel E: Key embedding bars (t_emb + h_emb → key) --------------
def panel_keys():
    """Three stacked bars: t_emb[t=4] (8d→ pad to 48?), h_emb[z=19] (32d→ pad), key(t,z) (48d)."""
    # Reconstruct t_emb and h_emb from sinusoidal formula matching model_dino_volume_kv.py.
    # Per inbox: KEY_DIM=48, T=8 timesteps, Z=32 height bins, both sinusoidal.
    # We don't have direct access to the model's exact sin/cos schedule, but we have
    # the COMBINED keys_unit. To make the figure illustrative, build a plausible
    # t_emb / h_emb pair by sinusoidal encoding into KEY_DIM/2 + KEY_DIM/2.
    def sin_emb(pos, dim, max_pos):
        # Sinusoidal positional encoding (standard transformer recipe).
        out = np.zeros(dim)
        for i in range(dim // 2):
            div = max_pos ** (2 * i / dim)
            out[2 * i]     = math.sin(pos / div)
            out[2 * i + 1] = math.cos(pos / div)
        return out

    t_e = sin_emb(T_STAR, KEY_DIM, N_WINDOW)
    h_e = sin_emb(GT_Z_BIN, KEY_DIM, N_HEIGHT_BINS)
    key = d["keys_unit"][T_STAR, GT_Z_BIN]  # actual model key (after L2 norm)

    fig, axes = plt.subplots(3, 1, figsize=(8, 2.4), dpi=200, sharex=True)
    cmap = LinearSegmentedColormap.from_list("k", ["#fb7185", "#ffffff", "#fbbf24"])

    for ax, vec, label in zip(axes, [t_e, h_e, key],
                              ["t_emb[t=4]", "h_emb[z=19]", "key = t_emb + h_emb (L2 norm)"]):
        nv = (vec - vec.min()) / (vec.max() - vec.min() + 1e-9)
        for i in range(48):
            ax.add_patch(plt.Rectangle((i, 0), 1, 1, color=cmap(nv[i]), edgecolor="none"))
        ax.set_xlim(0, 48); ax.set_ylim(0, 1)
        ax.set_xticks([]); ax.set_yticks([])
        for s in ax.spines.values():
            s.set_color("#94a3b8"); s.set_linewidth(1.0)
        ax.set_ylabel(label, fontsize=9, rotation=0, ha="right", va="center",
                      labelpad=8)
    return save_fig(fig, "keys")


# ---- Panel F: Response volume (3D voxel sheets) ----------------------
def panel_response_volume():
    """4 z-slices at t=4, each shown as a translucent rectangular sheet,
    colored by softmax intensity. argmax voxel shown as bright glow."""
    softmax = d["volume_softmax"][T_STAR]  # (32, 56, 56)
    Z_INDICES = [4, 12, 20, 28]
    heights = d["height_meters"][Z_INDICES]
    # Downsample 56×56 → 14×14 for visual clarity per spec ("don't render all 56×56")
    DS = 4
    H_OUT = PRED_SIZE // DS  # 14
    fig = plt.figure(figsize=(5.2, 4.8), dpi=200, facecolor="white")
    ax = fig.add_subplot(111, projection="3d")
    ax.set_box_aspect((1, 1, 1.1))
    cmap = plt.cm.inferno

    # x, y plane shared
    x = np.linspace(0, 1, H_OUT + 1)
    y = np.linspace(0, 1, H_OUT + 1)
    X, Y = np.meshgrid(x, y)

    z_min, z_max = float(d["height_meters"].min()), float(d["height_meters"].max())
    for zi, z in zip(Z_INDICES, heights):
        sl = softmax[zi]
        ds_sl = sl.reshape(H_OUT, DS, H_OUT, DS).mean(axis=(1, 3))
        # Plot grid as a surface w/ facecolors
        z_norm = (z - z_min) / (z_max - z_min + 1e-9)
        Z = np.full_like(X, z_norm)
        ds_norm = ds_sl / (softmax.max() + 1e-9)  # global max for color scale
        face = cmap(ds_norm)
        face[..., 3] = np.clip(0.18 + 0.85 * ds_norm, 0.18, 1.0)
        ax.plot_surface(X, Y, Z, facecolors=face, shade=False,
                        rstride=1, cstride=1, linewidth=0.0, antialiased=False)
        # Faint edge of slice
        ax.plot([0, 1, 1, 0, 0], [0, 0, 1, 1, 0], [z_norm]*5,
                color="#64748b", linewidth=0.8, alpha=0.4)
        # Label height
        ax.text(1.03, 0.0, z_norm, f"z={zi}  ({z*100:.1f} cm)",
                fontsize=8, color="#475569", ha="left", va="center")

    # Highlight argmax voxel
    az, ay, ax_ = d["argmax_voxel"][T_STAR]
    ax_norm = (float(d["height_meters"][int(az)]) - z_min) / (z_max - z_min + 1e-9)
    # Scale grid coords (ay, ax_) from 56-space to 0-1
    x_v = (ax_ + 0.5) / 56.0
    y_v = (ay + 0.5) / 56.0
    ax.scatter([x_v], [y_v], [ax_norm], s=180, c="#22c55e",
               edgecolors="white", linewidths=2.5, zorder=20, depthshade=False)

    ax.set_xlim(0, 1); ax.set_ylim(0, 1); ax.set_zlim(-0.05, 1.05)
    ax.set_xlabel("u (col)", fontsize=8, labelpad=-6)
    ax.set_ylabel("v (row)", fontsize=8, labelpad=-6)
    ax.set_zlabel("height z", fontsize=8, labelpad=-6)
    ax.set_xticks([]); ax.set_yticks([]); ax.set_zticks([])
    ax.view_init(elev=20, azim=-60)
    ax.set_facecolor("white")
    return save_fig(fig, "response_volume")


# ─── 3. Generate panels ───────────────────────────────────────────────
print("rendering panels...")
p_rgb  = panel_rgb_crosshair()
p_pca  = panel_dino_pca()
p_fmap = panel_f_map()
p_fvec = panel_f_vec()
p_keys = panel_keys()
p_vol  = panel_response_volume()
print("panels saved to", OUT_DIR)

rgb_b64  = b64(p_rgb)
pca_b64  = b64(p_pca)
fmap_b64 = b64(p_fmap)
fvec_b64 = b64(p_fvec)
keys_b64 = b64(p_keys)
vol_b64  = b64(p_vol)


# ─── 4. SVG composition ───────────────────────────────────────────────
# Canvas 1400 × 740 horizontal. Reading left→right:
#   col 1 (x=20..380):   RGB image, with PCA inset top-right of column
#   col 2 (x=420..760):  DINO box → F map → F(u,v) bar
#   col 3 (x=800..1140): keys (3 stacked bars)
#   col 4 (x=1180..1380): dot product node + response volume + 3D point

W, H = 1400, 760
svg = f"""<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
     viewBox="0 0 {W} {H}" width="{W}" height="{H}"
     font-family="Inter, Arial, sans-serif">
  <defs>
    <marker id="arrow" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="7" markerHeight="7" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#475569"/>
    </marker>
    <marker id="arrow-green" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#16653a"/>
    </marker>
    <filter id="card-shadow" x="-10%" y="-10%" width="120%" height="130%">
      <feDropShadow dx="0" dy="1.5" stdDeviation="2.5" flood-color="#000" flood-opacity="0.10"/>
    </filter>
  </defs>

  <rect width="{W}" height="{H}" fill="#ffffff"/>

  <!-- Title -->
  <text x="20" y="32" font-size="16" font-weight="800" fill="#0f172a">
    Volume head = image features &middot; (timestep + height) keys
  </text>
  <text x="20" y="52" font-size="12" font-weight="500" fill="#64748b">
    volume_logits[t,z,u,v] = F(u,v) &middot; key(t,z),
    where F &isin; &#8477;<tspan font-size="9" dy="-3">48&times;56&times;56</tspan><tspan dy="3"></tspan>
    is per-pixel from DINOv3, and key(t,z) = t_emb[t] + h_emb[z] &isin; &#8477;<tspan font-size="9" dy="-3">48</tspan>
  </text>
  <line x1="20" y1="68" x2="{W-20}" y2="68" stroke="#e5e7eb" stroke-width="1"/>

  <!-- ─── (a) RGB ─────────────────────────────────────── -->
  <text x="40" y="100" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">(a) RGB INPUT</text>
  <image xlink:href="data:image/png;base64,{rgb_b64}"
         x="40" y="110" width="300" height="300"/>
  <rect x="40" y="110" width="300" height="300" rx="6"
        fill="none" stroke="#94a3b8" stroke-width="1.2"/>
  <text x="40" y="430" font-size="10" font-weight="500" fill="#64748b">
    GT pixel @ t={T_STAR}: ({gt_u504:.0f}, {gt_v504:.0f}) in 504&times;504
  </text>

  <!-- DINO PCA inset (top-right of RGB column) -->
  <text x="245" y="100" font-size="9" font-weight="700" fill="#6b7280" letter-spacing="0.06em">DINO PCA</text>
  <image xlink:href="data:image/png;base64,{pca_b64}"
         x="245" y="110" width="95" height="95"
         style="filter: drop-shadow(0 1px 3px rgba(0,0,0,0.15));"/>

  <!-- DINO arrow → F map -->
  <line x1="345" y1="260" x2="410" y2="260" stroke="#475569" stroke-width="2.5" marker-end="url(#arrow)"/>
  <text x="378" y="252" text-anchor="middle" font-size="11" font-weight="700" fill="#475569">DINOv3</text>
  <text x="378" y="274" text-anchor="middle" font-size="9" font-style="italic" fill="#94a3b8">
    ViT-S/16+ &rarr; 1&times;1 proj
  </text>
  <text x="378" y="286" text-anchor="middle" font-size="9" font-style="italic" fill="#94a3b8">
    {DINO_DIM}d &rarr; {KEY_DIM}d
  </text>

  <!-- ─── (b) F map: per-pixel features ───────────────── -->
  <text x="420" y="100" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">(b) PER-PIXEL FEATURES F</text>
  <image xlink:href="data:image/png;base64,{fmap_b64}"
         x="420" y="110" width="300" height="300"/>
  <rect x="420" y="110" width="300" height="300" rx="6"
        fill="none" stroke="#94a3b8" stroke-width="1.2"/>
  <text x="420" y="430" font-size="10" font-weight="500" fill="#64748b">
    F &isin; &#8477;<tspan font-size="8" dy="-3">{KEY_DIM}&times;{PRED_SIZE}&times;{PRED_SIZE}</tspan>
    <tspan dy="3"></tspan> &nbsp;&middot;&nbsp; PCA of {KEY_DIM}-d feature vectors
  </text>

  <!-- pick pixel arrow: F map → F vector below -->
  <line x1="570" y1="412" x2="570" y2="450" stroke="#22c55e" stroke-width="2.5" marker-end="url(#arrow-green)"/>
  <text x="580" y="436" font-size="10" font-style="italic" fill="#16653a">
    pick pixel (u,v)
  </text>

  <!-- ─── (c) F(u,v) vector bar ───────────────────────── -->
  <text x="420" y="478" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">(c) F(u,v) &isin; &#8477;<tspan font-size="9" dy="-3">{KEY_DIM}</tspan></text>
  <image xlink:href="data:image/png;base64,{fvec_b64}"
         x="420" y="488" width="300" height="60"/>
  <rect x="420" y="488" width="300" height="60" rx="4"
        fill="none" stroke="#94a3b8" stroke-width="1"/>

  <!-- ─── (d) Keys ───────────────────────────────────── -->
  <text x="760" y="100" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">(d) KEY = t_emb + h_emb</text>
  <image xlink:href="data:image/png;base64,{keys_b64}"
         x="760" y="110" width="380" height="200"/>
  <text x="760" y="328" font-size="10" font-weight="500" fill="#64748b">
    8 timesteps &times; 32 height bins &nbsp;&rarr;&nbsp; bank of 256 key vectors,
    sinusoidal &#8869; chosen (t=4, z={GT_Z_BIN})
  </text>

  <!-- ─── (e) Dot product node ───────────────────────── -->
  <g id="dot-product">
    <rect x="760" y="478" width="380" height="80" rx="10"
          fill="#fef3c7" stroke="#b45309" stroke-width="1.6" filter="url(#card-shadow)"/>
    <text x="950" y="510" text-anchor="middle" font-size="14" font-weight="800" fill="#b45309">
      F(u,v) &middot; key(t,z) = volume_logits[t, z, u, v]
    </text>
    <text x="950" y="530" text-anchor="middle" font-size="10" fill="#92400e">
      einsum &quot;chw, tzc &rarr; tzhw&quot; , scaled by exp(logit_scale &approx; {meta['logit_scale_exp']:.1f})
    </text>
    <text x="950" y="548" text-anchor="middle" font-size="9.5" fill="#92400e" font-style="italic">
      = cosine similarity between every pixel feature and every (t,z) key
    </text>
  </g>

  <!-- arrows into dot-product -->
  <!-- from F(u,v) bar (centred at x=570, y=518) -->
  <path d="M 720 518 Q 745 518, 758 502" stroke="#22c55e" stroke-width="2.5"
        fill="none" marker-end="url(#arrow-green)"/>
  <!-- from key bank (centred at x=950, y=210) -->
  <path d="M 950 312 L 950 472" stroke="#fb7185" stroke-width="2.5"
        fill="none" marker-end="url(#arrow)"/>

  <!-- ─── (f) Response volume ────────────────────────── -->
  <text x="1160" y="100" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">(f) RESPONSE VOLUME</text>
  <image xlink:href="data:image/png;base64,{vol_b64}"
         x="1160" y="110" width="220" height="220"/>
  <rect x="1160" y="110" width="220" height="220" rx="6"
        fill="none" stroke="#94a3b8" stroke-width="1"/>
  <text x="1160" y="346" font-size="9.5" font-weight="500" fill="#64748b">
    softmax(volume_logits[t={T_STAR}]) over 4 z-slices
  </text>
  <text x="1160" y="360" font-size="9.5" font-weight="500" fill="#64748b">
    bright voxel = argmax, top-1 prob = {float(d['volume_softmax'][T_STAR].max())*100:.1f}%
  </text>

  <!-- arrow: dot-product → response volume -->
  <path d="M 1140 518 Q 1230 518, 1260 340" stroke="#475569" stroke-width="2.5"
        fill="none" marker-end="url(#arrow)"/>
  <text x="1185" y="478" font-size="10" font-style="italic" fill="#475569">
    softmax
  </text>

  <!-- ─── (g) Output: 3D point ───────────────────────── -->
  <g id="output">
    <rect x="1160" y="478" width="220" height="80" rx="10"
          fill="#dcfce7" stroke="#16653a" stroke-width="1.6" filter="url(#card-shadow)"/>
    <text x="1270" y="500" text-anchor="middle" font-size="11" font-weight="800" fill="#16653a">
      ARGMAX &rarr; 3D POINT
    </text>
    <text x="1270" y="520" text-anchor="middle" font-size="10" fill="#0f3a1f">
      pixel (u*, v*) = ({GT_U_GRID}, {GT_V_GRID})
    </text>
    <text x="1270" y="534" text-anchor="middle" font-size="10" fill="#0f3a1f">
      height z* = {GT_Z_BIN} &nbsp;({float(d['height_meters'][GT_Z_BIN])*100:.1f} cm)
    </text>
    <text x="1270" y="550" text-anchor="middle" font-size="9.5" font-style="italic" fill="#0f3a1f">
      &check; matches GT
    </text>
  </g>

  <!-- argmax → crosshair feedback note (response → RGB) — short callout, not a long arc -->
  <text x="1270" y="582" text-anchor="middle" font-size="9.5" font-style="italic" fill="#16653a">
    &uarr; matches (a) crosshair on RGB
  </text>

  <!-- Footnote: dataset stats anchoring height bins -->
  <text x="20" y="{H-18}" font-size="9" fill="#94a3b8" font-style="italic">
    Numbers from smith300 sample idx {meta['sample_idx']}, ckpt epoch {meta['ckpt_epoch']}, sin/sin embeddings.
    Height bin spacing &approx; {(float(d['height_meters'][-1])-float(d['height_meters'][0]))/(N_HEIGHT_BINS-1)*1000:.1f} mm
    over [{float(d['height_meters'][0])*100:.1f}, {float(d['height_meters'][-1])*100:.1f}] cm.
  </text>
</svg>
"""

OUT_SVG = Path("/data/cameron/para/paper/figs/svg/volume_kv_method.svg")
OUT_SVG.parent.mkdir(parents=True, exist_ok=True)
OUT_SVG.write_text(svg)
print(f"wrote {OUT_SVG}  ({len(svg)//1024} KB)")
