"""
Generate a 3D method visualization for PARA:
  1. Start at the policy camera view (what the robot sees)
  2. Smooth pull-back to an observer camera, showing the camera frustum
  3. Animate a ray through the GT target pixel with height bins colored by heatmap
  4. Fire more rays, then reveal the full heatmap volume
  5. Highlight the GT 3D point

Usage:
    export PYTHONPATH=/data/cameron/LIBERO:/data/cameron/para_normalized_losses/libero:$PYTHONPATH
    export DINO_REPO_DIR=/data/cameron/keygrip/dinov3
    export DINO_WEIGHTS_PATH=/data/cameron/keygrip/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth
    python ood_libero/generate_method_visualization.py
"""

import argparse
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import cv2
from pathlib import Path
from scipy.spatial.transform import Rotation as R

# ── LIBERO imports ──
from libero.libero.envs import OffScreenRenderEnv
from libero.libero import benchmark as bm_lib, get_libero_path
from robosuite.utils.camera_utils import (
    get_camera_transform_matrix,
    get_camera_extrinsic_matrix,
    get_camera_intrinsic_matrix,
    project_points_from_world_to_camera,
)
import h5py

# ── PARA model imports ──
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'libero'))
import model as model_module
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, PRED_SIZE
from utils import recover_3d_from_direct_keypoint_and_height, _unproject_2d_to_ray

# ── Constants ──
IMAGE_SIZE = 448
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
FPS = 30
TABLE_Z = 0.85

# ── Animation timing (in frames at 30fps) ──
# Start at policy camera (B), smoothly transition to observer (A), frustum at every frame
PHASE_TRANSITION = (0, 90)         # 3s: camera moves from policy→observer, re-render each frame
PHASE_ESTABLISH = (90, 120)        # 1s: hold at observer with frustum
PHASE_FIRST_RAY = (120, 195)      # 2.5s: target ray with height bins
PHASE_MORE_RAYS = (195, 225)      # 1s: additional sample rays
PHASE_FULL_VOLUME = (225, 270)    # 1.5s: full 16x16 volume appears ALL AT ONCE
PHASE_HIGHLIGHT = (270, 310)      # 1.3s: GT grasp highlight
PHASE_FADE = (310, 340)           # 1s: volume/rays fade out, only argmax dot remains
PHASE_SERVO = (340, 430)          # 3s: robot servos to argmax target
TOTAL_FRAMES = 430


# ===========================================================================
# Geometry helpers
# ===========================================================================

def smoothstep(t):
    """Smooth ease in-out [0, 1] → [0, 1]."""
    t = np.clip(t, 0, 1)
    return t * t * (3 - 2 * t)


def compute_frustum_corners(camera_pose, cam_K, image_size, depth):
    """Compute the 4 corners of the image plane at a given depth in world space.

    Returns:
        corners_world: (4, 3) array — TL, TR, BR, BL
    """
    K_inv = np.linalg.inv(cam_K)
    corners_px = np.array([
        [0, 0, 1],
        [image_size, 0, 1],
        [image_size, image_size, 1],
        [0, image_size, 1],
    ], dtype=np.float64)

    corners_world = []
    cam_pos = camera_pose[:3, 3]
    R_cam = camera_pose[:3, :3]
    for c in corners_px:
        ray_cam = K_inv @ c
        ray_cam = ray_cam / ray_cam[2] * depth  # scale to desired depth
        pt_world = R_cam @ ray_cam + cam_pos
        corners_world.append(pt_world)
    return np.array(corners_world)


def compute_ray_points(pixel_uv, camera_pose, cam_K, min_height, max_height, n_bins):
    """Compute 3D positions of height bins along the ray through a pixel.

    Returns:
        bin_points: (n_bins, 3) array of 3D world positions
        ray_start: (3,) camera position
        ray_end: (3,) intersection with table plane
    """
    cam_pos, ray_dir = _unproject_2d_to_ray(pixel_uv, camera_pose, cam_K)
    bin_points = []
    for i in range(n_bins):
        h = min_height + (i / max(n_bins - 1, 1)) * (max_height - min_height)
        if abs(ray_dir[2]) < 1e-6:
            continue
        t = (h - cam_pos[2]) / ray_dir[2]
        if t > 0:
            bin_points.append(cam_pos + t * ray_dir)

    # Ray endpoint at table height
    if abs(ray_dir[2]) > 1e-6:
        t_table = (TABLE_Z - cam_pos[2]) / ray_dir[2]
        ray_end = cam_pos + t_table * ray_dir
    else:
        ray_end = cam_pos + ray_dir * 2.0

    return np.array(bin_points), cam_pos.copy(), ray_end


def project_3d_to_2d(points_3d, world_to_camera, image_size):
    """Project 3D world points to 2D pixel coordinates in observer view.

    Args:
        points_3d: (N, 3) array
        world_to_camera: (4, 4) transform
        image_size: int

    Returns:
        pixels: list of (u, v) tuples or None for behind-camera points
    """
    results = []
    for pt in points_3d:
        pix_rc = project_points_from_world_to_camera(
            points=pt.reshape(1, 3).astype(np.float64),
            world_to_camera_transform=world_to_camera,
            camera_height=image_size,
            camera_width=image_size,
        )[0]
        u = int(round(float(pix_rc[1])))  # col
        v = int(round(float(pix_rc[0])))  # row
        if 0 <= u < image_size and 0 <= v < image_size:
            results.append((u, v))
        else:
            results.append(None)
    return results


def draw_line_3d(frame, p1_3d, p2_3d, world_to_camera, image_size, color, thickness=2, alpha=1.0):
    """Draw a line between two 3D points projected into the observer view."""
    pts = project_3d_to_2d(np.array([p1_3d, p2_3d]), world_to_camera, image_size)
    if pts[0] is not None and pts[1] is not None:
        if alpha < 1.0:
            overlay = frame.copy()
            cv2.line(overlay, pts[0], pts[1], color, thickness, cv2.LINE_AA)
            cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
        else:
            cv2.line(frame, pts[0], pts[1], color, thickness, cv2.LINE_AA)


def draw_circle_3d(frame, pt_3d, world_to_camera, image_size, color, radius=5, thickness=-1, alpha=1.0):
    """Draw a circle at a 3D point projected into the observer view, with optional alpha."""
    pts = project_3d_to_2d(np.array([pt_3d]), world_to_camera, image_size)
    if pts[0] is not None:
        if alpha < 0.95:
            overlay = frame.copy()
            cv2.circle(overlay, pts[0], radius, color, thickness, cv2.LINE_AA)
            cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0, frame)
        else:
            cv2.circle(frame, pts[0], radius, color, thickness, cv2.LINE_AA)


# ===========================================================================
# Colormap for heatmap bins
# ===========================================================================

def heatmap_color(value, alpha=1.0):
    """Map value [0, 1] to a color using a warm colormap.
    Returns (B, G, R) for OpenCV.
    """
    # Plasma-like colormap: dark purple → blue → cyan → yellow → white
    # Simple version using cv2.applyColorMap
    val_uint8 = np.clip(int(value * 255), 0, 255)
    color_img = cv2.applyColorMap(np.array([[val_uint8]], dtype=np.uint8), cv2.COLORMAP_PLASMA)
    b, g, r = int(color_img[0, 0, 0]), int(color_img[0, 0, 1]), int(color_img[0, 0, 2])
    return (b, g, r)


# ===========================================================================
# Floating image plane rendering
# ===========================================================================

def render_floating_image(frame, policy_image, frustum_corners_2d, alpha=0.7):
    """Warp the policy camera image into the observer view as a floating plane.

    Args:
        frame: observer view image (H, W, 3) uint8, modified in-place
        policy_image: (H, W, 3) uint8 RGB image from policy camera
        frustum_corners_2d: list of 4 (u, v) tuples (TL, TR, BR, BL) in observer view
        alpha: opacity of the floating image
    """
    if any(c is None for c in frustum_corners_2d):
        return

    h, w = policy_image.shape[:2]
    src_pts = np.array([[0, 0], [w, 0], [w, h], [0, h]], dtype=np.float32)
    dst_pts = np.array(frustum_corners_2d, dtype=np.float32)

    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    warped = cv2.warpPerspective(policy_image, M, (frame.shape[1], frame.shape[0]),
                                  flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)

    # Create mask for the warped region
    mask = np.zeros((frame.shape[0], frame.shape[1]), dtype=np.uint8)
    cv2.fillConvexPoly(mask, dst_pts.astype(np.int32), 255)

    # Alpha blend
    mask_3ch = mask[:, :, None].astype(np.float32) / 255.0
    frame_f = frame.astype(np.float32)
    warped_f = warped.astype(np.float32)
    blended = frame_f * (1 - mask_3ch * alpha) + warped_f * mask_3ch * alpha
    np.copyto(frame, blended.astype(np.uint8))


# ===========================================================================
# Camera interpolation
# ===========================================================================

def interpolate_camera(pos_start, quat_start, pos_end, quat_end, t):
    """Smoothly interpolate camera pose using smoothstep easing.

    Args:
        pos_start/end: (3,) camera positions
        quat_start/end: (4,) MuJoCo quaternions (w, x, y, z)
        t: interpolation factor [0, 1]

    Returns:
        pos, quat: interpolated camera pose
    """
    s = smoothstep(t)
    pos = pos_start * (1 - s) + pos_end * s

    # SLERP for quaternion
    q0 = np.array([quat_start[1], quat_start[2], quat_start[3], quat_start[0]])  # to scipy (x,y,z,w)
    q1 = np.array([quat_end[1], quat_end[2], quat_end[3], quat_end[0]])
    r0 = R.from_quat(q0)
    r1 = R.from_quat(q1)

    # Use SLERP
    from scipy.spatial.transform import Slerp
    slerp = Slerp([0, 1], R.concatenate([r0, r1]))
    r_interp = slerp(s)
    q_interp = r_interp.as_quat()  # (x, y, z, w)
    quat = np.array([q_interp[3], q_interp[0], q_interp[1], q_interp[2]])  # to MuJoCo (w,x,y,z)

    return pos, quat


def compute_observer_camera(policy_pos, policy_quat, look_at):
    """Compute an observer camera position that shows the frustum from the side.

    Places the observer further back, higher, and rotated ~60° to the side.
    The look_at is adjusted to be between camera and scene so both are in frame.
    """
    from scipy.spatial.transform import Rotation as R

    # Adjust look_at to midpoint between camera and table center so both are in frame
    look_at_adjusted = (policy_pos + look_at) * 0.5
    look_at_adjusted[2] = (policy_pos[2] + look_at[2]) * 0.45  # slightly below midpoint

    # Default direction: from adjusted look_at to policy camera
    default_dir = policy_pos - look_at_adjusted
    radius = np.linalg.norm(default_dir) * 2.2  # much further back to see full scene
    default_dir_norm = default_dir / np.linalg.norm(default_dir)

    # Rotate 50° azimuth, 10° elevation for a nice side view
    phi_rad = np.radians(50)
    theta_rad = np.radians(10)

    # Azimuth rotation around world Z
    cos_phi, sin_phi = np.cos(phi_rad), np.sin(phi_rad)
    rot_z = np.array([
        [cos_phi, -sin_phi, 0],
        [sin_phi, cos_phi, 0],
        [0, 0, 1]
    ])
    rotated_dir = rot_z @ default_dir_norm

    # Elevation rotation
    right = np.cross(np.array([0, 0, 1.0]), rotated_dir)
    if np.linalg.norm(right) > 1e-6:
        right = right / np.linalg.norm(right)
        rot_elev = R.from_rotvec(theta_rad * right).as_matrix()
        rotated_dir = rot_elev @ rotated_dir

    obs_pos = look_at_adjusted + radius * rotated_dir

    # Look-at quaternion — look at the adjusted midpoint
    forward = (look_at_adjusted - obs_pos)
    forward = forward / np.linalg.norm(forward)
    world_up = np.array([0, 0, 1.0])
    right = np.cross(forward, world_up)
    if np.linalg.norm(right) < 1e-6:
        right = np.array([1, 0, 0.0])
    right = right / np.linalg.norm(right)
    up = np.cross(right, forward)
    up = up / np.linalg.norm(up)

    cam_mat = np.stack([right, up, -forward], axis=1)
    cam_rot = R.from_matrix(cam_mat)
    quat_xyzw = cam_rot.as_quat()
    obs_quat = np.array([quat_xyzw[3], quat_xyzw[0], quat_xyzw[1], quat_xyzw[2]])

    return obs_pos, obs_quat


# ===========================================================================
# Drawing functions for each animation phase
# ===========================================================================

def draw_frustum(frame, cam_origin, frustum_corners, world_to_camera, image_size, alpha=1.0):
    """Draw the camera frustum (4 lines from origin to image plane corners + rectangle)."""
    color = (220, 240, 255)  # bright light blue

    # Lines from camera to corners
    for corner in frustum_corners:
        draw_line_3d(frame, cam_origin, corner, world_to_camera, image_size, color, 3, alpha)

    # Rectangle around image plane
    for i in range(4):
        draw_line_3d(frame, frustum_corners[i], frustum_corners[(i + 1) % 4],
                     world_to_camera, image_size, (255, 255, 255), 3, alpha)


def draw_ray_with_bins(frame, cam_origin, ray_end, bin_points, bin_probs,
                       world_to_camera, image_size, ray_color=(0, 200, 255),
                       n_bins_to_show=None, highlight_max=False, ray_alpha=0.7,
                       bin_radius=7):
    """Draw a ray from camera through the scene with colored height bins.

    All bins have the same size. Color differentiates probability (plasma colormap).
    Alpha also scales with probability: low-prob bins are nearly transparent (0.1),
    high-prob bins are fully opaque (1.0). This prevents dark blue dots from
    occluding the bright yellow/red high-probability bins behind them.
    """
    # Draw the ray line
    draw_line_3d(frame, cam_origin, ray_end, world_to_camera, image_size,
                 ray_color, 2, ray_alpha)

    if n_bins_to_show is None:
        n_bins_to_show = len(bin_points)

    max_idx = np.argmax(bin_probs[:len(bin_points)]) if len(bin_points) > 0 else -1
    p_max = max(bin_probs.max(), 1e-8)

    # Draw bins sorted by probability: low first so high-prob bins are on top
    indices = list(range(min(n_bins_to_show, len(bin_points))))
    indices.sort(key=lambda i: bin_probs[i] if i < len(bin_probs) else 0.0)

    for i in indices:
        prob = bin_probs[i] if i < len(bin_probs) else 0.0
        prob_norm = prob / p_max
        prob_vis = np.power(prob_norm, 0.5)
        color = heatmap_color(prob_vis)

        # Alpha: low prob → 0.1, high prob → 1.0
        bin_alpha = 0.1 + 0.9 * prob_vis

        if highlight_max and i == max_idx:
            draw_circle_3d(frame, bin_points[i], world_to_camera, image_size,
                          (255, 255, 255), bin_radius + 5, 3)
            draw_circle_3d(frame, bin_points[i], world_to_camera, image_size,
                          (0, 255, 200), bin_radius + 2, -1)
        else:
            draw_circle_3d(frame, bin_points[i], world_to_camera, image_size,
                          color, bin_radius, -1, alpha=bin_alpha)


def draw_gt_marker(frame, gt_3d, world_to_camera, image_size, pulse_t=0.0):
    """Draw a prominent marker at the GT 3D target position."""
    pts = project_3d_to_2d(np.array([gt_3d]), world_to_camera, image_size)
    if pts[0] is not None:
        u, v = pts[0]
        # Pulsing outer ring
        pulse = 1.0 + 0.3 * np.sin(pulse_t * 2 * np.pi)
        outer_r = int(18 * pulse)
        cv2.circle(frame, (u, v), outer_r, (0, 255, 100), 2, cv2.LINE_AA)
        cv2.circle(frame, (u, v), outer_r + 3, (0, 180, 70), 1, cv2.LINE_AA)
        # Filled inner dot
        cv2.circle(frame, (u, v), 6, (0, 255, 100), -1, cv2.LINE_AA)
        # Crosshair
        cv2.drawMarker(frame, (u, v), (255, 255, 255), cv2.MARKER_CROSS, 28, 2, cv2.LINE_AA)


def draw_camera_icon(frame, cam_pos, world_to_camera, image_size):
    """Draw a camera icon at the policy camera position."""
    pts = project_3d_to_2d(np.array([cam_pos]), world_to_camera, image_size)
    if pts[0] is not None:
        u, v = pts[0]
        # Camera body — larger and brighter
        cv2.rectangle(frame, (u - 12, v - 9), (u + 12, v + 9), (220, 220, 220), -1)
        cv2.rectangle(frame, (u - 12, v - 9), (u + 12, v + 9), (255, 255, 255), 2)
        # Lens
        cv2.circle(frame, (u, v), 5, (40, 40, 40), -1)
        cv2.circle(frame, (u, v), 5, (100, 100, 100), 1)


# ===========================================================================
# Text overlay
# ===========================================================================

def add_text(frame, text, position="bottom", font_scale=0.65, thickness=2, color=(255, 255, 255)):
    """Add text overlay with background."""
    h, w = frame.shape[:2]
    font = cv2.FONT_HERSHEY_SIMPLEX
    (tw, th), baseline = cv2.getTextSize(text, font, font_scale, thickness)

    if position == "bottom":
        x, y = (w - tw) // 2, h - 25
    elif position == "top":
        x, y = (w - tw) // 2, th + 15
    elif position == "top-left":
        x, y = 15, th + 15
    else:
        x, y = position

    # Background
    cv2.rectangle(frame, (x - 6, y - th - 6), (x + tw + 6, y + baseline + 6),
                  (0, 0, 0), -1)
    cv2.putText(frame, text, (x, y), font, font_scale, color, thickness, cv2.LINE_AA)


# ===========================================================================
# Compute sample ray pixels for the volume visualization
# ===========================================================================

def get_sample_pixels(gt_pixel, image_size, pred_size, n_extra=4):
    """Get a set of sample pixels for the volume visualization.

    Returns the GT pixel first, then n_extra surrounding pixels spread across
    the image to show volume coverage.
    """
    pixels = [gt_pixel]

    # Spread additional pixels around the image (not random, deterministic)
    offsets = [
        (-0.2, -0.15),   # upper left of GT
        (0.15, -0.2),    # upper right of GT
        (-0.15, 0.2),    # lower left of GT
        (0.2, 0.15),     # lower right of GT
        (0.0, -0.3),     # above GT
        (-0.3, 0.0),     # left of GT
    ]

    for dx, dy in offsets[:n_extra]:
        px = np.clip(gt_pixel[0] + dx * image_size, 10, image_size - 10)
        py = np.clip(gt_pixel[1] + dy * image_size, 10, image_size - 10)
        pixels.append(np.array([px, py]))

    return pixels


def get_volume_pixels(image_size, grid_size=16, gt_pixel=None):
    """Get a grid of pixels for the full volume visualization.

    Generates a grid_size x grid_size grid of evenly spaced pixels.
    If gt_pixel is provided, snaps the nearest grid point to it so
    the GT ray is always included.
    """
    stride = image_size / grid_size
    pixels = []
    gt_replaced = False

    for iy in range(grid_size):
        for ix in range(grid_size):
            u = (ix + 0.5) * stride
            v = (iy + 0.5) * stride
            pixels.append(np.array([u, v]))

    # Replace the nearest grid point with the exact GT pixel
    if gt_pixel is not None:
        min_dist = float('inf')
        min_idx = 0
        for i, px in enumerate(pixels):
            d = np.linalg.norm(px - gt_pixel)
            if d < min_dist:
                min_dist = d
                min_idx = i
        pixels[min_idx] = gt_pixel.copy()

    return pixels


# ===========================================================================
# Main
# ===========================================================================

def main():
    parser = argparse.ArgumentParser(description="PARA method 3D visualization")
    parser.add_argument("--output_dir", type=str,
                        default="/data/cameron/para/.agents/reports/project_site/media/")
    parser.add_argument("--checkpoint", type=str,
                        default="/data/cameron/para_normalized_losses/libero/checkpoints/para_v2_exp4_n64/best.pth")
    parser.add_argument("--device", type=str, default=None)
    parser.add_argument("--demo_idx", type=int, default=0)
    parser.add_argument("--frame_idx", type=int, default=20,
                        help="Demo frame to visualize (pick one with clear target)")
    parser.add_argument("--image_plane_depth", type=float, default=0.45,
                        help="Depth of the floating image plane from camera (meters)")
    parser.add_argument("--volume_grid", type=int, default=16,
                        help="Grid size for full volume visualization (16 = 16x16 = 256 rays)")
    args = parser.parse_args()

    device = torch.device(args.device if args.device else
                          "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}")

    # ── Load PARA model ──
    print("Loading PARA model...")
    ckpt = torch.load(args.checkpoint, map_location=device)
    # Set module-level constants from checkpoint
    model_module.MIN_HEIGHT = float(ckpt.get("min_height", model_module.MIN_HEIGHT))
    model_module.MAX_HEIGHT = float(ckpt.get("max_height", model_module.MAX_HEIGHT))
    model_module.MIN_GRIPPER = float(ckpt.get("min_gripper", model_module.MIN_GRIPPER))
    model_module.MAX_GRIPPER = float(ckpt.get("max_gripper", model_module.MAX_GRIPPER))
    if "min_rot" in ckpt:
        model_module.MIN_ROT = ckpt["min_rot"] if isinstance(ckpt["min_rot"], list) else ckpt["min_rot"].tolist()
        model_module.MAX_ROT = ckpt["max_rot"] if isinstance(ckpt["max_rot"], list) else ckpt["max_rot"].tolist()

    # Infer n_window from volume_head
    vol_w = ckpt["model_state_dict"]["volume_head.weight"]
    n_window = vol_w.shape[0] // N_HEIGHT_BINS
    print(f"  N_WINDOW={n_window}, height=[{model_module.MIN_HEIGHT:.4f}, {model_module.MAX_HEIGHT:.4f}]")

    model = TrajectoryHeatmapPredictor(n_window=n_window)
    model.load_state_dict(ckpt["model_state_dict"], strict=False)
    model = model.to(device).eval()

    # ── Initialize LIBERO environment ──
    print("Initializing LIBERO environment...")
    benchmark = bm_lib.get_benchmark_dict()["libero_spatial"]()
    task = benchmark.get_task(0)
    bddl_file = os.path.join(get_libero_path("bddl_files"),
                              task.problem_folder, task.bddl_file)

    env = OffScreenRenderEnv(
        bddl_file_name=bddl_file,
        camera_heights=IMAGE_SIZE,
        camera_widths=IMAGE_SIZE,
        camera_names=["agentview"],
    )
    env.seed(0)
    env.reset()
    sim = env.env.sim

    # ── Load demo state ──
    demo_path = os.path.join(get_libero_path("datasets"),
                              benchmark.get_task_demonstration(0))
    with h5py.File(demo_path, "r") as f:
        demo_key = f"data/demo_{args.demo_idx}"
        states = np.array(f[f"{demo_key}/states"])
        actions = np.array(f[f"{demo_key}/actions"])
        state = states[args.frame_idx].copy()
        # Find grasp frame (first gripper close transition)
        gripper_actions = actions[:, 6]
        grasp_frame = args.frame_idx  # fallback
        for t in range(1, len(gripper_actions)):
            if gripper_actions[t - 1] < 0 and gripper_actions[t] > 0:
                grasp_frame = t
                break
        print(f"  Grasp frame: {grasp_frame}")
        obs = env.set_init_state(state)
        sim.forward()

    # ── Clean scene ──
    print("Cleaning scene...")
    # Hide furniture
    for name in ["wooden_cabinet_1_main", "flat_stove_1_main"]:
        try:
            bid = sim.model.body_name2id(name)
            sim.model.body_pos[bid] = np.array([0, 0, -5.0])
        except Exception:
            pass
    # Hide distractors
    distractor_names = ["akita_black_bowl_2_main", "cookies_1_main",
                        "glazed_rim_porcelain_ramekin_1_main"]
    distractor_bids = set()
    for name in distractor_names:
        try:
            distractor_bids.add(sim.model.body_name2id(name))
        except Exception:
            pass
    for geom_id in range(sim.model.ngeom):
        if sim.model.geom_bodyid[geom_id] in distractor_bids:
            sim.model.geom_rgba[geom_id][3] = 0.0
    sim.forward()

    # ── Re-render after cleaning ──
    obs = env.set_init_state(state)
    sim.forward()
    obs = env.env._get_observations()

    # ── Get camera parameters ──
    cam_name = "agentview"
    cam_id = sim.model.camera_name2id(cam_name)
    policy_cam_pos = sim.model.cam_pos[cam_id].copy()
    policy_cam_quat = sim.model.cam_quat[cam_id].copy()

    world_to_camera = get_camera_transform_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
    camera_pose = get_camera_extrinsic_matrix(sim, cam_name)
    cam_K_norm = get_camera_intrinsic_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
    cam_K_norm[0] /= IMAGE_SIZE
    cam_K_norm[1] /= IMAGE_SIZE
    cam_K = cam_K_norm.copy()
    cam_K[0] *= IMAGE_SIZE
    cam_K[1] *= IMAGE_SIZE

    print(f"  Policy camera pos: {policy_cam_pos}")

    # ── Get GT EEF position ──
    eef_pos = np.array(obs["robot0_eef_pos"], dtype=np.float64)
    print(f"  EEF position: {eef_pos}")

    # Get GT target: the grasp point (where gripper closes on the object)
    obs_grasp = env.set_init_state(states[grasp_frame])
    sim.forward()
    obs_grasp = env.env._get_observations()
    grasp_eef = np.array(obs_grasp["robot0_eef_pos"], dtype=np.float64)
    print(f"  Grasp EEF (frame {grasp_frame}): {grasp_eef}")
    grasp_height = grasp_eef[2]
    print(f"  Grasp height: {grasp_height:.4f} (range [{model_module.MIN_HEIGHT:.4f}, {model_module.MAX_HEIGHT:.4f}])")

    # Restore to visualization frame
    obs = env.set_init_state(state)
    sim.forward()
    obs = env.env._get_observations()

    # ── Render policy camera image ──
    policy_rgb = np.flipud(np.asarray(obs["agentview_image"]).copy())  # training convention

    # ── Run PARA model inference ──
    print("Running PARA inference...")
    img = policy_rgb.astype(np.float32) / 255.0
    img = (img - IMAGENET_MEAN) / IMAGENET_STD
    img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).float().unsqueeze(0).to(device)

    # Get EEF start keypoint
    pix_rc = project_points_from_world_to_camera(
        eef_pos.reshape(1, 3), world_to_camera, IMAGE_SIZE, IMAGE_SIZE
    )[0]
    start_kp = torch.tensor([float(pix_rc[1]), float(pix_rc[0])], dtype=torch.float32).to(device)

    with torch.no_grad():
        volume_logits, _, _, feats = model(img_tensor, start_kp)
    # volume_logits: (1, N_WINDOW, N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE)
    print(f"  Volume logits shape: {volume_logits.shape}")

    # ── Extract heatmap for timestep 0 ──
    vol_t0 = volume_logits[0, 0]  # (N_HEIGHT_BINS, PRED_SIZE, PRED_SIZE)
    # Softmax over full volume to get probabilities
    # Global softmax over full volume (correct model output)
    vol_probs = F.softmax(vol_t0.reshape(-1), dim=0).reshape(vol_t0.shape)

    # 2D heatmap: max over height bins
    heat_2d = vol_probs.max(dim=0)[0].cpu().numpy()  # (PRED_SIZE, PRED_SIZE)
    heat_2d_upsampled = cv2.resize(heat_2d, (IMAGE_SIZE, IMAGE_SIZE))

    # Find predicted pixel (argmax of 2D heatmap)
    flat_idx = heat_2d.argmax()
    pred_py, pred_px = flat_idx // PRED_SIZE, flat_idx % PRED_SIZE
    scale = IMAGE_SIZE / PRED_SIZE
    pred_pixel = np.array([(pred_px + 0.5) * scale, (pred_py + 0.5) * scale])
    print(f"  Predicted pixel: ({pred_pixel[0]:.1f}, {pred_pixel[1]:.1f})")

    # Height probabilities at predicted pixel
    height_probs = vol_probs[:, pred_py, pred_px].cpu().numpy()
    print(f"  Max height bin: {np.argmax(height_probs)}, prob: {height_probs.max():.4f}")

    # ── Compute 3D geometry ──
    print("Computing 3D geometry...")

    # Frustum corners
    frustum_corners = compute_frustum_corners(camera_pose, cam_K, IMAGE_SIZE, args.image_plane_depth)

    # Ray through predicted pixel
    gt_ray_bins, ray_start, ray_end = compute_ray_points(
        pred_pixel, camera_pose, cam_K,
        model_module.MIN_HEIGHT, model_module.MAX_HEIGHT, N_HEIGHT_BINS
    )
    print(f"  Ray: {len(gt_ray_bins)} height bins, start={ray_start}, end={ray_end}")

    # Additional sample rays
    sample_pixels = get_sample_pixels(pred_pixel, IMAGE_SIZE, PRED_SIZE, n_extra=4)
    sample_rays = []
    for px in sample_pixels[1:]:  # skip GT pixel (already computed)
        px_pred = np.array([px[0] / scale - 0.5, px[1] / scale - 0.5])
        px_pred = np.clip(px_pred, 0, PRED_SIZE - 1).astype(int)
        h_probs = vol_probs[:, px_pred[1], px_pred[0]].cpu().numpy()
        bins, rs, re = compute_ray_points(px, camera_pose, cam_K,
                                          model_module.MIN_HEIGHT, model_module.MAX_HEIGHT, N_HEIGHT_BINS)
        sample_rays.append((bins, rs, re, h_probs, px))

    # Full volume rays
    volume_pixels = get_volume_pixels(IMAGE_SIZE, grid_size=16, gt_pixel=pred_pixel)
    volume_rays = []
    for px in volume_pixels:
        px_pred = np.array([px[0] / scale - 0.5, px[1] / scale - 0.5])
        px_pred = np.clip(px_pred, 0, PRED_SIZE - 1).astype(int)
        h_probs = vol_probs[:, px_pred[1], px_pred[0]].cpu().numpy()
        bins, rs, re = compute_ray_points(px, camera_pose, cam_K,
                                          model_module.MIN_HEIGHT, model_module.MAX_HEIGHT, N_HEIGHT_BINS)
        if len(bins) > 0:
            volume_rays.append((bins, rs, re, h_probs, px))

    # ── Compute observer camera ──
    # Look-at point: where the policy camera is looking (table intersection)
    default_rot = R.from_quat([policy_cam_quat[1], policy_cam_quat[2],
                                policy_cam_quat[3], policy_cam_quat[0]])
    default_mat = default_rot.as_matrix()
    forward_dir = -default_mat[:, 2]

    if abs(forward_dir[2]) > 1e-6:
        t_table = (TABLE_Z - policy_cam_pos[2]) / forward_dir[2]
        scene_center = policy_cam_pos + t_table * forward_dir
    else:
        scene_center = policy_cam_pos + 0.5 * forward_dir
        scene_center[2] = TABLE_Z

    observer_pos, observer_quat = compute_observer_camera(policy_cam_pos, policy_cam_quat, scene_center)
    print(f"  Observer camera pos: {observer_pos}")

    # ── Render policy view with heatmap overlay ──
    # Boost contrast: raise to power <1 to amplify low values
    heat_norm = (heat_2d_upsampled - heat_2d_upsampled.min()) / (heat_2d_upsampled.max() + 1e-8)
    heat_boosted = np.power(heat_norm, 0.4)  # gamma correction for visibility
    heat_color = cv2.applyColorMap((heat_boosted * 255).astype(np.uint8), cv2.COLORMAP_PLASMA)
    heat_color_rgb = cv2.cvtColor(heat_color, cv2.COLOR_BGR2RGB)
    policy_view_heatmap = np.clip(
        policy_rgb.astype(np.float32) * 0.4 + heat_color_rgb.astype(np.float32) * 0.6,
        0, 255
    ).astype(np.uint8)
    # Draw predicted pixel marker — bright green crosshair
    cv2.drawMarker(policy_view_heatmap, (int(pred_pixel[0]), int(pred_pixel[1])),
                   (0, 255, 0), cv2.MARKER_CROSS, 24, 3, cv2.LINE_AA)
    cv2.circle(policy_view_heatmap, (int(pred_pixel[0]), int(pred_pixel[1])),
               12, (0, 255, 0), 2, cv2.LINE_AA)

    # ── Generate frames ──
    print(f"Generating {TOTAL_FRAMES} frames at {FPS} fps...")
    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, "para_method_3d.mp4")

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    writer = cv2.VideoWriter(output_path, fourcc, FPS, (IMAGE_SIZE, IMAGE_SIZE))

    policy_rgb_bgr = cv2.cvtColor(policy_rgb, cv2.COLOR_RGB2BGR)
    policy_heatmap_bgr = cv2.cvtColor(policy_view_heatmap, cv2.COLOR_RGB2BGR)

    def render_at_camera(cam_p, cam_q):
        """Render MuJoCo scene from given camera pose, return (bgr_frame, w2c_matrix).

        Must call env.set_init_state + sim.forward + _get_observations to force
        a fresh MuJoCo render (plain _get_observations returns cached results).
        """
        # Set camera BEFORE set_init_state so it's active for the render
        sim.model.cam_pos[cam_id] = cam_p
        sim.model.cam_quat[cam_id] = cam_q
        # set_init_state triggers a fresh sim reset + render
        env.set_init_state(state)
        sim.forward()
        # Re-render to pick up camera changes
        r = env.env._get_observations()
        rgb = np.flipud(np.asarray(r["agentview_image"]).copy())
        bgr = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
        w2c = get_camera_transform_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
        return bgr, w2c

    # ── Pre-render observer view (for static phases after transition) ──
    observer_bgr, observer_w2c = render_at_camera(observer_pos, observer_quat)

    for frame_idx in range(TOTAL_FRAMES):
        if (frame_idx + 1) % 30 == 0:
            print(f"  Frame {frame_idx + 1}/{TOTAL_FRAMES}")

        # ─── Determine current camera position ───
        if frame_idx < PHASE_TRANSITION[1]:
            # Smooth transition from policy camera (B) → observer camera (A)
            t = frame_idx / max(PHASE_TRANSITION[1] - 1, 1)
            cur_pos, cur_quat = interpolate_camera(
                policy_cam_pos, policy_cam_quat,
                observer_pos, observer_quat, t
            )
            # Re-render MuJoCo at this intermediate camera position
            frame, cur_w2c = render_at_camera(cur_pos, cur_quat)
        elif frame_idx >= PHASE_SERVO[0]:
            # During servo: step robot toward grasp target, then render from observer
            if frame_idx == PHASE_SERVO[0]:
                # Initialize servo: reset env to the visualization state
                env.set_init_state(state)
                sim.forward()
                env.env.timestep = 0
                env.env.done = False
                env.env.horizon = 100000
            # Compute delta action toward grasp target
            cur_obs = env.env._get_observations()
            cur_eef = np.array(cur_obs["robot0_eef_pos"], dtype=np.float64)
            delta = grasp_eef - cur_eef
            dist = np.linalg.norm(delta)
            if dist > 0.005:  # only step if not yet at target
                delta_clipped = np.clip(delta / 0.05, -1.0, 1.0)
                action = np.zeros(7, dtype=np.float32)
                action[:3] = delta_clipped
                action[6] = -1.0  # gripper open during approach
                env.step(action)
            # Render from observer camera
            sim.model.cam_pos[cam_id] = observer_pos
            sim.model.cam_quat[cam_id] = observer_quat
            sim.forward()
            obs_servo = env.env._get_observations()
            frame_rgb = np.flipud(np.asarray(obs_servo["agentview_image"]).copy())
            frame = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
            cur_w2c = get_camera_transform_matrix(sim, cam_name, IMAGE_SIZE, IMAGE_SIZE)
        else:
            # Static at observer position — use pre-rendered frame
            frame = observer_bgr.copy()
            cur_w2c = observer_w2c

        # ─── Compute fade factor for overlay elements ───
        # 1.0 during normal phases, fades to 0.0 during PHASE_FADE, 0.0 during PHASE_SERVO
        if frame_idx < PHASE_FADE[0]:
            overlay_fade = 1.0
        elif frame_idx < PHASE_FADE[1]:
            overlay_fade = 1.0 - smoothstep((frame_idx - PHASE_FADE[0]) /
                                             max(PHASE_FADE[1] - PHASE_FADE[0], 1))
        else:
            overlay_fade = 0.0

        # ─── Frustum + floating image + camera icon (fade with overlay_fade) ───
        if overlay_fade > 0.01:
            fc_2d = project_3d_to_2d(frustum_corners, cur_w2c, IMAGE_SIZE)
            has_fc = all(c is not None for c in fc_2d)

            # Floating image plane
            if frame_idx < PHASE_TRANSITION[1]:
                t = frame_idx / max(PHASE_TRANSITION[1] - 1, 1)
                img_alpha = smoothstep(t) * 0.7
            else:
                img_alpha = 0.7
            img_alpha *= overlay_fade

            if has_fc and img_alpha > 0.05:
                pts = np.array([c for c in fc_2d], dtype=np.float32)
                area = cv2.contourArea(pts.astype(np.int32))
                if area > 500:
                    render_floating_image(frame, policy_heatmap_bgr, fc_2d, alpha=img_alpha)

            # Camera icon
            draw_camera_icon(frame, policy_cam_pos, cur_w2c, IMAGE_SIZE)

            # Frustum lines
            draw_frustum(frame, policy_cam_pos, frustum_corners, cur_w2c, IMAGE_SIZE, overlay_fade)

        # ─── Target ray with height bins (fades with overlay_fade) ───
        if frame_idx >= PHASE_FIRST_RAY[0] and overlay_fade > 0.01:
            ray_t = min(1.0, (frame_idx - PHASE_FIRST_RAY[0]) / (PHASE_FIRST_RAY[1] - PHASE_FIRST_RAY[0]))
            n_bins_show = max(1, int(smoothstep(ray_t) * len(gt_ray_bins)))
            highlight = frame_idx >= PHASE_HIGHLIGHT[0]
            draw_ray_with_bins(frame, policy_cam_pos, ray_end, gt_ray_bins, height_probs,
                              cur_w2c, IMAGE_SIZE, ray_color=(0, 230, 255),
                              n_bins_to_show=n_bins_show, highlight_max=highlight,
                              ray_alpha=0.8 * overlay_fade, bin_radius=8)

        # ─── More sample rays (fades with overlay_fade) ───
        if frame_idx >= PHASE_MORE_RAYS[0] and overlay_fade > 0.01:
            more_t = (frame_idx - PHASE_MORE_RAYS[0]) / (PHASE_MORE_RAYS[1] - PHASE_MORE_RAYS[0])
            n_rays_show = max(1, int(smoothstep(more_t) * len(sample_rays)))
            for i in range(min(n_rays_show, len(sample_rays))):
                bins, rs, re, hp, px = sample_rays[i]
                draw_ray_with_bins(frame, rs, re, bins, hp,
                                  cur_w2c, IMAGE_SIZE, ray_color=(0, 180, 220),
                                  highlight_max=False, ray_alpha=0.5 * overlay_fade, bin_radius=6)

        # ─── Full volume: 16x16 grid shown ALL AT ONCE, two-pass rendering ───
        # Also fades with overlay_fade
        # Pass 1: uniform purple cloud (low-prob dots, shows volume structure)
        # Pass 2: bright colored peaks (high-prob dots, shows where model predicts)
        if frame_idx >= PHASE_FULL_VOLUME[0] and overlay_fade > 0.01:
            vol_fade = smoothstep(min(1.0, (frame_idx - PHASE_FULL_VOLUME[0]) /
                                       max(PHASE_FULL_VOLUME[1] - PHASE_FULL_VOLUME[0], 1) * 2.0))
            PEAK_THRESH = 0.001  # fraction of global max to count as peak
            global_pmax = max((r[3].max() for r in volume_rays if len(r[0]) > 0), default=1e-8)

            cloud_2d = []  # low-prob bin positions (purple cloud)
            peak_2d = []   # high-prob bins (position, color)

            for bins, rs, re, hp, px in volume_rays:
                if len(bins) == 0:
                    continue
                # Show ALL height bins for this ray
                for bi in range(min(len(bins), len(hp))):
                    prob = hp[bi]
                    pt_2d = project_3d_to_2d(np.array([bins[bi]]), cur_w2c, IMAGE_SIZE)
                    if pt_2d[0] is None:
                        continue
                    pn = prob / global_pmax
                    if pn > PEAK_THRESH:
                        cv_val = int(np.clip(np.sqrt(pn) * 255, 0, 255))
                        ci = cv2.applyColorMap(np.array([[cv_val]], dtype=np.uint8), cv2.COLORMAP_PLASMA)
                        c = tuple(int(x) for x in ci[0, 0])
                        peak_2d.append((pt_2d[0], c))
                    else:
                        cloud_2d.append(pt_2d[0])

            # Pass 1: purple cloud (fades with overlay_fade)
            overlay = frame.copy()
            purple_bgr = (140, 50, 80)
            for pt in cloud_2d:
                cv2.circle(overlay, pt, 5, purple_bgr, -1, cv2.LINE_AA)
            blend = 0.3 * vol_fade * overlay_fade
            cv2.addWeighted(overlay, blend, frame, 1.0 - blend, 0, frame)

            # Pass 2: bright peaks on top (fades with overlay_fade)
            for pt, c in peak_2d:
                c_faded = tuple(int(x * overlay_fade) for x in c)
                cv2.circle(frame, pt, 7, c_faded, -1, cv2.LINE_AA)
                if overlay_fade > 0.5:
                    cv2.circle(frame, pt, 7, (255, 255, 255), 1, cv2.LINE_AA)

        # ─── GT grasp point highlight (visible from HIGHLIGHT through SERVO) ───
        if frame_idx >= PHASE_HIGHLIGHT[0]:
            pulse_t = (frame_idx - PHASE_HIGHLIGHT[0]) / FPS
            draw_gt_marker(frame, grasp_eef, cur_w2c, IMAGE_SIZE, pulse_t)

        # ─── Text overlays ───
        if frame_idx < PHASE_ESTABLISH[1]:
            add_text(frame, "Camera frustum over the scene", position="bottom")
        elif PHASE_FIRST_RAY[0] <= frame_idx < PHASE_MORE_RAYS[0]:
            add_text(frame, "Ray through predicted pixel: height bins colored by probability",
                     position="bottom", font_scale=0.48)
        elif PHASE_MORE_RAYS[0] <= frame_idx < PHASE_FULL_VOLUME[0]:
            add_text(frame, "Per-pixel height predictions along each ray", position="bottom",
                     font_scale=0.5)
        elif PHASE_FULL_VOLUME[0] <= frame_idx < PHASE_HIGHLIGHT[0]:
            add_text(frame, "Full heatmap volume", position="bottom")
        elif PHASE_HIGHLIGHT[0] <= frame_idx < PHASE_FADE[0]:
            add_text(frame, "Argmax -> 3D grasp target", position="bottom")
        elif PHASE_FADE[0] <= frame_idx < PHASE_SERVO[0]:
            add_text(frame, "Argmax -> 3D grasp target", position="bottom")
        elif frame_idx >= PHASE_SERVO[0]:
            add_text(frame, "Move robot to predicted target", position="bottom")

        writer.write(frame)

    writer.release()
    print(f"Saved raw video: {output_path}")

    # ── Re-encode to H.264 ──
    h264_path = output_path.replace(".mp4", "_h264.mp4")
    ret = os.system(
        f'ffmpeg -y -i "{output_path}" -c:v libx264 -preset ultrafast -crf 23 '
        f'-movflags +faststart "{h264_path}" 2>/dev/null'
    )
    if ret == 0:
        os.replace(h264_path, output_path)
        print(f"Re-encoded to H.264: {output_path}")

    env.close()
    print("Done!")


if __name__ == "__main__":
    main()
