#!/bin/bash
# Run all OOD object position experiments on centered_v2 dataset
# Usage: CUDA_VISIBLE_DEVICES=4 bash run_experiments_v2.sh para
#    or: CUDA_VISIBLE_DEVICES=5 bash run_experiments_v2.sh act

set -e
cd /data/cameron/para_normalized_losses/libero

MODEL=$1
if [ -z "$MODEL" ]; then echo "Usage: $0 <para|act>"; exit 1; fi

export PYTHONPATH=/data/cameron/LIBERO:$PYTHONPATH
export DINO_REPO_DIR=/data/cameron/keygrip/dinov3
export DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth

SPLITS_DIR=/data/libero/ood_objpos_v3_splits
RESULTS_FILE="logs/${MODEL}_v2_results.txt"
mkdir -p logs
> "$RESULTS_FILE"

eval_at_positions() {
    local ckpt=$1
    local positions_file=$2
    local label=$3

    python3 -c "
import subprocess, numpy as np
positions = np.load('$positions_file')
successes, total = 0, 0
for dx, dy in positions:
    cmd = f'python eval.py --model_type $MODEL --checkpoint $ckpt --benchmark libero_spatial --task_id 0 --n_episodes 5 --teleport --zero_rotation --clean_scene --max_steps 600 --shift_dx {dx} --shift_dy {dy} --out_dir /tmp/${MODEL}_v2_eval'
    result = subprocess.run(cmd.split(), capture_output=True, text=True)
    for line in result.stdout.split('\n'):
        if 'Success Rate' in line:
            rate = float(line.split(':')[1].strip().replace('%','')) / 100
            successes += int(round(rate * 5)); total += 5
if total > 0:
    print(f'$label: {successes}/{total} = {100*successes/total:.0f}%')
else:
    print(f'$label: no results')
"
}

run_experiment() {
    local exp_name=$1
    local cache_root=$2
    local test_file=$3

    echo "========================================"
    echo "  $MODEL — $exp_name"
    echo "========================================"

    rm -rf "checkpoints/${MODEL}_${exp_name}"
    find checkpoints -name "dataset_stats.json" -delete 2>/dev/null

    python train.py --model_type $MODEL --run_name "${MODEL}_${exp_name}" \
        --benchmark libero_spatial --task_id 0 \
        --cache_root "$cache_root" \
        --batch_size 8 --lr 1e-4 --epochs 9999 --max_minutes 10 \
        --skip_rotation --vis_every_steps 0 \
        --wandb_project para_libero --wandb_mode online

    echo "--- Eval ---"
    eval_at_positions "checkpoints/${MODEL}_${exp_name}/best.pth" "$test_file" "$exp_name" | tee -a "$RESULTS_FILE"
}

# Exp 1: Inner square
run_experiment "v2_exp1_inner" "${SPLITS_DIR}/exp1_inner_train" "${SPLITS_DIR}/test_positions.npy"

# Exp 2: Random 10
run_experiment "v2_exp2_random10" "${SPLITS_DIR}/exp2_random10_train" "${SPLITS_DIR}/test_positions.npy"

# Exp 3a: Near → Far
run_experiment "v2_exp3_near" "${SPLITS_DIR}/exp3_near_train" "${SPLITS_DIR}/exp3_far_test.npy"

# Exp 3b: Left → Right
run_experiment "v2_exp3_left" "${SPLITS_DIR}/exp3_left_train" "${SPLITS_DIR}/exp3_right_test.npy"

# Exp 4: Corner scaling
for n in 4 8 16 32 64; do
    run_experiment "v2_exp4_n${n}" "${SPLITS_DIR}/exp4_n${n}_train" "${SPLITS_DIR}/test_positions.npy"
done

# Exp 5: Distractor robustness — train on N=64 clean, eval WITH distractors
echo "========================================"
echo "  $MODEL — v2_distractor_robustness"
echo "========================================"
echo "--- Eval (using v2_exp4_n64 checkpoint, NO --clean_scene) ---"
CKPT="checkpoints/${MODEL}_v2_exp4_n64/best.pth"
if [ -f "$CKPT" ]; then
    python3 -c "
import subprocess, numpy as np
positions = np.load('${SPLITS_DIR}/test_positions.npy')
successes, total = 0, 0
for dx, dy in positions:
    cmd = f'python eval.py --model_type $MODEL --checkpoint $CKPT --benchmark libero_spatial --task_id 0 --n_episodes 5 --teleport --zero_rotation --max_steps 600 --shift_dx {dx} --shift_dy {dy} --out_dir /tmp/${MODEL}_v2_distractor_eval'
    result = subprocess.run(cmd.split(), capture_output=True, text=True)
    for line in result.stdout.split('\n'):
        if 'Success Rate' in line:
            rate = float(line.split(':')[1].strip().replace('%','')) / 100
            successes += int(round(rate * 5)); total += 5
if total > 0:
    print(f'v2_distractor_robustness: {successes}/{total} = {100*successes/total:.0f}%')
else:
    print(f'v2_distractor_robustness: no results')
" | tee -a "$RESULTS_FILE"
fi

echo ""
echo "========================================"
echo "  ALL RESULTS — $MODEL"
echo "========================================"
cat "$RESULTS_FILE"
