#!/usr/bin/env python3
"""eval_full_grid.py — Evaluate at all 64 viewpoints, save results + videos."""
import subprocess, numpy as np, json, os, sys
from pathlib import Path

# Config — override via command line args or edit here
CHECKPOINT = sys.argv[1] if len(sys.argv) > 1 else "/data/cameron/para_normalized_losses/libero/checkpoints/act_v2_exp4_n64/best.pth"
EXPERIMENT_NAME = sys.argv[2] if len(sys.argv) > 2 else "act_baseline"
N_EPISODES = 3          # per viewpoint
GPU = int(sys.argv[3]) if len(sys.argv) > 3 else 4
SAVE_VIDEOS = True

n_views = 8
thetas = np.linspace(0, 25, n_views)
phis = np.linspace(0, 360*(1-1/n_views), n_views)
center_dx, center_dy = 0.0509, -0.2063
dx_min, dx_max = -0.40, -0.01
dy_min, dy_max = -0.30, 0.30

rng = np.random.RandomState(42)
results = {}
out_root = Path(f"results/{EXPERIMENT_NAME}")
out_root.mkdir(parents=True, exist_ok=True)

for vi in range(64):
    ti, pi = vi // n_views, vi % n_views
    theta, phi = thetas[ti], phis[pi]

    # Random object position
    dx = rng.uniform(dx_min, dx_max)
    dy = rng.uniform(dy_min, dy_max)
    shift_dx = center_dx + dx
    shift_dy = center_dy + dy

    vid_flag = "--save_video" if SAVE_VIDEOS else ""
    cmd = (f"CUDA_VISIBLE_DEVICES={GPU} python eval.py --model_type act "
           f"--checkpoint {CHECKPOINT} "
           f"--benchmark libero_spatial --task_id 0 --n_episodes {N_EPISODES} "
           f"--teleport --zero_rotation --clean_scene --max_steps 600 "
           f"--shift_dx {shift_dx} --shift_dy {shift_dy} "
           f"--cam_theta {theta} --cam_phi {phi} "
           f"--out_dir {out_root}/vp_{vi} {vid_flag}")

    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    rate = 0
    for line in result.stdout.split('\n'):
        if 'Success Rate' in line:
            rate = float(line.split(':')[1].strip().replace('%','')) / 100

    results[vi] = {"theta": float(theta), "phi": float(phi),
                   "rate": rate, "ti": ti, "pi": pi}
    print(f"  [{vi+1}/64] θ={theta:.1f}° φ={phi:.1f}° → {rate*100:.0f}%")
    sys.stdout.flush()

# Save results JSON
with open(out_root / "grid_results.json", "w") as f:
    json.dump(results, f, indent=2)

# Print table
print("\n=== Results Grid ===")
print("θ\\φ   " + "  ".join(f"{p:5.0f}" for p in phis))
total_s, total_n = 0, 0
per_theta = {}
for ti, theta in enumerate(thetas):
    row = f"{theta:5.1f}"
    rates = []
    for pi in range(n_views):
        vi = ti * n_views + pi
        rate = results[vi]["rate"]
        total_s += rate * N_EPISODES
        total_n += N_EPISODES
        rates.append(rate)
        row += f"  {rate*100:4.0f}%"
    per_theta[theta] = np.mean(rates)
    print(row)
print(f"\nOverall: {total_s/total_n*100:.0f}%")
print("Per-θ: " + "  ".join(f"{t:.0f}°={r*100:.0f}%" for t,r in per_theta.items()))
