"""Visualize a fixed robot-frame voxel grid overlaid on a LIBERO scene.

Renders:
  - the 8 corner / 12 edge wireframe of the voxel bbox (white)
  - a sparse interior grid of voxel centers as dots, color-coded by z (blue→red)
  - the EEF home position (large lime dot)
  - the observed EEF trajectory from a demo (cyan polyline) for sanity

Usage:
  CUDA_VISIBLE_DEVICES=9 MUJOCO_GL=egl PYTHONPATH=/data/cameron/LIBERO:$PYTHONPATH \
  python render_robot_volume.py \
    --benchmark libero_spatial --task_id 0 --demo_idx 0 \
    --x_min -0.35 --x_max 0.25 --y_min -0.15 --y_max 0.35 --z_min 0.85 --z_max 1.35 \
    --n_xy 32 --n_z 32 \
    --out /tmp/robot_volume_overlay.png
"""
import argparse, os, sys
from pathlib import Path

import cv2
import h5py
import numpy as np

sys.path.insert(0, os.path.dirname(__file__))

from libero.libero import benchmark as bm_mod, get_libero_path
from libero.libero.envs import OffScreenRenderEnv
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    project_points_from_world_to_camera,
)

IMAGE_SIZE = 448


def project(world_points, world_to_camera, image_size):
    """world_points: (N, 3) → (N, 2) pixel (u, v) in flipud-image convention.
    Returns float coords; caller must mask out points outside image or behind camera.
    """
    pix_rc = project_points_from_world_to_camera(
        np.asarray(world_points, dtype=np.float64),
        world_to_camera_transform=world_to_camera,
        camera_height=image_size,
        camera_width=image_size,
    )
    # pix_rc is (N, 2) with [row, col]. row = v, col = u.
    uv = np.stack([pix_rc[:, 1], pix_rc[:, 0]], axis=1)
    return uv


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--benchmark", type=str, default="libero_spatial")
    p.add_argument("--task_id", type=int, default=0)
    p.add_argument("--camera", type=str, default="agentview")
    p.add_argument("--demo_idx", type=int, default=0)
    p.add_argument("--x_min", type=float, default=-0.35)
    p.add_argument("--x_max", type=float, default=0.25)
    p.add_argument("--y_min", type=float, default=-0.15)
    p.add_argument("--y_max", type=float, default=0.35)
    p.add_argument("--z_min", type=float, default=0.85)
    p.add_argument("--z_max", type=float, default=1.35)
    p.add_argument("--n_xy", type=int, default=32)
    p.add_argument("--n_z", type=int, default=32)
    p.add_argument("--draw_grid_stride", type=int, default=4,
                   help="Plot every Nth voxel center (4 = a 8x8x8 sampled grid of dots).")
    p.add_argument("--out", type=str, default="/tmp/robot_volume_overlay.png")
    p.add_argument("--show_trajectory", action="store_true", default=True)
    p.add_argument("--cache_root", type=str, default="/data/libero/parsed_libero")
    args = p.parse_args()

    # ── 1. Build env, reset to demo init state, render the scene
    bench = bm_mod.get_benchmark_dict()[args.benchmark]()
    task = bench.get_task(args.task_id)
    bddl = os.path.join(get_libero_path("bddl_files"), task.problem_folder, task.bddl_file)
    env = OffScreenRenderEnv(
        bddl_file_name=bddl,
        camera_heights=IMAGE_SIZE, camera_widths=IMAGE_SIZE,
        camera_names=[args.camera],
    )
    env.seed(0); env.reset()

    demo_path = os.path.join(get_libero_path("datasets"), bench.get_task_demonstration(args.task_id))
    with h5py.File(demo_path, "r") as f:
        keys = sorted([k for k in f["data"].keys() if k.startswith("demo_")])
        init_state = f[f"data/{keys[args.demo_idx]}/states"][0]
    obs = env.set_init_state(init_state)
    for _ in range(5):
        obs, _, _, _ = env.step(np.zeros(7, dtype=np.float32))

    rgb = obs[f"{args.camera}_image"].astype(np.uint8)
    rgb_disp = np.flipud(rgb).copy()                                  # training convention
    img = rgb_disp.copy()

    # ── 2. Camera matrices
    world_to_camera = get_camera_transform_matrix(env.sim, args.camera, IMAGE_SIZE, IMAGE_SIZE)

    # ── 3. Construct voxel grid corners + interior centers in world coords
    xs = np.linspace(args.x_min, args.x_max, args.n_xy)
    ys = np.linspace(args.y_min, args.y_max, args.n_xy)
    zs = np.linspace(args.z_min, args.z_max, args.n_z)

    # Bbox 8 corners
    corners = np.array([
        [args.x_min, args.y_min, args.z_min], [args.x_max, args.y_min, args.z_min],
        [args.x_max, args.y_max, args.z_min], [args.x_min, args.y_max, args.z_min],
        [args.x_min, args.y_min, args.z_max], [args.x_max, args.y_min, args.z_max],
        [args.x_max, args.y_max, args.z_max], [args.x_min, args.y_max, args.z_max],
    ])
    # 12 edges as pairs of corner indices
    edges = [(0,1),(1,2),(2,3),(3,0), (4,5),(5,6),(6,7),(7,4), (0,4),(1,5),(2,6),(3,7)]

    # Project corners
    corner_uv = project(corners, world_to_camera, IMAGE_SIZE).astype(int)

    # Draw the 12 edges as polyline segments — white
    for a, b in edges:
        cv2.line(img, tuple(corner_uv[a]), tuple(corner_uv[b]), (255, 255, 255), 2, cv2.LINE_AA)

    # ── 4. Interior voxel centers (sampled stride) as colored dots
    s = args.draw_grid_stride
    cx_idx = np.arange(0, args.n_xy, s)
    cz_idx = np.arange(0, args.n_z, s)
    interior_pts = []
    interior_z   = []
    for zi in cz_idx:
        z = zs[zi]
        for xi in cx_idx:
            for yi in cx_idx:
                interior_pts.append([xs[xi], ys[yi], z])
                interior_z.append(z)
    interior_pts = np.array(interior_pts)
    interior_z   = np.array(interior_z)

    interior_uv = project(interior_pts, world_to_camera, IMAGE_SIZE)
    # Color by z normalised to [0,1], blue→red ramp.
    # For dense renders, draw dots onto a separate canvas then alpha-blend so the scene shows through.
    z_norm = (interior_z - args.z_min) / (args.z_max - args.z_min + 1e-8)
    dot_layer = np.zeros_like(img)
    mask_layer = np.zeros((IMAGE_SIZE, IMAGE_SIZE), dtype=np.uint8)
    dot_radius = 1 if args.draw_grid_stride <= 2 else 2
    for (u, v), zn in zip(interior_uv, z_norm):
        if not (0 <= u < IMAGE_SIZE and 0 <= v < IMAGE_SIZE): continue
        r = int(255 * zn);  b = int(255 * (1 - zn))
        if dot_radius == 1:
            iu, iv = int(u), int(v)
            dot_layer[iv, iu] = (r, 80, b)
            mask_layer[iv, iu] = 1
        else:
            cv2.circle(dot_layer, (int(u), int(v)), dot_radius, (r, 80, b), -1, cv2.LINE_AA)
            cv2.circle(mask_layer, (int(u), int(v)), dot_radius, 1, -1)
    # Alpha-blend: where dots exist, mix 55% dot + 45% scene
    alpha = 0.55
    blend_mask = mask_layer.astype(bool)
    img[blend_mask] = (alpha * dot_layer[blend_mask] + (1 - alpha) * img[blend_mask]).astype(np.uint8)

    # ── 5. EEF home (lime) + (optional) demo trajectory (cyan)
    eef_now = np.asarray(obs["robot0_eef_pos"], dtype=np.float64)
    home_uv = project(eef_now.reshape(1, 3), world_to_camera, IMAGE_SIZE)[0].astype(int)
    cv2.circle(img, tuple(home_uv), 8, (180, 255, 0), -1, cv2.LINE_AA)
    cv2.putText(img, "EEF home", (home_uv[0] + 10, home_uv[1] - 6),
                cv2.FONT_HERSHEY_SIMPLEX, 0.45, (180, 255, 0), 1, cv2.LINE_AA)

    if args.show_trajectory:
        # Use the cached EEF trajectory from the corresponding demo if available
        demo_dir = Path(args.cache_root) / args.benchmark / f"task_{args.task_id}" / f"demo_{args.demo_idx}"
        traj_path = demo_dir / "eef_pos.npy"
        if traj_path.exists():
            traj = np.load(traj_path)
            traj_uv = project(traj, world_to_camera, IMAGE_SIZE).astype(int)
            for i in range(1, len(traj_uv)):
                cv2.line(img, tuple(traj_uv[i-1]), tuple(traj_uv[i]), (255, 255, 0), 1, cv2.LINE_AA)
            cv2.circle(img, tuple(traj_uv[-1]), 6, (255, 255, 0), 1, cv2.LINE_AA)

    # ── 6. Caption with bounds
    label = (f"X[{args.x_min:.2f},{args.x_max:.2f}] "
             f"Y[{args.y_min:.2f},{args.y_max:.2f}] "
             f"Z[{args.z_min:.2f},{args.z_max:.2f}] "
             f"n_xy={args.n_xy} n_z={args.n_z}")
    cv2.putText(img, label, (8, IMAGE_SIZE - 12),
                cv2.FONT_HERSHEY_SIMPLEX, 0.45, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(img, label, (8, IMAGE_SIZE - 12),
                cv2.FONT_HERSHEY_SIMPLEX, 0.45, (20, 20, 20), 1, cv2.LINE_AA)

    Path(args.out).parent.mkdir(parents=True, exist_ok=True)
    cv2.imwrite(args.out, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
    print(f"Saved: {args.out}")
    print(f"voxel grid: {args.n_xy}×{args.n_xy}×{args.n_z} = {args.n_xy*args.n_xy*args.n_z:,} voxels")
    print(f"cell size: dx={(args.x_max-args.x_min)/args.n_xy*1000:.1f}mm  "
          f"dy={(args.y_max-args.y_min)/args.n_xy*1000:.1f}mm  "
          f"dz={(args.z_max-args.z_min)/args.n_z*1000:.1f}mm")
    print(f"EEF home (world): {eef_now.tolist()}")


if __name__ == "__main__":
    main()
