"""Render the fixed-world 32³ voxel bbox + interior dots onto a real mac-robot start frame.

Uses the same bounds (configurable) as the LIBERO sim render but with the mac dataset's
camera matrices: world → camera = T_camera_arucoBase @ T_W_baseBody_inv_aruco_offset,
then perspective + K (in pixel coords). NO flipud here — mac images are already upright.

Usage:
  python render_robot_volume_on_real.py \
    --session_dir /data/cameron/mac_robot_datasets/first_mobile_collection/dataset_20260512_161202 \
    --x_min -0.3 --x_max 0.3 --y_min -0.3 --y_max 0.3 --z_min 0.0 --z_max 0.5 \
    --out /data/cameron/para/.agents/reports/backbones/media/mac_volume_d12_161202.png
"""
import argparse, json, os
from pathlib import Path

import cv2
import numpy as np


def project(world_pts, world_to_camera, K):
    """world_pts: (N, 3); world_to_camera: (4, 4); K: (3, 3). Returns (N, 2) (u, v) pixel."""
    ones = np.ones((world_pts.shape[0], 1))
    pts_h = np.concatenate([world_pts, ones], axis=-1)         # (N, 4)
    cam_h = (world_to_camera @ pts_h.T).T                       # (N, 4) — camera-frame coords
    cam_xyz = cam_h[:, :3]
    z = np.clip(cam_xyz[:, 2], 1e-3, None)
    # Pinhole: pixel = K @ [x/z, y/z, 1]
    norm = cam_xyz[:, :2] / z[:, None]                          # (N, 2)
    homog = np.concatenate([norm, np.ones((norm.shape[0], 1))], axis=-1)  # (N, 3)
    pix = (K @ homog.T).T[:, :2]                                # (N, 2)
    return pix, cam_xyz[:, 2]                                   # also return depth for filtering


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--session_dir", type=str, required=True)
    p.add_argument("--frame_idx", type=int, default=0)
    p.add_argument("--x_min", type=float, default=-0.30)
    p.add_argument("--x_max", type=float, default= 0.30)
    p.add_argument("--y_min", type=float, default=-0.30)
    p.add_argument("--y_max", type=float, default= 0.30)
    p.add_argument("--z_min", type=float, default= 0.00)
    p.add_argument("--z_max", type=float, default= 0.50)
    p.add_argument("--n_xy", type=int, default=32)
    p.add_argument("--n_z",  type=int, default=32)
    p.add_argument("--draw_grid_stride", type=int, default=1)
    p.add_argument("--out", type=str, required=True)
    args = p.parse_args()

    sess = Path(args.session_dir)
    meta = json.load(open(sess / "meta.json"))
    W, H = meta["image_size_wh"]
    K = np.array(meta["K"], dtype=np.float64)                   # already scaled to image_size_wh
    T_cam_aru = np.array(meta["T_camera_arucoBase"], dtype=np.float64)
    T_W_bb    = np.array(meta["T_W_baseBody_inv_aruco_offset"], dtype=np.float64)
    world_to_camera = T_cam_aru @ T_W_bb                        # (4, 4)

    img_path = sess / f"rgb_{args.frame_idx:06d}.jpg"
    bgr = cv2.imread(str(img_path))
    if bgr is None:
        raise FileNotFoundError(img_path)
    img = bgr.copy()                                            # render in BGR; cv2 saves BGR

    # 8 bbox corners + 12 edges
    corners = np.array([
        [args.x_min, args.y_min, args.z_min], [args.x_max, args.y_min, args.z_min],
        [args.x_max, args.y_max, args.z_min], [args.x_min, args.y_max, args.z_min],
        [args.x_min, args.y_min, args.z_max], [args.x_max, args.y_min, args.z_max],
        [args.x_max, args.y_max, args.z_max], [args.x_min, args.y_max, args.z_max],
    ])
    edges = [(0,1),(1,2),(2,3),(3,0), (4,5),(5,6),(6,7),(7,4), (0,4),(1,5),(2,6),(3,7)]
    corner_uv, _ = project(corners, world_to_camera, K)
    for a, b in edges:
        cv2.line(img,
                  (int(corner_uv[a, 0]), int(corner_uv[a, 1])),
                  (int(corner_uv[b, 0]), int(corner_uv[b, 1])),
                  (255, 255, 255), 2, cv2.LINE_AA)

    # Interior voxel centers — alpha-blended dots
    xs = np.linspace(args.x_min + (args.x_max - args.x_min) / (2 * args.n_xy),
                     args.x_max - (args.x_max - args.x_min) / (2 * args.n_xy), args.n_xy)
    ys = np.linspace(args.y_min + (args.y_max - args.y_min) / (2 * args.n_xy),
                     args.y_max - (args.y_max - args.y_min) / (2 * args.n_xy), args.n_xy)
    zs = np.linspace(args.z_min + (args.z_max - args.z_min) / (2 * args.n_z),
                     args.z_max - (args.z_max - args.z_min) / (2 * args.n_z), args.n_z)
    s = args.draw_grid_stride
    pts = []
    pts_z = []
    for zi in range(0, args.n_z, s):
        for xi in range(0, args.n_xy, s):
            for yi in range(0, args.n_xy, s):
                pts.append([xs[xi], ys[yi], zs[zi]])
                pts_z.append(zs[zi])
    pts = np.array(pts); pts_z = np.array(pts_z)
    pix, depth = project(pts, world_to_camera, K)
    z_norm = (pts_z - args.z_min) / (args.z_max - args.z_min + 1e-8)

    dot_layer = np.zeros_like(img)
    mask_layer = np.zeros((H, W), dtype=np.uint8)
    radius = 1 if args.draw_grid_stride <= 2 else 2
    for (u, v), zn, d in zip(pix, z_norm, depth):
        iu, iv = int(round(u)), int(round(v))
        if not (0 <= iu < W and 0 <= iv < H): continue
        if d <= 0.05: continue   # behind / extremely close to camera
        r = int(255 * zn); b = int(255 * (1 - zn))      # BGR: blue=low z, red=high z
        if radius == 1:
            dot_layer[iv, iu] = (b, 80, r); mask_layer[iv, iu] = 1
        else:
            cv2.circle(dot_layer, (iu, iv), radius, (b, 80, r), -1, cv2.LINE_AA)
            cv2.circle(mask_layer, (iu, iv), radius, 1, -1)
    alpha = 0.55
    bm = mask_layer.astype(bool)
    img[bm] = (alpha * dot_layer[bm] + (1 - alpha) * img[bm]).astype(np.uint8)

    label = (f"{Path(args.session_dir).name}  fr={args.frame_idx}  "
             f"X[{args.x_min:.2f},{args.x_max:.2f}] "
             f"Y[{args.y_min:.2f},{args.y_max:.2f}] "
             f"Z[{args.z_min:.2f},{args.z_max:.2f}]")
    cv2.putText(img, label, (8, H - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(img, label, (8, H - 12), cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)

    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    cv2.imwrite(args.out, img)
    print(f"Saved: {args.out}")
    pct_in = (mask_layer.sum() / mask_layer.size) * 100
    print(f"voxels in-image: {bm.sum()} / {len(pts)} ({100 * bm.sum() / len(pts):.1f}%)  "
          f"image coverage: {pct_in:.1f}%")


if __name__ == "__main__":
    main()
