"""Reconstruct LIBERO robot trajectory from tuple representation and compare to GT.

Tuple per frame:
    (camera_pose, cam_K, eef_pixel_upright, eef_height_from_base, eef_quat_xyzw, gripper_amount)

Pipeline:
1) Load one demo trajectory from LIBERO hdf5.
2) For each GT simulator state, extract the tuple above.
3) Reconstruct 3D eef target from (pixel + height + camera).
4) Solve IK (position + rotation) for robot arm joints.
5) Compose predicted state by replacing robot joints in GT state.
6) Render side-by-side GT vs predicted and save video.
"""

import argparse
import os

import cv2
import h5py
import numpy as np

from libero.libero import benchmark, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)
import robosuite.utils.transform_utils as T


def to_upright_u8(img):
    """Convert any LIBERO obs image to upright uint8."""
    arr = np.asarray(img).copy()
    if arr.dtype != np.uint8:
        arr = (np.clip(arr, 0, 1) * 255).astype(np.uint8)
    return np.ascontiguousarray(np.flipud(arr))


def unproject_pixel_to_world_ray(u_raw: float, v_raw: float, camera_pose_c2w: np.ndarray, cam_k: np.ndarray):
    """Return camera origin and world ray direction for a pixel in raw image coords."""
    k_inv = np.linalg.inv(cam_k)
    ray_cam = k_inv @ np.array([u_raw, v_raw, 1.0], dtype=np.float64)
    ray_cam = ray_cam / max(np.linalg.norm(ray_cam), 1e-12)
    r_c2w = camera_pose_c2w[:3, :3]
    cam_pos = camera_pose_c2w[:3, 3].copy()
    ray_world = r_c2w @ ray_cam
    ray_world = ray_world / max(np.linalg.norm(ray_world), 1e-12)
    return cam_pos, ray_world


def recover_3d_from_upright_pixel_and_height(
    u_up: float,
    v_up: float,
    height_from_base: float,
    base_z: float,
    camera_pose_c2w: np.ndarray,
    cam_k: np.ndarray,
    img_h: int,
):
    """Recover 3D point from upright pixel + z-height (relative to robot base)."""
    # Convert upright v to raw render coordinates used in camera math.
    v_raw = (img_h - 1) - v_up
    u_raw = u_up

    target_z = base_z + height_from_base
    cam_pos, ray_world = unproject_pixel_to_world_ray(u_raw, v_raw, camera_pose_c2w, cam_k)
    if abs(ray_world[2]) < 1e-9:
        return None
    t = (target_z - cam_pos[2]) / ray_world[2]
    if t <= 0:
        return None
    return cam_pos + t * ray_world


def rotmat_to_rotvec(r: np.ndarray):
    """Log-map for SO(3): returns axis-angle vector."""
    tr = np.trace(r)
    cos_theta = np.clip((tr - 1.0) * 0.5, -1.0, 1.0)
    theta = np.arccos(cos_theta)
    if theta < 1e-8:
        return np.zeros(3, dtype=np.float64)
    w = np.array(
        [
            r[2, 1] - r[1, 2],
            r[0, 2] - r[2, 0],
            r[1, 0] - r[0, 1],
        ],
        dtype=np.float64,
    ) / (2.0 * np.sin(theta))
    return w * theta


def solve_ik_pose_dls(
    sim,
    site_id: int,
    site_name: str,
    qpos_arm_idx,
    qvel_arm_idx,
    target_pos: np.ndarray,
    target_rot: np.ndarray,
    q_init: np.ndarray,
    n_iters: int = 60,
    damping: float = 1e-3,
    step_scale: float = 0.5,
):
    """Damped least-squares IK on eef site for position+orientation."""
    data = sim.data
    model = sim.model

    q = q_init.copy()
    for _ in range(n_iters):
        data.qpos[qpos_arm_idx] = q
        sim.forward()

        cur_pos = data.site_xpos[site_id].copy()
        cur_rot = data.site_xmat[site_id].reshape(3, 3).copy()

        pos_err = target_pos - cur_pos
        r_err = target_rot @ cur_rot.T
        rot_err = rotmat_to_rotvec(r_err)

        err = np.concatenate([pos_err, rot_err], axis=0)
        if np.linalg.norm(err[:3]) < 1e-4 and np.linalg.norm(err[3:]) < 2e-3:
            break

        jacp = data.get_site_jacp(site_name).reshape(3, model.nv).astype(np.float64)
        jacr = data.get_site_jacr(site_name).reshape(3, model.nv).astype(np.float64)
        j = np.vstack([jacp[:, qvel_arm_idx], jacr[:, qvel_arm_idx]])  # (6, 7)

        # DLS: dq = J^T (J J^T + λI)^-1 e
        jj_t = j @ j.T
        dq = j.T @ np.linalg.solve(jj_t + damping * np.eye(6), err)
        q = q + step_scale * dq

    data.qpos[qpos_arm_idx] = q
    sim.forward()
    return q


def main():
    parser = argparse.ArgumentParser(description="Reconstruct LIBERO trajectory from tuple + IK")
    parser.add_argument("--benchmark", type=str, default="libero_spatial")
    parser.add_argument("--task-id", type=int, default=0)
    parser.add_argument("--demo-id", type=int, default=0)
    parser.add_argument("--camera", type=str, default="agentview")
    parser.add_argument("--camera-height", type=int, default=256)
    parser.add_argument("--camera-width", type=int, default=256)
    parser.add_argument(
        "--out-video",
        type=str,
        default="/Users/cameronsmith/Projects/robotics_testing/LIBERO/out/libero_tuple_reconstruct_demo0.mp4",
    )
    parser.add_argument("--fps", type=int, default=15)
    args = parser.parse_args()

    bench = benchmark.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    if not os.path.isfile(demo_path):
        raise FileNotFoundError(demo_path)

    with h5py.File(demo_path, "r") as f:
        demos = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        demo_key = demos[min(args.demo_id, len(demos) - 1)]
        states = f[f"data/{demo_key}/states"][()]
        actions = f[f"data/{demo_key}/actions"][()]
        print(f"Using {demo_key}: {states.shape[0]} states, {actions.shape[0]} actions")

    bddl_file = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=args.camera_height,
        camera_widths=args.camera_width,
        camera_names=[args.camera],
    )
    env.seed(0)
    env.reset()

    robot = env.env.robots[0]
    qpos_arm_idx = np.array(robot._ref_joint_pos_indexes, dtype=np.int64)
    qvel_arm_idx = np.array(robot._ref_joint_vel_indexes, dtype=np.int64)
    qpos_grip_idx = np.array(robot._ref_gripper_joint_pos_indexes, dtype=np.int64)
    eef_site_id = int(robot.eef_site_id)
    eef_site_name = env.env.sim.model.site_id2name(eef_site_id)
    base_body_id = env.env.sim.model.body_name2id("robot0_base")

    # Initialize writer after first frame render.
    writer = None
    os.makedirs(os.path.dirname(args.out_video), exist_ok=True)

    q_pred = None
    pos_errs = []
    rot_errs = []

    for i in range(states.shape[0]):
        # ---------- GT state ----------
        obs_gt = env.set_init_state(states[i])
        env.env.sim.forward()
        img_key = f"{args.camera}_image" if args.camera != "robot0_eye_in_hand" else "robot0_eye_in_hand_image"
        img_gt_obs = np.asarray(obs_gt[img_key]).copy()
        h, w = img_gt_obs.shape[:2]
        img_gt = to_upright_u8(img_gt_obs)

        eef_pos_gt = np.asarray(obs_gt["robot0_eef_pos"], dtype=np.float64)
        # Use site orientation directly; this is the frame used by jacobian IK target.
        eef_rot_gt = env.env.sim.data.site_xmat[eef_site_id].reshape(3, 3).copy()
        gripper_amount = float(np.mean(np.asarray(obs_gt["robot0_gripper_qpos"], dtype=np.float64)))
        base_z = float(env.env.sim.data.xpos[base_body_id][2])
        height_rel = float(eef_pos_gt[2] - base_z)

        # Camera tuple
        cam_pose_c2w = get_camera_extrinsic_matrix(env.env.sim, args.camera)
        cam_k = get_camera_intrinsic_matrix(env.env.sim, args.camera, h, w)
        world_to_camera = get_camera_transform_matrix(env.env.sim, args.camera, h, w)

        # Pixel tuple from GT eef projection (upright coordinates)
        pix_rc = project_points_from_world_to_camera(
            points=eef_pos_gt.reshape(1, 3),
            world_to_camera_transform=world_to_camera,
            camera_height=h,
            camera_width=w,
        )[0]
        v_raw, u_raw = int(pix_rc[0]), int(pix_rc[1])
        v_up = (h - 1) - v_raw
        u_up = u_raw

        # ---------- Reconstruct from tuple ----------
        eef_pos_rec = recover_3d_from_upright_pixel_and_height(
            u_up=u_up,
            v_up=v_up,
            height_from_base=height_rel,
            base_z=base_z,
            camera_pose_c2w=cam_pose_c2w,
            cam_k=cam_k,
            img_h=h,
        )
        if eef_pos_rec is None:
            eef_pos_rec = eef_pos_gt.copy()

        # Seed IK from previous prediction or GT arm pose.
        if q_pred is None:
            q_pred = env.env.sim.data.qpos[qpos_arm_idx].copy()

        q_pred = solve_ik_pose_dls(
            sim=env.env.sim,
            site_id=eef_site_id,
            site_name=eef_site_name,
            qpos_arm_idx=qpos_arm_idx,
            qvel_arm_idx=qvel_arm_idx,
            target_pos=eef_pos_rec,
            target_rot=eef_rot_gt,
            q_init=q_pred,
        )

        # Compose predicted state by modifying robot joints directly on top of GT state.
        env.set_init_state(states[i])
        env.env.sim.data.qpos[qpos_arm_idx] = q_pred
        # map scalar to both gripper joints symmetrically as best effort
        if len(qpos_grip_idx) >= 2:
            env.env.sim.data.qpos[qpos_grip_idx[0]] = gripper_amount
            env.env.sim.data.qpos[qpos_grip_idx[1]] = -gripper_amount
        elif len(qpos_grip_idx) == 1:
            env.env.sim.data.qpos[qpos_grip_idx[0]] = gripper_amount
        env.env.sim.forward()
        obs_pred = env.env._get_observations()
        img_pred_obs = np.asarray(obs_pred[img_key]).copy()
        img_pred = to_upright_u8(img_pred_obs)

        # Error metrics
        eef_pos_pred = np.asarray(obs_pred["robot0_eef_pos"], dtype=np.float64)
        eef_rot_pred = env.env.sim.data.site_xmat[eef_site_id].reshape(3, 3).copy()
        pos_err = float(np.linalg.norm(eef_pos_pred - eef_pos_gt))
        rot_err = float(np.linalg.norm(rotmat_to_rotvec(eef_rot_gt @ eef_rot_pred.T)))
        pos_errs.append(pos_err)
        rot_errs.append(rot_err)

        # Draw overlay text
        gt_vis = img_gt.copy()
        pred_vis = img_pred.copy()
        cv2.putText(gt_vis, "GT", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(gt_vis, "GT", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (20, 20, 20), 1, cv2.LINE_AA)
        cv2.putText(pred_vis, "RECON (tuple->3d->IK)", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(pred_vis, "RECON (tuple->3d->IK)", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (20, 20, 20), 1, cv2.LINE_AA)
        cv2.putText(
            pred_vis,
            f"frame={i} pos_err={pos_err*1000:.1f}mm rot_err={np.rad2deg(rot_err):.2f}deg",
            (8, h - 8),
            cv2.FONT_HERSHEY_SIMPLEX,
            0.45,
            (230, 230, 230),
            1,
            cv2.LINE_AA,
        )

        panel = np.concatenate([gt_vis, pred_vis], axis=1)

        if writer is None:
            hh, ww = panel.shape[:2]
            writer = cv2.VideoWriter(
                args.out_video,
                cv2.VideoWriter_fourcc(*"mp4v"),
                float(args.fps),
                (ww, hh),
            )
            if not writer.isOpened():
                raise RuntimeError(f"Failed opening video writer: {args.out_video}")

        writer.write(cv2.cvtColor(panel, cv2.COLOR_RGB2BGR))

    if writer is not None:
        writer.release()
    env.close()

    print(f"Wrote: {args.out_video}")
    if len(pos_errs) > 0:
        print(
            "Errors: "
            f"pos mean={np.mean(pos_errs)*1000:.2f}mm med={np.median(pos_errs)*1000:.2f}mm max={np.max(pos_errs)*1000:.2f}mm | "
            f"rot mean={np.rad2deg(np.mean(rot_errs)):.2f}deg med={np.rad2deg(np.median(rot_errs)):.2f}deg max={np.rad2deg(np.max(rot_errs)):.2f}deg"
        )


if __name__ == "__main__":
    main()

