"""Live test trajectory predictions with 3D lifting and IK visualization from camera stream."""
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import argparse
import mujoco
import mink
import pickle
import time
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

from data import N_WINDOW, project_3d_to_2d
from model import TrajectoryHeatmapPredictor, MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER
from matplotlib.gridspec import GridSpec
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import (
    get_link_poses_from_robot, 
    position_exoskeleton_meshes, 
    render_from_camera_pose,
    detect_and_set_link_poses,
    estimate_robot_state
)
from robot_models.so100_controller import Arm

# Configuration
IMAGE_SIZE = 448
CHECKPOINT_PATH = "real_dino_tracks/checkpoints/real_tracks/best.pth"

# Virtual gripper keypoint in local gripper frame (from dataset generation)
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Hard-coded grasp parameters from make_datasets.py
cube_to_gripper_offset = np.array([-0.04433927, -0.00871351, 0.01834697])
initial_kp_quat_wxyz = np.array([-0.01718116, -0.00338855, 0.76436812, 0.64454224])


def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to camera ray in world coordinates.
    
    Args:
        point_2d: (2,) 2D pixel coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        cam_pos: (3,) camera position in world frame
        ray_direction: (3,) normalized ray direction in world frame
    """
    # Camera position in world frame
    cam_pos_world = np.linalg.inv(camera_pose)[:3, 3]
    
    # Unproject pixel to camera frame
    K_inv = np.linalg.inv(cam_K)
    point_2d_h = np.array([point_2d[0], point_2d[1], 1.0])
    ray_cam = K_inv @ point_2d_h
    
    # Transform ray to world frame
    cam_rot_world = np.linalg.inv(camera_pose)[:3, :3]
    ray_world = cam_rot_world @ ray_cam
    ray_world = ray_world / np.linalg.norm(ray_world)  # Normalize
    
    return cam_pos_world, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D projection and height.
    
    Args:
        kp_2d_image: (2,) 2D pixel coordinates
        height: height (z-coordinate in world frame)
        camera_pose: (4, 4) camera pose matrix
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (3,) 3D keypoint position, or None if invalid
    """
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (z coordinate)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D point
    point_3d = cam_pos + t * ray_direction
    
    return point_3d


def insert_intermediate_waypoints(trajectory_3d):
    """Insert intermediate waypoints between each consecutive pair of trajectory points.
    
    For each pair (start, end), creates an intermediate waypoint with:
    - X, Z coordinates from the end point
    - Height (Z coordinate) = max(start_height, end_height)
    
    Args:
        trajectory_3d: (N, 3) array of 3D points [x, y, z] where z is height
    
    Returns:
        (2*N-1, 3) array with intermediate waypoints inserted
    """
    if len(trajectory_3d) < 2:
        return trajectory_3d.copy()
    
    expanded = []
    for i in range(len(trajectory_3d) - 1):
        start = trajectory_3d[i]
        end = trajectory_3d[i + 1]
        
        # Add the start point
        expanded.append(start.copy())
        
        # Create intermediate waypoint: X,Z from end, height = max(start_height, end_height)
        intermediate = end.copy()
        intermediate[2] = max(start[2], end[2])  # Height is Z coordinate (index 2)
        expanded.append(intermediate)
    
    # Add the final point
    expanded.append(trajectory_3d[-1].copy())
    
    return np.array(expanded)


def ik_to_cube_grasp(model, data, configuration, target_pos, target_kp_quat_wxyz, num_iterations=50):
    """Run IK to move virtual gripper keypoint to maintain offset from cube.
    Returns:
        optimized_qpos: (nq,) array of optimized joint positions
        final_error: scalar error in meters
    """
    
    dt = 0.01  # Control timestep
    damping = 1.0  # Damping for smooth motion

    # Run IK iterations
    for iteration in range(num_iterations):
        mujoco.mj_forward(model, data)
        configuration.update(data.qpos)
        
        # Task 1: Keypoint position and orientation (virtual_gripper_keypoint)
        # Use saved initial orientation as target to maintain reference grasp orientation
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.001)
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([target_kp_quat_wxyz, target_pos])))
        
        # Task 2: Posture task to regularize (use current qpos to maintain smoothness)
        posture_task = mink.PostureTask(model, cost=1e-4)
        posture_task.set_target(data.qpos)
        
        # Solve IK
        vel = mink.solve_ik(configuration, [kp_task, posture_task], dt, "daqp", limits=[mink.ConfigurationLimit(model=model)])
        configuration.integrate_inplace(vel, dt)
        data.qpos[:] = configuration.q
        
        # Update forward kinematics (like test_model_ik.py)
        mujoco.mj_forward(model, data)
    
    # Compute final error
    kp_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
    current_kp_pos = data.xpos[kp_body_id]
    pos_error = target_pos - current_kp_pos
    error_norm = np.linalg.norm(pos_error)
    
    return data.qpos.copy(), error_norm


def preprocess_image(rgb, image_size=IMAGE_SIZE):
    """Preprocess RGB image for model input.
    
    Args:
        rgb: (H, W, 3) RGB image in [0, 255] or [0, 1]
        image_size: Target image size
    
    Returns:
        rgb_tensor: (3, H, W) normalized tensor
        rgb_vis: (H, W, 3) visualization image in [0, 1]
    """
    # Convert to float if needed
    if rgb.max() > 1.0:
        rgb = rgb.astype(np.float32) / 255.0
    
    H_orig, W_orig = rgb.shape[:2]
    
    # Resize to target size
    if H_orig != image_size or W_orig != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    
    # Convert to tensor and normalize for DINOv2
    rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()  # (3, H, W)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_tensor = (rgb_tensor - mean) / std
    
    # For visualization, denormalize
    rgb_vis = rgb.copy()
    
    return rgb_tensor, rgb_vis


def main():
    parser = argparse.ArgumentParser(description='Live test trajectory predictions with IK from camera stream')
    parser.add_argument('--ask_for_write', action='store_true',
                       help='Ask for user input to write to robot')
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT_PATH,
                       help='Path to model checkpoint')
    parser.add_argument('--camera', type=int, default=0,
                       help='Camera device ID (default: 0)')
    parser.add_argument('--num_ik_iters', type=int, default=50,
                       help='Number of IK iterations')
    parser.add_argument('--no_arm', action='store_true',
                       help='No arm connected, use image-based estimation')
    parser.add_argument('--exo', type=str, default="so100_adhesive", 
                       choices=list(EXOSKELETON_CONFIGS.keys()), 
                       help="Exoskeleton configuration to use")
    parser.add_argument('--dont_visualize', action='store_true',
                       help='Skip all visualizations (MuJoCo rendering and matplotlib plotting) to speed up inference')
    parser.add_argument('--no_render', action='store_true',
                       help='Skip MuJoCo rendering but keep matplotlib visualizations (faster than full visualization)')
    args = parser.parse_args()
    
    # Setup device
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    # Load MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER from checkpoint if available
    import model as model_module
    if 'min_height' in checkpoint and 'max_height' in checkpoint:
        model_module.MIN_HEIGHT = checkpoint['min_height']
        model_module.MAX_HEIGHT = checkpoint['max_height']
        print(f"✓ Loaded height range from checkpoint: [{checkpoint['min_height']:.6f}, {checkpoint['max_height']:.6f}] m")
        print(f"  ({checkpoint['min_height']*1000:.2f}mm to {checkpoint['max_height']*1000:.2f}mm)")
    else:
        # Checkpoint doesn't have height range - compute from dataset if available
        print(f"⚠ Checkpoint doesn't have min_height/max_height!")
        print(f"  Using hardcoded values: MIN_HEIGHT={model_module.MIN_HEIGHT:.6f}, MAX_HEIGHT={model_module.MAX_HEIGHT:.6f}")
        if model_module.MIN_HEIGHT == model_module.MAX_HEIGHT:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  Please recompute from dataset or use a newer checkpoint with height range saved.")
    if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
        model_module.MIN_GRIPPER = checkpoint['min_gripper']
        model_module.MAX_GRIPPER = checkpoint['max_gripper']
        print(f"✓ Loaded gripper range from checkpoint: [{checkpoint['min_gripper']:.6f}, {checkpoint['max_gripper']:.6f}]")
        print(f"  (Regression: [-0.2, 0.8], values >4.0 mapped to -0.2, values >0.8 clamped to 0.8)")
    else:
        print(f"⚠ Checkpoint doesn't have min_gripper/max_gripper!")
        print(f"  Using default regression values: MIN_GRIPPER={model_module.MIN_GRIPPER:.6f}, MAX_GRIPPER={model_module.MAX_GRIPPER:.6f}")
    print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
    
    # Initialize robot if using direct robot state
    arm = None
    use_robot_state = True
    calib_path = "robot_models/arm_offsets/rescrew2_fromimg.pkl"
    if os.path.exists(calib_path) and not args.no_arm:
        arm = Arm(pickle.load(open(calib_path, 'rb')))
        print("✓ Connected to robot for direct joint state reading")
    else:
        print(f"⚠️ Calibration file not found at {calib_path}, falling back to image-based estimation")
        use_robot_state = False

    start_pos=np.array([0.04078509, 4.33910383, 1.71240746, 1.58195613, 1.54817129, 0.07400006])
    last_pos=arm.get_pos()
    arm.write_pos(start_pos,slow=False)
    while True: # keep writing until the position is reached
        curr_pos=arm.get_pos()
        if np.max(np.abs(curr_pos-last_pos))<0.01: break
        last_pos=curr_pos
    print("done moving to high position")
    
    # Setup MuJoCo
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)
    
    # Initialize camera
    cap = cv2.VideoCapture(args.camera)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open camera device {args.camera}")
    
    # Get first frame to determine resolution
    ret, frame = cap.read()
    while not ret:
        ret, frame = cap.read()
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    height, width = rgb.shape[:2]
    print(f"Camera resolution: {width}x{height}")
    
    cam_K = None
    
    print("\n" + "="*60)
    print("Live inference started. Press 'q' to quit.")
    print("="*60)
    
    # Track previous figure for cleanup
    main.prev_fig = None
    
    first_write=True
    try:
        while True:
            frame_start_time = time.time()
            
            # Capture frame
            ret, frame = cap.read()
            if not ret:
                print("Failed to read frame from camera")
                continue
            
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Get joint states and camera pose
            camera_pose_world = None
            if use_robot_state and not args.no_arm:
                joint_state = arm.get_pos()
                mj_data.qpos[:] = mj_data.ctrl[:] = joint_state
                mujoco.mj_forward(mj_model, mj_data)

                
            
            # Detect camera pose and intrinsics
            try:
                link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                    rgb, mj_model, mj_data, robot_config, cam_K=cam_K
                )
                position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
                
                if not use_robot_state:
                    configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=35)
                    mj_data.qpos[:] = mj_data.ctrl[:] = configuration.q
                    mujoco.mj_forward(mj_model, mj_data)
            except Exception as e:
                print(f"Error detecting link poses: {e}")
                camera_pose_world = None
                continue

            if 0:
                rendered_img_gt = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, rgb.shape[0], rgb.shape[1], segmentation=False)/255
                rendered_img_gt = cv2.resize(rendered_img_gt, (width, height), interpolation=cv2.INTER_LINEAR)
                rgb_vis = rgb.copy()/255
                rgb_vis = rgb_vis * 0.5 + rendered_img_gt * 0.5
                plt.imshow(rgb_vis)
                plt.show()
            
            if camera_pose_world is None or cam_K is None:
                print("⚠️ Could not detect camera pose or intrinsics, skipping frame")
                continue
            
            # Preprocess image for model
            rgb_tensor, rgb_vis_resized = preprocess_image(rgb.copy(), IMAGE_SIZE)
            
            # Get starting keypoint 2D position from current gripper pose
            fixed_gripper_pose = get_link_poses_from_robot(robot_config, mj_model, mj_data)["fixed_gripper"]
            gripper_rot = fixed_gripper_pose[:3, :3]
            gripper_pos = fixed_gripper_pose[:3, 3]
            kp_3d_start = gripper_rot @ kp_local + gripper_pos
            
            # Project starting keypoint to 2D using original image resolution
            kp_2d_start = project_3d_to_2d(kp_3d_start, camera_pose_world, cam_K)
            if kp_2d_start is None:
                print("⚠️ Starting keypoint behind camera, skipping frame")
                continue
            
            # Scale to model input resolution
            scale_x = IMAGE_SIZE / width
            scale_y = IMAGE_SIZE / height
            start_keypoint_2d = kp_2d_start * np.array([scale_x, scale_y])
            
            # Get current height from starting keypoint 3D position (z-coordinate)
            current_height = kp_3d_start[2]  # meters
            
            # Get current gripper value from joint state (last dimension) and process
            joint_state = arm.get_pos()
            current_gripper_raw = float(joint_state[-1])  # Last value is gripper
            
            # Process gripper value: map >4.0 to -0.2, clamp to [-0.2, 0.8] (same as data.py)
            from data import process_gripper_value
            current_gripper = process_gripper_value(current_gripper_raw)
            
            # Run model inference
            with torch.no_grad():
                rgb_batch = rgb_tensor.unsqueeze(0).to(device)
                start_keypoint_tensor = torch.from_numpy(start_keypoint_2d).float().to(device)
                current_height_tensor = torch.tensor(current_height, dtype=torch.float32).to(device)
                current_gripper_tensor = torch.tensor(current_gripper, dtype=torch.float32).to(device)
                pred_logits, pred_height, pred_gripper = model(
                    rgb_batch, 
                    gt_target_heatmap=None, 
                    training=False,
                    start_keypoint_2d=start_keypoint_tensor,
                    current_height=current_height_tensor,
                    current_gripper=current_gripper_tensor
                )
            
            # Get predicted trajectories and heatmaps
            pred_trajectory_2d = []
            pred_heatmaps = []
            for t in range(N_WINDOW):
                pred_logits_t = pred_logits[0, t]  # (H, W)
                pred_probs_t = F.softmax(pred_logits_t.view(-1), dim=0).view_as(pred_logits_t).cpu().numpy()
                pred_heatmaps.append(pred_probs_t)  # Store heatmap for visualization
                pred_y, pred_x = np.unravel_index(pred_probs_t.argmax(), pred_probs_t.shape)
                pred_2d_t = np.array([pred_x, pred_y], dtype=np.float32)
                pred_trajectory_2d.append(pred_2d_t)
            
            pred_trajectory_2d = np.array(pred_trajectory_2d)  # (N_WINDOW, 2)
            pred_heatmaps = np.array(pred_heatmaps)  # (N_WINDOW, H, W)
            pred_height = pred_height[0].cpu().numpy()  # (N_WINDOW,)
            pred_gripper = pred_gripper[0].cpu().numpy()  # (N_WINDOW,)
            
            # Scale intrinsics for model resolution (448x448) - exactly like test_model_ik.py
            # First normalize intrinsics (like the dataset does)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0] /= width   # Normalize fx, cx by width
            cam_K_norm[1] /= height  # Normalize fy, cy by height
            
            # Then denormalize to IMAGE_SIZE (exactly like test_model_ik.py)
            cam_K_model = cam_K_norm.copy()
            cam_K_model[0] *= IMAGE_SIZE  # Denormalize fx, cx to IMAGE_SIZE
            cam_K_model[1] *= IMAGE_SIZE  # Denormalize fy, cy to IMAGE_SIZE
            
            # Debug: Print intrinsics to verify they match test_model_ik.py format
            if not hasattr(main, 'printed_intrinsics'):
                print(f"\nDebug - Intrinsics scaling:")
                print(f"  Original cam_K (at {width}x{height}): fx={cam_K[0,0]:.2f}, fy={cam_K[1,1]:.2f}, cx={cam_K[0,2]:.2f}, cy={cam_K[1,2]:.2f}")
                print(f"  Normalized cam_K_norm: fx={cam_K_norm[0,0]:.4f}, fy={cam_K_norm[1,1]:.4f}, cx={cam_K_norm[0,2]:.4f}, cy={cam_K_norm[1,2]:.4f}")
                print(f"  Scaled cam_K_model (at {IMAGE_SIZE}x{IMAGE_SIZE}): fx={cam_K_model[0,0]:.2f}, fy={cam_K_model[1,1]:.2f}, cx={cam_K_model[0,2]:.2f}, cy={cam_K_model[1,2]:.2f}")
                main.printed_intrinsics = True
            
            # Lift trajectory to 3D
            pred_trajectory_3d = []
            for t in range(N_WINDOW):
                pred_2d_t = pred_trajectory_2d[t]
                pred_height_t = pred_height[t]
                pred_3d_t = recover_3d_from_direct_keypoint_and_height(
                    pred_2d_t, pred_height_t, camera_pose_world, cam_K_model
                )
                
                if pred_3d_t is None:
                    print(f"✗ Failed to lift waypoint {t} to 3D")
                    continue
                
                # Validate 3D target: check if it's within reasonable workspace bounds
                # Rough workspace bounds (adjust based on your robot)
                workspace_bounds = {
                    'x': [-0.5, 0.5],  # meters
                    'y': [-0.5, 0.0],  # meters (negative y is forward)
                    'z': [0.0, 0.3]    # meters (height above ground)
                }
                
                if not (workspace_bounds['x'][0] <= pred_3d_t[0] <= workspace_bounds['x'][1] and
                        workspace_bounds['y'][0] <= pred_3d_t[1] <= workspace_bounds['y'][1] and
                        workspace_bounds['z'][0] <= pred_3d_t[2] <= workspace_bounds['z'][1]):
                    print(f"⚠️ Waypoint {t} 3D target out of workspace: {pred_3d_t}")
                    # Use previous waypoint if available, otherwise skip
                    if len(pred_trajectory_3d) > 0:
                        pred_3d_t = pred_trajectory_3d[-1].copy()
                        print(f"  Using previous waypoint: {pred_3d_t}")
                    else:
                        continue
                
                pred_trajectory_3d.append(pred_3d_t)
            
            if len(pred_trajectory_3d) == 0:
                print("⚠️ Could not lift any waypoints to 3D, skipping frame")
                continue
            
            pred_trajectory_3d = np.array(pred_trajectory_3d)  # (N, 3)
            
            # Insert intermediate waypoints: X,Z from target, height = max(start, end)
            print("\nInserting intermediate waypoints (X,Z from target, height=max(start,end))...")
            pred_trajectory_3d_expanded = insert_intermediate_waypoints(pred_trajectory_3d)
            n_waypoints = len(pred_trajectory_3d_expanded)
            print(f"✓ Expanded trajectory from {len(pred_trajectory_3d)} to {n_waypoints} waypoints")
            
            # Expand gripper array to match expanded trajectory
            # For original waypoints: use corresponding pred_gripper[t]
            # For intermediate waypoints: use gripper from the start point of that segment (previous original waypoint)
            pred_gripper_expanded = []
            n_original = len(pred_trajectory_3d)
            for i in range(n_waypoints):
                if i % 2 == 0:
                    # Original waypoint: use corresponding gripper value
                    orig_idx = i // 2
                    if orig_idx < len(pred_gripper):
                        pred_gripper_expanded.append(pred_gripper[orig_idx])
                    else:
                        # Fallback to last gripper value if index out of bounds
                        pred_gripper_expanded.append(pred_gripper[-1])
                else:
                    # Intermediate waypoint: use gripper from start of segment (previous original waypoint)
                    orig_idx = (i - 1) // 2
                    if orig_idx < len(pred_gripper):
                        pred_gripper_expanded.append(pred_gripper[orig_idx])
                    else:
                        # Fallback to first gripper value
                        pred_gripper_expanded.append(pred_gripper[0])
            pred_gripper_expanded = np.array(pred_gripper_expanded)
            print(f"✓ Expanded gripper array from {len(pred_gripper)} to {len(pred_gripper_expanded)} values")
            
            # Setup IK configuration - reset to current robot state before IK (exactly like test_model_ik.py)
            # Get fresh robot state to ensure mj_data is in correct state
            if use_robot_state and not args.no_arm:
                current_joint_state = arm.get_pos()
                mj_data.qpos[:len(mj_data.ctrl)] = current_joint_state[:len(mj_data.ctrl)]
                mj_data.ctrl[:] = current_joint_state[:len(mj_data.ctrl)]
            # If using image-based estimation, mj_data.qpos should already be set correctly above
            
            # Forward kinematics to update robot state (like test_model_ik.py)
            mujoco.mj_forward(mj_model, mj_data)
            
            # Create fresh configuration (like test_model_ik.py)
            configuration = mink.Configuration(mj_model)
            configuration.update(mj_data.qpos)
            
            # Perform sequential IK for each timestep (using expanded trajectory)
            ik_trajectory_qpos = []
            ik_errors = []
            current_qpos = mj_data.qpos.copy()  # Start from current robot state
            
            for t in range(n_waypoints):
                pred_3d_t = pred_trajectory_3d_expanded[t]
                
                # Set current joint state as starting point for this timestep's IK
                mj_data.qpos[:] = current_qpos[:]
                mj_data.ctrl[:] = current_qpos[:len(mj_data.ctrl)]
                mujoco.mj_forward(mj_model, mj_data)
                configuration.update(mj_data.qpos)
                
                # Perform IK to this timestep's 3D position
                optimized_qpos, ik_error = ik_to_cube_grasp(
                    mj_model,
                    mj_data,
                    configuration, 
                    pred_3d_t,
                    initial_kp_quat_wxyz,
                    num_iterations=args.num_ik_iters
                )
                
                # If IK error is too large, use previous timestep's result instead
                max_ik_error = 0.2  # 200mm threshold
                if ik_error > max_ik_error:
                    print(f"  ⚠️ t={t}: IK error too large ({ik_error*1000:.2f} mm), using previous timestep's result")
                    if len(ik_trajectory_qpos) > 0:
                        optimized_qpos = ik_trajectory_qpos[-1].copy()
                        # Recompute error with previous position
                        mj_data.qpos[:] = optimized_qpos[:]
                        mujoco.mj_forward(mj_model, mj_data)
                        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
                        current_kp_pos = mj_data.xpos[kp_body_id]
                        ik_error = np.linalg.norm(pred_3d_t - current_kp_pos)
                    else:
                        # First timestep failed - keep current position
                        optimized_qpos = current_qpos.copy()
                        ik_error = np.linalg.norm(pred_3d_t - mj_data.xpos[mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')])
                
                ik_trajectory_qpos.append(optimized_qpos.copy())
                ik_errors.append(ik_error)
                
                # Use this result as starting point for next timestep
                current_qpos = optimized_qpos.copy()
                
                waypoint_type = "intermediate" if (t % 2 == 1) and (t < n_waypoints - 1) else "original"
                print(f"  t={t} ({waypoint_type}): 3D=[{pred_3d_t[0]:.4f}, {pred_3d_t[1]:.4f}, {pred_3d_t[2]:.4f}] m, IK err={ik_error*1000:.2f} mm")
            
            ik_trajectory_qpos = np.array(ik_trajectory_qpos)  # (N_WINDOW, nq)
            ik_errors = np.array(ik_errors)  # (N_WINDOW,)
            
            print(f"✓ Sequential IK complete - Avg error: {ik_errors.mean()*1000:.2f} mm")
            
            # Visualization section - skip if --dont_visualize is set
            if not args.dont_visualize:
                # Resize rgb_vis_resized back to original for visualization
                rgb_vis = cv2.resize(rgb_vis_resized, (width, height), interpolation=cv2.INTER_LINEAR)
                
                # Render final timestep's IK result for row 1, last column (skip if --no_render)
                if not args.no_render:
                    mj_data.qpos[:6] = ik_trajectory_qpos[-1][:6]
                    mujoco.mj_forward(mj_model, mj_data)
                    render_res = [height // 2, width // 2]
                    render_cam_K = cam_K_norm.copy()
                    render_cam_K[0] *= render_res[1]
                    render_cam_K[1] *= render_res[0]
                    rendered_img_final = render_from_camera_pose(
                        mj_model, mj_data, camera_pose_world, render_cam_K, 
                        render_res[0], render_res[1], segmentation=False
                    ) / 255
                    rendered_img_final = cv2.resize(rendered_img_final, (width, height), interpolation=cv2.INTER_LINEAR)
                    overlay_final = rgb_vis * 0.5 + rendered_img_final * 0.5
                else:
                    # Skip rendering, just use original RGB
                    overlay_final = rgb_vis
                
                # Create visualization with 5 rows: trajectory, heatmaps, height bars, gripper bars, IK renders
                # Larger figure size with less whitespace
                fig = plt.figure(figsize=(6*(N_WINDOW + 1), 20))
                gs = GridSpec(5, N_WINDOW + 1, figure=fig, height_ratios=[3, 3, 1.5, 1.5, 3], hspace=0.1, wspace=0.1)
                
                # Resize rgb_vis_resized for trajectory visualization (model resolution)
                rgb_vis_traj = rgb_vis_resized.copy()

                # Row 1: Trajectory visualization with crosshairs and lines
                colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
                for t in range(N_WINDOW):
                    ax = fig.add_subplot(gs[0, t])
                    ax.imshow(rgb_vis_traj)
                    
                    # Plot trajectory up to this timestep
                    if t > 0:
                        for t_prev in range(t):
                            ax.plot([pred_trajectory_2d[t_prev, 0], pred_trajectory_2d[t_prev+1, 0]], 
                                   [pred_trajectory_2d[t_prev, 1], pred_trajectory_2d[t_prev+1, 1]], 
                                   '-', color='lime', linewidth=2, alpha=0.5)
                    
                    # Plot current timestep keypoint with crosshair
                    ax.scatter(pred_trajectory_2d[t, 0], pred_trajectory_2d[t, 1], c=colors[t], s=100, 
                              marker='x', linewidths=3, label='Pred', zorder=10)
                    # Add crosshair lines
                    ax.axvline(x=pred_trajectory_2d[t, 0], color=colors[t], linestyle='--', linewidth=1, alpha=0.5)
                    ax.axhline(y=pred_trajectory_2d[t, 1], color=colors[t], linestyle='--', linewidth=1, alpha=0.5)
                    
                    height_val = pred_height[t] * 1000
                    gripper_val = pred_gripper[t]  # Regression value in [MIN_GRIPPER, MAX_GRIPPER]
                    ax.set_title(f"t={t}\nH: {height_val:.2f}mm | G: {gripper_val:.2f}", fontsize=8)
                    ax.axis('off')
                    if t == 0:
                        ax.legend(loc='upper right', fontsize=6)
                
                # Row 1, last column: Rendered robot at final predicted 3D IK result
                ax_ik = fig.add_subplot(gs[0, N_WINDOW])
                ax_ik.imshow(overlay_final)
                ax_ik.set_title(f"Final IK (t={len(ik_trajectory_qpos)-1})\nErr: {ik_errors[-1]*1000:.1f}mm", fontsize=10)
                ax_ik.axis('off')
                
                # Row 2: Heatmap visualization
                for t in range(N_WINDOW):
                    ax = fig.add_subplot(gs[1, t])
                    # Show RGB with heatmap overlay
                    #ax.imshow(rgb_vis_traj)
                    ax.imshow(pred_heatmaps[t]*1e2, alpha=0.9, cmap='hot', interpolation='bilinear')
                    # Mark predicted location
                    #ax.scatter(pred_trajectory_2d[t, 0], pred_trajectory_2d[t, 1], c='cyan', s=50, marker='x', linewidths=2, zorder=10)
                    ax.set_title(f"Heatmap t={t}", fontsize=8)
                    ax.axis('off')
                
                # Row 2, last column: Average heatmap or empty
                ax_heatmap_ik = fig.add_subplot(gs[1, N_WINDOW])
                ax_heatmap_ik.axis('off')
                
                # Row 3: Height trajectory bar plots
                for t in range(N_WINDOW):
                    ax = fig.add_subplot(gs[2, t])
                    # Bar plot showing min, max, and current height
                    heights_to_show = [MIN_HEIGHT * 1000, pred_height[t] * 1000, MAX_HEIGHT * 1000]
                    colors_bar = ['gray', colors[t], 'gray']
                    x_pos = [0, 1, 2]
                    bars = ax.bar(x_pos, heights_to_show, color=colors_bar, alpha=0.7)
                    bars[1].set_alpha(1.0)  # Make current timestep more prominent
                    ax.set_xticks(x_pos)
                    ax.set_xticklabels(['Min', f't={t}', 'Max'], fontsize=6, rotation=45, ha='right')
                    ax.set_ylabel('Height (mm)', fontsize=7)
                    ax.set_title(f"H: {pred_height[t]*1000:.2f}mm", fontsize=7)
                    ax.grid(alpha=0.3, axis='y')
                    ax.set_ylim([MIN_HEIGHT * 1000 - 5, MAX_HEIGHT * 1000 + 5])
                
                # Row 3, last column: Height and Gripper trajectory over all timesteps
                ax_height_traj = fig.add_subplot(gs[2, N_WINDOW])
                timesteps = np.arange(N_WINDOW)
                ax_height_traj.plot(timesteps, pred_height * 1000, 'o-', color='blue', linewidth=2, markersize=6, label='Height (mm)')
                ax_height_traj.axhline(y=MIN_HEIGHT * 1000, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='Min H')
                ax_height_traj.axhline(y=MAX_HEIGHT * 1000, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='Max H')
                ax_height_traj.set_xlabel('Timestep', fontsize=8)
                ax_height_traj.set_ylabel('Height (mm)', fontsize=8)
                ax_height_traj.set_title('Height Trajectory', fontsize=8)
                ax_height_traj.grid(alpha=0.3)
                ax_height_traj.legend(fontsize=6)
                
                # Add gripper trajectory on secondary y-axis (regression)
                ax_gripper_traj = ax_height_traj.twinx()
                ax_gripper_traj.plot(timesteps, pred_gripper, 's-', color='red', linewidth=2, markersize=6, label='Gripper')
                ax_gripper_traj.axhline(y=MIN_GRIPPER, color='red', linestyle='--', linewidth=1, alpha=0.5, label='Min G')
                ax_gripper_traj.axhline(y=MAX_GRIPPER, color='green', linestyle='--', linewidth=1, alpha=0.5, label='Max G')
                ax_gripper_traj.set_ylabel('Gripper', fontsize=8, color='red')
                ax_gripper_traj.tick_params(axis='y', labelcolor='red')
                ax_gripper_traj.set_ylim([MIN_GRIPPER - 0.1, MAX_GRIPPER + 0.1])
                ax_gripper_traj.legend(loc='upper right', fontsize=6)
                
                # Row 4: Gripper trajectory bar plots
                for t in range(N_WINDOW):
                    ax = fig.add_subplot(gs[3, t])
                    # Bar plot showing gripper regression value
                    gripper_val = pred_gripper[t]  # Regression value in [MIN_GRIPPER, MAX_GRIPPER]
                    x_pos = [0, 1]
                    gripper_vals = [MIN_GRIPPER, gripper_val]  # Show min and prediction
                    colors_bar = ['gray', colors[t]]
                    bars = ax.bar(x_pos, gripper_vals, color=colors_bar, alpha=0.7)
                    bars[1].set_alpha(1.0)  # Make current timestep more prominent
                    ax.set_xticks(x_pos)
                    ax.set_xticklabels(['Min', f'{gripper_val:.2f}'], fontsize=6, rotation=45, ha='right')
                    ax.set_ylabel('Gripper', fontsize=7)
                    ax.set_title(f"G: {gripper_val:.2f}", fontsize=7)
                    ax.grid(alpha=0.3, axis='y')
                    ax.set_ylim([MIN_GRIPPER - 0.1, MAX_GRIPPER + 0.1])
                
                # Row 4, last column: Gripper trajectory over all timesteps
                ax_gripper_traj = fig.add_subplot(gs[3, N_WINDOW])
                timesteps = np.arange(N_WINDOW)
                ax_gripper_traj.plot(timesteps, pred_gripper, 'o-', color='red', linewidth=2, markersize=6)
                ax_gripper_traj.axhline(y=MIN_GRIPPER, color='red', linestyle='--', linewidth=1, alpha=0.5, label='Min G')
                ax_gripper_traj.axhline(y=MAX_GRIPPER, color='green', linestyle='--', linewidth=1, alpha=0.5, label='Max G')
                ax_gripper_traj.set_xlabel('Timestep', fontsize=8)
                ax_gripper_traj.set_ylabel('Gripper', fontsize=8)
                ax_gripper_traj.set_title('Gripper Trajectory', fontsize=8)
                ax_gripper_traj.grid(alpha=0.3)
                ax_gripper_traj.legend(fontsize=6)
                ax_gripper_traj.set_ylim([MIN_GRIPPER - 0.1, MAX_GRIPPER + 0.1])
                
                # Row 5: IK renders for each timestep (skip rendering if --no_render)
                # Limit to N_WINDOW to fit in grid (last column reserved for error trajectory)
                n_ik_renders = min(len(ik_trajectory_qpos), N_WINDOW)
                for t in range(n_ik_renders):
                    ax = fig.add_subplot(gs[4, t])
                    
                    if not args.no_render:
                        # Set robot to this timestep's IK joint state
                        mj_data.qpos[:6] = ik_trajectory_qpos[t][:6]
                        mujoco.mj_forward(mj_model, mj_data)
                        
                        # Render robot at this timestep's IK result
                        render_res = [height // 2, width // 2]
                        render_cam_K = cam_K_norm.copy()
                        render_cam_K[0] *= render_res[1]
                        render_cam_K[1] *= render_res[0]
                        rendered_img_t = render_from_camera_pose(
                            mj_model, mj_data, camera_pose_world, render_cam_K, 
                            render_res[0], render_res[1], segmentation=False
                        ) / 255
                        rendered_img_t = cv2.resize(rendered_img_t, (width, height), interpolation=cv2.INTER_LINEAR)
                        
                        # Create overlay
                        overlay_t = rgb_vis * 0.5 + rendered_img_t * 0.5
                    else:
                        # Skip rendering, just use original RGB
                        overlay_t = rgb_vis
                    
                    ax.imshow(overlay_t)
                    ax.set_title(f"IK t={t}\nErr: {ik_errors[t]*1000:.1f}mm", fontsize=8)
                    ax.axis('off')
                
                # Row 5, last column: IK error trajectory
                ax_ik_traj = fig.add_subplot(gs[4, N_WINDOW])
                timesteps = np.arange(len(ik_errors))
                ax_ik_traj.plot(timesteps, ik_errors * 1000, 'o-', color='red', linewidth=2, markersize=6)
                ax_ik_traj.set_xlabel('Timestep', fontsize=8)
                ax_ik_traj.set_ylabel('IK Error (mm)', fontsize=8)
                ax_ik_traj.set_title('IK Error Trajectory', fontsize=8)
                ax_ik_traj.grid(alpha=0.3)
                ax_ik_traj.set_ylim([0, max(ik_errors.max() * 1000 * 1.2, 10)])
                
                plt.suptitle("Live Trajectory Prediction → 3D Lifting → Sequential IK", fontsize=14, fontweight='bold')
                plt.subplots_adjust(left=0.02, right=0.98, top=0.96, bottom=0.02, hspace=0.1, wspace=0.1)
                plt.draw()
                plt.pause(0.01)

                # Close previous figure to avoid memory issues
                if hasattr(main, 'prev_fig'):
                    plt.close(main.prev_fig)
                main.prev_fig = fig
                
                # Check for quit
                if plt.waitforbuttonpress(timeout=0.01):
                    break
            
            # Execute robot trajectory (regardless of visualization)
            print(pred_gripper)
            
            
            # Maintain framerate
            elapsed = time.time() - frame_start_time
            sleep_time = max(0, 0.1 - elapsed)  # ~10 fps
            if sleep_time > 0:
                time.sleep(sleep_time)
            
            print(arm.get_pos())
            if not first_write:
                if not args.ask_for_write or input("Write to robot? (y/n): ")== 'y':#not first_write:
                #if 1:#not first_write:
                    for traj_i,targ_pos in enumerate(ik_trajectory_qpos[:]):
                        print(targ_pos)
                        # Use expanded gripper array that matches the expanded trajectory
                        if traj_i < len(pred_gripper_expanded):
                            pred_grip = pred_gripper_expanded[traj_i]
                        else:
                            # Fallback to last gripper value if index out of bounds
                            pred_grip = pred_gripper_expanded[-1] if len(pred_gripper_expanded) > 0 else pred_gripper[-1]
                        #if pred_grip<.7: pred_grip=-.2
                        #else: pred_grip=1
                        targ_pos[-1]=pred_grip
                        arm.write_pos(targ_pos,slow=False)
                        while True: # keep writing until the position is reached
                            curr_pos=arm.get_pos()
                            delta=np.max(np.abs(curr_pos-last_pos))
                            print("still moving",delta)
                            if delta<0.06: break
                            last_pos=curr_pos
                        print("done moving to high position")
            first_write=False


        
    
    except KeyboardInterrupt:
        print("\nLive inference interrupted by user")
    
    finally:
        cap.release()
        if not args.dont_visualize:
            plt.close('all')
        print("\n✓ Live inference complete!")


if __name__ == "__main__":
    main()
