"""Render YAM left arm skeleton onto image, given a frame from pickup_apple."""
import mujoco, numpy as np, pickle, cv2, sys

MJCF = "/home/robot-lab/raiden/third_party/i2rt/i2rt/robot_models/arm/yam/yam.xml"
DATA = "/home/robot-lab/data/processed/pickup_apple/0000"
FRAME = int(sys.argv[1]) if len(sys.argv) > 1 else 500
CAM = sys.argv[2] if len(sys.argv) > 2 else "scene_1"
OUT = f"/home/robot-lab/cameron/yam_overlay/out_v1_{CAM}_f{FRAME:04d}.png"

fstr = f"{FRAME:010d}"

model = mujoco.MjModel.from_xml_path(MJCF)
data = mujoco.MjData(model)
print(f"qpos size: {data.qpos.shape}, nbody: {model.nbody}")
for i in range(model.nbody):
    print(f"  body {i}: {model.body(i).name}")

with open(f"{DATA}/lowdim/{fstr}.pkl", "rb") as f:
    fd = pickle.load(f)

print(f"joints (full 14d): {fd['joints']}")
print(f"available cams: {list(fd['extrinsics'].keys())}")

# Set left arm 6 joints
data.qpos[:6] = fd["joints"][:6]
mujoco.mj_kinematics(model, data)

body_pos_world = data.xpos.copy()
print("world body positions:")
for i in range(model.nbody):
    print(f"  {model.body(i).name}: {body_pos_world[i]}")

K = np.array(fd["intrinsics"][CAM])
T_cam2world = np.array(fd["extrinsics"][CAM])
print(f"K =\n{K}")
print(f"T_cam2world =\n{T_cam2world}")
T_world2cam = np.linalg.inv(T_cam2world)

img = cv2.imread(f"{DATA}/rgb/{CAM}/{fstr}.jpg")
H, W = img.shape[:2]
print(f"img: {img.shape}")

pts_h = np.hstack([body_pos_world, np.ones((model.nbody, 1))])
pts_cam = (T_world2cam @ pts_h.T).T[:, :3]
mask = pts_cam[:, 2] > 0.01
denom = np.where(pts_cam[:, 2:3] != 0, pts_cam[:, 2:3], 1)
uv_h = (K @ (pts_cam / denom).T).T
uv = uv_h[:, :2]

print("projected uv:")
for i in range(model.nbody):
    if mask[i]:
        u, v = uv[i]
        in_img = (0 <= u < W) and (0 <= v < H)
        print(f"  {model.body(i).name}: ({u:.1f}, {v:.1f}) {'OK' if in_img else 'OOB'}")
    else:
        print(f"  {model.body(i).name}: behind cam")

for i in range(model.nbody):
    if not mask[i]: continue
    u, v = uv[i]
    if not (0 <= u < W and 0 <= v < H): continue
    cv2.circle(img, (int(u), int(v)), 8, (0, 255, 0), -1)
    cv2.putText(img, model.body(i).name, (int(u)+8, int(v)-4),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)

for i in range(1, model.nbody):
    p = model.body_parentid[i]
    if not (mask[i] and mask[p]): continue
    p0, p1 = (int(uv[p,0]), int(uv[p,1])), (int(uv[i,0]), int(uv[i,1]))
    if all(0 <= c[0] < W and 0 <= c[1] < H for c in (p0, p1)):
        cv2.line(img, p0, p1, (0, 255, 0), 2)

cv2.imwrite(OUT, img)
print(f"\nSaved: {OUT}")
