"""OOD closed-loop libero eval for DinoVolumeQuery2View — adds shift/clean/teleport flags.

Based on eval_libero_2view.py with these additions:
  --shift_dx, --shift_dy  : shift pick+place objects in init_state (deltas, not absolute)
  --clean_scene           : hide distractors + furniture (match OOD train conditions)
  --zero_rotation         : zero out predicted rotation deltas (use upright gripper)
  --teleport              : closed-loop servo to predicted 3D target, then apply gripper
"""
import os, sys, argparse, json
import numpy as np
import torch
import cv2
import h5py
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix, get_camera_intrinsic_matrix, get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

from model_dino_volume_query_2view import DinoVolumeQuery2View, PRED_SIZE, build_bev_world_xyz_table
from model_dino_volume_query_dualfrustum import (
    DinoVolumeQuery2ViewDualFrustum, build_wrist_world_xyz_table_batched,
)
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, eef_to_start_kp

LIBERO_IMG = 448
OSC_POS_SCALE = 0.05
OSC_ROT_SCALE = 0.5

# Task 0 object qpos offsets (state has +1 prefix)
TASK0_PICK_PLACE_QPOS = [9, 37]   # bowl, plate
TASK0_DISTRACTOR_QPOS = [16, 23, 30]
DISTRACTOR_FAR = np.array([10.0, 10.0, 0.9])


def shift_init_state(state, dx, dy):
    s = state.copy()
    for qp in TASK0_PICK_PLACE_QPOS:
        si = qp + 1
        s[si] += dx
        s[si + 1] += dy
    return s


def hide_distractors_in_state(state):
    s = state.copy()
    for qp in TASK0_DISTRACTOR_QPOS:
        si = qp + 1
        s[si:si + 3] = DISTRACTOR_FAR
    return s


def reposition_camera(sim, camera_name, theta_deg, phi_deg):
    """Spherical-cap camera reposition (matches generate_ood_viewpoint.py's grid)."""
    from scipy.spatial.transform import Rotation as ScipyR
    cam_id = sim.model.camera_name2id(camera_name)
    default_pos = sim.data.cam_xpos[cam_id].copy()
    cam_xmat = sim.data.cam_xmat[cam_id].reshape(3, 3)
    forward = -cam_xmat[:, 2]
    TABLE_Z = 0.90
    t_hit = (TABLE_Z - default_pos[2]) / (forward[2] + 1e-8)
    look_at = default_pos + t_hit * forward
    radius = np.linalg.norm(default_pos - look_at)
    default_dir = (default_pos - look_at) / radius
    up = np.array([0, 0, 1.0])
    if abs(np.dot(default_dir, up)) > 0.99:
        up = np.array([1, 0, 0.0])
    right = np.cross(default_dir, up); right /= np.linalg.norm(right)
    true_up = np.cross(right, default_dir)
    theta = np.radians(theta_deg); phi = np.radians(phi_deg)
    offset = (np.sin(theta) * np.cos(phi) * right +
              np.sin(theta) * np.sin(phi) * true_up +
              np.cos(theta) * default_dir)
    new_pos = look_at + radius * offset
    fwd = look_at - new_pos
    fwd = fwd / (np.linalg.norm(fwd) + 1e-12)
    cam_z = -fwd
    up_hint = np.array([0.0, 0.0, 1.0])
    if abs(np.dot(fwd, up_hint)) > 0.99:
        up_hint = np.array([0.0, 1.0, 0.0])
    cam_x = np.cross(up_hint, cam_z); cam_x /= (np.linalg.norm(cam_x) + 1e-12)
    cam_y = np.cross(cam_z, cam_x)
    R = np.stack([cam_x, cam_y, cam_z], axis=-1)
    q = ScipyR.from_matrix(R).as_quat()
    new_quat = np.array([q[3], q[0], q[1], q[2]])
    sim.model.cam_pos[cam_id] = new_pos
    sim.model.cam_quat[cam_id] = new_quat
    sim.forward()


def apply_clean_scene(sim):
    """Hide furniture + distractors to match OOD training data."""
    for fname in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(fname)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass
    sim.forward()
    distractor_bodies = set()
    for dn in ["akita_black_bowl_2_main", "cookies_1_main", "glazed_rim_porcelain_ramekin_1_main"]:
        try:
            distractor_bodies.add(sim.model.body_name2id(dn))
        except Exception:
            pass
    for gid in range(sim.model.ngeom):
        if sim.model.geom_bodyid[gid] in distractor_bodies:
            sim.model.geom_rgba[gid][3] = 0.0


def decode_actions(out_dict, camera_pose_bev, cam_K_bev_pixel,
                    current_eef_pos, current_eef_quat,
                    min_h, max_h, min_g, max_g,
                    rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                    image_size=LIBERO_IMG, max_delta=0.05):
    vol  = out_dict["volume_logits"][0]
    grip = out_dict["gripper_logits"][0]
    rot  = out_dict["rotation_logits"][0]
    T, Z, Hg, Wg = vol.shape
    n_grip = grip.shape[-1]
    n_rot  = rot.shape[-1]
    scale = image_size / Hg

    flat = vol.reshape(T, -1).argmax(dim=-1)
    h_bins  = (flat // (Hg * Wg)).cpu().numpy()
    yx      = (flat %  (Hg * Wg)).cpu().numpy()
    py_grid = (yx // Wg).astype(np.int64)
    px_grid = (yx %  Wg).astype(np.int64)
    grip_argmax = grip.argmax(dim=-1).cpu().numpy()
    rot_argmax  = rot .argmax(dim=-1).cpu().numpy()

    R_current = ScipyR.from_quat(current_eef_quat)
    actions = []; pred_3d_list = []
    ref_pos = current_eef_pos.copy()

    for t in range(T):
        px_full = (float(px_grid[t]) + 0.5) * scale
        py_full = (float(py_grid[t]) + 0.5) * scale
        height = (h_bins[t] / max(Z - 1, 1)) * (max_h - min_h) + min_h
        p3d = recover_3d_from_direct_keypoint_and_height(
            np.array([px_full, py_full], dtype=np.float64), float(height),
            camera_pose_bev, cam_K_bev_pixel,
        )
        if p3d is None:
            p3d = pred_3d_list[-1] if pred_3d_list else ref_pos.copy()
        pred_3d_list.append(p3d)

        delta_pos = p3d - ref_pos
        norm = np.linalg.norm(delta_pos)
        if norm > max_delta:
            delta_pos = delta_pos / norm * max_delta
        delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)

        pca_val = rot_pca_min + (rot_argmax[t] + 0.5) / n_rot * (rot_pca_max - rot_pca_min)
        euler_pred = rot_pca_mean + pca_val * rot_pca_axis
        R_pred = ScipyR.from_euler('xyz', euler_pred)
        R_delta = R_pred * R_current.inv()
        delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)

        grip_continuous = (grip_argmax[t] / max(n_grip - 1, 1)) * (max_g - min_g) + min_g
        gripper_cmd = 1.0 if grip_continuous > 0.0 else -1.0

        action = np.zeros(7, dtype=np.float32)
        action[:3]  = delta_norm
        action[3:6] = delta_rot_norm
        action[6]   = gripper_cmd
        actions.append(action)

    return actions, pred_3d_list


def servo_to_pos(env, target_pos, gripper_cmd, max_servo=25, threshold=0.005):
    """Closed-loop servo to a 3D target position with given gripper command.

    Returns (obs, n_steps, done) — n_steps is the actual env.step count consumed.
    """
    obs = None
    n_steps = 0
    done = False
    for _ in range(max_servo):
        cur_obs = env.env._get_observations()
        cur_pos = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
        delta = target_pos - cur_pos
        dist = np.linalg.norm(delta)
        if dist < threshold:
            obs = cur_obs
            break
        delta_clipped = np.clip(delta / OSC_POS_SCALE, -1.0, 1.0)
        action = np.zeros(7, dtype=np.float32)
        action[:3] = delta_clipped
        action[6] = gripper_cmd
        obs, _, done, _ = env.step(action)
        n_steps += 1
        if done:
            break
    if obs is None:
        obs = env.env._get_observations()
    return obs, n_steps, done


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark",  type=str, default="libero_spatial")
    p.add_argument("--task_id",    type=int, default=0)
    p.add_argument("--n_episodes", type=int, default=5)
    p.add_argument("--max_steps",  type=int, default=600)
    p.add_argument("--seed",       type=int, default=0)
    p.add_argument("--shift_dx",   type=float, default=0.0)
    p.add_argument("--shift_dy",   type=float, default=0.0)
    p.add_argument("--clean_scene", action="store_true")
    p.add_argument("--zero_rotation", action="store_true")
    p.add_argument("--teleport",     action="store_true")
    p.add_argument("--cam_theta",    type=float, default=0.0, help="BEV camera polar angle (deg)")
    p.add_argument("--cam_phi",      type=float, default=0.0, help="BEV camera azimuth (deg)")
    p.add_argument("--out_json",   type=str, default="")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)

    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    min_g, max_g = float(ckpt["min_grip"]),   float(ckpt["max_grip"])
    rot_pca_mean = np.asarray(ckpt["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(ckpt["rot_pca_axis"], dtype=np.float64)
    rot_pca_min  = float(ckpt["rot_pca_min"]); rot_pca_max = float(ckpt["rot_pca_max"])
    n_rot_bins   = int(ckpt["n_rot_bins"])
    n_window     = int(ckpt.get("n_window", 8))
    image_size   = int(ckpt.get("image_size", LIBERO_IMG))
    bev_K_norm   = np.asarray(ckpt["bev_K_norm"], dtype=np.float32)
    bev_extrinsic = np.asarray(ckpt["bev_extrinsic"], dtype=np.float32)
    print(f"Loaded ckpt: epoch={ckpt['epoch']}, n_window={n_window}")
    print(f"  shift_dx={args.shift_dx:+.3f} shift_dy={args.shift_dy:+.3f}  clean={args.clean_scene}  zero_rot={args.zero_rotation}  teleport={args.teleport}")

    model = DinoVolumeQuery2ViewDualFrustum(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE,
        rotation_mode='1d_pca',
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)

    # bev_xyz_table is rebuilt per-episode after the env (potentially with shifted camera) is set up
    bev_xyz_table = None

    bench = bm.get_benchmark_dict()[args.benchmark]()
    task  = bench.get_task(args.task_id)
    print(f"Task: [{args.benchmark}] {task.name}")
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    n_episodes = min(args.n_episodes, len(init_states))
    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=image_size, camera_widths=image_size,
        camera_names=["agentview", "robot0_eye_in_hand"],
    )
    env.seed(args.seed); env.reset()
    if args.clean_scene:
        apply_clean_scene(env.env.sim)
        print("✓ Clean scene applied")

    K_bev_pixel = bev_K_norm.copy().astype(np.float64)
    K_bev_pixel[0] *= image_size; K_bev_pixel[1] *= image_size
    camera_pose_bev = bev_extrinsic.astype(np.float64)

    successes, step_counts = [], []
    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        env.reset()
        if args.clean_scene:
            apply_clean_scene(env.env.sim)
        init_state = init_states[ep_idx].copy()
        if args.shift_dx != 0 or args.shift_dy != 0:
            init_state = shift_init_state(init_state, args.shift_dx, args.shift_dy)
            init_state = hide_distractors_in_state(init_state)
        obs = env.set_init_state(init_state)
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        if args.cam_theta != 0.0 or args.cam_phi != 0.0:
            reposition_camera(env.env.sim, "agentview", args.cam_theta, args.cam_phi)
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        # Rebuild BEV xyz_table + decode-time camera params using THIS episode's actual BEV cam
        cur_bev_K = get_camera_intrinsic_matrix(env.env.sim, "agentview", image_size, image_size).astype(np.float32)
        cur_bev_K_norm = cur_bev_K.copy(); cur_bev_K_norm[0] /= image_size; cur_bev_K_norm[1] /= image_size
        cur_bev_ext = get_camera_extrinsic_matrix(env.env.sim, "agentview").astype(np.float32)
        bev_xyz_table = build_bev_world_xyz_table(
            torch.tensor(cur_bev_K_norm, dtype=torch.float32, device=device),
            torch.tensor(cur_bev_ext,    dtype=torch.float32, device=device),
            32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size, device,
        )
        K_bev_pixel = cur_bev_K_norm.copy().astype(np.float64)
        K_bev_pixel[0] *= image_size; K_bev_pixel[1] *= image_size
        camera_pose_bev = cur_bev_ext.astype(np.float64)

        done = False; success = False
        step_idx = 0

        while step_idx < args.max_steps and not done:
            current_eef_pos  = np.array(obs["robot0_eef_pos"],  dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            rgb_bev_obs   = obs["agentview_image"]
            rgb_wrist_obs = obs["robot0_eye_in_hand_image"]
            img_bev   = preprocess_obs(rgb_bev_obs,   image_size).to(device)
            img_wrist = preprocess_obs(rgb_wrist_obs, image_size).to(device)

            world_to_cam_bev = get_camera_transform_matrix(env.sim, "agentview", image_size, image_size)
            start_pix = eef_to_start_kp(current_eef_pos, world_to_cam_bev, image_size).to(device)
            if start_pix.dim() == 1: start_pix = start_pix.unsqueeze(0)

            wrist_ext_np = get_camera_extrinsic_matrix(env.sim, "robot0_eye_in_hand").astype(np.float32)
            wrist_K_np   = get_camera_intrinsic_matrix(env.sim, "robot0_eye_in_hand", image_size, image_size).astype(np.float32)
            wrist_K_norm = wrist_K_np.copy()
            wrist_K_norm[0] /= float(image_size); wrist_K_norm[1] /= float(image_size)
            wrist_K_t   = torch.tensor(wrist_K_norm, dtype=torch.float32, device=device).unsqueeze(0)
            wrist_ext_t = torch.tensor(wrist_ext_np, dtype=torch.float32, device=device).unsqueeze(0)

            # Build wrist xyz table per-step from current wrist pose
            wrist_xyz_table = build_wrist_world_xyz_table_batched(
                wrist_K_t, wrist_ext_t, 32, min_h, max_h, PRED_SIZE, PRED_SIZE, image_size,
            )
            cur_bev_K_t = torch.tensor(cur_bev_K_norm, dtype=torch.float32, device=device).unsqueeze(0)
            cur_bev_ext_t = torch.tensor(cur_bev_ext, dtype=torch.float32, device=device).unsqueeze(0)
            with torch.no_grad():
                out = model(img_bev, img_wrist, start_pix, bev_xyz_table, wrist_K_t, wrist_ext_t,
                             cur_bev_K_t, cur_bev_ext_t, wrist_xyz_table)

            # Dualfrustum volume is (B, T, Z, 2, H, W). For decoding, find argmax across BOTH
            # anchors; if anchor=0 use BEV xyz_table, if anchor=1 use wrist_xyz_table.
            vol_full = out["volume_logits"][0]                       # (T, Z, 2, H, W)
            T_, Z_, _, Hg, Wg = vol_full.shape
            flat = vol_full.reshape(T_, -1).argmax(dim=-1)
            # Decompose: z * 2*H*W + anchor * H*W + y*W + x
            z_idx = flat // (2 * Hg * Wg)
            rest = flat % (2 * Hg * Wg)
            anchor = (rest // (Hg * Wg)).cpu().numpy()
            yx = rest % (Hg * Wg)
            py_g = (yx // Wg).cpu().numpy()
            px_g = (yx % Wg).cpu().numpy()
            z_idx_np = z_idx.cpu().numpy()
            # Build pred_3d_list using per-timestep anchor selection
            from utils import recover_3d_from_direct_keypoint_and_height as recover_3d
            pred_3d_list_anchored = []
            scale = image_size / Hg
            bev_xyz_np = bev_xyz_table[0].cpu().numpy() if bev_xyz_table.dim() == 5 else bev_xyz_table.cpu().numpy()
            wrist_xyz_np = wrist_xyz_table[0].cpu().numpy()
            for t in range(T_):
                if anchor[t] == 0:
                    # BEV-anchored: world XYZ from bev_xyz_table
                    p3d = bev_xyz_np[z_idx_np[t], py_g[t], px_g[t]]
                else:
                    # Wrist-anchored: world XYZ from wrist_xyz_table
                    p3d = wrist_xyz_np[z_idx_np[t], py_g[t], px_g[t]]
                pred_3d_list_anchored.append(p3d.astype(np.float64))

            # Build window_actions using the anchored 3D targets (override decode_actions' default)
            from scipy.spatial.transform import Rotation as ScipyR
            current_eef_quat_64 = current_eef_quat.astype(np.float64)
            R_current = ScipyR.from_quat(current_eef_quat_64)
            grip = out["gripper_logits"][0]
            rot = out["rotation_logits"][0]
            grip_argmax = grip.argmax(dim=-1).cpu().numpy()
            rot_argmax = rot.argmax(dim=-1).cpu().numpy()
            n_grip = grip.shape[-1]; n_rot = rot.shape[-1]
            window_actions = []
            ref_pos = current_eef_pos.copy()
            for t in range(T_):
                p3d = pred_3d_list_anchored[t]
                delta_pos = p3d - ref_pos
                norm = np.linalg.norm(delta_pos)
                if norm > 0.05: delta_pos = delta_pos / norm * 0.05
                delta_norm = np.clip(delta_pos / OSC_POS_SCALE, -1.0, 1.0)
                pca_val = rot_pca_min + (rot_argmax[t] + 0.5) / n_rot * (rot_pca_max - rot_pca_min)
                euler_pred = rot_pca_mean + pca_val * rot_pca_axis
                R_pred = ScipyR.from_euler('xyz', euler_pred)
                R_delta = R_pred * R_current.inv()
                delta_rot_norm = np.clip(R_delta.as_rotvec() / OSC_ROT_SCALE, -1.0, 1.0)
                grip_cont = (grip_argmax[t] / max(n_grip - 1, 1)) * (max_g - min_g) + min_g
                gripper_cmd = 1.0 if grip_cont > 0.0 else -1.0
                action = np.zeros(7, dtype=np.float32)
                action[:3] = delta_norm
                action[3:6] = delta_rot_norm
                action[6] = gripper_cmd
                window_actions.append(action)
            pred_3d_list = pred_3d_list_anchored

            for t, action in enumerate(window_actions):
                if args.zero_rotation:
                    action[3:6] = 0.0
                if args.teleport:
                    # Servo to predicted 3D, then apply gripper
                    target = pred_3d_list[t].astype(np.float64)
                    gripper_cmd = float(action[6])
                    obs, n_servo, ep_done = servo_to_pos(env, target, gripper_cmd, max_servo=25)
                    step_idx += n_servo
                    if ep_done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                        success = True; done = True; break
                else:
                    obs, _, done, _ = env.step(action)
                    step_idx += 1
                    if done: success = True; break
                if step_idx >= args.max_steps:
                    break
            if step_idx >= args.max_steps:
                break

        successes.append(int(success))
        step_counts.append(step_idx)

    sr = float(np.mean(successes))
    print(f"\nSuccess Rate: {sr:.1%}  ({sum(successes)}/{n_episodes})  avg_steps={float(np.mean(step_counts)):.1f}")

    if args.out_json:
        out = {
            "checkpoint": args.checkpoint, "shift_dx": args.shift_dx, "shift_dy": args.shift_dy,
            "n_episodes": n_episodes, "successes": successes, "step_counts": step_counts,
            "success_rate": sr, "clean_scene": args.clean_scene,
            "zero_rotation": args.zero_rotation, "teleport": args.teleport,
        }
        Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
        with open(args.out_json, "w") as f:
            json.dump(out, f, indent=2)


if __name__ == "__main__":
    main()
