"""Project LIBERO EEF keypoint into camera image for sanity checking.

This script uses only LIBERO demos + LIBERO env state replay:
- Loads one demo from a LIBERO HDF5 file
- Sets simulator state to a selected timestep
- Reads camera extrinsics from MuJoCo
- Builds camera intrinsics from MuJoCo camera fovy + image size
- Projects EEF (from simulator FK / observation) to pixel
- Saves an overlay image
"""

import argparse
import os
from typing import Tuple

import cv2
import h5py
import numpy as np

from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)
import robosuite.utils.transform_utils as T


def image_obs_key(camera_name: str) -> str:
    if camera_name == "robot0_eye_in_hand":
        return "robot0_eye_in_hand_image"
    return f"{camera_name}_image"


def main():
    parser = argparse.ArgumentParser(description="LIBERO camera/EFF projection debug")
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task-id", type=int, default=0)
    parser.add_argument("--demo-id", type=int, default=0)
    parser.add_argument("--frame-idx", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--out-image", type=str, default="/Users/cameronsmith/Projects/robotics_testing/LIBERO/out/libero_proj_debug.png")
    parser.add_argument("--out-video", type=str, default="", help="Optional output mp4 for full demo")
    parser.add_argument("--camera-height", type=int, default=256)
    parser.add_argument("--camera-width", type=int, default=256)
    parser.add_argument("--debug-panel", action="store_true", help="Write 3-panel convention debug image")
    parser.add_argument("--fps", type=int, default=15)
    args = parser.parse_args()

    bench_dict = benchmark.get_benchmark_dict()
    if args.benchmark not in bench_dict:
        raise ValueError(f"Unknown benchmark {args.benchmark}. options={list(bench_dict.keys())}")
    suite = bench_dict[args.benchmark]()
    task = suite.get_task(args.task_id)

    demo_file = os.path.join(get_libero_path("datasets"), suite.get_task_demonstration(args.task_id))
    if not os.path.isfile(demo_file):
        raise FileNotFoundError(f"Demo file not found: {demo_file}")

    with h5py.File(demo_file, "r") as f:
        demos = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        if len(demos) == 0:
            raise RuntimeError("No demo_* entries in HDF5")
        demo_key = demos[min(args.demo_id, len(demos) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        if states.shape[0] == 0:
            raise RuntimeError(f"{demo_key} has no states")
        idx = min(max(args.frame_idx, 0), states.shape[0] - 1)

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env_args = {
        "bddl_file_name": bddl_file,
        "camera_heights": args.camera_height,
        "camera_widths": args.camera_width,
        "camera_names": [args.camera],
    }
    env = OffScreenRenderEnv(**env_args)
    env.seed(0)
    def to_u8(x):
        y = np.ascontiguousarray(x.copy())
        if y.dtype != np.uint8:
            y = (np.clip(y, 0, 1) * 255).astype(np.uint8)
        return np.ascontiguousarray(y)

    def render_projection_frame(frame_i: int):
        state_i = states[frame_i]
        obs = env.set_init_state(state_i)
        env.env.sim.forward()
        eef_pos = np.asarray(obs["robot0_eef_pos"], dtype=np.float64)
        img_key = image_obs_key(args.camera)
        if img_key not in obs:
            raise KeyError(f"Observation key {img_key} not present. keys={list(obs.keys())}")
        img_obs = np.asarray(obs[img_key]).copy()
        h, w = img_obs.shape[:2]
        img_raw = env.env.sim.render(camera_name=args.camera, height=h, width=w)

        world_to_camera = get_camera_transform_matrix(
            sim=env.env.sim,
            camera_name=args.camera,
            camera_height=h,
            camera_width=w,
        )
        pix_rc = project_points_from_world_to_camera(
            points=eef_pos.reshape(1, 3),
            world_to_camera_transform=world_to_camera,
            camera_height=h,
            camera_width=w,
        )[0]
        v, u = int(pix_rc[0]), int(pix_rc[1])

        vis_obs = to_u8(img_obs)
        vis_raw = to_u8(img_raw)
        vis_raw_flip = np.ascontiguousarray(np.flipud(vis_raw.copy()))
        vis_obs_upright = np.ascontiguousarray(np.flipud(vis_obs))
        vf = (h - 1) - v

        if args.debug_panel:
            for vis, name in [(vis_obs, "obs"), (vis_raw, "raw"), (vis_raw_flip, "raw_flipud")]:
                if 0 <= u < w and 0 <= v < h:
                    cv2.circle(vis, (u, v), 6, (0, 255, 0), -1)
                    cv2.putText(vis, "proj", (u + 8, v - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.42, (0, 255, 0), 1, cv2.LINE_AA)
                if 0 <= u < w and 0 <= vf < h:
                    cv2.circle(vis, (u, vf), 6, (255, 0, 255), 2)
                    cv2.putText(vis, "proj_yflip", (u + 8, vf + 12), cv2.FONT_HERSHEY_SIMPLEX, 0.36, (255, 0, 255), 1, cv2.LINE_AA)
                cv2.putText(vis, name, (10, 18), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
                cv2.putText(vis, name, (10, 18), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)
            vis = np.concatenate([vis_obs, vis_raw, vis_raw_flip], axis=1)
        else:
            vis = vis_obs_upright
            if 0 <= u < w and 0 <= v < h:
                cv2.circle(vis, (u, v), 6, (0, 255, 0), -1)
                cv2.putText(vis, "eef proj", (u + 8, v - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (0, 255, 0), 1, cv2.LINE_AA)

                # Project same XY at robot-base height plane (z = robot0_base z), then connect.
                base_body_name = "robot0_base"
                base_body_id = env.env.sim.model.body_name2id(base_body_name)
                base_z = float(env.env.sim.data.xpos[base_body_id][2]) if base_body_id >= 0 else 0.0
                eef_base_plane = eef_pos.copy()
                eef_base_plane[2] = base_z
                pix_base_rc = project_points_from_world_to_camera(
                    points=eef_base_plane.reshape(1, 3),
                    world_to_camera_transform=world_to_camera,
                    camera_height=h,
                    camera_width=w,
                )[0]
                vg, ug = int(pix_base_rc[0]), int(pix_base_rc[1])
                if 0 <= ug < w and 0 <= vg < h:
                    cv2.circle(vis, (ug, vg), 6, (0, 255, 255), 2)  # yellow/cyan ring
                    cv2.putText(
                        vis,
                        f"base z={base_z:.3f}",
                        (ug + 8, vg + 12),
                        cv2.FONT_HERSHEY_SIMPLEX,
                        0.4,
                        (0, 255, 255),
                        1,
                        cv2.LINE_AA,
                    )
                    cv2.line(vis, (u, v), (ug, vg), (255, 255, 0), 2, cv2.LINE_AA)

                # Draw projected EEF local coordinate axes from quaternion.
                # robosuite uses quaternion convention (x, y, z, w).
                eef_quat = np.asarray(obs["robot0_eef_quat"], dtype=np.float64)
                eef_rot = T.quat2mat(eef_quat)
                axis_len = 0.08  # meters
                axes_world = {
                    "x": eef_pos + eef_rot[:, 0] * axis_len,
                    "y": eef_pos + eef_rot[:, 1] * axis_len,
                    "z": eef_pos + eef_rot[:, 2] * axis_len,
                }
                axis_colors = {"x": (255, 0, 0), "y": (0, 255, 0), "z": (0, 0, 255)}  # RGB
                for axis_name, endpoint in axes_world.items():
                    pix_axis_rc = project_points_from_world_to_camera(
                        points=endpoint.reshape(1, 3),
                        world_to_camera_transform=world_to_camera,
                        camera_height=h,
                        camera_width=w,
                    )[0]
                    va, ua = int(pix_axis_rc[0]), int(pix_axis_rc[1])
                    if 0 <= ua < w and 0 <= va < h:
                        cv2.line(vis, (u, v), (ua, va), axis_colors[axis_name], 2, cv2.LINE_AA)
                        cv2.circle(vis, (ua, va), 3, axis_colors[axis_name], -1)
            else:
                cv2.putText(vis, f"proj out of frame: ({u},{v})", (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 180, 0), 1, cv2.LINE_AA)

        cam_id = env.env.sim.model.camera_name2id(args.camera)
        fovy = float(env.env.sim.model.cam_fovy[cam_id])
        f = 0.5 * h / np.tan(fovy * np.pi / 360.0)
        fx, fy, cx, cy = f, f, w / 2.0, h / 2.0
        header = f"{args.benchmark} task={args.task_id} {demo_key} frame={frame_i} cam={args.camera} u={u} v={v} vf={vf}"
        cv2.putText(vis, header, (10, h - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(vis, header, (10, h - 24), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)
        intr = f"fx={fx:.1f} fy={fy:.1f} cx={cx:.1f} cy={cy:.1f}"
        cv2.putText(vis, intr, (10, h - 8), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (230, 230, 230), 1, cv2.LINE_AA)
        return vis

    # Single-frame image output
    vis = render_projection_frame(idx)
    out_path = args.out_image
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    cv2.imwrite(out_path, cv2.cvtColor(vis, cv2.COLOR_RGB2BGR))

    # Optional full-episode video output
    if args.out_video:
        os.makedirs(os.path.dirname(args.out_video), exist_ok=True)
        h0, w0 = vis.shape[:2]
        writer = cv2.VideoWriter(
            args.out_video,
            cv2.VideoWriter_fourcc(*"mp4v"),
            float(args.fps),
            (w0, h0),
        )
        if not writer.isOpened():
            raise RuntimeError(f"Failed to open video writer: {args.out_video}")
        for frame_i in range(states.shape[0]):
            vis_i = render_projection_frame(frame_i)
            if vis_i.shape[0] != h0 or vis_i.shape[1] != w0:
                vis_i = cv2.resize(vis_i, (w0, h0), interpolation=cv2.INTER_LINEAR)
            writer.write(cv2.cvtColor(vis_i, cv2.COLOR_RGB2BGR))
        writer.release()
        print(f"Wrote video: {args.out_video}")
    env.close()
    print(f"Wrote image: {out_path}")


if __name__ == "__main__":
    main()

