"""v5: project link_6 AND grasp_site (true gripper tip) for left arm only.
Add big gripper marker so we can see exactly where MJCF thinks the gripper is."""
import mujoco, numpy as np, pickle, cv2, sys

MJCF = "/home/robot-lab/raiden/third_party/i2rt/i2rt/robot_models/arm/yam/yam.xml"
DATA = "/home/robot-lab/data/processed/pickup_apple/0000"
FRAME = int(sys.argv[1]) if len(sys.argv) > 1 else 500
CAM = sys.argv[2] if len(sys.argv) > 2 else "scene_1"
OUT = f"/home/robot-lab/cameron/yam_overlay/out_v5_{CAM}_f{FRAME:04d}.png"

fstr = f"{FRAME:010d}"

model = mujoco.MjModel.from_xml_path(MJCF)
data = mujoco.MjData(model)
with open(f"{DATA}/lowdim/{fstr}.pkl", "rb") as f:
    fd = pickle.load(f)

print(f"sites: {[model.site(i).name for i in range(model.nsite)]}")
data.qpos[:6] = fd["joints"][:6]
mujoco.mj_kinematics(model, data)

link_pos = data.xpos.copy()
site_pos = data.site_xpos.copy()
print(f"\nleft arm joints: {fd['joints'][:6]}")
print(f"bodies world:")
for i in range(model.nbody):
    print(f"  {model.body(i).name}: {link_pos[i]}")
print(f"sites world:")
for i in range(model.nsite):
    print(f"  {model.site(i).name}: {site_pos[i]}")

K = np.array(fd["intrinsics"][CAM])
T_cam2world = np.array(fd["extrinsics"][CAM])
T_world2cam = np.linalg.inv(T_cam2world)
img = cv2.imread(f"{DATA}/rgb/{CAM}/{fstr}.jpg")
H, W = img.shape[:2]

def project_pts(pts_world):
    pts_h = np.hstack([pts_world, np.ones((pts_world.shape[0], 1))])
    pts_cam = (T_world2cam @ pts_h.T).T[:, :3]
    mask = pts_cam[:, 2] > 0.01
    denom = np.where(pts_cam[:, 2:3] != 0, pts_cam[:, 2:3], 1)
    uv = (K @ (pts_cam / denom).T).T[:, :2]
    return uv, mask

# skeleton (link positions)
uv_links, mask_links = project_pts(link_pos)
for i in range(model.nbody):
    if not mask_links[i]: continue
    u, v = uv_links[i]
    if not (0 <= u < W and 0 <= v < H): continue
    cv2.circle(img, (int(u), int(v)), 6, (0, 255, 0), -1)
    cv2.putText(img, model.body(i).name, (int(u)+8, int(v)-4),
                cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 255, 0), 1, cv2.LINE_AA)
for i in range(1, model.nbody):
    p = model.body_parentid[i]
    if not (mask_links[i] and mask_links[p]): continue
    p0, p1 = (int(uv_links[p,0]), int(uv_links[p,1])), (int(uv_links[i,0]), int(uv_links[i,1]))
    if all(0 <= c[0] < W and 0 <= c[1] < H for c in (p0, p1)):
        cv2.line(img, p0, p1, (0, 255, 0), 2)

# sites (tcp + grasp_site) - big markers, magenta + cyan
uv_sites, mask_sites = project_pts(site_pos)
site_colors = [(255, 0, 255), (0, 255, 255)]  # tcp magenta, grasp cyan
for i in range(model.nsite):
    print(f"  site {model.site(i).name} uv: {uv_sites[i]}, mask: {mask_sites[i]}")
    if not mask_sites[i]: continue
    u, v = uv_sites[i]
    if not (0 <= u < W and 0 <= v < H): continue
    cv2.drawMarker(img, (int(u), int(v)), site_colors[i], cv2.MARKER_CROSS, 30, 3)
    cv2.putText(img, model.site(i).name, (int(u)+15, int(v)-8),
                cv2.FONT_HERSHEY_SIMPLEX, 0.55, site_colors[i], 2, cv2.LINE_AA)

cv2.putText(img, f"frame {FRAME}  joints[:6]  (green=link, magenta=tcp, cyan=grasp_site)", (10, 30),
            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2, cv2.LINE_AA)

cv2.imwrite(OUT, img)
print(f"\nSaved: {OUT}")
