"""Batch process all sequences in grasp_dataset_keyboard directory."""
import sys
import os
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/vggt")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/MoGe")
sys.path.append("/Users/cameronsmith/Projects/robotics_testing/random/dinov3")
sys.path.append("Demos")
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))

import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as TF
from PIL import Image
from tqdm import tqdm
from scipy.spatial.transform import Rotation as R
import argparse
import pickle

import mujoco

from moge.model.v2 import MoGeModel
import utils3d

from ExoConfigs.so100_adhesive import SO100AdhesiveConfig
from ExoConfigs.alignment_board import ALIGNMENT_BOARD_CONFIG
from exo_utils import (
    detect_and_set_link_poses,
    estimate_robot_state,
    position_exoskeleton_meshes,
    render_from_camera_pose,
    detect_and_position_alignment_board,
    combine_xmls,
    get_link_poses_from_robot,
)
from demo_utils import procrustes_alignment

# Configuration constants
PATCH_SIZE = 16
IMAGE_SIZE = 768
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
N_LAYERS = 12

# Keypoints in gripper local frame (mm, converted to meters)
KEYPOINTS_LOCAL_MM = np.array([
    [13.25, -91.42, 15.9],
    [10.77, -99.6, 0],
    [13.25, -91.42, -15.9],
    [17.96, -83.96, 0],
    [22.86, -70.46, 0]
])
KEYPOINTS_LOCAL_M = KEYPOINTS_LOCAL_MM / 1000.0


def process_episode( seq_id, input_dir, output_dir, model, robot_config, device, model_seg, model_moge, preprocess, model_dino, pca_embedder):
    """Process a single episode."""
    # Create output directory
    seq_output_dir = os.path.join(output_dir, seq_id)
    os.makedirs(seq_output_dir, exist_ok=True)
    
    # File paths
    scene_img_path = os.path.join(input_dir, f"{seq_id}_helper_scene.png")
    grasp_img_path = os.path.join(input_dir, f"{seq_id}_helper_grasp.png")
    grasp_qpos_path = os.path.join(input_dir, f"{seq_id}_helper_grasp.npy")
    start_qpos_path = os.path.join(input_dir, f"{seq_id}_helper_scene.npy")
    
    # Load images
    rgb_scene = cv2.cvtColor(cv2.imread(scene_img_path), cv2.COLOR_BGR2RGB)
    rgb_grasp = cv2.cvtColor(cv2.imread(grasp_img_path), cv2.COLOR_BGR2RGB)
    if rgb_scene.max() <= 1.0: rgb_scene = (rgb_scene * 255).astype(np.uint8)
    if rgb_grasp.max() <= 1.0: rgb_grasp = (rgb_grasp * 255).astype(np.uint8)
    
    # Load joint state
    qpos_start = np.load(start_qpos_path)
    qpos_grasp = np.load(grasp_qpos_path)
    qpos_grasp[-1] = 1.2  # Set gripper opening
    
    # Setup MuJoCo
    data = mujoco.MjData(model)
    data.qpos[:] = qpos_start
    data.ctrl[:] = qpos_start[:len(data.ctrl)]
    mujoco.mj_forward(model, data)
    
    # Detect ArUco markers from scene image
    _, camera_pose_world, cam_K, corners_cache, corners_vis, obj_img_pts = detect_and_set_link_poses(
        rgb_scene, model, data, robot_config
    )
    
    # Save camera pose
    np.save(os.path.join(seq_output_dir, "robot_camera_pose.npy"), camera_pose_world)
    
    # Detect link poses from grasp image for robot state estimation
    link_poses, _, _, _, _, _ = detect_and_set_link_poses(rgb_scene, model, data, robot_config)
    
    # Detect alignment board
    board_result = detect_and_position_alignment_board(
        rgb_scene, model, data, ALIGNMENT_BOARD_CONFIG, cam_K, camera_pose_world, corners_cache, visualize=False
    )
    if board_result is not None:
        board_pose, board_pts = board_result
        obj_img_pts["alignment_board"] = board_pts
    
    position_exoskeleton_meshes(robot_config, model, data, link_poses)
    mujoco.mj_forward(model, data)
    
    # Get masks
    robot_mask = get_robot_mask(model, data, camera_pose_world, cam_K, rgb_scene.shape[:2])
    human_mask = get_human_mask(rgb_scene, model_seg, preprocess, device)
    
    # Get MoGe pointcloud (skip inference if already exists)
    moge_output_path = os.path.join(seq_output_dir, "moge_output_raw.pt")
    if os.path.exists(moge_output_path):
        print(f"  ⏭ Skipping MoGe inference (moge_output_raw.pt exists)")
    moge_output = get_moge_pointcloud(rgb_scene, model_moge, device, seq_output_dir)
    points_cam, mask, moge_edge_mask, H, W = moge_output
    
    # Align pointcloud to robot frame (for valid points only, used for DINO features)
    points_robot, valid_colors, valid_y_coords, valid_x_coords, T_procrustes = align_pointcloud_to_robot_frame(
        points_cam, mask, rgb_scene, robot_mask, human_mask, camera_pose_world, obj_img_pts, H, W
    )
    
    # Transform ALL points to robot frame (not just valid ones) for offset field computation
    points_cam_flat = points_cam.reshape(-1, 3)  # (H*W, 3)
    points_cam_h = np.hstack([points_cam_flat, np.ones((len(points_cam_flat), 1))])  # (H*W, 4)
    points_robot_full = (T_procrustes @ points_cam_h.T).T[:, :3]  # (H*W, 3)
    points_robot_full = points_robot_full.reshape(H, W, 3)  # (H, W, 3)

    # Save robot-frame MoGE pointcloud for downstream processing and visualization
    moge_robot_output_path = os.path.join(seq_output_dir, "moge_output_robot_frame.pt")
    moge_robot_output = {
        "points": points_robot_full,                # (H, W, 3) ndarray, float32
        "colors": valid_colors if valid_colors is not None else np.zeros_like(points_robot_full),  # fallback if needed
        "mask": mask,
        "moge_edge_mask": moge_edge_mask,
        "H": H,
        "W": W,
    }
    torch.save(moge_robot_output, moge_robot_output_path)
    print(f"  ✓ Saved robot-frame MoGE pointcloud: {moge_robot_output_path}")
    
    # Compute gripper pose from grasp joint state
    data_grasp = mujoco.MjData(model)
    data_grasp.qpos[:] = qpos_grasp
    data_grasp.ctrl[:] = qpos_grasp[:len(data_grasp.ctrl)]
    mujoco.mj_forward(model, data_grasp)
    
    # Position exoskeleton meshes for grasp state
    link_poses_grasp = get_link_poses_from_robot(robot_config, model, data_grasp)
    position_exoskeleton_meshes(robot_config, model, data_grasp, link_poses_grasp)
    mujoco.mj_forward(model, data_grasp)
    
    # Extract gripper pose
    exo_mesh_body_name = "fixed_gripper_exo_mesh"
    exo_mesh_body_id = mujoco.mj_name2id(model, mujoco.mjtObj.mjOBJ_BODY, exo_mesh_body_name)
    exo_mesh_mocap_id = model.body_mocapid[exo_mesh_body_id]
    exo_pos = data_grasp.mocap_pos[exo_mesh_mocap_id].copy()
    exo_quat_wxyz = data_grasp.mocap_quat[exo_mesh_mocap_id].copy()
    
    gripper_pose = np.eye(4)
    gripper_pose[:3, :3] = R.from_quat(exo_quat_wxyz[[1, 2, 3, 0]]).as_matrix()
    gripper_pose[:3, 3] = exo_pos
    
    # Compute offset fields: offset = keypoint - pointmap_point
    # Transform keypoints from gripper local frame to robot frame
    gripper_rot = gripper_pose[:3, :3]
    gripper_pos = gripper_pose[:3, 3]
    keypoints_robot = (gripper_rot @ KEYPOINTS_LOCAL_M.T).T + gripper_pos.reshape(1, 3)  # (5, 3)
    
    # Initialize offset field tensor: 5 keypoints x H x W x 3
    offset_fields = np.zeros((5, H, W, 3), dtype=np.float32)
    
    # Compute offsets for ALL HxW pixels (not just valid ones)
    for kp_idx, keypoint in enumerate(keypoints_robot):
        # offset = keypoint - points_robot_full for all pixels
        # points_robot_full is (H, W, 3), keypoint is (3,)
        offset_fields[kp_idx] = keypoint.reshape(1, 1, 3) - points_robot_full  # (H, W, 3)
    
    # Save offset fields
    offset_fields_path = os.path.join(seq_output_dir, "offset_fields.pt")
    torch.save(torch.from_numpy(offset_fields), offset_fields_path)
    print(f"  ✓ Saved offset fields: {offset_fields.shape}")
    
    # Extract and save DINO features (skip if already exists)
    dino_features_path = os.path.join(seq_output_dir, "dino_features.pt")
    dino_features_hw_path = os.path.join(seq_output_dir, "dino_features_hw.pt")
    dino_vis_path = os.path.join(seq_output_dir, "dino_features_vis.png")
    if (os.path.exists(dino_features_path) and os.path.exists(dino_features_hw_path) and os.path.exists(dino_vis_path)):
        print(f"  ⏭ Skipping DINO feature extraction (dino_features.pt, dino_features_hw.pt and dino_features_vis.png exist)")
    else:
        dino_features, dino_features_hw, pca_features_patches = run_dino_features(
            rgb_scene, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords
        )
        torch.save(
            torch.from_numpy(dino_features.astype(np.float32)),
            dino_features_path
        )
        torch.save(
            torch.from_numpy(dino_features_hw.astype(np.float32)),
            dino_features_hw_path
        )
        visualize_dino_features(rgb_scene, pca_features_patches, H, W, seq_output_dir)
        print(f"  ✓ Saved DINO features: {len(dino_features)} points, {dino_features_hw.shape} full resolution, {pca_embedder.n_components_} dimensions")
    
    # Save all outputs
    print(f"  Saving outputs to {seq_output_dir}")
    
    # Save images
    plt.imsave(os.path.join(seq_output_dir, "start.png"), rgb_scene)
    plt.imsave(os.path.join(seq_output_dir, "grasp.png"), rgb_grasp)
    
    # Save masks
    plt.imsave(os.path.join(seq_output_dir, "robot_mask.png"), robot_mask.astype(np.uint8) * 255, cmap='gray')
    plt.imsave(os.path.join(seq_output_dir, "human_mask.png"), human_mask.astype(np.uint8) * 255, cmap='gray')
    plt.imsave(os.path.join(seq_output_dir, "moge_edge_mask.png"), moge_edge_mask.astype(np.uint8) * 255, cmap='gray')
    
    # Save joint states
    np.save(os.path.join(seq_output_dir, "joint_states_grasp.npy"), qpos_grasp)
    
    # Save gripper pose
    np.save(os.path.join(seq_output_dir, "gripper_pose_grasp.npy"), gripper_pose)
    
    # Save pointclouds
    # Raw pointmap in camera frame
    points_cam_flat = points_cam.reshape(-1, 3)
    mask_flat = mask.reshape(-1)
    colors_flat = rgb_scene.reshape(-1, 3)
    pointmap_raw = {
        "points": torch.from_numpy(points_cam_flat.astype(np.float32)),
        "colors": torch.from_numpy(colors_flat.astype(np.uint8)),
        "mask": torch.from_numpy(mask_flat.astype(bool)),
    }
    torch.save(pointmap_raw, os.path.join(seq_output_dir, "pointmap_start_raw.pt"))
    
    # Robot-aligned pointmap
    pointmap_robot = {
        "points": torch.from_numpy(points_robot.astype(np.float32)),
        "colors": torch.from_numpy(valid_colors),
    }
    torch.save(pointmap_robot, os.path.join(seq_output_dir, "pointmap_start.pt"))
    
    print(f"  ✓ Saved: {len(points_cam_flat)} raw (camera frame), {len(points_robot)} robot-aligned points")


def get_robot_mask(model, data, camera_pose_world, cam_K, image_shape):
    """Get robot mask from rendered segmentation."""
    seg = render_from_camera_pose(
        model, data, camera_pose_world, cam_K, *image_shape, segmentation=True
    )
    robot_mask = (seg[..., 0] > 0)
    return robot_mask


def get_human_mask(rgb_scene, model_seg, preprocess, device):
    """Get human mask from DeepLabV3 segmentation."""
    img_pil = Image.fromarray(rgb_scene)
    input_tensor = preprocess(img_pil)
    input_batch = input_tensor.unsqueeze(0).to(device)
    
    with torch.no_grad():
        output_seg = model_seg(input_batch)['out'][0]
    output_predictions = output_seg.argmax(0).cpu().numpy()
    human_mask = (output_predictions == 15).astype(np.uint8)
    return human_mask


def get_moge_pointcloud(rgb_scene, model_moge, device, seq_output_dir):
    """Get MoGe pointcloud, loading from cache if available."""
    moge_output_path = os.path.join(seq_output_dir, "moge_output_raw.pt")
    
    if not os.path.exists(moge_output_path):
        input_tensor_moge = torch.tensor(rgb_scene / 255.0, dtype=torch.float32, device=device).permute(2, 0, 1)
        with torch.no_grad():
            output_moge = model_moge.infer(input_tensor_moge)
        torch.save(output_moge, moge_output_path)
    else:
        output_moge = torch.load(moge_output_path, weights_only=False)
    
    points = output_moge["points"].cpu().numpy()  # (H, W, 3)
    mask = output_moge["mask"].cpu().numpy().astype(bool)
    
    # Compute depth edge mask separately
    moge_edge_mask = utils3d.np.depth_map_edge(points[:, :, 2], rtol=0.005)
    mask = mask & ~moge_edge_mask
    
    H, W = points.shape[:2]
    return points, mask, moge_edge_mask, H, W


def align_pointcloud_to_robot_frame(points, mask, rgb_scene, robot_mask, human_mask, camera_pose_world, obj_img_pts, H, W):
    """Align MoGe pointcloud to robot frame using Procrustes alignment."""
    colors = rgb_scene.reshape(-1, 3)
    points_flat = points.reshape(-1, 3)
    mask_flat = mask.reshape(-1)
    
    # Track pixel coordinates for each point
    y_coords, x_coords = np.meshgrid(np.arange(H), np.arange(W), indexing='ij')
    y_coords_flat = y_coords.reshape(-1)
    x_coords_flat = x_coords.reshape(-1)
    
    # Apply robot and human masks
    robot_mask_resized = np.array(Image.fromarray(robot_mask.astype(np.uint8) * 255).resize((W, H), Image.Resampling.LANCZOS)) > 127
    human_mask_resized = np.array(Image.fromarray(human_mask.astype(np.uint8) * 255).resize((W, H), Image.Resampling.LANCZOS)) > 127
    exclude_mask = robot_mask_resized | human_mask_resized
    exclude_mask_flat = exclude_mask.reshape(-1)
    mask_flat = mask_flat & ~exclude_mask_flat
    
    valid_points_cam = points_flat[mask_flat]
    valid_colors = colors[mask_flat].astype(np.uint8)
    valid_y_coords = y_coords_flat[mask_flat]
    valid_x_coords = x_coords_flat[mask_flat]
    
    # Align to robot frame using Procrustes
    aruco_corners_robot_frame = []
    moge_aruco_corners = []
    
    for obj_name, (obj_img_pts_cam, img_pts_px) in obj_img_pts.items():
        if obj_name not in ["alignment_board", "larger_base"]:
            continue
        obj_cam_h = np.hstack([obj_img_pts_cam, np.ones((obj_img_pts_cam.shape[0], 1))])
        obj_robot = (np.linalg.inv(camera_pose_world) @ obj_cam_h.T).T[:, :3]
        aruco_corners_robot_frame.extend(obj_robot)
        
        for pt in img_pts_px:
            x, y = int(pt[0]), int(pt[1])
            if 0 <= y < H and 0 <= x < W:
                moge_aruco_corners.append(points[y, x])
    
    aruco_corners_robot_frame = np.array(aruco_corners_robot_frame)
    moge_aruco_corners = np.array(moge_aruco_corners)
    
    if len(aruco_corners_robot_frame) >= 3 and len(moge_aruco_corners) >= 3:
        T_procrustes, scale, rotation, translation = procrustes_alignment(
            aruco_corners_robot_frame, moge_aruco_corners
        )
        points_h = np.hstack([valid_points_cam, np.ones((len(valid_points_cam), 1))])
        points_robot = (T_procrustes @ points_h.T).T[:, :3]
    else:
        print(f"  ⚠ Insufficient correspondences, using camera frame")
        points_robot = valid_points_cam
        # Create identity transformation for fallback
        T_procrustes = np.eye(4)
    
    return points_robot, valid_colors, valid_y_coords, valid_x_coords, T_procrustes


def resize_transform(img: Image.Image, image_size: int = IMAGE_SIZE, patch_size: int = PATCH_SIZE) -> torch.Tensor:
    """Resize image to dimensions divisible by patch size."""
    w, h = img.size
    h_patches = int(image_size / patch_size)
    w_patches = int((w * image_size) / (h * patch_size))
    return TF.to_tensor(TF.resize(img, (h_patches * patch_size, w_patches * patch_size)))


def run_dino_features(rgb_scene, model_dino, pca_embedder, device, H, W, valid_y_coords, valid_x_coords):
    """Extract DINO features for all valid points and full HxW resolution."""
    # Load and preprocess image for DINO
    img_pil = Image.fromarray(rgb_scene).convert("RGB")
    image_resized = resize_transform(img_pil)
    image_resized_norm = TF.normalize(image_resized, mean=IMAGENET_MEAN, std=IMAGENET_STD)
    
    # Extract DINO features
    with torch.inference_mode():
        with torch.autocast(device_type='mps' if device.type == 'mps' else 'cpu', dtype=torch.float32):
            feats = model_dino.get_intermediate_layers(
                image_resized_norm.unsqueeze(0).to(device),
                n=range(N_LAYERS),
                reshape=True,
                norm=True
            )
            x = feats[-1].squeeze().detach().cpu()  # (D, H_patches, W_patches)
            dim = x.shape[0]
            x = x.view(dim, -1).permute(1, 0).numpy()  # (H_patches * W_patches, D)
            
            # Apply PCA to reduce to 32 dimensions
            pca_features_all = pca_embedder.transform(x)  # (H_patches * W_patches, 32)
            
            # Get patch resolution
            h_patches, w_patches = [int(d / PATCH_SIZE) for d in image_resized.shape[1:]]
            pca_features_patches = pca_features_all.reshape(h_patches, w_patches, -1)  # (H_patches, W_patches, 32)
            
            # Upsample features to full image resolution
            pca_features_tensor = torch.from_numpy(pca_features_patches).permute(2, 0, 1).float()  # (32, H_patches, W_patches)
            pca_features_upsampled = TF.resize(
                pca_features_tensor,
                (H, W),
                interpolation=TF.InterpolationMode.BILINEAR
            ).permute(1, 2, 0).numpy()  # (H, W, 32)
            
            # Sample features for each point using its pixel coordinate
            dino_features = np.zeros((len(valid_y_coords), pca_embedder.n_components_))
            for i, (y, x) in enumerate(zip(valid_y_coords, valid_x_coords)):
                if 0 <= y < H and 0 <= x < W:
                    dino_features[i] = pca_features_upsampled[y, x]
    
    return dino_features, pca_features_upsampled, pca_features_patches


def visualize_dino_features(rgb_scene, pca_features_patches, H, W, seq_output_dir):
    """Visualize DINO features and save visualization."""
    # Create visualization: map first 3 PCA components to RGB
    pca_rgb = pca_features_patches[:, :, :3]
    pca_rgb_normalized = torch.nn.functional.sigmoid(torch.from_numpy(pca_rgb).mul(2.0)).numpy()
    pca_rgb_upsampled = TF.resize(
        torch.from_numpy(pca_rgb_normalized).permute(2, 0, 1).float(),
        (H, W),
        interpolation=TF.InterpolationMode.BILINEAR
    ).permute(1, 2, 0).numpy()
    
    # Create visualization with RGB overlay
    img_normalized = rgb_scene.astype(float) / 255.0
    overlay = img_normalized * 0.5 + pca_rgb_upsampled * 0.5
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(rgb_scene)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    axes[1].imshow(pca_rgb_upsampled)
    axes[1].set_title('DINO PCA (first 3 components)')
    axes[1].axis('off')
    
    axes[2].imshow(overlay)
    axes[2].set_title('Overlay: RGB + DINO')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.savefig(os.path.join(seq_output_dir, "dino_features_vis.png"), dpi=150, bbox_inches='tight')
    plt.close()


if __name__ == "__main__":
    # Configuration
    input_dir = "scratch/grasp_dataset_keyboard"
    output_dir = "scratch/processed_grasp_dataset_keyboard"
    os.makedirs(output_dir, exist_ok=True)
    
    # DINOv3 configuration
    REPO_DIR = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3"
    WEIGHTS_PATH = "/Users/cameronsmith/Projects/robotics_testing/random/dinov3/weights/dinov3_vits16plus_pretrain_lvd1689m-4057cbaa.pth"
    
    print("=" * 60)
    print("Finding all sequences")
    print("=" * 60)
    
    # Find all unique sequence IDs
    all_files = os.listdir(input_dir)
    scene_images = [f for f in all_files if f.endswith("_helper_scene.png")]
    sequence_ids = sorted(set([f.replace("_helper_scene.png", "") for f in scene_images]))
    
    print(f"Found {len(sequence_ids)} sequences")
    for seq_id in sequence_ids:
        print(f"  - {seq_id}")
    
    # Setup robot config and models (load once, reuse)
    SO100AdhesiveConfig.exo_alpha = 0.2
    SO100AdhesiveConfig.aruco_alpha = 0.2
    robot_config = SO100AdhesiveConfig()
    combined_xml = combine_xmls(robot_config.xml, ALIGNMENT_BOARD_CONFIG.get_xml_addition())
    model = mujoco.MjModel.from_xml_string(combined_xml)
    
    # Load models once
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    print("Loading DeepLabV3 segmentation model...")
    model_seg = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=True)
    model_seg.eval()
    model_seg = model_seg.to(device)
    
    print("Loading MoGE model...")
    model_moge = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").to(device)
    model_moge.eval()
    
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Load DINOv3 model and PCA embedder
    print("Loading DINOv3 model...")
    with torch.inference_mode():
        model_dino = torch.hub.load(REPO_DIR, 'dinov3_vits16plus', source='local', weights=WEIGHTS_PATH).to(device)
        model_dino.eval()
    
    pca_path = "scratch/dino_pca_embedder.pkl"
    if not os.path.exists(pca_path):
        raise FileNotFoundError(f"PCA embedder not found: {pca_path}. Please run get_dino_pca_emb.py first.")
    
    print(f"Loading PCA embedder from {pca_path}")
    with open(pca_path, 'rb') as f:
        pca_data = pickle.load(f)
        pca_embedder = pca_data['pca']
        print(f"PCA embedder: {pca_embedder.n_components_} dimensions")
    
    print("\n" + "=" * 60)
    print("Processing sequences")
    print("=" * 60)
    
    # Process each sequence
    for seq_id in tqdm(sequence_ids, desc="Processing sequences"):
        print(f"\n{'='*60}")
        print(f"Processing sequence: {seq_id}")
        print(f"{'='*60}")
        
        # Check all input files exist
        scene_img_path = os.path.join(input_dir, f"{seq_id}_helper_scene.png")
        grasp_img_path = os.path.join(input_dir, f"{seq_id}_helper_grasp.png")
        grasp_qpos_path = os.path.join(input_dir, f"{seq_id}_helper_grasp.npy")
        start_qpos_path = os.path.join(input_dir, f"{seq_id}_helper_scene.npy")
        
        if not all([os.path.exists(p) for p in [scene_img_path, grasp_img_path, grasp_qpos_path, start_qpos_path]]):
            print(f"  ✗ Missing input files, skipping")
            continue
        
        process_episode(
            seq_id,
            input_dir,
            output_dir,
            model,
            robot_config,
            device,
            model_seg,
            model_moge,
            preprocess,
            model_dino,
            pca_embedder
        )
    
    print(f"\n{'='*60}")
    print(f"Done! Processed {len(sequence_ids)} sequences")
    print(f"Output directory: {output_dir}")
    print(f"{'='*60}")
