"""Closed-loop libero eval for DinoVolumeQuery (query-MLP arch).

Adapts the existing eval.py logic to the query-MLP output format:
  - model(rgb, start_pix=current_eef_pix) returns dict with volume_logits,
    gripper_logits, rotation_logits already per-timestep (no separate predict_at_pixels)
  - Rotation is 1D PCA: decode bin → mu + val·v1 → euler XYZ
  - Volume CE is joint over (Z, H, W); argmax gives (z*, y*, x*) per timestep

Run:
  CUDA_VISIBLE_DEVICES=8 PYTHONPATH=/data/cameron/LIBERO \\
  DINO_REPO_DIR=/data/cameron/keygrip/dinov3 \\
  python eval_libero_query.py \\
    --checkpoint checkpoints/libero_query_libero_spatial_t0_v0/latest.pth \\
    --benchmark libero_spatial --task_id 0 --n_episodes 10 --save_video
"""
import os, sys, argparse, json
import numpy as np
import torch
import cv2
import h5py
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix, get_camera_intrinsic_matrix, get_camera_transform_matrix,
)

from model_dino_volume_query import DinoVolumeQuery, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, get_camera_params, eef_to_start_kp

LIBERO_IMG = 448
OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5


def decode_query_actions(out_dict, camera_pose, cam_K,
                          current_eef_pos, current_eef_quat,
                          min_h, max_h, min_g, max_g,
                          rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                          image_size=LIBERO_IMG, max_delta=0.05):
    """Decode all T predicted timesteps from a query-MLP forward into OSC_POSE actions.

    out_dict contains volume_logits (1, T, Z, H, W), gripper_logits (1, T, n_grip),
    rotation_logits (1, T, n_rot) — all already per-timestep.

    Returns (actions: list of T (7,) numpy, pred_3d_targets: list of T (3,) numpy).
    """
    vol  = out_dict["volume_logits"][0]                        # (T, Z, H, W)
    grip = out_dict["gripper_logits"][0]                       # (T, n_grip)
    rot  = out_dict["rotation_logits"][0]                      # (T, n_rot) — 1D PCA bins
    T, Z, Hg, Wg = vol.shape
    n_grip = grip.shape[-1]
    n_rot  = rot.shape[-1]
    scale = image_size / Hg

    actions = []
    pred_3d = []
    ref_pos = current_eef_pos.copy()

    # Joint argmax over (Z, H, W) per t
    flat = vol.reshape(T, -1).argmax(dim=-1)                   # (T,)
    h_bins  = (flat // (Hg * Wg)).cpu().numpy()
    yx      = (flat %  (Hg * Wg)).cpu().numpy()
    py_grid = (yx // Wg).astype(np.int64)
    px_grid = (yx %  Wg).astype(np.int64)

    grip_argmax = grip.argmax(dim=-1).cpu().numpy()             # (T,)
    rot_argmax  = rot .argmax(dim=-1).cpu().numpy()             # (T,)

    R_current = ScipyR.from_quat(current_eef_quat)

    for t in range(T):
        # 3D position via unprojection
        px_full = (float(px_grid[t]) + 0.5) * scale
        py_full = (float(py_grid[t]) + 0.5) * scale
        height  = (h_bins[t] / max(Z - 1, 1)) * (max_h - min_h) + min_h
        p3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), float(height),
            camera_pose, cam_K,
        )
        if p3d is None:
            p3d = pred_3d[-1] if pred_3d else ref_pos.copy()
        pred_3d.append(p3d)

        # Delta position → OSC
        delta_pos = p3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        # Rotation: 1D PCA bin → projected value → mu + val·v1 → euler → delta
        pca_val = rot_pca_min + (rot_argmax[t] + 0.5) / n_rot * (rot_pca_max - rot_pca_min)
        euler_pred = rot_pca_mean + pca_val * rot_pca_axis      # (3,) — absolute euler XYZ
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        # Gripper: bin → continuous value → threshold sign for binary command
        grip_continuous = (grip_argmax[t] / max(n_grip - 1, 1)) * (max_g - min_g) + min_g
        # libero gripper action is in [-1, 1] (open / close); training data was -1/+1, so threshold at 0.
        gripper_cmd = 1.0 if grip_continuous > 0.0 else -1.0

        action = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot_norm
        action[6]   = gripper_cmd
        actions.append(action)

    return actions, pred_3d


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark",  type=str, default="libero_spatial")
    p.add_argument("--task_id",    type=int, default=0)
    p.add_argument("--camera",     type=str, default="agentview")
    p.add_argument("--n_episodes", type=int, default=10)
    p.add_argument("--max_steps",  type=int, default=600)
    p.add_argument("--seed",       type=int, default=0)
    p.add_argument("--save_video", action="store_true")
    p.add_argument("--video_dir",  type=str, default="/data/cameron/para/libero/eval_videos")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt_path = Path(args.checkpoint)
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)

    # Restore ranges + PCA basis from ckpt
    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    min_g, max_g = float(ckpt["min_grip"]),   float(ckpt["max_grip"])
    rot_pca_mean = np.asarray(ckpt["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(ckpt["rot_pca_axis"], dtype=np.float64)
    rot_pca_min  = float(ckpt["rot_pca_min"])
    rot_pca_max  = float(ckpt["rot_pca_max"])
    n_rot_bins   = int(ckpt["n_rot_bins"])
    n_window     = int(ckpt.get("n_window", 8))
    image_size   = int(ckpt.get("image_size", LIBERO_IMG))
    print(f"Loaded ckpt: epoch={ckpt['epoch']}, n_window={n_window}, n_rot_bins={n_rot_bins}")
    print(f"  height [{min_h:.3f}, {max_h:.3f}]  grip [{min_g:.3f}, {max_g:.3f}]")
    print(f"  rot PCA: mu={rot_pca_mean}  axis={rot_pca_axis}  range=[{rot_pca_min:.2f}, {rot_pca_max:.2f}]")

    # Build model
    model = DinoVolumeQuery(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE,
        use_eef=True, rotation_mode='1d_pca',
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    print(f"Model loaded.")

    # LIBERO benchmark + task
    bench = bm.get_benchmark_dict()[args.benchmark]()
    task  = bench.get_task(args.task_id)
    task_name = task.name
    print(f"Task: [{args.benchmark}] {task_name}")

    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    n_episodes = min(args.n_episodes, len(init_states))
    print(f"Running {n_episodes} / {len(init_states)} episodes...")

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=image_size, camera_widths=image_size,
        camera_names=[args.camera],
    )
    env.seed(args.seed); env.reset()
    print("Environment ready.")

    if args.save_video:
        Path(args.video_dir).mkdir(parents=True, exist_ok=True)

    successes = []
    step_counts = []
    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        env.reset()
        obs = env.set_init_state(init_states[ep_idx])
        # Physics settle
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, image_size)

        done = False; success = False
        step_idx = 0
        frames = [] if args.save_video else None

        while step_idx < args.max_steps and not done:
            current_eef_pos  = np.array(obs["robot0_eef_pos"],  dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            rgb_obs = obs[f"{args.camera}_image"]
            img_tensor = preprocess_obs(rgb_obs, image_size).to(device)
            start_pix = eef_to_start_kp(current_eef_pos, world_to_camera, image_size).to(device)
            # start_pix from eef_to_start_kp is (B,2) shape — match the query model's expectation
            if start_pix.dim() == 1:
                start_pix = start_pix.unsqueeze(0)

            with torch.no_grad():
                out = model(img_tensor, start_pix=start_pix)

            window_actions, _ = decode_query_actions(
                out, camera_pose, cam_K,
                current_eef_pos, current_eef_quat,
                min_h, max_h, min_g, max_g,
                rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                image_size=image_size,
            )

            for action in window_actions:
                if frames is not None:
                    f = obs[f"{args.camera}_image"].astype(np.float32) / 255.0
                    f = np.flipud(f).copy()
                    f = cv2.resize(f, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
                    vis = (f * 255.0).astype(np.uint8)
                    cv2.putText(vis, f"step {step_idx}", (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
                    frames.append(vis)
                obs, _, done, _ = env.step(action)
                step_idx += 1
                if done:
                    success = True; break
                if step_idx >= args.max_steps:
                    break

        successes.append(success)
        step_counts.append(step_idx)

        if frames:
            out_path = Path(args.video_dir) / f"ep{ep_idx:03d}_{'OK' if success else 'FAIL'}_steps{step_idx}.mp4"
            h, w = frames[0].shape[:2]
            writer = cv2.VideoWriter(str(out_path), cv2.VideoWriter_fourcc(*"mp4v"), 20, (w, h))
            for fr in frames:
                writer.write(cv2.cvtColor(fr, cv2.COLOR_RGB2BGR))
            writer.release()

    sr = float(np.mean(successes))
    print(f"\nFinal: {sum(successes)}/{n_episodes} = {sr:.1%} success rate")
    print(f"  avg steps: {float(np.mean(step_counts)):.1f}  (max {args.max_steps})")


if __name__ == "__main__":
    main()
