"""Build fig4_ood.svg — OOD generalization, 2 rows × 3 cols.

Row 1: Spatial / object position generalization
Row 2: Viewpoint generalization

Each row: (a/d) train/test distribution, (b/e) chart vs PARA/ACT,
(c/f) qualitative PARA success vs ACT failure at an OOD condition.

Charts are generated with matplotlib (matching fig4b_pertheta style)
and embedded as PNGs into the SVG.
"""

import base64
import os
import time
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

_t = time.time()

PARA_GREEN = "#16653a"
ACT_RED = "#a12029"
TEXT_DARK = "#0f172a"

CHART_OUT = "/data/cameron/penpot/figures/extracted/fig4"
os.makedirs(CHART_OUT, exist_ok=True)


def _styled_chart_axes(ax):
    """Apply common fig4b-style axes styling."""
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_color("#888888")
    ax.spines["bottom"].set_color("#888888")
    ax.tick_params(colors="#444444", labelsize=11)


def render_spatial_chart(path):
    """Row 1 middle chart: success rate vs distance from train boundary.

    Reads /tmp/fig4_spatial_per_dx.json (per-position eval, 19 test positions
    × 5 episodes each) and plots scatter + binned mean line for PARA and ACT.
    """
    import json
    with open("/tmp/fig4_spatial_per_dx.json") as f:
        data = json.load(f)

    results = data["results"] if isinstance(data, dict) and "results" in data else data
    pts = sorted(results, key=lambda r: r["dx"])
    dx_cm = [r["dx"] * 100 for r in pts]   # meters → centimeters
    para = [r["para"] * 100 for r in pts]
    act  = [r["act"]  * 100 for r in pts]

    fig, ax = plt.subplots(figsize=(6, 3.8), dpi=300)

    for y in (25, 50, 75):
        ax.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=1)

    # Boundary marker (train ↔ test divider at dx=0)
    ax.axvline(0, color="#94a3b8", linestyle=":", linewidth=1.2, zorder=2)

    # Scatter (raw per-position)
    ax.scatter(dx_cm, para, s=42, color=PARA_GREEN, edgecolor="white",
               linewidth=1.0, zorder=4, alpha=0.85)
    ax.scatter(dx_cm, act, s=42, color=ACT_RED, edgecolor="white",
               linewidth=1.0, marker="s", zorder=3, alpha=0.85)

    # Binned mean line (5 bins across dx range)
    import numpy as np
    bins = np.linspace(min(dx_cm) - 0.1, max(dx_cm) + 0.1, 6)
    bin_centers, bin_para, bin_act = [], [], []
    for i in range(len(bins) - 1):
        lo, hi = bins[i], bins[i + 1]
        members_p = [p for d, p in zip(dx_cm, para) if lo <= d < hi]
        members_a = [a for d, a in zip(dx_cm, act)  if lo <= d < hi]
        if members_p:
            bin_centers.append((lo + hi) / 2)
            bin_para.append(sum(members_p) / len(members_p))
            bin_act.append(sum(members_a) / len(members_a))

    ax.plot(bin_centers, bin_para, color=PARA_GREEN, linewidth=2.4,
            marker="o", markersize=10, markerfacecolor=PARA_GREEN,
            markeredgecolor="white", markeredgewidth=1.5,
            label="PARA", zorder=6)
    ax.plot(bin_centers, bin_act, color=ACT_RED, linewidth=2.4,
            marker="s", markersize=10, markerfacecolor=ACT_RED,
            markeredgecolor="white", markeredgewidth=1.5,
            linestyle="--", label="ACT", zorder=5)

    ax.set_xlabel("Distance from train boundary (cm)", fontsize=12)
    ax.set_ylabel("Test-set Success Rate (%)", fontsize=12)
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.set_ylim(-3, 105)
    _styled_chart_axes(ax)

    # "train side / test side" annotation
    ax.text(min(dx_cm) - 1, 95, "train\nside", fontsize=9, color="#94a3b8",
            ha="left", va="top", fontweight="700", linespacing=1.0)
    ax.text(max(dx_cm) + 1, 95, "test\nside", fontsize=9, color="#94a3b8",
            ha="right", va="top", fontweight="700", linespacing=1.0)

    leg = ax.legend(loc="upper right", frameon=True, fontsize=10,
                    edgecolor="#cccccc", facecolor="white")
    leg.get_frame().set_linewidth(0.8)

    plt.tight_layout()
    fig.savefig(path, format="png", bbox_inches="tight", dpi=300,
                facecolor="white", edgecolor="none")
    plt.close(fig)


def render_viewpoint_chart(path):
    """Row 2 middle chart: per-theta (same data as fig4b_pertheta)."""
    fig, ax = plt.subplots(figsize=(6, 3.8), dpi=300)
    theta = [0, 3.6, 7.1, 10.7, 14.3, 17.9, 21.4, 25]
    para  = [88, 79, 62, 63, 62, 62, 33, 38]
    act   = [67, 54, 42, 17, 12,  0,  0,  0]

    for y in (25, 50, 75):
        ax.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=1)

    ax.plot(theta, para, color=PARA_GREEN, linewidth=2.2, marker="o",
            markersize=9, markerfacecolor=PARA_GREEN, markeredgecolor="white",
            markeredgewidth=1.2, label="PARA", zorder=5)
    ax.plot(theta, act, color=ACT_RED, linewidth=2.2, marker="s",
            markersize=9, markerfacecolor=ACT_RED, markeredgecolor="white",
            markeredgewidth=1.2, linestyle="--", label="ACT", zorder=4)

    ax.set_xlabel("Camera elevation angle θ (degrees)", fontsize=12)
    ax.set_ylabel("Test-set Success Rate (%)", fontsize=12)
    ax.set_xlim(-1, 26)
    ax.set_ylim(0, 100)
    ax.set_xticks(theta)
    ax.set_xticklabels(["0°\n(train)", "3.6", "7.1", "10.7",
                        "14.3", "17.9", "21.4", "25"])
    ax.set_yticks([0, 25, 50, 75, 100])
    _styled_chart_axes(ax)

    leg = ax.legend(loc="upper right", frameon=True, fontsize=11,
                    edgecolor="#cccccc", facecolor="white")
    leg.get_frame().set_linewidth(0.8)

    plt.tight_layout()
    fig.savefig(path, format="png", bbox_inches="tight", dpi=300,
                facecolor="white", edgecolor="none")
    plt.close(fig)


spatial_chart_path = f"{CHART_OUT}/spatial_chart.png"
viewpoint_chart_path = f"{CHART_OUT}/viewpoint_chart.png"
render_spatial_chart(spatial_chart_path)
render_viewpoint_chart(viewpoint_chart_path)
print(f"[{time.time()-_t:.2f}s] generated charts")


def b64(p):
    return base64.b64encode(open(p, "rb").read()).decode()


DASHBOARD = "/data/cameron/para/.agents/reports/project_site/media"

# Images for the 6 cells
spatial_dist_b64 = b64(f"{DASHBOARD}/exp3_leftright_distribution.png")
vp_dist_b64 = b64(f"{CHART_OUT}/vp_polar_only.png")
spatial_chart_b64 = b64(spatial_chart_path)
viewpoint_chart_b64 = b64(viewpoint_chart_path)
act_spatial_b64 = b64("/data/cameron/penpot/figures/extracted/4c_act_pos_extreme.png")
para_spatial_b64 = b64("/data/cameron/penpot/figures/extracted/4c_para_pos_extreme.png")
act_vp_b64 = b64("/data/cameron/penpot/figures/extracted/4c_act_vp_extreme.png")
para_vp_b64 = b64("/data/cameron/penpot/figures/extracted/4c_para_vp_extreme.png")

GREEN = PARA_GREEN
RED = ACT_RED
GRAY = "#6b7280"
LIGHT = "#e5e7eb"
TXT = "#0f172a"

# ═══════════════════════════════════════════════════════════════════════════
# Layout: 1400 x 820, 2 rows × 3 cols
# Rows are 50px taller each so charts/frames have more vertical room.
# ═══════════════════════════════════════════════════════════════════════════

ROW1_Y = 50
ROW1_H = 360
ROW2_Y = 440
ROW2_H = 360

COL_X = [20, 480, 940]
COL_W = [440, 440, 440]

# Qualitative-frame geometry (side-by-side, square)
FRAME_SIZE = 200
FRAME_GAP = 10


def card(x, y, w, h, label, body, body_filter=True):
    """Wrap card: outer rect + sub-label + body SVG."""
    return f'''
    <g>
      <rect x="{x}" y="{y}" width="{w}" height="{h}" rx="10" fill="#ffffff"
            stroke="{LIGHT}" stroke-width="1.2" filter="url(#card-shadow)"/>
      <text x="{x + 16}" y="{y + 22}" font-size="11" font-weight="800" fill="{GRAY}" letter-spacing="0.04em">{label}</text>
      {body}
    </g>
    '''


# ── Row 1 cells ──

# (a) Spatial distribution — image fills most of the cell
r1c1_body = f'''
  <clipPath id="clip-r1c1"><rect x="{COL_X[0] + 14}" y="{ROW1_Y + 40}" width="{COL_W[0] - 28}" height="{ROW1_H - 54}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{spatial_dist_b64}"
         x="{COL_X[0] + 14}" y="{ROW1_Y + 40}" width="{COL_W[0] - 28}" height="{ROW1_H - 54}"
         preserveAspectRatio="xMidYMid meet" clip-path="url(#clip-r1c1)"/>
'''

# (b) Spatial chart
r1c2_body = f'''
  <clipPath id="clip-r1c2"><rect x="{COL_X[1] + 14}" y="{ROW1_Y + 40}" width="{COL_W[1] - 28}" height="{ROW1_H - 54}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{spatial_chart_b64}"
         x="{COL_X[1] + 14}" y="{ROW1_Y + 40}" width="{COL_W[1] - 28}" height="{ROW1_H - 54}"
         preserveAspectRatio="xMidYMid meet" clip-path="url(#clip-r1c2)"/>
'''

# (c) Spatial qualitative — two frames side-by-side with annotation circles
# Frame layout: ACT (left) and PARA (right), each 200x200, centered in the cell
_r1c3_x0 = COL_X[2] + (COL_W[2] - (2 * FRAME_SIZE + FRAME_GAP)) // 2
_r1c3_act_x = _r1c3_x0
_r1c3_para_x = _r1c3_x0 + FRAME_SIZE + FRAME_GAP
_r1c3_y = ROW1_Y + 50  # after card label

# Spatial frame approx circle positions (source 448x448, slice scale = 200/448 = 0.446)
# Visible image: entire source (square → square)
# ACT_fail: gripper is hovering ~center-upper area; target (plate) is center-lower
_sp_scale = FRAME_SIZE / 448
_sp_act_grip  = (int(250 * _sp_scale), int(185 * _sp_scale))  # ACT gripper (wrong spot)
_sp_bowl      = (int(325 * _sp_scale), int(330 * _sp_scale))  # target plate
_sp_para_grip = (int(250 * _sp_scale), int(270 * _sp_scale))  # PARA gripper (on target)

r1c3_body = f'''
  <!-- ACT failure (left) -->
  <clipPath id="clip-r1c3-act"><rect x="{_r1c3_act_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{act_spatial_b64}"
         x="{_r1c3_act_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}"
         preserveAspectRatio="xMidYMid slice" clip-path="url(#clip-r1c3-act)"/>
  <rect x="{_r1c3_act_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"
        fill="none" stroke="{RED}" stroke-width="3"/>
  <rect x="{_r1c3_act_x + 6}" y="{_r1c3_y + 6}" width="58" height="20" rx="4" fill="{RED}"/>
  <text x="{_r1c3_act_x + 35}" y="{_r1c3_y + 21}" text-anchor="middle" font-size="11" font-weight="800" fill="#ffffff">ACT ✗</text>
  <!-- Red circle on ACT gripper (wrong), green circle on target -->
  <circle cx="{_r1c3_act_x + _sp_act_grip[0]}" cy="{_r1c3_y + _sp_act_grip[1]}" r="18"
          fill="none" stroke="{RED}" stroke-width="3"/>
  <circle cx="{_r1c3_act_x + _sp_bowl[0]}" cy="{_r1c3_y + _sp_bowl[1]}" r="18"
          fill="none" stroke="{GREEN}" stroke-width="3"/>

  <!-- PARA success (right) -->
  <clipPath id="clip-r1c3-para"><rect x="{_r1c3_para_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{para_spatial_b64}"
         x="{_r1c3_para_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}"
         preserveAspectRatio="xMidYMid slice" clip-path="url(#clip-r1c3-para)"/>
  <rect x="{_r1c3_para_x}" y="{_r1c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"
        fill="none" stroke="{GREEN}" stroke-width="3"/>
  <rect x="{_r1c3_para_x + 6}" y="{_r1c3_y + 6}" width="64" height="20" rx="4" fill="{GREEN}"/>
  <text x="{_r1c3_para_x + 38}" y="{_r1c3_y + 21}" text-anchor="middle" font-size="11" font-weight="800" fill="#ffffff">PARA ✓</text>
  <!-- Green circle on PARA gripper (at target) -->
  <circle cx="{_r1c3_para_x + _sp_para_grip[0]}" cy="{_r1c3_y + _sp_para_grip[1]}" r="20"
          fill="none" stroke="{GREEN}" stroke-width="3"/>

  <!-- Caption below frames -->
  <text x="{COL_X[2] + COL_W[2] / 2}" y="{_r1c3_y + FRAME_SIZE + 28}" text-anchor="middle"
        font-size="11" font-weight="600" fill="{GRAY}">
    ACT overshoots target · PARA reaches bowl
  </text>
'''

# ── Row 2 cells ──

# (d) Viewpoint distribution
r2c1_body = f'''
  <clipPath id="clip-r2c1"><rect x="{COL_X[0] + 14}" y="{ROW2_Y + 40}" width="{COL_W[0] - 28}" height="{ROW2_H - 54}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{vp_dist_b64}"
         x="{COL_X[0] + 14}" y="{ROW2_Y + 40}" width="{COL_W[0] - 28}" height="{ROW2_H - 54}"
         preserveAspectRatio="xMidYMid meet" clip-path="url(#clip-r2c1)"/>
'''

# (e) Viewpoint chart
r2c2_body = f'''
  <clipPath id="clip-r2c2"><rect x="{COL_X[1] + 14}" y="{ROW2_Y + 40}" width="{COL_W[1] - 28}" height="{ROW2_H - 54}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{viewpoint_chart_b64}"
         x="{COL_X[1] + 14}" y="{ROW2_Y + 40}" width="{COL_W[1] - 28}" height="{ROW2_H - 54}"
         preserveAspectRatio="xMidYMid meet" clip-path="url(#clip-r2c2)"/>
'''

# (f) Viewpoint qualitative — two frames side-by-side with annotation circles
# Viewpoint frames are ~448x373 (aspect 1.20). With slice into 200x200, scale = 0.536.
# Horizontal crop: (448*0.536 - 200)/2 = 20 px each side in SVG coords.
_r2c3_x0 = COL_X[2] + (COL_W[2] - (2 * FRAME_SIZE + FRAME_GAP)) // 2
_r2c3_act_x = _r2c3_x0
_r2c3_para_x = _r2c3_x0 + FRAME_SIZE + FRAME_GAP
_r2c3_y = ROW2_Y + 50

# Source pixel → SVG frame coord (slice mode, xMidYMid, 448x373 source)
_vp_scale = max(FRAME_SIZE / 448, FRAME_SIZE / 373)  # slice = max scale
# Scaled image: 448*s x 373*s. Crop: horizontal (448*s - 200)/2, vertical (373*s - 200)/2

def _vp_svg(px, py):
    sx = px * _vp_scale - (448 * _vp_scale - FRAME_SIZE) / 2
    sy = py * _vp_scale - (373 * _vp_scale - FRAME_SIZE) / 2
    return (sx, sy)

_vp_act_grip  = _vp_svg(380, 110)
_vp_bowl      = _vp_svg(385, 290)
_vp_para_grip = _vp_svg(385, 250)

r2c3_body = f'''
  <!-- ACT failure (left) — extreme viewpoint pulled from vp_default_to_all rollout grid, cell r3c0 -->
  <clipPath id="clip-r2c3-act"><rect x="{_r2c3_act_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{act_vp_b64}"
         x="{_r2c3_act_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}"
         preserveAspectRatio="xMidYMid slice" clip-path="url(#clip-r2c3-act)"/>
  <rect x="{_r2c3_act_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"
        fill="none" stroke="{RED}" stroke-width="3"/>
  <rect x="{_r2c3_act_x + 6}" y="{_r2c3_y + 6}" width="58" height="20" rx="4" fill="{RED}"/>
  <text x="{_r2c3_act_x + 35}" y="{_r2c3_y + 21}" text-anchor="middle" font-size="11" font-weight="800" fill="#ffffff">ACT ✗</text>

  <!-- PARA success (right) — same extreme viewpoint, PARA grid cell r3c0 -->
  <clipPath id="clip-r2c3-para"><rect x="{_r2c3_para_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"/></clipPath>
  <image xlink:href="data:image/png;base64,{para_vp_b64}"
         x="{_r2c3_para_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}"
         preserveAspectRatio="xMidYMid slice" clip-path="url(#clip-r2c3-para)"/>
  <rect x="{_r2c3_para_x}" y="{_r2c3_y}" width="{FRAME_SIZE}" height="{FRAME_SIZE}" rx="6"
        fill="none" stroke="{GREEN}" stroke-width="3"/>
  <rect x="{_r2c3_para_x + 6}" y="{_r2c3_y + 6}" width="64" height="20" rx="4" fill="{GREEN}"/>
  <text x="{_r2c3_para_x + 38}" y="{_r2c3_y + 21}" text-anchor="middle" font-size="11" font-weight="800" fill="#ffffff">PARA ✓</text>

  <!-- Caption -->
  <text x="{COL_X[2] + COL_W[2] / 2}" y="{_r2c3_y + FRAME_SIZE + 28}" text-anchor="middle"
        font-size="11" font-weight="600" fill="{GRAY}">
    ACT collapses at new viewpoint · PARA tracks target
  </text>
'''

svg = f'''<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
     viewBox="0 0 1400 820" width="1400" height="820"
     font-family="Inter, Arial, sans-serif">
  <defs>
    <filter id="card-shadow" x="-10%" y="-10%" width="120%" height="130%">
      <feDropShadow dx="0" dy="1.5" stdDeviation="2.5" flood-color="#000" flood-opacity="0.08"/>
    </filter>
  </defs>

  <rect width="1400" height="820" fill="#ffffff"/>

  <!-- Row 1 header -->
  <text x="30" y="32" font-size="15" font-weight="800" fill="{TXT}">(a) Spatial / Object Position Generalization</text>
  <text x="30" y="422" font-size="15" font-weight="800" fill="{TXT}">(b) Viewpoint Generalization</text>

  {card(COL_X[0], ROW1_Y, COL_W[0], ROW1_H, "(a) TRAIN / TEST DISTRIBUTION", r1c1_body)}
  {card(COL_X[1], ROW1_Y, COL_W[1], ROW1_H, "(b) DISTANCE FROM TRAIN → SUCCESS", r1c2_body)}
  {card(COL_X[2], ROW1_Y, COL_W[2], ROW1_H, "(c) QUALITATIVE · same OOD position", r1c3_body)}

  {card(COL_X[0], ROW2_Y, COL_W[0], ROW2_H, "(d) TRAIN / TEST VIEWPOINTS", r2c1_body)}
  {card(COL_X[1], ROW2_Y, COL_W[1], ROW2_H, "(e) CAMERA ANGLE → SUCCESS", r2c2_body)}
  {card(COL_X[2], ROW2_Y, COL_W[2], ROW2_H, "(f) QUALITATIVE · same OOD viewpoint", r2c3_body)}
</svg>
'''

out = "/data/cameron/para/paper/figs/svg/fig4_ood.svg"
with open(out, "w") as f:
    f.write(svg)
print(f"[{time.time()-_t:.2f}s] wrote {out} ({len(svg)} bytes)")
