"""Build fig1_overview.svg with a 3D volume cloud replacing the 2D heatmap.

Runs the PARA model, subsamples the 32×64×64 heatmap to 32×16×16 via max pooling,
projects each voxel isometrically, and emits SVG circles for the voxel cloud,
plus a wireframe cube outline.
"""

import os
import sys
import math
import base64
import time
_t_start = time.time()
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
print(f"[{time.time()-_t_start:5.2f}s] imports complete")

sys.path.insert(0, "/data/cameron/para_normalized_losses/libero")
import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, PRED_SIZE
from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm, get_libero_path
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix, project_points_from_world_to_camera,
)
import h5py

# ═══════════════════════════════════════════════════════════════════════════
# Step 1: Run PARA to get the 3D heatmap
# ═══════════════════════════════════════════════════════════════════════════

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMAGE_SIZE = 448
MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

print(f"[{time.time()-_t_start:5.2f}s] loading PARA model...")
ckpt = torch.load(
    "/data/cameron/para_normalized_losses/libero/checkpoints/para_v2_exp4_n64/best.pth",
    map_location="cpu")
model_module.MIN_HEIGHT = float(ckpt["min_height"])
model_module.MAX_HEIGHT = float(ckpt["max_height"])
n_window = ckpt["model_state_dict"]["volume_head.weight"].shape[0] // N_HEIGHT_BINS
model = TrajectoryHeatmapPredictor(n_window=n_window)
model.load_state_dict(ckpt["model_state_dict"], strict=False)
model = model.to(device).eval()

print(f"[{time.time()-_t_start:5.2f}s] model loaded; initializing LIBERO...")
benchmark = bm.get_benchmark_dict()["libero_spatial"]()
task = benchmark.get_task(0)
bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env = OffScreenRenderEnv(
    bddl_file_name=bddl_file,
    camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
    camera_names=["agentview"])
env.seed(0); env.reset()
sim = env.env.sim
for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
    try: sim.model.body_pos[sim.model.body_name2id(name)] = [0, 0, -5]
    except: pass
for name in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
    try:
        dbid = sim.model.body_name2id(name)
        for gid in range(sim.model.ngeom):
            if sim.model.geom_bodyid[gid] == dbid: sim.model.geom_rgba[gid][3] = 0.0
    except: pass
sim.forward()

with h5py.File(os.path.join(get_libero_path("datasets"), benchmark.get_task_demonstration(0)), "r") as f:
    state = np.array(f["data/demo_0/states"][32])
obs = env.set_init_state(state); sim.forward()
obs = env.env._get_observations()
rgb = np.flipud(np.asarray(obs["agentview_image"]).copy())
eef = np.array(obs["robot0_eef_pos"])

img = (rgb.astype(np.float32)/255 - MEAN) / STD
img_t = torch.from_numpy(img.transpose(2,0,1)).float().unsqueeze(0).to(device)
w2c = get_camera_transform_matrix(sim, "agentview", IMAGE_SIZE, IMAGE_SIZE)
pix_rc = project_points_from_world_to_camera(eef.reshape(1,3), w2c, IMAGE_SIZE, IMAGE_SIZE)[0]
start_kp = torch.tensor([float(pix_rc[1]), float(pix_rc[0])], dtype=torch.float32).to(device)
print(f"[{time.time()-_t_start:5.2f}s] LIBERO ready; running PARA inference...")
with torch.no_grad():
    vol_logits, _, _, _ = model(img_t, start_kp)
print(f"[{time.time()-_t_start:5.2f}s] inference done; subsampling + projecting voxels...")

vol_t = vol_logits[0, -1]  # last window timestep
vol_probs = F.softmax(vol_t.reshape(-1), dim=0).reshape(vol_t.shape)
vol_np = vol_probs.cpu().numpy()  # (32, 64, 64)
print(f"Raw volume: shape={vol_np.shape}, sum={vol_np.sum():.4f}, max={vol_np.max():.4f}")

env.close()

# ═══════════════════════════════════════════════════════════════════════════
# Step 2: Subsample and transform
# ═══════════════════════════════════════════════════════════════════════════

# Max pool 4x4 over spatial dims: 64 → 16 per axis
# vol_np shape: (32_h, 64_v, 64_u)
vol_sub = vol_np.reshape(32, 16, 4, 16, 4).max(axis=(2, 4))  # (32, 16, 16)
argmax_flat = vol_sub.argmax()
ah = argmax_flat // (16*16); av = (argmax_flat // 16) % 16; au = argmax_flat % 16
print(f"Subsampled: shape={vol_sub.shape}, max={vol_sub.max():.4f}, min={vol_sub.min():.3e}")
print(f"Argmax voxel: h={ah}, v={av}, u={au}")

# Normalize to [0, 1]
vol_norm = vol_sub / vol_sub.max()

# Log-scale visibility to spread values across the full [0, 1] range
# Without this, the peak dominates and everything else is invisible
eps = 1e-10
log_norm = np.log10(vol_norm + eps)  # range roughly [-10, 0]
log_min, log_max = log_norm.min(), log_norm.max()
vol_vis = (log_norm - log_min) / (log_max - log_min + 1e-8)  # [0, 1]
print(f"Log-vis: min={vol_vis.min():.3f}, max={vol_vis.max():.3f}, mean={vol_vis.mean():.3f}")

# ═══════════════════════════════════════════════════════════════════════════
# Step 3: Isometric projection
# ═══════════════════════════════════════════════════════════════════════════

# True isometric projection:
# - u axis (pixel col)  → down-right at 30°
# - v axis (pixel row)  → down-left at 30°
# - h axis (height bin) → straight up (screen y decreases)
#
# sx = (u - v) * cos30 * STEP_UV
# sy = (u + v) * sin30 * STEP_UV - h * STEP_H

STEP_UV = 4.2
STEP_H  = 2.3
COS30 = math.cos(math.radians(30))
SIN30 = math.sin(math.radians(30))

def project(u, v, h):
    sx = (u - v) * COS30 * STEP_UV
    sy = (u + v) * SIN30 * STEP_UV - h * STEP_H
    return sx, sy

# Compute the bounding box of the projected cube
all_corners = [project(u, v, h) for u in (0, 15) for v in (0, 15) for h in (0, 31)]
xs, ys = zip(*all_corners)
min_x, max_x = min(xs), max(xs)
min_y, max_y = min(ys), max(ys)
cube_w = max_x - min_x
cube_h = max_y - min_y
print(f"Cube projection bounds: {cube_w:.1f} x {cube_h:.1f}")

# Target box in the main SVG — Heatmap Volume new position
BOX_X = 240
BOX_Y = 80
BOX_W = 180
BOX_H = 160

# Center the cube in the box
offset_x = BOX_X + (BOX_W - cube_w) / 2 - min_x
offset_y = BOX_Y + (BOX_H - cube_h) / 2 - min_y

# ═══════════════════════════════════════════════════════════════════════════
# Step 4: Generate voxel cloud + wireframe
# ═══════════════════════════════════════════════════════════════════════════

plasma = plt.cm.plasma

# Render ALL voxels; alpha from log-scaled visibility
# Sort back-to-front: in isometric with camera at (+u, -v, +h) roughly,
# voxels at lower u, higher v, lower h are farthest.
# Depth measure: larger value = closer to camera (drawn last).
#   camera direction roughly along (+u, -v, +h); depth = u - v + h
voxels = []
for h in range(32):
    for v in range(16):
        for u in range(16):
            p_vis = float(vol_vis[h, v, u])
            sx, sy = project(u, v, h)
            sx += offset_x
            sy += offset_y
            # Standard isometric: camera at (+u, +v, +h). Farther voxels have
            # lower (u+v+h) — draw them first. Higher = closer = drawn last.
            depth = u + v + h
            voxels.append((sx, sy, p_vis, depth))

# Sort ascending by depth: lowest depth drawn first (goes behind)
voxels.sort(key=lambda x: x[3])
print(f"Voxels to render: {len(voxels)} / {32*16*16}")

# Generate SVG — color by plasma, variable radius, alpha by vis
# - Dim voxels: tiny, faint, dark-purple (fog)
# - Bright voxels: large, opaque, yellow (peak)
voxel_svg_parts = []
for sx, sy, p_vis, _ in voxels:
    # Color mapping: shift up so dim voxels aren't pure dark navy
    color_val = 0.2 + 0.8 * p_vis
    rgba = plasma(color_val)
    r, g, b = int(rgba[0]*255), int(rgba[1]*255), int(rgba[2]*255)
    # Variable radius: dim = small fog, bright = prominent peak
    radius = 1.4 + 5.0 * (p_vis ** 1.0)
    # Alpha: floor 0.08 for fog, peak gets 1.0
    alpha = max(0.08, p_vis ** 1.5)
    voxel_svg_parts.append(
        f'<circle cx="{sx:.2f}" cy="{sy:.2f}" r="{radius:.2f}" '
        f'fill="rgb({r},{g},{b})" fill-opacity="{alpha:.3f}"/>')

voxel_svg = '\n      '.join(voxel_svg_parts)
print(f"[{time.time()-_t_start:5.2f}s] {len(voxel_svg_parts)} voxel SVG shapes generated")

# Wireframe cube outline (8 corners, 12 edges)
def corner_svg(u, v, h):
    sx, sy = project(u, v, h)
    return (sx + offset_x, sy + offset_y)

corners = {(u, v, h): corner_svg(u, v, h)
           for u in (0, 15) for v in (0, 15) for h in (0, 31)}

edges = [
    # bottom face (h=0)
    ((0,0,0), (15,0,0)), ((15,0,0), (15,15,0)),
    ((15,15,0), (0,15,0)), ((0,15,0), (0,0,0)),
    # top face (h=31)
    ((0,0,31), (15,0,31)), ((15,0,31), (15,15,31)),
    ((15,15,31), (0,15,31)), ((0,15,31), (0,0,31)),
    # vertical (h-axis) edges
    ((0,0,0), (0,0,31)), ((15,0,0), (15,0,31)),
    ((15,15,0), (15,15,31)), ((0,15,0), (0,15,31)),
]

wireframe_parts = []
for a, b in edges:
    x1, y1 = corners[a]
    x2, y2 = corners[b]
    wireframe_parts.append(
        f'<line x1="{x1:.2f}" y1="{y1:.2f}" x2="{x2:.2f}" y2="{y2:.2f}" '
        f'stroke="#888888" stroke-width="1.2" stroke-opacity="0.65"/>')
wireframe_svg = '\n      '.join(wireframe_parts)

# ═══════════════════════════════════════════════════════════════════════════
# Step 5: Assemble the full fig1_overview.svg with the new volume block
# ═══════════════════════════════════════════════════════════════════════════

def b64(p): return base64.b64encode(open(p, "rb").read()).decode()

rgb_b64 = b64("/data/cameron/penpot/figures/extracted/fig1a_rgb_hires.png")
pca_b64 = b64("/data/cameron/penpot/figures/extracted/fig1a_pca_hires.png")
hm_b64  = b64("/data/cameron/penpot/figures/extracted/fig1a_heatmap_sq.png")

# Panel (a) left: 4 row illustrations
r1_towel_b64 = b64("/data/cameron/penpot/figures/extracted/fig1v2/row1_towel.png")
r2_train_b64 = b64("/data/cameron/penpot/figures/extracted/fig1v2/row2_train.png")
r2_test_b64  = b64("/data/cameron/penpot/figures/extracted/fig1v2/row2_test_v3.png")
r3_svd_b64   = b64("/data/cameron/penpot/figures/extracted/fig1v2/row3_svd_clean.png")
r4_hand_b64  = b64("/data/cameron/penpot/figures/extracted/fig6_arm_deleted_0.png")
r4_libero_b64 = b64("/data/cameron/penpot/figures/extracted/fig1a_rgb_hires.png")

# fig1 v3 (advisor-meeting rebuild, 2026-05-12): 3 stacked OOD-axis rows
v3_r1_train_b64 = b64("/data/cameron/penpot/figures/extracted/fig1v3/row1_train.png")
v3_r1_test_b64  = b64("/data/cameron/penpot/figures/extracted/fig1v3/row1_test.png")
v3_r2_train_b64 = b64("/data/cameron/penpot/figures/extracted/fig1v3/row2_train.png")
v3_r2_test_b64  = b64("/data/cameron/penpot/figures/extracted/fig1v3/row2_test.png")
v3_r3_test_b64  = b64("/data/cameron/penpot/figures/extracted/fig1v3/row3_test.png")
v3_r3_ta_b64    = b64("/data/cameron/penpot/figures/extracted/fig1v3/row3_train_a.png")
v3_r3_tb_b64    = b64("/data/cameron/penpot/figures/extracted/fig1v3/row3_train_b.png")
v3_r3_tc_b64    = b64("/data/cameron/penpot/figures/extracted/fig1v3/row3_train_c.png")

svg = f'''<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
     viewBox="0 0 1400 700" width="1400" height="700"
     font-family="Inter, Arial, sans-serif">
  <defs>
    <!-- refX=0 places the triangle BASE at the line endpoint so the stroke's
         butt cap is covered; the tip then extends markerWidth units past it.
         markerUnits="userSpaceOnUse" makes markerWidth/Height literal user-unit
         values instead of multiples of stroke-width (default), so the arrow
         tip extends exactly 6 units past the line endpoint regardless of the
         line's stroke-width. Without this, a stroke-width-3 arrow would render
         a marker 6×3 = 18 units wide and the tip would bleed into the target. -->
    <marker id="arrow-gray" viewBox="0 0 10 10" refX="0" refY="5" markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#9ca3af"/>
    </marker>
    <marker id="arrow-green" viewBox="0 0 10 10" refX="0" refY="5" markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#16653a"/>
    </marker>
    <marker id="arrow-red" viewBox="0 0 10 10" refX="0" refY="5" markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#a12029"/>
    </marker>
    <marker id="arrow-slate" viewBox="0 0 10 10" refX="0" refY="5" markerWidth="6" markerHeight="6" markerUnits="userSpaceOnUse" orient="auto">
      <path d="M0,0 L10,5 L0,10 Z" fill="#64748b"/>
    </marker>
    <filter id="card-shadow" x="-10%" y="-10%" width="120%" height="130%">
      <feDropShadow dx="0" dy="1" stdDeviation="2" flood-color="#000" flood-opacity="0.06"/>
    </filter>
    <filter id="brighten">
      <feComponentTransfer>
        <feFuncR type="linear" slope="1.35" intercept="0.05"/>
        <feFuncG type="linear" slope="1.35" intercept="0.05"/>
        <feFuncB type="linear" slope="1.35" intercept="0.05"/>
      </feComponentTransfer>
    </filter>
  </defs>

  <rect width="1400" height="700" fill="#ffffff"/>

  <!-- ══════════════════════════════════════════════════════════════
       PANEL (b) RIGHT — Benefits of PARA: 4 result rows
       Wrapped in transform to shift entire stats panel to right side.
       ══════════════════════════════════════════════════════════════ -->
  <g transform="translate(750, 0)" id="panel-b-benefits">
  <text x="20" y="48" font-size="13" font-weight="700" fill="#6b7280" letter-spacing="0.02em">(b) Benefits of PARA</text>
  <line x1="20" y1="58" x2="550" y2="58" stroke="#e5e7eb" stroke-width="1"/>

  <!-- ─── Row 1: NEW OBJECT POSITION ─── -->
  <g id="v3-row1-position">
    <rect x="10" y="70" width="540" height="170" rx="12" fill="#ffffff" stroke="#e5e7eb" stroke-width="1" filter="url(#card-shadow)"/>
    <text x="22" y="92" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.09em">NEW OBJECT POSITION</text>
    <text x="22" y="108" font-size="10" font-weight="500" fill="#475569" font-style="italic">Train on diversity → robust at held-out positions.</text>

    <!-- Train viz (smaller, LEFT) -->
    <clipPath id="v3-clip-r1tr"><rect x="22" y="118" width="140" height="112" rx="4"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r1_train_b64}"
           x="22" y="118" width="140" height="112" preserveAspectRatio="xMidYMid meet" clip-path="url(#v3-clip-r1tr)"/>
    <rect x="22" y="118" width="140" height="112" rx="4" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <rect x="22" y="218" width="140" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="92" y="227" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">TRAIN COVERAGE</text>

    <!-- Arrow -->
    <line x1="168" y1="174" x2="192" y2="174" stroke="#64748b" stroke-width="2.5" marker-end="url(#arrow-slate)"/>

    <!-- Test viz (larger, RIGHT) -->
    <clipPath id="v3-clip-r1te"><rect x="200" y="118" width="330" height="112" rx="4"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r1_test_b64}"
           x="200" y="118" width="330" height="112" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r1te)"/>
    <rect x="200" y="118" width="330" height="112" rx="4" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <!-- Result badges -->
    <g id="v3-r1-badges">
      <rect x="208" y="124" width="86" height="18" rx="9" fill="#dcfce7" stroke="#16653a" stroke-width="1.2"/>
      <text x="251" y="137" text-anchor="middle" font-size="10" font-weight="800" fill="#16653a">PARA ✓ 97%</text>
      <rect x="300" y="124" width="76" height="18" rx="9" fill="#fef2f2" stroke="#a12029" stroke-width="1.2"/>
      <text x="338" y="137" text-anchor="middle" font-size="10" font-weight="800" fill="#a12029">ACT ✗ 9%</text>
    </g>
    <rect x="200" y="218" width="330" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="365" y="227" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">HELD-OUT POSITION ROLLOUT</text>
  </g>

  <!-- ─── Row 2: NEW VIEWPOINT ─── -->
  <g id="v3-row2-viewpoint">
    <rect x="10" y="248" width="540" height="170" rx="12" fill="#ffffff" stroke="#e5e7eb" stroke-width="1" filter="url(#card-shadow)"/>
    <text x="22" y="270" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.09em">NEW VIEWPOINT</text>
    <text x="22" y="286" font-size="10" font-weight="500" fill="#475569" font-style="italic">Train on diverse views → robust at held-out angles.</text>

    <clipPath id="v3-clip-r2tr"><rect x="22" y="296" width="140" height="112" rx="4"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r2_train_b64}"
           x="22" y="296" width="140" height="112" preserveAspectRatio="xMidYMid meet" clip-path="url(#v3-clip-r2tr)"/>
    <rect x="22" y="296" width="140" height="112" rx="4" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <rect x="22" y="396" width="140" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="92" y="405" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">TRAIN COVERAGE</text>

    <line x1="168" y1="352" x2="192" y2="352" stroke="#64748b" stroke-width="2.5" marker-end="url(#arrow-slate)"/>

    <clipPath id="v3-clip-r2te"><rect x="200" y="296" width="330" height="112" rx="4"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r2_test_b64}"
           x="200" y="296" width="330" height="112" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r2te)"/>
    <rect x="200" y="296" width="330" height="112" rx="4" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <g id="v3-r2-badges">
      <rect x="208" y="302" width="86" height="18" rx="9" fill="#dcfce7" stroke="#16653a" stroke-width="1.2"/>
      <text x="251" y="315" text-anchor="middle" font-size="10" font-weight="800" fill="#16653a">PARA ✓ 52%</text>
      <rect x="300" y="302" width="76" height="18" rx="9" fill="#fef2f2" stroke="#a12029" stroke-width="1.2"/>
      <text x="338" y="315" text-anchor="middle" font-size="10" font-weight="800" fill="#a12029">ACT ✗ 0%</text>
    </g>
    <rect x="200" y="396" width="330" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="365" y="405" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">HELD-OUT VIEWPOINT ROLLOUT</text>
  </g>

  <!-- ─── Row 3: NEW ENVIRONMENT ─── -->
  <g id="v3-row3-environment">
    <rect x="10" y="426" width="540" height="170" rx="12" fill="#ffffff" stroke="#e5e7eb" stroke-width="1" filter="url(#card-shadow)"/>
    <text x="22" y="448" font-size="11" font-weight="700" fill="#6b7280" letter-spacing="0.09em">NEW ENVIRONMENT</text>
    <text x="22" y="464" font-size="10" font-weight="500" fill="#475569" font-style="italic">Train across setups → robust in a held-out env.</text>

    <!-- Train: 3 stacked thumbnails labeled as different lab setups -->
    <clipPath id="v3-clip-r3ta"><rect x="22" y="472" width="140" height="32" rx="3"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r3_ta_b64}"
           x="22" y="472" width="140" height="32" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r3ta)"/>
    <rect x="22" y="472" width="140" height="32" rx="3" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <clipPath id="v3-clip-r3tb"><rect x="22" y="507" width="140" height="32" rx="3"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r3_tb_b64}"
           x="22" y="507" width="140" height="32" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r3tb)"/>
    <rect x="22" y="507" width="140" height="32" rx="3" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <clipPath id="v3-clip-r3tc"><rect x="22" y="542" width="140" height="32" rx="3"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r3_tc_b64}"
           x="22" y="542" width="140" height="32" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r3tc)"/>
    <rect x="22" y="542" width="140" height="32" rx="3" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <rect x="22" y="582" width="140" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="92" y="591" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">3 LAB SETUPS</text>

    <line x1="168" y1="523" x2="192" y2="523" stroke="#64748b" stroke-width="2.5" marker-end="url(#arrow-slate)"/>

    <clipPath id="v3-clip-r3te"><rect x="200" y="472" width="330" height="112" rx="4"/></clipPath>
    <image xlink:href="data:image/png;base64,{v3_r3_test_b64}"
           x="200" y="472" width="330" height="112" preserveAspectRatio="xMidYMid slice" clip-path="url(#v3-clip-r3te)"/>
    <rect x="200" y="472" width="330" height="112" rx="4" fill="none" stroke="#cbd5e1" stroke-width="1"/>
    <g id="v3-r3-badges">
      <rect x="208" y="478" width="86" height="18" rx="9" fill="#dcfce7" stroke="#16653a" stroke-width="1.2"/>
      <text x="251" y="491" text-anchor="middle" font-size="10" font-weight="800" fill="#16653a">PARA ✓ 94%</text>
      <rect x="300" y="478" width="76" height="18" rx="9" fill="#fef2f2" stroke="#a12029" stroke-width="1.2"/>
      <text x="338" y="491" text-anchor="middle" font-size="10" font-weight="800" fill="#a12029">ACT ✗ 0%</text>
    </g>
    <rect x="200" y="582" width="330" height="12" fill="#000" fill-opacity="0.6"/>
    <text x="365" y="591" text-anchor="middle" font-size="8" font-weight="700" fill="#ffffff" letter-spacing="0.05em">HELD-OUT ENVIRONMENT ROLLOUT</text>
  </g>

  <!-- ─── Anchor numbers strip ─── -->
  <rect x="10" y="604" width="540" height="22" rx="6" fill="#f8fafc" stroke="#e5e7eb" stroke-width="1"/>
  <text x="280" y="619" text-anchor="middle" font-size="10" font-weight="600" fill="#475569">
    <tspan font-weight="800" fill="#16653a">SO-100 anchor results</tspan>
    <tspan>  ·  Position 97 / 9</tspan>
    <tspan>  ·  Viewpoint 52 / 0</tspan>
    <tspan>  ·  New env 94 / 0</tspan>
  </text>

  <!-- ─── Caption (cup-task disclaimer) ─── -->
  <text x="280" y="644" text-anchor="middle" font-size="9.5" font-style="italic" fill="#64748b">In submission: cup task on custom arm — held-out condition rollouts in progress.</text>
  </g>

  <!-- ══════════════════════════════════════════════════════════════
       PANEL (a) LEFT — PARA EEF Regression Modification
       Architecture diagram wrapped in a transform so the internal voxel/wireframe
       coordinates (which come from BOX_X/BOX_Y) stay unchanged.
       ══════════════════════════════════════════════════════════════ -->
  <text x="20" y="48" font-size="13" font-weight="700" fill="#6b7280" letter-spacing="0.02em">(a) PARA — EEF Regression Modification</text>
  <line x1="20" y1="58" x2="750" y2="58" stroke="#e5e7eb" stroke-width="1"/>

  <g transform="translate(-43, -2) scale(1.2)" id="panel-a-architecture">
    <text x="330" y="64" text-anchor="middle" font-size="13" font-weight="800" fill="#16653a" letter-spacing="0.03em">PARA EEF Regression (Ours)</text>
    <text x="330" y="78" text-anchor="middle" font-size="10" font-weight="600" fill="#16653a" font-style="italic">Pixel-Aligned Heatmap Volume</text>

    <!-- TOP ROW: Heatmap Volume (3D voxels) -->
    <g id="box-3d-volume">
      <rect x="240" y="80" width="180" height="160" rx="10" fill="#ffffff" stroke="#16653a" stroke-width="2.5"/>
      {wireframe_svg}
      {voxel_svg}
    </g>

    <!-- Unproject arrow: Heatmap Volume → 3D Scene (visual tip leaves 4px gap before box) -->
    <line x1="422" y1="160" x2="478" y2="160" stroke="#16653a" stroke-width="3.5" marker-end="url(#arrow-green)"/>
    <text x="454" y="152" text-anchor="middle" font-size="11" font-weight="600" fill="#16653a" font-style="italic">Unproject</text>
    <text x="454" y="178" text-anchor="middle" font-size="9.5" font-weight="500" fill="#16653a" font-style="italic">(Pose, Intr.)</text>

    <!-- 3D Scene visualization -->
    <g id="box-3d-scene">
      <rect x="488" y="80" width="160" height="160" rx="10" fill="#fef2f2" stroke="#16653a" stroke-width="2.5"/>
      <image xlink:href="data:image/png;base64,{hm_b64}"
             x="492" y="84" width="152" height="152" preserveAspectRatio="xMidYMid slice"/>
    </g>

    <!-- argmax arrow: 3D Scene → (x,y,z) PARA (4px gap before PARA top y=299) -->
    <line x1="568" y1="244" x2="568" y2="289" stroke="#16653a" stroke-width="3.5" marker-end="url(#arrow-green)"/>
    <text x="578" y="268" font-size="11" font-weight="600" fill="#16653a" font-style="italic">argmax</text>

    <!-- (x,y,z) PARA output (below 3D Scene) — lowered per user edit -->
    <g id="box-para-out">
      <rect x="506" y="299" width="124" height="44" rx="10" fill="#f0fdf4" stroke="#16653a" stroke-width="2"/>
      <text x="568" y="321" text-anchor="middle" font-size="16" font-weight="800" fill="#16653a">(x, y, z)</text>
      <text x="568" y="336" text-anchor="middle" font-size="10" font-weight="700" fill="#16653a">PARA</text>
    </g>

    <!-- CNN block: Image Features → Heatmap Volume -->
    <g id="box-cnn">
      <rect x="300" y="256" width="60" height="34" rx="8" fill="#f0fdf4" stroke="#16653a" stroke-width="2"/>
      <text x="330" y="278" text-anchor="middle" font-size="13" font-weight="800" fill="#16653a">CNN</text>
    </g>
    <!-- Up arrow: PCA top → CNN bottom (4px gap before y=290) -->
    <line x1="330" y1="318" x2="330" y2="300" stroke="#16653a" stroke-width="3" marker-end="url(#arrow-green)"/>
    <!-- Up arrow: CNN top → Heatmap bottom (4px gap before y=240) -->
    <line x1="330" y1="254" x2="330" y2="250" stroke="#16653a" stroke-width="3" marker-end="url(#arrow-green)"/>

    <!-- Image box (above DINO) -->
    <text x="134" y="154" text-anchor="middle" font-size="13" font-weight="700" fill="#1f2937">Image</text>
    <g id="box-image">
      <rect x="44" y="160" width="180" height="160" rx="10" fill="#ffffff" stroke="#bfdbfe" stroke-width="2"/>
      <image xlink:href="data:image/png;base64,{rgb_b64}"
             x="48" y="164" width="172" height="152" preserveAspectRatio="xMidYMid slice"/>
    </g>

    <!-- Down arrow: Image → DINO (4px gap before y=360) -->
    <line x1="134" y1="322" x2="134" y2="350" stroke="#9ca3af" stroke-width="2.5" marker-end="url(#arrow-gray)"/>

    <!-- DINO box (below Image) -->
    <g id="box-dino">
      <rect x="94" y="360" width="80" height="80" rx="10" fill="#fffbeb" stroke="#d97e1f" stroke-width="2.5"/>
      <text x="134" y="406" text-anchor="middle" font-size="16" font-weight="800" fill="#d97e1f">DINO</text>
    </g>

    <!-- Right arrow: DINO → Image Features (PCA) (4px gap before x=240) -->
    <line x1="176" y1="400" x2="230" y2="400" stroke="#9ca3af" stroke-width="2.5" marker-end="url(#arrow-gray)"/>

    <!-- Image Features (PCA) box -->
    <g id="box-pca">
      <rect x="240" y="320" width="180" height="160" rx="10" fill="#ffffff" stroke="#6b7280" stroke-width="2"/>
      <image xlink:href="data:image/png;base64,{pca_b64}"
             x="244" y="324" width="172" height="152" preserveAspectRatio="xMidYMid slice"/>

      <!-- CLS badge: flatter, wider, lower per user edit -->
      <g id="cls-badge">
        <rect x="313" y="457" width="100" height="17" rx="5" fill="#ffffff" fill-opacity="0.92"
              stroke="#a12029" stroke-width="1.5"/>
        <text x="363" y="470" text-anchor="middle" font-size="10" font-weight="700" fill="#a12029">CLS token</text>
      </g>
    </g>

    <!-- Bottom-track label: positioned under the CLS Token / bottom row -->
    <text x="363" y="510" text-anchor="middle" font-size="13" font-weight="800" fill="#a12029" letter-spacing="0.03em">Global EEF Regression (Traditional)</text>

    <!-- Right path: CLS → MLP → (x,y,z) ACT — lowered by 10 to align with new CLS y-center 466 -->
    <line x1="414" y1="466" x2="434" y2="466" stroke="#a12029" stroke-width="3" marker-end="url(#arrow-red)"/>

    <g id="box-mlp">
      <rect x="444" y="446" width="60" height="40" rx="8" fill="#fef2f2" stroke="#a12029" stroke-width="2"/>
      <text x="474" y="471" text-anchor="middle" font-size="13" font-weight="800" fill="#a12029">MLP</text>
    </g>

    <line x1="506" y1="466" x2="520" y2="466" stroke="#a12029" stroke-width="3" marker-end="url(#arrow-red)"/>

    <g id="box-act-out">
      <rect x="530" y="438" width="124" height="56" rx="10" fill="#fef2f2" stroke="#a12029" stroke-width="2"/>
      <text x="592" y="464" text-anchor="middle" font-size="16" font-weight="800" fill="#a12029">(x, y, z)</text>
      <text x="592" y="484" text-anchor="middle" font-size="11" font-weight="700" fill="#a12029">ACT</text>
    </g>

    <line x1="48" y1="530" x2="660" y2="530" stroke="#f3f4f6" stroke-width="1"/>
    <text x="354" y="555" text-anchor="middle" font-size="13" fill="#4b5563">
      CLS collapses spatial structure.
      <tspan font-weight="700" fill="#16653a">PARA preserves it.</tspan>
    </text>
  </g>
</svg>
'''

with open("/data/cameron/para/paper/figs/svg/fig1_overview.svg", "w") as f:
    f.write(svg)
print(f"[{time.time()-_t_start:5.2f}s] wrote fig1_overview.svg ({len(svg)} bytes)")
