#!/bin/bash
# Phase A — Object position OOD: 1view vs 2view
#
# Pipeline:
#   1. Wait for /data/libero/ood_objpos_task0_2view to finish (128 left-half demos).
#   2. Create symlink splits at /data/libero/ood_objpos_splits_2view/exp3_left_train.
#   3. Train DualParaPredictor on the 2view exp3_left split.
#   4. Eval both PARA (existing ckpt) and DualPara (new ckpt) at left + right test positions.
#   5. Print comparison table.

set -e
cd /data/cameron/para_normalized_losses/libero
ROOT=/data/cameron/para_normalized_losses/libero

DATA_2V=/data/libero/ood_objpos_task0_2view
SPLITS_2V=/data/libero/ood_objpos_splits_2view
SPLITS_1V=/data/libero/ood_objpos_splits

DUAL_NAME=dual_para_exp3_left_2v
DUAL_CKPT=$ROOT/checkpoints/${DUAL_NAME}/best.pth
PARA_CKPT=$ROOT/checkpoints/para_v2_exp3_left/best.pth

RESULTS=$ROOT/logs/phase_a_objpos_results.txt
mkdir -p $ROOT/logs
echo "=== Phase A: OOD object position 1view vs 2view ===" | tee $RESULTS
echo "Started $(date)" | tee -a $RESULTS

# ----- Step 1: wait for rendering -----
echo "[Step 1] Waiting for 128 2view demos at $DATA_2V ..." | tee -a $RESULTS
while [ $(find $DATA_2V -name 'wrist_w2c.npy' 2>/dev/null | wc -l) -lt 128 ]; do
    n=$(find $DATA_2V -name 'wrist_w2c.npy' 2>/dev/null | wc -l)
    echo "  $(date +%H:%M:%S) demos=${n}/128"
    sleep 60
done
echo "  $(date +%H:%M:%S) ✓ all 128 demos rendered" | tee -a $RESULTS

# ----- Step 2: create splits -----
echo "[Step 2] Creating 2view splits ..." | tee -a $RESULTS
python create_splits_2view.py \
    --data_root $DATA_2V \
    --splits_root $SPLITS_2V 2>&1 | tee -a $RESULTS

# ----- Step 3: train dual-view -----
echo "[Step 3] Training DualParaPredictor on 2view exp3_left ..." | tee -a $RESULTS
rm -rf $ROOT/checkpoints/${DUAL_NAME}
find $ROOT/checkpoints -name "dataset_stats.json" -delete 2>/dev/null || true

# Use a free GPU - check nvidia-smi
GPU_ID=${PHASE_A_GPU:-1}
CUDA_VISIBLE_DEVICES=$GPU_ID \
    DINO_REPO_DIR=/data/cameron/dinov3 \
    DINO_WEIGHTS_PATH=/data/cameron/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
    python train.py --model_type dual_para --run_name $DUAL_NAME \
        --benchmark libero_spatial --task_id 0 \
        --cache_root $SPLITS_2V/exp3_left_train \
        --batch_size 8 --lr 1e-4 --epochs 9999 --max_minutes 60 \
        --skip_rotation --vis_every_steps 0 \
        --wandb_project para_libero --wandb_mode online 2>&1 | tee $ROOT/logs/${DUAL_NAME}.log

echo "[Step 3] ✓ Training done" | tee -a $RESULTS

# ----- Step 4: eval both at left + right -----
eval_at() {
    local MODEL=$1
    local CKPT=$2
    local POS_FILE=$3
    local LABEL=$4
    python3 -c "
import subprocess, numpy as np, sys
positions = np.load('$POS_FILE')
successes, total = 0, 0
for dx, dy in positions:
    cmd = [
        'python', 'eval.py',
        '--model_type', '$MODEL', '--checkpoint', '$CKPT',
        '--benchmark', 'libero_spatial', '--task_id', '0',
        '--n_episodes', '5', '--teleport', '--zero_rotation', '--clean_scene',
        '--max_steps', '600',
        '--shift_dx', str(dx), '--shift_dy', str(dy),
        '--out_dir', '/tmp/phase_a_eval'
    ]
    r = subprocess.run(cmd, capture_output=True, text=True)
    for line in r.stdout.split('\n'):
        if 'Success Rate' in line:
            rate = float(line.split(':')[1].strip().replace('%','')) / 100
            successes += int(round(rate * 5))
            total += 5
print(f'$LABEL: {successes}/{total} = {100*successes/total if total>0 else 0:.0f}%')
" 2>&1 | tail -2
}

echo "[Step 4] Eval matrix:" | tee -a $RESULTS

# 1view PARA at left (in-dist)
echo "  1v PARA   @ left  ..." | tee -a $RESULTS
eval_at para $PARA_CKPT $SPLITS_1V/exp3_left_test.npy "1v_PARA_left" | tee -a $RESULTS
# 1view PARA at right (OOD)
echo "  1v PARA   @ right ..." | tee -a $RESULTS
eval_at para $PARA_CKPT $SPLITS_1V/exp3_right_test.npy "1v_PARA_right" | tee -a $RESULTS

# 2view dual at left (in-dist)
echo "  2v DUAL   @ left  ..." | tee -a $RESULTS
eval_at dual_para $DUAL_CKPT $SPLITS_1V/exp3_left_test.npy "2v_DUAL_left" | tee -a $RESULTS
# 2view dual at right (OOD)
echo "  2v DUAL   @ right ..." | tee -a $RESULTS
eval_at dual_para $DUAL_CKPT $SPLITS_1V/exp3_right_test.npy "2v_DUAL_right" | tee -a $RESULTS

echo "=== Phase A complete $(date) ===" | tee -a $RESULTS
