"""Hi-res spatial train/test distribution illustration for fig4_ood panel (a).

Matches the dark-navy aesthetic of the viewpoint polar plot — both are
the two distribution illustrations on the left of fig4, so they should
read as a matched pair.

Train positions = full left half of a 16×16 workspace grid (128 cells).
Test positions  = 20 held-out cells sampled from the right half.

Replaces `exp3_leftright_distribution.png` (originally a baked-in-legend
LIBERO render whose source script we lost). The new illustration is
fully synthetic so we own the source going forward.

Usage:
    python3 generate_spatial_distribution.py [--dpi 240] [--out /path/to/png]
"""
import argparse
from pathlib import Path
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from matplotlib.patches import FancyBboxPatch, Rectangle


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dpi", type=int, default=240)
    ap.add_argument("--out",
                    default="/data/cameron/para/.agents/reports/project_site/media/exp3_leftright_distribution.png")
    ap.add_argument("--mirror",
                    default="/data/cameron/para/.agents/reports/backbones/media/exp3_leftright_distribution.png",
                    help="Second copy of the output (set empty to skip)")
    ap.add_argument("--n_test", type=int, default=20)
    ap.add_argument("--seed", type=int, default=7,
                    help="RNG seed for selecting test cells")
    args = ap.parse_args()

    BG     = "#0f172a"
    FG     = "#e2e8f0"
    GRID   = "#334155"
    TABLE  = "#d4a574"  # warm wood color so green/blue contrast pops
    TABLE_E = "#8b6b3a"
    GREEN  = "#22c55e"
    GREEN_E = "#14532d"
    BLUE   = "#60a5fa"
    BLUE_E = "#1e3a8a"

    fig, ax = plt.subplots(figsize=(5.5, 5.0))
    fig.patch.set_facecolor(BG)
    ax.set_facecolor(BG)

    # Workspace bounds (logical, normalized)
    NX = 16
    NY = 16
    x_centers = np.linspace(0.5, NX - 0.5, NX)
    y_centers = np.linspace(0.5, NY - 0.5, NY)
    half = NX // 2  # 8 — divider between left/right

    # Table backdrop — rounded rect spanning the dot grid
    pad = 0.7
    ax.add_patch(FancyBboxPatch(
        (0 - pad, 0 - pad),
        NX + 2 * pad, NY + 2 * pad,
        boxstyle="round,pad=0,rounding_size=0.7",
        linewidth=1.6, edgecolor=TABLE_E, facecolor=TABLE, zorder=1))

    # Faint grid lines
    for x in range(NX + 1):
        ax.axvline(x, color=TABLE_E, linewidth=0.4, alpha=0.25, zorder=2)
    for y in range(NY + 1):
        ax.axhline(y, color=TABLE_E, linewidth=0.4, alpha=0.25, zorder=2)

    # Left/right divider — dashed
    ax.plot([half, half], [-pad, NY + pad],
            color="#1e293b", linewidth=2.0, linestyle="--", alpha=0.85, zorder=3)

    # Train: all 128 cells of left half (cols 0..half-1, rows 0..NY-1)
    train_x, train_y = [], []
    for c in range(half):
        for r in range(NY):
            train_x.append(x_centers[c])
            train_y.append(y_centers[r])
    ax.scatter(train_x, train_y, s=160, c=GREEN,
               edgecolors=GREEN_E, linewidths=1.2, alpha=0.95,
               label="Train", zorder=4)

    # Test: n_test cells sampled from right half
    rng = np.random.default_rng(args.seed)
    right_cells = [(c, r) for c in range(half, NX) for r in range(NY)]
    test_idx = rng.choice(len(right_cells), size=args.n_test, replace=False)
    test_x = [x_centers[right_cells[i][0]] for i in test_idx]
    test_y = [y_centers[right_cells[i][1]] for i in test_idx]
    ax.scatter(test_x, test_y, s=180, c=BLUE,
               edgecolors=BLUE_E, linewidths=1.3, alpha=0.95,
               label="Test", zorder=5)

    # Side labels
    ax.text(half / 2, -pad - 0.6, "train side", ha="center", va="top",
            color=GREEN, fontsize=11, fontweight="700", fontstyle="italic")
    ax.text(half + (NX - half) / 2, -pad - 0.6, "held-out side", ha="center", va="top",
            color=BLUE, fontsize=11, fontweight="700", fontstyle="italic")

    ax.set_xlim(-pad - 1.2, NX + pad + 1.2)
    ax.set_ylim(-pad - 2.0, NY + pad + 1.4)
    ax.set_aspect("equal")
    ax.set_xticks([])
    ax.set_yticks([])
    for s in ax.spines.values():
        s.set_visible(False)

    ax.set_title("Train / Held-out Position Split", color=FG,
                 fontsize=14, fontweight="700", pad=12)

    leg = ax.legend(loc="upper right",
                    bbox_to_anchor=(1.02, 1.04),
                    frameon=True, facecolor="#1e293b", edgecolor=GRID,
                    labelcolor=FG, fontsize=11)
    leg.get_frame().set_linewidth(0.8)

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out_path, dpi=args.dpi, bbox_inches="tight", facecolor=BG)
    plt.close(fig)
    print(f"wrote {out_path} at {args.dpi} dpi")

    if args.mirror:
        from shutil import copy2
        mp = Path(args.mirror)
        mp.parent.mkdir(parents=True, exist_ok=True)
        copy2(out_path, mp)
        print(f"mirrored to {mp}")


if __name__ == "__main__":
    main()
