"""Verify IK + rotation convention by round-tripping action EE poses through mink.

For a sampled set of frames:
  1. Load action[0:3, 3:12]  (left EE world pose) and action[13:16, 16:25] (right EE in right_base).
  2. Run i2rt Kinematics.ik(target_pose, "grasp_site", init_q=data_joints[:7|7:]).
  3. FK(recovered) → compare to target → "pose round-trip" error (translation + rotation).
  4. Render YAM silhouette mask twice: once with data joints (control), once with IK joints.
  5. Compute mask IoU per arm. High IoU + low round-trip error → IK + convention solid.
  6. Save side-by-side panels for visual inspection.

Usage:
  python verify_ik_yam.py \\
    --root_dir ~/yam_para/data/test_calib_pickup_cube \\
    --episode 0000 \\
    --frames 0,500,1000,1500,2000,2500,3000 \\
    --out_dir ~/yam_para/ik_verify
"""
import argparse
import os
os.environ.setdefault("MUJOCO_GL", "egl")

import json
import pickle
from pathlib import Path

import cv2
import numpy as np
from scipy.spatial.transform import Rotation as ScipyR

from i2rt.robots.kinematics import Kinematics
from raiden._xml_paths import get_yam_4310_linear_xml_path

import sys
sys.path.insert(0, os.path.dirname(__file__))
from yam_overlay_render import render_arm_mask  # rebuilds XML per call; ok for low-N test


YAM_XML = get_yam_4310_linear_xml_path()
SCENE_CAM = "scene_camera"


def pose4x4(pos, rot9):
    T = np.eye(4)
    T[:3, :3] = np.asarray(rot9, dtype=np.float64).reshape(3, 3)
    T[:3, 3] = np.asarray(pos, dtype=np.float64)
    return T


def pose_err(T_pred, T_target):
    """Translation L2 (m) and rotation geodesic (deg) between two 4x4 SE(3)."""
    dt = float(np.linalg.norm(T_pred[:3, 3] - T_target[:3, 3]))
    R = T_pred[:3, :3] @ T_target[:3, :3].T
    cos = float(np.clip((np.trace(R) - 1) / 2, -1.0, 1.0))
    dr = float(np.degrees(np.arccos(cos)))
    return dt, dr


def iou(a, b):
    a = a.astype(bool); b = b.astype(bool)
    inter = int((a & b).sum())
    union = int((a | b).sum())
    return inter / max(union, 1)


def contour_overlay(rgb_bgr, mask_l, mask_r):
    out = rgb_bgr.copy()
    for mask, color in [(mask_l, (0, 255, 0)), (mask_r, (0, 0, 255))]:
        cs, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        cv2.drawContours(out, cs, -1, color, 2)
    return out


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--root_dir", type=str,
                   default=str(Path.home() / "yam_para/data/test_calib_pickup_cube"))
    p.add_argument("--episode", type=str, default="0000")
    p.add_argument("--frames", type=str, default="0,500,1000,1500,2000,2500,3000")
    p.add_argument("--out_dir", type=str, default=str(Path.home() / "yam_para/ik_verify"))
    p.add_argument("--cam", type=str, default=SCENE_CAM)
    p.add_argument("--cheat_init", type=int, default=1,
                   help="1: init mink IK from data joints (round-trip sanity). 0: init from zero pose.")
    p.add_argument("--ik_iters", type=int, default=200)
    args = p.parse_args()

    out_dir = Path(args.out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    ep_dir = Path(args.root_dir) / args.episode

    kin = Kinematics(YAM_XML, site_name="grasp_site")
    model_nq = int(kin._configuration.model.nq)
    print(f"YAM model nq = {model_nq}")

    def _pad(q, nq=model_nq):
        if len(q) >= nq:
            return q[:nq].astype(np.float64).copy()
        out = np.zeros(nq, dtype=np.float64)
        out[: len(q)] = q
        return out
    nq_full = None  # set after first IK

    results = []
    for frame in [int(x) for x in args.frames.split(",")]:
        fstr = f"{frame:010d}"
        pkl_path = ep_dir / "lowdim" / f"{fstr}.pkl"
        rgb_path = ep_dir / "rgb" / args.cam / f"{fstr}.png"
        if not pkl_path.exists() or not rgb_path.exists():
            print(f"[frame {frame}] missing pkl/rgb, skip")
            continue
        with open(pkl_path, "rb") as f:
            fd = pickle.load(f)
        action = np.asarray(fd["action"])
        joints = np.asarray(fd["joints"], dtype=np.float64)                       # (14,) — data ground-truth FK joints
        T_lfr = np.asarray(fd["T_left_from_right"], dtype=np.float64)
        K_intr = np.asarray(fd["intrinsics"][args.cam], dtype=np.float64)
        T_c2w = np.asarray(fd["extrinsics"][args.cam], dtype=np.float64)
        bgr = cv2.imread(str(rgb_path))
        H, W = bgr.shape[:2]

        # Target poses
        T_target_L = pose4x4(action[0:3],  action[3:12])                          # left in world (= left_arm_base)
        T_target_R_rb = pose4x4(action[13:16], action[16:25])                     # right in right_arm_base

        # Init q: cheat from data joints (round-trip test) or zeros (honest)
        n_data_per_arm = 7  # 6 arm + 1 gripper
        if args.cheat_init:
            init_q_L = _pad(joints[:n_data_per_arm])
            init_q_R = _pad(joints[n_data_per_arm:n_data_per_arm * 2])
        else:
            init_q_L = _pad(np.zeros(n_data_per_arm))
            init_q_R = _pad(np.zeros(n_data_per_arm))

        ok_L, q_L = kin.ik(T_target_L,    site_name="grasp_site", init_q=init_q_L, max_iters=args.ik_iters)
        ok_R, q_R = kin.ik(T_target_R_rb, site_name="grasp_site", init_q=init_q_R, max_iters=args.ik_iters)
        q_L = _pad(q_L); q_R = _pad(q_R)

        # FK round-trip
        T_fk_L = kin.fk(q_L, site_name="grasp_site")
        T_fk_R = kin.fk(q_R, site_name="grasp_site")
        dt_L, dr_L = pose_err(T_fk_L, T_target_L)
        dt_R, dr_R = pose_err(T_fk_R, T_target_R_rb)

        # Render with DATA joints (control)
        mask_L_data, _ = render_arm_mask(joints[:n_data_per_arm], T_c2w, K_intr, W, H)
        T_world2cam = np.linalg.inv(T_c2w)
        T_c2rb = np.linalg.inv(T_lfr) @ T_c2w
        mask_R_data, _ = render_arm_mask(joints[n_data_per_arm:n_data_per_arm * 2], T_c2rb, K_intr, W, H)

        # Render with IK joints
        if nq_full is None:
            nq_full = len(q_L)
        mask_L_ik, _ = render_arm_mask(q_L, T_c2w,  K_intr, W, H)
        mask_R_ik, _ = render_arm_mask(q_R, T_c2rb, K_intr, W, H)

        # Mask IoU
        iou_L = iou(mask_L_data, mask_L_ik)
        iou_R = iou(mask_R_data, mask_R_ik)

        # Side-by-side: raw | data-joints overlay | ik-joints overlay
        overlay_data = contour_overlay(bgr, mask_L_data, mask_R_data)
        overlay_ik   = contour_overlay(bgr, mask_L_ik,   mask_R_ik)
        sep = np.full((H, 16, 3), 60, dtype=np.uint8)
        panel = np.concatenate([bgr, sep, overlay_data, sep, overlay_ik], axis=1)
        cv2.putText(panel, f"frame {frame}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (255, 255, 255), 2)
        cv2.putText(panel, "raw", (W // 2 - 30, H - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        cv2.putText(panel, "data-joints FK overlay", (W + 16 + 60, H - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        cv2.putText(panel, "IK-recovered overlay",   (2 * W + 32 + 60, H - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
        out_path = out_dir / f"verify_ik_f{frame:04d}.png"
        cv2.imwrite(str(out_path), panel)

        print(
            f"frame={frame}  L: ok={ok_L} pose_err=({dt_L*1000:.1f}mm,{dr_L:.2f}°) IoU={iou_L:.3f}   "
            f"R: ok={ok_R} pose_err=({dt_R*1000:.1f}mm,{dr_R:.2f}°) IoU={iou_R:.3f}   →  {out_path}"
        )
        results.append({
            "frame": frame,
            "L": {"ok": bool(ok_L), "trans_err_mm": dt_L*1000, "rot_err_deg": dr_L, "iou": iou_L},
            "R": {"ok": bool(ok_R), "trans_err_mm": dt_R*1000, "rot_err_deg": dr_R, "iou": iou_R},
        })

    # Aggregate
    if results:
        def agg(side, k):
            return float(np.mean([r[side][k] for r in results]))
        print("\n=== AGG ===")
        for side in ("L", "R"):
            print(f"  {side}: trans_err={agg(side,'trans_err_mm'):.1f}mm  "
                  f"rot_err={agg(side,'rot_err_deg'):.2f}°  IoU={agg(side,'iou'):.3f}")
        with open(out_dir / "ik_verify_summary.json", "w") as f:
            json.dump(results, f, indent=2)
        print(f"Summary: {out_dir / 'ik_verify_summary.json'}")


if __name__ == "__main__":
    main()
