"""OOD closed-loop libero eval for DinoVolumeQuery (single-view query-MLP).

Mirrors eval_libero_2view_ood.py — adds shift/clean/teleport flags to eval_libero_query.py.
"""
import os, sys, argparse, json
import numpy as np
import torch
import cv2
import h5py
from pathlib import Path
from tqdm import tqdm
from scipy.spatial.transform import Rotation as ScipyR

sys.path.insert(0, os.path.dirname(__file__))
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("DINO_REPO_DIR",     "/data/cameron/keygrip/dinov3")
os.environ.setdefault("DINO_WEIGHTS_PATH", "/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth")
os.environ.setdefault("MUJOCO_GL", "osmesa")
os.environ.setdefault("PYOPENGL_PLATFORM", "osmesa")

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import get_camera_extrinsic_matrix, get_camera_intrinsic_matrix, get_camera_transform_matrix

from model_dino_volume_query import DinoVolumeQuery, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height
from eval import preprocess_obs, get_camera_params, eef_to_start_kp
from eval_libero_query import decode_query_actions
from eval_libero_2view_ood import (
    shift_init_state, hide_distractors_in_state, apply_clean_scene, servo_to_pos,
    reposition_camera,
)

LIBERO_IMG = 448


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--checkpoint", type=str, required=True)
    p.add_argument("--benchmark", default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--camera", default="agentview")
    p.add_argument("--n_episodes", type=int, default=5)
    p.add_argument("--max_steps", type=int, default=600)
    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--shift_dx", type=float, default=0.0)
    p.add_argument("--shift_dy", type=float, default=0.0)
    p.add_argument("--clean_scene", action="store_true")
    p.add_argument("--zero_rotation", action="store_true")
    p.add_argument("--teleport", action="store_true")
    p.add_argument("--cam_theta", type=float, default=0.0)
    p.add_argument("--cam_phi", type=float, default=0.0)
    p.add_argument("--out_json", default="")
    args = p.parse_args()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ckpt = torch.load(args.checkpoint, map_location="cpu", weights_only=False)
    min_h, max_h = float(ckpt["min_height"]), float(ckpt["max_height"])
    min_g, max_g = float(ckpt["min_grip"]),   float(ckpt["max_grip"])
    rot_pca_mean = np.asarray(ckpt["rot_pca_mean"], dtype=np.float64)
    rot_pca_axis = np.asarray(ckpt["rot_pca_axis"], dtype=np.float64)
    rot_pca_min  = float(ckpt["rot_pca_min"]);  rot_pca_max = float(ckpt["rot_pca_max"])
    n_rot_bins   = int(ckpt["n_rot_bins"])
    n_window     = int(ckpt.get("n_window", 8))
    image_size   = int(ckpt.get("image_size", LIBERO_IMG))
    print(f"Loaded ckpt: epoch={ckpt['epoch']}, n_window={n_window}")
    print(f"  shift=({args.shift_dx:+.3f},{args.shift_dy:+.3f}) clean={args.clean_scene} zero_rot={args.zero_rotation} teleport={args.teleport}")

    model = DinoVolumeQuery(
        n_window=n_window, n_height_bins=32, n_gripper_bins=32, n_rot_bins=n_rot_bins,
        image_size=image_size, pred_size=PRED_SIZE,
        use_eef=True, rotation_mode='1d_pca',
    ).to(device).eval()
    model.load_state_dict(ckpt["model_state_dict"], strict=False)

    bench = bm.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        demo_keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_states = [f[f"data/{k}/states"][0] for k in demo_keys]
    n_episodes = min(args.n_episodes, len(init_states))
    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=image_size, camera_widths=image_size,
        camera_names=[args.camera],
    )
    env.seed(args.seed); env.reset()
    if args.clean_scene:
        apply_clean_scene(env.env.sim)

    successes, step_counts = [], []
    for ep_idx in tqdm(range(n_episodes), desc="Episodes"):
        env.reset()
        if args.clean_scene:
            apply_clean_scene(env.env.sim)
        init_state = init_states[ep_idx].copy()
        if args.shift_dx != 0 or args.shift_dy != 0:
            init_state = shift_init_state(init_state, args.shift_dx, args.shift_dy)
            init_state = hide_distractors_in_state(init_state)
        obs = env.set_init_state(init_state)
        for _ in range(5):
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        if args.cam_theta != 0.0 or args.cam_phi != 0.0:
            reposition_camera(env.env.sim, args.camera, args.cam_theta, args.cam_phi)
            obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

        world_to_camera, camera_pose, cam_K = get_camera_params(env.sim, args.camera, image_size)
        done = False; success = False
        step_idx = 0
        while step_idx < args.max_steps and not done:
            current_eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
            current_eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)
            rgb_obs = obs[f"{args.camera}_image"]
            img_tensor = preprocess_obs(rgb_obs, image_size).to(device)
            start_pix = eef_to_start_kp(current_eef_pos, world_to_camera, image_size).to(device)
            if start_pix.dim() == 1:
                start_pix = start_pix.unsqueeze(0)
            with torch.no_grad():
                out = model(img_tensor, start_pix=start_pix)
            window_actions, pred_3d_list = decode_query_actions(
                out, camera_pose, cam_K, current_eef_pos, current_eef_quat,
                min_h, max_h, min_g, max_g,
                rot_pca_mean, rot_pca_axis, rot_pca_min, rot_pca_max,
                image_size=image_size,
            )
            for t, action in enumerate(window_actions):
                if args.zero_rotation:
                    action[3:6] = 0.0
                if args.teleport:
                    target = pred_3d_list[t].astype(np.float64)
                    gripper_cmd = float(action[6])
                    obs, n_servo, ep_done = servo_to_pos(env, target, gripper_cmd, max_servo=25)
                    step_idx += n_servo
                    if ep_done or (hasattr(env.env, '_check_success') and env.env._check_success()):
                        success = True; done = True; break
                else:
                    obs, _, done, _ = env.step(action)
                    step_idx += 1
                    if done: success = True; break
                if step_idx >= args.max_steps: break
            if step_idx >= args.max_steps: break
        successes.append(int(success))
        step_counts.append(step_idx)

    sr = float(np.mean(successes))
    print(f"\nSuccess Rate: {sr:.1%}  ({sum(successes)}/{n_episodes})  avg_steps={float(np.mean(step_counts)):.1f}")
    if args.out_json:
        Path(args.out_json).parent.mkdir(parents=True, exist_ok=True)
        with open(args.out_json, "w") as f:
            json.dump({"checkpoint": args.checkpoint, "shift_dx": args.shift_dx, "shift_dy": args.shift_dy,
                       "n_episodes": n_episodes, "successes": successes, "step_counts": step_counts,
                       "success_rate": sr}, f, indent=2)


if __name__ == "__main__":
    main()
