"""Train a CNN model to predict 3D offset field for gripper center from DINO features, RGB, and pointmap XYZ."""
import argparse
import os
import sys
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms.functional as TF
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import viser
import trimesh
from pathlib import Path

# Configuration
OUTPUT_RES = 224  # Downsampled resolution for training
OUTPUT_CHANNELS = 3  # 3D offset field (x, y, z) for gripper center
DINO_DIM = 32
RGB_DIM = 3
XYZ_DIM = 3
INPUT_CHANNELS = DINO_DIM + RGB_DIM + XYZ_DIM  # 38 channels total
N_episodes=5

def preload_dataset(sequence_dirs, processed_dir, use_dino_pointmap=True, output_res=224):
    """Preload all data into memory for fast access from processed keyboard grasp dataset."""
    from pathlib import Path
    print(f"Preloading {len(sequence_dirs)} sequences...")
    
    images_list = []
    offset_fields_list = []  # Offset fields for gripper center (H, W, 3)
    episode_ids_list = []
    
    # For visualization
    dino_features_list = []  # Store DINO features for visualization
    pointmap_list = []  # Store pointmaps for visualization
    
    for seq_dir in tqdm(sequence_dirs, desc="Loading sequences"):
        seq_dir = Path(seq_dir) if not isinstance(seq_dir, Path) else seq_dir
        sequence_id = seq_dir.name
        
        # Load start image with keypoints rendered (from test_2d_keypoint_grasp_render.py)
        start_kprender_path = seq_dir / "start_kprender.png"
        if not start_kprender_path.exists():
            print(f"  ⚠ Skipping {sequence_id}: start_kprender.png not found")
            continue
        
        rgb = cv2.cvtColor(cv2.imread(str(start_kprender_path)), cv2.COLOR_BGR2RGB)
        if rgb.max() <= 1.0:
            rgb = (rgb * 255).astype(np.uint8)
        
        H_orig, W_orig = rgb.shape[:2]
        
        # Resize to output_res
        rgb_resized = cv2.resize(rgb, (output_res, output_res), interpolation=cv2.INTER_LINEAR)
        rgb_tensor = torch.from_numpy(rgb_resized).permute(2, 0, 1).float() / 255.0  # (3, H, W)
        
        # Always use DINO + pointmap (default)
        if use_dino_pointmap:
            # Load DINO features (from start frame)
            dino_path = seq_dir / "dino_features_hw.pt"
            if dino_path.exists():
                dino_features_hw = torch.load(dino_path)  # (H, W, 32) or (H*W, 32)
                if len(dino_features_hw.shape) == 2:
                    # Reshape from (H*W, 32) to (H, W, 32)
                    N = dino_features_hw.shape[0]
                    H_dino = int(np.sqrt(N))
                    W_dino = N // H_dino
                    if H_dino * W_dino == N:
                        dino_features_hw = dino_features_hw.reshape(H_dino, W_dino, 32)
                    else:
                        dino_features_hw = dino_features_hw.reshape(H_orig, W_orig, 32)
                
                # Resize to output_res
                dino_resized = F.interpolate(
                    torch.from_numpy(dino_features_hw.numpy() if isinstance(dino_features_hw, torch.Tensor) else dino_features_hw).permute(2, 0, 1).float().unsqueeze(0),
                    size=(output_res, output_res),
                    mode='bilinear',
                    align_corners=False
                ).squeeze(0)  # (32, H, W)
                
                # Store for visualization (first 3 PCA components)
                if isinstance(dino_features_hw, torch.Tensor):
                    dino_features_hw_np = dino_features_hw.numpy()
                else:
                    dino_features_hw_np = dino_features_hw
                dino_pca = dino_features_hw_np[:, :, :3] if len(dino_features_hw_np.shape) == 3 else dino_features_hw_np[:, :3]
                dino_features_list.append(dino_pca)
            else:
                dino_resized = torch.zeros(32, output_res, output_res)
                dino_features_list.append(None)
            
            # Load pointmap (from start frame)
            pointmap_path = seq_dir / "pointmap_start_raw.pt"
            points = None
            H_pts, W_pts = rgb.shape[:2]
            N = 0
            if pointmap_path.exists():
                pointmap = torch.load(pointmap_path)
                points = pointmap["points"].numpy() if isinstance(pointmap["points"], torch.Tensor) else pointmap["points"]  # (N, 3)
                colors = pointmap["colors"].numpy() if isinstance(pointmap["colors"], torch.Tensor) else pointmap["colors"]  # (N, 3)
                
                # Reshape and resize - use RGB image dimensions
                N = len(points)
                if H_pts * W_pts == N:
                    points_2d = points.reshape(H_pts, W_pts, 3)
                    points_resized = F.interpolate(
                        torch.from_numpy(points_2d).permute(2, 0, 1).float().unsqueeze(0),
                        size=(output_res, output_res),
                        mode='bilinear',
                        align_corners=False
                    ).squeeze(0)  # (3, H, W)
                    
                    # Store for visualization
                    colors_2d = colors.reshape(H_pts, W_pts, 3)
                    pointmap_list.append({'points': points_2d, 'colors': colors_2d})
                else:
                    points_resized = torch.zeros(3, output_res, output_res)
                    pointmap_list.append(None)
            else:
                points_resized = torch.zeros(3, output_res, output_res)
                pointmap_list.append(None)
            
            # Concatenate: RGB(3) + DINO(32) + XYZ(3) = 38 channels
            rgb_tensor = torch.cat([rgb_tensor, dino_resized, points_resized], dim=0)  # (38, H, W)
        else:
            dino_features_list.append(None)
            pointmap_list.append(None)
            points = None
            H_pts, W_pts = rgb.shape[:2]
            N = 0
        
        # Load gripper pose and compute offset field for gripper center
        gripper_pose_path = seq_dir / "gripper_pose_grasp.npy"
        if not gripper_pose_path.exists():
            print(f"  ⚠ Skipping {sequence_id}: gripper_pose_grasp.npy not found")
            continue
        
        gripper_pose = np.load(gripper_pose_path)  # (4, 4)
        gripper_center = gripper_pose[:3, 3]  # Extract 3D position (gripper center)
        
        # Compute offset field: offset = gripper_center - pointmap_point
        if use_dino_pointmap and points is not None and H_pts * W_pts == N:
            points_2d_full = points.reshape(H_pts, W_pts, 3)  # (H_orig, W_orig, 3)
            # Compute offset field: offset[i,j] = gripper_center - points_2d_full[i,j]
            offset_field_full = gripper_center.reshape(1, 1, 3) - points_2d_full  # (H_orig, W_orig, 3)
            
            # Downsample offset field to output_res
            offset_field_t = torch.from_numpy(offset_field_full).permute(2, 0, 1).float().unsqueeze(0)  # (1, 3, H_orig, W_orig)
            offset_field_down = F.interpolate(
                offset_field_t,
                size=(output_res, output_res),
                mode='bilinear',
                align_corners=False
            ).squeeze(0)  # (3, H, W)
            
            # Permute back to (H, W, 3) for consistency
            offset_field_down = offset_field_down.permute(1, 2, 0)  # (H, W, 3)
        else:
            # Create zero offset field if pointmap not available
            offset_field_down = torch.zeros(output_res, output_res, 3)
        
        images_list.append(rgb_tensor)
        offset_fields_list.append(offset_field_down)
        episode_ids_list.append(sequence_id)
    
    return images_list, offset_fields_list, episode_ids_list, dino_features_list, pointmap_list

class GripperOffsetDataset(Dataset):
    """Dataset that indexes into preloaded data."""
    def __init__(self, images_list, offset_fields_list, episode_ids_list):
        self.images_list = images_list
        self.offset_fields_list = offset_fields_list
        self.episode_ids_list = episode_ids_list
    
    def __len__(self):
        return len(self.images_list)
    
    def __getitem__(self, idx):
        return (
            self.images_list[idx],
            self.offset_fields_list[idx],
            self.episode_ids_list[idx]
        )

class GripperOffsetPredictor(nn.Module):
    """CNN to predict 3D offset field for gripper center from DINO + RGB + XYZ features."""
    def __init__(self, input_channels=38, output_channels=3, output_res=224):
        super().__init__()
        self.output_res = output_res
        self.output_channels = output_channels
        
        # Encoder: process concatenated features
        self.encoder = nn.Sequential(
            # First conv block
            nn.Conv2d(input_channels, 64, kernel_size=5, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Second conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Third conv block
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        
        # Decoder: predict offset fields (pixel-aligned)
        self.decoder = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(64, output_channels, kernel_size=1),
            # No activation - output can be negative/positive
        )
    
    def forward(self, x):
        # x: (B, 38, H, W)
        features = self.encoder(x)  # (B, 128, H, W)
        offsets = self.decoder(features) / 10.0  # (B, 3, H, W) - scale down for stability
        return offsets  # (B, 3, H, W)

def train_model(model, train_loader, val_loader, lr=1e-3, epochs=1000, log_dir="runs/gripper_offset_predictor", checkpoint_path=None,
                train_dino_features=None, train_pointmaps=None, val_dino_features=None, val_pointmaps=None):
    device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()
    
    os.makedirs(log_dir, exist_ok=True)
    writer = SummaryWriter(log_dir)
    
    best_val_loss = float('inf')
    start_epoch = 0
    
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint from {checkpoint_path}...")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        best_val_loss = checkpoint['best_val_loss']
        print(f"  Resumed from epoch {start_epoch}, best val loss: {best_val_loss:.6f}")
    
    for epoch in range(start_epoch, epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        # Training
        model.train()
        train_loss = 0.0
        loss_weight = 1e3
        for images, offset_fields_gt, ep_ids in tqdm(train_loader, desc="Training"):
            images = images.to(device)  # (B, 38, H, W)
            # offset_fields_gt is (B, H, W, 3), convert to (B, 3, H, W)
            offset_fields_gt = offset_fields_gt.permute(0, 3, 1, 2).to(device)  # (B, 3, H, W)
            
            optimizer.zero_grad()
            offset_fields_pred = model(images)  # (B, 3, H, W)
            
            # Compute inverse distance weighting: weight points closer to gripper center more
            offset_magnitude_gt = torch.norm(offset_fields_gt, dim=1, keepdim=True)  # (B, 1, H, W)
            eps = 1e-6
            loss_mask = ((1.0 / (offset_magnitude_gt + eps)).clip(max=30.0)) / 30.0  # (B, 1, H, W)
            
            # Weighted loss: weight by inverse distance
            loss = (((offset_fields_pred - offset_fields_gt).norm(dim=1, keepdim=True)) * loss_mask * loss_weight).mean()
            
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, offset_fields_gt, ep_ids in val_loader:
                images = images.to(device)
                offset_fields_gt = offset_fields_gt.permute(0, 3, 1, 2).to(device)  # (B, 3, H, W)
                offset_fields_pred = model(images)
                loss = criterion(offset_fields_pred, offset_fields_gt)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        writer.add_scalar('Loss/Train', train_loss, epoch + 1)
        writer.add_scalar('Loss/Val', val_loss, epoch + 1)
        
        # Log visualizations every 10 epochs
        if epoch % 10 == 0:
            # Get sample batch
            train_iter = iter(train_loader)
            sample_images, sample_offset_fields_gt, sample_ep_ids = next(train_iter)
            sample_images = sample_images[:4].to(device)
            sample_offset_fields_gt = sample_offset_fields_gt[:4]  # (4, H, W, 3)
            
            with torch.no_grad():
                sample_offset_fields_pred = model(sample_images)  # (4, 3, H, W)
            
            # Convert to (H, W, 3) for visualization
            sample_offset_fields_pred_cpu = sample_offset_fields_pred.permute(0, 2, 3, 1).cpu()  # (4, H, W, 3)
            sample_offset_fields_gt_cpu = sample_offset_fields_gt.cpu()  # (4, H, W, 3)
            
            # Compute inverse distance maps for visualization
            gt_distance = torch.norm(sample_offset_fields_gt_cpu, dim=3)  # (4, H, W)
            pred_distance = torch.norm(sample_offset_fields_pred_cpu, dim=3)  # (4, H, W)
            
            eps = 1e-6
            gt_inv_dist = 1.0 / (gt_distance + eps)  # (4, H, W)
            pred_inv_dist = 1.0 / (pred_distance + eps)  # (4, H, W)
            
            # Normalize for visualization
            gt_min, gt_max = gt_inv_dist.min(), gt_inv_dist.max()
            pred_min, pred_max = pred_inv_dist.min(), pred_inv_dist.max()
            
            if gt_max > gt_min:
                gt_inv_dist_norm = (gt_inv_dist - gt_min) / (gt_max - gt_min)
            else:
                gt_inv_dist_norm = torch.zeros_like(gt_inv_dist)
            
            if pred_max > pred_min:
                pred_inv_dist_norm = (pred_inv_dist - pred_min) / (pred_max - pred_min)
            else:
                pred_inv_dist_norm = torch.zeros_like(pred_inv_dist)
            
            # Convert to grayscale images (3, H, W) for make_grid
            gt_images = []
            pred_images = []
            for i in range(4):
                gt_img = gt_inv_dist_norm[i].unsqueeze(0).repeat(3, 1, 1)  # (3, H, W)
                pred_img = pred_inv_dist_norm[i].unsqueeze(0).repeat(3, 1, 1)  # (3, H, W)
                gt_images.append(gt_img)
                pred_images.append(pred_img)
            
            # Create grids
            gt_grid = vutils.make_grid(gt_images, nrow=2, normalize=False, pad_value=1.0)
            pred_grid = vutils.make_grid(pred_images, nrow=2, normalize=False, pad_value=1.0)
            writer.add_image('Train/InvDist_GT', gt_grid, epoch + 1)
            writer.add_image('Train/InvDist_Pred', pred_grid, epoch + 1)
            
            # Convert RGB channels to display format (C, H, W) -> (H, W, C)
            images_rgb = sample_images[:, :3].permute(0, 2, 3, 1).cpu().numpy()
            images_rgb = np.clip(images_rgb, 0, 1)
            
            # Create grid using torchvision
            images_tensor = torch.from_numpy(images_rgb).permute(0, 3, 1, 2)  # (4, 3, H, W)
            grid = vutils.make_grid(images_tensor, nrow=2, normalize=False, pad_value=1.0)
            writer.add_image('Train/Input_RGB_with_Keypoints', grid, epoch + 1)
            
            # Visualize DINO PCA features
            if train_dino_features is not None:
                dino_vis_list = []
                for i in range(min(4, len(sample_ep_ids))):
                    ep_id = sample_ep_ids[i]
                    # Find index in original list
                    try:
                        idx = [eid for eid in train_loader.dataset.episode_ids_list].index(ep_id)
                        if idx < len(train_dino_features) and train_dino_features[idx] is not None:
                            dino_pca = train_dino_features[idx]
                            # Normalize and convert to RGB
                            if isinstance(dino_pca, torch.Tensor):
                                dino_pca = dino_pca.numpy()
                            if len(dino_pca.shape) == 2:
                                # Reshape if needed
                                N = dino_pca.shape[0]
                                H_dino = int(np.sqrt(N))
                                W_dino = N // H_dino
                                if H_dino * W_dino == N:
                                    dino_pca = dino_pca.reshape(H_dino, W_dino, 3)
                            
                            # Normalize DINO PCA features to [0, 1] range per channel
                            dino_pca_norm = dino_pca.copy()
                            for c in range(3):
                                channel_data = dino_pca_norm[:, :, c]
                                min_val = channel_data.min()
                                max_val = channel_data.max()
                                if max_val > min_val:
                                    dino_pca_norm[:, :, c] = (channel_data - min_val) / (max_val - min_val)
                                else:
                                    dino_pca_norm[:, :, c] = 0.5  # Default to middle if constant
                            
                            dino_rgb_resized = cv2.resize(dino_pca_norm, (OUTPUT_RES, OUTPUT_RES), interpolation=cv2.INTER_LINEAR)
                            dino_vis_list.append(torch.from_numpy(dino_rgb_resized).permute(2, 0, 1).float())
                        else:
                            dino_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                    except (ValueError, IndexError):
                        dino_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                
                if len(dino_vis_list) > 0:
                    dino_grid = vutils.make_grid(torch.stack(dino_vis_list), nrow=2, normalize=False, pad_value=1.0)
                    writer.add_image('Train/DINO_PCA_Features', dino_grid, epoch + 1)
            
            # Visualize pointmaps
            if train_pointmaps is not None:
                pointmap_vis_list = []
                for i in range(min(4, len(sample_ep_ids))):
                    ep_id = sample_ep_ids[i]
                    try:
                        idx = [eid for eid in train_loader.dataset.episode_ids_list].index(ep_id)
                        if idx < len(train_pointmaps) and train_pointmaps[idx] is not None:
                            pm = train_pointmaps[idx]
                            colors = pm['points']
                            if isinstance(colors, torch.Tensor):
                                colors = colors.numpy()
                            
                            colors_resized = cv2.resize(colors, (OUTPUT_RES, OUTPUT_RES), interpolation=cv2.INTER_LINEAR)
                            # Ensure values are in [0, 1] range after resize
                            colors_normalized = np.clip(colors_resized, 0.0, 1.0)
                            pointmap_vis_list.append(torch.from_numpy(colors_normalized).permute(2, 0, 1).float())
                        else:
                            pointmap_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                    except (ValueError, IndexError):
                        pointmap_vis_list.append(torch.zeros(3, OUTPUT_RES, OUTPUT_RES))
                
                if len(pointmap_vis_list) > 0:
                    pointmap_grid = vutils.make_grid(torch.stack(pointmap_vis_list), nrow=2, normalize=True, pad_value=1.0)
                    writer.add_image('Train/Pointmap_Visualization', pointmap_grid, epoch + 1)
        
        if (epoch + 1) % 50 == 0:
            print(f"Epoch {epoch+1}/{epochs}: Train Loss: {train_loss:.6f}, Val Loss: {val_loss:.6f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            os.makedirs("scratch/policies", exist_ok=True)
            torch.save(model.state_dict(), "scratch/policies/3d_keypoint_predictor.pt")
            if (epoch + 1) % 50 == 0:
                print(f"  Saved best model (val_loss: {val_loss:.6f})")
        
        # Save checkpoint every 100 epochs
        if (epoch + 1) % 100 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_loss': best_val_loss,
            }
            checkpoint_path_save = os.path.join("scratch/policies", "3d_keypoint_predictor_checkpoint.pt")
            torch.save(checkpoint, checkpoint_path_save)
            print(f"  Saved checkpoint at epoch {epoch + 1}")
    
    writer.close()
    print(f"\nTensorBoard logs saved to: {log_dir}")

def predict_and_save(model, data_loader, device, output_dir, split_name="test"):
    """Predict offset fields for all samples and save results."""
    model.eval()
    os.makedirs(output_dir, exist_ok=True)
    
    all_predictions = []
    all_gts = []
    all_seq_ids = []
    
    print(f"\nPredicting offset fields for {split_name} set...")
    with torch.no_grad():
        for images, offset_fields_gt, ep_ids in tqdm(data_loader, desc=f"Predicting {split_name}"):
            images = images.to(device)
            # offset_fields_gt is (B, H, W, 3), convert to (B, 3, H, W)
            offset_fields_gt = offset_fields_gt.permute(0, 3, 1, 2).to(device)  # (B, 3, H, W)
            
            # Predict
            offset_fields_pred = model(images)  # (B, 3, H, W)
            
            # Convert back to (B, H, W, 3) for saving
            offset_fields_pred = offset_fields_pred.permute(0, 2, 3, 1).cpu()  # (B, H, W, 3)
            offset_fields_gt = offset_fields_gt.permute(0, 2, 3, 1).cpu()  # (B, H, W, 3)
            
            all_predictions.append(offset_fields_pred)
            all_gts.append(offset_fields_gt)
            all_seq_ids.extend(ep_ids)
    
    # Concatenate all batches
    all_predictions = torch.cat(all_predictions, dim=0)  # (N, H, W, 3)
    all_gts = torch.cat(all_gts, dim=0)  # (N, H, W, 3)
    
    # Save raw tensors
    torch.save(all_predictions, os.path.join(output_dir, f"{split_name}_offset_fields_pred.pt"))
    torch.save(all_gts, os.path.join(output_dir, f"{split_name}_offset_fields_gt.pt"))
    print(f"  ✓ Saved raw tensors: {all_predictions.shape}")
    
    # Save sequence IDs
    with open(os.path.join(output_dir, f"{split_name}_sequence_ids.txt"), 'w') as f:
        for seq_id in all_seq_ids:
            f.write(f"{seq_id}\n")
    
    # Compute errors (mean offset magnitude error)
    pred_magnitude = torch.norm(all_predictions, dim=3)  # (N, H, W)
    gt_magnitude = torch.norm(all_gts, dim=3)  # (N, H, W)
    errors = torch.abs(pred_magnitude - gt_magnitude).mean().item()
    print(f"  ✓ Mean offset magnitude error: {errors:.4f} m")
    
    print(f"  ✓ Saved all results to {output_dir}")
    return all_predictions, all_gts, all_seq_ids

def visualize_predictions_viser(predictions, gt_offset_fields, seq_ids, processed_dir, output_res=224):
    """Visualize predicted offset fields and gripper centers in Viser."""
    try:
        from scipy.ndimage import gaussian_filter
    except ImportError:
        # Fallback if scipy not available
        def gaussian_filter(x, sigma):
            return x  # No filtering
    
    print("\n" + "=" * 60)
    print("Launching Viser visualization")
    print("=" * 60)
    
    server = viser.ViserServer()
    
    # Load ball mesh for gripper center visualization
    ball_stl_path = "robot_models/so100_blender_testings/ball.stl"
    ball_mesh = None
    if os.path.exists(ball_stl_path):
        ball_mesh = trimesh.load(ball_stl_path)
        if isinstance(ball_mesh, trimesh.Scene):
            ball_mesh = list(ball_mesh.geometry.values())[0]
        bounds = ball_mesh.bounds
        max_extent = np.max(bounds[1] - bounds[0])
        if max_extent > 1.0:
            ball_mesh.apply_scale(0.001)
    
    # Store current sample index
    current_sample_idx = [0]
    
    def update_sample(sample_idx):
        """Update visualization for selected sample."""
        if not (0 <= sample_idx < len(seq_ids)):
            return
        
        current_sample_idx[0] = sample_idx
        seq_id = seq_ids[sample_idx]
        seq_dir = Path(processed_dir) / seq_id
        
        # Load pointmap for visualization
        pointmap_path = seq_dir / "pointmap_start_raw.pt"
        if not pointmap_path.exists():
            print(f"  ⚠ Pointmap not found for {seq_id}")
            return
        
        pointmap = torch.load(pointmap_path)
        points_full = pointmap["points"].numpy() if isinstance(pointmap["points"], torch.Tensor) else pointmap["points"]
        colors_full = pointmap["colors"].numpy() if isinstance(pointmap["colors"], torch.Tensor) else pointmap["colors"]
        
        # Get image dimensions
        start_img_path = seq_dir / "start.png"
        if start_img_path.exists():
            start_img = cv2.imread(str(start_img_path))
            start_img = cv2.cvtColor(start_img, cv2.COLOR_BGR2RGB)
            H_orig, W_orig = start_img.shape[:2]
        else:
            H_orig = int(np.sqrt(len(points_full)))
            W_orig = len(points_full) // H_orig
        
        # Reshape points to 2D grid
        if H_orig * W_orig == len(points_full):
            points_2d = points_full.reshape(H_orig, W_orig, 3)
        else:
            print(f"  ⚠ Cannot reshape points: {len(points_full)} != {H_orig * W_orig}")
            return
        
        # Get predicted and GT offset fields (downsampled to output_res)
        offset_field_pred = predictions[sample_idx].numpy()  # (H, W, 3) at output_res
        offset_field_gt = gt_offset_fields[sample_idx].numpy()  # (H, W, 3) at output_res
        
        # Upsample offset fields to original resolution for visualization
        offset_field_pred_full = cv2.resize(offset_field_pred, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
        offset_field_gt_full = cv2.resize(offset_field_gt, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)
        
        # Extract predicted gripper center using weighted average on inverse distance
        # Compute inverse distance map
        pred_distance = np.linalg.norm(offset_field_pred, axis=2)  # (H, W)
        eps = 1e-6
        pred_inv_dist = 1.0 / (pred_distance + eps)
        
        # Smooth the inverse distance map
        pred_inv_dist_smooth = gaussian_filter(pred_inv_dist, sigma=2.0)
        
        # Normalize to get weights
        weights = pred_inv_dist_smooth / (pred_inv_dist_smooth.sum() + eps)
        
        # Compute weighted average: gripper_center = pointmap_point + offset
        # For each pixel: predicted_gripper = points_2d[i,j] + offset_field_pred[i,j]
        # Weighted average: sum(weights[i,j] * (points_2d[i,j] + offset_field_pred[i,j]))
        points_2d_down = cv2.resize(points_2d, (output_res, output_res), interpolation=cv2.INTER_LINEAR)
        predicted_gripper_centers = points_2d_down + offset_field_pred  # (H, W, 3)
        predicted_gripper_center = np.sum(
            predicted_gripper_centers * weights.reshape(output_res, output_res, 1),
            axis=(0, 1)
        )
        
        # Load GT gripper center directly from gripper pose (more accurate than reconstructing from offset field)
        gripper_pose_path = seq_dir / "gripper_pose_grasp.npy"
        if gripper_pose_path.exists():
            gripper_pose = np.load(gripper_pose_path)  # (4, 4)
            gt_gripper_center = gripper_pose[:3, 3]  # Extract 3D position directly
        else:
            # Fallback: reconstruct from offset field if pose file not found
            gt_distance = np.linalg.norm(offset_field_gt, axis=2)
            gt_inv_dist = 1.0 / (gt_distance + eps)
            gt_inv_dist_smooth = gaussian_filter(gt_inv_dist, sigma=2.0)
            gt_weights = gt_inv_dist_smooth / (gt_inv_dist_smooth.sum() + eps)
            gt_gripper_centers = points_2d_down + offset_field_gt
            gt_gripper_center = np.sum(
                gt_gripper_centers * gt_weights.reshape(output_res, output_res, 1),
                axis=(0, 1)
            )
        
        # Flatten points for pointcloud
        points_flat = points_2d.reshape(-1, 3)
        colors_flat = colors_full.reshape(-1, 3) if colors_full.shape == (H_orig, W_orig, 3) else colors_full
        
        # Ensure colors are uint8
        if colors_flat.max() <= 1.0:
            colors_flat = (colors_flat * 255).astype(np.uint8)
        else:
            colors_flat = colors_flat.astype(np.uint8)
        
        # Update pointcloud with original RGB colors
        server.scene.add_point_cloud(
            name="/pointmap_original",
            points=points_flat.astype(np.float32),
            colors=colors_flat.astype(np.uint8),
            point_size=0.002,
        )
        
        # Add predicted gripper center sphere (red)
        if ball_mesh is not None:
            server.scene.add_mesh_trimesh(
                name="/gripper_center_pred",
                mesh=ball_mesh,
                wxyz=(1.0, 0.0, 0.0, 0.0),
                position=predicted_gripper_center.astype(np.float32),
            )
        
        # Add GT gripper center sphere (blue)
        if ball_mesh is not None:
            server.scene.add_mesh_trimesh(
                name="/gripper_center_gt",
                mesh=ball_mesh,
                wxyz=(1.0, 0.0, 0.0, 0.0),
                position=gt_gripper_center.astype(np.float32),
            )
        
        # Print info
        error = np.linalg.norm(predicted_gripper_center - gt_gripper_center)
        print(f"\nSample {sample_idx+1}/{len(seq_ids)}: {seq_id}")
        print(f"  Predicted Center: [{predicted_gripper_center[0]:.4f}, {predicted_gripper_center[1]:.4f}, {predicted_gripper_center[2]:.4f}]")
        print(f"  GT Center: [{gt_gripper_center[0]:.4f}, {gt_gripper_center[1]:.4f}, {gt_gripper_center[2]:.4f}]")
        print(f"  Error: {error:.4f} m")
    
    # Add slider for sample selection
    slider = server.gui.add_slider(
        "/sample_slider",
        min=0,
        max=max(0, len(seq_ids) - 1),
        step=1,
        initial_value=0,
    )
    
    @slider.on_update
    def on_slider_change(_):
        slider_value = slider.value
        update_sample(int(slider_value))
    
    # Initialize with first sample
    update_sample(0)
    
    print(f"\nViser server running at http://localhost:8080")
    print(f"Use the slider to navigate between {len(seq_ids)} samples")
    print(f"Press Ctrl+C to exit")
    print("=" * 60)
    
    try:
        while True:
            import time
            time.sleep(0.1)
    except KeyboardInterrupt:
        pass
    
    print("\nDone!")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--mode", type=str, choices=["train", "test", "test_with_train"], default="train")
    parser.add_argument("--batch_size", type=int, default=4)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--epochs", type=int, default=1000)
    parser.add_argument("--checkpoint", type=str, default=None)
    parser.add_argument("--run_name", type=str, default="3d_keypoint_predictor")
    args = parser.parse_args()
    
    print("Loading dataset...")
    from pathlib import Path
    processed_dir = Path("scratch/processed_grasp_dataset_keyboard")
    
    # Find all sequences
    sequences = sorted([d for d in processed_dir.iterdir() if d.is_dir()])
    print(f"Found {len(sequences)} sequences")
    
    # Split train/val (last 5 for validation)
    train_sequences = sequences[:-5] if len(sequences) > 5 else sequences[:-1]
    val_sequences = sequences[-5:] if len(sequences) > 5 else sequences[-1:]
    
    if args.mode == "train":
        train_sequences = train_sequences[:N_episodes]
        val_sequences = val_sequences[:1]
    
    print(f"Train: {len(train_sequences)} sequences, Val: {len(val_sequences)} sequences")
    
    # Preload all data
    print("\nPreloading training data...")
    train_images, train_offset_fields, train_ep_ids, train_dino_features, train_pointmaps = preload_dataset(
        train_sequences, processed_dir, use_dino_pointmap=True, output_res=OUTPUT_RES
    )
    print("\nPreloading validation data...")
    val_images, val_offset_fields, val_ep_ids, val_dino_features, val_pointmaps = preload_dataset(
        val_sequences, processed_dir, use_dino_pointmap=True, output_res=OUTPUT_RES
    )
    
    train_dataset = GripperOffsetDataset(train_images, train_offset_fields, train_ep_ids)
    val_dataset = GripperOffsetDataset(val_images, val_offset_fields, val_ep_ids)
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0)
    
    print(f"Train: {len(train_dataset)} samples, Val: {len(val_dataset)} samples")
    
    # Create model
    model = GripperOffsetPredictor(
        input_channels=INPUT_CHANNELS,
        output_channels=OUTPUT_CHANNELS,
        output_res=OUTPUT_RES
    )
    
    model_path = os.path.join("scratch/policies", f"{args.run_name}.pt")
    log_dir = os.path.join("runs", args.run_name)
    
    if args.mode == "train":
        checkpoint_path = args.checkpoint
        
        train_model(model, train_loader, val_loader, lr=args.lr, epochs=args.epochs,
                   log_dir=log_dir, checkpoint_path=checkpoint_path,
                   train_dino_features=train_dino_features, train_pointmaps=train_pointmaps,
                   val_dino_features=val_dino_features, val_pointmaps=val_pointmaps)
        print(f"\nTraining complete. Model saved to {model_path}")
    else:
        if not os.path.exists(model_path):
            print(f"Error: Model not found at {model_path}. Train first.")
            sys.exit(1)
        device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
        
        # Load model (handle both state_dict and checkpoint dict formats)
        checkpoint = torch.load(model_path, map_location=device)
        if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)
        
        model = model.to(device).eval()
        print(f"Loaded model from {model_path}")
        
        # Predict and save results for both train and val sets
        output_dir = os.path.join("scratch/pred", args.run_name)
        os.makedirs(output_dir, exist_ok=True)
        
        print("\n" + "=" * 60)
        print("Predicting and saving results")
        print("=" * 60)
        
        # Predict on train set
        train_pred, train_gt, train_seq_ids = predict_and_save(
            model, train_loader, device, output_dir, split_name="train"
        )
        
        # Predict on val set
        val_pred, val_gt, val_seq_ids = predict_and_save(
            model, val_loader, device, output_dir, split_name="val"
        )
        
        print(f"\n✓ Done! Results saved to {output_dir}")
        print(f"  Train: {len(train_seq_ids)} samples")
        print(f"  Val: {len(val_seq_ids)} samples")
        
        # Visualize predictions in Viser
        print("\n" + "=" * 60)
        print("Visualizing predictions in Viser")
        print("=" * 60)
        
        if args.mode == "test_with_train":
            # Visualize training data
            visualize_predictions_viser(
                train_pred, train_gt, train_seq_ids, processed_dir, output_res=OUTPUT_RES
            )
        else:
            # Visualize validation data
            visualize_predictions_viser(
                val_pred, val_gt, val_seq_ids, processed_dir, output_res=OUTPUT_RES
            )

        

