"""Create train/test splits for OOD object position experiments.

Creates symlink directories for each experiment split, pointing back to
the base dataset. Also creates test position .npy files for eval.

Usage:
    python create_splits.py --data_root /data/libero/ood_objpos_centered --splits_root /data/libero/ood_objpos_centered_splits
"""
import argparse
import os
import shutil
from pathlib import Path
import numpy as np


def create_split(splits_root, split_name, data_root, demo_indices, grid_size=16):
    """Create a symlink split directory with sequential demo names."""
    split_dir = splits_root / f"{split_name}_train" / "libero_spatial" / "task_0"
    if split_dir.exists():
        shutil.rmtree(split_dir)
    split_dir.mkdir(parents=True)

    for new_idx, real_idx in enumerate(sorted(demo_indices)):
        src = data_root / "libero_spatial" / "task_0" / f"demo_{real_idx}"
        dst = split_dir / f"demo_{new_idx}"
        os.symlink(str(src.resolve()), str(dst))

    print(f"  {split_name}: {len(demo_indices)} demos -> {split_dir}")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_root", type=str, required=True,
                        help="Path to base OOD dataset (e.g. /data/libero/ood_objpos_centered)")
    parser.add_argument("--splits_root", type=str, required=True,
                        help="Path to output splits directory")
    parser.add_argument("--grid_size", type=int, default=16)
    parser.add_argument("--seed", type=int, default=42)
    args = parser.parse_args()

    data_root = Path(args.data_root)
    splits_root = Path(args.splits_root)
    splits_root.mkdir(parents=True, exist_ok=True)
    rng = np.random.RandomState(args.seed)

    # Load grid metadata — derive grid dimensions from saved arrays
    meta = np.load(data_root / "libero_spatial" / "task_0" / "grid_meta.npz")
    dx_vals = meta["dx_vals"]
    dy_vals = meta["dy_vals"]
    center_dx = float(meta["center_dx"])
    center_dy = float(meta["center_dy"])
    Nx = len(dx_vals)  # number of dx values (rows)
    Ny = len(dy_vals)  # number of dy values (columns)
    N = Ny  # column stride for demo indexing: demo_idx = i * Ny + j

    print(f"Grid: {Nx}x{Ny} = {Nx*Ny} demos, dx=[{dx_vals[0]:.3f}, {dx_vals[-1]:.3f}], dy=[{dy_vals[0]:.3f}, {dy_vals[-1]:.3f}]")
    print(f"Center offset: ({center_dx:+.4f}, {center_dy:+.4f})")

    # --- Fixed test set: 20 uniformly sampled positions ---
    all_indices = list(range(Nx * Ny))
    test_indices = sorted(rng.choice(all_indices, size=20, replace=False))
    test_positions = []
    for idx in test_indices:
        i, j = idx // Ny, idx % Ny
        test_positions.append([center_dx + dx_vals[i], center_dy + dy_vals[j]])
    test_positions = np.array(test_positions)

    np.save(splits_root / "test_indices.npy", np.array(test_indices))
    np.save(splits_root / "test_positions.npy", test_positions)
    print(f"\nTest set: {len(test_indices)} positions")
    print(f"  Indices: {test_indices}")

    # --- Exp 1: Inner Square (center 4×4) ---
    inner_i = range(Nx // 2 - 2, Nx // 2 + 2)
    inner_j = range(Ny // 2 - 2, Ny // 2 + 2)
    exp1_indices = [i * Ny + j for i in inner_i for j in inner_j]
    create_split(splits_root, "exp1_inner", data_root, exp1_indices, N)

    # --- Exp 2: Random 10 ---
    exp2_indices = sorted(rng.choice(all_indices, size=10, replace=False))
    create_split(splits_root, "exp2_random10", data_root, exp2_indices, N)

    # --- Exp 3a: Near → Far (split by i index, first half vs second half) ---
    exp3_near = [i * Ny + j for i in range(Nx // 2) for j in range(Ny)]
    exp3_far_indices = [i * Ny + j for i in range(Nx // 2, Nx) for j in range(Ny)]
    # Test positions from far half only
    exp3_far_test = sorted(rng.choice(exp3_far_indices, size=20, replace=False))
    exp3_far_test_positions = []
    for idx in exp3_far_test:
        i, j = idx // Ny, idx % Ny
        exp3_far_test_positions.append([center_dx + dx_vals[i], center_dy + dy_vals[j]])
    np.save(splits_root / "exp3_far_test.npy", np.array(exp3_far_test_positions))
    np.save(splits_root / "exp3_far_test_indices.npy", np.array(exp3_far_test))
    create_split(splits_root, "exp3_near", data_root, exp3_near, N)

    # --- Exp 3b: Left → Right (split by j index) ---
    exp3_left = [i * Ny + j for i in range(Nx) for j in range(Ny // 2)]
    exp3_right_indices = [i * Ny + j for i in range(Nx) for j in range(Ny // 2, Ny)]
    exp3_right_test = sorted(rng.choice(exp3_right_indices, size=20, replace=False))
    exp3_right_test_positions = []
    for idx in exp3_right_test:
        i, j = idx // Ny, idx % Ny
        exp3_right_test_positions.append([center_dx + dx_vals[i], center_dy + dy_vals[j]])
    np.save(splits_root / "exp3_right_test.npy", np.array(exp3_right_test_positions))
    np.save(splits_root / "exp3_right_test_indices.npy", np.array(exp3_right_test))
    create_split(splits_root, "exp3_left", data_root, exp3_left, N)

    # --- Exp 4: Corner scaling (N=4, 8, 16, 32, 64) ---
    def evenly_spaced_grid(n_per_side):
        i_indices = np.round(np.linspace(0, Nx - 1, n_per_side)).astype(int)
        j_indices = np.round(np.linspace(0, Ny - 1, n_per_side)).astype(int)
        return sorted(set(i * Ny + j for i in i_indices for j in j_indices))

    # N=4: 4 corners (2×2 at extremes)
    exp4_n4 = evenly_spaced_grid(2)
    create_split(splits_root, "exp4_n4", data_root, exp4_n4, N)

    # N=8: corners + edge midpoints (3×3 minus center = 8, but 3×3=9 is fine)
    exp4_n8 = evenly_spaced_grid(3)  # 9 positions (close to 8)
    create_split(splits_root, "exp4_n8", data_root, exp4_n8, N)

    # N=16: 4×4 evenly spaced
    exp4_n16 = evenly_spaced_grid(4)
    create_split(splits_root, "exp4_n16", data_root, exp4_n16, N)

    # N=32: ~6×6 evenly spaced
    exp4_n32 = evenly_spaced_grid(6)
    create_split(splits_root, "exp4_n32", data_root, exp4_n32, N)

    # N=64: 8×8 evenly spaced
    exp4_n64 = evenly_spaced_grid(8)
    create_split(splits_root, "exp4_n64", data_root, exp4_n64, N)

    print("\nAll splits created.")
    print(f"\nSummary:")
    for name in ["exp1_inner", "exp2_random10", "exp3_near", "exp3_left",
                  "exp4_n4", "exp4_n8", "exp4_n16", "exp4_n32", "exp4_n64"]:
        d = splits_root / f"{name}_train" / "libero_spatial" / "task_0"
        n = len(list(d.iterdir())) if d.exists() else 0
        print(f"  {name}: {n} demos")


if __name__ == "__main__":
    main()
