"""Precompute PCA-1D rotation basis for the dataset.

For Cameron's "Model surgery" task 2026-05-20: collapse 3-axis Euler rotation prediction
to a single PCA-1D axis. This script computes the basis once and saves it for the data
loader + decode helpers to use.

Sanity gate: PC1 explained-variance ratio must be ≥ 0.85. If not, the data has too much
rotational variance for 1D collapse — abort and ping Cameron.

Output: rotation_pca_basis.npz with keys
  mean (3,)          — μ
  principal_axis (3,)— v1 (top eigenvector)
  pca_min, pca_max   — observed range of projected values
  ev_ratio_pc1, ev_ratio_pc2, ev_ratio_pc3 — explained-variance ratios
"""
import argparse, sys, os
sys.path.insert(0, "/data/cameron/para/libero")
import numpy as np
import torch
from data_da3_volume import Smith300DA3VolumeDataset

DEFAULT_OUT = "/data/cameron/para/libero/rotation_pca_basis.npz"


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str,
                   default="/data/cameron/mac_robot_datasets/first_mobile_collection")
    p.add_argument("--sessions_whitelist", type=str, default="",
                   help="Comma-separated session dirs to include (empty = all).")
    p.add_argument("--out", type=str, default=DEFAULT_OUT)
    p.add_argument("--depth_subdir", type=str, default="da3_depth_large")
    p.add_argument("--gate_ev", type=float, default=0.85,
                   help="Minimum PC1 explained-variance ratio. Abort if PC1 EV < this.")
    args = p.parse_args()

    wl = [s.strip() for s in args.sessions_whitelist.split(",") if s.strip()] if args.sessions_whitelist else None
    print(f"Loading dataset from {args.root_dir}"
          + (f" (whitelist: {wl})" if wl else ""))
    ds = Smith300DA3VolumeDataset(root_dir=args.root_dir, depth_subdir=args.depth_subdir,
                                   sessions_whitelist=wl)
    rot = ds.eef_euler_t.numpy()  # (N, 3) float32
    N = rot.shape[0]
    print(f"Collected {N} rotation samples (3 euler axes each)")
    print(f"  axis 0 (roll) :  min={rot[:,0].min():+.3f}  max={rot[:,0].max():+.3f}  std={rot[:,0].std():.3f}")
    print(f"  axis 1 (pitch):  min={rot[:,1].min():+.3f}  max={rot[:,1].max():+.3f}  std={rot[:,1].std():.3f}")
    print(f"  axis 2 (yaw)  :  min={rot[:,2].min():+.3f}  max={rot[:,2].max():+.3f}  std={rot[:,2].std():.3f}")

    mu = rot.mean(axis=0)
    centred = rot - mu
    u, s, vt = np.linalg.svd(centred, full_matrices=False)
    # s has the singular values, vt rows are the principal axes (descending order).
    ev = (s ** 2) / (s ** 2).sum()
    v1 = vt[0]                                        # (3,) top eigenvector
    proj = centred @ v1                               # (N,) projection on PC1
    pmin, pmax = float(proj.min()), float(proj.max())
    print(f"\nPCA explained-variance ratios: PC1={ev[0]:.4f}  PC2={ev[1]:.4f}  PC3={ev[2]:.4f}")
    print(f"Principal axis (v1): [{v1[0]:+.4f}, {v1[1]:+.4f}, {v1[2]:+.4f}]")
    print(f"Mean (μ): [{mu[0]:+.4f}, {mu[1]:+.4f}, {mu[2]:+.4f}]")
    print(f"PCA-1D range: [{pmin:.4f}, {pmax:.4f}]")
    print(f"Bin spacing at N_ROT_BINS=48: {(pmax-pmin)/48:.4f}")

    # SANITY GATE
    if ev[0] < args.gate_ev:
        print(f"\n!! ABORT: PC1 explained-variance ratio = {ev[0]:.4f} < threshold {args.gate_ev}")
        print(f"!! Rotation is too multi-axis for 1D collapse. Pinging Cameron — DO NOT proceed.")
        sys.exit(1)
    else:
        print(f"\n✓ PC1 EV ratio {ev[0]:.4f} ≥ {args.gate_ev} threshold — sanity gate passed.")

    np.savez(args.out,
             mean=mu.astype(np.float32),
             principal_axis=v1.astype(np.float32),
             pca_min=np.float32(pmin),
             pca_max=np.float32(pmax),
             ev_ratio_pc1=np.float32(ev[0]),
             ev_ratio_pc2=np.float32(ev[1]),
             ev_ratio_pc3=np.float32(ev[2]),
             n_samples=np.int64(N),
             root_dir=args.root_dir,
             sessions_whitelist=",".join(wl) if wl else "")
    print(f"Saved: {args.out}")


if __name__ == "__main__":
    main()
