"""Closed-loop LIBERO eval for the v2 AR policy with 7-DoF heads.

Per step: cache new frame patches → ARHead forward → decode (xy, height, rot, grip)
→ 7-DoF OSC_POSE action → env.step. True closed loop, one action per call (no window).

Usage:
  cd /data/cameron/para/libero
  CUDA_VISIBLE_DEVICES=9 \
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \
  DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth \
  python eval_ar_closed_loop.py \
    --checkpoint checkpoints/ar_v2_libero_spatial_t0_w20h8/best.pth \
    --benchmark libero_spatial --task_id 0 \
    --n_episodes 20 --max_steps 600 \
    --teleport --zero_rotation \
    --out_dir out/eval_ar_closed_loop
"""
import argparse, json, os, sys, time
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))

from model_autoregressive_v2 import (
    ARTransformerPolicyV2, RolloutCache, IMAGE_SIZE, N_HEIGHT_BINS, N_ROT_BINS,
)
from utils import recover_3d_from_direct_keypoint_and_height
# Reuse helpers from eval.py for env-side things
from eval import (
    preprocess_obs, get_camera_params, eef_to_start_kp,
)

from libero.libero import benchmark as bm_mod, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import project_points_from_world_to_camera


# Matches MIN/MAX defaults in train_ar_v2.py
MIN_HEIGHT = 0.85
MAX_HEIGHT = 1.55
MIN_ROT = np.array([-3.14159, -3.14159, -3.14159])
MAX_ROT = np.array([ 3.14159,  3.14159,  3.14159])
OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5
GRID_SIZE = 56


def decode_ar_action(out, model, cache_eef_xy, camera_pose, cam_K,
                     current_eef_pos, current_eef_quat,
                     zero_rotation=False, teleport=False, max_delta=0.05,
                     action_scale=1.0):
    """Convert ARHead outputs (dict of logits) → 7-D OSC_POSE action.

    out: dict from model.ar_head(...) — xy_logits (1, G^2), height_logits (1, N_H),
                                        gripper_logit (1,), rotation_logits (1, 3, N_R)
    """
    G = model.ar_head.grid_size
    cell = IMAGE_SIZE / G
    xy_idx = int(out["xy_logits"].argmax(dim=-1).item())
    gx = xy_idx %  G
    gy = xy_idx // G
    px = (gx + 0.5) * cell
    py = (gy + 0.5) * cell

    h_bin = int(out["height_logits"].argmax(dim=-1).item())
    height = (h_bin / max(N_HEIGHT_BINS - 1, 1)) * (MAX_HEIGHT - MIN_HEIGHT) + MIN_HEIGHT

    pred_3d = recover_3d_from_direct_keypoint_and_height(
        np.array([px, py], dtype=np.float64), height, camera_pose, cam_K,
    )
    if pred_3d is None:
        pred_3d = current_eef_pos.copy()

    if teleport:
        # Take a large step toward the predicted 3D target (LIBERO custom teleport semantic
        # used in para_normalized_losses eval — bigger than OSC_POS_SCALE).
        delta_pos = (pred_3d - current_eef_pos) * action_scale
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)
    else:
        delta_pos = (pred_3d - current_eef_pos) * action_scale
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

    if zero_rotation:
        delta_rot_norm = np.zeros(3, dtype=np.float64)
    else:
        rot_logits = out["rotation_logits"][0]  # (3, N_R)
        euler_pred = np.array([
            (rot_logits[axis].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (MAX_ROT[axis] - MIN_ROT[axis]) + MIN_ROT[axis]
            for axis in range(3)
        ])
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_current = ScipyR.from_quat(current_eef_quat)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

    g_logit = float(out["gripper_logit"][0].cpu())
    gripper_cmd = 1.0 if g_logit > 0.0 else -1.0

    action = np.zeros(7, dtype=np.float32)
    action[:3]  = delta_norm
    action[3:6] = delta_rot_norm
    action[6]   = gripper_cmd
    return action, pred_3d, np.array([px, py])


def load_model(ckpt_path, device, history_len=8, grid_size=56):
    model = ARTransformerPolicyV2(
        target_size=IMAGE_SIZE, history_len=history_len, grid_size=grid_size, freeze_backbone=True,
    ).to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if missing:
        print(f"⚠ missing keys (random init): {len(missing)}")
    if unexpected:
        print(f"⚠ unexpected keys (ignored): {len(unexpected)}")
    model.eval()
    return model


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--camera", type=str, default="agentview")
    p.add_argument("--n_episodes", type=int, default=20)
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--history_len", type=int, default=8)
    p.add_argument("--grid_size", type=int, default=56)
    p.add_argument("--teleport", action="store_true")
    p.add_argument("--zero_rotation", action="store_true")
    p.add_argument("--shift_dx", type=float, default=0.0, help="Per-episode object-position shift")
    p.add_argument("--shift_dy", type=float, default=0.0)
    p.add_argument("--positions_file", type=str, default="",
                   help="If set, .npy with (N, 2) (dx, dy) test positions. Runs n_episodes per position; SR averaged.")
    p.add_argument("--action_scale", type=float, default=1.0,
                   help="Multiplier on the delta (pred_3d - current_eef) before OSC normalization. "
                        "Use ~10 to amplify directional predictions when model trained on next-frame targets.")
    p.add_argument("--save_video", action="store_true",
                   help="Save MP4 per episode with frame + predicted target overlay.")
    p.add_argument("--save_video_max", type=int, default=4,
                   help="Cap on how many episodes to save videos for (avoid filling disk).")
    p.add_argument("--out_dir", type=str, default="out/eval_ar_closed_loop")
    p.add_argument("--seed", type=int, default=0)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading model: {args.checkpoint}")
    model = load_model(args.checkpoint, device, args.history_len, args.grid_size)

    bench = bm_mod.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    print(f"Task: {task.name}")

    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in keys]
    n_eps = min(args.n_episodes, len(init_states))

    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl, camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(args.seed)
    env.reset()

    Np = model.patch_encoder.n_patches
    D  = model.patch_encoder.embed_dim

    successes = []
    step_counts = []
    trajectories = []  # for jerk/coverage diagnostics
    t_start = time.time()

    # Build (ep_idx, dx, dy) plan: if positions_file is given, loop over each position with
    # ~3 episodes per position (totalling n_episodes * len(positions) attempts).
    if args.positions_file:
        positions = np.load(args.positions_file)
        plan = [(ep, float(dx), float(dy)) for dx, dy in positions for ep in range(n_eps)]
        print(f"Position-OOD eval: {len(positions)} positions × {n_eps} eps = {len(plan)} rollouts")
    else:
        plan = [(ep, float(args.shift_dx), float(args.shift_dy)) for ep in range(n_eps)]

    video_dir = out_dir / "videos"
    if args.save_video:
        video_dir.mkdir(exist_ok=True)
    n_videos_saved = 0

    for plan_i, (ep, dx, dy) in enumerate(tqdm(plan, desc="Rollouts")):
        env.reset()
        init_state = init_states[ep].copy()
        if abs(dx) > 1e-6 or abs(dy) > 1e-6:
            init_state[0] += dx
            init_state[1] += dy
        save_this = args.save_video and (n_videos_saved < args.save_video_max)
        frames = [] if save_this else None
        obs = env.set_init_state(init_state)
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)

        cache = RolloutCache(args.history_len, Np, D, device)
        ep_trajectory = []
        done = False
        success = False

        with torch.no_grad():
            for step_idx in range(args.max_steps):
                current_eef_pos  = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
                rgb_obs = obs[f"{args.camera}_image"]

                # Project EEF → pixel (for history token)
                pix_rc = project_points_from_world_to_camera(
                    current_eef_pos.reshape(1, 3), world_to_camera, IMAGE_SIZE, IMAGE_SIZE,
                )[0]
                eef_xy_tensor = torch.tensor(
                    [np.clip(pix_rc[1], 0, IMAGE_SIZE - 1),
                     np.clip(pix_rc[0], 0, IMAGE_SIZE - 1)],
                    dtype=torch.float32, device=device,
                )

                img_tensor = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device).unsqueeze(1)  # (1, 1, 3, H, W)
                new_patches = model.patch_encoder(img_tensor)[0]   # (1, Np, D)
                cache.push(new_patches, eef_xy_tensor)

                hist_p, hist_e = cache.window()
                anchor = hist_e[:, -1]
                out = model.ar_head(hist_p, hist_e, anchor, IMAGE_SIZE)
                action, pred_3d, pred_px = decode_ar_action(
                    out, model, hist_e, camera_pose, cam_K,
                    current_eef_pos, current_eef_quat,
                    zero_rotation=args.zero_rotation, teleport=args.teleport,
                    action_scale=args.action_scale,
                )
                ep_trajectory.append(pred_px)

                if frames is not None:
                    # Render: RGB + predicted-pixel crosshair + current EEF marker
                    vis = np.flipud(rgb_obs.astype(np.uint8)).copy()
                    px_int = int(np.clip(pred_px[0], 0, IMAGE_SIZE - 1))
                    py_int = int(np.clip(pred_px[1], 0, IMAGE_SIZE - 1))
                    cv2.drawMarker(vis, (px_int, py_int), (0, 255, 0), cv2.MARKER_CROSS, 18, 2, cv2.LINE_AA)
                    cu = int(eef_xy_tensor[0].item())
                    cv = int(eef_xy_tensor[1].item())
                    cv2.circle(vis, (cu, cv), 6, (255, 255, 255), -1)
                    label = f"ep{plan_i} step{step_idx} g={action[6]:+.0f}"
                    cv2.putText(vis, label, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
                    cv2.putText(vis, label, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20, 20, 20), 1, cv2.LINE_AA)
                    frames.append(vis)

                obs, _, done, _ = env.step(action)
                if done:
                    success = True
                    break

        if frames is not None and len(frames) > 0:
            try:
                import imageio.v2 as imageio
                vpath = video_dir / f"ep{plan_i:03d}_{'success' if success else 'fail'}.mp4"
                imageio.mimwrite(str(vpath), frames, fps=30, codec="libx264", quality=8)
                n_videos_saved += 1
            except Exception as e:
                print(f"  video save failed: {e}")

        successes.append(float(success))
        step_counts.append(step_idx + 1)
        trajectories.append(np.stack(ep_trajectory, axis=0) if ep_trajectory else np.zeros((0, 2)))

    elapsed = time.time() - t_start
    sr = float(np.mean(successes))

    # Trajectory metrics on predicted pixel streams (no GT here — just self-smoothness)
    jerks = []
    for tr in trajectories:
        if tr.shape[0] >= 3:
            ddp = tr[2:] - 2 * tr[1:-1] + tr[:-2]
            jerks.append(float(np.linalg.norm(ddp, axis=-1).mean()))

    results = {
        "checkpoint": args.checkpoint,
        "benchmark": args.benchmark,
        "task_id": args.task_id,
        "n_episodes_per_pos": n_eps,
        "n_rollouts_total": len(plan),
        "positions_file": args.positions_file,
        "shift_dx": args.shift_dx,
        "shift_dy": args.shift_dy,
        "teleport": args.teleport,
        "zero_rotation": args.zero_rotation,
        "successes": successes,
        "success_rate": sr,
        "step_counts": step_counts,
        "avg_steps": float(np.mean(step_counts)),
        "pred_jerk_mean_px": float(np.mean(jerks)) if jerks else None,
        "elapsed_sec": elapsed,
    }
    out_file = out_dir / f"eval_{args.benchmark}_task{args.task_id}.json"
    with open(out_file, "w") as f:
        json.dump(results, f, indent=2)
    print(f"\n===")
    print(f"Success rate: {100 * sr:.1f}% ({sum(successes):.0f}/{n_eps})")
    print(f"Avg steps:    {np.mean(step_counts):.1f}")
    print(f"Pred jerk px: {results['pred_jerk_mean_px']}")
    print(f"Saved: {out_file}")


if __name__ == "__main__":
    main()
