"""Live test trajectory predictions with 3D lifting and IK visualization from camera stream."""
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import math
import argparse
import mujoco
import mink
import pickle
import time
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

from data import N_WINDOW, project_3d_to_2d
from model import TrajectoryHeatmapPredictor, N_HEIGHT_BINS, N_GRIPPER_BINS
import model as model_module  # Import module to access updated MIN_HEIGHT/MAX_HEIGHT at runtime
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import (
    get_link_poses_from_robot, 
    position_exoskeleton_meshes, 
    render_from_camera_pose,
    detect_and_set_link_poses,
    estimate_robot_state
)
from robot_models.so100_controller import Arm

try:
    from torchvision.utils import make_grid as tv_make_grid  # type: ignore
except Exception:
    tv_make_grid = None

# Configuration
IMAGE_SIZE = 448

# Virtual gripper keypoint in local gripper frame (from dataset generation)
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Hard-coded grasp parameters from make_datasets.py
cube_to_gripper_offset = np.array([-0.04433927, -0.00871351, 0.01834697])
initial_kp_quat_wxyz = np.array([-0.01718116, -0.00338855, 0.76436812, 0.64454224])


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.
    For each t: argmax over full volume gives (h_bin, y, x); use (x,y) and decode h_bin to height.
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits back to continuous gripper values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_GRIPPER_BINS) logits for each bin
    
    Returns:
        gripper_values: (B, N_WINDOW) continuous gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by checkpoint loading)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)  # (N_GRIPPER_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_GRIPPER, MAX_GRIPPER]
    gripper_values = normalized * (max_g - min_g) + min_g
    return gripper_values


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep (decode at pred pixel during inference).
    gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W), pixel_2d: (B, N_WINDOW, 2) -> (B, N_WINDOW, N_GRIPPER_BINS)
    """
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    return gripper_logits[batch_idx, time_idx, :, py, px]


def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to camera ray in world coordinates.
    
    Args:
        point_2d: (2,) 2D pixel coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        cam_pos: (3,) camera position in world frame
        ray_direction: (3,) normalized ray direction in world frame
    """
    # Camera position in world frame
    cam_pos_world = np.linalg.inv(camera_pose)[:3, 3]
    
    # Unproject pixel to camera frame
    K_inv = np.linalg.inv(cam_K)
    point_2d_h = np.array([point_2d[0], point_2d[1], 1.0])
    ray_cam = K_inv @ point_2d_h
    
    # Transform ray to world frame
    cam_rot_world = np.linalg.inv(camera_pose)[:3, :3]
    ray_world = cam_rot_world @ ray_cam
    ray_world = ray_world / np.linalg.norm(ray_world)  # Normalize
    
    return cam_pos_world, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D projection and height.
    
    Args:
        kp_2d_image: (2,) 2D pixel coordinates
        height: height (z-coordinate in world frame)
        camera_pose: (4, 4) camera pose matrix
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (3,) 3D keypoint position, or None if invalid
    """
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (z coordinate)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D point
    point_3d = cam_pos + t * ray_direction
    
    return point_3d


def ik_to_cube_grasp(model, data, configuration, target_pos, target_kp_quat_wxyz, num_iterations=50):
    """Run IK to move virtual gripper keypoint to maintain offset from cube.
    Returns:
        optimized_qpos: (nq,) array of optimized joint positions
        final_error: scalar error in meters
    """
    
    dt = 0.01  # Control timestep
    damping = 1.0  # Damping for smooth motion

    # Run IK iterations
    for iteration in range(num_iterations):
        mujoco.mj_forward(model, data)
        configuration.update(data.qpos)
        
        # Task 1: Keypoint position and orientation (virtual_gripper_keypoint)
        # Use saved initial orientation as target to maintain reference grasp orientation
        kp_task = mink.FrameTask("virtual_gripper_keypoint", "body", position_cost=1.0, orientation_cost=0.000)
        kp_task.set_target(mink.SE3(wxyz_xyz=np.concatenate([target_kp_quat_wxyz, target_pos])))
        
        # Task 2: Posture task to regularize (use current qpos to maintain smoothness)
        posture_task = mink.PostureTask(model, cost=1e-4)
        posture_task.set_target(data.qpos)
        
        # Solve IK
        vel = mink.solve_ik(configuration, [kp_task, posture_task], dt, "daqp", limits=[mink.ConfigurationLimit(model=model)])
        configuration.integrate_inplace(vel, dt)
        data.qpos[:] = configuration.q
        
        # Update forward kinematics (like test_model_ik.py)
        mujoco.mj_forward(model, data)
    
    # Compute final error
    kp_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
    current_kp_pos = data.xpos[kp_body_id]
    pos_error = target_pos - current_kp_pos
    error_norm = np.linalg.norm(pos_error)
    
    return data.qpos.copy(), error_norm


def preprocess_image(rgb, image_size=IMAGE_SIZE):
    """Preprocess RGB image for model input.
    
    Args:
        rgb: (H, W, 3) RGB image in [0, 255] or [0, 1]
        image_size: Target image size
    
    Returns:
        rgb_tensor: (3, H, W) normalized tensor
        rgb_vis: (H, W, 3) visualization image in [0, 1]
    """
    # Convert to float if needed
    if rgb.max() > 1.0:
        rgb = rgb.astype(np.float32) / 255.0
    
    H_orig, W_orig = rgb.shape[:2]

    # Resize to target size
    if H_orig != image_size or W_orig != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    
    # Convert to tensor and normalize for DINOv2
    rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()  # (3, H, W)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_tensor = (rgb_tensor - mean) / std
    
    # For visualization, denormalize
    rgb_vis = rgb.copy()
    
    return rgb_tensor, rgb_vis


def main():
    parser = argparse.ArgumentParser(description='Live test trajectory predictions with IK from camera stream')
    parser.add_argument('--ask_for_write', action='store_true',
                       help='Ask for user input to write to robot')
    parser.add_argument('--dont_write', action='store_true',
                       help='Skip writing to robot')
    parser.add_argument('--checkpoint', type=str, default="volume_dino_tracks/checkpoints/volume_dino_tracks/latest.pth",
                       help='Path to model checkpoint')
    parser.add_argument('--camera', type=int, default=0,
                       help='Camera device ID (default: 0)')
    parser.add_argument('--num_ik_iters', type=int, default=50,
                       help='Number of IK iterations')
    parser.add_argument('--no_arm', action='store_true',
                       help='No arm connected, use image-based estimation')
    parser.add_argument('--exo', type=str, default="so100_adhesive", 
                       choices=list(EXOSKELETON_CONFIGS.keys()), 
                       help="Exoskeleton configuration to use")
    parser.add_argument('--dont_visualize', action='store_true',
                       help='Skip all visualizations (MuJoCo rendering and matplotlib plotting) to speed up inference')
    parser.add_argument('--no_render', action='store_true',
                       help='Skip MuJoCo rendering but keep matplotlib visualizations (faster than full visualization)')
    args = parser.parse_args()

    # Setup device
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model
    print(f"\nLoading model from {args.checkpoint}...")
    model = TrajectoryHeatmapPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    # Load MIN_HEIGHT, MAX_HEIGHT, MIN_GRIPPER, MAX_GRIPPER from checkpoint if available
    if 'min_height' in checkpoint and 'max_height' in checkpoint:
        model_module.MIN_HEIGHT = checkpoint['min_height']
        model_module.MAX_HEIGHT = checkpoint['max_height']
        print(f"✓ Loaded height range from checkpoint: [{checkpoint['min_height']:.6f}, {checkpoint['max_height']:.6f}] m")
        print(f"  ({checkpoint['min_height']*1000:.2f}mm to {checkpoint['max_height']*1000:.2f}mm)")
    else:
        # Checkpoint doesn't have height range - compute from dataset if available
        print(f"⚠ Checkpoint doesn't have min_height/max_height!")
        print(f"  Using hardcoded values: MIN_HEIGHT={model_module.MIN_HEIGHT:.6f}, MAX_HEIGHT={model_module.MAX_HEIGHT:.6f}")
        if model_module.MIN_HEIGHT == model_module.MAX_HEIGHT:
            print(f"  ⚠ WARNING: MIN_HEIGHT == MAX_HEIGHT! All height predictions will be constant!")
            print(f"  Please recompute from dataset or use a newer checkpoint with height range saved.")
    if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
        model_module.MIN_GRIPPER = checkpoint['min_gripper']
        model_module.MAX_GRIPPER = checkpoint['max_gripper']
        print(f"✓ Loaded gripper range from checkpoint: [{checkpoint['min_gripper']:.6f}, {checkpoint['max_gripper']:.6f}]")
        print(f"  (Regression: [-0.2, 0.8], values >4.0 mapped to -0.2, values >0.8 clamped to 0.8)")
    else:
        print(f"⚠ Checkpoint doesn't have min_gripper/max_gripper!")
        print(f"  Using default regression values: MIN_GRIPPER={model_module.MIN_GRIPPER:.6f}, MAX_GRIPPER={model_module.MAX_GRIPPER:.6f}")
    print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
    
    # Initialize robot if using direct robot state
    arm = None
    use_robot_state = True
    calib_path = "robot_models/arm_offsets/rescrew_school_fromimg.pkl"
    if os.path.exists(calib_path) and not args.no_arm:
        arm = Arm(pickle.load(open(calib_path, 'rb')))
        print("✓ Connected to robot for direct joint state reading")
    else:
        print(f"⚠️ Calibration file not found at {calib_path}, falling back to image-based estimation")
        use_robot_state = False

    if arm is not None:
        #start_pos = np.array([0.04078509, 4.33910383, 1.71240746, 1.58195613, 1.54817129, 0.17400006])
        start_pos=np.array([6.183330982897405, 4.051123761386936, 1.927748584159554, 1.4132532496516983, 3.0797872859107915, 0.9894176081862386])
        last_pos = arm.get_pos()
        if not args.dont_write:
            arm.write_pos(start_pos, slow=False)
        while True:  # keep writing until the position is reached
            curr_pos = arm.get_pos()
            if np.max(np.abs(curr_pos - last_pos)) < 0.01:
                break
            last_pos = curr_pos
        print("done moving to high position")
    
    # Setup MuJoCo
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)
    
    # Initialize camera
    cap = cv2.VideoCapture(args.camera)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open camera device {args.camera}")
    
    # Get first frame to determine resolution
    ret, frame = cap.read()
    while not ret:
        ret, frame = cap.read()
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    height, width = rgb.shape[:2]
    print(f"Camera resolution: {width}x{height}")
    
    cam_K = None
    
    print("\n" + "="*60)
    print("Live inference started. Press 'q' to quit.")
    print("="*60)
    
    # Track previous figure for cleanup
    main.prev_fig = None
    
    first_write=True
    try:
        while True:
            frame_start_time = time.time()
            
            # Capture frame
            ret, frame = cap.read()
            if not ret:
                print("Failed to read frame from camera")
                continue
            
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            
            # Get joint states and camera pose
            camera_pose_world = None
            if use_robot_state and not args.no_arm:
                joint_state = arm.get_pos()
                mj_data.qpos[:] = mj_data.ctrl[:] = joint_state
                mujoco.mj_forward(mj_model, mj_data)

                
            
            # Detect camera pose and intrinsics
            try:
                link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                    rgb, mj_model, mj_data, robot_config, cam_K=cam_K
                )
                position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
                
                if not use_robot_state:
                    configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=35)
                    mj_data.qpos[:] = mj_data.ctrl[:] = configuration.q
                    mujoco.mj_forward(mj_model, mj_data)
            except Exception as e:
                print(f"Error detecting link poses: {e}")
                camera_pose_world = None
                continue

            if 0:
                rendered_img_gt = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, rgb.shape[0], rgb.shape[1], segmentation=False)/255
                rendered_img_gt = cv2.resize(rendered_img_gt, (width, height), interpolation=cv2.INTER_LINEAR)
                rgb_vis = rgb.copy()/255
                rgb_vis = rgb_vis * 0.5 + rendered_img_gt * 0.5
                plt.imshow(rgb_vis)
                plt.show()
            
            if camera_pose_world is None or cam_K is None:
                print("⚠️ Could not detect camera pose or intrinsics, skipping frame")
                continue
            
            # Preprocess image for model
            rgb_tensor, rgb_vis_resized = preprocess_image(rgb.copy(), IMAGE_SIZE)
            
            # Get starting keypoint 2D position from current gripper pose
            fixed_gripper_pose = get_link_poses_from_robot(robot_config, mj_model, mj_data)["fixed_gripper"]
            gripper_rot = fixed_gripper_pose[:3, :3]
            gripper_pos = fixed_gripper_pose[:3, 3]
            kp_3d_start = gripper_rot @ kp_local + gripper_pos
            
            # Project starting keypoint to 2D using original image resolution
            kp_2d_start = project_3d_to_2d(kp_3d_start, camera_pose_world, cam_K)
            if kp_2d_start is None:
                print("⚠️ Starting keypoint behind camera, skipping frame")
                continue
            
            # Scale to model input resolution
            scale_x = IMAGE_SIZE / width
            scale_y = IMAGE_SIZE / height
            start_keypoint_2d = kp_2d_start * np.array([scale_x, scale_y])
            
            # Run model inference (volume model: volume_logits + gripper_logits; no current_height/current_gripper)
            with torch.no_grad():
                rgb_batch = rgb_tensor.unsqueeze(0).to(device)
                start_keypoint_tensor = torch.from_numpy(start_keypoint_2d).float().to(device)
                volume_logits, gripper_logits = model(
                    rgb_batch,
                    training=False,
                    start_keypoint_2d=start_keypoint_tensor,
                )
            
            # 3D selection from volume: extract pred 2D and height per timestep; decode gripper at pred pixel
            pred_2d, pred_height = extract_pred_2d_and_height_from_volume(volume_logits)
            gripper_logits_at_pred = extract_gripper_logits_at_pixels(gripper_logits, pred_2d)
            pred_gripper = decode_gripper_bins(gripper_logits_at_pred)
            
            pred_trajectory_2d = pred_2d[0].cpu().numpy()  # (N_WINDOW, 2)
            pred_height = pred_height[0].cpu().numpy()  # (N_WINDOW,)
            pred_gripper = pred_gripper[0].cpu().numpy()  # (N_WINDOW,)
            
            # Heatmaps for visualization: softmax over volume then max along ray
            pred_heatmaps = []
            for t in range(N_WINDOW):
                vol_t = volume_logits[0, t]  # (Nh, H, W)
                vol_probs = F.softmax(vol_t.view(-1), dim=0).view(vol_t.shape[0], vol_t.shape[1], vol_t.shape[2])
                max_along_ray = vol_probs.max(dim=0)[0].cpu().numpy()  # (H, W)
                pred_heatmaps.append(max_along_ray)
            pred_heatmaps = np.array(pred_heatmaps)  # (N_WINDOW, H, W)
            
            # Scale intrinsics for model resolution (448x448) - exactly like test_model_ik.py
            # First normalize intrinsics (like the dataset does)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0] /= width   # Normalize fx, cx by width
            cam_K_norm[1] /= height  # Normalize fy, cy by height
            
            # Then denormalize to IMAGE_SIZE (exactly like test_model_ik.py)
            cam_K_model = cam_K_norm.copy()
            cam_K_model[0] *= IMAGE_SIZE  # Denormalize fx, cx to IMAGE_SIZE
            cam_K_model[1] *= IMAGE_SIZE  # Denormalize fy, cy to IMAGE_SIZE
            
            # Debug: Print intrinsics to verify they match test_model_ik.py format
            if not hasattr(main, 'printed_intrinsics'):
                print(f"\nDebug - Intrinsics scaling:")
                print(f"  Original cam_K (at {width}x{height}): fx={cam_K[0,0]:.2f}, fy={cam_K[1,1]:.2f}, cx={cam_K[0,2]:.2f}, cy={cam_K[1,2]:.2f}")
                print(f"  Normalized cam_K_norm: fx={cam_K_norm[0,0]:.4f}, fy={cam_K_norm[1,1]:.4f}, cx={cam_K_norm[0,2]:.4f}, cy={cam_K_norm[1,2]:.4f}")
                print(f"  Scaled cam_K_model (at {IMAGE_SIZE}x{IMAGE_SIZE}): fx={cam_K_model[0,0]:.2f}, fy={cam_K_model[1,1]:.2f}, cx={cam_K_model[0,2]:.2f}, cy={cam_K_model[1,2]:.2f}")
                main.printed_intrinsics = True
            
            # Lift trajectory to 3D
            pred_trajectory_3d = []
            for t in range(N_WINDOW):
                pred_2d_t = pred_trajectory_2d[t]
                pred_height_t = float(pred_height[t])  # Ensure scalar
                pred_3d_t = recover_3d_from_direct_keypoint_and_height(
                    pred_2d_t, pred_height_t, camera_pose_world, cam_K_model
                )
                
                if pred_3d_t is None:
                    print(f"✗ Failed to lift waypoint {t} to 3D")
                    continue
                
                # Validate 3D target: check if it's within reasonable workspace bounds
                # Rough workspace bounds (adjust based on your robot)
                workspace_bounds = {
                    'x': [-0.5, 0.5],  # meters
                    'y': [-0.5, 0.0],  # meters (negative y is forward)
                    'z': [0.0, 0.3]    # meters (height above ground)
                }
                
                if not (workspace_bounds['x'][0] <= pred_3d_t[0] <= workspace_bounds['x'][1] and
                        workspace_bounds['y'][0] <= pred_3d_t[1] <= workspace_bounds['y'][1] and
                        workspace_bounds['z'][0] <= pred_3d_t[2] <= workspace_bounds['z'][1]):
                    print(f"⚠️ Waypoint {t} 3D target out of workspace: {pred_3d_t}")
                    # Use previous waypoint if available, otherwise skip
                    if len(pred_trajectory_3d) > 0:
                        pred_3d_t = pred_trajectory_3d[-1].copy()
                        print(f"  Using previous waypoint: {pred_3d_t}")
                    else:
                        continue
                
                pred_trajectory_3d.append(pred_3d_t)
            
            if len(pred_trajectory_3d) == 0:
                print("⚠️ Could not lift any waypoints to 3D, skipping frame")
                continue
            
            pred_trajectory_3d = np.array(pred_trajectory_3d)  # (N, 3)
            
            print(f"✓ Predicted 3D trajectory with {len(pred_trajectory_3d)} waypoints")
            
            # Setup IK configuration - reset to current robot state before IK (exactly like test_model_ik.py)
            # Get fresh robot state to ensure mj_data is in correct state
            if use_robot_state and not args.no_arm:
                current_joint_state = arm.get_pos()
                mj_data.qpos[:len(mj_data.ctrl)] = current_joint_state[:len(mj_data.ctrl)]
                mj_data.ctrl[:] = current_joint_state[:len(mj_data.ctrl)]
            # If using image-based estimation, mj_data.qpos should already be set correctly above
            
            # Forward kinematics to update robot state (like test_model_ik.py)
            mujoco.mj_forward(mj_model, mj_data)
            
            # Create fresh configuration (like test_model_ik.py)
            configuration = mink.Configuration(mj_model)
            configuration.update(mj_data.qpos)
            
            # Perform sequential IK for each timestep
            ik_trajectory_qpos = []
            ik_errors = []
            current_qpos = mj_data.qpos.copy()  # Start from current robot state
            
            for t in range(len(pred_trajectory_3d)):
                pred_3d_t = pred_trajectory_3d[t]
                
                # Set current joint state as starting point for this timestep's IK
                mj_data.qpos[:] = current_qpos[:]
                mj_data.ctrl[:] = current_qpos[:len(mj_data.ctrl)]
                mujoco.mj_forward(mj_model, mj_data)
                configuration.update(mj_data.qpos)
                
                # Perform IK to this timestep's 3D position
                optimized_qpos, ik_error = ik_to_cube_grasp(
                    mj_model,
                    mj_data,
                    configuration, 
                    pred_3d_t,
                    initial_kp_quat_wxyz,
                    num_iterations=args.num_ik_iters
                )
                
                # If IK error is too large, use previous timestep's result instead
                max_ik_error = 0.2  # 200mm threshold
                if ik_error > max_ik_error:
                    print(f"  ⚠️ t={t}: IK error too large ({ik_error*1000:.2f} mm), using previous timestep's result")
                    if len(ik_trajectory_qpos) > 0:
                        optimized_qpos = ik_trajectory_qpos[-1].copy()
                        # Recompute error with previous position
                        mj_data.qpos[:] = optimized_qpos[:]
                        mujoco.mj_forward(mj_model, mj_data)
                        kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')
                        current_kp_pos = mj_data.xpos[kp_body_id]
                        ik_error = np.linalg.norm(pred_3d_t - current_kp_pos)
                    else:
                        # First timestep failed - keep current position
                        optimized_qpos = current_qpos.copy()
                        ik_error = np.linalg.norm(pred_3d_t - mj_data.xpos[mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, 'virtual_gripper_keypoint')])
                
                ik_trajectory_qpos.append(optimized_qpos.copy())
                ik_errors.append(ik_error)
                
                # Use this result as starting point for next timestep
                current_qpos = optimized_qpos.copy()
                
                print(f"  t={t}: 3D=[{pred_3d_t[0]:.4f}, {pred_3d_t[1]:.4f}, {pred_3d_t[2]:.4f}] m, IK err={ik_error*1000:.2f} mm")
            
            ik_trajectory_qpos = np.array(ik_trajectory_qpos)  # (N_WINDOW, nq)
            ik_errors = np.array(ik_errors)  # (N_WINDOW,)
            
            print(f"✓ Sequential IK complete - Avg error: {ik_errors.mean()*1000:.2f} mm")
            
            # Visualization section - skip if --dont_visualize is set
            if not args.dont_visualize:
                # Resize model-res RGB back to original for visualization (and also keep model-res for 2D traj/heatmaps).
                rgb_vis = cv2.resize(rgb_vis_resized, (width, height), interpolation=cv2.INTER_LINEAR)
                rgb_vis_model = rgb_vis_resized.copy()  # (IMAGE_SIZE, IMAGE_SIZE, 3)

                # Final MuJoCo render overlay (or just RGB if --no_render)
                if not args.no_render:
                    mj_data.qpos[:6] = ik_trajectory_qpos[-1][:6]
                    mujoco.mj_forward(mj_model, mj_data)
                    render_res = [height // 2, width // 2]
                    render_cam_K = cam_K_norm.copy()
                    render_cam_K[0] *= render_res[1]
                    render_cam_K[1] *= render_res[0]
                    rendered_img_final = render_from_camera_pose(
                        mj_model,
                        mj_data,
                        camera_pose_world,
                        render_cam_K,
                        render_res[0],
                        render_res[1],
                        segmentation=False,
                    ) / 255.0
                    rendered_img_final = cv2.resize(rendered_img_final, (width, height), interpolation=cv2.INTER_LINEAR)
                    overlay_final = rgb_vis * 0.5 + rendered_img_final * 0.5
                else:
                    overlay_final = rgb_vis

                # Build heatmap grid image (single square-ish panel), resize to RGB size.
                heatmap_grid_vis = None
                if tv_make_grid is not None:
                    # (T,H,W) -> (T,1,H,W) tensor
                    hm = torch.from_numpy(pred_heatmaps).float().unsqueeze(1)  # (T,1,H,W)
                    nrow = int(math.ceil(math.sqrt(hm.shape[0])))
                    grid = tv_make_grid(hm, nrow=nrow, padding=2, normalize=True, scale_each=True)  # (C,Hg,Wg)
                    # Convert to HxWxC in [0,1]
                    grid_np = grid.detach().cpu().numpy()
                    if grid_np.shape[0] == 1:
                        grid_np = np.repeat(grid_np, 3, axis=0)
                    grid_np = np.transpose(grid_np, (1, 2, 0))
                    heatmap_grid_vis = cv2.resize(grid_np, (width, height), interpolation=cv2.INTER_LINEAR)
                else:
                    # Fallback: show max-over-time heatmap
                    hm = pred_heatmaps.max(axis=0)
                    hm = (hm - hm.min()) / (hm.max() - hm.min() + 1e-8)
                    hm3 = np.repeat(hm[..., None], 3, axis=2)
                    heatmap_grid_vis = cv2.resize(hm3, (width, height), interpolation=cv2.INTER_LINEAR)

                # Create figure once and update artists in-place to avoid flashing.
                if not hasattr(main, "vis_state") or (main.vis_state is None):
                    plt.ion()
                    fig, axs = plt.subplots(2, 2, figsize=(12, 10), constrained_layout=True)
                    ax_traj = axs[0, 0]
                    ax_hm = axs[0, 1]
                    ax_lines = axs[1, 0]
                    ax_render = axs[1, 1]

                    # (1) Trajectory image + trajectory line + per-timestep markers
                    im_traj = ax_traj.imshow(rgb_vis_model)
                    traj_line, = ax_traj.plot([], [], "-", color="lime", linewidth=2, alpha=0.6)
                    colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
                    scat = ax_traj.scatter([], [], s=40, marker="x", linewidths=2)
                    ax_traj.set_title("2D predicted trajectory (all timesteps)")
                    ax_traj.axis("off")

                    # (2) Heatmap grid image
                    im_hm = ax_hm.imshow(heatmap_grid_vis)
                    ax_hm.set_title("Heatmaps (make_grid) resized to RGB")
                    ax_hm.axis("off")

                    # (3) Height + gripper lines
                    ts = np.arange(N_WINDOW)
                    height_line, = ax_lines.plot(ts, pred_height * 1000.0, "o-", color="blue", linewidth=2, markersize=4)
                    ax_lines.axhline(y=model_module.MIN_HEIGHT * 1000.0, color="gray", linestyle="--", linewidth=1, alpha=0.4)
                    ax_lines.axhline(y=model_module.MAX_HEIGHT * 1000.0, color="gray", linestyle="--", linewidth=1, alpha=0.4)
                    ax_lines.set_xlabel("Timestep")
                    ax_lines.set_ylabel("Height (mm)", color="blue")
                    ax_lines.tick_params(axis="y", labelcolor="blue")
                    ax_lines.grid(alpha=0.3)

                    ax_g = ax_lines.twinx()
                    gripper_line, = ax_g.plot(ts, pred_gripper, "s-", color="red", linewidth=2, markersize=4)
                    ax_g.axhline(y=model_module.MIN_GRIPPER, color="red", linestyle="--", linewidth=1, alpha=0.35)
                    ax_g.axhline(y=model_module.MAX_GRIPPER, color="green", linestyle="--", linewidth=1, alpha=0.35)
                    ax_g.set_ylabel("Gripper", color="red")
                    ax_g.tick_params(axis="y", labelcolor="red")
                    ax_lines.set_title("Height + Gripper trajectory")

                    # (4) Final render overlay
                    im_render = ax_render.imshow(overlay_final)
                    title_render = ax_render.set_title("Final MuJoCo render overlay")
                    ax_render.axis("off")

                    main.vis_state = {
                        "fig": fig,
                        "ax_traj": ax_traj,
                        "im_traj": im_traj,
                        "traj_line": traj_line,
                        "scat": scat,
                        "colors": colors,
                        "ax_hm": ax_hm,
                        "im_hm": im_hm,
                        "ax_lines": ax_lines,
                        "height_line": height_line,
                        "ax_g": ax_g,
                        "gripper_line": gripper_line,
                        "ax_render": ax_render,
                        "im_render": im_render,
                        "title_render": title_render,
                    }
                else:
                    st = main.vis_state
                    fig = st["fig"]
                    ax_traj = st["ax_traj"]
                    im_traj = st["im_traj"]
                    traj_line = st["traj_line"]
                    scat = st["scat"]
                    colors = st["colors"]
                    im_hm = st["im_hm"]
                    height_line = st["height_line"]
                    gripper_line = st["gripper_line"]
                    im_render = st["im_render"]
                    title_render = st["title_render"]

                # Update (1) trajectory panel
                im_traj.set_data(rgb_vis_model)
                traj_line.set_data(pred_trajectory_2d[:, 0], pred_trajectory_2d[:, 1])
                scat.set_offsets(pred_trajectory_2d[:, :2])
                # Per-point colors (RGBA)
                scat.set_color(colors)

                # Update (2) heatmap grid
                im_hm.set_data(heatmap_grid_vis)

                # Update (3) lines
                height_line.set_ydata(pred_height * 1000.0)
                gripper_line.set_ydata(pred_gripper)

                # Update (4) final render
                im_render.set_data(overlay_final)
                title_render.set_text(f"Final MuJoCo render overlay (IK err {ik_errors[-1]*1000:.1f}mm)")

                # Draw without clearing/closing to reduce flashing
                fig.canvas.draw_idle()
                fig.canvas.flush_events()
            
            # Execute robot trajectory (regardless of visualization)
            print(pred_gripper)

            
            # Maintain framerate
            elapsed = time.time() - frame_start_time
            sleep_time = max(0, 0.1 - elapsed)  # ~10 fps
            if sleep_time > 0:
                time.sleep(sleep_time)
            
            print(arm.get_pos())
            if not first_write and not args.dont_write:
                if not args.ask_for_write or input("Write to robot? (y/n): ")== 'y':#not first_write:
                #if 1:#not first_write:
                    for traj_i,targ_pos in enumerate(ik_trajectory_qpos[:]):
                        print(targ_pos)
                        pred_grip=pred_gripper[traj_i]
                        #if pred_grip<.7: pred_grip=-.2
                        #else: pred_grip=1
                        targ_pos[-1]=pred_grip
                        if 1.4<targ_pos[3]<4:
                            for _ in range(20):print("bad ik coming, skipping")
                            #continue
                        arm.write_pos(targ_pos,slow=False)
                        while True: # keep writing until the position is reached
                            curr_pos=arm.get_pos()
                            delta=np.max(np.abs(curr_pos-last_pos))
                            print("still moving",delta)
                            if delta<0.06: break
                            last_pos=curr_pos
                        print("done moving to high position")

                            #zz
                        #import pdb; pdb.set_trace()
            first_write=False

    
    except KeyboardInterrupt:
        print("\nLive inference interrupted by user")
    
    finally:
        cap.release()
        if not args.dont_visualize:
            plt.close('all')
        print("\n✓ Live inference complete!")


if __name__ == "__main__":
    main()
