"""Verify wrist camera extrinsics/intrinsics by projecting a grid onto ground plane.

For each frame:
1. Sample a grid of 2D points on the wrist camera image (red dots)
2. Unproject each to 3D on the ground plane (z = table height)
3. Reproject those 3D points onto the agentview camera (blue dots)
4. Save side-by-side: wrist (with red grid) | agentview (with blue reprojections)

Usage:
    python debug_wrist_projection.py --benchmark libero_spatial --task_id 0
"""

import argparse
import os
import sys

import cv2
import h5py
import numpy as np

sys.path.insert(0, os.path.dirname(__file__))

from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)


def unproject_pixel_to_ground(u, v, cam_extrinsic, cam_K, ground_z=0.8):
    """Unproject a pixel (u, v) through the camera to the ground plane at z=ground_z.

    Args:
        u, v: pixel coordinates (col, row)
        cam_extrinsic: (4, 4) camera-to-world transform
        cam_K: (3, 3) intrinsic matrix (unnormalized, in pixel units)
        ground_z: world z-height of the ground plane

    Returns:
        (3,) world point, or None if ray doesn't hit ground
    """
    # Ray in camera frame
    kp_h = np.array([u, v, 1.0], dtype=np.float64)
    ray_cam = np.linalg.inv(cam_K) @ kp_h
    ray_cam = ray_cam / max(np.linalg.norm(ray_cam), 1e-12)

    # Transform to world frame
    cam_pos = cam_extrinsic[:3, 3].astype(np.float64)
    ray_world = (cam_extrinsic[:3, :3].astype(np.float64)) @ ray_cam
    ray_world = ray_world / max(np.linalg.norm(ray_world), 1e-12)

    # Intersect with z = ground_z plane
    if abs(ray_world[2]) < 1e-6:
        return None
    t = (ground_z - cam_pos[2]) / ray_world[2]
    if t < 0:
        return None
    return cam_pos + t * ray_world


def make_grid_points(h, w, n_rows=8, n_cols=8, margin=20):
    """Generate a grid of (u, v) pixel coordinates."""
    us = np.linspace(margin, w - margin, n_cols)
    vs = np.linspace(margin, h - margin, n_rows)
    points = []
    for v in vs:
        for u in us:
            points.append((int(u), int(v)))
    return points


def render_frame(env, sim, states, frame_idx, image_size, ground_z):
    """Render one frame with wrist grid → 3D → agentview reprojection."""

    # Set state
    obs = env.set_init_state(states[frame_idx])
    sim.forward()

    wrist_cam = "robot0_eye_in_hand"
    agent_cam = "agentview"

    # Get images
    wrist_img = np.flipud(obs[f"{wrist_cam}_image"].copy())
    agent_img = np.flipud(obs[f"{agent_cam}_image"].copy())
    h, w = wrist_img.shape[:2]

    # Get camera params for wrist camera
    wrist_extrinsic = get_camera_extrinsic_matrix(sim, wrist_cam)
    wrist_K = get_camera_intrinsic_matrix(sim, wrist_cam, h, w)

    # Get camera params for agentview
    agent_w2c = get_camera_transform_matrix(sim, agent_cam, h, w)

    # Sample grid on wrist image
    grid_points = make_grid_points(h, w)

    # Draw red dots on wrist image, unproject to 3D, reproject to agentview
    wrist_vis = wrist_img.copy()
    agent_vis = agent_img.copy()

    for u, v_flipped in grid_points:
        # Red dot on wrist (in flipped image space)
        cv2.circle(wrist_vis, (u, v_flipped), 3, (255, 0, 0), -1)

        # Convert flipped row → original row for unprojection with original intrinsics
        v_original = h - 1 - v_flipped
        pt3d = unproject_pixel_to_ground(u, v_original, wrist_extrinsic, wrist_K, ground_z)
        if pt3d is None:
            continue

        # Reproject onto agentview (returns coords in original image space)
        pix_rc = project_points_from_world_to_camera(
            points=pt3d.reshape(1, 3),
            world_to_camera_transform=agent_w2c,
            camera_height=h,
            camera_width=w,
        )[0]
        # project_points_from_world_to_camera returns coords in raw obs space (pre-flipud)
        # which matches our flipped image since flipud(obs) is what we display
        av = int(round(pix_rc[0]))
        au = int(round(pix_rc[1]))

        # Blue dot on agentview (if in bounds)
        if 0 <= au < w and 0 <= av < h:
            cv2.circle(agent_vis, (au, av), 4, (0, 0, 255), -1)

    # Also mark EEF position on both views
    eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)

    # EEF on wrist (project in original space, then flip row)
    wrist_w2c = get_camera_transform_matrix(sim, wrist_cam, h, w)
    eef_wrist_rc = project_points_from_world_to_camera(
        eef_pos.reshape(1, 3), wrist_w2c, h, w
    )[0]
    eu = int(round(eef_wrist_rc[1]))
    ev = int(round(eef_wrist_rc[0]))
    if 0 <= eu < w and 0 <= ev < h:
        cv2.drawMarker(wrist_vis, (eu, ev), (0, 255, 0), cv2.MARKER_CROSS, 15, 2)

    # EEF on agentview
    eef_agent_rc = project_points_from_world_to_camera(
        eef_pos.reshape(1, 3), agent_w2c, h, w
    )[0]
    eu = int(round(eef_agent_rc[1]))
    ev = int(round(eef_agent_rc[0]))
    if 0 <= eu < w and 0 <= ev < h:
        cv2.drawMarker(agent_vis, (eu, ev), (0, 255, 0), cv2.MARKER_CROSS, 15, 2)

    # Add labels
    cv2.putText(wrist_vis, f"wrist f={frame_idx}", (5, 15),
                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
    cv2.putText(agent_vis, f"agentview f={frame_idx}", (5, 15),
                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)

    # Side by side
    combined = np.concatenate([wrist_vis, agent_vis], axis=1)
    return combined


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task_id", type=int, default=0)
    parser.add_argument("--demo_id", type=int, default=0)
    parser.add_argument("--image_size", type=int, default=256)
    parser.add_argument("--ground_z", type=float, default=0.85,
                        help="Ground plane z-height for unprojection")
    parser.add_argument("--out_dir", type=str, default="out/wrist_projection_debug")
    parser.add_argument("--video", action="store_true", help="Save full episode video")
    args = parser.parse_args()

    bench = benchmark.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)

    demo_file = os.path.join(get_libero_path("datasets"),
                             bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_file, "r") as f:
        demos = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demos[min(args.demo_id, len(demos) - 1)]
        states = f[f"data/{demo_key}/states"][()]

    T_total = states.shape[0]
    print(f"Task: {task.name}")
    print(f"Demo: {demo_key}, {T_total} frames")

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.image_size,
        camera_widths=args.image_size,
        camera_names=["agentview", "robot0_eye_in_hand"],
    )
    env.seed(0)
    sim = env.env.sim

    os.makedirs(args.out_dir, exist_ok=True)

    # Render first, middle, last frames
    for label, idx in [("first", 0), ("middle", T_total // 2), ("last", T_total - 1)]:
        print(f"Rendering {label} frame (idx={idx})...")
        frame = render_frame(env, sim, states, idx, args.image_size, args.ground_z)
        path = os.path.join(args.out_dir, f"wrist_proj_{label}.png")
        cv2.imwrite(path, cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
        print(f"  → {path}  ({frame.shape[1]}x{frame.shape[0]})")

    # Optionally save full video
    if args.video:
        video_path = os.path.join(args.out_dir, "wrist_projection_video.mp4")
        print(f"\nRendering full video ({T_total} frames)...")
        first_frame = render_frame(env, sim, states, 0, args.image_size, args.ground_z)
        h_out, w_out = first_frame.shape[:2]
        writer = cv2.VideoWriter(
            video_path,
            cv2.VideoWriter_fourcc(*"mp4v"),
            15, (w_out, h_out),
        )
        for i in range(T_total):
            frame = render_frame(env, sim, states, i, args.image_size, args.ground_z)
            writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
            if (i + 1) % 10 == 0:
                print(f"  {i+1}/{T_total}")
        writer.release()
        print(f"  → {video_path}")

    env.close()
    print("Done.")


if __name__ == "__main__":
    main()
