"""v3: render left arm into all 3 cameras simultaneously to disambiguate convention."""
import mujoco, numpy as np, pickle, cv2, sys

MJCF = "/home/robot-lab/raiden/third_party/i2rt/i2rt/robot_models/arm/yam/yam.xml"
DATA = "/home/robot-lab/data/processed/pickup_apple/0000"
FRAME = int(sys.argv[1]) if len(sys.argv) > 1 else 500
OUT = f"/home/robot-lab/cameron/yam_overlay/out_v3_allcam_f{FRAME:04d}.png"

fstr = f"{FRAME:010d}"

model = mujoco.MjModel.from_xml_path(MJCF)
data = mujoco.MjData(model)
with open(f"{DATA}/lowdim/{fstr}.pkl", "rb") as f:
    fd = pickle.load(f)

data.qpos[:6] = fd["joints"][:6]
mujoco.mj_kinematics(model, data)
body_pos_world = data.xpos.copy()

print(f"joints[:7] (left arm guess): {fd['joints'][:7]}")
print(f"joints[7:14] (right arm guess): {fd['joints'][7:14]}")

panels = []
for CAM in ["scene_1", "left_wrist", "right_wrist"]:
    K = np.array(fd["intrinsics"][CAM])
    T_cam2world = np.array(fd["extrinsics"][CAM])
    T_world2cam = np.linalg.inv(T_cam2world)
    img = cv2.imread(f"{DATA}/rgb/{CAM}/{fstr}.jpg")
    H, W = img.shape[:2]
    print(f"\n[{CAM}] cam-in-world: {T_cam2world[:3, 3]}")

    pts_h = np.hstack([body_pos_world, np.ones((model.nbody, 1))])
    pts_cam = (T_world2cam @ pts_h.T).T[:, :3]
    mask = pts_cam[:, 2] > 0.01
    denom = np.where(pts_cam[:, 2:3] != 0, pts_cam[:, 2:3], 1)
    uv = (K @ (pts_cam / denom).T).T[:, :2]

    for i in range(model.nbody):
        u, v = uv[i]
        ok = mask[i] and 0 <= u < W and 0 <= v < H
        st = "OK" if ok else ("OOB" if mask[i] else "BEHIND")
        print(f"  {model.body(i).name}: ({u:.0f}, {v:.0f}) {st}")

    for i in range(model.nbody):
        if not mask[i]: continue
        u, v = uv[i]
        if not (0 <= u < W and 0 <= v < H): continue
        cv2.circle(img, (int(u), int(v)), 8, (0, 255, 0), -1)
        cv2.putText(img, model.body(i).name, (int(u)+8, int(v)-4),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1, cv2.LINE_AA)
    for i in range(1, model.nbody):
        p = model.body_parentid[i]
        if not (mask[i] and mask[p]): continue
        p0, p1 = (int(uv[p,0]), int(uv[p,1])), (int(uv[i,0]), int(uv[i,1]))
        if all(0 <= c[0] < W and 0 <= c[1] < H for c in (p0, p1)):
            cv2.line(img, p0, p1, (0, 255, 0), 2)

    cv2.putText(img, CAM, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 0), 2, cv2.LINE_AA)
    panels.append(img)

# Resize all to same height, then hstack
H_target = 360
out_panels = []
for p in panels:
    h, w = p.shape[:2]
    nw = int(w * H_target / h)
    out_panels.append(cv2.resize(p, (nw, H_target)))
out = np.hstack(out_panels)
cv2.imwrite(OUT, out)
print(f"\nSaved: {OUT}")
