#!/usr/bin/env python3
"""Render the Panda + hand-eye ArUco board from a fixed camera viewpoint.

Usage:
  python hand_eye_calib/render_test.py
"""
import sys
import os

sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))

import cv2
import mujoco
import numpy as np
import matplotlib.pyplot as plt

from panda_exo_handeye import PANDA_HANDEYE_CONFIG, handeye_board, aruco_dict, BOARD_WIDTH, BOARD_HEIGHT
from exo_utils import position_exoskeleton_meshes, get_link_poses_from_robot, do_est_aruco_pose

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04

RENDER_W, RENDER_H = 1280, 960


def main() -> int:
    robot_config = PANDA_HANDEYE_CONFIG

    # Inject a fixed camera into the XML
    cam_xml = '<camera name="handeye_cam" pos="0.298 -1.736 1.314" xyaxes="1.000 0.010 0.000 -0.005 0.496 0.868" fovy="60"/>'
    xml = robot_config.xml.replace("</worldbody>", f"  {cam_xml}\n  </worldbody>")

    model = mujoco.MjModel.from_xml_string(xml)
    data = mujoco.MjData(model)

    # Calibration poses: 17 diverse configurations covering shoulder, elbow, and wrist
    # First 3 are the original user poses; rest add J1-J3 diversity for better coverage
    CALIB_POSES = [
        # Original 3 poses (wrist-varied)
        [0.0, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
        [0.0, -0.785, 0.0, -2.356, -0.8,  2.3,   2.3],
        [0.0, -0.785, 0.0, -2.2,   -0.1,  1.4,   2.7],
        # J1 variation (base rotation — swings hand left/right)
        [ 0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
        [-0.4, -0.785, 0.0, -2.356,  0.0,  1.571, 2.3],
        # J1 + wrist combos
        [ 0.3, -0.6,   0.0, -2.0,   -0.5,  1.8,   1.5],
        [-0.3, -1.0,   0.0, -2.5,    0.3,  1.2,   2.8],
        # J2 variation (shoulder pitch — raises/lowers elbow)
        [ 0.0, -0.4,   0.0, -2.0,    0.0,  1.571, 2.3],
        [ 0.0, -1.2,   0.0, -2.5,    0.0,  1.571, 2.3],
        # J3 variation (elbow rotation — twists upper arm)
        [ 0.0, -0.785, 0.5, -2.356,  0.0,  1.571, 2.3],
        [ 0.0, -0.785,-0.5, -2.356,  0.0,  1.571, 2.3],
        # Full-arm combos (all joints varied)
        [ 0.2, -0.6,   0.3, -1.8,   -0.4,  2.0,   1.8],
        [-0.2, -0.9,  -0.3, -2.6,    0.5,  1.0,   0.5],
        [ 0.5, -0.5,   0.2, -2.1,    0.8,  1.8,   0.8],
        [-0.5, -1.1,  -0.2, -2.4,   -0.6,  2.5,   1.2],
        [ 0.1, -0.7,   0.4, -1.9,    0.3,  0.8,   2.0],
        [-0.1, -0.85, -0.4, -2.7,   -0.3,  1.6,  -0.5],
    ]

    renderer = mujoco.Renderer(model, height=RENDER_H, width=RENDER_W)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "handeye_cam")

    # Compute camera intrinsics from MuJoCo camera (fovy=60°, square pixels)
    fovy_rad = np.radians(model.cam_fovy[cam_id])
    fy = (RENDER_H / 2.0) / np.tan(fovy_rad / 2.0)
    fx = fy
    cx, cy = RENDER_W / 2.0, RENDER_H / 2.0
    cam_K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)

    # Render all poses in a grid
    n = len(CALIB_POSES)
    cols = 5
    rows = (n + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(20, 4 * rows))
    axes = axes.flatten()

    detected_count = 0
    for i, qpos in enumerate(CALIB_POSES):
        data.qpos[:N_ARM_JOINTS] = np.array(qpos)
        data.ctrl[:N_ARM_JOINTS] = np.array(qpos)
        data.qpos[N_ARM_JOINTS] = GRIPPER_POS_MAX
        data.qpos[N_ARM_JOINTS + 1] = GRIPPER_POS_MAX

        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(
            robot_config, model, data, get_link_poses_from_robot(robot_config, model, data)
        )

        renderer.update_scene(data, camera=cam_id)
        img_rgb = renderer.render().copy()

        # Run ArUco detection
        img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
        result = do_est_aruco_pose(
            img_bgr, aruco_dict, handeye_board, BOARD_WIDTH,
            cameraMatrix=cam_K,
        )

        if result == -1:
            title = f"Pose {i+1}: no detection"
            vis = img_rgb
        else:
            n_markers = len(result["corners"][0])
            detected_count += 1
            title = f"Pose {i+1}: {n_markers} markers"
            # pose_vis is BGR with corners + axes drawn
            vis = cv2.cvtColor(result["pose_vis"], cv2.COLOR_BGR2RGB)

        axes[i].imshow(vis)
        axes[i].set_title(title, fontsize=9)
        axes[i].axis("off")

    # Hide unused subplots
    for j in range(n, len(axes)):
        axes[j].axis("off")

    fig.suptitle(f"Hand-eye calibration: {detected_count}/{n} poses with ArUco detection", fontsize=14)
    plt.tight_layout()
    plt.show()
    print(f"\nDetected board in {detected_count}/{n} poses")

    return 0


if __name__ == "__main__":
    raise SystemExit(main())
