"""Run live inference on camera stream: estimate robot state, compute DINO features, and predict trajectory."""
import sys
import os
from pathlib import Path
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from PIL import Image
from torchvision.transforms import functional as TF
import pickle
import argparse
import time
import xml.etree.ElementTree as ET
import mink


sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../.."))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from robot_models.so100_controller import Arm

import mujoco
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from exo_utils import detect_and_set_link_poses, estimate_robot_state, position_exoskeleton_meshes, render_from_camera_pose, get_link_poses_from_robot
from model import TrajectoryPredictor, MIN_HEIGHT, MAX_HEIGHT, GROUNDPLANE_X_MIN, GROUNDPLANE_X_MAX, GROUNDPLANE_Z_MIN, GROUNDPLANE_Z_MAX
from utils import project_3d_to_2d, rescale_coords, post_process_predictions, recover_3d_from_direct_keypoint_and_height, ik_to_keypoint_and_rotation
from data import KEYPOINTS_LOCAL_M_ALL, KP_INDEX, unproject_patch_to_groundplane, compute_volume_mask_for_patches

# DINO configuration
REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12

# Model configuration
MAX_TIMESTEPS = 3  # Extrema points: close, open, end
GROUNDPLANE_RANGE = 1.0
RES_LOW = 224

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Parse arguments
parser = argparse.ArgumentParser(description='Run live inference on camera stream')
parser.add_argument('--camera', '-c', type=int, default=0, help='Camera device ID (default: 0)')
parser.add_argument('--vis_2d', action='store_true', help='Show 2D visualization (like vis_eval_2d.py)')
parser.add_argument('--vis_mujoco', action='store_true', help='Show MuJoCo visualization (like test_ik.py)')
parser.add_argument('--fps', type=float, default=1.0, help='Processing framerate (default: 1.0 fps)')
parser.add_argument('--use_arm', action='store_true', help='Use arm in inference')
args = parser.parse_args()

if args.use_arm:
    arm=Arm( pickle.load(open("robot_models/arm_offsets/rescrew_fromimg.pkl", 'rb')) )
    print("✓ Connected to robot for direct joint state reading")

# If no visualization flags are set, show 2D visualization by default
if not (args.vis_2d or args.vis_mujoco):
    args.vis_2d = True
    print("No visualization flags specified, showing 2D visualization by default")

# Initialize camera
print(f"Initializing camera device {args.camera}...")
cap = cv2.VideoCapture(args.camera)
if not cap.isOpened():
    raise RuntimeError(f"Failed to open camera device {args.camera}")

# Get first frame to determine resolution
ret, frame = cap.read()
while not ret:
    ret, frame = cap.read()
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
if rgb.max() <= 1.0:
    rgb = (rgb * 255).astype(np.uint8)
H_loaded, W_loaded = rgb.shape[:2]
print(f"Camera resolution: {W_loaded}x{H_loaded}")

# Hardcode original image resolution (before downsampling)
# cam_K is calibrated for this resolution
H_orig = 1080
W_orig = 1920

# Setup robot
robot_config = SO100AdhesiveConfig()
mj_model = mujoco.MjModel.from_xml_string(robot_config.xml)
mj_data = mujoco.MjData(mj_model)

# Load DINO model
print("Loading DINO model...")
sys.path.append(REPO_DIR)
dinov3_model = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
dinov3_model.eval()

# Load PCA embedder
pca_path = Path("scratch/dino_pca_embedder.pkl")
if not pca_path.exists():
    print(f"PCA embedder not found at {pca_path}")
    exit(1)
with open(pca_path, 'rb') as f:
    pca_data = pickle.load(f)
    pca_embedder = pca_data['pca']

# Process image for DINO
def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))

# Load trajectory prediction model
model_path = Path("keypoint_testing2/tmpstorage/model.pt")
if not model_path.exists():
    print(f"Model not found at {model_path}")
    exit(1)

# Detect max_timesteps and dino_feat_dim from checkpoint
checkpoint = torch.load(model_path, map_location='cpu', weights_only=False)
detected_max_timesteps = MAX_TIMESTEPS
if 'height_timestep_emb.weight' in checkpoint:
    detected_max_timesteps = checkpoint['height_timestep_emb.weight'].shape[0]
    print(f"Detected max_timesteps={detected_max_timesteps} from checkpoint")
    if detected_max_timesteps != MAX_TIMESTEPS:
        print(f"⚠ Warning: Checkpoint has max_timesteps={detected_max_timesteps}, but code expects {MAX_TIMESTEPS}")
        print(f"Using checkpoint's max_timesteps={detected_max_timesteps} for model loading")
else:
    print(f"Could not detect max_timesteps from checkpoint, using {MAX_TIMESTEPS}")

dino_feat_dim = 64  # Default
if 'dino_proj.weight' in checkpoint:
    detected_dino_feat_dim = checkpoint['dino_proj.weight'].shape[1]
    print(f"Detected dino_feat_dim={detected_dino_feat_dim} from checkpoint")
    dino_feat_dim = detected_dino_feat_dim
else:
    print(f"Could not detect dino_feat_dim from checkpoint, using {dino_feat_dim}")

model = TrajectoryPredictor(
    dino_feat_dim=dino_feat_dim,
    max_timesteps=detected_max_timesteps,  # Use detected max_timesteps
    num_layers=3,
    num_heads=4,
    hidden_dim=128,
    num_pos_bands=4,
    groundplane_range=GROUNDPLANE_RANGE
).to(device)
model.load_state_dict(checkpoint, strict=True)
print(f"✓ Loaded model from {model_path}")
model.eval()

# Setup MuJoCo visualization (always needed for 2D vis with MuJoCo pane)
mj_model_viz = None
mj_data_viz = None
ik_configuration = None
median_dataset_rotation = np.array([[-0.99912433, -0.03007201, -0.02909046],
                                    [-0.04176828,  0.67620482,  0.73552869],
                                    [-0.00244771,  0.73609967, -0.67686874]])

if args.vis_2d or args.vis_mujoco:
    # Initialize MuJoCo visualization model for IK and rendering
    xml_root = ET.fromstring(robot_config.xml)
    worldbody = xml_root.find('worldbody')
    mj_model_viz = mujoco.MjModel.from_xml_string(ET.tostring(xml_root, encoding='unicode'))
    mj_data_viz = mujoco.MjData(mj_model_viz)
    ik_configuration = mink.Configuration(mj_model_viz)

# Setup 2D visualization if requested
fig = None
ax_rgb = None
ax_dino = None
ax_attn = None
ax_traj = None
ax_mujoco = None
ax_height = None
im_rgb = None
im_dino = None
im_attn = None
im_traj = None
scatter_pred_rgb = None
scatter_pred_dino = None
scatter_pred_traj = None
scatter_eef_rgb = None
scatter_eef_dino = None
scatter_eef_traj = None
bar_height = None

first_write=True

if args.vis_2d:
    plt.ion()  # Turn on interactive mode
    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(5, 4, figure=fig, hspace=0.3, wspace=0.3, height_ratios=[1, 1, 1, 0.4, 0.4])
    ax_rgb = fig.add_subplot(gs[0, 0])
    ax_dino = fig.add_subplot(gs[0, 1])
    ax_attn = fig.add_subplot(gs[0, 2])
    ax_traj = fig.add_subplot(gs[0, 3])
    # Second row: 4 Attention maps (for same timesteps as MuJoCo renders)
    ax_attn_0 = fig.add_subplot(gs[1, 0])
    ax_attn_1 = fig.add_subplot(gs[1, 1])
    ax_attn_2 = fig.add_subplot(gs[1, 2])
    ax_attn_3 = fig.add_subplot(gs[1, 3])
    # Third row: 4 MuJoCo renders
    ax_mujoco_0 = fig.add_subplot(gs[2, 0])
    ax_mujoco_1 = fig.add_subplot(gs[2, 1])
    ax_mujoco_2 = fig.add_subplot(gs[2, 2])
    ax_mujoco_3 = fig.add_subplot(gs[2, 3])
    ax_height = fig.add_subplot(gs[3, :])
    ax_gripper = fig.add_subplot(gs[4, :])
    
    # Initialize empty plots
    rgb_lowres_placeholder = np.zeros((RES_LOW, RES_LOW, 3), dtype=np.uint8)
    im_rgb = ax_rgb.imshow(rgb_lowres_placeholder)
    ax_rgb.set_title("RGB with Predicted Trajectory")
    ax_rgb.axis('off')
    
    im_dino = ax_dino.imshow(rgb_lowres_placeholder)
    ax_dino.set_title("DINO Features with Predicted Trajectory")
    ax_dino.axis('off')
    
    im_attn = ax_attn.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
    ax_attn.set_title("Attention Map (t=0)")
    ax_attn.axis('off')
    
    im_traj = ax_traj.imshow(rgb_lowres_placeholder)
    ax_traj.set_title("Full Predicted Trajectory")
    ax_traj.axis('off')
    
    # Initialize attention map panes (for same timesteps as MuJoCo renders)
    im_attn_0 = ax_attn_0.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
    ax_attn_0.set_title("Attention Map (t=0)")
    ax_attn_0.axis('off')
    
    im_attn_1 = ax_attn_1.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
    ax_attn_1.set_title("Attention Map (t=1)")
    ax_attn_1.axis('off')
    
    im_attn_2 = ax_attn_2.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
    ax_attn_2.set_title("Attention Map (t=2)")
    ax_attn_2.axis('off')
    
    im_attn_3 = ax_attn_3.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
    ax_attn_3.set_title("Attention Map (t=3)")
    ax_attn_3.axis('off')
    
    # Initialize MuJoCo render panes
    im_mujoco_0 = ax_mujoco_0.imshow(rgb_lowres_placeholder)
    ax_mujoco_0.set_title("MuJoCo Render (t=0)")
    ax_mujoco_0.axis('off')
    
    im_mujoco_1 = ax_mujoco_1.imshow(rgb_lowres_placeholder)
    ax_mujoco_1.set_title("MuJoCo Render (t=1)")
    ax_mujoco_1.axis('off')
    
    im_mujoco_2 = ax_mujoco_2.imshow(rgb_lowres_placeholder)
    ax_mujoco_2.set_title("MuJoCo Render (t=2)")
    ax_mujoco_2.axis('off')
    
    im_mujoco_3 = ax_mujoco_3.imshow(rgb_lowres_placeholder)
    ax_mujoco_3.set_title("MuJoCo Render (t=3)")
    ax_mujoco_3.axis('off')
    
    ax_height.set_xlabel('Timestep', fontsize=10)
    ax_height.set_ylabel('Height (m)', fontsize=10)
    ax_height.set_title('Predicted Height Trajectory', fontsize=12)
    ax_height.grid(alpha=0.3)
    
    ax_gripper.set_xlabel('Timestep', fontsize=10)
    ax_gripper.set_ylabel('Gripper Value', fontsize=10)
    ax_gripper.set_title('Predicted Gripper Open/Close Trajectory', fontsize=12)
    ax_gripper.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show(block=False)

# Processing loop
frame_interval = 1.0 / args.fps
print(f"Starting live inference at {args.fps} fps. Press 'q' to quit.")
print("Processing frames...")

try:
    while True:
        frame_start_time = time.time()
        
        # Read frame
        ret, frame = cap.read()
        if not ret:
            print("Failed to read frame")
            continue
        
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        if rgb.max() <= 1.0:
            rgb = (rgb * 255).astype(np.uint8)
        
        # Estimate robot state from image
        try:
            link_poses, camera_pose_world, cam_K, _, _, _ = detect_and_set_link_poses(rgb, mj_model, mj_data, robot_config)
            configuration, _ = estimate_robot_state(mj_model, mj_data, robot_config, link_poses, ik_iterations=55)
            mj_data.qpos[:] = configuration.q
            mj_data.ctrl[:] = configuration.q[:len(mj_data.ctrl)]
            mujoco.mj_forward(mj_model, mj_data)
            position_exoskeleton_meshes(robot_config, mj_model, mj_data, link_poses)
            mujoco.mj_forward(mj_model, mj_data)
            camera_pose = camera_pose_world
        except Exception as e:
            print(f"Error estimating robot state: {e}")
            time.sleep(frame_interval)
            continue
        
        # Get current gripper pose
        kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]
        gripper_body_id = mujoco.mj_name2id(mj_model, mujoco.mjtObj.mjOBJ_BODY, "Fixed_Jaw")
        gripper_pos = mj_data.xpos[gripper_body_id].copy()
        gripper_quat = mj_data.xquat[gripper_body_id].copy()
        from scipy.spatial.transform import Rotation as R
        gripper_rot = R.from_quat(gripper_quat[[1, 2, 3, 0]]).as_matrix()
        current_kp_3d = gripper_rot @ kp_local + gripper_pos
        current_eef_height_norm = float(np.clip((current_kp_3d[2] - MIN_HEIGHT) / (MAX_HEIGHT - MIN_HEIGHT), 0.0, 1.0))
        
        # Compute DINO features
        img_pil = Image.fromarray(rgb).convert("RGB")
        image_resized = resize_transform(img_pil)
        image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
        
        print("doing dino inference")
        with torch.no_grad():
            with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
                feats = dinov3_model.get_intermediate_layers(
                    image_resized_norm.unsqueeze(0).to(device),
                    n=range(N_LAYERS),
                    reshape=True,
                    norm=True
                )
                x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
                dim = x.shape[0]
                x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
                
                # Apply PCA
                pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, dino_feat_dim)
                
                # Get patch resolution
                h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
                dino_feat_dim_actual = pca_features_all.shape[1]  # Get actual feature dimension from PCA
                pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, dino_feat_dim)
        print("done dino inference")
        
        # Convert to tensor format
        dino_features = torch.from_numpy(pca_features_patches).float()  # (H_patches, W_patches, dino_feat_dim)
        dino_tokens_flat = dino_features.view(h_patches * w_patches, dino_feat_dim_actual).float()
        
        # Compute ground-plane coordinates for each patch
        groundplane_coords_list = []
        y_coords, x_coords = np.meshgrid(np.arange(h_patches), np.arange(w_patches), indexing='ij')
        for patch_y, patch_x in zip(y_coords.flatten(), x_coords.flatten()):
            x_gp, z_gp = unproject_patch_to_groundplane(
                patch_x, patch_y, h_patches, w_patches, H_loaded, W_loaded,
                camera_pose, cam_K, ground_y=0.0
            )
            if x_gp is not None and z_gp is not None:
                groundplane_coords_list.append([x_gp, z_gp])
            else:
                groundplane_coords_list.append([0.0, 0.0])
        
        groundplane_coords = torch.from_numpy(np.array(groundplane_coords_list, dtype=np.float32))
        
        # Compute volume mask for patches
        volume_mask = compute_volume_mask_for_patches(h_patches, w_patches, H_loaded, W_loaded, camera_pose, cam_K)
        volume_mask_flat = torch.from_numpy(volume_mask.flatten().astype(np.float32))  # (num_patches,)
        
        # Compute current EEF 2D position in patch coordinates
        current_kp_2d_image = project_3d_to_2d(current_kp_3d, camera_pose, cam_K)
        if current_kp_2d_image is None:
            print("Current EEF projection failed")
            time.sleep(frame_interval)
            continue
        
        current_kp_2d_patches = rescale_coords(
            current_kp_2d_image.reshape(1, 2),
            H_loaded, W_loaded,
            h_patches, w_patches
        )[0]
        
        current_eef_patch_x = int(np.round(np.clip(current_kp_2d_patches[0], 0, w_patches - 1)))
        current_eef_patch_y = int(np.round(np.clip(current_kp_2d_patches[1], 0, h_patches - 1)))
        current_eef_patch_idx = current_eef_patch_y * w_patches + current_eef_patch_x
        
        # Run inference
        with torch.no_grad():
            dino_tokens_batch = dino_tokens_flat.unsqueeze(0).to(device)  # (1, num_patches, dino_feat_dim)
            groundplane_coords_batch = groundplane_coords.unsqueeze(0).to(device)  # (1, num_patches, 2)
            current_eef_patch_idx_batch = torch.tensor([current_eef_patch_idx], dtype=torch.long).to(device)
            current_eef_height_batch = torch.tensor([current_eef_height_norm], dtype=torch.float32).to(device)
            
            volume_mask_batch = volume_mask_flat.unsqueeze(0).to(device)  # (1, num_patches)
            
            print("doing model inference"   )
            attention_scores, heights_pred, grippers_pred = model(
                dino_tokens_batch,
                groundplane_coords_batch,
                current_eef_patch_idx_batch,
                current_eef_height_batch,
                volume_mask=volume_mask_batch,  # Pass volume mask
                use_attention_mask=True  # Use volume mask during inference
            )
            print("done model inference")
            
            attention_scores = attention_scores.squeeze(0).cpu().numpy()  # (max_timesteps, num_patches)
            heights_pred = heights_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)
            grippers_pred = grippers_pred.squeeze(0).cpu().numpy()  # (max_timesteps,)
        
        # Post-process predictions
        trajectory_pred_3d, pred_image_coords, heights_pred_denorm = post_process_predictions(
            attention_scores, heights_pred, h_patches, w_patches, H_loaded, W_loaded, camera_pose, cam_K
        )
        
        # Lift predicted 2D direct keypoint + predicted height to 3D (like test_ik_lifting_raw.py)
        trajectory_3d_lifted = []
        for i in range(len(pred_image_coords)):
            kp_2d = pred_image_coords[i]
            height = heights_pred_denorm[i]
            kp_3d = recover_3d_from_direct_keypoint_and_height(kp_2d, height, camera_pose, cam_K)
            if kp_3d is not None:
                trajectory_3d_lifted.append(kp_3d)
        trajectory_3d_lifted = np.array(trajectory_3d_lifted) if len(trajectory_3d_lifted) > 0 else None
        
        # Run IK once for all trajectory points and save joint positions (used by both vis_2d and vis_mujoco)
        print("doing ik")
        ik_joint_positions = []
        if mj_model_viz is not None and trajectory_3d_lifted is not None and len(trajectory_3d_lifted) > 0:
            # Initialize MuJoCo visualization model state
            mj_data_viz.qpos[:] = mj_data.qpos
            mj_data_viz.ctrl[:] = mj_data.ctrl
            mujoco.mj_forward(mj_model_viz, mj_data_viz)
            position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses)
            mujoco.mj_forward(mj_model_viz, mj_data_viz)
            ik_configuration.update(mj_data_viz.qpos)
            
            # Run IK for each trajectory point
            for i, target_pos in enumerate(trajectory_3d_lifted):
                target_gripper_rot = median_dataset_rotation
                
                # Update IK configuration from current state before solving
                link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
                position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
                mujoco.mj_forward(mj_model_viz, mj_data_viz)
                ik_configuration.update(mj_data_viz.qpos)
                
                # Solve IK with median rotation constraint
                ik_to_keypoint_and_rotation(target_pos, target_gripper_rot, ik_configuration, robot_config, mj_model_viz, mj_data_viz,max_iterations=30)
                
                # Set gripper to predicted value (last joint)
                if i < len(grippers_pred):
                    predicted_gripper_val = grippers_pred[i]
                    # Set the last joint (gripper) to the predicted value
                    if len(mj_data_viz.qpos) > 0:
                        mj_data_viz.qpos[-1] = predicted_gripper_val
                    if len(mj_data_viz.ctrl) > 0:
                        mj_data_viz.ctrl[-1] = predicted_gripper_val
                
                # Update link poses and forward kinematics after IK
                link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
                position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
                mujoco.mj_forward(mj_model_viz, mj_data_viz)
                
                # Save joint positions after IK (including gripper)
                ik_joint_positions.append(mj_data_viz.qpos.copy())
        print("done ik")
        # Determine render indices (evenly spaced from trajectory) - needed for both attention maps and MuJoCo renders
        num_renders = 4
        if len(trajectory_3d_lifted) > 0:
            # Select evenly spaced indices
            render_indices = np.linspace(0, len(trajectory_3d_lifted) - 1, num_renders, dtype=int)
        else:
            render_indices = [0, 0, 0, 0]
        
        # MuJoCo rendering for 2D vis panes (using pre-computed IK joint positions)
        print("doing mujoco rendering")
        mujoco_renders = []
        if args.vis_2d and len(ik_joint_positions) > 0:
            # Render 4 evenly spaced frames from the trajectory
            if len(trajectory_3d_lifted) > 0:
                
                for render_idx in render_indices:
                    # Set joint positions for this frame (from pre-computed IK, which already includes gripper)
                    mj_data_viz.qpos[:] = ik_joint_positions[render_idx]
                    mj_data_viz.ctrl[:] = ik_joint_positions[render_idx][:len(mj_data_viz.ctrl)]
                    # Ensure gripper is set (should already be in ik_joint_positions, but set explicitly for safety)
                    if render_idx < len(grippers_pred):
                        predicted_gripper_val = grippers_pred[render_idx]
                        if len(mj_data_viz.qpos) > 0:
                            mj_data_viz.qpos[-1] = predicted_gripper_val
                        if len(mj_data_viz.ctrl) > 0:
                            mj_data_viz.ctrl[-1] = predicted_gripper_val
                    mujoco.mj_forward(mj_model_viz, mj_data_viz)
                    
                    # Update link poses and forward kinematics
                    link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
                    position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
                    mujoco.mj_forward(mj_model_viz, mj_data_viz)

                    cam_K_for_render = cam_K.copy()
                    cam_K_for_render[:2]/=20
                    rendered = render_from_camera_pose(mj_model_viz, mj_data_viz, camera_pose, cam_K_for_render, H_orig//20, W_orig//20)
                    # Resize rendered to match visualization resolution
                    rendered_resized = cv2.resize(rendered, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
                    mujoco_renders.append(rendered_resized)
            else:
                # Fill with placeholder if no trajectory
                for _ in range(num_renders):
                    mujoco_renders.append(np.zeros((RES_LOW, RES_LOW, 3), dtype=np.uint8))
        else:
            # Fill with placeholder if no IK positions
            for _ in range(num_renders):
                mujoco_renders.append(np.zeros((RES_LOW, RES_LOW, 3), dtype=np.uint8))
        print("done mujoco rendering")
        # Update 2D visualization
        print("updating 2d visualization")
        if args.vis_2d:
            # Resize RGB to low-res for visualization
            rgb_lowres = cv2.resize(rgb, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
            
            # Extract DINO vis
            dino_vis = dino_features[:, :, :3].numpy()
            for i in range(3):
                channel = dino_vis[:, :, i]
                min_val, max_val = channel.min(), channel.max()
                if max_val > min_val:
                    dino_vis[:, :, i] = (channel - min_val) / (max_val - min_val)
                else:
                    dino_vis[:, :, i] = 0.5
            dino_vis = np.clip(dino_vis, 0, 1)
            dino_vis_upscaled = cv2.resize(dino_vis, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
            
            # Rescale predicted trajectory to low-res
            predicted_trajectory_lowres = rescale_coords(pred_image_coords, H_loaded, W_loaded, RES_LOW, RES_LOW)
            predicted_trajectory_patches = rescale_coords(pred_image_coords, H_loaded, W_loaded, h_patches, w_patches)
            
            # Update RGB plot
            ax_rgb.clear()
            ax_rgb.imshow(rgb_lowres)
            if len(predicted_trajectory_lowres) > 0:
                ax_rgb.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
                for i, (x, y) in enumerate(predicted_trajectory_lowres):
                    color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
                    ax_rgb.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
            ax_rgb.plot(current_kp_2d_patches[0] * (RES_LOW / w_patches), current_kp_2d_patches[1] * (RES_LOW / h_patches), 
                       'go', markersize=8, markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
            ax_rgb.set_title("RGB with Predicted Trajectory")
            ax_rgb.legend(loc='upper right', fontsize=9)
            ax_rgb.axis('off')
            
            # Update DINO plot
            ax_dino.clear()
            ax_dino.imshow(dino_vis_upscaled)
            if len(predicted_trajectory_patches) > 0:
                pred_patches_lowres = predicted_trajectory_patches * (RES_LOW / np.array([w_patches, h_patches]))
                ax_dino.plot(pred_patches_lowres[:, 0], pred_patches_lowres[:, 1], 'r-', linewidth=2, alpha=0.7, label='Pred Trajectory')
                for i, (x, y) in enumerate(pred_patches_lowres):
                    color = plt.cm.plasma(i / len(pred_patches_lowres))
                    ax_dino.plot(x, y, 'x', color=color, markersize=6, markeredgewidth=1)
            ax_dino.plot(current_kp_2d_patches[0] * (RES_LOW / w_patches), current_kp_2d_patches[1] * (RES_LOW / h_patches),
                       'go', markersize=8, markeredgecolor='white', markeredgewidth=1, label='Current EEF', zorder=10)
            ax_dino.set_title("DINO Features with Predicted Trajectory")
            ax_dino.legend(loc='upper right', fontsize=9)
            ax_dino.axis('off')
            
            # Update attention map
            attention_map = attention_scores[0].reshape(h_patches, w_patches)
            # Use volume mask for visualization
            volume_mask_2d = volume_mask  # (h_patches, w_patches) boolean
            # Replace masked values with minimum value from valid region for better visualization
            if np.any(volume_mask_2d):
                min_valid_value = attention_map[volume_mask_2d].min()
                attention_map[~volume_mask_2d] = min_valid_value
            attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
            attention_map_upscaled = cv2.resize(attention_map_norm, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
            ax_attn.clear()
            ax_attn.imshow(attention_map_upscaled, cmap='hot', vmin=0, vmax=1)
            ax_attn.set_title("Attention Map (t=0) - Volume Masked")
            ax_attn.axis('off')
            
            # Update trajectory overview
            ax_traj.clear()
            ax_traj.imshow(rgb_lowres)
            if len(predicted_trajectory_lowres) > 0:
                ax_traj.plot(predicted_trajectory_lowres[:, 0], predicted_trajectory_lowres[:, 1], 'r-', linewidth=3, alpha=0.8, label='Predicted')
                for i, (x, y) in enumerate(predicted_trajectory_lowres):
                    color = plt.cm.plasma(i / len(predicted_trajectory_lowres))
                    ax_traj.plot(x, y, 'o', color=color, markersize=8, markeredgecolor='white', markeredgewidth=1)
            ax_traj.plot(current_kp_2d_patches[0] * (RES_LOW / w_patches), current_kp_2d_patches[1] * (RES_LOW / h_patches),
                       'go', markersize=10, markeredgecolor='white', markeredgewidth=2, label='Start EEF', zorder=10)
            ax_traj.set_title("Full Predicted Trajectory")
            ax_traj.legend(loc='upper right', fontsize=9)
            ax_traj.axis('off')
            
            # Update attention map panes (for same timesteps as MuJoCo renders)
            attention_axes = [ax_attn_0, ax_attn_1, ax_attn_2, ax_attn_3]
            for i, ax_attn_timestep in enumerate(attention_axes):
                ax_attn_timestep.clear()
                if len(render_indices) > i and render_indices[i] < attention_scores.shape[0]:
                    # Extract attention map for this timestep
                    attn_idx = render_indices[i]
                    attention_map = attention_scores[attn_idx].reshape(h_patches, w_patches)
                    # Use volume mask for visualization
                    volume_mask_2d = volume_mask  # (h_patches, w_patches) boolean
                    # Replace masked values with minimum value from valid region for better visualization
                    if np.any(volume_mask_2d):
                        min_valid_value = attention_map[volume_mask_2d].min()
                        attention_map[~volume_mask_2d] = min_valid_value
                    attention_map_norm = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min() + 1e-8)
                    attention_map_upscaled = cv2.resize(attention_map_norm, (RES_LOW, RES_LOW), interpolation=cv2.INTER_LINEAR)
                    ax_attn_timestep.imshow(attention_map_upscaled, cmap='hot', vmin=0, vmax=1)
                    ax_attn_timestep.set_title(f"Attention Map (t={attn_idx}) - Volume Masked")
                else:
                    ax_attn_timestep.imshow(np.zeros((RES_LOW, RES_LOW)), cmap='hot', vmin=0, vmax=1)
                    ax_attn_timestep.set_title(f"Attention Map (t={i})")
                ax_attn_timestep.axis('off')
            
            # Update MuJoCo rendered robot panes (4 evenly spaced frames)
            mujoco_axes = [ax_mujoco_0, ax_mujoco_1, ax_mujoco_2, ax_mujoco_3]
            for i, ax_mujoco in enumerate(mujoco_axes):
                ax_mujoco.clear()
                if i < len(mujoco_renders) and mujoco_renders[i] is not None and len(mujoco_renders[i].shape) == 3:
                    ax_mujoco.imshow(mujoco_renders[i])
                    if len(render_indices) > i:
                        ax_mujoco.set_title(f"MuJoCo Render (t={render_indices[i]})")
                    else:
                        ax_mujoco.set_title(f"MuJoCo Render (t={i})")
                else:
                    ax_mujoco.imshow(rgb_lowres)
                    ax_mujoco.set_title(f"MuJoCo Render (t={i})")
                ax_mujoco.axis('off')
            
            # Update height chart
            ax_height.clear()
            timesteps = np.arange(1, len(heights_pred_denorm) + 1)
            ax_height.bar(timesteps - 0.2, heights_pred_denorm, width=0.4, alpha=0.6, color='red', label='Pred Height')
            ax_height.set_xlabel('Timestep', fontsize=10)
            ax_height.set_ylabel('Height (m)', fontsize=10)
            ax_height.set_title('Predicted Height Trajectory', fontsize=12)
            ax_height.legend(fontsize=9)
            ax_height.grid(alpha=0.3)
            
            # Update gripper chart
            ax_gripper.clear()
            timesteps_gripper = np.arange(1, len(grippers_pred) + 1)
            ax_gripper.bar(timesteps_gripper, grippers_pred, width=0.4, alpha=0.6, color='orange', label='Pred Gripper')
            ax_gripper.set_xlabel('Timestep', fontsize=10)
            ax_gripper.set_ylabel('Gripper Value', fontsize=10)
            ax_gripper.set_title('Predicted Gripper Open/Close Trajectory', fontsize=12)
            ax_gripper.legend(fontsize=9)
            ax_gripper.grid(alpha=0.3)
            
            fig.canvas.draw()
            fig.canvas.flush_events()
        print("done updating 2d visualization")
        # Update MuJoCo visualization (separate window) - using pre-computed IK joint positions
        print("updating mujoco visualization")
        if args.vis_mujoco and len(ik_joint_positions) > 0:
            # Use the first trajectory point's pre-computed joint positions (which already includes gripper)
            mj_data_viz.qpos[:] = ik_joint_positions[0]
            mj_data_viz.ctrl[:] = ik_joint_positions[0][:len(mj_data_viz.ctrl)]
            # Ensure gripper is set (should already be in ik_joint_positions, but set explicitly for safety)
            if len(grippers_pred) > 0:
                predicted_gripper_val = grippers_pred[0]
                if len(mj_data_viz.qpos) > 0:
                    mj_data_viz.qpos[-1] = predicted_gripper_val
                if len(mj_data_viz.ctrl) > 0:
                    mj_data_viz.ctrl[-1] = predicted_gripper_val
            mujoco.mj_forward(mj_model_viz, mj_data_viz)
            
            # Update link poses and forward kinematics
            link_poses_viz = get_link_poses_from_robot(robot_config, mj_model_viz, mj_data_viz)
            position_exoskeleton_meshes(robot_config, mj_model_viz, mj_data_viz, link_poses_viz)
            mujoco.mj_forward(mj_model_viz, mj_data_viz)
            
            # Render at original resolution with cam_K (calibrated for H_orig x W_orig)
            rendered = render_from_camera_pose(mj_model_viz, mj_data_viz, camera_pose, cam_K, H_orig, W_orig)
            # Resize both RGB and rendered to match for overlay
            rgb_for_overlay = cv2.resize(rgb, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
            rendered_resized = cv2.resize(rendered, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
            
            # Display results
            cv2.imshow('Live Inference - MuJoCo', cv2.cvtColor(
                (rgb_for_overlay * 0.5 + rendered_resized * 0.5).astype(np.uint8), cv2.COLOR_RGB2BGR
            ))
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break
        print("done updating mujoco visualization")
        # optionally do inference
        if args.use_arm:
            #print(arm.get_pos())
            #print(ik_joint_positions[0])
            #print(ik_joint_positions[-1])
            #targ_pos=ik_joint_positions[0]
            ik_joint_positions[0][-1]=1
            start_pos=[0,  4.14158842, 1.65811715, 1.60476692, 1.60733803, 1]
            for i, targ_pos in enumerate(ik_joint_positions[:5]):
                start_pos[-1]=targ_pos[-1]=arm.get_pos()[-1]

                # write to high position first
                if i:
                    last_pos=arm.get_pos()
                    arm.write_pos(start_pos,slow=False)
                    while True: # keep writing until the position is reached
                        curr_pos=arm.get_pos()
                        if np.max(np.abs(curr_pos-last_pos))<0.01: break
                        last_pos=curr_pos
                    print("done moving to high position")

                last_pos=arm.get_pos()
                arm.write_pos(targ_pos,slow=False)
                while True: # keep writing until the position is reached
                    curr_pos=arm.get_pos()
                    if np.max(np.abs(curr_pos-last_pos))<0.01: break
                    last_pos=curr_pos

                # Use predicted gripper value for the last joint
                targ_pos[-1]=grippers_pred[i]
                last_pos=arm.get_pos()
                arm.write_pos(targ_pos,slow=False)
                while True: # keep writing until the position is reached
                    curr_pos=arm.get_pos()
                    if np.max(np.abs(curr_pos-last_pos))<0.01: break
                    last_pos=curr_pos
                print("done moving gripper")
                first_write=False

        
        # Maintain framerate
        #elapsed = time.time() - frame_start_time
        #sleep_time = max(0, frame_interval - elapsed)
        #if sleep_time > 0:
        #    time.sleep(sleep_time)

except KeyboardInterrupt:
    print("\nStopping live inference...")

finally:
    cap.release()
    if args.vis_mujoco:
        cv2.destroyAllWindows()
    if args.vis_2d:
        plt.ioff()
    print("✓ Live inference stopped")
