"""Test IK recovery: GT 3D keypoint → IK joints → render alongside GT.

Verifies that if our model predicts the correct 3D point, we can recover
good joint states via IK with fixed rotation.

Usage:
  cd /data/cameron/para/panda_streaming
  MUJOCO_GL=egl python test_ik_recovery.py \
    --data_dir /data/cameron/panda_data/data_20260420_115853_632_frames \
    --output test_ik_results.png
"""
import os, sys, argparse, json, glob
import numpy as np
import cv2
import mujoco
from scipy.spatial.transform import Rotation as Rot

sys.path.insert(0, os.path.dirname(__file__))
from data_panda_para import T_CAM_WORLD, CAM_K, CAM_K_448, IMAGE_SIZE, IMG_W, IMG_H
from ExoConfigs.panda_exo_handeye_4x2 import PANDA_HANDEYE_4X2_CONFIG as robot_config
from exo_utils import render_from_camera_pose, get_link_poses_from_robot, position_exoskeleton_meshes

N_ARM_JOINTS = 7
GRIPPER_POS_MAX = 0.04

# Fixed EEF rotation (pointing down)
FIXED_EEF_ROT = np.array([
    [1.0,  0.0,  0.0],
    [0.0, -1.0,  0.0],
    [0.0,  0.0, -1.0],
], dtype=np.float64)


def mujoco_ik(model, data, target_pos, target_rot, q_init, eef_body_id,
              max_iter=200, pos_weight=1.0, rot_weight=0.3, damping=1e-4):
    """Damped least-squares IK."""
    n = N_ARM_JOINTS
    data.qpos[:n] = q_init[:n].copy()
    mujoco.mj_forward(model, data)

    for _ in range(max_iter):
        cur_pos = data.xpos[eef_body_id].copy()
        cur_rot = data.xmat[eef_body_id].reshape(3, 3).copy()

        pos_err = target_pos - cur_pos
        R_err = target_rot @ cur_rot.T
        angle = np.arccos(np.clip((np.trace(R_err) - 1) / 2, -1, 1))
        if angle < 1e-6:
            rot_err = np.zeros(3)
        else:
            rot_err = angle / (2 * np.sin(angle + 1e-10)) * np.array([
                R_err[2, 1] - R_err[1, 2],
                R_err[0, 2] - R_err[2, 0],
                R_err[1, 0] - R_err[0, 1],
            ])

        if np.linalg.norm(pos_err) < 5e-4 and np.linalg.norm(rot_err) < 1e-3:
            break

        jacp = np.zeros((3, model.nv))
        jacr = np.zeros((3, model.nv))
        mujoco.mj_jacBody(model, data, jacp, jacr, eef_body_id)
        J_pos = jacp[:, :n]
        J_rot = jacr[:, :n]

        err = np.concatenate([pos_weight * pos_err, rot_weight * rot_err])
        J = np.vstack([pos_weight * J_pos, rot_weight * J_rot])
        JtJ = J.T @ J + damping * np.eye(n)
        dq = np.linalg.solve(JtJ, J.T @ err)

        max_step = 0.1
        if np.max(np.abs(dq)) > max_step:
            dq *= max_step / np.max(np.abs(dq))

        data.qpos[:n] += dq
        mujoco.mj_forward(model, data)

    return data.qpos[:n].copy()


def project_world_to_pixel_fullres(pos_3d):
    """Project 3D → pixel at 1920x1080."""
    p_cam = T_CAM_WORLD[:3, :3] @ pos_3d + T_CAM_WORLD[:3, 3]
    if p_cam[2] <= 0:
        return None
    u = CAM_K[0, 0] * p_cam[0] / p_cam[2] + CAM_K[0, 2]
    v = CAM_K[1, 1] * p_cam[1] / p_cam[2] + CAM_K[1, 2]
    return int(round(u)), int(round(v))


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--data_dir", required=True)
    p.add_argument("--output", default="test_ik_results.png")
    p.add_argument("--n_samples", type=int, default=8, help="Number of test frames")
    args = p.parse_args()

    # Load model
    model = mujoco.MjModel.from_xml_string(robot_config.xml)
    data = mujoco.MjData(model)
    hand_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, "hand")

    # Load episodes
    with open(os.path.join(args.data_dir, "episodes.json")) as f:
        episodes = json.load(f)["episodes"]

    # Collect test frames evenly across all episodes
    all_frames = []
    for ep in episodes:
        for idx in range(ep["start"], ep["end"] + 1):
            all_frames.append(idx)

    test_indices = np.linspace(0, len(all_frames) - 1, args.n_samples, dtype=int)
    test_frames = [all_frames[i] for i in test_indices]

    rows = []
    print(f"Testing IK recovery on {len(test_frames)} frames...")

    for frame_idx in test_frames:
        ts = f"{frame_idx:06d}"
        npy_path = os.path.join(args.data_dir, f"{ts}.npy")
        img_path = os.path.join(args.data_dir, f"{ts}.png")

        js_gt = np.load(npy_path).astype(np.float64)
        rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)

        # GT FK → EEF position
        data.qpos[:N_ARM_JOINTS] = js_gt[:N_ARM_JOINTS]
        gw = js_gt[7] if len(js_gt) > 7 else 1.0
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = gw * GRIPPER_POS_MAX
        mujoco.mj_forward(model, data)
        gt_eef_pos = data.xpos[hand_id].copy()
        gt_eef_rot = data.xmat[hand_id].reshape(3, 3).copy()

        # IK: recover joints from GT 3D position + fixed rotation
        # Start from neutral home pose (NOT GT joints) — realistic test
        HOME_Q = np.array([0.0, -0.785, 0.0, -2.356, 0.0, 1.571, 0.785])
        ik_q = mujoco_ik(model, data, gt_eef_pos, FIXED_EEF_ROT, HOME_Q, hand_id)

        # Check IK result
        data.qpos[:N_ARM_JOINTS] = ik_q
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = gw * GRIPPER_POS_MAX
        mujoco.mj_forward(model, data)
        ik_eef_pos = data.xpos[hand_id].copy()
        pos_err_mm = np.linalg.norm(ik_eef_pos - gt_eef_pos) * 1000

        print(f"  Frame {ts}: IK pos error = {pos_err_mm:.1f}mm")

        # Render GT joints
        data.qpos[:N_ARM_JOINTS] = js_gt[:N_ARM_JOINTS]
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = gw * GRIPPER_POS_MAX
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(robot_config, model, data,
            get_link_poses_from_robot(robot_config, model, data))
        gt_render = render_from_camera_pose(model, data, T_CAM_WORLD, CAM_K, 1080, 1920)

        # Render IK joints
        data.qpos[:N_ARM_JOINTS] = ik_q
        data.qpos[N_ARM_JOINTS] = data.qpos[N_ARM_JOINTS + 1] = gw * GRIPPER_POS_MAX
        mujoco.mj_forward(model, data)
        position_exoskeleton_meshes(robot_config, model, data,
            get_link_poses_from_robot(robot_config, model, data))
        ik_render = render_from_camera_pose(model, data, T_CAM_WORLD, CAM_K, 1080, 1920)

        # Build comparison: RGB | GT overlay | IK overlay
        H, W = rgb.shape[:2]
        gt_mask = np.any(gt_render > 10, axis=2)
        ik_mask = np.any(ik_render > 10, axis=2)

        gt_overlay = rgb.copy()
        gt_overlay[gt_mask] = (rgb[gt_mask] * 0.4 + gt_render[gt_mask] * 0.6).astype(np.uint8)

        ik_overlay = rgb.copy()
        ik_overlay[ik_mask] = (rgb[ik_mask] * 0.4 + ik_render[ik_mask] * 0.6).astype(np.uint8)

        # Draw EEF keypoints
        gt_pt = project_world_to_pixel_fullres(gt_eef_pos)
        ik_pt = project_world_to_pixel_fullres(ik_eef_pos)
        if gt_pt:
            cv2.circle(gt_overlay, gt_pt, 15, (255, 255, 255), -1)
            cv2.putText(gt_overlay, "GT EEF", (gt_pt[0]+18, gt_pt[1]-5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        if ik_pt:
            cv2.circle(ik_overlay, ik_pt, 15, (0, 255, 0), -1)
            cv2.putText(ik_overlay, f"IK EEF ({pos_err_mm:.1f}mm)", (ik_pt[0]+18, ik_pt[1]-5),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2)

        # Labels
        cv2.putText(gt_overlay, f"GT joints (frame {ts})", (20, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)
        cv2.putText(ik_overlay, f"IK fixed rot (frame {ts})", (20, 40),
                    cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 255, 0), 2)

        # Downsize for grid
        s = lambda img: cv2.resize(img, (W // 3, H // 3))
        row = np.hstack([s(rgb), s(gt_overlay), s(ik_overlay)])
        rows.append(row)

    grid = np.vstack(rows)
    cv2.imwrite(args.output, cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))
    print(f"\nSaved {args.output} ({grid.shape[1]}x{grid.shape[0]})")


if __name__ == "__main__":
    main()
