"""Unit test for the 2D->3D recovery used during model deployment.

For each frame in episode 0 of dataset_20260505_114857:
  1. q_motors[k] -> mj_forward -> ground-truth EEF position p_gt (world frame)
  2. Project p_gt through T_CAM_WORLD + K -> pixel (u, v)
  3. Take height = p_gt.z (world-z)
  4. recover_3d_from_direct_keypoint_and_height((u, v), height, c2w, K) -> p_rec
  5. Compare p_rec to p_gt

If the coordinate system is correct, continuous (un-rounded) recovery should
agree with the original to sub-micron precision (machine eps). We also
report the recovery error after rounding the pixel to integer at three
resolutions -- this is the inherent quantization the model would hit at its
argmax-on-64x64 heatmap output.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path

import numpy as np
import mujoco

# Canonical solver lives in the libero training utils
sys.path.insert(0, "/data/cameron/para/libero")
from utils import recover_3d_from_direct_keypoint_and_height  # noqa: E402

DATASET = Path("/data/cameron/mac_robot_datasets/dataset_20260505_114857")
SMITH300_XML = "/home/cameronsmith/mnt/mac/smith300_para_stuff/example_twolink.xml"
EEF_BODY = "virtual_gripper_keypoint"
PRED_SIZE = 64
IMAGE_SIZE = 448


def project_world_to_pixel(p_world, T_cam_world, K):
    p_cam = T_cam_world[:3, :3] @ p_world + T_cam_world[:3, 3]
    if p_cam[2] <= 0:
        return None
    u = K[0, 0] * p_cam[0] / p_cam[2] + K[0, 2]
    v = K[1, 1] * p_cam[1] / p_cam[2] + K[1, 2]
    return np.array([u, v], dtype=np.float64)


def main():
    meta = json.load(open(DATASET / "meta.json"))
    joints = np.load(DATASET / "joints.npz")
    q_motors = np.asarray(joints["q_motors"], dtype=np.float64)
    K_native = np.array(meta["K"], dtype=np.float64)  # 960x540
    T_camera_arucoBase = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
    T_W_baseBody = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
    T_CAM_WORLD = T_camera_arucoBase @ T_W_baseBody
    camera_pose_c2w = np.linalg.inv(T_CAM_WORLD)
    img_w, img_h = meta["image_size_wh"]

    mj_model = mujoco.MjModel.from_xml_path(SMITH300_XML)
    mj_data = mujoco.MjData(mj_model)
    eef_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, EEF_BODY)
    n_qpos = mj_model.nq

    eps = json.load(open(DATASET / "rgb_overlay" / "episodes.json"))["episodes"]
    ep = eps[0]
    print(f"Episode 0 ({ep['id']}): frames {ep['start']}..{ep['end']} "
          f"= {ep['end']-ep['start']+1} frames")
    print(f"camera_pos in world (m): {camera_pose_c2w[:3, 3].round(3)}")
    print(f"image_size_wh: {img_w}x{img_h}")

    # Three resolutions to exercise: native, model-input (448x448), and model
    # internal supervision (PRED_SIZE 64x64). Anisotropic K rescale matches
    # how the dataset/training scales K.
    cases = [
        ("native     960x540", img_w, img_h),
        (f"trained-on {IMAGE_SIZE}x{IMAGE_SIZE}", IMAGE_SIZE, IMAGE_SIZE),
        (f"PRED_SIZE  {PRED_SIZE}x{PRED_SIZE}", PRED_SIZE, PRED_SIZE),
    ]

    for name, W, H in cases:
        K = K_native.copy()
        K[0] *= W / img_w
        K[1] *= H / img_h
        print(f"\n=== {name} ===")
        print(f"K =\n{K}")

        cont_errs, round_errs = [], []
        out_of_frame = 0
        for k in range(int(ep["start"]), int(ep["end"]) + 1):
            q = np.zeros(n_qpos)
            q[: min(q_motors.shape[1], n_qpos)] = q_motors[k, :n_qpos]
            mj_data.qpos[:n_qpos] = q
            mujoco.mj_forward(mj_model, mj_data)
            p_gt = mj_data.xpos[eef_id].copy()

            pix = project_world_to_pixel(p_gt, T_CAM_WORLD, K)
            if pix is None:
                continue
            u, v = pix
            if not (0 <= u < W and 0 <= v < H):
                out_of_frame += 1

            height = float(p_gt[2])

            p_cont = recover_3d_from_direct_keypoint_and_height(
                (u, v), height, camera_pose_c2w, K)
            if p_cont is not None:
                cont_errs.append(float(np.linalg.norm(p_cont - p_gt)))

            ui, vi = int(round(u)), int(round(v))
            p_round = recover_3d_from_direct_keypoint_and_height(
                (ui, vi), height, camera_pose_c2w, K)
            if p_round is not None:
                round_errs.append(float(np.linalg.norm(p_round - p_gt)))

        if cont_errs:
            print(f"  continuous err  mean={np.mean(cont_errs)*1e9:7.3f} nm  "
                  f"max={np.max(cont_errs)*1e9:7.3f} nm  ({len(cont_errs)} samples)")
        if round_errs:
            print(f"  rounded   err   mean={np.mean(round_errs)*1000:7.3f} mm  "
                  f"max={np.max(round_errs)*1000:7.3f} mm  ({len(round_errs)} samples)")
        if out_of_frame:
            print(f"  WARN: {out_of_frame} GT pixels were outside image bounds")

    # Sanity threshold: continuous round-trip should be < 1 micron.
    print("\nVerdict: if continuous err is sub-micron at all 3 resolutions, "
          "the projection/unprojection inverse pair is consistent and "
          "recover_3d_from_direct_keypoint_and_height is the right solver "
          "for our smith300 coordinate system.")


if __name__ == "__main__":
    main()
