"""Build paper figures in the para_paper_figures Penpot file.

Each figure = one frame on the page. Uses a hybrid of:
  - Native rect/text shapes (editable in Penpot)
  - Image media uploads + image shape placement (for charts, photos, video frames)

Run with --figs to specify which figures to build:
    python3 build_paper_figures.py --figs 4b
    python3 build_paper_figures.py --figs all
    python3 build_paper_figures.py --figs 1,2,3
"""

import argparse
import json
import subprocess
import sys
import uuid
from pathlib import Path

# ── Penpot identifiers ──
FILE_ID  = "5688a1b5-92ba-8001-8007-dc49a465085a"
PAGE_ID  = "5688a1b5-92ba-8001-8007-dc49a465085b"
ROOT     = "00000000-0000-0000-0000-000000000000"
COOKIE   = "/tmp/penpot_user_cookies.txt"
BASE_URL = "https://penpot.omidlab.net"

# ── Color palette ──
PARA_GREEN = "#16653a"
ACT_RED    = "#a12029"
NEUTRAL    = "#71717a"
TEXT_DARK  = "#1f2d3d"
SUBTLE     = "#5a5a5a"
BG_LIGHT   = "#f8f9fa"
BORDER     = "#d4d4d8"

# ── Frame layout ──
FRAME_W = 1200
FRAME_H_DEFAULT = 600
FRAME_GAP = 80   # vertical gap between figure frames
FRAME_X = 100    # left margin

# Cache the current revn so we can chain calls
_revn_cache = None

# ─────────────────────────────────────────────────────────────────────────────
# Penpot API helpers
# ─────────────────────────────────────────────────────────────────────────────

def get_revn():
    r = subprocess.run([
        "curl", "-s", "-b", COOKIE,
        f"{BASE_URL}/api/rpc/command/get-file?id={FILE_ID}",
        "-H", "Accept: application/json",
    ], capture_output=True, text=True)
    return json.loads(r.stdout)["revn"]


def upload_media(file_path, name):
    """Upload an image to the Penpot file. Returns (media_id, width, height, mtype)."""
    print(f"  Uploading {name} ({Path(file_path).name})...")
    r = subprocess.run([
        "curl", "-s", "-b", COOKIE,
        "-X", "POST", f"{BASE_URL}/api/rpc/command/upload-file-media-object",
        "-H", "Accept: application/json",
        "-F", f"file-id={FILE_ID}",
        "-F", "is-local=true",
        "-F", f"name={name}",
        "-F", f"content=@{file_path};type={_mtype_for(file_path)}",
    ], capture_output=True, text=True)
    d = json.loads(r.stdout)
    if "id" not in d:
        raise RuntimeError(f"upload failed: {r.stdout[:300]}")
    return d["id"], d["width"], d["height"], d["mtype"]


def _mtype_for(path):
    p = str(path).lower()
    if p.endswith(".png"):  return "image/png"
    if p.endswith(".jpg") or p.endswith(".jpeg"): return "image/jpeg"
    if p.endswith(".svg"):  return "image/svg+xml"
    return "application/octet-stream"


def commit_changes(changes):
    """POST update-file with the given changes; bumps revn cache."""
    global _revn_cache
    if _revn_cache is None:
        _revn_cache = get_revn()
    payload = {"id": FILE_ID, "sessionId": str(uuid.uuid4()),
               "revn": _revn_cache, "vern": 0, "changes": changes}
    r = subprocess.run([
        "curl", "-s", "-b", COOKIE,
        "-X", "POST", f"{BASE_URL}/api/rpc/command/update-file",
        "-H", "Content-Type: application/json", "-H", "Accept: application/json",
        "-d", json.dumps(payload),
    ], capture_output=True, text=True)
    try:
        d = json.loads(r.stdout)
    except json.JSONDecodeError:
        raise RuntimeError(f"non-JSON response: {r.stdout[:500]}")
    if isinstance(d, dict) and d.get("type") == "validation":
        raise RuntimeError(f"validation error: {r.stdout[:600]}")
    new_revn = d.get("revn", _revn_cache) if isinstance(d, dict) else _revn_cache
    _revn_cache = new_revn + 1   # next call will be a new revn
    print(f"  Committed {len(changes)} changes (revn -> {_revn_cache})")
    return d


# ─────────────────────────────────────────────────────────────────────────────
# Shape constructors
# ─────────────────────────────────────────────────────────────────────────────

def _geom(x, y, w, h):
    return {
        "x": x, "y": y, "width": w, "height": h, "rotation": 0,
        "selrect": {"x": x, "y": y, "x1": x, "y1": y, "x2": x+w, "y2": y+h, "width": w, "height": h},
        "points": [{"x": x, "y": y}, {"x": x+w, "y": y},
                   {"x": x+w, "y": y+h}, {"x": x, "y": y+h}],
        "transform":        {"a":1,"b":0,"c":0,"d":1,"e":0,"f":0},
        "transformInverse": {"a":1,"b":0,"c":0,"d":1,"e":0,"f":0},
    }


def frame(name, x, y, w, h, fill=BG_LIGHT, sid=None):
    sid = sid or str(uuid.uuid4())
    obj = {
        "id": sid, "type": "frame", "name": name, **_geom(x, y, w, h),
        "parentId": ROOT, "frameId": ROOT,
        "fills": [{"fillColor": fill, "fillOpacity": 1}],
        "shapes": [],
    }
    return sid, obj


def rect(name, x, y, w, h, parent_frame, fill=None, stroke=None, rx=0, sid=None):
    sid = sid or str(uuid.uuid4())
    obj = {"id": sid, "type": "rect", "name": name, **_geom(x, y, w, h),
           "parentId": parent_frame, "frameId": parent_frame}
    obj["fills"] = [{"fillColor": fill, "fillOpacity": 1}] if fill else []
    if stroke:
        sw, sc = stroke
        obj["strokes"] = [{"strokeColor": sc, "strokeOpacity": 1, "strokeStyle": "solid",
                           "strokeWidth": sw, "strokeAlignment": "center"}]
    if rx:
        obj["rx"] = rx
        obj["ry"] = rx
    return sid, obj


def text(name, x, y, w, h, content, parent_frame, size=14, weight="400",
         color=TEXT_DARK, align="left", sid=None):
    sid = sid or str(uuid.uuid4())
    obj = {
        "id": sid, "type": "text", "name": name, **_geom(x, y, w, h),
        "parentId": parent_frame, "frameId": parent_frame,
        "growType": "fixed",
        "content": {
            "type": "root",
            "children": [{
                "type": "paragraph-set",
                "children": [{
                    "type": "paragraph",
                    "textAlign": align,
                    "children": [{
                        "text": content,
                        "fontId": "sourcesanspro",
                        "fontFamily": "sourcesanspro",
                        "fontSize": str(size),
                        "fontWeight": weight,
                        "fills": [{"fillColor": color, "fillOpacity": 1}],
                    }],
                }],
            }],
        },
    }
    return sid, obj


def image(name, x, y, w, h, media_id, mw, mh, parent_frame, mtype="image/png", sid=None):
    sid = sid or str(uuid.uuid4())
    obj = {
        "id": sid, "type": "image", "name": name, **_geom(x, y, w, h),
        "parentId": parent_frame, "frameId": parent_frame,
        "metadata": {"id": media_id, "width": mw, "height": mh, "mtype": mtype},
    }
    return sid, obj


def add_obj_change(obj, parent):
    return {"type": "add-obj", "id": obj["id"], "pageId": PAGE_ID,
            "frameId": parent, "parentId": parent, "obj": obj}


# ─────────────────────────────────────────────────────────────────────────────
# Figure builders
# ─────────────────────────────────────────────────────────────────────────────

def fig_y(figure_index, custom_heights=None):
    """Stable y-coordinate for figure N (1..6) given per-figure heights."""
    if custom_heights is None:
        custom_heights = {1: 600, 2: 540, 3: 700, 4: 550, 5: 480, 6: 480}
    y = 100
    for i in range(1, figure_index):
        y += custom_heights.get(i, 600) + FRAME_GAP
    return y


# ─────────────────────────────────────────────────────────────────────────────
# Figure 1: Overview teaser
# ─────────────────────────────────────────────────────────────────────────────

def build_fig1_overview():
    fy = fig_y(1)
    fh = 600
    print(f"\n=== Fig 1: Overview teaser ===")
    fid, fobj = frame("Fig1_Overview", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    # Title
    _, t = text("fig1-title", FRAME_X+30, fy+15, FRAME_W-60, 32,
                "Figure 1: PARA — Pixel-Aligned Robot Actions",
                fid, size=22, weight="700")
    changes.append(add_obj_change(t, fid))

    # ── 3 panels: left (diagram), middle (vignettes), right (numbers) ──
    panel_w = (FRAME_W - 90) // 3   # 30 + p1 + 15 + p2 + 15 + p3 + 30
    panel_x = [FRAME_X+30, FRAME_X+30 + panel_w + 15,
               FRAME_X+30 + 2*(panel_w + 15)]
    panel_y = fy + 65
    panel_h = fh - 100

    # ── LEFT PANEL: Comparison diagram ──
    px = panel_x[0]
    _, lbl = text("p1-label", px, panel_y, panel_w, 22,
                  "(a) Architecture Comparison", fid, size=13, weight="700")
    changes.append(add_obj_change(lbl, fid))

    # Global Regression row
    grow_y = panel_y + 40
    _, grow_lbl = text("p1-grow-lbl", px, grow_y, panel_w, 18,
                       "Global Regression (ACT)", fid, size=11, weight="600", color=ACT_RED)
    changes.append(add_obj_change(grow_lbl, fid))

    box_w, box_h = 60, 38
    boxes_g = [("Image", "#eaf2fb", "#3b6fa6"),
               ("DINO", "#fff3e6", "#d97e1f"),
               ("CLS", "#f5e6fb", "#c2415a"),
               ("MLP", "#fff", "#71717a"),
               ("(x,y,z)", "#e8f5ec", ACT_RED)]
    spacing = (panel_w - box_w * len(boxes_g)) // (len(boxes_g) - 1)
    for i, (label, fill, stroke) in enumerate(boxes_g):
        bx = px + i * (box_w + spacing)
        _, b = rect(f"p1-grow-box-{i}", bx, grow_y + 22, box_w, box_h,
                    fid, fill=fill, stroke=(1.5, stroke), rx=6)
        _, bt = text(f"p1-grow-box-{i}-t", bx, grow_y + 22 + (box_h-14)//2,
                     box_w, 16, label, fid, size=10, weight="600", align="center")
        changes.extend([add_obj_change(b, fid), add_obj_change(bt, fid)])
        # Arrow between boxes
        if i < len(boxes_g) - 1:
            ax = bx + box_w + 2
            ay = grow_y + 22 + box_h // 2 - 1
            _, a = rect(f"p1-grow-arr-{i}", ax, ay, spacing - 6, 2, fid, fill=NEUTRAL)
            changes.append(add_obj_change(a, fid))

    # PARA row
    prow_y = grow_y + 95
    _, prow_lbl = text("p1-prow-lbl", px, prow_y, panel_w, 18,
                       "PARA", fid, size=11, weight="600", color=PARA_GREEN)
    changes.append(add_obj_change(prow_lbl, fid))

    boxes_p = [("Image", "#eaf2fb", "#3b6fa6"),
               ("DINO", "#fff3e6", "#d97e1f"),
               ("Heatmap\nVolume", "#fde9ec", "#c2415a"),
               ("Argmax", "#fff", "#71717a"),
               ("3D Pt", "#e8f5ec", PARA_GREEN)]
    box_w2 = 60
    spacing2 = (panel_w - box_w2 * len(boxes_p)) // (len(boxes_p) - 1)
    for i, (label, fill, stroke) in enumerate(boxes_p):
        bx = px + i * (box_w2 + spacing2)
        _, b = rect(f"p1-para-box-{i}", bx, prow_y + 22, box_w2, box_h,
                    fid, fill=fill, stroke=(1.5, stroke), rx=6)
        # Multi-line shorter labels just use single line at smaller size
        first_line = label.split("\n")[0] if "\n" in label else label
        _, bt = text(f"p1-para-box-{i}-t", bx, prow_y + 22 + (box_h-14)//2,
                     box_w2, 16, first_line, fid, size=10, weight="600", align="center")
        changes.extend([add_obj_change(b, fid), add_obj_change(bt, fid)])
        if i < len(boxes_p) - 1:
            ax = bx + box_w2 + 2
            ay = prow_y + 22 + box_h // 2 - 1
            _, a = rect(f"p1-para-arr-{i}", ax, ay, spacing2 - 6, 2, fid, fill=NEUTRAL)
            changes.append(add_obj_change(a, fid))

    # Punchline text below
    _, punch = text("p1-punch", px, prow_y + 100, panel_w, 32,
                    "CLS collapses spatial info; PARA preserves it.",
                    fid, size=11, weight="500", color=SUBTLE, align="left")
    changes.append(add_obj_change(punch, fid))

    # ── MIDDLE PANEL: 3 vignettes ──
    px = panel_x[1]
    _, lbl2 = text("p2-label", px, panel_y, panel_w, 22,
                   "(b) Where PARA Helps", fid, size=13, weight="700")
    changes.append(add_obj_change(lbl2, fid))

    vignettes = [
        ("OOD Generalization",   "Shifted camera & objects",   PARA_GREEN),
        ("Video Backbone",       "SVD video → robot",          "#3b6fa6"),
        ("Cross-Embodiment",     "Arm-deleted point tracks",   "#5a3da2"),
    ]
    vh = (panel_h - 50) // 3
    for i, (title, sub, color) in enumerate(vignettes):
        vy = panel_y + 35 + i * (vh + 8)
        # Icon box (placeholder colored circle)
        _, ic = rect(f"p2-vig-{i}-icon", px+10, vy+12, 36, 36, fid,
                     fill=color, rx=18)
        _, tt = text(f"p2-vig-{i}-title", px+58, vy+10, panel_w-70, 22,
                     title, fid, size=12, weight="700")
        _, st = text(f"p2-vig-{i}-sub", px+58, vy+30, panel_w-70, 18,
                     sub, fid, size=10, weight="400", color=SUBTLE)
        for x in (ic, tt, st):
            changes.append(add_obj_change(x, fid))

    # ── RIGHT PANEL: 3 headline numbers ──
    px = panel_x[2]
    _, lbl3 = text("p3-label", px, panel_y, panel_w, 22,
                   "(c) Headline Results", fid, size=13, weight="700")
    changes.append(add_obj_change(lbl3, fid))

    headlines = [
        ("97% vs 9%",  "Real Robot (20 demos)",  PARA_GREEN, ACT_RED),
        ("90% vs 0%",  "Video Backbone",          PARA_GREEN, ACT_RED),
        ("[TBD]",      "Point Track Pretraining", NEUTRAL, NEUTRAL),
    ]
    hh = (panel_h - 50) // 3
    for i, (big, sub, c1, c2) in enumerate(headlines):
        hy = panel_y + 35 + i * (hh + 8)
        # Background card
        _, card = rect(f"p3-card-{i}", px, hy, panel_w, hh, fid,
                       fill="#ffffff", stroke=(1.5, BORDER), rx=10)
        _, big_t = text(f"p3-big-{i}", px, hy+12, panel_w, 36,
                        big, fid, size=24, weight="800", color=c1, align="center")
        _, sub_t = text(f"p3-sub-{i}", px, hy+50, panel_w, 22,
                        sub, fid, size=11, weight="500", color=SUBTLE, align="center")
        for x in (card, big_t, sub_t):
            changes.append(add_obj_change(x, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 2: Method details
# ─────────────────────────────────────────────────────────────────────────────

def build_fig2_method():
    fy = fig_y(2)
    fh = 540
    print(f"\n=== Fig 2: Method Details ===")
    fid, fobj = frame("Fig2_Method", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    _, t = text("fig2-title", FRAME_X+30, fy+15, FRAME_W-60, 30,
                "Figure 2: PARA Method Pipeline",
                fid, size=20, weight="700")
    changes.append(add_obj_change(t, fid))

    # Sub-labels
    _, la = text("fig2-a-label", FRAME_X+30, fy+60, 760, 22,
                 "(a) Visualization of the PARA inference pipeline",
                 fid, size=13, weight="700")
    _, lb = text("fig2-b-label", FRAME_X+820, fy+60, 360, 22,
                 "(b) Height vs Depth", fid, size=13, weight="700")
    changes.extend([add_obj_change(la, fid), add_obj_change(lb, fid)])

    # ── Panel (a): horizontal strip of 4 method frames ──
    strip_w = 760
    strip_h = 360
    sx = FRAME_X + 30
    sy = fy + 95

    # Upload all 4 frames
    method_frames = []
    for i, fname in enumerate([
        "frame_2_stage1.png", "frame_2_stage2.png",
        "frame_2_stage3.png", "frame_2_stage4.png"
    ]):
        mid, mw, mh, mtype = upload_media(
            f"/data/cameron/penpot/figures/extracted/{fname}",
            f"Method Stage {i+1}")
        method_frames.append((mid, mw, mh, mtype))

    img_w = (strip_w - 30) // 4   # 4 images, 3 gaps of 10
    gap = 10
    captions = ["Camera Frustum", "Heatmap Volume", "Argmax → 3D", "Robot at Target"]
    for i, ((mid, mw, mh, mtype), cap) in enumerate(zip(method_frames, captions)):
        ix = sx + i * (img_w + gap)
        ih = int(img_w * mh / mw)
        _, im = image(f"fig2-a-img-{i}", ix, sy, img_w, ih, mid, mw, mh, fid, mtype=mtype)
        changes.append(add_obj_change(im, fid))
        _, ct = text(f"fig2-a-cap-{i}", ix, sy + ih + 6, img_w, 18,
                     cap, fid, size=11, weight="600", color=SUBTLE, align="center")
        changes.append(add_obj_change(ct, fid))
        # Arrow between
        if i < 3:
            ax = ix + img_w + 1
            ay = sy + ih // 2 - 1
            _, ar = rect(f"fig2-a-arr-{i}", ax, ay, gap-2, 2, fid, fill=TEXT_DARK)
            changes.append(add_obj_change(ar, fid))

    # ── Panel (b): height illustration ──
    bx = FRAME_X + 820
    by = fy + 95
    bw = 360
    hi_id, hi_w, hi_h, hi_mt = upload_media(
        "/data/cameron/para/paper/figs/figma/height_illustration.png",
        "Height Illustration")
    bh = int(bw * hi_h / hi_w)
    if bh > 380:
        bh = 380
        bw = int(bh * hi_w / hi_h)
    _, im_h = image("fig2-b-img", bx, by, bw, bh, hi_id, hi_w, hi_h, fid, mtype=hi_mt)
    changes.append(add_obj_change(im_h, fid))
    _, cap_h = text("fig2-b-cap", FRAME_X+820, by + bh + 12, 360, 18,
                    "Height is camera-invariant; depth is not.",
                    fid, size=11, weight="500", color=SUBTLE, align="center")
    changes.append(add_obj_change(cap_h, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 3: Real robot results — single image import
# ─────────────────────────────────────────────────────────────────────────────

def build_fig3_real_robot():
    fy = fig_y(3)
    fh = 700
    print(f"\n=== Fig 3: Real Robot Results ===")
    fid, fobj = frame("Fig3_RealRobot", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    _, t = text("fig3-title", FRAME_X+30, fy+15, FRAME_W-60, 30,
                "Figure 3: Real Robot Results (SO-100, 20 demos)",
                fid, size=20, weight="700")
    changes.append(add_obj_change(t, fid))

    # Import the existing Figma asset
    pr_id, pr_w, pr_h, pr_mt = upload_media(
        "/data/cameron/para/paper/figs/figma/para_results.png",
        "Para Results")
    iw = FRAME_W - 60
    ih = int(iw * pr_h / pr_w)
    if ih > fh - 90:
        ih = fh - 90
        iw = int(ih * pr_w / pr_h)
    ix = FRAME_X + (FRAME_W - iw) // 2
    iy = fy + 60
    _, im = image("fig3-img", ix, iy, iw, ih, pr_id, pr_w, pr_h, fid, mtype=pr_mt)
    changes.append(add_obj_change(im, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 4: OOD Analysis (rebuild — full version with all 3 panels)
# ─────────────────────────────────────────────────────────────────────────────

def build_fig4_full():
    fy = fig_y(4)
    fh = 550
    print(f"\n=== Fig 4: OOD Analysis (full) ===")
    fid, fobj = frame("Fig4_OOD_Analysis", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    _, t = text("fig4-title", FRAME_X+30, fy+15, FRAME_W-60, 30,
                "Figure 4: OOD Generalization Analysis",
                fid, size=20, weight="700")
    changes.append(add_obj_change(t, fid))

    panel_w = 380
    panel_gap = 20
    px_a = FRAME_X + 30
    px_b = px_a + panel_w + panel_gap
    px_c = px_b + panel_w + panel_gap
    py = fy + 65

    # Sub-labels
    _, la = text("fig4-a-label", px_a, py, panel_w, 22,
                 "(a) Spatial Extrapolation", fid, size=13, weight="700")
    _, lb = text("fig4-b-label", px_b, py, panel_w, 22,
                 "(b) Per-Angle Viewpoint Robustness", fid, size=13, weight="700")
    _, lc = text("fig4-c-label", px_c, py, panel_w, 22,
                 "(c) Qualitative Comparison", fid, size=13, weight="700")
    changes.extend([add_obj_change(x, fid) for x in (la, lb, lc)])

    panel_y = py + 30
    panel_inner_h = 360

    # Panel (a): distribution plot + bar chart
    dist_id, dw, dh, dmt = upload_media(
        "/data/cameron/para/.agents/reports/project_site/media/exp3_leftright_distribution.png",
        "Train/Test Distribution")
    dist_w = panel_w
    dist_h = int(dist_w * dh / dw)
    if dist_h > 240:
        dist_h = 240
        dist_w = int(dist_h * dw / dh)
    _, dim = image("fig4-a-dist", px_a + (panel_w-dist_w)//2, panel_y, dist_w, dist_h,
                   dist_id, dw, dh, fid, mtype=dmt)
    changes.append(add_obj_change(dim, fid))

    # Bar chart underneath: PARA 54% vs ACT 1%
    bar_y = panel_y + dist_h + 10
    bar_h = 60
    bar_w_unit = (panel_w - 40) // 2
    # PARA bar (54%)
    para_bar_h = int(bar_h * 0.54)
    _, pb = rect("fig4-a-para-bar", px_a + 20, bar_y + (bar_h - para_bar_h),
                 (bar_w_unit-10), para_bar_h, fid, fill=PARA_GREEN)
    _, pbl = text("fig4-a-para-lbl", px_a + 20, bar_y + bar_h + 4, bar_w_unit-10, 18,
                  "PARA 54%", fid, size=11, weight="700", color=PARA_GREEN, align="center")
    # ACT bar (1%)
    act_bar_h = max(2, int(bar_h * 0.01))
    _, ab = rect("fig4-a-act-bar", px_a + 20 + bar_w_unit, bar_y + (bar_h - act_bar_h),
                 (bar_w_unit-10), act_bar_h, fid, fill=ACT_RED)
    _, abl = text("fig4-a-act-lbl", px_a + 20 + bar_w_unit, bar_y + bar_h + 4, bar_w_unit-10, 18,
                  "ACT 1%", fid, size=11, weight="700", color=ACT_RED, align="center")
    for x in (pb, pbl, ab, abl):
        changes.append(add_obj_change(x, fid))

    # Panel (b): per-theta chart
    pt_id, pw, ph_, pmt = upload_media(
        "/data/cameron/penpot/figures/per_theta.png", "Per-Theta Chart (2)")
    chart_w = panel_w
    chart_h = int(chart_w * ph_ / pw)
    if chart_h > panel_inner_h:
        chart_h = panel_inner_h
        chart_w = int(chart_h * pw / ph_)
    _, ci = image("fig4-b-chart", px_b + (panel_w-chart_w)//2, panel_y, chart_w, chart_h,
                  pt_id, pw, ph_, fid, mtype=pmt)
    changes.append(add_obj_change(ci, fid))

    # Panel (c): 2x2 grid of feature comparison frames
    cell_w = (panel_w - 10) // 2
    cell_h = cell_w
    cells = [
        ("4c_act_start.png",  px_c,            panel_y,          "ACT start"),
        ("4c_act_mid.png",    px_c+cell_w+10,  panel_y,          "ACT mid"),
        ("4c_para_start.png", px_c,            panel_y+cell_h+24,"PARA start"),
        ("4c_para_mid.png",   px_c+cell_w+10,  panel_y+cell_h+24,"PARA mid"),
    ]
    for i, (fn, cx, cy, label) in enumerate(cells):
        mid, mw, mh, mtype = upload_media(
            f"/data/cameron/penpot/figures/extracted/{fn}",
            f"Fig4c {label}")
        _, im = image(f"fig4-c-{i}", cx, cy, cell_w, cell_h, mid, mw, mh, fid, mtype=mtype)
        _, cap = text(f"fig4-c-cap-{i}", cx, cy + cell_h + 2, cell_w, 18,
                      label, fid, size=10, weight="600", color=SUBTLE, align="center")
        changes.extend([add_obj_change(im, fid), add_obj_change(cap, fid)])

    # Verdict text under panel c
    _, vd = text("fig4-c-verdict", px_c, panel_y + 2*(cell_h+24) + 10, panel_w, 18,
                 "ACT reaches wrong location · PARA reaches bowl correctly",
                 fid, size=10, weight="500", color=SUBTLE, align="center")
    changes.append(add_obj_change(vd, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 5: Video Backbone
# ─────────────────────────────────────────────────────────────────────────────

def build_fig5_video_backbone():
    fy = fig_y(5)
    fh = 480
    print(f"\n=== Fig 5: Video Backbone ===")
    fid, fobj = frame("Fig5_VideoBackbone", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    _, t = text("fig5-title", FRAME_X+30, fy+15, FRAME_W-60, 30,
                "Figure 5: Video Models as Policy Backbones",
                fid, size=20, weight="700")
    changes.append(add_obj_change(t, fid))

    # Left: training diagram
    diag_x = FRAME_X + 30
    diag_y = fy + 70
    diag_w = 480
    _, lab = text("fig5-diag-label", diag_x, diag_y, diag_w, 22,
                  "(a) Two-Stage Training", fid, size=13, weight="700")
    changes.append(add_obj_change(lab, fid))

    boxes = [
        ("Video UNet\n(4K pretrain)",       diag_x + 20,  diag_y + 100, 160, 72, "#eaf2fb", "#3b6fa6"),
        ("Joint Fine-tune\n(3K)",            diag_x + 240, diag_y + 100, 160, 72, "#fff3e6", "#d97e1f"),
    ]
    for name, x, y, w, h, fill, stroke in boxes:
        _, b = rect(f"fig5-{name[:8]}", x, y, w, h, fid, fill=fill,
                    stroke=(2, stroke), rx=10)
        first_line = name.split("\n")[0]
        sub = name.split("\n")[1] if "\n" in name else ""
        _, bt = text(f"fig5-{name[:8]}-t", x, y+18, w, 22, first_line,
                     fid, size=13, weight="700", align="center")
        _, bs = text(f"fig5-{name[:8]}-s", x, y+40, w, 18, sub,
                     fid, size=10, weight="400", color=SUBTLE, align="center")
        changes.extend([add_obj_change(b, fid), add_obj_change(bt, fid), add_obj_change(bs, fid)])

    # Arrow between
    _, ar = rect("fig5-arr", diag_x + 180, diag_y + 134, 60, 2, fid, fill=TEXT_DARK)
    changes.append(add_obj_change(ar, fid))

    # Two output branches from joint fine-tune
    out_y_top = diag_y + 90
    out_y_bot = diag_y + 200
    _, out1 = text("fig5-out1", diag_x + 240, out_y_top - 20, 160, 16,
                   "↑ PARA Heatmap Head", fid, size=10, weight="600", color=PARA_GREEN, align="center")
    _, out2 = text("fig5-out2", diag_x + 240, out_y_bot - 5, 160, 16,
                   "↓ Video Generation", fid, size=10, weight="600", color=SUBTLE, align="center")
    changes.extend([add_obj_change(out1, fid), add_obj_change(out2, fid)])

    # Right: rollout grids comparison
    right_x = FRAME_X + 540
    right_y = fy + 70
    right_w = 620
    _, rlab = text("fig5-right-label", right_x, right_y, right_w, 22,
                   "(b) Rollout Comparison", fid, size=13, weight="700")
    changes.append(add_obj_change(rlab, fid))

    para_id, pw, ph_, pmt = upload_media(
        "/data/cameron/penpot/figures/extracted/frame_5_para.png",
        "Rollout PARA")
    glob_id, gw, gh, gmt = upload_media(
        "/data/cameron/penpot/figures/extracted/frame_5_global.png",
        "Rollout Global")

    cell_w = (right_w - 20) // 2
    cell_h = int(cell_w * ph_ / pw)
    if cell_h > 280:
        cell_h = 280

    _, im_p = image("fig5-rollout-para", right_x, right_y + 30, cell_w, cell_h,
                    para_id, pw, ph_, fid, mtype=pmt)
    _, im_g = image("fig5-rollout-global", right_x + cell_w + 20, right_y + 30, cell_w, cell_h,
                    glob_id, gw, gh, fid, mtype=gmt)
    changes.extend([add_obj_change(im_p, fid), add_obj_change(im_g, fid)])

    _, cp = text("fig5-rollout-para-cap", right_x, right_y + 30 + cell_h + 6, cell_w, 22,
                 "SVD + PARA — 92%", fid, size=14, weight="700", color=PARA_GREEN, align="center")
    _, cg = text("fig5-rollout-global-cap", right_x + cell_w + 20, right_y + 30 + cell_h + 6, cell_w, 22,
                 "SVD + Global Regression — 0%", fid, size=14, weight="700", color=ACT_RED, align="center")
    changes.extend([add_obj_change(cp, fid), add_obj_change(cg, fid)])

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Figure 6: Point Track Pretraining (placeholder)
# ─────────────────────────────────────────────────────────────────────────────

def build_fig6_point_tracks():
    fy = fig_y(6)
    fh = 480
    print(f"\n=== Fig 6: Point Track Pretraining (placeholder) ===")
    fid, fobj = frame("Fig6_PointTracks", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    _, t = text("fig6-title", FRAME_X+30, fy+15, FRAME_W-60, 30,
                "Figure 6: Point Track Pretraining (Preliminary)",
                fid, size=20, weight="700")
    changes.append(add_obj_change(t, fid))

    # Left: 3 placeholder frames
    left_x = FRAME_X + 30
    left_y = fy + 70
    left_w = 540
    _, lab = text("fig6-left-label", left_x, left_y, left_w, 22,
                  "(a) Arm-Deleted Training Data", fid, size=13, weight="700")
    changes.append(add_obj_change(lab, fid))

    cell_w = (left_w - 20) // 3
    cell_h = cell_w * 3 // 4
    table_color = "#e8d5b7"
    for i in range(3):
        cx = left_x + i * (cell_w + 10)
        cy = left_y + 35
        # Placeholder: tan rectangle (table) with a small green dot (EEF)
        _, table = rect(f"fig6-table-{i}", cx, cy, cell_w, cell_h, fid,
                        fill=table_color, stroke=(1, BORDER), rx=4)
        # Object (small circle stand-in)
        ox = cx + cell_w//2 - 8
        oy = cy + cell_h//2
        _, obj = rect(f"fig6-obj-{i}", ox, oy, 16, 8, fid, fill="#888888", rx=4)
        # Green dot (EEF position)
        dx = cx + cell_w//2 + 20 + i*4
        dy = cy + cell_h//2 - 12
        _, dot = rect(f"fig6-dot-{i}", dx, dy, 8, 8, fid, fill=PARA_GREEN, rx=4)
        # Frame number
        _, fn = text(f"fig6-fn-{i}", cx, cy + cell_h + 2, cell_w, 16,
                     f"t = {i*5}", fid, size=10, color=SUBTLE, align="center")
        for x in (table, obj, dot, fn):
            changes.append(add_obj_change(x, fid))

    _, leftcap = text("fig6-leftcap", left_x, left_y + 35 + cell_h + 24, left_w, 18,
                      "(robot arm invisible; only EEF position retained as supervision)",
                      fid, size=10, weight="500", color=SUBTLE, align="center")
    changes.append(add_obj_change(leftcap, fid))

    # Right: placeholder bar chart
    right_x = FRAME_X + 600
    right_y = fy + 70
    right_w = 560
    _, rlab = text("fig6-right-label", right_x, right_y, right_w, 22,
                   "(b) Pretrain → Fine-tune Results (preliminary)",
                   fid, size=13, weight="700")
    changes.append(add_obj_change(rlab, fid))

    # 4 bars: PARA pretrain, PARA scratch, Global pretrain, Global scratch
    bar_data = [
        ("PARA\npretrain",  70, PARA_GREEN),
        ("PARA\nscratch",   30, PARA_GREEN),
        ("Global\npretrain", 20, ACT_RED),
        ("Global\nscratch",  15, ACT_RED),
    ]
    plot_x = right_x + 30
    plot_y = right_y + 60
    plot_w = right_w - 60
    plot_h = 240
    bar_w = (plot_w - 40) // 4
    bar_gap = (plot_w - bar_w*4) // 5

    # Y-axis line
    _, ya = rect("fig6-y-axis", plot_x, plot_y, 2, plot_h, fid, fill="#888888")
    # X-axis line
    _, xa = rect("fig6-x-axis", plot_x, plot_y + plot_h - 1, plot_w, 2, fid, fill="#888888")
    changes.extend([add_obj_change(ya, fid), add_obj_change(xa, fid)])

    for i, (label, val, color) in enumerate(bar_data):
        bh = int(plot_h * val / 100)
        bx = plot_x + bar_gap + i * (bar_w + bar_gap)
        by = plot_y + plot_h - bh
        _, b = rect(f"fig6-bar-{i}", bx, by, bar_w, bh, fid, fill=color)
        # Hatching effect for "scratch" — overlay with lower opacity color (using lighter color rect)
        _, bv = text(f"fig6-bv-{i}", bx, by - 18, bar_w, 16,
                     f"{val}%", fid, size=11, weight="700", align="center")
        _, bl = text(f"fig6-bl-{i}", bx-4, plot_y + plot_h + 4, bar_w+8, 28,
                     label.replace("\n", " "), fid, size=9, weight="500",
                     color=SUBTLE, align="center")
        for x in (b, bv, bl):
            changes.append(add_obj_change(x, fid))

    _, rcap = text("fig6-rcap", right_x, right_y + plot_h + 80, right_w, 18,
                   "Preliminary — final values pending from backbones agent",
                   fid, size=10, weight="600", color=NEUTRAL, align="center")
    changes.append(add_obj_change(rcap, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


def build_fig4_with_chart_only():
    """Fig 4: OOD Generalization Analysis. For now: just panel (b) per-theta chart.
    Full figure (with 4a distribution + 4c qualitative comparison) added later.
    """
    fy = fig_y(4)
    fh = 700
    print(f"\n=== Fig 4: OOD Analysis (placing panel 4b first) ===")
    print(f"  Frame at ({FRAME_X},{fy}) {FRAME_W}x{fh}")

    fid, fobj = frame("Fig4_OOD_Analysis", FRAME_X, fy, FRAME_W, fh)
    changes = [add_obj_change(fobj, ROOT)]

    # Title
    _, t = text("fig4-title", FRAME_X + 30, fy + 20, FRAME_W - 60, 30,
                "Figure 4: OOD Generalization Analysis",
                fid, size=20, weight="700", align="left")
    changes.append(add_obj_change(t, fid))

    # Panel labels
    _, lab_a = text("panel-a-label", FRAME_X + 30,  fy + 70, 380, 22,
                    "(a) Spatial Extrapolation", fid, size=14, weight="700", align="left")
    _, lab_b = text("panel-b-label", FRAME_X + 430, fy + 70, 380, 22,
                    "(b) Per-Angle Viewpoint Robustness", fid, size=14, weight="700", align="left")
    _, lab_c = text("panel-c-label", FRAME_X + 830, fy + 70, 380, 22,
                    "(c) Qualitative Comparison", fid, size=14, weight="700", align="left")
    for lab in (lab_a, lab_b, lab_c):
        changes.append(add_obj_change(lab, fid))

    # Upload + place per-theta chart in panel (b)
    media_id, mw, mh, mtype = upload_media(
        "/data/cameron/penpot/figures/per_theta.png",
        "Per-Theta Chart")
    chart_w = 380
    chart_h = int(chart_w * mh / mw)
    chart_x = FRAME_X + 430
    chart_y = fy + 110
    _, img = image("per-theta-chart", chart_x, chart_y, chart_w, chart_h,
                   media_id, mw, mh, fid, mtype=mtype)
    changes.append(add_obj_change(img, fid))

    # Placeholders for 4a and 4c (to be filled in later)
    _, ph_a = rect("placeholder-4a", FRAME_X + 30, fy + 110, 380, 280, fid,
                   stroke=(2, BORDER), rx=8)
    _, ph_c = rect("placeholder-4c", FRAME_X + 830, fy + 110, 380, 280, fid,
                   stroke=(2, BORDER), rx=8)
    changes.append(add_obj_change(ph_a, fid))
    changes.append(add_obj_change(ph_c, fid))
    _, ph_a_t = text("placeholder-4a-text", FRAME_X + 30, fy + 240, 380, 30,
                     "(distribution plot — pending)", fid, size=12,
                     color=NEUTRAL, align="center")
    _, ph_c_t = text("placeholder-4c-text", FRAME_X + 830, fy + 240, 380, 30,
                     "(rollout frames — pending)", fid, size=12,
                     color=NEUTRAL, align="center")
    changes.append(add_obj_change(ph_a_t, fid))
    changes.append(add_obj_change(ph_c_t, fid))

    commit_changes(changes)
    print(f"  Done. Frame: {fid}")


# ─────────────────────────────────────────────────────────────────────────────
# Main
# ─────────────────────────────────────────────────────────────────────────────

BUILDERS = {
    "1":  build_fig1_overview,
    "2":  build_fig2_method,
    "3":  build_fig3_real_robot,
    "4":  build_fig4_full,
    "4b": build_fig4_with_chart_only,
    "5":  build_fig5_video_backbone,
    "6":  build_fig6_point_tracks,
}

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--figs", default="4b", help="comma-separated list, or 'all'")
    args = parser.parse_args()

    if args.figs == "all":
        keys = list(BUILDERS.keys())
    else:
        keys = [k.strip() for k in args.figs.split(",")]

    for k in keys:
        if k not in BUILDERS:
            print(f"unknown figure: {k} (have: {list(BUILDERS.keys())})")
            sys.exit(1)
        BUILDERS[k]()


if __name__ == "__main__":
    main()
