"""Render all 6 paper figures as PNGs (and SVGs where pure-vector).

Output: /data/cameron/para/paper/figs/generated/figN_*.png

Run: python3 render_paper_figures.py [--figs 1,2,3,4,5,6 | all]
"""
import argparse
import os
import sys

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import matplotlib.image as mpimg
import numpy as np

# ── Style ──
plt.rcParams["font.family"] = "DejaVu Sans"
plt.rcParams["font.size"] = 11

PARA_GREEN = "#16653a"
ACT_RED    = "#a12029"
NEUTRAL    = "#71717a"
TEXT_DARK  = "#1f2d3d"
SUBTLE     = "#5a5a5a"
BG_LIGHT   = "#f8f9fa"
BORDER     = "#d4d4d8"

OUT_DIRS = [
    "/data/cameron/para/paper/figs/generated",
    "/data/cameron/para/.agents/reports/paper_figures/media",
]
EXTRACTED = "/data/cameron/penpot/figures/extracted"
EXISTING_FIGS = "/data/cameron/para/paper/figs/figma"
DASHBOARD_MEDIA = "/data/cameron/para/.agents/reports/project_site/media"

for d in OUT_DIRS:
    os.makedirs(d, exist_ok=True)


# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def save(fig, name, also_svg=False):
    """Save to all output locations."""
    for out_dir in OUT_DIRS:
        fig.savefig(f"{out_dir}/{name}.png", format="png", bbox_inches="tight",
                    dpi=200, facecolor="white")
        if also_svg:
            fig.savefig(f"{out_dir}/{name}.svg", format="svg", bbox_inches="tight",
                        facecolor="white")
    plt.close(fig)
    print(f"  wrote {name}.png to {len(OUT_DIRS)} locations" +
          (" + .svg" if also_svg else ""))


def axis_off(ax):
    ax.set_xticks([])
    ax.set_yticks([])
    for s in ax.spines.values():
        s.set_visible(False)


def rounded_box(ax, x, y, w, h, fill, edgecolor=None, linewidth=1.5, rad=0.06):
    box = patches.FancyBboxPatch(
        (x, y), w, h,
        boxstyle=f"round,pad=0,rounding_size={rad}",
        facecolor=fill, edgecolor=edgecolor or fill,
        linewidth=linewidth, transform=ax.transData, clip_on=False, zorder=2)
    ax.add_patch(box)
    return box


def arrow(ax, x1, y1, x2, y2, color=TEXT_DARK, lw=1.6):
    ax.annotate("", xy=(x2, y2), xytext=(x1, y1),
                arrowprops=dict(arrowstyle="->", color=color, lw=lw),
                zorder=3)


# ─────────────────────────────────────────────────────────────────────────────
# Figure 1: Overview teaser
# ─────────────────────────────────────────────────────────────────────────────

def render_fig1():
    """Fig 1: Overview teaser. Working draft — to be polished in Figma later.
    No top title (paper caption owns it). No italic punchline.
    """
    fig = plt.figure(figsize=(12, 5.0), dpi=300)

    gs = gridspec.GridSpec(1, 3, figure=fig, left=0.04, right=0.96,
                           top=0.96, bottom=0.04, wspace=0.18,
                           width_ratios=[1.4, 1.0, 1.0])

    # ── Panel (a): Architecture comparison ──
    axA = fig.add_subplot(gs[0, 0])
    axA.set_xlim(0, 10)
    axA.set_ylim(0, 6.5)
    axis_off(axA)
    axA.text(5, 6.2, "(a) Architecture Comparison",
             ha="center", va="bottom", fontsize=12, fontweight="700")

    # Global Regression row
    axA.text(0.2, 5.4, "Global Regression (ACT)",
             ha="left", va="center", fontsize=9.5, fontweight="700", color=ACT_RED)
    g_boxes = [("Image",   "#eaf2fb"),
               ("DINO",    "#fff3e6"),
               ("CLS",     "#f5e6fb"),
               ("MLP",     "#ffffff"),
               ("(x,y,z)", "#fde9ec")]
    g_x = 0.15
    g_y = 4.3
    bw, bh = 1.45, 0.7
    gap = 0.4
    for i, (label, fill) in enumerate(g_boxes):
        bx = g_x + i * (bw + gap)
        rounded_box(axA, bx, g_y, bw, bh, fill, edgecolor=NEUTRAL)
        axA.text(bx + bw/2, g_y + bh/2, label, ha="center", va="center",
                 fontsize=9, fontweight="600", zorder=3)
        if i < len(g_boxes) - 1:
            arrow(axA, bx + bw + 0.05, g_y + bh/2, bx + bw + gap - 0.05, g_y + bh/2)

    # PARA row
    axA.text(0.2, 3.0, "PARA (ours)",
             ha="left", va="center", fontsize=9.5, fontweight="700", color=PARA_GREEN)
    p_boxes = [("Image",     "#eaf2fb"),
               ("DINO",      "#fff3e6"),
               ("Heatmap\nVolume", "#fde9ec"),
               ("Argmax",    "#ffffff"),
               ("3D Point",  "#e8f5ec")]
    p_y = 1.9
    for i, (label, fill) in enumerate(p_boxes):
        bx = g_x + i * (bw + gap)
        rounded_box(axA, bx, p_y, bw, bh, fill, edgecolor=NEUTRAL)
        axA.text(bx + bw/2, p_y + bh/2, label, ha="center", va="center",
                 fontsize=9, fontweight="600", zorder=3)
        if i < len(p_boxes) - 1:
            arrow(axA, bx + bw + 0.05, p_y + bh/2, bx + bw + gap - 0.05, p_y + bh/2)

    # (italic punchline removed — paper caption owns it)

    # ── Panel (b): Vignettes ──
    axB = fig.add_subplot(gs[0, 1])
    axB.set_xlim(0, 10)
    axB.set_ylim(0, 6.5)
    axis_off(axB)
    axB.text(5, 6.2, "(b) Where PARA Helps",
             ha="center", va="bottom", fontsize=12, fontweight="700")

    vignettes = [
        ("OOD Generalization",  "Shifted camera & objects",  PARA_GREEN),
        ("Video Backbone",      "SVD video → robot",         "#3b6fa6"),
        ("Cross-Embodiment",    "Arm-deleted point tracks",  "#5a3da2"),
    ]
    vy_start = 5.2
    vh = 1.5
    for i, (title, sub, color) in enumerate(vignettes):
        vy = vy_start - i * vh
        # Icon circle
        circle = patches.Circle((0.9, vy), 0.4, facecolor=color, edgecolor="white",
                                linewidth=2, zorder=3)
        axB.add_patch(circle)
        axB.text(1.7, vy + 0.15, title, ha="left", va="center",
                 fontsize=11, fontweight="700")
        axB.text(1.7, vy - 0.25, sub, ha="left", va="center",
                 fontsize=9, color=SUBTLE)

    # ── Panel (c): Headlines ──
    axC = fig.add_subplot(gs[0, 2])
    axC.set_xlim(0, 10)
    axC.set_ylim(0, 6.5)
    axis_off(axC)
    axC.text(5, 6.2, "(c) Headline Results",
             ha="center", va="bottom", fontsize=12, fontweight="700")

    headlines = [
        ("97% vs 9%",  "Real Robot (20 demos)",      PARA_GREEN),
        ("90% vs 0%",  "Video Backbone",              PARA_GREEN),
        ("[TBD]",      "Point Track Pretraining",     NEUTRAL),
    ]
    hy_start = 5.0
    hh = 1.55
    for i, (big, sub, color) in enumerate(headlines):
        hy = hy_start - i * hh
        # Card background
        rounded_box(axC, 0.5, hy - 0.55, 9.0, 1.25, "#ffffff",
                    edgecolor=BORDER, linewidth=1.2, rad=0.1)
        axC.text(5, hy + 0.15, big, ha="center", va="center",
                 fontsize=22, fontweight="800", color=color)
        axC.text(5, hy - 0.4, sub, ha="center", va="center",
                 fontsize=10, color=SUBTLE, fontweight="500")

    save(fig, "fig1_overview")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 2: Method details
# ─────────────────────────────────────────────────────────────────────────────

def render_fig2():
    """Fig 2: Method pipeline + height illustration. No top title (paper caption owns it)."""
    fig = plt.figure(figsize=(12, 5.6), dpi=300)
    # No suptitle — paper caption handles it

    gs = gridspec.GridSpec(1, 2, figure=fig, left=0.02, right=0.98,
                           top=0.92, bottom=0.04, wspace=0.06,
                           width_ratios=[1.45, 1.0])  # b panel ~40% width for readable text

    # Sub-titles (smaller, lighter)
    fig.text(0.30, 0.96, "(a) Inference pipeline visualization",
             fontsize=11, fontweight="600", color="#555555", ha="center")
    fig.text(0.76, 0.96, "(b) Height vs Depth",
             fontsize=11, fontweight="600", color="#555555", ha="center")

    # ── Panel (a): 4-frame method strip with arrows between ──
    # Use a 7-column grid: img | arrow | img | arrow | img | arrow | img
    gsA = gridspec.GridSpecFromSubplotSpec(
        1, 7, subplot_spec=gs[0, 0],
        width_ratios=[4, 0.5, 4, 0.5, 4, 0.5, 4], wspace=0.0)

    captions = ["Camera Frustum", "Heatmap Volume",
                "Argmax → 3D Target", "Robot at Target"]
    paths = [
        f"{EXTRACTED}/frame_2_stage1_clean.png",
        f"{EXTRACTED}/frame_2_stage2_clean.png",
        f"{EXTRACTED}/frame_2_stage3_clean.png",
        f"{EXTRACTED}/frame_2_stage4_clean.png",
    ]
    img_cols = [0, 2, 4, 6]
    arrow_cols = [1, 3, 5]
    for i, (col, p, cap) in enumerate(zip(img_cols, paths, captions)):
        ax = fig.add_subplot(gsA[0, col])
        img = mpimg.imread(p)
        ax.imshow(img)
        ax.set_xticks([])
        ax.set_yticks([])
        # Light gray border around each frame
        for s in ax.spines.values():
            s.set_visible(True)
            s.set_color("#cccccc")
            s.set_linewidth(1.0)
        ax.set_title(cap, fontsize=10, fontweight="600", color=SUBTLE, pad=4)

    for col in arrow_cols:
        ax = fig.add_subplot(gsA[0, col])
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        axis_off(ax)
        ax.annotate("", xy=(0.92, 0.5), xytext=(0.08, 0.5),
                    arrowprops=dict(
                        arrowstyle="-|>,head_length=0.7,head_width=0.5",
                        color="#1f2d3d", lw=2.8))

    # ── Panel (b): Height illustration ──
    axB = fig.add_subplot(gs[0, 1])
    img_h = mpimg.imread(f"{EXISTING_FIGS}/height_illustration.png")
    axB.imshow(img_h)
    axB.set_xticks([])
    axB.set_yticks([])
    for s in axB.spines.values():
        s.set_visible(True)
        s.set_color("#cccccc")
        s.set_linewidth(1.0)

    save(fig, "fig2_method")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 3: Real robot results — single image
# ─────────────────────────────────────────────────────────────────────────────

def render_fig3():
    img = mpimg.imread(f"{EXISTING_FIGS}/para_results.png")
    h, w = img.shape[:2]
    aspect = h / w
    fig_w = 12
    fig_h = fig_w * aspect + 0.6  # extra space for title
    fig = plt.figure(figsize=(fig_w, fig_h), dpi=200)
    fig.suptitle("Figure 3: Real Robot Results (SO-100, 20 demos)",
                 fontsize=16, fontweight="700", y=0.99, color=TEXT_DARK)
    ax = fig.add_subplot(111)
    ax.imshow(img)
    axis_off(ax)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    save(fig, "fig3_realrobot")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 4: OOD Analysis (3 panels)
# ─────────────────────────────────────────────────────────────────────────────

def render_fig4():
    """Composite Fig 4: stacked (a)+(b) on left, (c) on right (taller layout)."""
    fig = plt.figure(figsize=(12, 7.5), dpi=300)
    # No suptitle — paper caption owns it

    gs = gridspec.GridSpec(1, 2, figure=fig, left=0.04, right=0.985,
                           top=0.96, bottom=0.05, wspace=0.13,
                           width_ratios=[1.05, 1.0])

    # Left column: stack (a) on top, (b) on bottom
    gsL = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=gs[0, 0],
                                           hspace=0.40,
                                           height_ratios=[1.0, 1.05])

    # Sub-labels — (a) nudged left to avoid overlap with distribution plot's
    # built-in legend text ("Left half (j=0~7)...")
    fig.text(0.07, 0.965, "(a) Spatial Extrapolation",
             fontsize=12, fontweight="600", color="#555555", ha="left")
    fig.text(0.07, 0.488, "(b) Per-Angle Viewpoint Robustness",
             fontsize=12, fontweight="600", color="#555555", ha="left")
    fig.text(0.555, 0.965, "(c) Qualitative Comparison",
             fontsize=12, fontweight="600", color="#555555", ha="left")

    # ── Panel (a): distribution image + bar chart side-by-side inside ──
    gsA = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gsL[0, 0],
                                           wspace=0.25, width_ratios=[1.4, 1.0])
    axA1 = fig.add_subplot(gsA[0, 0])
    img = mpimg.imread(f"{DASHBOARD_MEDIA}/exp3_leftright_distribution.png")
    axA1.imshow(img)
    axis_off(axA1)

    axA2 = fig.add_subplot(gsA[0, 1])
    for y in (25, 50):
        axA2.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=0)
    bars = axA2.bar(["PARA", "ACT"], [54, 1],
                    color=[PARA_GREEN, ACT_RED], width=0.55)
    for b, val in zip(bars, [54, 1]):
        axA2.annotate(f"{val}%",
                      xy=(b.get_x() + b.get_width()/2, b.get_height()),
                      xytext=(0, 3), textcoords="offset points",
                      ha="center", va="bottom", fontsize=12, fontweight="700")
    axA2.set_ylim(0, 70)
    axA2.set_ylabel("Success Rate (%)", fontsize=10)
    axA2.tick_params(labelsize=10, colors="#444444")
    axA2.spines["top"].set_visible(False)
    axA2.spines["right"].set_visible(False)
    axA2.spines["left"].set_color("#888888")
    axA2.spines["bottom"].set_color("#888888")

    # ── Panel (b): per-theta chart ──
    axB = fig.add_subplot(gsL[1, 0])
    theta = [0, 3.6, 7.1, 10.7, 14.3, 17.9, 21.4, 25]
    para  = [88, 79, 62, 63, 62, 62, 33, 38]
    act   = [67, 54, 42, 17, 12,  0,  0,  0]
    for y in (25, 50, 75):
        axB.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=1)
    axB.plot(theta, para, color=PARA_GREEN, linewidth=2, marker="o", markersize=8,
             markerfacecolor=PARA_GREEN, markeredgecolor="white", markeredgewidth=1.2,
             label="PARA", zorder=5)
    axB.plot(theta, act, color=ACT_RED, linewidth=2, marker="s", markersize=8,
             markerfacecolor=ACT_RED, markeredgecolor="white", markeredgewidth=1.2,
             linestyle="--", label="ACT", zorder=4)
    axB.set_xlabel("Camera Elevation θ (degrees)", fontsize=11)
    axB.set_ylabel("Success Rate (%)", fontsize=11)
    axB.set_xlim(-1, 26)
    axB.set_ylim(0, 100)
    axB.set_xticks(theta)
    axB.set_xticklabels(["0°\n(train)", "3.6", "7.1", "10.7",
                         "14.3", "17.9", "21.4", "25"])
    axB.set_yticks([0, 25, 50, 75, 100])
    axB.spines["top"].set_visible(False)
    axB.spines["right"].set_visible(False)
    axB.spines["left"].set_color("#888888")
    axB.spines["bottom"].set_color("#888888")
    axB.tick_params(labelsize=10, colors="#444444")
    leg = axB.legend(loc="upper right", frameon=True, fontsize=10,
                     edgecolor="#cccccc", facecolor="white")
    leg.get_frame().set_linewidth(0.8)

    # ── Panel (c): 2x2 qualitative grid with annotations ──
    gsC = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[0, 1],
                                           hspace=0.14, wspace=0.04)

    ACT_GRIPPER_MID  = (380, 110)
    BOWL_POS         = (385, 290)
    PARA_GRIPPER_MID = (385, 230)

    cells = [
        ("4c_act_start_v2.png",  "ACT — start",  ACT_RED,    None),
        ("4c_act_mid_v2.png",    "ACT — mid",    ACT_RED,    "act"),
        ("4c_para_start_v2.png", "PARA — start", PARA_GREEN, None),
        ("4c_para_mid_v2.png",   "PARA — mid",   PARA_GREEN, "para"),
    ]
    for i, (fn, label, color, annot) in enumerate(cells):
        ax = fig.add_subplot(gsC[i // 2, i % 2])
        img = mpimg.imread(f"{EXTRACTED}/{fn}")
        ax.imshow(img)
        ax.set_xticks([])
        ax.set_yticks([])
        for s in ax.spines.values():
            s.set_visible(True)
            s.set_color("#cccccc")
            s.set_linewidth(1.0)
        ax.set_title(label, fontsize=10, fontweight="700", color=color, pad=3)

        if annot == "act":
            ax.add_patch(patches.Circle(ACT_GRIPPER_MID, 36, facecolor="none",
                                        edgecolor=ACT_RED, linewidth=2.5, zorder=5))
            ax.add_patch(patches.Circle(BOWL_POS, 36, facecolor="none",
                                        edgecolor=PARA_GREEN, linewidth=2.5, zorder=5))
        elif annot == "para":
            ax.add_patch(patches.Circle(PARA_GRIPPER_MID, 42, facecolor="none",
                                        edgecolor=PARA_GREEN, linewidth=2.5, zorder=5))

    save(fig, "fig4_ood")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 5: Video Backbone
# ─────────────────────────────────────────────────────────────────────────────

def render_fig5():
    """Fig 5: Two-stage training diagram + rollout grid comparison."""
    fig = plt.figure(figsize=(12, 5.0), dpi=300)
    # No suptitle — paper caption owns it

    gs = gridspec.GridSpec(1, 2, figure=fig, left=0.03, right=0.97,
                           top=0.90, bottom=0.05, wspace=0.08,
                           width_ratios=[1.0, 1.4])

    # Sub-titles (smaller, lighter — like Fig 2)
    fig.text(0.23, 0.95, "(a) Two-Stage Training",
             fontsize=11, fontweight="600", color="#555555", ha="center")
    fig.text(0.71, 0.95, "(b) Rollout Comparison",
             fontsize=11, fontweight="600", color="#555555", ha="center")

    # ── Panel (a): two-stage diagram ──
    axA = fig.add_subplot(gs[0, 0])
    axA.set_xlim(0, 10)
    axA.set_ylim(0, 6)
    axis_off(axA)

    # Box 1: Video UNet
    rounded_box(axA, 0.8, 2.5, 3.5, 1.6, "#eaf2fb", edgecolor="#3b6fa6",
                linewidth=2, rad=0.15)
    axA.text(2.55, 3.6, "Video UNet", ha="center", va="center",
             fontsize=12, fontweight="700")
    axA.text(2.55, 3.0, "(4K pretrain)", ha="center", va="center",
             fontsize=10, color=SUBTLE)

    # Arrow
    axA.annotate("", xy=(5.4, 3.3), xytext=(4.4, 3.3),
                 arrowprops=dict(
                     arrowstyle="-|>,head_length=0.6,head_width=0.4",
                     color="#1f2d3d", lw=2.5))

    # Box 2: Joint Fine-tune
    rounded_box(axA, 5.5, 2.5, 3.7, 1.6, "#fff3e6", edgecolor="#d97e1f",
                linewidth=2, rad=0.15)
    axA.text(7.35, 3.6, "Joint Fine-tune", ha="center", va="center",
             fontsize=12, fontweight="700")
    axA.text(7.35, 3.0, "(3K)", ha="center", va="center",
             fontsize=10, color=SUBTLE)

    # Two output branches — labels positioned away from box
    axA.annotate("PARA Heatmap Head",
                 xy=(9.6, 4.7), xytext=(9.6, 5.4),
                 fontsize=10, fontweight="600", color=PARA_GREEN, ha="center",
                 arrowprops=dict(
                     arrowstyle="-|>,head_length=0.5,head_width=0.3",
                     color=PARA_GREEN, lw=1.6))
    axA.annotate("Video Generation",
                 xy=(9.6, 1.9), xytext=(9.6, 1.2),
                 fontsize=10, fontweight="600", color=SUBTLE, ha="center",
                 arrowprops=dict(
                     arrowstyle="-|>,head_length=0.5,head_width=0.3",
                     color=SUBTLE, lw=1.6))

    # ── Panel (b): rollout comparison ──
    gsB = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=gs[0, 1], wspace=0.06)
    para_img = mpimg.imread(f"{EXTRACTED}/frame_5_para.png")
    glob_img = mpimg.imread(f"{EXTRACTED}/frame_5_global.png")

    axB1 = fig.add_subplot(gsB[0, 0])
    axB1.imshow(para_img)
    axB1.set_xticks([])
    axB1.set_yticks([])
    for s in axB1.spines.values():
        s.set_visible(True)
        s.set_color("#cccccc")
        s.set_linewidth(1.0)
    axB1.set_title("SVD + PARA — 92%",
                   fontsize=12, fontweight="700", color=PARA_GREEN, pad=4)

    axB2 = fig.add_subplot(gsB[0, 1])
    axB2.imshow(glob_img)
    axB2.set_xticks([])
    axB2.set_yticks([])
    for s in axB2.spines.values():
        s.set_visible(True)
        s.set_color("#cccccc")
        s.set_linewidth(1.0)
    axB2.set_title("SVD + Global Regression — 0%",
                   fontsize=12, fontweight="700", color=ACT_RED, pad=4)

    save(fig, "fig5_video")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 6: Point Track Pretraining (placeholder)
# ─────────────────────────────────────────────────────────────────────────────

def render_fig6():
    fig = plt.figure(figsize=(12, 4.5), dpi=200)
    fig.suptitle("Figure 6: Point Track Pretraining (Preliminary)",
                 fontsize=16, fontweight="700", y=0.99, color=TEXT_DARK)

    gs = gridspec.GridSpec(1, 2, figure=fig, left=0.04, right=0.96,
                           top=0.86, bottom=0.10, wspace=0.18,
                           width_ratios=[1.2, 1.0])

    # ── Panel (a): 3 placeholder frames ──
    gsA = gridspec.GridSpecFromSubplotSpec(1, 3, subplot_spec=gs[0, 0],
                                           wspace=0.08)
    for i in range(3):
        ax = fig.add_subplot(gsA[0, i])
        ax.set_xlim(0, 10)
        ax.set_ylim(0, 7)
        axis_off(ax)
        # Tan table
        table = patches.Rectangle((1, 1), 8, 5, facecolor="#e8d5b7",
                                  edgecolor=BORDER, linewidth=1)
        ax.add_patch(table)
        # Object (gray)
        obj = patches.Ellipse((4.0, 3.5), 1.2, 0.6, facecolor="#888888")
        ax.add_patch(obj)
        # Green dot (EEF position) — moves across frames
        dot_x = 5.0 + i * 0.7
        dot_y = 3.8 - i * 0.3
        dot = patches.Circle((dot_x, dot_y), 0.35,
                             facecolor=PARA_GREEN, edgecolor="white", linewidth=1.5)
        ax.add_patch(dot)
        ax.set_title(f"t = {i*5}", fontsize=10, color=SUBTLE, pad=2)

    fig.text(0.295, 0.92, "(a) Arm-Deleted Training Data",
             fontsize=12, fontweight="700", ha="center")
    fig.text(0.295, 0.04,
             "(robot arm invisible — only EEF position retained as supervision)",
             fontsize=9, color=SUBTLE, ha="center", style="italic")

    # ── Panel (b): placeholder bar chart ──
    axB = fig.add_subplot(gs[0, 1])
    labels = ["PARA\npretrain", "PARA\nscratch", "Global\npretrain", "Global\nscratch"]
    values = [70, 30, 20, 15]
    colors = [PARA_GREEN, PARA_GREEN, ACT_RED, ACT_RED]
    alphas = [1.0, 0.4, 1.0, 0.4]
    bars = axB.bar(range(len(labels)), values, color=colors, width=0.65,
                   edgecolor="white", linewidth=1.2)
    for b, a in zip(bars, alphas):
        b.set_alpha(a)
    for b, v in zip(bars, values):
        axB.annotate(f"{v}%",
                     xy=(b.get_x() + b.get_width()/2, b.get_height()),
                     xytext=(0, 3), textcoords="offset points",
                     ha="center", va="bottom", fontsize=11, fontweight="700")
    axB.set_xticks(range(len(labels)))
    axB.set_xticklabels(labels, fontsize=9.5)
    axB.set_ylabel("Success Rate (%)", fontsize=10)
    axB.set_ylim(0, 100)
    axB.set_yticks([0, 25, 50, 75, 100])
    axB.tick_params(labelsize=9, colors="#444444")
    axB.spines["top"].set_visible(False)
    axB.spines["right"].set_visible(False)
    axB.spines["left"].set_color("#888888")
    axB.spines["bottom"].set_color("#888888")
    axB.set_title("(b) Pretrain → Fine-tune (preliminary)",
                  fontsize=12, fontweight="700", pad=6)
    fig.text(0.74, 0.04,
             "Placeholder values — final results pending from backbones agent",
             fontsize=9, color=NEUTRAL, ha="center", fontweight="500")

    save(fig, "fig6_pointtrack")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 4b standalone (already exists; re-render for consistency)
# ─────────────────────────────────────────────────────────────────────────────

def render_fig4a_spatial():
    """Standalone Fig 4a: distribution plot + bar chart, side-by-side."""
    fig = plt.figure(figsize=(8, 4), dpi=300)
    gs = gridspec.GridSpec(1, 2, figure=fig, width_ratios=[1.4, 1.0],
                           wspace=0.30, top=0.86, bottom=0.14, left=0.06, right=0.96)
    ax1 = fig.add_subplot(gs[0, 0])
    img = mpimg.imread(f"{DASHBOARD_MEDIA}/exp3_leftright_distribution.png")
    ax1.imshow(img)
    axis_off(ax1)

    ax2 = fig.add_subplot(gs[0, 1])
    bars = ax2.bar(["PARA", "ACT"], [54, 1], color=[PARA_GREEN, ACT_RED], width=0.55)
    for b, val in zip(bars, [54, 1]):
        ax2.annotate(f"{val}%", xy=(b.get_x() + b.get_width()/2, b.get_height()),
                     xytext=(0, 4), textcoords="offset points",
                     ha="center", va="bottom", fontsize=13, fontweight="700")
    ax2.set_ylim(0, 70)
    ax2.set_ylabel("Success Rate (%)", fontsize=11)
    ax2.tick_params(labelsize=11, colors="#444444")
    ax2.spines["top"].set_visible(False)
    ax2.spines["right"].set_visible(False)
    ax2.spines["left"].set_color("#888888")
    ax2.spines["bottom"].set_color("#888888")
    # Subtle gridlines like 4b
    for y in (25, 50):
        ax2.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=0)

    fig.suptitle("Fig 4a: Spatial Extrapolation (train left, test right)",
                 fontsize=13, fontweight="700", y=0.98, color=TEXT_DARK)
    for out_dir in OUT_DIRS:
        fig.savefig(f"{out_dir}/fig4a_spatial.png", format="png",
                    bbox_inches="tight", dpi=300, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("  wrote fig4a_spatial.png to 2 locations (dpi=300)")


def render_fig4b_standalone():
    fig, ax = plt.subplots(figsize=(6, 3.5), dpi=300)
    theta = [0, 3.6, 7.1, 10.7, 14.3, 17.9, 21.4, 25]
    para  = [88, 79, 62, 63, 62, 62, 33, 38]
    act   = [67, 54, 42, 17, 12,  0,  0,  0]

    # Subtle horizontal gridlines at 25/50/75 (zorder under data)
    for y in (25, 50, 75):
        ax.axhline(y, color="#888888", alpha=0.3, linewidth=0.5, zorder=1)

    ax.plot(theta, para, color=PARA_GREEN, linewidth=2, marker="o", markersize=8,
            markerfacecolor=PARA_GREEN, markeredgecolor="white", markeredgewidth=1.2,
            label="PARA", zorder=5)
    ax.plot(theta, act, color=ACT_RED, linewidth=2, marker="s", markersize=8,
            markerfacecolor=ACT_RED, markeredgecolor="white", markeredgewidth=1.2,
            linestyle="--", label="ACT", zorder=4)

    ax.set_xlabel("Camera Elevation Angle θ (degrees)", fontsize=12)
    ax.set_ylabel("Success Rate (%)", fontsize=12)
    ax.set_xlim(-1, 26)
    ax.set_ylim(0, 100)
    ax.set_xticks(theta)
    # Custom x-tick labels with "(train)" annotation on the 0° tick
    ax.set_xticklabels(["0°\n(train)", "3.6", "7.1", "10.7",
                        "14.3", "17.9", "21.4", "25"])
    ax.set_yticks([0, 25, 50, 75, 100])
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_color("#888888")
    ax.spines["bottom"].set_color("#888888")
    ax.tick_params(colors="#444444", labelsize=11)
    leg = ax.legend(loc="upper right", frameon=True, fontsize=10,
                    edgecolor="#cccccc", facecolor="white")
    leg.get_frame().set_linewidth(0.8)

    plt.tight_layout()
    # Save with white background (no frame), high dpi
    for out_dir in OUT_DIRS:
        fig.savefig(f"{out_dir}/fig4b_pertheta.png", format="png",
                    bbox_inches="tight", dpi=300, facecolor="white",
                    edgecolor="none")
        fig.savefig(f"{out_dir}/fig4b_pertheta.svg", format="svg",
                    bbox_inches="tight", facecolor="white", edgecolor="none")
    plt.close(fig)
    print("  wrote fig4b_pertheta.png + .svg to 2 locations (dpi=300)")


def render_fig4c_qualitative():
    """Fig 4c: 2x2 qualitative comparison with annotations highlighting divergence.

    Frames at t=0 (start) and t=12 (mid-rollout, clearest divergence).
    Annotations:
      - ACT-mid: red circle on its gripper (wrong location), green circle on bowl
      - PARA-mid: green circle on its gripper (correct, at bowl)
    """
    fig = plt.figure(figsize=(7.5, 5.5), dpi=300)
    gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.14, wspace=0.04,
                           top=0.95, bottom=0.04, left=0.04, right=0.96)

    # Approximate gripper / bowl positions in the cropped frame coordinate space
    # (frames are ~448 wide, ~370 tall after cropping)
    # ACT mid: gripper hovers upper-right, NOT touching bowl
    # PARA mid: gripper extends down to bowl
    # Bowl/plate: bottom-right area of all frames
    ACT_GRIPPER_MID  = (380, 110)   # where ACT gripper actually is (upper right)
    BOWL_POS         = (385, 290)   # where the bowl actually is (lower right)
    PARA_GRIPPER_MID = (385, 230)   # where PARA gripper is (close to bowl)

    cells = [
        ("4c_act_start_v2.png",  "ACT — start (t=0)",   ACT_RED,    None,  None),
        ("4c_act_mid_v2.png",    "ACT — mid (t=12)",    ACT_RED,    "act", None),
        ("4c_para_start_v2.png", "PARA — start (t=0)",  PARA_GREEN, None,  None),
        ("4c_para_mid_v2.png",   "PARA — mid (t=12)",   PARA_GREEN, None,  "para"),
    ]

    for i, (fn, label, color, annot_act, annot_para) in enumerate(cells):
        ax = fig.add_subplot(gs[i // 2, i % 2])
        img = mpimg.imread(f"{EXTRACTED}/{fn}")
        ax.imshow(img)
        ax.set_xticks([])
        ax.set_yticks([])
        for s in ax.spines.values():
            s.set_visible(True)
            s.set_color("#cccccc")
            s.set_linewidth(1.0)
        ax.set_title(label, fontsize=12, fontweight="700", color=color, pad=4)

        # ACT-mid annotations: red on wrong gripper, green on correct bowl
        if annot_act:
            # Red circle on ACT gripper (wrong location)
            c1 = patches.Circle(ACT_GRIPPER_MID, 36, facecolor="none",
                                edgecolor=ACT_RED, linewidth=3, zorder=5)
            ax.add_patch(c1)
            ax.annotate("ACT gripper\n(wrong)", xy=ACT_GRIPPER_MID,
                        xytext=(180, 70), fontsize=9, fontweight="600",
                        color=ACT_RED, ha="center",
                        arrowprops=dict(arrowstyle="-", color=ACT_RED, lw=1.2))
            # Green circle on bowl (correct location)
            c2 = patches.Circle(BOWL_POS, 36, facecolor="none",
                                edgecolor=PARA_GREEN, linewidth=3, zorder=5)
            ax.add_patch(c2)
            ax.annotate("bowl\n(target)", xy=BOWL_POS,
                        xytext=(180, 290), fontsize=9, fontweight="600",
                        color=PARA_GREEN, ha="center",
                        arrowprops=dict(arrowstyle="-", color=PARA_GREEN, lw=1.2))

        # PARA-mid annotation: green circle around gripper which IS at bowl
        if annot_para:
            c = patches.Circle(PARA_GRIPPER_MID, 42, facecolor="none",
                               edgecolor=PARA_GREEN, linewidth=3, zorder=5)
            ax.add_patch(c)
            ax.annotate("PARA gripper\nreaches bowl", xy=PARA_GRIPPER_MID,
                        xytext=(180, 200), fontsize=9, fontweight="600",
                        color=PARA_GREEN, ha="center",
                        arrowprops=dict(arrowstyle="-", color=PARA_GREEN, lw=1.2))

    # No suptitle — paper caption owns it
    for out_dir in OUT_DIRS:
        fig.savefig(f"{out_dir}/fig4c_qualitative.png", format="png",
                    bbox_inches="tight", dpi=300, facecolor="white", edgecolor="none")
    plt.close(fig)
    print("  wrote fig4c_qualitative.png to 2 locations (dpi=300)")


BUILDERS = {
    "1":  render_fig1,
    "2":  render_fig2,
    "3":  render_fig3,
    "4":  render_fig4,
    "4a": render_fig4a_spatial,
    "4b": render_fig4b_standalone,
    "4c": render_fig4c_qualitative,
    "5":  render_fig5,
    "6":  render_fig6,
}


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--figs", default="all")
    args = parser.parse_args()
    if args.figs == "all":
        keys = ["1", "2", "3", "4", "4b", "5", "6"]
    else:
        keys = [k.strip() for k in args.figs.split(",")]
    for k in keys:
        if k not in BUILDERS:
            print(f"unknown: {k}")
            sys.exit(1)
        print(f"=== render fig {k} ===")
        BUILDERS[k]()


if __name__ == "__main__":
    main()
