#!/usr/bin/env python3
"""Hand-eye calibration: recover camera intrinsics + extrinsics from
ArUco detections on the Panda end-effector, compare to MuJoCo GT.

Usage:
  python hand_eye_calib/calibrate.py
"""
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

import cv2
import mujoco
import numpy as np
from scipy.spatial.transform import Rotation as Rot
from scipy.optimize import minimize
import matplotlib.pyplot as plt

from panda_exo_handeye import (
    PANDA_HANDEYE_CONFIG, handeye_board, aruco_dict,
    BOARD_WIDTH, BOARD_HEIGHT,
)
from exo_utils import position_exoskeleton_meshes, get_link_poses_from_robot

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04
RENDER_W, RENDER_H = 1280, 960

CALIB_POSES = [
    [0.0, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [0.0, -0.785, 0.0, -2.356, -0.8,  2.3,   2.3],
    [0.0, -0.785, 0.0, -2.2,   -0.1,  1.4,   2.7],
    [ 0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [-0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
    [ 0.3, -0.6,   0.0, -2.0,   -0.5,  1.8,   1.5],
    [-0.3, -1.0,   0.0, -2.5,    0.3,  1.2,   2.8],
    [ 0.0, -0.4,   0.0, -2.0,    0.0,  1.571, 2.3],
    [ 0.0, -1.2,   0.0, -2.5,    0.0,  1.571, 2.3],
    [ 0.0, -0.785, 0.5, -2.356,  0.0,  1.571, 2.3],
    [ 0.0, -0.785,-0.5, -2.356,  0.0,  1.571, 2.3],
    [ 0.2, -0.6,   0.3, -1.8,   -0.4,  2.0,   1.8],
    [-0.2, -0.9,  -0.3, -2.6,    0.5,  1.0,   0.5],
    [ 0.5, -0.5,   0.2, -2.1,    0.8,  1.8,   0.8],
    [-0.5, -1.1,  -0.2, -2.4,   -0.6,  2.5,   1.2],
    [ 0.1, -0.7,   0.4, -1.9,    0.3,  0.8,   2.0],
    [-0.1, -0.85, -0.4, -2.7,   -0.3,  1.6,  -0.5],
]


def make_T(R_mat, t):
    T = np.eye(4, dtype=np.float64)
    T[:3, :3] = R_mat
    T[:3, 3] = np.asarray(t, dtype=np.float64).flatten()
    return T


def _params_to_T(params):
    T = np.eye(4, dtype=np.float64)
    T[:3, :3] = Rot.from_rotvec(params[:3]).as_matrix()
    T[:3, 3] = params[3:6]
    return T


def _T_to_params(T):
    return np.concatenate([Rot.from_matrix(T[:3, :3]).as_rotvec(), T[:3, 3]])


def solve_hand_eye(T_wh_list, T_cb_list):
    """Solve for T_cam_world and T_hand_board via nonlinear optimization.

    No privileged information needed — jointly optimizes both unknowns.
    Minimizes: sum_i || T_cw @ T_wh_i @ T_hb  -  T_cb_i ||
    Returns (T_cam_world, T_hand_board).
    """
    N = len(T_wh_list)

    def cost(x):
        T_cw = _params_to_T(x[:6])
        T_hb = _params_to_T(x[6:12])
        err = 0.0
        for i in range(N):
            T_pred = T_cw @ T_wh_list[i] @ T_hb
            T_obs = T_cb_list[i]
            err += np.sum((T_pred[:3, 3] - T_obs[:3, 3]) ** 2)
            dR = T_pred[:3, :3].T @ T_obs[:3, :3]
            err += 0.1 * np.sum((dR - np.eye(3)) ** 2)
        return err

    # Initial guess from cv2.calibrateHandEye (TSAI)
    R_g2b, t_g2b, R_t2c, t_t2c = [], [], [], []
    for T_wh, T_cb in zip(T_wh_list, T_cb_list):
        T_hw = np.linalg.inv(T_wh)
        R_g2b.append(T_hw[:3, :3])
        t_g2b.append(T_hw[:3, 3].reshape(3, 1))
        R_t2c.append(T_cb[:3, :3])
        t_t2c.append(T_cb[:3, 3].reshape(3, 1))

    try:
        R_cw0, t_cw0 = cv2.calibrateHandEye(
            R_g2b, t_g2b, R_t2c, t_t2c, method=cv2.CALIB_HAND_EYE_TSAI)
        T_cw0 = make_T(R_cw0.squeeze(), t_cw0.squeeze())
        T_hb0 = np.linalg.inv(T_wh_list[0]) @ np.linalg.inv(T_cw0) @ T_cb_list[0]
        x0 = np.concatenate([_T_to_params(T_cw0), _T_to_params(T_hb0)])
    except cv2.error:
        x0 = np.zeros(12)

    best_cost = float("inf")
    best_x = None

    for trial in range(15):
        x_init = x0 if trial == 0 else x0 + np.random.randn(12) * 0.3
        result = minimize(cost, x_init, method="L-BFGS-B", options={"maxiter": 2000})
        if result.fun < best_cost:
            best_cost = result.fun
            best_x = result.x

    return _params_to_T(best_x[:6]), _params_to_T(best_x[6:12])


def main() -> int:
    robot_config = PANDA_HANDEYE_CONFIG

    # Inject fixed camera
    cam_xml = ('<camera name="handeye_cam" pos="0.298 -1.736 1.314" '
               'xyaxes="1.000 0.010 0.000 -0.005 0.496 0.868" fovy="60"/>')
    xml = robot_config.xml.replace("</worldbody>", f"  {cam_xml}\n  </worldbody>")

    model = mujoco.MjModel.from_xml_string(xml)
    data = mujoco.MjData(model)

    renderer = mujoco.Renderer(model, height=RENDER_H, width=RENDER_W)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "handeye_cam")
    hand_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, "hand")

    # ── GT intrinsics ────────────────────────────────────────────────────
    fovy_rad = np.radians(model.cam_fovy[cam_id])
    fy_gt = (RENDER_H / 2.0) / np.tan(fovy_rad / 2.0)
    fx_gt = fy_gt
    K_gt = np.array([[fx_gt, 0, RENDER_W / 2.0],
                     [0, fy_gt, RENDER_H / 2.0],
                     [0, 0, 1]], dtype=np.float64)

    # ── GT extrinsics (need a forward pass to populate cam_xpos/cam_xmat) ─
    data.qpos[:N_ARM_JOINTS] = np.array(CALIB_POSES[0])
    mujoco.mj_forward(model, data)

    cam_pos_gt = data.cam_xpos[cam_id].copy()
    # cam_xmat columns = MuJoCo camera axes (x-right, y-up, z-back) in world
    R_world_cammj = data.cam_xmat[cam_id].reshape(3, 3).copy()
    # Convert to OpenCV camera convention (flip Y and Z)
    F = np.diag([1.0, -1.0, -1.0])
    R_cam_world_cv = F @ R_world_cammj.T
    t_cam_world_cv = -R_cam_world_cv @ cam_pos_gt
    T_cam_world_gt = make_T(R_cam_world_cv, t_cam_world_cv)

    print("=" * 60)
    print("GT camera pos (world):", cam_pos_gt)
    print("GT K:\n", K_gt)
    print("=" * 60)

    # ── GT Board-to-hand transform (for validation only) ──────────────────
    # This is the privileged sim information — we only use it to CHECK
    # how well the joint solver recovers T_hand_board.
    cfg = robot_config.links["hand_board"]
    R_hand_bc = Rot.from_euler("xyz", cfg.aruco_offset_rot).as_matrix()
    t_hand_bc = cfg.aruco_offset_pos / 1000.0  # mm → m
    T_hand_boardcenter = make_T(R_hand_bc, t_hand_bc)
    T_texflip = make_T(np.diag([1.0, -1.0, -1.0]), [0, 0, 0])
    T_bc_bo = make_T(np.eye(3), [-BOARD_WIDTH / 2, -BOARD_HEIGHT / 2, 0])
    T_hand_boardorigin_gt = T_hand_boardcenter @ T_texflip @ T_bc_bo

    # ── Collect detections ───────────────────────────────────────────────
    all_obj_pts = []
    all_img_pts = []
    all_T_world_hand = []
    all_images_rgb = []
    detected_idx = []

    for i, qpos in enumerate(CALIB_POSES):
        data.qpos[:N_ARM_JOINTS] = np.array(qpos)
        data.ctrl[:N_ARM_JOINTS] = np.array(qpos)
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = GRIPPER_POS_MAX
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(
            robot_config, model, data,
            get_link_poses_from_robot(robot_config, model, data),
        )

        # FK → hand pose in world
        hp = data.xpos[hand_body_id].copy()
        hq = data.xquat[hand_body_id].copy()
        T_wh = make_T(Rot.from_quat(hq[[1, 2, 3, 0]]).as_matrix(), hp)

        # Render
        renderer.update_scene(data, camera=cam_id)
        img_rgb = renderer.render().copy()
        img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
        all_images_rgb.append(img_rgb)

        # Detect
        gray = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
        detector = cv2.aruco.ArucoDetector(aruco_dict, cv2.aruco.DetectorParameters())
        corners, ids, _ = detector.detectMarkers(gray)
        if ids is None or len(ids) < 3:
            continue
        obj_pts, img_pts = handeye_board.matchImagePoints(corners, ids)
        if obj_pts is None or obj_pts.size == 0:
            continue

        all_obj_pts.append(obj_pts.reshape(-1, 1, 3).astype(np.float32))
        all_img_pts.append(img_pts.reshape(-1, 1, 2).astype(np.float32))
        all_T_world_hand.append(T_wh)
        detected_idx.append(i)

    n_det = len(detected_idx)
    print(f"\nDetected board in {n_det}/{len(CALIB_POSES)} poses")

    # ── Step 1: intrinsic calibration ────────────────────────────────────
    K_init = K_gt.copy()
    flags = (cv2.CALIB_FIX_ASPECT_RATIO
             | cv2.CALIB_ZERO_TANGENT_DIST
             | cv2.CALIB_FIX_K3)
    ret, K_est, dist_est, rvecs, tvecs = cv2.calibrateCamera(
        all_obj_pts, all_img_pts, (RENDER_W, RENDER_H), K_init.copy(), None,
        flags=flags,
    )
    print(f"\n{'=== INTRINSICS ===':^60}")
    print(f"  GT:  fx={K_gt[0,0]:.2f}  fy={K_gt[1,1]:.2f}  "
          f"cx={K_gt[0,2]:.2f}  cy={K_gt[1,2]:.2f}")
    print(f"  Est: fx={K_est[0,0]:.2f}  fy={K_est[1,1]:.2f}  "
          f"cx={K_est[0,2]:.2f}  cy={K_est[1,2]:.2f}")
    print(f"  Δfx={abs(K_est[0,0]-K_gt[0,0]):.2f}  "
          f"Δfy={abs(K_est[1,1]-K_gt[1,1]):.2f}  "
          f"Δcx={abs(K_est[0,2]-K_gt[0,2]):.2f}  "
          f"Δcy={abs(K_est[1,2]-K_gt[1,2]):.2f}")
    print(f"  Distortion: {dist_est.flatten()}")
    print(f"  Reprojection RMS: {ret:.4f} px")

    # ── Step 2: joint hand-eye calibration (no privileged info) ──────────
    # For each intrinsic set, run PnP per frame to get T_cam_board,
    # then jointly solve for T_cam_world and T_hand_board.
    results = {}  # K_label -> (T_cam_world_est, T_hand_board_est, K_used)
    for K_label, K_use in [("GT K", K_gt), ("Est K", K_est)]:
        print(f"\n  --- Joint hand-eye solver using {K_label} ---")

        # PnP per frame → T_cam_board
        all_T_cam_board = []
        for j in range(n_det):
            dist = None if K_label == "GT K" else dist_est
            ok, rvec_j, tvec_j = cv2.solvePnP(
                all_obj_pts[j], all_img_pts[j], K_use, dist
            )
            R_mat = cv2.Rodrigues(rvec_j)[0]
            T_cam_board = make_T(R_mat, tvec_j.flatten())
            all_T_cam_board.append(T_cam_board)

        # Joint optimization: solve for T_cam_world and T_hand_board
        # without knowing either one a priori
        T_cw_est, T_hb_est = solve_hand_eye(all_T_world_hand, all_T_cam_board)
        results[K_label] = (T_cw_est, T_hb_est, K_use)

        T_wc = np.linalg.inv(T_cw_est)
        pos_err = np.linalg.norm(T_wc[:3, 3] - cam_pos_gt) * 1000
        rot_err = np.degrees(
            (Rot.from_matrix(T_cw_est[:3, :3])
             * Rot.from_matrix(T_cam_world_gt[:3, :3]).inv()).magnitude()
        )
        print(f"    GT  cam pos: {cam_pos_gt}")
        print(f"    Est cam pos: {T_wc[:3, 3]}")
        print(f"    Position error:    {pos_err:.1f} mm")
        print(f"    Rotation error:    {rot_err:.2f} deg")

        # Validate: compare solved T_hand_board to GT
        hb_pos_err = np.linalg.norm(T_hb_est[:3, 3] - T_hand_boardorigin_gt[:3, 3]) * 1000
        hb_rot_err = np.degrees(
            (Rot.from_matrix(T_hb_est[:3, :3])
             * Rot.from_matrix(T_hand_boardorigin_gt[:3, :3]).inv()).magnitude()
        )
        print(f"    T_hand_board pos error: {hb_pos_err:.1f} mm")
        print(f"    T_hand_board rot error: {hb_rot_err:.2f} deg")

    # ── Step 3: re-render comparison ─────────────────────────────────────
    pick = 0
    pose_idx = detected_idx[pick]
    data.qpos[:N_ARM_JOINTS] = np.array(CALIB_POSES[pose_idx])
    data.ctrl[:N_ARM_JOINTS] = np.array(CALIB_POSES[pose_idx])
    data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = GRIPPER_POS_MAX
    mujoco.mj_forward(model, data)
    position_exoskeleton_meshes(
        robot_config, model, data,
        get_link_poses_from_robot(robot_config, model, data),
    )

    def render_from_T_cam_world(T_cw, K_for_fov):
        """Set MuJoCo camera from a w2c transform (OpenCV convention) and render."""
        T_wc = np.linalg.inv(T_cw)
        data.cam_xpos[cam_id] = T_wc[:3, 3]
        R_cw_cv = T_cw[:3, :3]
        R_world_cammj_ = (F @ R_cw_cv).T
        data.cam_xmat[cam_id] = R_world_cammj_.reshape(-1)
        model.cam_fovy[cam_id] = np.degrees(2 * np.arctan(RENDER_H / (2 * K_for_fov[1, 1])))
        renderer.update_scene(data, camera=cam_id)
        return renderer.render().copy()

    def pose_err(T_cw):
        T_wc = np.linalg.inv(T_cw)
        pe = np.linalg.norm(T_wc[:3, 3] - cam_pos_gt) * 1000
        re = np.degrees((Rot.from_matrix(T_cw[:3, :3])
                         * Rot.from_matrix(T_cam_world_gt[:3, :3]).inv()).magnitude())
        return pe, re

    img_gt = render_from_T_cam_world(T_cam_world_gt, K_gt)

    T_cw_gtk, _, K_gtk = results["GT K"]
    T_cw_estk, _, K_estk = results["Est K"]
    img_gtk = render_from_T_cam_world(T_cw_gtk, K_gt)
    img_estk = render_from_T_cam_world(T_cw_estk, K_estk)

    pe_gtk, re_gtk = pose_err(T_cw_gtk)
    pe_estk, re_estk = pose_err(T_cw_estk)

    fig, axes = plt.subplots(1, 4, figsize=(24, 6))

    axes[0].imshow(img_gt)
    axes[0].set_title("GT camera pose + GT K")
    axes[0].axis("off")

    axes[1].imshow(img_gtk)
    axes[1].set_title(f"Joint solver (GT K)\npos {pe_gtk:.1f}mm, rot {re_gtk:.2f}°")
    axes[1].axis("off")

    axes[2].imshow(img_estk)
    axes[2].set_title(f"Joint solver (Est K)\npos {pe_estk:.1f}mm, rot {re_estk:.2f}°")
    axes[2].axis("off")

    diff = np.abs(img_gt.astype(float) - img_gtk.astype(float)).mean(axis=2)
    axes[3].imshow(diff, cmap="hot")
    axes[3].set_title("Diff: GT vs Joint solver (GT K)")
    axes[3].axis("off")

    fig.suptitle(f"Hand-eye calibration — joint solver, no privileged info (pose {pose_idx+1})",
                 fontsize=14)
    plt.tight_layout()
    plt.show()

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
