#!/bin/bash
# Run all 4 cells of Phase A object-position eval matrix in parallel.
# {1v PARA, 2v dual} × {left, right} test positions.
# Each cell: 20 positions × 5 episodes.

set -e
cd /data/cameron/para_normalized_losses/libero

PARA_CKPT=/data/cameron/para_normalized_losses/libero/checkpoints/para_exp3_left_60min/latest.pth
DUAL_CKPT=/data/cameron/para_normalized_losses/libero/checkpoints/dual_para_exp3_left_2v/latest.pth
LEFT_POS=/data/libero/ood_objpos_splits/exp3_left_test.npy
RIGHT_POS=/data/libero/ood_objpos_splits/exp3_right_test.npy

mkdir -p /data/cameron/para_normalized_losses/libero/logs
LOG_DIR=/data/cameron/para_normalized_losses/libero/logs/phase_a_eval
mkdir -p $LOG_DIR

run_cell() {
    local MODEL=$1
    local CKPT=$2
    local POS_FILE=$3
    local LABEL=$4
    local GPU=$5

    PYTHONPATH=/data/cameron/LIBERO \
    LIBERO_DATA_PATH=/data/libero \
    CUDA_VISIBLE_DEVICES=$GPU \
    DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
    DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
    MUJOCO_GL=osmesa PYOPENGL_PLATFORM=osmesa \
    python3 -c "
import subprocess, numpy as np, sys, os
positions = np.load('$POS_FILE')
print(f'Evaluating {len(positions)} positions for [$LABEL]', flush=True)
successes, total = 0, 0
for k, (dx, dy) in enumerate(positions):
    cmd = [
        'python', 'eval.py',
        '--model_type', '$MODEL', '--checkpoint', '$CKPT',
        '--benchmark', 'libero_spatial', '--task_id', '0',
        '--n_episodes', '5', '--teleport', '--zero_rotation', '--clean_scene',
        '--max_steps', '600',
        '--shift_dx', str(dx), '--shift_dy', str(dy),
        '--out_dir', '/tmp/phase_a_${LABEL}'
    ]
    r = subprocess.run(cmd, capture_output=True, text=True, env=os.environ)
    cell_succ, cell_total = 0, 0
    for line in r.stdout.split('\n'):
        if 'Success Rate' in line:
            rate = float(line.split(':')[1].strip().replace('%','')) / 100
            cell_succ = int(round(rate * 5))
            cell_total = 5
            successes += cell_succ
            total += cell_total
            break
    print(f'  pos[{k+1}/{len(positions)}] dx={dx:+.3f} dy={dy:+.3f}: {cell_succ}/{cell_total} (cumulative: {successes}/{total})', flush=True)
print(f'$LABEL FINAL: {successes}/{total} = {100*successes/total if total>0 else 0:.0f}%')
" > $LOG_DIR/${LABEL}.log 2>&1
}

# 4 cells × parallel on different GPUs
(run_cell para     "$PARA_CKPT" "$LEFT_POS"  "1v_PARA_left"  0) &
PID1=$!
(run_cell para     "$PARA_CKPT" "$RIGHT_POS" "1v_PARA_right" 2) &
PID2=$!
(run_cell dual_para "$DUAL_CKPT" "$LEFT_POS"  "2v_DUAL_left"  3) &
PID3=$!
(run_cell dual_para "$DUAL_CKPT" "$RIGHT_POS" "2v_DUAL_right" 4) &
PID4=$!

echo "Launched cells: PIDs=$PID1 $PID2 $PID3 $PID4"
wait $PID1 $PID2 $PID3 $PID4
echo ""
echo "=========== Phase A Object Position Eval Results ==========="
for cell in 1v_PARA_left 1v_PARA_right 2v_DUAL_left 2v_DUAL_right; do
    grep "FINAL:" $LOG_DIR/${cell}.log 2>/dev/null || echo "$cell: missing"
done
