"""Build query_arch_method.svg — diagram for the query-MLP volume head.

Cameron's pitch text (from the figure_maker inbox 2026-05-20):
  RGB → DINO PCA → sparse 3D feature volume (uniform low-stride downsample)
  → point-sampled feature 'lifted under image' between F and the volume
  → positional encoding of height and time concat with feature
  → separately EEF feature produces spatial query via MLP
  → dot product with volume = heatmap
  → probability volume with heatmap colouring
  → argmax for 3D target location.

Same pattern as `build_volume_kv_diagram.py`: pre-rendered matplotlib PNGs
(staged at `/data/cameron/para/paper/figs/data/query_arch/`) get embedded
as base64 into a hand-authored SVG.

Source spec: `/data/cameron/para/paper/figs/data/query_arch/SPEC.md`
"""
import base64
import json
from pathlib import Path

import numpy as np

DATA_DIR = Path("/data/cameron/para/paper/figs/data/query_arch")
NPZ = DATA_DIR / "example.npz"
d = np.load(NPZ, allow_pickle=True)
meta = json.loads(d["meta"].item())
T_STAR = int(d["T_star"])
SAMPLE_IDX = int(d["sample_idx"])
N_VALID = int(meta["n_valid"])
SPAN_PX = float(meta["span_px"])
Z, H, W = int(meta["Z"]), int(meta["H"]), int(meta["W"])
T = int(meta["T"])
PEAK = float(meta["peak_frac_at_T_star"])
gt_pix = d["gt_pix_504"][T_STAR]
start_pix = d["start_pix_504"]

print(f"sample={SAMPLE_IDX}  T*={T_STAR}  peak={PEAK:.3f}  start_pix={start_pix}  gt_pix={gt_pix}")


def b64(p: Path) -> str:
    return base64.b64encode(Path(p).read_bytes()).decode()


# Pre-rendered panels live in DATA_DIR. arch_overview.png is dev-only.
rgb_b64       = b64(DATA_DIR / "rgb.png")
f_pca_b64     = b64(DATA_DIR / "f_pca.png")
f_pca_eef_b64 = b64(DATA_DIR / "f_pca_eef.png")
fvol_b64      = b64(DATA_DIR / "feature_volume.png")
pvol_b64      = b64(DATA_DIR / "prob_volume.png")
pvol_ax_b64   = b64(DATA_DIR / "prob_volume_argmax.png")

# ─── SVG composition ──────────────────────────────────────────────────
#
# Canvas: 1500 × 740
#
# Top side-branch zone (y=80..240):
#   f_pca_eef (130×130)  →  Res-MLP box  →  q box  →  ↓ into dot-product
#   conditioning "sin(t) AdaLN-Zero" callout above MLP
#
# Main horizontal flow (y=270..560):
#   RGB → F-PCA → Feature volume → · ← q → Probability-argmax volume → 3D point
#
W_CANVAS, H_CANVAS = 1500, 820

# Panel coords
RGB_X, RGB_Y, RGB_W, RGB_H               = 20,   380, 170, 170
FPCA_X, FPCA_Y, FPCA_W, FPCA_H           = 220,  380, 170, 170
FVOL_X, FVOL_Y, FVOL_W, FVOL_H           = 410,  325, 320, 320
PVOL_X, PVOL_Y, PVOL_W, PVOL_H           = 870,  325, 320, 320
OUT_X,  OUT_Y,  OUT_W,  OUT_H            = 1240, 385, 240, 200

# Side branch (more vertical room — above main row)
EEF_X, EEF_Y, EEF_W, EEF_H               = 410,  150, 130, 130
MLP_X, MLP_Y, MLP_W, MLP_H               = 580,  170, 220, 100
QBX_X, QBX_Y, QBX_W, QBX_H               = 830,  190, 100, 70
ADALN_X, ADALN_Y                          = 690,  118  # label above MLP

# Dot product node anchor (visual cue between fvol and pvol)
DOT_CX, DOT_CY                            = 790,  485

# Arrow colours (matching the volume_kv_method.svg convention)
ARROW_GRAY = "#475569"
ARROW_GREEN = "#16653a"
ARROW_BLUE = "#1d4ed8"   # query/side-branch arrow
ARROW_ROSE = "#9d174d"   # AdaLN cond
EEF_RED = "#dc2626"


svg = f"""<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
     viewBox="0 0 {W_CANVAS} {H_CANVAS}" width="{W_CANVAS}" height="{H_CANVAS}"
     font-family="Inter, Arial, sans-serif">
  <defs>
    <marker id="arrow-gray" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="7" markerHeight="7" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="{ARROW_GRAY}"/>
    </marker>
    <marker id="arrow-green" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="7" markerHeight="7" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="{ARROW_GREEN}"/>
    </marker>
    <marker id="arrow-blue" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="7" markerHeight="7" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="{ARROW_BLUE}"/>
    </marker>
    <marker id="arrow-rose" viewBox="0 0 10 10" refX="0" refY="5"
            markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="{ARROW_ROSE}"/>
    </marker>
    <filter id="card-shadow" x="-10%" y="-10%" width="120%" height="130%">
      <feDropShadow dx="0" dy="1.5" stdDeviation="2.5" flood-color="#000" flood-opacity="0.10"/>
    </filter>
  </defs>

  <rect width="{W_CANVAS}" height="{H_CANVAS}" fill="#ffffff"/>

  <!-- ─── Title ────────────────────────────────────────────────── -->
  <text x="20" y="34" font-size="17" font-weight="800" fill="#0f172a">
    Query-MLP volume head: q &middot; V &rarr; 3D target
  </text>
  <text x="20" y="56" font-size="11.5" font-weight="500" fill="#64748b">
    EEF + CLS features pass through a 5-layer Res-MLP (AdaLN-Zero on sin(t)) to produce a per-timestep query
    q &isin; &#8477;<tspan font-size="9" dy="-3">64</tspan><tspan dy="3"></tspan>.
  </text>
  <text x="20" y="74" font-size="11.5" font-weight="500" fill="#64748b">
    q is dot-producted against a sparse feature volume V built by lifting per-pixel features F + sin(z) + sin(t) into 3D.
    Argmax over softmax(q&middot;V) gives (z*, y*, x*).
  </text>
  <line x1="20" y1="90" x2="{W_CANVAS-20}" y2="90" stroke="#e5e7eb" stroke-width="1"/>

  <!-- ─── (a) RGB ──────────────────────────────────────────────── -->
  <text x="{RGB_X}" y="{RGB_Y - 10}" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">
    (a) RGB INPUT
  </text>
  <image xlink:href="data:image/png;base64,{rgb_b64}"
         x="{RGB_X}" y="{RGB_Y}" width="{RGB_W}" height="{RGB_H}"
         preserveAspectRatio="xMidYMid meet"/>
  <rect x="{RGB_X}" y="{RGB_Y}" width="{RGB_W}" height="{RGB_H}" rx="6"
        fill="none" stroke="#94a3b8" stroke-width="1.2"/>
  <text x="{RGB_X}" y="{RGB_Y + RGB_H + 18}" font-size="10" font-weight="500" fill="#64748b">
    izzy3 sample {SAMPLE_IDX} &middot; {N_VALID} valid timesteps &middot; trajectory span {SPAN_PX:.0f}&thinsp;px
  </text>

  <!-- DINO encoder arrow: RGB → F-PCA -->
  <line x1="{RGB_X + RGB_W + 5}" y1="{RGB_Y + RGB_H/2}" x2="{FPCA_X - 5}" y2="{FPCA_Y + FPCA_H/2}"
        stroke="{ARROW_GRAY}" stroke-width="2.5" marker-end="url(#arrow-gray)"/>
  <text x="{(RGB_X + RGB_W + FPCA_X)/2}" y="{RGB_Y + RGB_H/2 - 18}" text-anchor="middle"
        font-size="11" font-weight="700" fill="{ARROW_GRAY}">DINOv3</text>
  <text x="{(RGB_X + RGB_W + FPCA_X)/2}" y="{RGB_Y + RGB_H/2 - 4}" text-anchor="middle"
        font-size="9" font-style="italic" fill="#94a3b8">ViT-S/16+</text>
  <text x="{(RGB_X + RGB_W + FPCA_X)/2}" y="{RGB_Y + RGB_H/2 + 12}" text-anchor="middle"
        font-size="9" font-style="italic" fill="#94a3b8">+ 1&times;1 conv</text>

  <!-- ─── (b) F-PCA ────────────────────────────────────────────── -->
  <text x="{FPCA_X}" y="{FPCA_Y - 10}" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">
    (b) F (DINO PCA)
  </text>
  <image xlink:href="data:image/png;base64,{f_pca_b64}"
         x="{FPCA_X}" y="{FPCA_Y}" width="{FPCA_W}" height="{FPCA_H}"
         preserveAspectRatio="xMidYMid meet" image-rendering="pixelated"/>
  <rect x="{FPCA_X}" y="{FPCA_Y}" width="{FPCA_W}" height="{FPCA_H}" rx="6"
        fill="none" stroke="#94a3b8" stroke-width="1.2"/>
  <text x="{FPCA_X}" y="{FPCA_Y + FPCA_H + 18}" font-size="10" font-weight="500" fill="#64748b">
    F &isin; &#8477;<tspan font-size="8" dy="-3">48&times;56&times;56</tspan><tspan dy="3"></tspan>
    &thinsp;&middot;&thinsp; 3-PCA visualization
  </text>

  <!-- F-PCA → Feature volume: indicate "each voxel samples F at (y,x)" -->
  <path d="M {FPCA_X + FPCA_W} {FPCA_Y + 20} Q {FVOL_X - 20} {FVOL_Y + 100}, {FVOL_X + 30} {FVOL_Y + 240}"
        stroke="{ARROW_GRAY}" stroke-width="2" stroke-dasharray="6,4" fill="none"
        marker-end="url(#arrow-gray)" opacity="0.7"/>
  <text x="{FPCA_X + FPCA_W + 15}" y="{FPCA_Y + 12}" font-size="9" font-style="italic" fill="#64748b">
    sample F[y,x] per voxel
  </text>

  <!-- ─── (c) Feature volume ───────────────────────────────────── -->
  <text x="{FVOL_X}" y="{FVOL_Y - 12}" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">
    (c) FEATURE VOLUME V
  </text>
  <image xlink:href="data:image/png;base64,{fvol_b64}"
         x="{FVOL_X}" y="{FVOL_Y}" width="{FVOL_W}" height="{FVOL_H}"
         preserveAspectRatio="xMidYMid meet"/>
  <text x="{FVOL_X}" y="{FVOL_Y + FVOL_H + 4}" font-size="10" font-weight="500" fill="#64748b">
    {Z}&thinsp;heights &times; {H}&thinsp;&times;&thinsp;{W} downsampled to a sparse 14&times;14 grid for display
  </text>
  <text x="{FVOL_X}" y="{FVOL_Y + FVOL_H + 18}" font-size="10" font-style="italic" fill="#64748b">
    cell value = F[y, x] &oplus; sin(z) &oplus; sin(t)
  </text>

  <!-- ─── Side branch: EEF → MLP → q ───────────────────────────── -->
  <!-- conditioning callout -->
  <text x="{ADALN_X + MLP_W/2}" y="{ADALN_Y}" text-anchor="middle"
        font-size="10" font-weight="700" fill="{ARROW_ROSE}" letter-spacing="0.04em">
    sin(t) AdaLN-Zero conditioning
  </text>
  <line x1="{MLP_X + MLP_W/2}" y1="{ADALN_Y + 8}" x2="{MLP_X + MLP_W/2}" y2="{MLP_Y - 4}"
        stroke="{ARROW_ROSE}" stroke-width="1.6" marker-end="url(#arrow-rose)"/>

  <!-- f_pca_eef thumbnail -->
  <text x="{EEF_X}" y="{EEF_Y - 8}" font-size="10" font-weight="700" fill="#6b7280" letter-spacing="0.06em">
    EEF + CLS
  </text>
  <image xlink:href="data:image/png;base64,{f_pca_eef_b64}"
         x="{EEF_X}" y="{EEF_Y}" width="{EEF_W}" height="{EEF_H}"
         preserveAspectRatio="xMidYMid meet"/>
  <rect x="{EEF_X}" y="{EEF_Y}" width="{EEF_W}" height="{EEF_H}" rx="6"
        fill="none" stroke="{EEF_RED}" stroke-width="1.5"/>

  <!-- EEF → MLP arrow -->
  <line x1="{EEF_X + EEF_W + 4}" y1="{EEF_Y + EEF_H/2}" x2="{MLP_X - 4}" y2="{MLP_Y + MLP_H/2}"
        stroke="{ARROW_BLUE}" stroke-width="2.5" marker-end="url(#arrow-blue)"/>

  <!-- Res-MLP box -->
  <rect x="{MLP_X}" y="{MLP_Y}" width="{MLP_W}" height="{MLP_H}" rx="10"
        fill="#dbeafe" stroke="{ARROW_BLUE}" stroke-width="1.6" filter="url(#card-shadow)"/>
  <text x="{MLP_X + MLP_W/2}" y="{MLP_Y + 28}" text-anchor="middle"
        font-size="13" font-weight="800" fill="{ARROW_BLUE}">
    5-layer Res-MLP
  </text>
  <text x="{MLP_X + MLP_W/2}" y="{MLP_Y + 46}" text-anchor="middle"
        font-size="10" fill="#1e3a8a">
    hidden=512, AdaLN-Zero
  </text>
  <text x="{MLP_X + MLP_W/2}" y="{MLP_Y + 62}" text-anchor="middle"
        font-size="10" fill="#1e3a8a">
    conditioned on sin(t)
  </text>
  <text x="{MLP_X + MLP_W/2}" y="{MLP_Y + 78}" text-anchor="middle"
        font-size="9.5" font-style="italic" fill="#1e3a8a">
    EEF + CLS &rarr; q (per t)
  </text>

  <!-- MLP → q box -->
  <line x1="{MLP_X + MLP_W + 4}" y1="{MLP_Y + MLP_H/2}" x2="{QBX_X - 4}" y2="{QBX_Y + QBX_H/2}"
        stroke="{ARROW_BLUE}" stroke-width="2.5" marker-end="url(#arrow-blue)"/>

  <!-- q box -->
  <rect x="{QBX_X}" y="{QBX_Y}" width="{QBX_W}" height="{QBX_H}" rx="10"
        fill="#dbeafe" stroke="{ARROW_BLUE}" stroke-width="1.6" filter="url(#card-shadow)"/>
  <text x="{QBX_X + QBX_W/2}" y="{QBX_Y + 26}" text-anchor="middle"
        font-size="15" font-weight="800" fill="{ARROW_BLUE}">q</text>
  <text x="{QBX_X + QBX_W/2}" y="{QBX_Y + 44}" text-anchor="middle"
        font-size="10" fill="#1e3a8a">
    &isin; &#8477;<tspan font-size="8" dy="-3">64</tspan>
  </text>

  <!-- q → dot product symbol (down arrow into between fvol and pvol) -->
  <path d="M {QBX_X + QBX_W/2} {QBX_Y + QBX_H + 4} L {QBX_X + QBX_W/2} {DOT_CY - 22}"
        stroke="{ARROW_BLUE}" stroke-width="2.5" fill="none"
        marker-end="url(#arrow-blue)"/>

  <!-- Dot product node -->
  <circle cx="{DOT_CX}" cy="{DOT_CY}" r="26"
          fill="#fef3c7" stroke="#b45309" stroke-width="1.8" filter="url(#card-shadow)"/>
  <text x="{DOT_CX}" y="{DOT_CY + 6}" text-anchor="middle"
        font-size="22" font-weight="900" fill="#b45309">&middot;</text>
  <text x="{DOT_CX}" y="{DOT_CY + 50}" text-anchor="middle"
        font-size="10" font-weight="700" fill="#92400e">q &middot; V</text>

  <!-- fvol → dot product arrow -->
  <line x1="{FVOL_X + FVOL_W}" y1="{FVOL_Y + FVOL_H/2}" x2="{DOT_CX - 26}" y2="{DOT_CY}"
        stroke="{ARROW_GRAY}" stroke-width="2.5" marker-end="url(#arrow-gray)"/>

  <!-- dot product → pvol arrow -->
  <line x1="{DOT_CX + 26}" y1="{DOT_CY}" x2="{PVOL_X - 4}" y2="{PVOL_Y + PVOL_H/2}"
        stroke="{ARROW_GRAY}" stroke-width="2.5" marker-end="url(#arrow-gray)"/>
  <text x="{(DOT_CX + PVOL_X)/2}" y="{DOT_CY - 12}" text-anchor="middle"
        font-size="10" font-style="italic" fill="#64748b">softmax(...)</text>

  <!-- ─── (d) Probability volume w/ argmax ─────────────────────── -->
  <text x="{PVOL_X}" y="{PVOL_Y - 12}" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.06em">
    (d) PROBABILITY VOLUME &middot; argmax @ t={T_STAR}
  </text>
  <image xlink:href="data:image/png;base64,{pvol_ax_b64}"
         x="{PVOL_X}" y="{PVOL_Y}" width="{PVOL_W}" height="{PVOL_H}"
         preserveAspectRatio="xMidYMid meet"/>
  <text x="{PVOL_X}" y="{PVOL_Y + PVOL_H + 4}" font-size="10" font-weight="500" fill="#64748b">
    top-25 voxels of softmax(q&middot;V), plasma colormap
  </text>
  <text x="{PVOL_X}" y="{PVOL_Y + PVOL_H + 18}" font-size="10" font-style="italic" fill="#16653a">
    bright green = argmax voxel &middot; peak softmax {PEAK*100:.1f}%
  </text>

  <!-- pvol → output box -->
  <line x1="{PVOL_X + PVOL_W - 8}" y1="{PVOL_Y + PVOL_H/2 - 20}"
        x2="{OUT_X - 4}" y2="{OUT_Y + OUT_H/2}"
        stroke="{ARROW_GREEN}" stroke-width="2.5" marker-end="url(#arrow-green)"/>
  <text x="{(PVOL_X + PVOL_W + OUT_X)/2}" y="{OUT_Y - 14}" text-anchor="middle"
        font-size="10" font-style="italic" fill="{ARROW_GREEN}">argmax</text>

  <!-- ─── (e) Output: 3D point ─────────────────────────────────── -->
  <rect x="{OUT_X}" y="{OUT_Y}" width="{OUT_W}" height="{OUT_H}" rx="10"
        fill="#dcfce7" stroke="{ARROW_GREEN}" stroke-width="1.6" filter="url(#card-shadow)"/>
  <text x="{OUT_X + OUT_W/2}" y="{OUT_Y + 28}" text-anchor="middle"
        font-size="13" font-weight="800" fill="{ARROW_GREEN}">
    ARGMAX &rarr; 3D POINT
  </text>
  <text x="{OUT_X + OUT_W/2}" y="{OUT_Y + 52}" text-anchor="middle"
        font-size="11" fill="#0f3a1f">
    (x*, y*, z*)<tspan font-size="9" dy="-3">t</tspan><tspan dy="3"></tspan> at t={T_STAR}
  </text>
  <text x="{OUT_X + OUT_W/2}" y="{OUT_Y + 72}" text-anchor="middle"
        font-size="10" fill="#0f3a1f">
    = ({int(d['argmax_x'][T_STAR])}, {int(d['argmax_y'][T_STAR])}, {int(d['argmax_z'][T_STAR])}) @ {H}&times;{W} grid
  </text>
  <text x="{OUT_X + OUT_W/2}" y="{OUT_Y + 90}" text-anchor="middle"
        font-size="10" fill="#0f3a1f">
    height z* &rarr; world-Z bin center
  </text>
  <text x="{OUT_X + OUT_W/2}" y="{OUT_Y + 116}" text-anchor="middle"
        font-size="9.5" font-style="italic" fill="#0f3a1f">
    per-t target; full traj predicted jointly
  </text>

  <!-- ─── Score formula callout (under main row) ───────────────── -->
  <rect x="350" y="700" width="800" height="48" rx="8"
        fill="#f8fafc" stroke="#cbd5e1" stroke-width="1"/>
  <text x="750" y="720" text-anchor="middle"
        font-size="12" font-weight="700" fill="#0f172a"
        font-family="ui-monospace, monospace">
    score(z, y, x) = &lt;q<tspan font-size="9" dy="3">F</tspan><tspan dy="-3"></tspan>, F[y,x]&gt; + &lt;q<tspan font-size="9" dy="3">z</tspan><tspan dy="-3"></tspan>, sin<tspan font-size="9" dy="3">z</tspan><tspan dy="-3"></tspan>[z]&gt; + &lt;q<tspan font-size="9" dy="3">t</tspan><tspan dy="-3"></tspan>, sin<tspan font-size="9" dy="3">t</tspan><tspan dy="-3"></tspan>[t]&gt;
  </text>
  <text x="750" y="738" text-anchor="middle"
        font-size="10" font-style="italic" fill="#64748b">
    bilinear scoring &middot; cosine similarity between query and (feature, height, time) basis
  </text>

  <!-- ─── Footer ────────────────────────────────────────────────── -->
  <text x="20" y="{H_CANVAS - 16}" font-size="9" fill="#94a3b8" font-style="italic">
    Sample {SAMPLE_IDX} (izzy3 train, 1D-PCA rotation, ckpt `dino_query_izzy3_t50_pca1d_v0/latest.pth`).
    EEF + CLS tokens drive the spatial query; key/value pathways factor through the per-pixel features F.
  </text>
</svg>
"""

OUT_SVG = Path("/data/cameron/para/paper/figs/svg/query_arch_method.svg")
OUT_SVG.parent.mkdir(parents=True, exist_ok=True)
OUT_SVG.write_text(svg)
print(f"wrote {OUT_SVG}  ({len(svg)//1024} KB)")
