"""Live test trajectory predictions with direct 6D joint regression (no IK) from camera stream."""
import sys
import os
from pathlib import Path
import torch
import torch.nn.functional as F
import cv2
import numpy as np
import matplotlib.pyplot as plt
import math
import argparse
import mujoco
import pickle
import time
from scipy.spatial.transform import Rotation as R

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
sys.path.insert(0, os.path.dirname(__file__))

from data import N_WINDOW, N_JOINTS, project_3d_to_2d
from model import ACTJointsTrajectoryPredictor
import model as model_module
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs import EXOSKELETON_CONFIGS
from exo_utils import (
    get_link_poses_from_robot, 
    position_exoskeleton_meshes, 
    render_from_camera_pose,
    detect_and_set_link_poses,
    estimate_robot_state
)
from robot_models.so100_controller import Arm

try:
    from torchvision.utils import make_grid as tv_make_grid  # type: ignore
except Exception:
    tv_make_grid = None

# Configuration
IMAGE_SIZE = 448


def draw_tracks_on_rgb(rgb_vis_raw, tracks_2d_display):
    """Draw current 2D tracks onto RGB image. Lime line + per-point 'x' markers with viridis colors."""
    img = (np.clip(rgb_vis_raw, 0.0, 1.0) * 255.0).astype(np.uint8).copy()
    if tracks_2d_display is not None and len(tracks_2d_display) > 0:
        h, w = img.shape[:2]
        n = len(tracks_2d_display)
        pts = np.clip(tracks_2d_display.astype(np.int32), [0, 0], [w - 1, h - 1])
        colors = plt.cm.viridis(np.linspace(0, 1, n))
        color_lime = (0, 255, 0)
        for i in range(len(pts) - 1):
            cv2.line(img, tuple(pts[i]), tuple(pts[i + 1]), color_lime, 2)
        d = 6
        for i in range(len(pts)):
            r, g, b = colors[i, 0], colors[i, 1], colors[i, 2]
            color_rgb = (int(r * 255), int(g * 255), int(b * 255))
            x, y = int(pts[i, 0]), int(pts[i, 1])
            cv2.line(img, (x - d, y - d), (x + d, y + d), color_rgb, 2)
            cv2.line(img, (x - d, y + d), (x + d, y - d), color_rgb, 2)
    return img.astype(np.float32) / 255.0


# Virtual gripper keypoint in local gripper frame (from dataset generation)
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

# Virtual gripper keypoint body name for FK-based 2D projection
KP_BODY_NAME = "virtual_gripper_keypoint"


def joints_to_trajectory_2d(mj_model, mj_data, trajectory_qpos, camera_pose_world, cam_K):
    """Compute 2D trajectory by forward kinematics: set qpos for each timestep, read keypoint 3D, project to 2D.
    trajectory_qpos: (N_WINDOW, 7) full qpos (6 joints + gripper) per timestep.
    Returns (N_WINDOW, 2).
    """
    kp_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, KP_BODY_NAME)
    nq = min(mj_data.qpos.size, trajectory_qpos.shape[1])
    out = []
    for t in range(len(trajectory_qpos)):
        mj_data.qpos[:nq] = trajectory_qpos[t][:nq]
        mujoco.mj_forward(mj_model, mj_data)
        kp_3d = mj_data.xpos[kp_body_id].copy()
        p2d = project_3d_to_2d(kp_3d, camera_pose_world, cam_K)
        if p2d is None:
            out.append([0.0, 0.0])
        else:
            out.append(p2d)
    return np.array(out, dtype=np.float64)


def extract_pred_2d_and_height_from_volume(volume_logits):
    """From volume (B, N_WINDOW, N_HEIGHT_BINS, H, W) get pred 2D and height per timestep.
    For each t: argmax over full volume gives (h_bin, y, x); use (x,y) and decode h_bin to height.
    """
    B, N, Nh, H, W = volume_logits.shape
    device = volume_logits.device
    pred_2d = torch.zeros(B, N, 2, device=device, dtype=torch.float32)
    pred_height_bins = torch.zeros(B, N, device=device, dtype=torch.long)
    for t in range(N):
        vol_t = volume_logits[:, t]  # (B, Nh, H, W)
        max_over_h, _ = vol_t.max(dim=1)  # (B, H, W)
        flat_idx = max_over_h.view(B, -1).argmax(dim=1)  # (B,)
        py = flat_idx // W
        px = flat_idx % W
        pred_2d[:, t, 0] = px.float()
        pred_2d[:, t, 1] = py.float()
        pred_height_bins[:, t] = vol_t[
            torch.arange(B, device=device), :, py, px
        ].argmax(dim=1)
    bin_centers = torch.linspace(0.0, 1.0, N_HEIGHT_BINS, device=device)
    min_h = model_module.MIN_HEIGHT
    max_h = model_module.MAX_HEIGHT
    normalized = bin_centers[pred_height_bins]
    pred_height = normalized * (max_h - min_h) + min_h
    return pred_2d, pred_height


def decode_gripper_bins(bin_logits):
    """Decode gripper bin logits back to continuous gripper values.
    
    Args:
        bin_logits: (B, N_WINDOW, N_GRIPPER_BINS) logits for each bin
    
    Returns:
        gripper_values: (B, N_WINDOW) continuous gripper values in [MIN_GRIPPER, MAX_GRIPPER]
    """
    # Access MIN_GRIPPER/MAX_GRIPPER from model module at runtime (updated by checkpoint loading)
    min_g = model_module.MIN_GRIPPER
    max_g = model_module.MAX_GRIPPER
    # Get predicted bin indices (argmax)
    bin_indices = bin_logits.argmax(dim=-1)  # (B, N_WINDOW)
    # Convert bin indices to continuous values (use bin centers)
    bin_centers = torch.linspace(0.0, 1.0, N_GRIPPER_BINS, device=bin_logits.device)  # (N_GRIPPER_BINS,)
    normalized = bin_centers[bin_indices]  # (B, N_WINDOW)
    # Denormalize to [MIN_GRIPPER, MAX_GRIPPER]
    gripper_values = normalized * (max_g - min_g) + min_g
    return gripper_values


def extract_gripper_logits_at_pixels(gripper_logits, pixel_2d):
    """Index per-pixel gripper logits at given (x, y) for each timestep (decode at pred pixel during inference).
    gripper_logits: (B, N_WINDOW, N_GRIPPER_BINS, H, W), pixel_2d: (B, N_WINDOW, 2) -> (B, N_WINDOW, N_GRIPPER_BINS)
    """
    B, N, Ng, H, W = gripper_logits.shape
    device = gripper_logits.device
    px = pixel_2d[..., 0].long().clamp(0, W - 1)
    py = pixel_2d[..., 1].long().clamp(0, H - 1)
    batch_idx = torch.arange(B, device=device).view(B, 1).expand(B, N)
    time_idx = torch.arange(N, device=device).view(1, N).expand(B, N)
    return gripper_logits[batch_idx, time_idx, :, py, px]


def unproject_2d_to_ray(point_2d, camera_pose, cam_K):
    """Unproject 2D point to camera ray in world coordinates.
    
    Args:
        point_2d: (2,) 2D pixel coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        cam_pos: (3,) camera position in world frame
        ray_direction: (3,) normalized ray direction in world frame
    """
    # Camera position in world frame
    cam_pos_world = np.linalg.inv(camera_pose)[:3, 3]
    
    # Unproject pixel to camera frame
    K_inv = np.linalg.inv(cam_K)
    point_2d_h = np.array([point_2d[0], point_2d[1], 1.0])
    ray_cam = K_inv @ point_2d_h
    
    # Transform ray to world frame
    cam_rot_world = np.linalg.inv(camera_pose)[:3, :3]
    ray_world = cam_rot_world @ ray_cam
    ray_world = ray_world / np.linalg.norm(ray_world)  # Normalize
    
    return cam_pos_world, ray_world


def recover_3d_from_direct_keypoint_and_height(kp_2d_image, height, camera_pose, cam_K):
    """Recover 3D keypoint from 2D projection and height.
    
    Args:
        kp_2d_image: (2,) 2D pixel coordinates
        height: height (z-coordinate in world frame)
        camera_pose: (4, 4) camera pose matrix
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (3,) 3D keypoint position, or None if invalid
    """
    cam_pos, ray_direction = unproject_2d_to_ray(kp_2d_image, camera_pose, cam_K)
    
    # Ray equation: point = cam_pos + t * ray_direction
    # We want point[2] = height (z coordinate)
    if abs(ray_direction[2]) < 1e-6:
        return None  # Ray is parallel to height plane
    
    # Solve for t where cam_pos[2] + t * ray_direction[2] = height
    t = (height - cam_pos[2]) / ray_direction[2]
    
    if t < 0:
        return None  # Point is behind camera
    
    # Compute 3D point
    point_3d = cam_pos + t * ray_direction
    
    return point_3d


def preprocess_image(rgb, image_size=IMAGE_SIZE):
    """Preprocess RGB image for model input.
    
    Args:
        rgb: (H, W, 3) RGB image in [0, 255] or [0, 1]
        image_size: Target image size
    
    Returns:
        rgb_tensor: (3, H, W) normalized tensor
        rgb_vis: (H, W, 3) visualization image in [0, 1]
    """
    # Convert to float if needed
    if rgb.max() > 1.0:
        rgb = rgb.astype(np.float32) / 255.0
    
    H_orig, W_orig = rgb.shape[:2]

    # Resize to target size
    if H_orig != image_size or W_orig != image_size:
        rgb = cv2.resize(rgb, (image_size, image_size), interpolation=cv2.INTER_LINEAR)
    
    # Convert to tensor and normalize for DINOv2
    rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()  # (3, H, W)
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    rgb_tensor = (rgb_tensor - mean) / std
    
    # For visualization, denormalize
    rgb_vis = rgb.copy()
    
    return rgb_tensor, rgb_vis


def main():
    parser = argparse.ArgumentParser(description='Live test trajectory predictions with direct joint regression (no IK)')
    parser.add_argument('--ask_for_write', action='store_true',
                       help='Ask for user input to write to robot')
    parser.add_argument('--dont_write', action='store_true',
                       help='Skip writing to robot')
    _script_dir = Path(__file__).resolve().parent
    CHECKPOINT_PATH = str(_script_dir / "checkpoints" / "act_baseline_joints" / "latest.pth")
    parser.add_argument('--checkpoint', type=str, default=CHECKPOINT_PATH,
                       help='Path to model checkpoint')
    parser.add_argument('--camera', type=int, default=0,
                       help='Camera device ID (default: 0)')
    parser.add_argument('--no_arm', action='store_true',
                       help='No arm connected, use image-based estimation')
    parser.add_argument('--exo', type=str, default="so100_adhesive", 
                       choices=list(EXOSKELETON_CONFIGS.keys()), 
                       help="Exoskeleton configuration to use")
    parser.add_argument('--dont_visualize', action='store_true',
                       help='Skip all visualizations (MuJoCo rendering and matplotlib plotting) to speed up inference')
    parser.add_argument('--no_render', action='store_true',
                       help='Skip MuJoCo rendering but keep matplotlib visualizations (faster than full visualization)')
    parser.add_argument('--just_rgb', action='store_true',
                       help='Show only the raw RGB stream panel (no 2D tracks or gripper panes)')
    parser.add_argument('--no_track_overlay', action='store_true',
                       help='Do not draw predicted tracks on the RGB image')
    args = parser.parse_args()
    
    # Setup device
    device = torch.device("mps" if torch.backends.mps.is_available() else 
                         "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Load model (ACT joints: direct 6D joint + gripper regression, no IK)
    print(f"\nLoading model from {args.checkpoint}...")
    model = ACTJointsTrajectoryPredictor(target_size=IMAGE_SIZE, n_window=N_WINDOW, freeze_backbone=False)
    checkpoint = torch.load(args.checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    if 'min_gripper' in checkpoint and 'max_gripper' in checkpoint:
        model_module.MIN_GRIPPER = checkpoint['min_gripper']
        model_module.MAX_GRIPPER = checkpoint['max_gripper']
        print(f"✓ Loaded gripper range from checkpoint: [{checkpoint['min_gripper']:.6f}, {checkpoint['max_gripper']:.6f}]")
    else:
        print(f"⚠ Checkpoint doesn't have min_gripper/max_gripper! Using defaults.")
    print(f"✓ Loaded model from epoch {checkpoint['epoch']}")
    
    # Initialize robot if using direct robot state
    arm = None
    use_robot_state = True
    calib_path = "robot_models/arm_offsets/rescrew_fromimg.pkl"
    if os.path.exists(calib_path) and not args.no_arm:
        arm = Arm(pickle.load(open(calib_path, 'rb')))
        print("✓ Connected to robot for direct joint state reading")
    else:
        print(f"⚠️ Calibration file not found at {calib_path}, falling back to image-based estimation")
        use_robot_state = False

    if arm is not None:
        start_pos = np.array([0.04078509, 4.33910383, 1.71240746, 1.58195613, 1.54817129, 0.17400006])
        #start_pos=np.array([6.183330982897405, 4.051123761386936, 1.927748584159554, 1.4132532496516983, 3.0797872859107915, 0.9894176081862386])
        last_pos = arm.get_pos()
        if not args.dont_write:
            arm.write_pos(start_pos, slow=False)
        while True:  # keep writing until the position is reached
            curr_pos = arm.get_pos()
            if np.max(np.abs(curr_pos - last_pos)) < 0.01:
                break
            last_pos = curr_pos
        print("done moving to high position")
    
    # Setup MuJoCo
    robot_config = SO100AdhesiveConfig()
    mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
    mj_data = mujoco.MjData(mj_model)
    
    # Initialize camera
    cap = cv2.VideoCapture(args.camera)
    if not cap.isOpened():
        raise RuntimeError(f"Failed to open camera device {args.camera}")
    
    # Get first frame to determine resolution
    ret, frame = cap.read()
    while not ret:
        ret, frame = cap.read()
    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    height, width = rgb.shape[:2]
    print(f"Camera resolution: {width}x{height}")
    
    cam_K = None
    
    print("\n" + "="*60)
    print("Live inference started. Press 'q' to quit.")
    print("="*60)
    
    # Track previous figure for cleanup
    main.prev_fig = None
    
    first_write=True
    try:
        while True:
            frame_start_time = time.time()
            
            # Capture frame
            ret, frame = cap.read()
            if not ret:
                print("Failed to read frame from camera")
                continue
            
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            # Raw RGB for display (update every iteration)
            rgb_vis_raw = (rgb.astype(np.float32) / 255.0) if rgb.max() > 1.0 else rgb.copy()
            if not args.dont_visualize and getattr(main, "vis_state", None) is not None:
                st = main.vis_state
                rgb_for_panel = draw_tracks_on_rgb(rgb_vis_raw, None if args.no_track_overlay else st.get("tracks_2d_display"))
                st["im_rgb"].set_data(rgb_for_panel)
                st["fig"].canvas.draw_idle()
                st["fig"].canvas.flush_events()

            # Get joint states and camera pose
            camera_pose_world = None
            if use_robot_state and not args.no_arm:
                joint_state = arm.get_pos()
                mj_data.qpos[:] = mj_data.ctrl[:] = joint_state
                mujoco.mj_forward(mj_model, mj_data)

                
            
            # Detect camera pose and intrinsics
            try:
                link_poses, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
                    rgb, mj_model, mj_data, robot_config, cam_K=cam_K
                )
                position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
                
                if not use_robot_state:
                    configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=35)
                    mj_data.qpos[:] = mj_data.ctrl[:] = configuration.q
                    mujoco.mj_forward(mj_model, mj_data)
            except Exception as e:
                print(f"Error detecting link poses: {e}")
                camera_pose_world = None
                continue

            if 0:
                rendered_img_gt = render_from_camera_pose(mj_model, mj_data, camera_pose_world, cam_K, rgb.shape[0], rgb.shape[1], segmentation=False)/255
                rendered_img_gt = cv2.resize(rendered_img_gt, (width, height), interpolation=cv2.INTER_LINEAR)
                rgb_vis = rgb.copy()/255
                rgb_vis = rgb_vis * 0.5 + rendered_img_gt * 0.5
                plt.imshow(rgb_vis)
                plt.show()
            
            if camera_pose_world is None or cam_K is None:
                print("⚠️ Could not detect camera pose or intrinsics, skipping frame")
                continue
            
            # Preprocess image for model
            rgb_tensor, rgb_vis_resized = preprocess_image(rgb.copy(), IMAGE_SIZE)
            
            # Current robot state for conditioning: 6 joints + gripper (no IK)
            current_joints = mj_data.qpos[:N_JOINTS].astype(np.float32)  # (6,)
            current_gripper = float(mj_data.qpos[-1]) if mj_data.qpos.size > 0 else 0.0
            current_joints_t = torch.from_numpy(current_joints).unsqueeze(0).to(device)  # (1, 6)
            current_gripper_t = torch.tensor([current_gripper], device=device, dtype=torch.float32)  # (1,)

            # Run ACT joints model (direct 6D joints + gripper regression)
            with torch.no_grad():
                rgb_batch = rgb_tensor.unsqueeze(0).to(device)
                pred_joints, pred_gripper = model(rgb_batch, training=False, current_joints=current_joints_t, current_gripper_state=current_gripper_t)
            
            pred_joints_np = pred_joints[0].cpu().numpy()  # (N_WINDOW, 6)
            pred_gripper = pred_gripper[0].cpu().numpy()  # (N_WINDOW,)
            
            # Build full qpos trajectory (6 joints + gripper per timestep) — no IK
            ik_trajectory_qpos = np.array([np.concatenate([pred_joints_np[t], [pred_gripper[t]]]) for t in range(N_WINDOW)], dtype=np.float64)  # (N_WINDOW, 7)
            
            # Scale intrinsics for 2D projection (for visualization)
            cam_K_norm = cam_K.copy()
            cam_K_norm[0] /= width
            cam_K_norm[1] /= height
            cam_K_model = cam_K_norm.copy()
            cam_K_model[0] *= IMAGE_SIZE
            cam_K_model[1] *= IMAGE_SIZE
            if not hasattr(main, 'printed_intrinsics'):
                print(f"\nDebug - Intrinsics scaling (for 2D projection): fx={cam_K_model[0,0]:.2f}, fy={cam_K_model[1,1]:.2f}")
                main.printed_intrinsics = True
            
            # Compute 2D trajectory for visualization via forward kinematics
            pred_trajectory_2d = joints_to_trajectory_2d(mj_model, mj_data, ik_trajectory_qpos, camera_pose_world, cam_K_model)
            
            print(f"✓ Predicted joint trajectory with {len(ik_trajectory_qpos)} waypoints (no IK)")
            
            # Visualization section - 3 panels or --just_rgb single panel (same as streaming_live_continuous_test)
            if not args.dont_visualize:
                if not hasattr(main, "vis_state") or (main.vis_state is None):
                    plt.ion()
                    if args.just_rgb:
                        fig, ax_rgb = plt.subplots(1, 1, figsize=(10, 6), constrained_layout=True)
                        im_rgb = ax_rgb.imshow(rgb_vis_raw)
                        ax_rgb.set_title("Raw RGB stream")
                        ax_rgb.axis("off")
                        main.vis_state = {"fig": fig, "ax_rgb": ax_rgb, "im_rgb": im_rgb, "tracks_2d_display": None}
                    else:
                        rgb_vis_model = rgb_vis_resized.copy()
                        fig, axs = plt.subplots(1, 3, figsize=(14, 5), constrained_layout=True)
                        ax_traj, ax_lines, ax_rgb = axs[0], axs[1], axs[2]
                        im_traj = ax_traj.imshow(rgb_vis_model)
                        traj_line, = ax_traj.plot([], [], "-", color="lime", linewidth=2, alpha=0.6)
                        colors = plt.cm.viridis(np.linspace(0, 1, N_WINDOW))
                        scat = ax_traj.scatter([], [], s=40, marker="x", linewidths=2)
                        ax_traj.set_title("2D projected trajectory (ACT joints)")
                        ax_traj.axis("off")
                        ts = np.arange(N_WINDOW)
                        gripper_line, = ax_lines.plot(ts, pred_gripper, "s-", color="red", linewidth=2, markersize=4)
                        ax_lines.axhline(y=model_module.MIN_GRIPPER, color="red", linestyle="--", linewidth=1, alpha=0.35)
                        ax_lines.axhline(y=model_module.MAX_GRIPPER, color="green", linestyle="--", linewidth=1, alpha=0.35)
                        ax_lines.set_xlabel("Timestep")
                        ax_lines.set_ylabel("Gripper", color="red")
                        ax_lines.tick_params(axis="y", labelcolor="red")
                        ax_lines.grid(alpha=0.3)
                        ax_lines.set_title("Gripper trajectory")
                        im_rgb = ax_rgb.imshow(rgb_vis_raw)
                        ax_rgb.set_title("Raw RGB stream")
                        ax_rgb.axis("off")
                        main.vis_state = {
                            "fig": fig, "ax_traj": ax_traj, "im_traj": im_traj, "traj_line": traj_line,
                            "scat": scat, "colors": colors, "ax_lines": ax_lines, "gripper_line": gripper_line,
                            "ax_rgb": ax_rgb, "im_rgb": im_rgb, "tracks_2d_display": None,
                        }

                tracks_2d_display = pred_trajectory_2d * np.array([width / IMAGE_SIZE, height / IMAGE_SIZE])
                main.vis_state["tracks_2d_display"] = tracks_2d_display

                if not args.just_rgb:
                    st = main.vis_state
                    rgb_vis_model = rgb_vis_resized.copy()
                    st["im_traj"].set_data(rgb_vis_model)
                    st["traj_line"].set_data(pred_trajectory_2d[:, 0], pred_trajectory_2d[:, 1])
                    st["scat"].set_offsets(pred_trajectory_2d[:, :2])
                    st["scat"].set_color(st["colors"])
                    st["gripper_line"].set_ydata(pred_gripper)

                rgb_with_tracks = draw_tracks_on_rgb(rgb_vis_raw, None if args.no_track_overlay else tracks_2d_display)
                main.vis_state["im_rgb"].set_data(rgb_with_tracks)
                main.vis_state["fig"].canvas.draw_idle()
                main.vis_state["fig"].canvas.flush_events()
            
            # Execute robot trajectory (regardless of visualization)
            print(pred_gripper)

            
            # Maintain framerate
            elapsed = time.time() - frame_start_time
            sleep_time = max(0, 0.1 - elapsed)  # ~10 fps
            if sleep_time > 0:
                time.sleep(sleep_time)
            
            if arm is not None:
                print(arm.get_pos())
            if not first_write and not args.dont_write and arm is not None:
                user_inp = input("Write to robot? (y/n): ")
                if user_inp == 'y':#not first_write:
                #if 1:#not first_write:
                    for traj_i,targ_pos in enumerate(ik_trajectory_qpos[:]):
                        print(targ_pos)
                        pred_grip=pred_gripper[traj_i]
                        #if pred_grip<.7: pred_grip=-.2
                        #else: pred_grip=1
                        targ_pos[-1]=pred_grip
                        if 1.4<targ_pos[3]<4:
                            for _ in range(20):print("bad ik coming, skipping")
                            #continue
                        arm.write_pos(targ_pos,slow=False)
                        while True: # keep writing until the position is reached
                            # Keep RGB stream and plot updating while waiting for robot
                            ret, frame = cap.read()
                            if ret:
                                rgb_wait = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                                rgb_vis_raw = (rgb_wait.astype(np.float32) / 255.0) if rgb_wait.max() > 1.0 else rgb_wait.copy()
                                if not args.dont_visualize and getattr(main, "vis_state", None) is not None:
                                    st = main.vis_state
                                    rgb_with_tracks = draw_tracks_on_rgb(rgb_vis_raw, None if args.no_track_overlay else st.get("tracks_2d_display"))
                                    st["im_rgb"].set_data(rgb_with_tracks)
                                    st["fig"].canvas.draw_idle()
                                    st["fig"].canvas.flush_events()
                            curr_pos = arm.get_pos()
                            delta = np.max(np.abs(curr_pos - last_pos))
                            print("still moving", delta)
                            if delta < 0.06:
                                break
                            last_pos = curr_pos
                            time.sleep(0.03)  # avoid burning CPU, ~30 Hz poll
                        print("done moving to high position")

                            #zz
                        #import pdb; pdb.set_trace()
                elif user_inp == 'c':
                    if getattr(main, "vis_state", None) is not None:
                        main.vis_state["tracks_2d_display"] = None
                    if arm is None:
                        continue
                    start_pos = np.array([0.04078509, 4.33910383, 1.71240746, 1.58195613, 1.54817129, 0.17400006])
                    last_pos = arm.get_pos()
                    if not args.dont_write:
                        arm.write_pos(start_pos, slow=False)
                    while True:  # keep writing until the position is reached
                        ret, frame = cap.read()
                        if ret:
                            rgb_wait = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            rgb_vis_raw = (rgb_wait.astype(np.float32) / 255.0) if rgb_wait.max() > 1.0 else rgb_wait.copy()
                            if not args.dont_visualize and getattr(main, "vis_state", None) is not None:
                                st = main.vis_state
                                rgb_with_tracks = draw_tracks_on_rgb(rgb_vis_raw, None if args.no_track_overlay else st.get("tracks_2d_display"))
                                st["im_rgb"].set_data(rgb_with_tracks)
                                st["fig"].canvas.draw_idle()
                                st["fig"].canvas.flush_events()
                        curr_pos = arm.get_pos()
                        if np.max(np.abs(curr_pos - last_pos)) < 0.01:
                            break
                        last_pos = curr_pos
                        time.sleep(0.03)
                    print("done moving to high position")

            first_write=False

    
    except KeyboardInterrupt:
        print("\nLive inference interrupted by user")
    
    finally:
        cap.release()
        if not args.dont_visualize:
            plt.close('all')
        print("\n✓ Live inference complete!")


if __name__ == "__main__":
    main()
