"""Probe ArUco detection success across ZED resolutions for the YAM exo board.

Opens the scene camera (serial from raiden's camera.json) at each of HD720 /
HD1080 / HD2K, grabs N warm-up frames + one keep frame, saves the keep frame
to disk, and runs the same detection pipeline used by `rd exo_calibrate`.
Prints per-resolution: (W, H), K diag, marker count, board-corner residual.
"""
import os
os.environ.setdefault("MUJOCO_GL", "egl")

import sys, json, time
from pathlib import Path

import cv2
import numpy as np
import pyzed.sl as sl

sys.path.insert(0, "/home/robot-lab/cameron/raiden_fork/third_party/exo_redo")
from ExoConfigs.yam_exo import YAM_BASE_ONLY_CONFIG
from ExoConfigs.exoskeleton import link_to_aruco_transform
from exo_utils import do_est_aruco_pose, ARUCO_DICT


# Find the scene camera serial via raiden's camera.json
with open(os.path.expanduser("~/.config/raiden/camera.json")) as f:
    cam_cfg = json.load(f)
scene_serial = None
for name, entry in cam_cfg.items():
    if entry.get("role") == "scene":
        scene_serial = int(entry["serial"])
        scene_name = name
        break
if scene_serial is None:
    sys.exit("No scene camera in camera.json")
print(f"Scene cam: {scene_name} (serial {scene_serial})\n")

cfg = YAM_BASE_ONLY_CONFIG
link_cfg = cfg.links["larger_base"]
aruco_board = cfg.aruco_board_objects["larger_base"]
board_length = link_cfg.board_length
T_link_to_aruco = link_to_aruco_transform(link_cfg)
T_aruco_to_link = np.linalg.inv(T_link_to_aruco)

OUT_DIR = Path("/home/robot-lab/cameron/yam_overlay/aruco_probe")
OUT_DIR.mkdir(parents=True, exist_ok=True)

resolutions = [
    ("HD720",  sl.RESOLUTION.HD720),
    ("HD1080", sl.RESOLUTION.HD1080),
    ("HD2K",   sl.RESOLUTION.HD2K),
]

results = []
for label, res in resolutions:
    cam = sl.Camera()
    init = sl.InitParameters()
    init.set_from_serial_number(scene_serial)
    init.camera_resolution = res
    init.camera_fps = 15  # most-conservative across HD720/HD1080/HD2K
    init.depth_mode = sl.DEPTH_MODE.NONE
    init.coordinate_units = sl.UNIT.METER

    status = cam.open(init)
    if status != sl.ERROR_CODE.SUCCESS:
        print(f"[{label}] open failed: {status}")
        continue

    # Pull factory intrinsics for the chosen resolution
    info = cam.get_camera_information()
    cal = info.camera_configuration.calibration_parameters.left_cam
    W, H = info.camera_configuration.resolution.width, info.camera_configuration.resolution.height
    K = np.array([[cal.fx, 0, cal.cx], [0, cal.fy, cal.cy], [0, 0, 1]], dtype=np.float64)
    dist = np.array([cal.disto[0], cal.disto[1], cal.disto[2], cal.disto[3], cal.disto[4]],
                    dtype=np.float64)
    print(f"=== {label}: {W}x{H},  K diag=({cal.fx:.1f}, {cal.fy:.1f}),  "
          f"cxcy=({cal.cx:.1f}, {cal.cy:.1f}),  k1={cal.disto[0]:+.4f} ===")

    # Warm up + grab
    image = sl.Mat()
    rt = sl.RuntimeParameters()
    bgr = None
    for i in range(15):
        if cam.grab(rt) != sl.ERROR_CODE.SUCCESS:
            time.sleep(0.05)
            continue
        cam.retrieve_image(image, sl.VIEW.LEFT)
        # sl.Mat → np uint8 BGRA → BGR
        npm = image.get_data()  # (H, W, 4) BGRA
        bgr = cv2.cvtColor(npm, cv2.COLOR_BGRA2BGR)
    if bgr is None:
        print(f"[{label}] no frame after warm-up")
        cam.close()
        continue

    # Save raw frame
    raw_path = OUT_DIR / f"frame_{label}.png"
    cv2.imwrite(str(raw_path), bgr)
    print(f"  saved raw → {raw_path}")

    # Marker count (low-level, gives detection success irrespective of board pose)
    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
    corners, ids, rejected = cv2.aruco.detectMarkers(
        gray, ARUCO_DICT, parameters=cv2.aruco.DetectorParameters()
    )
    n_markers = 0 if ids is None else len(ids)
    n_rejected = 0 if rejected is None else len(rejected)
    print(f"  cv2.aruco.detectMarkers → {n_markers} markers, {n_rejected} rejected")

    # Full pose estimation
    try:
        result = do_est_aruco_pose(bgr, ARUCO_DICT, aruco_board, board_length,
                                   cameraMatrix=K, distCoeffs=dist)
    except Exception as e:
        result = -1
        print(f"  do_est_aruco_pose error: {e}")

    if result == -1:
        print(f"  POSE: FAILED (no board pose)")
        vis = bgr.copy()
        cv2.putText(vis, f"{label}: aruco pose failed (markers={n_markers})",
                    (20, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
    else:
        T_aruco_in_cam = result["est_aruco_pose"]
        T_link_in_cam = T_aruco_in_cam @ T_aruco_to_link
        T_cam_in_base = np.linalg.inv(T_link_in_cam)
        t = T_cam_in_base[:3, 3]
        # Reprojection residual
        obj_cam, img_pts = result["obj_img_pts"]
        rvec, tvec = result["rtvec"]
        proj, _ = cv2.projectPoints(
            (np.linalg.inv(cv2.Rodrigues(rvec)[0]) @ (obj_cam.T - tvec.reshape(3,1))).T,
            rvec, tvec, K, dist,
        )
        residual = np.linalg.norm(proj.reshape(-1, 2) - img_pts.reshape(-1, 2), axis=1).mean()
        print(f"  POSE OK: cam@base xyz=({t[0]:+.3f},{t[1]:+.3f},{t[2]:+.3f})  "
              f"mean reproj resid = {residual:.2f} px")
        vis = result["pose_vis"]

    annot_path = OUT_DIR / f"frame_{label}_annot.png"
    cv2.imwrite(str(annot_path), vis)
    print(f"  saved annotated → {annot_path}\n")

    cam.close()
    results.append({"label": label, "W": W, "H": H, "n_markers": n_markers,
                    "pose_ok": (result != -1) if isinstance(result, dict) or result == -1 else False})

print("=" * 60)
print("Summary:")
for r in results:
    print(f"  {r['label']:>6s}  {r['W']}x{r['H']}  markers={r['n_markers']:3d}  "
          f"pose_ok={r['pose_ok']}")
print(f"\nView the saved frames at:")
print(f"  /browse/yam_remote/cameron/yam_overlay/aruco_probe/")
