"""Cluster training rotations into N centroids in canonical-quaternion space.

K-means in quat-space with sign-canonical normalisation (force w ≥ 0 so antipodal
quaternions q and -q map to the same point). Output:

  rotation_kmeans_basis_<name>.npz
    centroids_quat  : (N, 4) — quat centroids (sign-canonical, normalised)
    centroids_euler : (N, 3) — euler XYZ at each centroid (for decode)
    bin_counts      : (N,)   — how many training samples assigned to each centroid
    n_clusters      : int    — N
    n_samples       : int    — total training samples
"""
import argparse, sys, os
sys.path.insert(0, "/data/cameron/para/libero")
import numpy as np
import torch
from scipy.spatial.transform import Rotation as R
from sklearn.cluster import KMeans
from data_da3_volume import Smith300DA3VolumeDataset

DEFAULT_OUT = "/data/cameron/para/libero/rotation_kmeans_basis.npz"


def canonical_quat(q):
    """Force w ≥ 0 (last component if scalar-last format). q: (N, 4) xyzw."""
    q = np.asarray(q, dtype=np.float64).copy()
    # scipy default xyzw: w is last → flip if w < 0
    mask = q[:, 3] < 0
    q[mask] *= -1
    # Normalise to unit quaternions
    q /= (np.linalg.norm(q, axis=-1, keepdims=True) + 1e-12)
    return q


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str, default="/data/cameron/mac_robot_datasets")
    p.add_argument("--sessions_whitelist", type=str, required=True,
                   help="Comma-separated session names (e.g. 'home_towel').")
    p.add_argument("--n_clusters", type=int, default=8)
    p.add_argument("--out", type=str, default=DEFAULT_OUT)
    args = p.parse_args()

    wl = [s.strip() for s in args.sessions_whitelist.split(",") if s.strip()]
    print(f"Loading {wl} from {args.root_dir} (n_clusters={args.n_clusters})")
    ds = Smith300DA3VolumeDataset(
        root_dir=args.root_dir, sessions_whitelist=wl,
        n_window=1, frame_stride=1,
    )
    eul = ds.eef_euler_t.numpy()                           # (N, 3) xyz euler
    quats = R.from_euler('xyz', eul).as_quat()             # (N, 4) xyzw
    quats = canonical_quat(quats)
    print(f"  {quats.shape[0]} samples, quat ranges (canonical): "
          f"x=[{quats[:,0].min():+.3f},{quats[:,0].max():+.3f}] "
          f"y=[{quats[:,1].min():+.3f},{quats[:,1].max():+.3f}] "
          f"z=[{quats[:,2].min():+.3f},{quats[:,2].max():+.3f}] "
          f"w=[{quats[:,3].min():+.3f},{quats[:,3].max():+.3f}]")

    print(f"K-means with N={args.n_clusters}...")
    km = KMeans(n_clusters=args.n_clusters, n_init=10, random_state=0).fit(quats)
    centroids_quat = km.cluster_centers_                   # (N, 4) — NOT unit quats yet
    centroids_quat = canonical_quat(centroids_quat)        # re-normalise to unit + canonical
    labels = km.labels_                                    # (N_samples,)
    bin_counts = np.bincount(labels, minlength=args.n_clusters)
    print(f"  per-bin sample counts: {bin_counts.tolist()}")
    centroids_euler = R.from_quat(centroids_quat).as_euler('xyz')
    print(f"  centroid eulers (xyz):")
    for i, (cq, ce, c) in enumerate(zip(centroids_quat, centroids_euler, bin_counts)):
        print(f"    bin {i}: {c} samples — euler=[{ce[0]:+.3f}, {ce[1]:+.3f}, {ce[2]:+.3f}]")

    np.savez(args.out,
             centroids_quat=centroids_quat.astype(np.float32),
             centroids_euler=centroids_euler.astype(np.float32),
             bin_counts=bin_counts.astype(np.int64),
             n_clusters=np.int64(args.n_clusters),
             n_samples=np.int64(quats.shape[0]),
             sessions_whitelist=",".join(wl))
    print(f"Saved {args.out}")


if __name__ == "__main__":
    main()
