"""Closed-loop LIBERO eval for the volume AR model.

Per env step:
  - capture RGB + EEF pos/quat from sim
  - assemble past_eef_world (20 most recent EEFs; pad with earliest at episode start)
  - forward → 8 predicted voxel indices for the next 8 timesteps
  - pick a target horizon (default = last/t=7, biggest stride per env step)
  - decode (voxel_center_world, grip, rot[axis]) → 7-DoF OSC_POSE action
  - step env

Save video w/ rainbow predicted polyline + chosen target marker.
"""
import argparse, json, os, sys, time
from collections import deque
from pathlib import Path

import cv2
import h5py
import numpy as np
import torch
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))

from robot_volume import (
    voxel_centers_world, world_to_pixel_torch,
    N_PAST_EEF, T_FUTURE, N_ROT_BINS, MIN_ROT, MAX_ROT, IMAGE_SIZE,
)
from model_volume_ar import VolumeARModel
from model_volume_smooth import SmoothVolumeARModel
from eval import preprocess_obs, get_camera_params

from libero.libero import benchmark as bm_mod, get_libero_path
from libero.libero.envs import OffScreenRenderEnv

OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5


def rainbow_bgr(t, T):
    h = int(180.0 * (t / max(T - 1, 1)))
    return tuple(int(x) for x in cv2.cvtColor(np.uint8([[[h, 255, 255]]]), cv2.COLOR_HSV2BGR)[0, 0])


def load_volume_model(ckpt_path, device, variant="ar"):
    cls = SmoothVolumeARModel if variant == "smooth" else VolumeARModel
    m = cls().to(device)
    ckpt = torch.load(ckpt_path, map_location=device)
    sd = ckpt.get("model_state_dict", ckpt)
    missing, unexpected = m.load_state_dict(sd, strict=False)
    if missing:    print(f"⚠ missing keys: {len(missing)}")
    if unexpected: print(f"⚠ unexpected keys: {len(unexpected)}")
    m.eval()
    return m


def decode_action(target_world, current_eef_pos, current_eef_quat, target_grip_logit,
                  target_rot_logits, zero_rotation=False, max_delta=0.05):
    """Convert (target_world, grip_logit, rot_logits) → 7-D OSC action."""
    delta_pos = target_world - current_eef_pos
    norm = np.linalg.norm(delta_pos)
    if norm > max_delta:
        delta_pos = delta_pos / norm * max_delta
    delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

    if zero_rotation:
        delta_rot_norm = np.zeros(3, dtype=np.float64)
    else:
        eul = np.array([
            (target_rot_logits[axis].argmax().item() / max(N_ROT_BINS - 1, 1))
            * (MAX_ROT[axis] - MIN_ROT[axis]) + MIN_ROT[axis]
            for axis in range(3)
        ])
        R_pred = ScipyR.from_euler('xyz', eul)
        # current_eef_quat from robosuite obs is WXYZ; scipy needs XYZW.
        R_cur  = ScipyR.from_quat(current_eef_quat[[1, 2, 3, 0]])
        delta_rot_norm = np.clip((R_pred * R_cur.inv()).as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

    gripper_cmd = 1.0 if float(target_grip_logit) > 0.0 else -1.0
    action = np.zeros(7, dtype=np.float32)
    action[:3]  = delta_norm
    action[3:6] = delta_rot_norm
    action[6]   = gripper_cmd
    return action


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--demo_idx", type=int, default=0, help="Init state to use (single scene).")
    p.add_argument("--camera", type=str, default="agentview")
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--target_horizon", type=int, default=7,
                   help="Which of the T=8 predicted timesteps to chase per env step (0=first, 7=farthest). "
                        "Used only if --actions_per_forward == 1.")
    p.add_argument("--actions_per_forward", type=int, default=4,
                   help="Number of predicted targets (0..N-1) to execute before replanning. "
                        "Default 4 = take first 4 of 8 predictions then re-forward.")
    p.add_argument("--zero_rotation", action="store_true")
    p.add_argument("--teleport", action="store_true",
                   help="Servo to predicted 3D target via up-to-25 small OSC steps, then 1 grip step. "
                        "Matches para_normalized_losses eval.py teleport semantics.")
    p.add_argument("--teleport_max_servo", type=int, default=25)
    p.add_argument("--teleport_threshold", type=float, default=0.005, help="meters")
    p.add_argument("--variant", type=str, default="ar", choices=["ar", "smooth"])
    p.add_argument("--save_video", action="store_true", default=True)
    p.add_argument("--out_dir", type=str, default="out/eval_volume_single")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    out_dir = Path(args.out_dir); out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Loading model: {args.checkpoint}")
    model = load_volume_model(args.checkpoint, device, args.variant)
    voxel_centers = voxel_centers_world().to(device)             # (V, 3)

    bench = bm_mod.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    print(f"Task: {task.name}; demo {args.demo_idx}; horizon t={args.target_horizon}")
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_state = f[f"data/{keys[args.demo_idx]}/states"][0]

    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(bddl_file_name=bddl, camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
                             camera_names=[args.camera])
    env.seed(0); env.reset()
    obs = env.set_init_state(init_state)
    for _ in range(5):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, IMAGE_SIZE)
    w2c_t = torch.tensor(world_to_camera, dtype=torch.float32, device=device).unsqueeze(0)

    # Past-EEF ring buffer (initialized with current EEF repeated 20 times)
    eef0 = np.array(obs["robot0_eef_pos"], dtype=np.float64)
    past_buf = deque([eef0.copy() for _ in range(N_PAST_EEF)], maxlen=N_PAST_EEF)

    frames = [] if args.save_video else None
    done = False; success = False; t_start = time.time()
    hold_grip = -1.0  # gripper state we hold during teleport servo (open at start)
    env_step = 0
    forward_count = 0

    with torch.no_grad():
        while env_step < args.max_steps and not done:
            step_idx = env_step                                                    # for vis label/back-compat
            forward_count += 1
            current_eef_pos  = np.array(obs["robot0_eef_pos"],  dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            rgb_obs = obs[f"{args.camera}_image"]

            past_buf.append(current_eef_pos)
            past_arr = np.stack(list(past_buf), axis=0).astype(np.float32)           # (N, 3)

            img_t   = preprocess_obs(rgb_obs, IMAGE_SIZE).to(device)
            past_t  = torch.from_numpy(past_arr).to(device).unsqueeze(0)             # (1, N, 3)
            cur_t   = torch.from_numpy(current_eef_pos.astype(np.float32)).to(device).unsqueeze(0)  # (1, 3)

            out = model(img_t, past_t, cur_t, w2c_t)
            pred_idx_t   = out["pred_voxel_idx"][0]                                  # (T,)
            grip_logit_t = out["grip_logit"][0]                                       # (T,)
            rot_logits_t = out["rot_logits"][0]                                       # (T, 3, 32)
            T_pred = pred_idx_t.shape[0]

            # Pick which timesteps of the plan to execute this forward.
            if args.actions_per_forward <= 1:
                exec_steps = [args.target_horizon]
            else:
                exec_steps = list(range(min(args.actions_per_forward, T_pred)))

            # Decode the chosen action sequence into (target_world, action) pairs.
            planned = []
            for ti in exec_steps:
                tw = voxel_centers[int(pred_idx_t[ti].item())].cpu().numpy()
                a  = decode_action(tw, current_eef_pos, current_eef_quat,
                                   grip_logit_t[ti], rot_logits_t[ti],
                                   zero_rotation=args.zero_rotation)
                planned.append((ti, tw, a))
            # `action`/`target_world` for the FIRST step in this plan — used by viz overlay below.
            _, target_world, action = planned[0]

            if frames is not None:
                # Rainbow overlay: all 8 predicted voxel centers projected to pixel
                pred_world = voxel_centers[pred_idx_t]                                # (T, 3)
                pred_pix   = world_to_pixel_torch(pred_world.unsqueeze(0), w2c_t)[0].cpu().numpy()
                vis = np.flipud(rgb_obs.astype(np.uint8)).copy()
                T = pred_pix.shape[0]
                # Project current EEF (white dot)
                cur_pix = world_to_pixel_torch(torch.from_numpy(current_eef_pos.astype(np.float32))
                                                .to(device).reshape(1, 1, 3), w2c_t)[0, 0].cpu().numpy()
                cv2.circle(vis, (int(np.clip(cur_pix[0], 0, IMAGE_SIZE-1)),
                                  int(np.clip(cur_pix[1], 0, IMAGE_SIZE-1))),
                            6, (255, 255, 255), -1, cv2.LINE_AA)
                # Rainbow polyline + numbered crosshairs
                for i in range(1, T):
                    cv2.line(vis,
                              (int(pred_pix[i-1, 0]), int(pred_pix[i-1, 1])),
                              (int(pred_pix[i,   0]), int(pred_pix[i,   1])),
                              rainbow_bgr(i, T), 2, cv2.LINE_AA)
                for i in range(T):
                    c = rainbow_bgr(i, T)
                    cv2.drawMarker(vis, (int(pred_pix[i, 0]), int(pred_pix[i, 1])),
                                   c, cv2.MARKER_CROSS, 14, 2, cv2.LINE_AA)
                    cv2.putText(vis, str(i), (int(pred_pix[i, 0]) + 5, int(pred_pix[i, 1]) - 5),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA)
                # Highlight the chosen horizon
                cv2.circle(vis, (int(pred_pix[args.target_horizon, 0]),
                                  int(pred_pix[args.target_horizon, 1])),
                            10, rainbow_bgr(args.target_horizon, T), 2, cv2.LINE_AA)
                lbl = f"step {step_idx} g={action[6]:+.0f} h={args.target_horizon}"
                cv2.putText(vis, lbl, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
                cv2.putText(vis, lbl, (8, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (20, 20, 20), 1, cv2.LINE_AA)
                frames.append(vis)

            # Execute each planned action segment (teleport-servo or single OSC step).
            for (_, tw, act) in planned:
                if done or env_step >= args.max_steps:
                    break
                if args.teleport:
                    new_grip = act[6]
                    servo_steps = 0
                    while servo_steps < args.teleport_max_servo and env_step < args.max_steps and not done:
                        cur_p = np.array(obs["robot0_eef_pos"], dtype=np.float64)
                        delta = tw - cur_p
                        if float(np.linalg.norm(delta)) < args.teleport_threshold:
                            break
                        delta_clipped = np.clip(delta / OSC_POS_SCALE, -1.0, 1.0)
                        servo_action = np.zeros(7, dtype=np.float32)
                        servo_action[:3] = delta_clipped
                        if not args.zero_rotation:
                            servo_action[3:6] = act[3:6]
                        servo_action[6] = hold_grip
                        obs, _, done, _ = env.step(servo_action)
                        env_step += 1; servo_steps += 1
                        if done:
                            success = True
                    if not done and env_step < args.max_steps:
                        grip_action = np.zeros(7, dtype=np.float32); grip_action[6] = new_grip
                        obs, _, done, _ = env.step(grip_action)
                        env_step += 1
                        if done:
                            success = True
                    hold_grip = new_grip
                else:
                    obs, _, done, _ = env.step(act)
                    env_step += 1
                    if done:
                        success = True

    elapsed = time.time() - t_start
    print(f"Result: {'SUCCESS' if success else 'FAIL'} after {env_step} env steps, "
          f"{forward_count} model forwards, in {elapsed:.1f}s")

    if frames is not None and frames:
        import imageio.v2 as imageio
        vname = f"demo{args.demo_idx}_h{args.target_horizon}_{'success' if success else 'fail'}.mp4"
        vpath = out_dir / vname
        imageio.mimwrite(str(vpath), frames, fps=30, codec="libx264", quality=8)
        print(f"Video: {vpath}")

    json.dump({"success": int(success), "env_steps": env_step,
               "forward_count": forward_count, "elapsed_sec": elapsed,
               "checkpoint": args.checkpoint, "demo_idx": args.demo_idx,
               "target_horizon": args.target_horizon, "teleport": args.teleport},
              open(out_dir / "result.json", "w"), indent=2)


if __name__ == "__main__":
    main()
