"""Before/after diagnostic for the gripper rendering bug.

Renders one frame from cube_in_carton/0000 twice:
  - BEFORE: using the current puget yam_overlay_render.py as-is.
  - AFTER:  using a local fixed render_arm_mask that
              (a) sets BOTH gripper-slider qpos channels (joint7, joint8)
                  symmetrically (mj_forward doesn't enforce the MJCF equality
                  constraint that couples them), and
              (b) scales raiden's normalized gripper command (~0..1) to the
                  slider's physical range (~0..0.0475 m).
"""
import os
os.environ.setdefault("MUJOCO_GL", "egl")
os.environ.setdefault("PYOPENGL_PLATFORM", "egl")

import sys
import pickle
from pathlib import Path

import cv2
import mujoco
import numpy as np

sys.path.insert(0, "/home/robot-lab/cameron/yam_control")
from yam_overlay_render import (
    YAM_XML,
    build_xml_with_camera,
    fovy_from_K,
    render_arm_mask as render_arm_mask_BEFORE,  # buggy
    overlay_contours,
)


GRIPPER_SLIDER_MAX_M = 0.0475   # from i2rt MJCF: <joint range="0 0.0475" />
GRIPPER_CMD_MAX = 1.0           # raiden's FOLLOWER_HOME_POS gripper = 1.0 (open)


def render_arm_mask_FIXED(joints7, T_cam2base, K, W, H):
    """Same as render_arm_mask but with proper bimanual gripper handling."""
    cam_pos = T_cam2base[:3, 3]
    cam_rot = T_cam2base[:3, :3]
    fovy = fovy_from_K(K, H)
    xml = build_xml_with_camera(YAM_XML, cam_pos, cam_rot, fovy, W, H)
    model = mujoco.MjModel.from_xml_string(xml)
    data = mujoco.MjData(model)
    # Arm joints 0..5 → qpos[0..5]
    arm_n = min(6, len(joints7))
    data.qpos[:arm_n] = joints7[:arm_n]
    # Gripper: raiden cmd ∈ [0, ~1] → slider ∈ [0, 0.0475] m. Drive BOTH
    # fingers (joint7 = qpos[6], joint8 = qpos[7]) symmetrically.
    if len(joints7) >= 7 and model.nq >= 8:
        slider_m = float(np.clip(joints7[6], 0.0, GRIPPER_CMD_MAX)) * GRIPPER_SLIDER_MAX_M
        data.qpos[6] = slider_m
        data.qpos[7] = slider_m
    mujoco.mj_forward(model, data)
    renderer = mujoco.Renderer(model, height=H, width=W)
    cam_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_CAMERA, "posed")
    renderer.enable_depth_rendering()
    renderer.update_scene(data, camera=cam_id)
    depth = renderer.render().copy()
    renderer.close()
    far = float(depth.max())
    mask = (depth < far * 0.999).astype(np.uint8)
    return mask, depth


def render_bimanual_with(joints14, T_cam2world, T_lfr, K, W, H, fn):
    mask_l, _ = fn(joints14[:7], T_cam2world, K, W, H)
    T_cam2rbase = np.linalg.inv(T_lfr) @ T_cam2world
    mask_r, _ = fn(joints14[7:14], T_cam2rbase, K, W, H)
    return mask_l, mask_r


def main():
    task = "cube_in_carton"
    ep = "0000"
    cam = "scene_camera"
    frame = int(sys.argv[1]) if len(sys.argv) > 1 else 200
    out_dir = Path("/home/robot-lab/cameron/yam_overlay/gripper_diag")
    out_dir.mkdir(parents=True, exist_ok=True)

    fstr = f"{frame:010d}"
    ep_dir = Path(f"/home/robot-lab/data/processed/{task}/{ep}")
    with open(ep_dir / "lowdim" / f"{fstr}.pkl", "rb") as f:
        fd = pickle.load(f)
    joints = np.array(fd["joints"])
    T_lfr = np.array(fd["T_left_from_right"])
    K = np.array(fd["intrinsics"][cam])
    T_cam2world = np.array(fd["extrinsics"][cam])
    rgb = cv2.imread(str(ep_dir / "rgb" / cam / f"{fstr}.png"))
    H, W = rgb.shape[:2]
    print(f"frame {frame}  l_grip={joints[6]:.3f}  r_grip={joints[13]:.3f}")

    ml_b, mr_b = render_bimanual_with(joints, T_cam2world, T_lfr, K, W, H,
                                      render_arm_mask_BEFORE)
    overlay_b = overlay_contours(rgb, ml_b, mr_b)

    ml_a, mr_a = render_bimanual_with(joints, T_cam2world, T_lfr, K, W, H,
                                      render_arm_mask_FIXED)
    overlay_a = overlay_contours(rgb, ml_a, mr_a)

    cv2.putText(overlay_b, f"BEFORE  f={frame}  l_grip={joints[6]:.2f}  r_grip={joints[13]:.2f}",
                (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(overlay_a, f"AFTER  f={frame}  (joint7=joint8, scaled)",
                (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2, cv2.LINE_AA)
    panel = np.concatenate([overlay_b, overlay_a], axis=1)
    cv2.line(panel, (W, 0), (W, H), (40, 40, 40), 2)

    out_b = out_dir / f"f{frame:04d}_before.png"
    out_a = out_dir / f"f{frame:04d}_after.png"
    out_panel = out_dir / f"f{frame:04d}_before_after.png"
    cv2.imwrite(str(out_b), overlay_b)
    cv2.imwrite(str(out_a), overlay_a)
    cv2.imwrite(str(out_panel), panel)
    print(f"  saved {out_b}")
    print(f"  saved {out_a}")
    print(f"  saved {out_panel}")


if __name__ == "__main__":
    main()
