"""v2: try BOTH extrinsic conventions and label which is which.
Render both projections side-by-side so we can see which one is correct."""
import mujoco, numpy as np, pickle, cv2, sys

MJCF = "/home/robot-lab/raiden/third_party/i2rt/i2rt/robot_models/arm/yam/yam.xml"
DATA = "/home/robot-lab/data/processed/pickup_apple/0000"
FRAME = int(sys.argv[1]) if len(sys.argv) > 1 else 500
CAM = sys.argv[2] if len(sys.argv) > 2 else "scene_1"
OUT = f"/home/robot-lab/cameron/yam_overlay/out_v2_{CAM}_f{FRAME:04d}.png"

fstr = f"{FRAME:010d}"

model = mujoco.MjModel.from_xml_path(MJCF)
data = mujoco.MjData(model)
with open(f"{DATA}/lowdim/{fstr}.pkl", "rb") as f:
    fd = pickle.load(f)

data.qpos[:6] = fd["joints"][:6]
mujoco.mj_kinematics(model, data)
body_pos_world = data.xpos.copy()

K = np.array(fd["intrinsics"][CAM])
T_stored = np.array(fd["extrinsics"][CAM])
print(f"T_stored ({CAM}) =\n{T_stored}")
print(f"K =\n{K}")
print(f"T_stored translation: {T_stored[:3, 3]}")
print(f"T_stored.inv() translation: {np.linalg.inv(T_stored)[:3, 3]}")

img = cv2.imread(f"{DATA}/rgb/{CAM}/{fstr}.jpg")
H, W = img.shape[:2]

# Side-by-side comparison
img_a = img.copy()  # interpret as cam2world (invert to project)
img_b = img.copy()  # interpret as world2cam (use directly)

def project(T_world2cam, K, pts_world, img, color, label):
    pts_h = np.hstack([pts_world, np.ones((pts_world.shape[0], 1))])
    pts_cam = (T_world2cam @ pts_h.T).T[:, :3]
    mask = pts_cam[:, 2] > 0.01
    denom = np.where(pts_cam[:, 2:3] != 0, pts_cam[:, 2:3], 1)
    uv = (K @ (pts_cam / denom).T).T[:, :2]
    print(f"\n  [{label}] projected uv:")
    for i in range(pts_world.shape[0]):
        u, v = uv[i]
        in_img = (0 <= u < W) and (0 <= v < H) and mask[i]
        bname = model.body(i).name
        print(f"    {bname}: ({u:.1f}, {v:.1f}) {'OK' if in_img else 'OOB' if mask[i] else 'BEHIND'}")
    # draw
    for i in range(pts_world.shape[0]):
        if not mask[i]: continue
        u, v = uv[i]
        if not (0 <= u < W and 0 <= v < H): continue
        cv2.circle(img, (int(u), int(v)), 8, color, -1)
        cv2.putText(img, model.body(i).name, (int(u)+8, int(v)-4),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.45, color, 1, cv2.LINE_AA)
    for i in range(1, pts_world.shape[0]):
        p = model.body_parentid[i]
        if not (mask[i] and mask[p]): continue
        p0, p1 = (int(uv[p,0]), int(uv[p,1])), (int(uv[i,0]), int(uv[i,1]))
        if all(0 <= c[0] < W and 0 <= c[1] < H for c in (p0, p1)):
            cv2.line(img, p0, p1, color, 2)
    cv2.putText(img, label, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2, cv2.LINE_AA)

# A: interpret as cam2world → invert
T_world2cam_a = np.linalg.inv(T_stored)
project(T_world2cam_a, K, body_pos_world, img_a, (0, 255, 0), "A: T_stored=cam2world (inv applied)")

# B: interpret as world2cam → use directly
T_world2cam_b = T_stored
project(T_world2cam_b, K, body_pos_world, img_b, (0, 200, 255), "B: T_stored=world2cam (no inv)")

# Side by side
out = np.hstack([img_a, img_b])
cv2.imwrite(OUT, out)
print(f"\nSaved: {OUT}")
