"""Create train/test splits for OOD viewpoint experiments.

The viewpoint dataset has N_VIEWS x N_VIEWS viewpoints, each with DEMOS_PER_VIEW demos.
Splits are based on which VIEWPOINTS are in train vs test.
All demos for a given viewpoint go together (no mixing).

Usage:
    python create_viewpoint_splits.py --data_root /data/libero/ood_viewpoint_v2 --splits_root /data/libero/ood_viewpoint_v2_splits
"""
import argparse
import os
import shutil
from pathlib import Path
import numpy as np


def create_split(splits_root, split_name, data_root, viewpoint_indices, demos_per_view):
    """Create a symlink split directory for given viewpoint indices."""
    split_dir = splits_root / f"{split_name}_train" / "libero_spatial" / "task_0"
    if split_dir.exists():
        shutil.rmtree(split_dir)
    split_dir.mkdir(parents=True)

    new_idx = 0
    for vi in sorted(viewpoint_indices):
        for di in range(demos_per_view):
            real_idx = vi * demos_per_view + di
            src = data_root / "libero_spatial" / "task_0" / f"demo_{real_idx}"
            dst = split_dir / f"demo_{new_idx}"
            if src.exists():
                os.symlink(str(src.resolve()), str(dst))
                new_idx += 1

    print(f"  {split_name}: {len(viewpoint_indices)} viewpoints, {new_idx} demos -> {split_dir}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, required=True)
    parser.add_argument("--splits_root", type=str, required=True)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    data_root = Path(args.data_root)
    splits_root = Path(args.splits_root)
    splits_root.mkdir(parents=True, exist_ok=True)
    rng = np.random.RandomState(args.seed)

    # Load viewpoint metadata
    meta = np.load(data_root / "libero_spatial" / "task_0" / "viewpoint_meta.npz")
    n_views = int(meta["n_views"])
    demos_per_view = int(meta["demos_per_view"])
    thetas_deg = meta["thetas_deg"]
    phis_deg = meta["phis_deg"]
    n_viewpoints = n_views * n_views

    print(f"Viewpoint grid: {n_views}x{n_views} = {n_viewpoints} viewpoints")
    print(f"Demos per viewpoint: {demos_per_view}")
    print(f"Theta: {thetas_deg.round(1)}")
    print(f"Phi: {phis_deg.round(1)}")
    print(f"Total demos: {n_viewpoints * demos_per_view}")

    # Build viewpoint index: vi -> (theta_idx, phi_idx)
    vp_info = {}
    for vi in range(n_viewpoints):
        ti = vi // n_views
        pi = vi % n_views
        vp_info[vi] = {"ti": ti, "pi": pi, "theta": thetas_deg[ti], "phi": phis_deg[pi]}

    all_vi = list(range(n_viewpoints))

    # --- Fixed test viewpoints: 10 random viewpoints ---
    test_vis = sorted(rng.choice(all_vi, size=10, replace=False))
    np.save(splits_root / "test_viewpoints.npy", np.array(test_vis))
    print(f"\nTest viewpoints ({len(test_vis)}): {test_vis}")

    # === RADIAL SPLITS (by theta) ===

    # Default only: train theta=0, test all others
    train = [vi for vi in all_vi if vp_info[vi]["ti"] == 0]
    # Deduplicate theta=0 (all phis at theta=0 are same viewpoint, but keep all for data)
    test_default = [vi for vi in all_vi if vp_info[vi]["ti"] > 0]
    test_default_sample = sorted(rng.choice(test_default, size=min(10, len(test_default)), replace=False))
    np.save(splits_root / "test_default_only.npy", np.array(test_default_sample))
    create_split(splits_root, "vp_default_only", data_root, train, demos_per_view)

    # Inner -> Outer: train theta <= median, test theta > median
    mid_ti = n_views // 2
    for threshold_ti, label in [(1, "inner5"), (2, "inner10"), (mid_ti, "inner_half")]:
        train = [vi for vi in all_vi if vp_info[vi]["ti"] <= threshold_ti]
        test_pool = [vi for vi in all_vi if vp_info[vi]["ti"] > threshold_ti]
        test_sample = sorted(rng.choice(test_pool, size=min(10, len(test_pool)), replace=False))
        np.save(splits_root / f"test_{label}.npy", np.array(test_sample))
        create_split(splits_root, f"vp_{label}", data_root, train, demos_per_view)

    # Outer -> Inner: train theta >= median, test theta < median
    for threshold_ti, label in [(n_views-2, "outer20"), (mid_ti, "outer_half")]:
        train = [vi for vi in all_vi if vp_info[vi]["ti"] >= threshold_ti]
        test_pool = [vi for vi in all_vi if vp_info[vi]["ti"] < threshold_ti]
        test_sample = sorted(rng.choice(test_pool, size=min(10, len(test_pool)), replace=False))
        np.save(splits_root / f"test_{label}.npy", np.array(test_sample))
        create_split(splits_root, f"vp_{label}", data_root, train, demos_per_view)

    # === AZIMUTHAL SPLITS ===
    # Train on 5/8 azimuths, test on 3/8
    train_phis = [0, 1, 2, 3, 4]
    test_phis = [5, 6, 7]
    train = [vi for vi in all_vi if vp_info[vi]["pi"] in train_phis]
    test_pool = [vi for vi in all_vi if vp_info[vi]["pi"] in test_phis]
    test_sample = sorted(rng.choice(test_pool, size=min(10, len(test_pool)), replace=False))
    np.save(splits_root / "test_azimuthal.npy", np.array(test_sample))
    create_split(splits_root, "vp_azimuthal", data_root, train, demos_per_view)

    # === COVERAGE SCALING ===
    # N=1 (center only)
    train = [vi for vi in all_vi if vp_info[vi]["ti"] == 0][:1]
    create_split(splits_root, "vp_n1", data_root, train, demos_per_view)

    # N=5 (center + extremes at theta_max)
    train = [vi for vi in all_vi if vp_info[vi]["ti"] == 0][:1] + \
            [vi for vi in all_vi if vp_info[vi]["ti"] == n_views-1]
    create_split(splits_root, "vp_n_extremes", data_root, train, demos_per_view)

    # N=sparse (center + every other ring, every other phi)
    train = [vi for vi in all_vi if vp_info[vi]["ti"] % 2 == 0 and vp_info[vi]["pi"] % 2 == 0]
    create_split(splits_root, "vp_sparse", data_root, train, demos_per_view)

    # N=all (full coverage)
    create_split(splits_root, "vp_all", data_root, all_vi, demos_per_view)

    # Save metadata
    np.savez(splits_root / "split_meta.npz",
             n_views=n_views, demos_per_view=demos_per_view,
             thetas_deg=thetas_deg, phis_deg=phis_deg)

    print("\nAll viewpoint splits created.")


if __name__ == "__main__":
    main()
