"""Train a ResNet model to predict next 5 gripper translations from current frame."""
import argparse
import os
import sys
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import viser
import mujoco
import colorsys
from scipy.spatial.transform import Rotation as R
from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (
    get_link_poses_from_robot,
    position_exoskeleton_meshes,
    combine_xmls,
)

# Configuration
OUTPUT_RES = 224*2  # Input image resolution
OUTPUT_DIM = 3  # 3D coordinate (x, y, z)

def get_gripper_position_from_joint_state(joint_state, model, robot_config):
    """Get gripper position from joint state using MuJoCo."""
    data = mujoco.MjData(model)
    data.qpos[:] = joint_state
    data.ctrl[:] = joint_state[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Position exoskeleton meshes
    link_poses = get_link_poses_from_robot(robot_config, model, data)
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Get gripper position from exoskeleton mesh
    exo_mesh_body_name = "fixed_gripper_exo_mesh"
    exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
    exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
    gripper_pos = data.mocap_pos[exo_mesh_mocap_id].copy()
    
    return gripper_pos

def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point in world frame to 2D image coordinates.
    
    Args:
        point_3d: (3,) array in world frame
        camera_pose: (4, 4) transformation matrix from world to camera
        cam_K: (3, 3) camera intrinsic matrix
    
    Returns:
        point_2d: (2,) array of image coordinates, or None if behind camera
    """
    # Transform point to camera frame
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = (camera_pose @ point_3d_h)[:3]
    
    # Check if point is behind camera
    if point_cam[2] <= 0:
        return None
    
    # Project to image plane
    point_2d_h = cam_K @ point_cam
    point_2d = point_2d_h[:2] / point_2d_h[2]
    
    return point_2d

def draw_keypoint_on_image(img, point_2d, color=(255, 0, 0), size=10):
    """Draw a keypoint on an image."""
    if point_2d is None:
        return img
    
    H, W = img.shape[:2]
    x, y = int(point_2d[0]), int(point_2d[1])
    
    # Check if point is within image bounds
    if 0 <= x < W and 0 <= y < H:
        # Draw filled circle with bright color
        cv2.circle(img, (x, y), size, color, -1)
        # Draw white border for visibility
        cv2.circle(img, (x, y), size + 2, (255, 255, 255), 2)
        # Draw cross for better visibility
        cv2.line(img, (x - size, y), (x + size, y), (255, 255, 255), 2)
        cv2.line(img, (x, y - size), (x, y + size), (255, 255, 255), 2)
    
    return img

def preload_dataset(sequence_dirs, processed_dir, use_dino_pointmap=True, output_res=224):
    """Preload all data into memory for fast access from processed keyboard grasp dataset."""
    print(f"Preloading {len(sequence_dirs)} sequences...")
    
    images_list = []
    target_positions_list = []  # Target 3D gripper position at grasp
    episode_ids_list = []
    
    # For visualization
    dino_features_list = []  # Store DINO features for visualization
    pointmap_list = []  # Store pointmaps for visualization
    
    for seq_dir in tqdm(sequence_dirs, desc="Loading sequences"):
        seq_dir = Path(seq_dir)
        sequence_id = seq_dir.name
        
        # Load start image with keypoints rendered (from test_2d_keypoint_grasp_render.py)
        start_kprender_path = seq_dir / "start_kprender.png"
        if not start_kprender_path.exists():
            print(f"  ⚠ Skipping {sequence_id}: start_kprender.png not found")
            continue
        
        rgb = cv2.cvtColor(cv2.imread(str(start_kprender_path)), cv2.COLOR_BGR2RGB)
        if rgb.max() <= 1.0:
            rgb = (rgb * 255).astype(np.uint8)
        
        H_orig, W_orig = rgb.shape[:2]
        
        # Resize to output_res
        rgb_resized = cv2.resize(rgb, (output_res, output_res), interpolation=cv2.INTER_LINEAR)
        rgb_tensor = torch.from_numpy(rgb_resized).permute(2, 0, 1).float() / 255.0  # (3, H, W)
        
        # Always use DINO + pointmap (default)
        if use_dino_pointmap:
            # Load DINO features (from start frame)
            dino_path = seq_dir / "dino_features_hw.pt"
            if dino_path.exists():
                dino_features_hw = torch.load(dino_path)  # (H, W, 32) or (H*W, 32)
                if len(dino_features_hw.shape) == 2:
                    # Reshape from (H*W, 32) to (H, W, 32)
                    N = dino_features_hw.shape[0]
                    H_dino = int(np.sqrt(N))
                    W_dino = N // H_dino
                    if H_dino * W_dino == N:
                        dino_features_hw = dino_features_hw.reshape(H_dino, W_dino, 32)
                    else:
                        dino_features_hw = dino_features_hw.reshape(H_orig, W_orig, 32)
                
                # Resize to output_res
                dino_resized = F.interpolate(
                    torch.from_numpy(dino_features_hw.numpy() if isinstance(dino_features_hw, torch.Tensor) else dino_features_hw).permute(2, 0, 1).float().unsqueeze(0),
                    size=(output_res, output_res),
                    mode='bilinear',
                    align_corners=False
                ).squeeze(0)  # (32, H, W)
                
                # Store for visualization (first 3 PCA components)
                if isinstance(dino_features_hw, torch.Tensor):
                    dino_features_hw_np = dino_features_hw.numpy()
                else:
                    dino_features_hw_np = dino_features_hw
                dino_pca = dino_features_hw_np[:, :, :3] if len(dino_features_hw_np.shape) == 3 else dino_features_hw_np[:, :3]
                dino_features_list.append(dino_pca)
            else:
                dino_resized = torch.zeros(32, output_res, output_res)
                dino_features_list.append(None)
            
            # Load pointmap (from start frame)
            pointmap_path = seq_dir / "pointmap_start_raw.pt"
            if pointmap_path.exists():
                pointmap = torch.load(pointmap_path)
                points = pointmap["points"].numpy() if isinstance(pointmap["points"], torch.Tensor) else pointmap["points"]  # (N, 3)
                colors = pointmap["colors"].numpy() if isinstance(pointmap["colors"], torch.Tensor) else pointmap["colors"]  # (N, 3)
                
                # Reshape and resize
                H_pts,W_pts=rgb.shape[:2]
                N = len(points)
                #H_pts = int(np.sqrt(N))
                #W_pts = N // H_pts
                if H_pts * W_pts == N:
                    points_2d = points.reshape(H_pts, W_pts, 3)
                    points_resized = F.interpolate(
                        torch.from_numpy(points_2d).permute(2, 0, 1).float().unsqueeze(0),
                        size=(output_res, output_res),
                        mode='bilinear',
                        align_corners=False
                    ).squeeze(0)  # (3, H, W)
                    
                    # Store for visualization
                    colors_2d = colors.reshape(H_pts, W_pts, 3)
                    pointmap_list.append({'points': points_2d, 'colors': colors_2d})
                else:
                    points_resized = torch.zeros(3, output_res, output_res)
                    pointmap_list.append(None)
            else:
                points_resized = torch.zeros(3, output_res, output_res)
                pointmap_list.append(None)
            
            # Concatenate: RGB(3) + DINO(32) + XYZ(3) = 38 channels
            rgb_tensor = torch.cat([rgb_tensor, dino_resized, points_resized], dim=0)  # (38, H, W)
        else:
            dino_features_list.append(None)
            pointmap_list.append(None)
        
        # Load target gripper position at grasp
        gripper_pose_path = seq_dir / "gripper_pose_grasp.npy"
        if not gripper_pose_path.exists():
            print(f"  ⚠ Skipping {sequence_id}: gripper_pose_grasp.npy not found")
            continue
        
        gripper_pose = np.load(gripper_pose_path)  # (4, 4)
        target_position = gripper_pose[:3, 3]  # Extract 3D position
        
        images_list.append(rgb_tensor)
        target_positions_list.append(torch.from_numpy(target_position).float())
        episode_ids_list.append(sequence_id)
    
    return images_list, target_positions_list, episode_ids_list, dino_features_list, pointmap_list

class GripperPositionDataset(Dataset):
    """Dataset that indexes into preloaded data."""
    def __init__(self, images_list, target_positions_list, episode_ids_list):
        self.images_list = images_list
        self.target_positions_list = target_positions_list
        self.episode_ids_list = episode_ids_list
    
    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, idx):
        return (
            self.images_list[idx],
            self.target_positions_list[idx],
            self.episode_ids_list[idx]
        )

class GripperPositionPredictor(nn.Module):
    """ResNet-based model to predict 3D gripper position from start frame."""
    def __init__(self, input_channels=38, output_dim=3, resnet_type='resnet18'):
        super().__init__()
        self.output_dim = output_dim
        
        # Load pretrained ResNet
        if resnet_type == 'resnet18':
            resnet = models.resnet18(pretrained=True)
        elif resnet_type == 'resnet34':
            resnet = models.resnet34(pretrained=True)
        elif resnet_type == 'resnet50':
            resnet = models.resnet50(pretrained=True)
        else:
            raise ValueError(f"Unknown resnet_type: {resnet_type}")
        
        # Modify first layer if input_channels != 3
        if input_channels != 3:
            # Replace first conv layer
            old_conv = resnet.conv1
            resnet.conv1 = nn.Conv2d(
                input_channels, old_conv.out_channels,
                kernel_size=old_conv.kernel_size,
                stride=old_conv.stride,
                padding=old_conv.padding,
                bias=old_conv.bias is not None
            )
            # Initialize new channels
            if input_channels > 3   :
                with torch.no_grad():
                    resnet.conv1.weight[:, :3] = old_conv.weight
                    resnet.conv1.weight[:, 3:] = 0.0
        
        # Remove final FC layer and add regression head
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])  # Remove final FC
        self.fc = nn.Linear(resnet.fc.in_features, output_dim)
    
    def forward(self, x):
        # x: (B, C, H, W)
        features = self.backbone(x)  # (B, 512, 1, 1) for ResNet18
        features = features.view(features.size(0), -1)  # (B, 512)
        position = self.fc(features)  # (B, 3)
        return position

def train_model(model, train_loader, val_loader, lr=1e-3, epochs=1000, log_dir="runs/gripper_position_predictor", checkpoint_path=None,
                train_dino_features=None, train_pointmaps=None, val_dino_features=None, val_pointmaps=None):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)
    
    best_val_loss = float('inf')
    start_epoch = 0
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_loss = checkpoint['best_val_loss']
        print(f"  Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.6f}")
    
    for epoch in range(start_epoch, epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        # Training
        model.train()
        train_loss = 0.0
        for images, target_positions, ep_ids in tqdm(train_loader, desc="Training"):
            images = images.to(device)  # (B, C, H, W)
            target_positions = target_positions.to(device)  # (B, 3)
            
            optimizer.zero_grad()
            positions_pred = model(images)  # (B, 3)
            loss = criterion(positions_pred, target_positions)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, target_positions, ep_ids in val_loader:
                images = images.to(device)
                target_positions = target_positions.to(device)
                positions_pred = model(images)
                loss = criterion(positions_pred, target_positions)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        writer.add_scalar('Loss/Train', train_loss, epoch + 1)
        writer.add_scalar('Loss/Val', val_loss, epoch + 1)
        
        # Log visualizations every 10 epochs
        if epoch % 10 == 0:
            # Get sample batch
            train_iter = iter(train_loader)
            sample_images, sample_positions, sample_ep_ids = next(train_iter)
            sample_images = sample_images[:4]  # Take first 4 samples
            
            # Convert RGB channels to display format (C, H, W) -> (H, W, C)
            images_rgb = sample_images[:, :3].permute(0, 2, 3, 1).cpu().numpy()
            images_rgb = np.clip(images_rgb, 0, 1)
            
            # Create grid using torchvision
            images_tensor = torch.from_numpy(images_rgb).permute(0, 3, 1, 2)  # (4, 3, H, W)
            grid = vutils.make_grid(images_tensor, nrow=2, normalize=False, pad_value=1.0)
            writer.add_image('Train/Input_RGB_with_Keypoints', grid, epoch + 1)
            
            # Visualize DINO PCA features
            if train_dino_features is not None:
                dino_vis_list = []
                for i in range(min(4, len(sample_ep_ids))):
                    ep_id = sample_ep_ids[i]
                    # Find index in original list
                    try:
                        idx = [eid for eid in train_loader.dataset.episode_ids_list].index(ep_id)
                        if idx < len(train_dino_features) and train_dino_features[idx] is not None:
                            dino_pca = train_dino_features[idx]
                            # Normalize and convert to RGB
                            if isinstance(dino_pca, torch.Tensor):
                                dino_pca = dino_pca.numpy()
                            if len(dino_pca.shape) == 2:
                                # Reshape if needed
                                N = dino_pca.shape[0]
                                H_dino = int(np.sqrt(N))
                                W_dino = N // H_dino
                                if H_dino * W_dino == N:
                                    dino_pca = dino_pca.reshape(H_dino, W_dino, 3)
                            
                            # Normalize DINO PCA features to [0, 1] range per channel
                            dino_pca_norm = dino_pca.copy()
                            for c in range(3):
                                channel_data = dino_pca_norm[:, :, c]
                                min_val = channel_data.min()
                                max_val = channel_data.max()
                                if max_val > min_val:
                                    dino_pca_norm[:, :, c] = (channel_data - min_val) / (max_val - min_val)
                                else:
                                    dino_pca_norm[:, :, c] = 0.5  # Default to middle if constant
                            
                            dino_rgb_resized = cv2.resize(dino_pca_norm, (OUTPUT_RES, OUTPUT_RES), interpolation=cv2.INTER_LINEAR)
                            dino_vis_list.append(torch.from_numpy(dino_rgb_resized).permute(2, 0, 1).float())
                        else:
                            dino_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                    except (ValueError, IndexError):
                        dino_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                
                if len(dino_vis_list) > 0:
                    dino_grid = vutils.make_grid(torch.stack(dino_vis_list), nrow=2, normalize=False, pad_value=1.0)
                    writer.add_image('Train/DINO_PCA_Features', dino_grid, epoch + 1)
            
            # Visualize pointmaps
            if train_pointmaps is not None:
                pointmap_vis_list = []
                for i in range(min(4, len(sample_ep_ids))):
                    ep_id = sample_ep_ids[i]
                    try:
                        idx = [eid for eid in train_loader.dataset.episode_ids_list].index(ep_id)
                        if idx < len(train_pointmaps) and train_pointmaps[idx] is not None:
                            pm = train_pointmaps[idx]
                            colors = pm['colors']
                            if isinstance(colors, torch.Tensor):
                                colors = colors.numpy()
                            
                            # Normalize colors to [0, 1] range
                            if colors.max() > 1.0:
                                colors = colors.astype(np.float32) / 255.0
                            else:
                                colors = colors.astype(np.float32)
                            
                            colors_resized = cv2.resize(colors, (OUTPUT_RES, OUTPUT_RES), interpolation=cv2.INTER_LINEAR)
                            # Ensure values are in [0, 1] range after resize
                            colors_normalized = np.clip(colors_resized, 0.0, 1.0)
                            pointmap_vis_list.append(torch.from_numpy(colors_normalized).permute(2, 0, 1).float())
                        else:
                            pointmap_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                    except (ValueError, IndexError):
                        pointmap_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                
                if len(pointmap_vis_list) > 0:
                    pointmap_grid = vutils.make_grid(torch.stack(pointmap_vis_list), nrow=2, normalize=False, pad_value=1.0)
                    writer.add_image('Train/Pointmap_Visualization', pointmap_grid, epoch + 1)
        
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs("scratch/policies", exist_ok=True)
            torch.save(model.state_dict(), f"scratch/policies/{args.run_name}.pt")
            if (epoch + 1) % 50 == 0:
                print(f"  Saved best model (val_loss: {val_loss:.6f})")
        
        # Save checkpoint every 100 epochs
        if (epoch + 1) % 50 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }
            checkpoint_path_save = os.path.join("scratch/policies", "gripper_translation_predictor_checkpoint.pt")
            torch.save(checkpoint, checkpoint_path_save)
            print(f"  Saved checkpoint at epoch {epoch + 1}")
    
    writer.close()
    print(f"\nTensorBoard logs saved to: {log_dir}")

def visualize_predictions(model, data_loader, device, processed_dir, output_dir):
    """Visualize predictions with sliders for each translation step and pointcloud."""
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    # Setup MuJoCo model for gripper pose computation
    robot_config = SO100AdhesiveConfig()
    combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
    mj_model = mujoco.MjModel.from_xml_string(combined_xml)
    
    # Load all batches
    all_images = []
    all_positions_pred = []
    all_positions_gt = []
    all_ep_ids = []
    
    print("Running predictions on validation set...")
    with torch.no_grad():
        for images, positions_gt, ep_ids in tqdm(data_loader, desc="Predicting"):
            images = images.to(device)
            positions_pred = model(images)
            
            all_images.append(images.cpu())
            all_positions_pred.append(positions_pred.cpu())
            all_positions_gt.append(positions_gt)
            all_ep_ids.extend(ep_ids)
    
    # Concatenate all batches
    images_np = torch.cat(all_images, dim=0).numpy()
    positions_pred_np = torch.cat(all_positions_pred, dim=0).numpy()
    positions_gt_np = torch.cat(all_positions_gt, dim=0).numpy()
    
    # Setup Viser
    server = viser.ViserServer(port=8080)
    
    # Load gripper mesh for visualization
    import trimesh
    fixed_gripper_stl_path = "robot_models/so100_model/assets/Fixed_Jaw.stl"
    fixed_gripper_mesh = None
    if os.path.exists(fixed_gripper_stl_path):
        fixed_gripper_mesh = trimesh.load(fixed_gripper_stl_path)
        if isinstance(fixed_gripper_mesh, trimesh.Scene):
            fixed_gripper_mesh = list(fixed_gripper_mesh.geometry.values())[0]
        # Check if mesh is in mm and convert to meters
        bounds = fixed_gripper_mesh.bounds
        max_extent = np.max(bounds[1] - bounds[0])
        if max_extent > 1.0:
            fixed_gripper_mesh.apply_scale(0.001)
    
    # Load pointclouds for visualization
    pointclouds = []
    processed_dir_path = Path(processed_dir)
    for i, ep_id in enumerate(all_ep_ids):
        if i >= len(images_np):
            break
        seq_dir = processed_dir_path / ep_id
        
        # Load robot-aligned pointmap from start frame (already in robot frame)
        pointmap_path = seq_dir / "pointmap_start.pt"
        
        if pointmap_path.exists():
            try:
                pointmap = torch.load(pointmap_path)
                # Points are already in robot frame
                points_robot = pointmap["points"].numpy() if isinstance(pointmap["points"], torch.Tensor) else pointmap["points"]
                colors = pointmap["colors"].numpy() if isinstance(pointmap["colors"], torch.Tensor) else pointmap["colors"]
                
                # Check for invalid points
                valid_points_mask = ~(np.any(np.isnan(points_robot), axis=1) | np.any(np.isinf(points_robot), axis=1))
                points_robot = points_robot[valid_points_mask]
                colors = colors[valid_points_mask]
                
                if len(points_robot) > 0:
                    # Ensure colors are uint8 [0-255]
                    if colors.dtype != np.uint8:
                        if colors.max() <= 1.0:
                            colors = (colors * 255).astype(np.uint8)
                        else:
                            colors = colors.astype(np.uint8)
                    
                    # Downsample
                    indices = np.arange(0, len(points_robot), 9)
                    pointclouds.append({
                        'points': points_robot[indices],
                        'colors': colors[indices]
                    })
                else:
                    pointclouds.append({'points': np.zeros((0, 3)), 'colors': np.zeros((0, 3))})
            except Exception as e:
                print(f"  ⚠ Error loading pointcloud for {ep_id}: {e}")
                pointclouds.append({'points': np.zeros((0, 3)), 'colors': np.zeros((0, 3))})
        else:
            pointclouds.append({'points': np.zeros((0, 3)), 'colors': np.zeros((0, 3))})
    
    
    # Current sample index
    current_idx = [0]
    
    def update_visualization():
        idx = current_idx[0]
        
        if idx >= len(images_np):
            return
        
        # Get predicted and GT positions
        position_pred = positions_pred_np[idx]  # (3,)
        position_gt = positions_gt_np[idx]  # (3,)
        
        # Convert to numpy if tensors
        if isinstance(position_pred, torch.Tensor):
            position_pred = position_pred.numpy()
        if isinstance(position_gt, torch.Tensor):
            position_gt = position_gt.numpy()
        
        # Update pointcloud
        if idx < len(pointclouds):
            pc = pointclouds[idx]
            server.scene.add_point_cloud(
                name="/pointcloud",
                points=pc['points'].astype(np.float32),
                colors=pc['colors'].astype(np.uint8),
                point_size=0.002,
            )
        
        # Load GT gripper pose from grasp frame
        seq_dir = processed_dir_path / all_ep_ids[idx]
        gripper_pose_path = seq_dir / "gripper_pose_grasp.npy"
        if gripper_pose_path.exists():
            gripper_pose_gt = np.load(gripper_pose_path)  # (4, 4)
            gripper_rot_gt = gripper_pose_gt[:3, :3]
            gripper_pos_gt = gripper_pose_gt[:3, 3]
            
            # Display gripper meshes
            if fixed_gripper_mesh is not None:
                # GT gripper (blue)
                quat_gt = R.from_matrix(gripper_rot_gt).as_quat()  # (x, y, z, w)
                try:
                    server.scene.add_mesh_trimesh(
                        name="/gripper_gt",
                        mesh=fixed_gripper_mesh,
                        wxyz=quat_gt[[3, 0, 1, 2]],  # (w, x, y, z)
                        position=gripper_pos_gt.astype(np.float32),
                    )
                except Exception as e:
                    # Fallback: transform vertices manually
                    vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
                    transformed_vertices_gt = (gripper_pose_gt @ vertices_homogeneous.T).T[:, :3]
                    try:
                        server.scene.add_mesh_simple(
                            name="/gripper_gt",
                            vertices=transformed_vertices_gt.astype(np.float32),
                            faces=fixed_gripper_mesh.faces.astype(np.int32),
                        )
                    except:
                        pass
                
                # Predicted gripper (red/orange) - using GT rotation, predicted position
                gripper_pose_pred = np.eye(4)
                gripper_pose_pred[:3, :3] = gripper_rot_gt  # Use GT rotation
                gripper_pose_pred[:3, 3] = position_pred  # Use predicted position
                try:
                    server.scene.add_mesh_trimesh(
                        name="/gripper_pred",
                        mesh=fixed_gripper_mesh,
                        wxyz=quat_gt[[3, 0, 1, 2]],  # Use GT rotation
                        position=position_pred.astype(np.float32),
                    )
                except Exception as e:
                    # Fallback: transform vertices manually
                    vertices_homogeneous = np.hstack([fixed_gripper_mesh.vertices, np.ones((fixed_gripper_mesh.vertices.shape[0], 1))])
                    transformed_vertices_pred = (gripper_pose_pred @ vertices_homogeneous.T).T[:, :3]
                    try:
                        server.scene.add_mesh_simple(
                            name="/gripper_pred",
                            vertices=transformed_vertices_pred.astype(np.float32),
                            faces=fixed_gripper_mesh.faces.astype(np.int32),
                        )
                    except:
                        pass
        
        # Display position values
        info_text = f"Sample {idx+1}/{len(images_np)}\n"
        info_text += f"Sequence: {all_ep_ids[idx]}\n\n"
        info_text += "Predicted Position:\n"
        info_text += f"  X: {position_pred[0]:.4f}\n"
        info_text += f"  Y: {position_pred[1]:.4f}\n"
        info_text += f"  Z: {position_pred[2]:.4f}\n"
        info_text += "\nGT Position:\n"
        info_text += f"  X: {position_gt[0]:.4f}\n"
        info_text += f"  Y: {position_gt[1]:.4f}\n"
        info_text += f"  Z: {position_gt[2]:.4f}\n"
        error = np.linalg.norm(position_pred - position_gt)
        info_text += f"\nError: {error:.4f} m"
        
        # Update info text
        if hasattr(update_visualization, 'info_text_handle'):
            update_visualization.info_text_handle.value = info_text
        else:
            update_visualization.info_text_handle = server.gui.add_text("info", initial_value=info_text)
    
    # Add slider
    sample_slider = server.gui.add_slider(
        "sample",
        0,
        len(images_np) - 1,
        initial_value=0,
        step=1
    )
    
    @sample_slider.on_update
    def _(_):
        current_idx[0] = int(sample_slider.value)
        update_visualization()
    
    # Initial visualization
    update_visualization()
    
    print(f"\nViser server running at http://localhost:8080")
    print(f"Press Ctrl+C to exit")
    
    try:
        while True:
            import time
            time.sleep(0.1)
    except KeyboardInterrupt:
        pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=["train", "test", "test_with_train"], default="train")
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--run_name", type=str, default="gripper_translation_predictor")
    parser.add_argument("--use_dino_pointmap", action="store_false", help="Disable DINO features and pointmap (default: enabled)")
    parser.add_argument("--resnet_type", type=str, default="resnet18", choices=["resnet18", "resnet34", "resnet50"])
    args = parser.parse_args()
    
    print("Loading dataset...")
    processed_dir = Path("scratch/processed_grasp_dataset_keyboard")
    
    # Find all sequences
    sequences = sorted([d for d in processed_dir.iterdir() if d.is_dir()])
    print(f"Found {len(sequences)} sequences")
    
    # Split train/val (last 5 for validation)
    train_sequences = sequences[:-5] if len(sequences) > 5 else sequences[:-1]
    val_sequences = sequences[-5:] if len(sequences) > 5 else sequences[-1:]
    
    # Determine input channels (default to 38 with DINO+pointmap)
    use_dino_pointmap = not args.use_dino_pointmap  # Inverted because action="store_false"
    input_channels = 38 if use_dino_pointmap else 3
    
    # Preload all data
    print("\nPreloading training data...")
    train_images, train_positions, train_ep_ids, train_dino_features, train_pointmaps = preload_dataset(
        train_sequences, processed_dir, use_dino_pointmap=use_dino_pointmap, output_res=OUTPUT_RES
    )
    print("\nPreloading validation data...")
    val_images, val_positions, val_ep_ids, val_dino_features, val_pointmaps = preload_dataset(
        val_sequences, processed_dir, use_dino_pointmap=use_dino_pointmap, output_res=OUTPUT_RES
    )
    
    train_dataset = GripperPositionDataset(train_images, train_positions, train_ep_ids)
    val_dataset = GripperPositionDataset(val_images, val_positions, val_ep_ids)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    
    print(f"Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
    
    # Create model
    model = GripperPositionPredictor(
        input_channels=input_channels,
        output_dim=OUTPUT_DIM,
        resnet_type=args.resnet_type
    )
    
    model_path = os.path.join("scratch/policies", f"{args.run_name}.pt")
    log_dir = os.path.join("runs", args.run_name)
    
    if args.mode == "train":
        checkpoint_path = args.checkpoint
        
        train_model(model, train_loader, val_loader, lr=args.lr, epochs=args.epochs,
                   log_dir=log_dir, checkpoint_path=checkpoint_path,
                   train_dino_features=train_dino_features, train_pointmaps=train_pointmaps,
                   val_dino_features=val_dino_features, val_pointmaps=val_pointmaps)
        print(f"\nTraining complete. Model saved to {model_path}")
    else:
        if not os.path.exists(model_path):
            print(f"Error: Model not found at {model_path}. Train first.")
            sys.exit(1)
        device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
        
        # Load model
        model.load_state_dict(torch.load(model_path, map_location=device))
        model = model.to(device).eval()
        print(f"Loaded model from {model_path}")
        
        # Visualize predictions
        output_dir = os.path.join("scratch/pred", args.run_name)
        
        if args.mode == "test":
            # Use validation dataset
            visualize_predictions(model, val_loader, device, processed_dir, output_dir)
        elif args.mode == "test_with_train":
            # Use training dataset
            visualize_predictions(model, train_loader, device, processed_dir, output_dir)

