#!/usr/bin/env python3
"""eval_translation_grid.py — Evaluate at camera translation grid (dx, dy, dz)."""
import subprocess, numpy as np, json, os, sys
from pathlib import Path

CHECKPOINT = sys.argv[1] if len(sys.argv) > 1 else "checkpoints/act_defvp_crop50_nokp/best.pth"
EXPERIMENT_NAME = sys.argv[2] if len(sys.argv) > 2 else "translation_test"
GPU = int(sys.argv[3]) if len(sys.argv) > 3 else 4
N_EPISODES = 3

# Translation grid: horizontal (dx), vertical (dy), depth (dz)
# These are in meters, in camera-local coordinates
dx_vals = np.linspace(-0.15, 0.15, 5)  # left-right
dy_vals = np.linspace(-0.10, 0.10, 5)  # up-down
dz_val = 0.0  # no depth change for now

center_dx, center_dy = 0.0509, -0.2063

results = {}
out_root = Path(f"results/{EXPERIMENT_NAME}")
out_root.mkdir(parents=True, exist_ok=True)

vi = 0
total = len(dx_vals) * len(dy_vals)
for di, dx in enumerate(dx_vals):
    for dj, dy in enumerate(dy_vals):
        # Fixed object position (centered)
        shift_dx = center_dx
        shift_dy = center_dy

        cmd = (f"CUDA_VISIBLE_DEVICES={GPU} python eval.py --model_type act "
               f"--checkpoint {CHECKPOINT} "
               f"--benchmark libero_spatial --task_id 0 --n_episodes {N_EPISODES} "
               f"--teleport --zero_rotation --clean_scene --max_steps 600 "
               f"--shift_dx {shift_dx} --shift_dy {shift_dy} "
               f"--cam_dx {dx} --cam_dy {dy} --cam_dz {dz_val} "
               f"--out_dir {out_root}/vp_{vi} --save_video")

        result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
        rate = 0
        for line in result.stdout.split('\n'):
            if 'Success Rate' in line:
                rate = float(line.split(':')[1].strip().replace('%','')) / 100

        results[vi] = {"dx": float(dx), "dy": float(dy), "dz": float(dz_val),
                       "rate": rate, "di": di, "dj": dj}
        vi += 1
        print(f"  [{vi}/{total}] dx={dx:.3f} dy={dy:.3f} → {rate*100:.0f}%")
        sys.stdout.flush()

# Save results
with open(out_root / "grid_results.json", "w") as f:
    json.dump(results, f, indent=2)

# Print grid
print(f"\n=== Translation Grid (camera-local meters) ===")
print("dx\\dy " + "  ".join(f"{d:6.3f}" for d in dy_vals))
total_s, total_n = 0, 0
for di, dx in enumerate(dx_vals):
    row = f"{dx:6.3f}"
    for dj in range(len(dy_vals)):
        idx = di * len(dy_vals) + dj
        rate = results[idx]["rate"]
        total_s += rate * N_EPISODES
        total_n += N_EPISODES
        row += f"  {rate*100:4.0f}%"
    print(row)
print(f"\nOverall: {total_s/total_n*100:.0f}%")
