"""Verify libero wrist camera extrinsics before building the dual-view model.

For a few frames of a libero_spatial demo, render BOTH agentview AND wrist cam side-by-side,
project the EEF world position into each image, and overlay a marker. Confirms:
  1. Both cameras can be rendered simultaneously
  2. The EEF projection lands ON the gripper in both views
  3. The wrist camera's extrinsics (moves with EEF) are correctly fetched per-frame
"""
import os, sys
sys.path.insert(0, "/data/cameron/para/libero")
sys.path.insert(0, "/data/cameron/LIBERO")
os.environ.setdefault("MUJOCO_GL", "osmesa")
import numpy as np
import h5py
import cv2
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from pathlib import Path

from libero.libero import benchmark as bm, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

OUT = Path("/data/cameron/para/paper/figs/generated/libero_wrist_verify.png")
OUT.parent.mkdir(parents=True, exist_ok=True)

IMG = 448

print("Setting up libero_spatial task 0 env with both cameras...")
bench = bm.get_benchmark_dict()["libero_spatial"]()
task  = bench.get_task(0)
bddl  = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
env = OffScreenRenderEnv(
    bddl_file_name=bddl,
    camera_heights=IMG, camera_widths=IMG,
    camera_names=["agentview", "robot0_eye_in_hand"],
)
env.seed(0); env.reset()

# Load the state SEQUENCE of demo 0 — each entry is a full sim state we can snap to.
demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(0))
with h5py.File(demo_path, "r") as f:
    demo_states = f["data/demo_0/states"][:]
T_total = demo_states.shape[0]
print(f"Demo 0 has {T_total} states")

# Sample 4 snapshots across the demo for varied EEF poses
N_PANELS = 4
step_pattern = [0, T_total // 4, T_total // 2, max(T_total - 5, 1)]

panels = []
for target_t in step_pattern:
    # Snap to that recorded state directly — no need to step
    env.reset()
    obs = env.set_init_state(demo_states[target_t])
    # Tiny physics settle with zero action (avoid action-terminates issues)
    for _ in range(2):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    eef_pos  = np.array(obs["robot0_eef_pos"], dtype=np.float64)
    eef_quat = np.array(obs["robot0_eef_quat"], dtype=np.float64)

    # AGENTVIEW (static)
    bev_img = obs["agentview_image"].astype(np.uint8)
    bev_img = np.flipud(bev_img).copy()                 # libero convention
    wtc_bev = get_camera_transform_matrix(env.sim, "agentview", IMG, IMG)
    pix_bev = project_points_from_world_to_camera(
        points=eef_pos.reshape(1, 3),
        world_to_camera_transform=wtc_bev,
        camera_height=IMG, camera_width=IMG,
    )[0]
    bev_u, bev_v = float(pix_bev[1]), float(pix_bev[0])     # libero pixel convention: col=1, row=0

    # WRIST (moves with EEF — extrinsics fetched fresh per frame)
    wrist_img = obs["robot0_eye_in_hand_image"].astype(np.uint8)
    wrist_img = np.flipud(wrist_img).copy()
    wtc_wrist = get_camera_transform_matrix(env.sim, "robot0_eye_in_hand", IMG, IMG)
    pix_wrist = project_points_from_world_to_camera(
        points=eef_pos.reshape(1, 3),
        world_to_camera_transform=wtc_wrist,
        camera_height=IMG, camera_width=IMG,
    )[0]
    wrist_u, wrist_v = float(pix_wrist[1]), float(pix_wrist[0])

    # Sanity: also project EEF + 5cm-Z (above the wrist) to check the camera "looks down"
    eef_above = eef_pos.copy(); eef_above[2] += 0.05
    pix_above_bev   = project_points_from_world_to_camera(eef_above.reshape(1, 3), wtc_bev,   IMG, IMG)[0]
    pix_above_wrist = project_points_from_world_to_camera(eef_above.reshape(1, 3), wtc_wrist, IMG, IMG)[0]
    above_bev_u, above_bev_v = float(pix_above_bev[1]),   float(pix_above_bev[0])
    above_wrist_u, above_wrist_v = float(pix_above_wrist[1]), float(pix_above_wrist[0])

    # Annotate BEV
    bev_draw = bev_img.copy()
    cv2.circle(bev_draw, (int(bev_u), int(bev_v)), 8, (255, 30, 30), -1)
    cv2.circle(bev_draw, (int(bev_u), int(bev_v)), 16, (255, 255, 255), 2)
    cv2.line(bev_draw,
             (int(bev_u), int(bev_v)),
             (int(above_bev_u), int(above_bev_v)),
             (30, 255, 30), 2)
    cv2.putText(bev_draw, f"step {target_t}  eef={eef_pos.round(3).tolist()}",
                (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 1)

    # Annotate wrist
    wrist_draw = wrist_img.copy()
    in_bounds_wrist = 0 <= wrist_u < IMG and 0 <= wrist_v < IMG
    color = (30, 255, 30) if in_bounds_wrist else (255, 30, 30)
    cv2.circle(wrist_draw, (int(np.clip(wrist_u, 0, IMG-1)), int(np.clip(wrist_v, 0, IMG-1))),
                8, (255, 30, 30), -1)
    cv2.circle(wrist_draw, (int(np.clip(wrist_u, 0, IMG-1)), int(np.clip(wrist_v, 0, IMG-1))),
                16, color, 2)
    cv2.line(wrist_draw,
             (int(np.clip(wrist_u, 0, IMG-1)), int(np.clip(wrist_v, 0, IMG-1))),
             (int(np.clip(above_wrist_u, 0, IMG-1)), int(np.clip(above_wrist_v, 0, IMG-1))),
             (30, 255, 30), 2)
    bounds_msg = "EEF in-bounds" if in_bounds_wrist else "EEF OUT-of-bounds"
    cv2.putText(wrist_draw, f"step {target_t}  {bounds_msg}",
                (8, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1)

    print(f"  step {target_t}: eef=({eef_pos[0]:+.3f}, {eef_pos[1]:+.3f}, {eef_pos[2]:+.3f}) "
          f"BEV pix=({bev_u:.1f}, {bev_v:.1f}) wrist pix=({wrist_u:.1f}, {wrist_v:.1f}) "
          f"{'in' if in_bounds_wrist else 'OUT'}")

    panels.append((target_t, bev_draw, wrist_draw))

# Compose 4 rows x 2 cols figure
fig, axes = plt.subplots(N_PANELS, 2, figsize=(7.2, 3.4 * N_PANELS))
for r, (t, bev, wrist) in enumerate(panels):
    axes[r, 0].imshow(bev);   axes[r, 0].set_xticks([]); axes[r, 0].set_yticks([])
    axes[r, 1].imshow(wrist); axes[r, 1].set_xticks([]); axes[r, 1].set_yticks([])
    axes[r, 0].set_title(f"agentview — step {t}", fontsize=9)
    axes[r, 1].set_title(f"robot0_eye_in_hand — step {t}", fontsize=9)
fig.suptitle("Wrist-camera extrinsics verification: red dot = EEF world→camera projection. "
             "Green segment = EEF + 5cm Z.", fontsize=10)
fig.tight_layout()
fig.savefig(OUT, dpi=150, bbox_inches='tight', facecolor='white')
print(f"\n✓ Saved {OUT}")
