"""Hi-res polar plot for the viewpoint OOD figure (fig4_ood, panel (d)).

Generates `vp_default_to_all_polar_only.png` showing the "default-view-only train,
all-other-views test" protocol:
- 1 green dot at center = train (default camera)
- 56 blue dots = test viewpoints arrayed on a polar grid (8 azimuths × 7 elevations)

Replaces the low-res 286×262 file that ships embedded in fig4_ood.svg.

Usage:
    python3 generate_viewpoint_polar.py [--dpi 240] [--out /path/to/vp_polar_only.png]
"""
import argparse
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--dpi", type=int, default=240,
                    help="Output DPI (default 240; >=200 gives crisp embed at 412px wide)")
    ap.add_argument("--size", type=float, default=5.0,
                    help="Figure size in inches (square; default 5.0)")
    ap.add_argument("--out", default="/data/cameron/penpot/figures/extracted/fig4/vp_polar_only.png")
    ap.add_argument("--n_azim", type=int, default=8, help="Azimuth spokes (default 8)")
    ap.add_argument("--n_elev", type=int, default=7, help="Elevation rings (default 7)")
    ap.add_argument("--elev_min", type=float, default=3.5, help="Inner elevation degrees")
    ap.add_argument("--elev_max", type=float, default=25.0, help="Outer elevation degrees")
    args = ap.parse_args()

    azimuths = np.linspace(0, 2 * np.pi, args.n_azim, endpoint=False)
    elevations = np.linspace(args.elev_min, args.elev_max, args.n_elev)

    test_theta = np.tile(azimuths, args.n_elev)
    test_r = np.repeat(elevations, args.n_azim)
    n_test = test_theta.size

    bg = "#0f172a"
    fg_text = "#e2e8f0"
    grid = "#334155"
    train_green = "#22c55e"
    test_blue = "#60a5fa"

    fig, ax = plt.subplots(figsize=(args.size, args.size),
                           subplot_kw={"projection": "polar"})
    fig.patch.set_facecolor(bg)
    ax.set_facecolor(bg)

    ax.scatter(test_theta, test_r, s=180, c=test_blue,
               edgecolors="#1e3a8a", linewidths=1.3, alpha=0.95,
               label="Test", zorder=3)
    ax.scatter([0], [0], s=240, c=train_green,
               edgecolors="#14532d", linewidths=1.6,
               label="Train", zorder=4)

    ax.set_theta_zero_location("E")
    ax.set_theta_direction(1)  # CCW from east — matches standard polar convention
    ax.set_rlim(0, args.elev_max + 2)
    ax.set_rticks([5, 10, 15, 20, 25])
    ax.set_rlabel_position(60)
    ax.set_yticklabels([f"{int(t)}" for t in ax.get_yticks()],
                       color=fg_text, fontsize=10, fontweight="600")
    ax.set_xticklabels([f"{int(np.degrees(a))}°"
                        for a in np.linspace(0, 2 * np.pi, 8, endpoint=False)],
                       color=fg_text, fontsize=11, fontweight="600")
    ax.grid(True, color=grid, linewidth=0.9, alpha=0.55)
    ax.spines["polar"].set_color(grid)
    ax.spines["polar"].set_linewidth(1.0)

    ax.set_title("Default → All Viewpoints", color=fg_text,
                 fontsize=14, fontweight="700", pad=14)

    leg = ax.legend(loc="upper right", bbox_to_anchor=(1.18, 1.10),
                    frameon=True, facecolor="#1e293b", edgecolor=grid,
                    labelcolor=fg_text, fontsize=10)
    leg.get_frame().set_linewidth(0.8)

    out = Path(args.out)
    out.parent.mkdir(parents=True, exist_ok=True)
    fig.savefig(out, dpi=args.dpi, bbox_inches="tight", facecolor=bg)
    plt.close(fig)
    print(f"wrote {out} at {args.dpi} dpi → ~{int(args.size * args.dpi)}px wide")


if __name__ == "__main__":
    main()
