"""Dataset for real dense trajectory prediction."""
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from pathlib import Path
import matplotlib.pyplot as plt
import cv2
import math
from PIL import Image

# Number of waypoints in trajectory
N_WINDOW = 12

# Point-track pretraining (RTX): predict N_WINDOW future frames from first frame
N_WINDOW_POINT_TRACK = 15
N_QUERY_POINTS = 32
HEATMAP_SIZE = 64
TRACKS_ROOT = Path("/data/RTX/tracks")
IMAGE_SIZE_PT = 448
# Motion filter: only use tracks in top 15% by total distance traveled (85th percentile)
MOTION_PERCENTILE = 85

# Hard-coded gripper range for regression
MIN_GRIPPER = -0.2
MAX_GRIPPER = 0.8

def process_gripper_value(gripper_value):
    """Process gripper value: fix wrapping and clamp to range.
    
    Args:
        gripper_value: scalar gripper joint value
    
    Returns:
        Processed gripper value in [MIN_GRIPPER, MAX_GRIPPER] = [-0.2, 0.8]
        Processing steps:
        1. Map values > 4.0 to -0.2 (wraparound fix)
        2. Clamp values > 0.8 to 0.8
        3. Ensure final range is [-0.2, 0.8]
    """
    # Pseudocode: x = gripper_inp; if x > 4: x = -0.2; if x > 0.8: x = 0.8
    x = gripper_value
    if x > 4.0:
        x = -0.2
    if x > 0.8:
        x = 0.8
    
    # Final clamp to ensure range [-0.2, 0.8] (handles values < -0.2 if any)
    return max(MIN_GRIPPER, min(MAX_GRIPPER, x))

# Commented out binarization function (kept for reference)
# def binarize_gripper(gripper_value):
#     """Binarize gripper value to open (1.0) or closed (0.0).
#     
#     Args:
#         gripper_value: scalar gripper joint value
#     
#     Returns:
#         Binary gripper value: 1.0 for open, 0.0 for closed
#         Closed if joint < 0.1 or joint > 4.0 (wrapped values), else open
#     """
#     # Fix wrapping: if value > 4, it's wrapped (2π ≈ 6.28, so values > 4 are likely wrapped)
#     #if gripper_value > 4.0:
#     #    # Convert: 2π - value (so 6.1 becomes ~-0.2)
#     #    gripper_value = 2 * math.pi - gripper_value
#     
#     # Binarize: closed if < 0.1 or > 4.0 (after wrapping fix, > 4.0 shouldn't happen, but check anyway)
#     if gripper_value < 0.1 or gripper_value > 4.0:
#         return 0.0  # Closed
#     else:
#         return 1.0  # Open

# Virtual gripper keypoint in local gripper frame
KEYPOINTS_LOCAL_M_ALL = np.array([[13.25, -91.42, 15.9], [10.77, -99.6, 0], [13.25, -91.42, -15.9], 
                                   [17.96, -83.96, 0], [22.86, -70.46, 0]]) / 1000.0
KP_INDEX = 3  # Using index 3
kp_local = KEYPOINTS_LOCAL_M_ALL[KP_INDEX]

def project_3d_to_2d(point_3d, camera_pose, cam_K):
    """Project 3D point to 2D pixel coordinates.
    
    Args:
        point_3d: (3,) 3D point in world coordinates
        camera_pose: (4, 4) camera pose matrix (world-to-camera)
        cam_K: (3, 3) camera intrinsics
    
    Returns:
        (2,) 2D pixel coordinates [x, y], or None if behind camera
    """
    # Transform to camera frame
    point_3d_h = np.append(point_3d, 1.0)
    point_cam = camera_pose @ point_3d_h
    
    # Check if point is behind camera
    if point_cam[2] <= 0:
        return None
    
    # Project to image plane
    point_2d_h = cam_K @ point_cam[:3]
    point_2d = point_2d_h[:2] / point_2d_h[2]
    
    return point_2d

class RealTrajectoryDataset(Dataset):
    """Dataset for real dense trajectory prediction.
    
    Each sample contains:
        - RGB image (from any frame)
        - Dense 2D trajectory (N_WINDOW waypoints starting from that frame)
        - Camera parameters (pose and intrinsics)
        - Ground truth heatmaps for each timestep
    
    For each episode, creates samples from every frame, with trajectories
    padded with the last observed keypoint if there aren't enough subsequent frames.
    """
    
    def __init__(self, dataset_root="scratch/", image_size=448, episode: str | None = None, max_episodes: int | None = None):
        """Initialize dataset.
        
        Args:
            dataset_root: Root directory containing episodes
            image_size: Size to resize images to (will be square, default 448)
            episode: Optional episode directory name (e.g. "episode_001") to load only.
            max_episodes: Optional limit on how many episodes to load (after sorting).
        """
        self.dataset_root = Path(dataset_root)
        self.image_size = image_size
        self.n_window = N_WINDOW
        
        # Load episode directories
        if not self.dataset_root.exists():
            raise ValueError(f"Dataset directory not found: {self.dataset_root}")
        
        episode_dirs = sorted([d for d in self.dataset_root.iterdir()
                             if d.is_dir() and "episode" in d.name])

        if episode is not None:
            episode_dirs = [d for d in episode_dirs if d.name == episode]
            if len(episode_dirs) == 0:
                raise ValueError(f"Episode '{episode}' not found in {self.dataset_root}")

        if max_episodes is not None:
            episode_dirs = episode_dirs[:max_episodes]
        
        if len(episode_dirs) == 0:
            raise ValueError(f"No episodes found in {self.dataset_root}")
        
        # Build list of (episode_dir, frame_idx) pairs
        self.samples = []
        for episode_dir in episode_dirs:
            # Find all frame images
            frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
            
            # Create a sample for each frame
            for frame_file in frame_files:
                frame_idx = int(frame_file.stem)
                self.samples.append((episode_dir, frame_idx))
        
        print(f"Loaded {len(episode_dirs)} episodes")
        print(f"Created {len(self.samples)} samples (one per frame)")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        """Get a single sample.
        
        Returns:
            dict with keys:
                - rgb: (3, H, W) normalized RGB image
                - heatmap_target: (N_WINDOW, H, W) one-hot heatmaps for each timestep
                - trajectory_2d: (N_WINDOW, 2) 2D pixel locations for each timestep
                - trajectory_3d: (N_WINDOW, 3) 3D world positions for each timestep
                - trajectory_gripper: (N_WINDOW,) gripper values for each timestep
                - target_3d: (3,) final target 3D world position (last waypoint)
                - camera_pose: (4, 4) camera pose matrix (from first frame)
                - cam_K_norm: (3, 3) normalized intrinsics (from first frame)
        """
        episode_dir, frame_idx = self.samples[idx]
        frame_str = f"{frame_idx:06d}"
        
        # Load RGB image for this frame
        rgb_path = episode_dir / f"{frame_str}.png"
        rgb = plt.imread(rgb_path)[..., :3]  # (H, W, 3)
        H_orig, W_orig = rgb.shape[:2]
        
        # Find all subsequent frames
        frame_files = sorted([f for f in episode_dir.glob("*.png") if f.stem.isdigit()])
        frame_indices = [int(f.stem) for f in frame_files]
        start_frame_idx = frame_indices.index(frame_idx)
        
        # Build trajectory from subsequent frames
        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []  # Gripper values for each timestep
        
        # Process up to N_WINDOW frames starting from current frame
        for i in range(self.n_window):
            if start_frame_idx + i >= len(frame_indices):
                break
            
            next_frame_idx = frame_indices[start_frame_idx + i]
            next_frame_str = f"{next_frame_idx:06d}"
            
            # Load gripper pose
            gripper_pose_path = episode_dir / f"{next_frame_str}_gripper_pose.npy"
            if not gripper_pose_path.exists():
                break
            
            gripper_pose = np.load(gripper_pose_path)  # (4, 4)
            gripper_rot = gripper_pose[:3, :3]
            gripper_pos = gripper_pose[:3, 3]
            
            # Compute 3D keypoint from gripper pose
            kp_3d = gripper_rot @ kp_local + gripper_pos
            trajectory_3d.append(kp_3d)
            
            # Load gripper value from joint state file (last dimension)
            joint_state_path = episode_dir / f"{next_frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])  # Last value is gripper
                gripper_value = process_gripper_value(gripper_value)  # Process: map >4.0 to -0.2, clamp to [-0.2, 0.8]
                trajectory_gripper.append(gripper_value)
            else:
                # Default to 1.0 (open) if joint state not found
                trajectory_gripper.append(1.0)
            
            # Load camera parameters (use first frame's camera params for consistency)
            cam_K_norm_path = episode_dir / f"{frame_str}_cam_K.npy"
            camera_pose_path = episode_dir / f"{frame_str}_camera_pose.npy"
            
            if not (cam_K_norm_path.exists() and camera_pose_path.exists()):
                break
            
            cam_K_norm = np.load(cam_K_norm_path)  # (3, 3) normalized
            camera_pose = np.load(camera_pose_path)  # (4, 4)
            
            # Scale intrinsics to image resolution
            cam_K = cam_K_norm.copy()
            cam_K[0] *= W_orig  # fx
            cam_K[1] *= H_orig  # fy
            
            # Project 3D keypoint to 2D
            kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
            if kp_2d is None:
                break
            
            trajectory_2d.append(kp_2d)
        
        if len(trajectory_2d) == 0:
            # Fallback: use current frame only
            gripper_pose = np.load(episode_dir / f"{frame_str}_gripper_pose.npy")
            gripper_rot = gripper_pose[:3, :3]
            gripper_pos = gripper_pose[:3, 3]
            kp_3d = gripper_rot @ kp_local + gripper_pos
            trajectory_3d = [kp_3d]
            
            # Load gripper value for current frame
            joint_state_path = episode_dir / f"{frame_str}.npy"
            if joint_state_path.exists():
                joint_state = np.load(joint_state_path)
                gripper_value = float(joint_state[-1])
                gripper_value = process_gripper_value(gripper_value)  # Fix wrapping and clamp
                trajectory_gripper = [gripper_value]
            else:
                trajectory_gripper = [1.0]
            
            cam_K_norm = np.load(episode_dir / f"{frame_str}_cam_K.npy")
            camera_pose = np.load(episode_dir / f"{frame_str}_camera_pose.npy")
            cam_K = cam_K_norm.copy()
            cam_K[0] *= W_orig
            cam_K[1] *= H_orig
            kp_2d = project_3d_to_2d(kp_3d, camera_pose, cam_K)
            if kp_2d is not None:
                trajectory_2d = [kp_2d]
        
        if len(trajectory_2d) == 0:
            raise ValueError(f"Could not compute trajectory for {episode_dir.name} frame {frame_str}")
        
        trajectory_2d = np.array(trajectory_2d)  # (N, 2)
        trajectory_3d = np.array(trajectory_3d)  # (N, 3)
        trajectory_gripper = np.array(trajectory_gripper)  # (N,)
        
        # Pad trajectory with last observed keypoint if needed
        if len(trajectory_2d) < self.n_window:
            # Pad with last point
            last_point_2d = trajectory_2d[-1:]  # (1, 2)
            last_point_3d = trajectory_3d[-1:]  # (1, 3)
            last_gripper = trajectory_gripper[-1:]  # (1,)
            n_pad = self.n_window - len(trajectory_2d)
            trajectory_2d = np.concatenate([trajectory_2d, np.tile(last_point_2d, (n_pad, 1))], axis=0)
            trajectory_3d = np.concatenate([trajectory_3d, np.tile(last_point_3d, (n_pad, 1))], axis=0)
            trajectory_gripper = np.concatenate([trajectory_gripper, np.tile(last_gripper, (n_pad,))], axis=0)
        elif len(trajectory_2d) > self.n_window:
            # Truncate to first N_WINDOW points
            trajectory_2d = trajectory_2d[:self.n_window]
            trajectory_3d = trajectory_3d[:self.n_window]
            trajectory_gripper = trajectory_gripper[:self.n_window]
        
        # Get target 3D position (final waypoint)
        target_3d = trajectory_3d[-1]
        
        # Load camera parameters (from first frame for consistency)
        camera_pose = np.load(episode_dir / f"{frame_str}_camera_pose.npy")
        cam_K_norm = np.load(episode_dir / f"{frame_str}_cam_K.npy")
        
        # Resize image if needed
        if H_orig != self.image_size or W_orig != self.image_size:
            rgb = cv2.resize(rgb, (self.image_size, self.image_size), 
                           interpolation=cv2.INTER_LINEAR)
            
            # Scale trajectory_2d to new resolution
            scale_x = self.image_size / W_orig
            scale_y = self.image_size / H_orig
            trajectory_2d = trajectory_2d * np.array([scale_x, scale_y])
        
        # Create heatmaps for each timestep
        heatmap_targets = []
        for t in range(self.n_window):
            target_2d = trajectory_2d[t]
            
            # Clip to image bounds
            target_2d = np.array([
                np.clip(target_2d[0], 0, self.image_size - 1),
                np.clip(target_2d[1], 0, self.image_size - 1)
            ])
            
            # Create one-hot heatmap at model resolution
            heatmap_target = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            target_x = int(round(target_2d[0]))
            target_y = int(round(target_2d[1]))
            if 0 <= target_x < self.image_size and 0 <= target_y < self.image_size:
                heatmap_target[target_y, target_x] = 1.0
            
            heatmap_targets.append(heatmap_target)
        
        heatmap_targets = np.array(heatmap_targets)  # (N_WINDOW, H, W)
        
        # Convert to torch tensors
        # RGB: (H, W, 3) -> (3, H, W) and normalize to [0, 1]
        rgb_tensor = torch.from_numpy(rgb).permute(2, 0, 1).float()
        if rgb_tensor.max() > 1.0:
            rgb_tensor = rgb_tensor / 255.0
        
        # Normalize RGB to ImageNet stats (for DINOv2)
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        rgb_tensor = (rgb_tensor - mean) / std
        
        heatmap_tensor = torch.from_numpy(heatmap_targets).float()  # (N_WINDOW, H, W)
        trajectory_2d_tensor = torch.from_numpy(trajectory_2d).float()  # (N_WINDOW, 2)
        trajectory_3d_tensor = torch.from_numpy(trajectory_3d).float()  # (N_WINDOW, 3)
        trajectory_gripper_tensor = torch.from_numpy(trajectory_gripper).float()  # (N_WINDOW,)
        target_3d_tensor = torch.from_numpy(target_3d).float()
        camera_pose_tensor = torch.from_numpy(camera_pose).float()
        cam_K_norm_tensor = torch.from_numpy(cam_K_norm).float()

        # Load joint states for debugging / IK initialization
        joint_state_path_cur = episode_dir / f"{frame_str}.npy"
        joint_state_cur = None
        if joint_state_path_cur.exists():
            try:
                joint_state_cur = np.load(joint_state_path_cur)
            except Exception:
                joint_state_cur = None

        joint_state_path_ep0 = episode_dir / "000000.npy"
        joint_state_ep0 = None
        if joint_state_path_ep0.exists():
            try:
                joint_state_ep0 = np.load(joint_state_path_ep0)
            except Exception:
                joint_state_ep0 = None

        joint_state_cur_tensor = (
            torch.from_numpy(np.asarray(joint_state_cur, dtype=np.float32))
            if joint_state_cur is not None
            else None
        )
        joint_state_ep0_tensor = (
            torch.from_numpy(np.asarray(joint_state_ep0, dtype=np.float32))
            if joint_state_ep0 is not None
            else None
        )
        
        return {
            'rgb': rgb_tensor,
            'heatmap_target': heatmap_tensor,
            'trajectory_2d': trajectory_2d_tensor,
            'trajectory_3d': trajectory_3d_tensor,
            'trajectory_gripper': trajectory_gripper_tensor,
            'target_3d': target_3d_tensor,
            'camera_pose': camera_pose_tensor,
            'cam_K_norm': cam_K_norm_tensor,
            'episode_id': f"{episode_dir.name}_frame_{frame_str}",
            'episode_name': episode_dir.name,
            'frame_idx': frame_idx,
            'episode_dir': str(episode_dir),
            'joint_state': joint_state_cur_tensor,
            'episode_start_joint_state': joint_state_ep0_tensor,
        }


def _top_motion_track_indices(tracks, visibility, percentile=85):
    """Return indices of tracks in the top (100 - percentile)% by total distance traveled.
    So percentile=85 means top 15% (most motion). tracks (T, N, 2), visibility (T, N)."""
    T, N, _ = tracks.shape
    total_dist = np.zeros(N, dtype=np.float64)
    for n in range(N):
        for t in range(T - 1):
            if visibility[t, n] and visibility[t + 1, n]:
                d = np.linalg.norm(tracks[t + 1, n] - tracks[t, n])
                total_dist[n] += d
    if N == 0 or np.all(total_dist == 0):
        return np.arange(N)
    thresh = np.percentile(total_dist, percentile)
    mask = total_dist >= thresh
    indices = np.where(mask)[0]
    return indices if len(indices) > 0 else np.arange(N)


def load_gif_frame(path, frame_idx):
    """Load a single frame from a GIF as (H, W, 3) uint8."""
    with Image.open(path) as im:
        im.seek(int(frame_idx))
        return np.array(im.convert("RGB"))


def load_gif_frames_at_indices(path, indices):
    """Load GIF frames at given indices. Returns (len(indices), H, W, 3) uint8."""
    out = []
    with Image.open(path) as im:
        for i in indices:
            im.seek(int(i))
            out.append(np.array(im.convert("RGB")))
    return np.stack(out)


class RTXPointTrackDataset(Dataset):
    """Load .pt track files from /data/RTX/tracks; first frame as input, predict next N_WINDOW heatmaps for N_QUERY_POINTS.
    Optionally filter to top motion tracks only (motion_percentile=85 -> top 15%% by total distance)."""

    def __init__(self, tracks_root=None, image_size=IMAGE_SIZE_PT, n_window=N_WINDOW_POINT_TRACK, n_query=N_QUERY_POINTS, heatmap_size=HEATMAP_SIZE, max_samples=None, motion_percentile=MOTION_PERCENTILE, shuffle_seed=42):
        self.tracks_root = Path(tracks_root or TRACKS_ROOT)
        self.image_size = image_size
        self.n_window = n_window
        self.n_query = n_query
        self.heatmap_size = heatmap_size
        self.motion_percentile = motion_percentile
        self.pt_paths = sorted(self.tracks_root.glob("*.pt"))
        if not self.pt_paths:
            raise FileNotFoundError(f"No .pt files in {self.tracks_root}")
        if shuffle_seed is not None:
            rng = random.Random(shuffle_seed)
            rng.shuffle(self.pt_paths)
        if max_samples is not None:
            self.pt_paths = self.pt_paths[:max_samples]
        print(f"RTXPointTrackDataset: {len(self.pt_paths)} .pt files from {self.tracks_root}" + (f" (motion filter: top {100 - motion_percentile}%%)" if motion_percentile is not None else "") + (" (paths shuffled)" if shuffle_seed is not None else ""))

    def __len__(self):
        return len(self.pt_paths)

    def __getitem__(self, idx):
        pt_path = self.pt_paths[idx]
        payload = torch.load(pt_path, weights_only=False)
        tracks = payload["tracks"].numpy()   # (T, N, 2)
        T, N = tracks.shape[0], tracks.shape[1]
        # Visibility: saved by point_track_rtx_batch.py from CoTracker (T, N, 1). Used to mask CE loss and motion filter.
        if "visibility" in payload:
            visibility = payload["visibility"].numpy().squeeze()
        else:
            visibility = np.ones((T, N), dtype=np.float32)
        if visibility.ndim == 1:
            # Could be (T,) per-frame or (N,) per-track
            if len(visibility) == T:
                visibility = np.broadcast_to(visibility.reshape(-1, 1), (T, N))
            elif len(visibility) == N:
                visibility = np.broadcast_to(visibility.reshape(1, -1), (T, N))
            else:
                visibility = np.broadcast_to(visibility[:, None], (T, N))
        elif visibility.ndim == 2:
            if visibility.shape == (T, N):
                pass
            elif visibility.shape == (N, T):
                visibility = visibility.T
            elif visibility.shape == (T, 1):
                visibility = np.broadcast_to(visibility, (T, N)).copy()
            elif visibility.shape == (1, N):
                visibility = np.broadcast_to(visibility, (T, N)).copy()
            else:
                visibility = np.ones((T, N), dtype=np.float32)
        frame_indices = payload["frame_indices"].numpy()
        gif_path = Path(payload["gif_path"])
        H_orig = int(payload["height"])
        W_orig = int(payload["width"])
        T, N, _ = tracks.shape
        if T < self.n_window + 1:
            # Pad with last frame
            last = np.tile(tracks[-1:], (self.n_window + 1 - T, 1, 1))
            tracks = np.concatenate([tracks, last], axis=0)
            vis_last = np.broadcast_to(visibility[-1:], (self.n_window + 1 - T, N))
            visibility = np.concatenate([visibility, vis_last], axis=0)
            T = tracks.shape[0]
        # Restrict to top motion tracks (e.g. 85th percentile = top 15%)
        if self.motion_percentile is not None:
            candidate_indices = _top_motion_track_indices(tracks, visibility, self.motion_percentile)
        else:
            candidate_indices = np.arange(N)
        n_candidates = len(candidate_indices)
        if n_candidates == 0:
            candidate_indices = np.arange(N)
            n_candidates = N
        # Sample query point indices from the candidate set only
        query_idx = np.random.choice(candidate_indices, size=min(self.n_query, n_candidates), replace=(n_candidates < self.n_query))
        if len(query_idx) < self.n_query:
            query_idx = np.concatenate([query_idx, np.random.choice(candidate_indices, size=self.n_query - len(query_idx))])
        query_idx = query_idx[: self.n_query]
        # First frame as input
        first_frame = load_gif_frame(gif_path, frame_indices[0])  # (H_orig, W_orig, 3)
        first_frame = cv2.resize(first_frame, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        rgb_tensor = torch.from_numpy(first_frame).permute(2, 0, 1).float()
        if rgb_tensor.max() > 1.0:
            rgb_tensor = rgb_tensor / 255.0
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        rgb_tensor = (rgb_tensor - mean) / std
        # Query start positions in image_size space (448)
        scale_448_x = self.image_size / W_orig
        scale_448_y = self.image_size / H_orig
        query_start = tracks[0, query_idx] * np.array([scale_448_x, scale_448_y])  # (32, 2)
        query_start_2d = torch.from_numpy(query_start).float()
        # Targets: future frames 1..n_window, in heatmap_size (64) grid for CE
        scale_64_x = self.heatmap_size / W_orig
        scale_64_y = self.heatmap_size / H_orig
        target_indices = []  # (32, n_window) flat index in 64*64
        vis_mask = []
        for q in range(self.n_query):
            ti = []
            vm = []
            for t in range(1, self.n_window + 1):
                x, y = tracks[t, query_idx[q], 0], tracks[t, query_idx[q], 1]
                x64 = np.clip(int(round(x * scale_64_x)), 0, self.heatmap_size - 1)
                y64 = np.clip(int(round(y * scale_64_y)), 0, self.heatmap_size - 1)
                ti.append(y64 * self.heatmap_size + x64)
                vm.append(visibility[t, query_idx[q]] > 0.5)
            target_indices.append(ti)
            vis_mask.append(vm)
        target_heatmap_indices = torch.from_numpy(np.array(target_indices)).long()  # (32, n_window)
        visibility_mask = torch.from_numpy(np.array(vis_mask)).bool()  # (32, n_window)
        # For wandb: frames and tracks in original res
        frame_inds_vis = frame_indices[: self.n_window + 1]
        frames_vis = load_gif_frames_at_indices(gif_path, frame_inds_vis)  # (T_vis, H_orig, W_orig, 3)
        tracks_vis = tracks[: self.n_window + 1, query_idx]  # (T_vis, 32, 2) original coords
        return {
            "rgb": rgb_tensor,
            "query_start_2d": query_start_2d,
            "target_heatmap_indices": target_heatmap_indices,
            "visibility": visibility_mask,
            "frames_vis": torch.from_numpy(frames_vis),
            "tracks_vis": torch.from_numpy(tracks_vis).float(),
            "pt_name": pt_path.stem,
        }


if __name__ == "__main__":
    # Test dataset loading
    print("Testing RealTrajectoryDataset...")
    
    try:
        dataset = RealTrajectoryDataset(image_size=448, dataset_root="scratch/parsed_moredata_pickplace_home")
        print(f"✓ Loaded {len(dataset)} samples")
        
        # Test first sample
        sample = dataset[0]
        print(f"\nSample 0:")
        print(f"  RGB shape: {sample['rgb'].shape}")
        print(f"  RGB range: [{sample['rgb'].min():.3f}, {sample['rgb'].max():.3f}]")
        print(f"  Heatmap shape: {sample['heatmap_target'].shape}")
        print(f"  Heatmap sum per timestep: {sample['heatmap_target'].sum(dim=(1,2))}")
        print(f"  Trajectory 2D shape: {sample['trajectory_2d'].shape}")
        print(f"  Trajectory 2D (first 3): {sample['trajectory_2d'][:3]}")
        print(f"  Trajectory 3D shape: {sample['trajectory_3d'].shape}")
        print(f"  Target 3D: {sample['target_3d']}")
        print(f"  Episode ID: {sample['episode_id']}")
        
    except Exception as e:
        print(f"✗ Error: {e}")
        import traceback
        traceback.print_exc()
    
    print(f"\n{'='*60}")
    print("✓ Dataset test complete!")
