"""Closed-loop LIBERO eval for the voxel-token AR variants (B abs / C rel).

Mirrors eval_ar_closed_loop.py but uses model_voxel_ar.VoxelARPolicyAbs / VoxelARPolicyRel.
Per step:
  - get RGB + EEF pos/quat
  - project EEF → pixel
  - push (frame, pixel) into ring buffer of past H frames
  - one DINO patch pass on the new frame, store in cache
  - run voxel_builder(current_frame_patches, cam_K, cam_extrinsic, anchor) → voxels
  - run ar_head(history_patches, history_eef_xy, voxel_feats) → 7-DoF heads
  - decode → 7-DoF OSC_POSE action → env.step

Usage:
  CUDA_VISIBLE_DEVICES=9 MUJOCO_GL=egl PYTHONPATH=/data/cameron/LIBERO:$PYTHONPATH \
  DINO_REPO_DIR=... DINO_WEIGHTS_PATH=... \
  python eval_voxel_closed_loop.py \
    --checkpoint checkpoints/B_voxel_abs_pilot/best.pth --variant abs \
    --benchmark libero_spatial --task_id 0 \
    --n_episodes 1 --max_steps 600 --teleport --zero_rotation \
    --save_video --out_dir out/eval_B_single
"""
import argparse, json, os, sys, time
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))

from model_voxel_ar import (
    VoxelARPolicyAbs, VoxelARPolicyRel, IMAGE_SIZE, N_HEIGHT_BINS, N_ROT_BINS,
)
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, get_camera_params  # reuse helpers

from libero.libero import benchmark as bm_mod, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import project_points_from_world_to_camera


MIN_HEIGHT = 0.85
MAX_HEIGHT = 1.55
MIN_ROT = np.array([-3.14159, -3.14159, -3.14159])
MAX_ROT = np.array([ 3.14159,  3.14159,  3.14159])
OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5


def decode_action(out, grid_size, camera_pose, cam_K, current_eef_pos, current_eef_quat,
                  zero_rotation, action_scale, max_delta=0.05):
    G = grid_size
    cell = IMAGE_SIZE / G
    xy_idx = int(out["xy_logits"].argmax(dim=-1).item())
    gx = xy_idx %  G
    gy = xy_idx // G
    px = (gx + 0.5) * cell
    py = (gy + 0.5) * cell
    h_bin = int(out["height_logits"].argmax(dim=-1).item())
    height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT
    pred_3d = recover_3d_from_direct_keypoint_and_height(
        np.array([px, py], dtype=np.float64), height, camera_pose, cam_K)
    if pred_3d is None:
        pred_3d = current_eef_pos.copy()
    delta_pos = (pred_3d - current_eef_pos) * action_scale
    norm = np.linalg.norm(delta_pos)
    if norm > max_delta:
        delta_pos = delta_pos / norm * max_delta
    delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)
    if zero_rotation:
        delta_rot_norm = np.zeros(3, dtype=np.float64)
    else:
        rot_logits = out["rotation_logits"][0]
        euler_pred = np.array([
            (rot_logits[axis].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (MAX_ROT[axis] - MIN_ROT[axis]) + MIN_ROT[axis]
            for axis in range(3)
        ])
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_current = ScipyR.from_quat(current_eef_quat)
        delta_rot_norm = np.clip((R_pred * R_current.inv()).as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)
    g_logit = float(out["gripper_logit"][0].cpu())
    gripper_cmd = 1.0 if g_logit > 0.0 else -1.0
    action = np.zeros(7, dtype=np.float32)
    action[:3]  = delta_norm
    action[3:6] = delta_rot_norm
    action[6]   = gripper_cmd
    return action, pred_3d, np.array([px, py])


class PatchRollingCache:
    """Ring buffer of past H patch tensors + EEF pixel coords for voxel eval."""
    def __init__(self, H, Np, D, device):
        self.H = H
        self.patches = torch.zeros(1, H, Np, D, device=device)
        self.eef_xy  = torch.zeros(1, H, 2, device=device)
        self.fill = 0

    def push(self, new_patches, new_eef_xy):
        new_eef_xy = new_eef_xy.view(1, 2) if new_eef_xy.dim() == 1 else new_eef_xy
        if self.fill < self.H:
            self.patches[0, self.fill] = new_patches[0]
            self.eef_xy[0, self.fill]  = new_eef_xy[0]
            self.fill += 1
        else:
            self.patches[:, :-1] = self.patches[:, 1:].clone()
            self.eef_xy[:, :-1]  = self.eef_xy[:, 1:].clone()
            self.patches[0, -1] = new_patches[0]
            self.eef_xy[0, -1]  = new_eef_xy[0]

    def window(self):
        if self.fill < self.H:
            p = self.patches.clone(); e = self.eef_xy.clone()
            last = max(0, self.fill - 1)
            for i in range(self.fill, self.H):
                p[0, i] = p[0, last]; e[0, i] = e[0, last]
            return p, e
        return self.patches, self.eef_xy


def load_voxel_model(ckpt_path, variant, device, history_len=8, grid_size=56,
                    voxel_xy=28, voxel_z=16):
    Cls = VoxelARPolicyAbs if variant == "abs" else VoxelARPolicyRel
    model = Cls(target_size=IMAGE_SIZE, history_len=history_len, grid_size=grid_size,
                voxel_xy=voxel_xy, voxel_z=voxel_z, freeze_backbone=True,
                min_height=MIN_HEIGHT, max_height=MAX_HEIGHT).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing:    print(f"⚠ missing keys: {len(missing)}")
    if unexpected: print(f"⚠ unexpected keys: {len(unexpected)}")
    model.eval()
    return model


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--variant", type=str, required=True, choices=["abs", "rel"])
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--camera", type=str, default="agentview")
    p.add_argument("--n_episodes", type=int, default=1)
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--history_len", type=int, default=8)
    p.add_argument("--grid_size", type=int, default=56)
    p.add_argument("--voxel_xy", type=int, default=28)
    p.add_argument("--voxel_z", type=int, default=16)
    p.add_argument("--teleport", action="store_true")
    p.add_argument("--zero_rotation", action="store_true")
    p.add_argument("--action_scale", type=float, default=1.0)
    p.add_argument("--save_video", action="store_true")
    p.add_argument("--save_video_max", type=int, default=4)
    p.add_argument("--out_dir", type=str, default="out/eval_voxel_closed_loop")
    p.add_argument("--seed", type=int, default=0)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading voxel-{args.variant} model: {args.checkpoint}")
    model = load_voxel_model(args.checkpoint, args.variant, device,
                             args.history_len, args.grid_size, args.voxel_xy, args.voxel_z)

    bench = bm_mod.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    print(f"Task: {task.name}")
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in keys]
    n_eps = min(args.n_episodes, len(init_states))

    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(bddl_file_name=bddl, camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
                             camera_names=[args.camera])
    env.seed(args.seed); env.reset()

    Np = model.patch_encoder.n_patches
    D  = model.patch_encoder.embed_dim

    video_dir = out_dir / "videos"
    if args.save_video:
        video_dir.mkdir(exist_ok=True)
    n_videos_saved = 0
    successes = []; step_counts = []; trajectories = []
    t_start = time.time()

    for ep in tqdm(range(n_eps), desc="Eps"):
        env.reset()
        obs = env.set_init_state(init_states[ep])
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)
        # Static-per-episode tensors for the voxel builder
        cam_K_t = torch.tensor(cam_K, dtype=torch.float32, device=device).unsqueeze(0)         # (1, 3, 3)
        cam_E_t = torch.tensor(camera_pose, dtype=torch.float32, device=device).unsqueeze(0)   # (1, 4, 4)
        # eef_start_xyz for variant C: position at episode start
        eef_start = np.array(obs["robot0_eef_pos"], dtype=np.float64)
        anchor_t = torch.tensor(eef_start, dtype=torch.float32, device=device).unsqueeze(0)    # (1, 3)

        cache = PatchRollingCache(args.history_len, Np, D, device)
        save_this = args.save_video and (n_videos_saved < args.save_video_max)
        frames = [] if save_this else None
        ep_trajectory = []
        done = False; success = False
        with torch.no_grad():
            for step_idx in range(args.max_steps):
                current_eef_pos  = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
                rgb_obs = obs[f"{args.camera}_image"]

                pix_rc = project_points_from_world_to_camera(
                    current_eef_pos.reshape(1, 3), world_to_camera, IMAGE_SIZE, IMAGE_SIZE)[0]
                eef_xy_tensor = torch.tensor(
                    [np.clip(pix_rc[1], 0, IMAGE_SIZE - 1),
                     np.clip(pix_rc[0], 0, IMAGE_SIZE - 1)],
                    dtype=torch.float32, device=device)

                img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device).unsqueeze(1)  # (1, 1, 3, H, W)
                new_patches = model.patch_encoder(img_tensor)[0]                              # (1, Np, D)
                cache.push(new_patches, eef_xy_tensor)
                hist_p, hist_e = cache.window()
                # Build voxel features for the current (most recent) frame
                current_patches = hist_p[:, -1]                                                # (1, Np, D)
                use_anchor = anchor_t if args.variant == "rel" else None
                voxel_feats, _ = model.voxel_builder(current_patches, cam_K_t, cam_E_t, use_anchor)
                out = model.ar_head(hist_p, hist_e, voxel_feats, IMAGE_SIZE)

                action, pred_3d, pred_px = decode_action(
                    out, args.grid_size, camera_pose, cam_K,
                    current_eef_pos, current_eef_quat,
                    args.zero_rotation, args.action_scale)
                ep_trajectory.append(pred_px)

                if frames is not None:
                    vis = np.flipud(rgb_obs.astype(np.uint8)).copy()
                    px_int = int(np.clip(pred_px[0], 0, IMAGE_SIZE - 1))
                    py_int = int(np.clip(pred_px[1], 0, IMAGE_SIZE - 1))
                    cv2.drawMarker(vis, (px_int, py_int), (0, 255, 0), cv2.MARKER_CROSS, 18, 2, cv2.LINE_AA)
                    cu = int(eef_xy_tensor[0].item()); cv = int(eef_xy_tensor[1].item())
                    cv2.circle(vis, (cu, cv), 6, (255, 255, 255), -1)
                    label = f"voxel-{args.variant} ep{ep} step{step_idx} g={action[6]:+.0f}"
                    cv2.putText(vis, label, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
                    cv2.putText(vis, label, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20, 20, 20), 1, cv2.LINE_AA)
                    frames.append(vis)

                obs, _, done, _ = env.step(action)
                if done:
                    success = True
                    break

        if frames is not None and frames:
            try:
                import imageio.v2 as imageio
                vpath = video_dir / f"ep{ep:03d}_{'success' if success else 'fail'}.mp4"
                imageio.mimwrite(str(vpath), frames, fps=30, codec="libx264", quality=8)
                n_videos_saved += 1
            except Exception as e:
                print(f"  video save failed: {e}")

        successes.append(float(success))
        step_counts.append(step_idx + 1)
        trajectories.append(np.stack(ep_trajectory, axis=0) if ep_trajectory else np.zeros((0, 2)))

    elapsed = time.time() - t_start
    sr = float(np.mean(successes))
    jerks = []
    for tr in trajectories:
        if tr.shape[0] >= 3:
            jerks.append(float(np.linalg.norm(tr[2:] - 2 * tr[1:-1] + tr[:-2], axis=-1).mean()))

    results = {
        "checkpoint": args.checkpoint, "variant": args.variant,
        "benchmark": args.benchmark, "task_id": args.task_id,
        "n_episodes": n_eps, "teleport": args.teleport, "zero_rotation": args.zero_rotation,
        "successes": successes, "success_rate": sr,
        "step_counts": step_counts, "avg_steps": float(np.mean(step_counts)),
        "pred_jerk_mean_px": float(np.mean(jerks)) if jerks else None,
        "elapsed_sec": elapsed,
    }
    out_file = out_dir / f"eval_{args.benchmark}_task{args.task_id}.json"
    with open(out_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\nSR={100*sr:.1f}%  avg_steps={np.mean(step_counts):.1f}  pred_jerk={results['pred_jerk_mean_px']}")
    print(f"Saved: {out_file}")


if __name__ == "__main__":
    main()
