#!/usr/bin/env python3
"""eval_translation_grid_fast.py — Fast 3x3 translation eval, 1 episode each."""
import subprocess, numpy as np, json, sys
from pathlib import Path

CHECKPOINT = sys.argv[1]
EXPERIMENT_NAME = sys.argv[2]
GPU = int(sys.argv[3])
N_EPISODES = 1

dx_vals = np.linspace(-0.10, 0.10, 3)  # 3 points
dy_vals = np.linspace(-0.075, 0.075, 3)
center_dx, center_dy = 0.0509, -0.2063

results = {}
out_root = Path(f"results/{EXPERIMENT_NAME}")
out_root.mkdir(parents=True, exist_ok=True)

vi = 0
total = len(dx_vals) * len(dy_vals)
for di, dx in enumerate(dx_vals):
    for dj, dy in enumerate(dy_vals):
        shift_dx = center_dx
        shift_dy = center_dy

        cmd = (f"CUDA_VISIBLE_DEVICES={GPU} python eval.py --model_type act "
               f"--checkpoint {CHECKPOINT} "
               f"--benchmark libero_spatial --task_id 0 --n_episodes {N_EPISODES} "
               f"--teleport --zero_rotation --clean_scene --max_steps 600 "
               f"--shift_dx {shift_dx} --shift_dy {shift_dy} "
               f"--cam_dx {dx} --cam_dy {dy} --cam_dz 0.0 "
               f"--out_dir {out_root}/vp_{vi} --save_video")

        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        rate = 0
        for line in result.stdout.split('\n'):
            if 'Success Rate' in line:
                rate = float(line.split(':')[1].strip().replace('%','')) / 100

        results[vi] = {"dx": float(dx), "dy": float(dy), "rate": rate, "di": di, "dj": dj}
        vi += 1
        status = "OK" if rate > 0 else "FAIL"
        print(f"  [{vi}/{total}] dx={dx:.3f} dy={dy:.3f} → {rate*100:.0f}% {status}")
        sys.stdout.flush()

with open(out_root / "grid_results.json", "w") as f:
    json.dump(results, f, indent=2)

print(f"\n=== Translation Grid ===")
print("dx\\dy " + "  ".join(f"{d:6.3f}" for d in dy_vals))
total_s, total_n = 0, 0
for di, dx in enumerate(dx_vals):
    row = f"{dx:6.3f}"
    for dj in range(len(dy_vals)):
        idx = di * len(dy_vals) + dj
        rate = results[idx]["rate"]
        total_s += rate * N_EPISODES
        total_n += N_EPISODES
        row += f"  {rate*100:4.0f}%"
    print(row)
print(f"\nOverall: {total_s/total_n*100:.0f}%")
