"""Test trajectory predictions with 3D lifting and IK visualization."""
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import argparse
import mujoco
import mink
import time
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

from data import RealTrajectoryDataset, N_WINDOW
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS
import model as model_module
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import get_link_poses_from_robot, position_exoskeleton_meshes, render_from_camera_pose

# Configuration
IMAGE_SIZE = 448
CHECKPOINT_PATH = "volume_dino_tracks/checkpoints/volume_dino_tracks/latest.pth"

# Virtual gripper keypoint in local gripper frame (from dataset generation)
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Hard-coded grasp parameters from make_datasets.py
cube_to_gripper_offset = np.array([-0.04433927, -0.00871351, 0.01834697])
initial_kp_quat_wxyz = np.array([-0.01718116, -0.00338855, 0.76436812, 0.64454224])
start_pos = np.array([0, -1.57, 1.57, 1.57, 1.57, 1])


def _try_import_viewer():
    """Import mujoco.viewer lazily (may be unavailable in some installs)."""
    try:
        import mujoco.viewer as viewer  # type: ignore
        return viewer
    except Exception:
        return None


def _set_marker_sphere(geom, pos_xyz, rgba, size=0.01):
    """Configure an mjvGeom as a sphere marker (robust across MuJoCo Python versions)."""
    pos = np.asarray(pos_xyz, dtype=np.float64).reshape(3)
    rgba = np.asarray(rgba, dtype=np.float32).reshape(4)
    size3 = np.array([float(size), 0.0, 0.0], dtype=np.float64)
    mat9 = np.eye(3, dtype=np.float64).reshape(9)

    # Initialize via MuJoCo helper to avoid version-specific struct fields.
    mujoco.mjv_initGeom(
        geom,
        mujoco.mjtGeom.mjGEOM_SPHERE,
        size3,
        pos,
        mat9,
        rgba,
    )
    # Ensure it's treated as decoration (not collision/physics)
    try:
        geom.category = mujoco.mjtCatBit.mjCAT_DECOR
    except Exception:
        pass


def _launch_mujoco_viewer_with_trajectory(
    mj_model,
    mj_data,
    pred_trajectory_3d,
    gt_trajectory_3d,
    ik_trajectory_qpos,
    title="MuJoCo trajectory viewer",
    animate_robot=True,
    marker_size=0.012,
    waypoint_hz=4.0,
):
    """Launch MuJoCo viewer and draw predicted/GT 3D trajectories as spheres."""
    viewer_mod = _try_import_viewer()
    if viewer_mod is None:
        raise RuntimeError(
            "MuJoCo viewer is not available in this Python environment. "
            "Install a mujoco build with viewer support or run without --mujoco_viewer."
        )

    with viewer_mod.launch_passive(mj_model, mj_data, show_left_ui=False, show_right_ui=False) as viewer:
        # Best effort: set window title (API varies)
        try:
            viewer._render_window.set_title(title)  # type: ignore[attr-defined]
        except Exception:
            pass

        pred_rgba = np.array([0.1, 0.9, 0.1, 0.8], dtype=np.float32)  # green
        gt_rgba = np.array([1.0, 1.0, 1.0, 0.8], dtype=np.float32)  # white

        pred_positions = np.asarray(pred_trajectory_3d, dtype=np.float64)
        gt_positions = np.asarray(gt_trajectory_3d, dtype=np.float64)

        n_pred = int(pred_positions.shape[0])
        n_gt = int(gt_positions.shape[0])

        scn = viewer.user_scn
        max_geoms = int(getattr(scn, "maxgeom", 0) or len(scn.geoms))
        needed = n_pred + n_gt
        if needed > max_geoms:
            # Subsample uniformly to fit in the user scene
            keep_pred = max(1, max_geoms // 2)
            keep_gt = max(1, max_geoms - keep_pred)
            pred_idx = np.linspace(0, n_pred - 1, keep_pred, dtype=int)
            gt_idx = np.linspace(0, n_gt - 1, keep_gt, dtype=int)
            pred_positions = pred_positions[pred_idx]
            gt_positions = gt_positions[gt_idx]
            n_pred = int(pred_positions.shape[0])
            n_gt = int(gt_positions.shape[0])

        # Add markers under viewer lock (required by some builds)
        try:
            lock_ctx = viewer.lock()
        except Exception:
            lock_ctx = None
        if lock_ctx is not None:
            lock_ctx.__enter__()
        try:
            scn.ngeom = 0
            for i in range(n_gt):
                g = scn.geoms[scn.ngeom]
                _set_marker_sphere(g, gt_positions[i], gt_rgba, size=float(marker_size))
                scn.ngeom += 1
            for i in range(n_pred):
                g = scn.geoms[scn.ngeom]
                _set_marker_sphere(g, pred_positions[i], pred_rgba, size=float(marker_size))
                scn.ngeom += 1

            # Make endpoints bigger
            if n_gt > 0:
                scn.geoms[0].size[0] = float(marker_size) * 1.4
                scn.geoms[n_gt - 1].size[0] = float(marker_size) * 1.8
            if n_pred > 0:
                scn.geoms[n_gt].size[0] = float(marker_size) * 1.4
                scn.geoms[n_gt + n_pred - 1].size[0] = float(marker_size) * 1.8
        finally:
            if lock_ctx is not None:
                lock_ctx.__exit__(None, None, None)

        # Recenter viewer camera on the trajectories (so markers are visible immediately).
        all_pts = np.concatenate([gt_positions, pred_positions], axis=0) if (n_gt + n_pred) > 0 else None
        if all_pts is not None and all_pts.size > 0 and hasattr(viewer, "cam"):
            center = all_pts.mean(axis=0)
            span = float(np.max(np.linalg.norm(all_pts - center[None, :], axis=1)))
            try:
                viewer.cam.lookat[:] = center
                # A bit farther than the max span for a comfy view.
                viewer.cam.distance = max(0.2, span * 3.0)
                viewer.cam.azimuth = -90
                viewer.cam.elevation = -30
            except Exception:
                pass

        ik_qpos = np.asarray(ik_trajectory_qpos, dtype=np.float64)
        step = 0
        last = time.time()
        t_accum = 0.0
        waypoint_dt = 1.0 / max(float(waypoint_hz), 1e-3)

        # Ensure first frame shows markers even if animation is off.
        try:
            viewer.sync()
        except Exception:
            pass

        # Run until user closes the viewer. Some mujoco python builds don't expose `is_running`,
        # so we fall back to catching exceptions from `sync()`.
        while True:
            if hasattr(viewer, "is_running"):
                try:
                    if not viewer.is_running():
                        break
                except Exception:
                    pass
            now = time.time()
            dt = now - last
            last = now
            t_accum += dt

            if animate_robot and ik_qpos.size > 0:
                # Advance waypoint at fixed rate (waypoint_hz), independent of viewer render FPS.
                while t_accum >= waypoint_dt:
                    t_accum -= waypoint_dt
                    step = (step + 1) % ik_qpos.shape[0]
                mj_data.qpos[: len(mj_data.ctrl)] = ik_qpos[step, : len(mj_data.ctrl)]
                mj_data.ctrl[:] = mj_data.qpos[: len(mj_data.ctrl)]
                mujoco.mj_forward(mj_model, mj_data)

            # Re-apply markers each frame (some viewer builds rebuild scenes on sync)
            try:
                lock_ctx = viewer.lock()
            except Exception:
                lock_ctx = None
            if lock_ctx is not None:
                lock_ctx.__enter__()
            try:
                scn.ngeom = 0
                for i in range(n_gt):
                    g = scn.geoms[scn.ngeom]
                    _set_marker_sphere(g, gt_positions[i], gt_rgba, size=float(marker_size))
                    scn.ngeom += 1
                for i in range(n_pred):
                    g = scn.geoms[scn.ngeom]
                    _set_marker_sphere(g, pred_positions[i], pred_rgba, size=float(marker_size))
                    scn.ngeom += 1
            finally:
                if lock_ctx is not None:
                    lock_ctx.__exit__(None, None, None)

            try:
                viewer.sync()
            except Exception:
                break
            # Cap render loop to ~60fps (viewer handles vsync too)
            time.sleep(max(0.0, (1.0 / 60.0) - dt))


def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to camera ray in world coordinates.
    
    Args:
        point_2d: (2,) 2D pixel coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        cam_pos: (3,) camera position in world frame
        ray_direction: (3,) normalized ray direction in world frame
    """
    # Camera position in world frame
    cam_pos_world = np.linalg.inv(camera_pose)[:3, 3]
    
    # Unproject pixel to camera frame
    K_inv = np.linalg.inv(cam_K)
    point_2d_h = np.array([point_2d[0], point_2d[1], 1.0])
    ray_cam = K_inv @ point_2d_h
    
    # Transform ray to world frame
    cam_rot_world = np.linalg.inv(camera_pose)[:3, :3]
    ray_world = cam_rot_world @ ray_cam
    ray_world = ray_world / np.linalg.norm(ray_world)  # Normalize
    
    return cam_pos_world, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D projection and height.
    
    Args:
        kp_2d_image: (2,) 2D pixel coordinates
        height: height (z-coordinate in world frame)
        camera_pose: (4, 4) camera pose matrix
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (3,) 3D keypoint position, or None if invalid
    """
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (z coordinate)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D point
    point_3d = cam_pos + t * ray_direction
    
    return point_3d


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep."""
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits to continuous values in [MIN_GRIPPER, MAX_GRIPPER]."""
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)
    normalized = bin_centers[bin_indices]
    return normalized * (max_g - min_g) + min_g


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep (decode at pred pixel during inference).
    gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W), pixel_2d: (B, N_WINDOW, 2) -> (B, N_WINDOW, N_GRIPPER_BINS)
    """
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    return gripper_logits[batch_idx, time_idx, :, py, px]


def ik_to_cube_grasp(model, data, configuration, target_pos, target_kp_quat_wxyz, num_iterations=50):
    """Run IK to move virtual gripper keypoint to maintain offset from cube.
    Returns:
        optimized_qpos: (nq,) array of optimized joint positions
        final_error: scalar error in meters
    """
    
    dt = 0.01  # Control timestep
    damping = 1.0  # Damping for smooth motion

    # Run IK iterations
    for iteration in range(num_iterations):
        mujoco.mj_forward(model, data)
        configuration.update(data.qpos)
        
        # Task 1: Keypoint position and orientation (virtual_gripper_keypoint)
        # Use saved initial orientation as target to maintain reference grasp orientation
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.0)
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([target_kp_quat_wxyz, target_pos])))
        
        # Task 2: Posture task to regularize
        posture_task = mink.PostureTask(model, cost=1e-3)
        posture_task.set_target(data.qpos)
        
        # Solve IK
        vel = mink.solve_ik(configuration, [kp_task, ], dt, "daqp", limits=[mink.ConfigurationLimit(model=model)])
        configuration.integrate_inplace(vel, dt)
        data.qpos[:] = configuration.q
        
        # Update forward kinematics
        mujoco.mj_forward(model, data)
    
    # Compute final error

    kp_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
    current_kp_pos = data.xpos[kp_body_id]
    pos_error = target_pos - current_kp_pos
    error_norm = np.linalg.norm(pos_error)
    
    return data.qpos.copy(), error_norm

def main():
    parser = argparse.ArgumentParser(description='Test trajectory predictions with IK')
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT_PATH,
                       help='Path to model checkpoint')
    parser.add_argument('--split', type=str, default='val_viewpoints',
                       choices=['train', 'val_viewpoints', 'val_cube_pos'],
                       help='Dataset split')
    parser.add_argument('--dataset_root', type=str, default='scratch/parsed_school_cap',
                       help='Root directory of dataset')
    parser.add_argument('--sample_idx', type=int, default=0,
                       help='Sample index to test')
    parser.add_argument('--num_ik_iters', type=int, default=50,
                       help='Number of IK iterations')
    parser.add_argument('--mujoco_viewer', action='store_true',
                       help='Show 3D trajectory + IK animation in MuJoCo viewer instead of Matplotlib rendering')
    parser.add_argument('--no_viewer_animate', action='store_true',
                       help='When using --mujoco_viewer, disable robot animation (show markers only)')
    parser.add_argument('--viewer_marker_size', type=float, default=0.012,
                       help='Marker sphere radius (meters) for MuJoCo viewer trajectory points')
    parser.add_argument('--viewer_fps', type=float, default=4.0,
                       help='Playback rate for IK waypoint animation in MuJoCo viewer (waypoints/sec)')
    parser.add_argument('--use_gt', action='store_true',
                       help='Skip model inference and use dataset GT 2D+height(+gripper) trajectories for debugging')
    parser.add_argument('--use_fixed_start_pos', action='store_true',
                       help='Always initialize MuJoCo from hard-coded start_pos (ignore episode start joint state)')
    args = parser.parse_args()

    # Import matplotlib lazily so `mjpython --mujoco_viewer` doesn't initialize any GUI toolkits.
    if not args.mujoco_viewer:
        import matplotlib.pyplot as plt  # noqa: F401
        from mpl_toolkits.mplot3d import Axes3D  # noqa: F401  (needed for 3D projection)
    
    device = None
    checkpoint = None
    model = None
    if not args.use_gt:
        # Setup device
        device = torch.device("mps" if torch.backends.mps.is_available() else 
                             "cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        # Load model
        print(f"\nLoading model from {args.checkpoint}...")
        model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
        checkpoint = torch.load(args.checkpoint, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model = model.to(device)
        model.eval()
        # Load MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER from checkpoint if available
        import model as model_module
        if 'min_height' in checkpoint and 'max_height' in checkpoint:
            model_module.MIN_HEIGHT = checkpoint['min_height']
            model_module.MAX_HEIGHT = checkpoint['max_height']
            print(f"✓ Loaded height range from checkpoint: [{checkpoint['min_height']:.6f}, {checkpoint['max_height']:.6f}] m")
            print(f"  ({checkpoint['min_height']*1000:.2f}mm to {checkpoint['max_height']*1000:.2f}mm)")
            if checkpoint['min_height'] == checkpoint['max_height']:
                print(f"  ⚠ WARNING: min_height == max_height; height predictions will be constant.")
        else:
            print(f"⚠ Checkpoint missing min_height/max_height; using model defaults.")
        if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
            model_module.MIN_GRIPPER = checkpoint['min_gripper']
            model_module.MAX_GRIPPER = checkpoint['max_gripper']
            print(f"✓ Loaded gripper range from checkpoint: [{checkpoint['min_gripper']:.6f}, {checkpoint['max_gripper']:.6f}]")
        else:
            print(f"⚠ Checkpoint missing min_gripper/max_gripper; using model defaults.")
        print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
    else:
        print("Using GT trajectories (skipping model load).")
    
    # Load dataset
    print(f"\nLoading dataset...")
    dataset = RealTrajectoryDataset(
        dataset_root=args.dataset_root,
        image_size=IMAGE_SIZE
    )
    print(f"✓ Loaded {len(dataset)} samples")
    
    # Get sample
    sample = dataset[args.sample_idx]
    episode_id = sample['episode_id']
    print(f"\nTesting sample {args.sample_idx}: {episode_id}")
    
    # Get GT trajectory for comparison
    trajectory_2d = sample['trajectory_2d'].numpy()  # (N_WINDOW, 2)
    trajectory_3d = sample['trajectory_3d'].numpy()  # (N_WINDOW, 3)
    trajectory_gripper = sample['trajectory_gripper'].numpy()  # (N_WINDOW,)
    
    if args.use_gt:
        # Use GT 2D + height (+ gripper) as "predicted" for debugging.
        pred_trajectory_2d = trajectory_2d.copy()
        pred_height = trajectory_3d[:, 2].copy()
        pred_gripper = trajectory_gripper.copy()
        print("\nUsing dataset GT trajectory for debugging (2D + height + gripper).")
    else:
        # Run model prediction (volume model: volume_logits + gripper_logits)
        print("\nRunning model inference...")
        start_keypoint_2d = torch.from_numpy(trajectory_2d[0]).float().to(device)  # (2,)
        
        with torch.no_grad():
            rgb = sample['rgb'].unsqueeze(0).to(device)
            volume_logits, gripper_logits = model(
                rgb,
                training=False,
                start_keypoint_2d=start_keypoint_2d,
            )
        
        pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
        gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)
        pred_gripper = decode_gripper_bins(gripper_logits_at_pred)
        pred_trajectory_2d = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2)
        pred_height = pred_height[0].cpu().numpy()  # (N_WINDOW,)
        pred_gripper = pred_gripper[0].cpu().numpy()  # (N_WINDOW,)
    
    print(f"✓ Predicted trajectory (first and last):")
    gripper_start_status = "open" if pred_gripper[0] >= 0.5 else "closed"
    gripper_end_status = "open" if pred_gripper[-1] >= 0.5 else "closed"
    print(f"  Start: [{pred_trajectory_2d[0, 0]:.1f}, {pred_trajectory_2d[0, 1]:.1f}] px, H: {pred_height[0]*1000:.2f} mm, G: {gripper_start_status} ({pred_gripper[0]:.3f})")
    print(f"  Final: [{pred_trajectory_2d[-1, 0]:.1f}, {pred_trajectory_2d[-1, 1]:.1f}] px, H: {pred_height[-1]*1000:.2f} mm, G: {gripper_end_status} ({pred_gripper[-1]:.3f})")
    gt_gripper_start_status = "open" if trajectory_gripper[0] >= 0.5 else "closed"
    gt_gripper_end_status = "open" if trajectory_gripper[-1] >= 0.5 else "closed"
    print(f"  GT Start: [{trajectory_2d[0, 0]:.1f}, {trajectory_2d[0, 1]:.1f}] px, H: {trajectory_3d[0, 2]*1000:.2f} mm, G: {gt_gripper_start_status} ({trajectory_gripper[0]:.3f})")
    print(f"  GT Final: [{trajectory_2d[-1, 0]:.1f}, {trajectory_2d[-1, 1]:.1f}] px, H: {trajectory_3d[-1, 2]*1000:.2f} mm, G: {gt_gripper_end_status} ({trajectory_gripper[-1]:.3f})")
    
    # Load camera parameters
    camera_pose = sample['camera_pose'].numpy()
    cam_K_norm = sample['cam_K_norm'].numpy()
    
    # Denormalize intrinsics
    cam_K = cam_K_norm.copy()
    cam_K[0] *= IMAGE_SIZE
    cam_K[1] *= IMAGE_SIZE
    
    # Lift trajectory to 3D
    print("\nLifting trajectory to 3D...")
    pred_trajectory_3d = []
    for t in range(N_WINDOW):
        pred_2d_t = pred_trajectory_2d[t]
        pred_height_t = pred_height[t]
        pred_3d_t = recover_3d_from_direct_keypoint_and_height(pred_2d_t, pred_height_t, camera_pose, cam_K)
        
        if pred_3d_t is None:
            print(f"✗ Failed to lift waypoint {t} to 3D (point behind camera or ray parallel to height plane)")
            # Use GT as fallback
            pred_3d_t = trajectory_3d[t].copy()
        pred_trajectory_3d.append(pred_3d_t)
    
    pred_trajectory_3d = np.array(pred_trajectory_3d)  # (N_WINDOW, 3)
    fig_3d = None
    if not args.mujoco_viewer:
        # Plot 3D trajectory (GT vs Pred)
        import matplotlib.pyplot as plt
        fig_3d = plt.figure(figsize=(12, 5))
        ax_traj3d = fig_3d.add_subplot(1, 2, 1, projection="3d")
        ax_traj3d.plot(trajectory_3d[:, 0], trajectory_3d[:, 1], trajectory_3d[:, 2], "--", color="white", linewidth=2, label="GT 3D")
        ax_traj3d.plot(pred_trajectory_3d[:, 0], pred_trajectory_3d[:, 1], pred_trajectory_3d[:, 2], "-", color="lime", linewidth=2, label="Pred 3D")
        ax_traj3d.scatter(trajectory_3d[0, 0], trajectory_3d[0, 1], trajectory_3d[0, 2], c="cyan", s=40, label="Start")
        ax_traj3d.scatter(trajectory_3d[-1, 0], trajectory_3d[-1, 1], trajectory_3d[-1, 2], c="white", s=40, label="GT End")
        ax_traj3d.scatter(pred_trajectory_3d[-1, 0], pred_trajectory_3d[-1, 1], pred_trajectory_3d[-1, 2], c="lime", s=40, label="Pred End")
        ax_traj3d.set_title("3D trajectory", fontsize=12, fontweight="bold")
        ax_traj3d.set_xlabel("X (m)")
        ax_traj3d.set_ylabel("Y (m)")
        ax_traj3d.set_zlabel("Z (m)")
        ax_traj3d.legend(fontsize=8, loc="best")
        ax_traj3d.view_init(elev=20, azim=-60)
        # Try to make aspect roughly equal
        mins = np.minimum(trajectory_3d.min(axis=0), pred_trajectory_3d.min(axis=0))
        maxs = np.maximum(trajectory_3d.max(axis=0), pred_trajectory_3d.max(axis=0))
        spans = np.maximum(maxs - mins, 1e-6)
        centers = (mins + maxs) / 2.0
        max_span = float(spans.max())
        ax_traj3d.set_xlim(centers[0] - max_span / 2, centers[0] + max_span / 2)
        ax_traj3d.set_ylim(centers[1] - max_span / 2, centers[1] + max_span / 2)
        ax_traj3d.set_zlim(centers[2] - max_span / 2, centers[2] + max_span / 2)
    
    # Use final waypoint for IK (also used as last timestep of sequential IK)
    pred_3d = pred_trajectory_3d[-1]
    target_3d = trajectory_3d[-1]
    
    print(f"✓ Predicted 3D (final): [{pred_3d[0]:.4f}, {pred_3d[1]:.4f}, {pred_3d[2]:.4f}] m")
    print(f"  GT 3D (final):        [{target_3d[0]:.4f}, {target_3d[1]:.4f}, {target_3d[2]:.4f}] m")
    print(f"  3D error (final):     {np.linalg.norm(pred_3d - target_3d)*1000:.2f} mm")
    
    # Setup MuJoCo and IK
    print("\nSetting up MuJoCo and IK...")
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)

    # Initialize robot pose for IK:
    # Prefer episode's true start joint state if available (frame 000000), else fall back to start_pos.
    init_qpos = start_pos.copy()
    try:
        if (not args.use_fixed_start_pos) and (sample.get("episode_start_joint_state") is not None):
            js = sample["episode_start_joint_state"]
            if hasattr(js, "numpy"):
                js = js.numpy()
            js = np.asarray(js).reshape(-1)
            if js.size >= len(mj_data.ctrl):
                init_qpos = js[:len(mj_data.ctrl)].astype(np.float64)
                print("✓ Using episode start joint state for MuJoCo initialization")
    except Exception:
        pass

    mj_data.qpos[:len(mj_data.ctrl)] = init_qpos
    mj_data.ctrl[:] = init_qpos
    mujoco.mj_forward(mj_model, mj_data)
    configuration = mink.Configuration(mj_model)
    configuration.update(mj_data.qpos)
    print("✓ MuJoCo initialized")
    
    # Sequential IK along the *predicted* 3D trajectory (robot trajectory)
    print(f"\nPerforming sequential IK along predicted 3D trajectory ({N_WINDOW} steps, {args.num_ik_iters} iters/step)...")
    ik_trajectory_qpos = []
    ik_trajectory_err = []
    # Reset to start
    mj_data.qpos[:len(mj_data.ctrl)] = init_qpos
    mj_data.ctrl[:] = init_qpos
    mujoco.mj_forward(mj_model, mj_data)
    configuration.update(mj_data.qpos)
    for t in range(N_WINDOW):
        targ_pos = pred_trajectory_3d[t]
        qpos_t, err_t = ik_to_cube_grasp(
            mj_model,
            mj_data,
            configuration,
            targ_pos,
            initial_kp_quat_wxyz,
            num_iterations=args.num_ik_iters,
        )
        ik_trajectory_qpos.append(qpos_t.copy())
        ik_trajectory_err.append(float(err_t))
        # carry forward
        mj_data.qpos[:] = qpos_t
        mujoco.mj_forward(mj_model, mj_data)
        configuration.update(mj_data.qpos)
        print(f"  t={t}: 3D=[{targ_pos[0]:.4f}, {targ_pos[1]:.4f}, {targ_pos[2]:.4f}] m, IK err={err_t*1000:.2f} mm")

    ik_trajectory_qpos = np.array(ik_trajectory_qpos)
    ik_trajectory_err = np.array(ik_trajectory_err)
    ik_error_last = float(ik_trajectory_err[-1])
    optimized_qpos_last = ik_trajectory_qpos[-1]

    # Optional: MuJoCo viewer visualization (trajectory spheres + optional robot animation)
    if args.mujoco_viewer:
        print("\nLaunching MuJoCo viewer...")
        mj_data.qpos[:len(mj_data.ctrl)] = init_qpos
        mj_data.ctrl[:] = init_qpos
        mujoco.mj_forward(mj_model, mj_data)
        _launch_mujoco_viewer_with_trajectory(
            mj_model,
            mj_data,
            pred_trajectory_3d,
            trajectory_3d,
            ik_trajectory_qpos,
            title=f"Trajectory spheres (Pred=green, GT=white) | {episode_id}",
            animate_robot=(not args.no_viewer_animate),
            marker_size=float(args.viewer_marker_size),
            waypoint_hz=float(args.viewer_fps),
        )
        print("\n✓ Viewer closed.")
        return
    
    if not args.mujoco_viewer:
        # Render result (Matplotlib path)
        print("\nRendering result...")
        import matplotlib.pyplot as plt
        # Get RGB for visualization
        rgb_vis = sample['rgb'].permute(1, 2, 0).numpy()  # (H, W, 3) already denormalized in dataset
        # Denormalize if needed
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        rgb_vis = np.clip(rgb_vis * std + mean, 0, 1)
        H, W = rgb_vis.shape[:2]
        
        render_res=[540//2, 960//2]
        render_cam_K=cam_K_norm.copy()
        render_cam_K[0] *= render_res[1]
        render_cam_K[1] *= render_res[0]

        # Render robot at predicted trajectory last timestep (sequential IK result)
        mj_data.qpos[:6] = optimized_qpos_last[:6]
        mujoco.mj_forward(mj_model, mj_data)
        rendered_img_pred = render_from_camera_pose(
            mj_model, mj_data, camera_pose, render_cam_K, render_res[0], render_res[1], segmentation=False
        ) / 255
        rendered_img_pred = cv2.resize(rendered_img_pred, (W, H), interpolation=cv2.INTER_LINEAR)
        rendered_overlay_pred = rgb_vis.copy()
        rendered_overlay_pred = rendered_overlay_pred * 0.5 + rendered_img_pred * 0.5

        # Add rendered last-timestep action to the 3D figure
        if fig_3d is not None:
            ax_render = fig_3d.add_subplot(1, 2, 2)
            ax_render.imshow(rendered_overlay_pred)
            ax_render.set_title(f"Rendered robot @ last timestep\nIK err: {ik_error_last*1000:.1f}mm", fontsize=11)
            ax_render.axis("off")
            fig_3d.suptitle(f"3D Lifted Trajectory + Last-Step IK Render | {episode_id}", fontsize=13, fontweight="bold")
            fig_3d.tight_layout()
        
        # Create visualization with trajectory panes (all timesteps + last-step IK render)
        # Layout: N_WINDOW trajectory timesteps + 1 IK render (last timestep)
        fig, axes = plt.subplots(1, N_WINDOW + 1, figsize=(4*(N_WINDOW + 1), 5))
        
        # Plot trajectory timesteps
        colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
        for t in range(N_WINDOW):
            ax = axes[t]
            ax.imshow(rgb_vis)
            
            # Plot trajectory up to this timestep
            if t > 0:
                # Draw trajectory line
                for t_prev in range(t):
                    ax.plot([pred_trajectory_2d[t_prev, 0], pred_trajectory_2d[t_prev+1, 0]], 
                           [pred_trajectory_2d[t_prev, 1], pred_trajectory_2d[t_prev+1, 1]], 
                           '-', color='lime', linewidth=2, alpha=0.5)
                    ax.plot([trajectory_2d[t_prev, 0], trajectory_2d[t_prev+1, 0]], 
                           [trajectory_2d[t_prev, 1], trajectory_2d[t_prev+1, 1]], 
                           '-', color='white', linewidth=2, alpha=0.5, linestyle='--')
            
            # Plot current timestep keypoints
            ax.scatter(trajectory_2d[t, 0], trajectory_2d[t, 1], c='white', s=100, 
                      marker='o', edgecolors='black', linewidths=2, label='GT', zorder=10)
            ax.scatter(pred_trajectory_2d[t, 0], pred_trajectory_2d[t, 1], c=colors[t], s=100, 
                      marker='x', linewidths=3, label='Pred', zorder=10)
            
            # Compute errors
            pixel_err = np.linalg.norm(trajectory_2d[t] - pred_trajectory_2d[t])
            height_err = abs(trajectory_3d[t, 2] - pred_height[t]) * 1000
            gripper_err = abs(float(trajectory_gripper[t]) - float(pred_gripper[t]))
            
            ax.set_title(
                f"t={t}\nPx:{pixel_err:.1f}px | H Err:{height_err:.1f}mm\nG:{pred_gripper[t]:.2f} (GT:{trajectory_gripper[t]:.2f}, Err:{gripper_err:.2f})",
                fontsize=8,
            )
            ax.axis('off')
            if t == 0:
                ax.legend(loc='upper right', fontsize=6)
        
        # Rendered robot at predicted 3D last timestep (sequential IK)
        axes[N_WINDOW].imshow(rendered_overlay_pred)
        axes[N_WINDOW].set_title(f"Pred last-step IK\nIK Err: {ik_error_last*1000:.1f}mm", fontsize=10)
        axes[N_WINDOW].axis('off')
        
        plt.suptitle(f"Trajectory Prediction → 3D Lifting → IK | {episode_id}", fontsize=14, fontweight='bold')
        plt.tight_layout()
        plt.show()
    
    print("\n✓ Test complete!")


if __name__ == "__main__":
    main()
