"""Closed-loop libero eval for the dual-view DinoVolumeQuery2View model.

Variant of eval_libero_query.py — uses both agentview + robot0_eye_in_hand at each step,
feeds them through the dual-view model, decodes the same way (volume argmax → world XYZ,
1D PCA rotation, gripper bin). The wrist extrinsic moves with the EEF so we fetch it
fresh from the simulator per step.
"""
import os, sys, argparse
import numpy as np
import torch
import cv2
import h5py
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix, get_camera_intrinsic_matrix, get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

from model_dino_volume_query_2view import DinoVolumeQuery2View, PRED_SIZE, build_bev_world_xyz_table
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, eef_to_start_kp

LIBERO_IMG = 448
OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5


def decode_actions(out_dict, camera_pose_bev, cam_K_bev_pixel,
                    current_eef_pos, current_eef_quat,
                    min_h, max_h, min_g, max_g,
                    rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                    image_size=LIBERO_IMG, max_delta=0.05):
    vol  = out_dict["volume_logits"][0]                  # (T, Z, H, W)
    grip = out_dict["gripper_logits"][0]                 # (T, n_grip)
    rot  = out_dict["rotation_logits"][0]                # (T, n_rot)
    T, Z, Hg, Wg = vol.shape
    n_grip = grip.shape[-1]
    n_rot  = rot.shape[-1]
    scale = image_size / Hg

    flat = vol.reshape(T, -1).argmax(dim=-1)
    h_bins  = (flat // (Hg * Wg)).cpu().numpy()
    yx      = (flat %  (Hg * Wg)).cpu().numpy()
    py_grid = (yx // Wg).astype(np.int64)
    px_grid = (yx %  Wg).astype(np.int64)
    grip_argmax = grip.argmax(dim=-1).cpu().numpy()
    rot_argmax  = rot .argmax(dim=-1).cpu().numpy()

    R_current = ScipyR.from_quat(current_eef_quat)
    actions = []; pred_3d = []
    ref_pos = current_eef_pos.copy()

    for t in range(T):
        px_full = (float(px_grid[t]) + 0.5) * scale
        py_full = (float(py_grid[t]) + 0.5) * scale
        height = (h_bins[t] / max(Z - 1, 1)) * (max_h - min_h) + min_h
        p3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), float(height),
            camera_pose_bev, cam_K_bev_pixel,
        )
        if p3d is None:
            p3d = pred_3d[-1] if pred_3d else ref_pos.copy()
        pred_3d.append(p3d)

        delta_pos = p3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        pca_val = rot_pca_min + (rot_argmax[t] + 0.5) / n_rot * (rot_pca_max - rot_pca_min)
        euler_pred = rot_pca_mean + pca_val * rot_pca_axis
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        grip_continuous = (grip_argmax[t] / max(n_grip - 1, 1)) * (max_g - min_g) + min_g
        gripper_cmd = 1.0 if grip_continuous > 0.0 else -1.0

        action = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot_norm
        action[6]   = gripper_cmd
        actions.append(action)

    return actions, pred_3d


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark",  type=str, default="libero_spatial")
    p.add_argument("--task_id",    type=int, default=0)
    p.add_argument("--n_episodes", type=int, default=10)
    p.add_argument("--max_steps",  type=int, default=500)
    p.add_argument("--seed",       type=int, default=0)
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)

    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    min_g, max_g = float(ckpt["min_grip"]),   float(ckpt["max_grip"])
    rot_pca_mean = np.asarray(ckpt["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(ckpt["rot_pca_axis"], dtype=np.float64)
    rot_pca_min  = float(ckpt["rot_pca_min"]);  rot_pca_max = float(ckpt["rot_pca_max"])
    n_rot_bins   = int(ckpt["n_rot_bins"])
    n_window     = int(ckpt.get("n_window", 8))
    image_size   = int(ckpt.get("image_size", LIBERO_IMG))
    bev_K_norm   = np.asarray(ckpt["bev_K_norm"], dtype=np.float32)
    bev_extrinsic = np.asarray(ckpt["bev_extrinsic"], dtype=np.float32)
    print(f"Loaded ckpt: epoch={ckpt['epoch']}, n_window={n_window}")

    model = DinoVolumeQuery2View(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE,
        rotation_mode='1d_pca',
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)

    # Precompute static BEV world-XYZ table
    bev_xyz_table = build_bev_world_xyz_table(
        torch.tensor(bev_K_norm,   dtype=torch.float32, device=device),
        torch.tensor(bev_extrinsic, dtype=torch.float32, device=device),
        32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size, device,
    )

    # libero env
    bench = bm.get_benchmark_dict()[args.benchmark]()
    task  = bench.get_task(args.task_id)
    print(f"Task: [{args.benchmark}] {task.name}")
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    n_episodes = min(args.n_episodes, len(init_states))
    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=image_size, camera_widths=image_size,
        camera_names=["agentview", "robot0_eye_in_hand"],
    )
    env.seed(args.seed); env.reset()

    # BEV intrinsic K in pixel space (for 3D recovery in decode)
    K_bev_pixel = bev_K_norm.copy().astype(np.float64)
    K_bev_pixel[0] *= image_size; K_bev_pixel[1] *= image_size
    camera_pose_bev = bev_extrinsic.astype(np.float64)

    successes = []
    step_counts = []
    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        env.reset()
        obs = env.set_init_state(init_states[ep_idx])
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        done = False; success = False
        step_idx = 0

        while step_idx < args.max_steps and not done:
            current_eef_pos  = np.array(obs["robot0_eef_pos"],  dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)

            rgb_bev_obs   = obs["agentview_image"]
            rgb_wrist_obs = obs["robot0_eye_in_hand_image"]
            img_bev   = preprocess_obs(rgb_bev_obs,   image_size).to(device)
            img_wrist = preprocess_obs(rgb_wrist_obs, image_size).to(device)

            # BEV world→camera projection (use static one from ckpt for EEF→start_pix)
            world_to_cam_bev = get_camera_transform_matrix(env.sim, "agentview", image_size, image_size)
            start_pix = eef_to_start_kp(current_eef_pos, world_to_cam_bev, image_size).to(device)
            if start_pix.dim() == 1: start_pix = start_pix.unsqueeze(0)

            # Wrist extrinsic + K — fetched fresh each step (wrist moves with EEF)
            wrist_ext_np = get_camera_extrinsic_matrix(env.sim, "robot0_eye_in_hand").astype(np.float32)
            wrist_K_np   = get_camera_intrinsic_matrix(env.sim, "robot0_eye_in_hand", image_size, image_size).astype(np.float32)
            wrist_K_norm = wrist_K_np.copy()
            wrist_K_norm[0] /= float(image_size); wrist_K_norm[1] /= float(image_size)
            wrist_K_t   = torch.tensor(wrist_K_norm, dtype=torch.float32, device=device).unsqueeze(0)
            wrist_ext_t = torch.tensor(wrist_ext_np, dtype=torch.float32, device=device).unsqueeze(0)

            with torch.no_grad():
                out = model(img_bev, img_wrist, start_pix, bev_xyz_table, wrist_K_t, wrist_ext_t)

            window_actions, _ = decode_actions(
                out, camera_pose_bev, K_bev_pixel,
                current_eef_pos, current_eef_quat,
                min_h, max_h, min_g, max_g,
                rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                image_size=image_size,
            )
            for action in window_actions:
                obs, _, done, _ = env.step(action)
                step_idx += 1
                if done: success = True; break
                if step_idx >= args.max_steps: break

        successes.append(success)
        step_counts.append(step_idx)

    sr = float(np.mean(successes))
    print(f"\nFinal: {sum(successes)}/{n_episodes} = {sr:.1%}  avg_steps={float(np.mean(step_counts)):.1f}")


if __name__ == "__main__":
    main()
