"""DROID dataset for PARA training.

Streams episodes from HuggingFace (cadene/droid_1.0.1) and produces training
samples in the same format as CachedTrajectoryDataset (LIBERO).

NOTE: Camera intrinsics are estimated (no per-camera intrinsics in the dataset).
The estimated fy=130 corresponds to ZED 2 wide mode at 320×180.

Key differences from LIBERO:
  - Real images (320×180), resized to 448×448, no vertical flip needed
  - Camera extrinsics from DROID 6D format [x,y,z,rx,ry,rz]
  - Gripper: DROID [0,1] → mapped to [-1,+1] for PARA convention
  - EEF pose from cartesian_position [x,y,z,roll,pitch,yaw]
  - Robot base at world origin (base_z=0)
  - 15 Hz (vs LIBERO 20 Hz), frame_stride=2 gives ~7.5Hz
"""

import os
from pathlib import Path

import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from scipy.spatial.transform import Rotation as ScipyR

N_WINDOW = 4
MIN_GRIPPER = -1.0
MAX_GRIPPER = 1.0
BASE_Z = 0.0  # Franka base at world origin in DROID

# Estimated ZED 2 intrinsics at 320×180 (wide mode)
DEFAULT_FY = 130.0
IMG_W_NATIVE, IMG_H_NATIVE = 320, 180

IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)


def _build_camera_matrices(ext6d, fy=DEFAULT_FY, img_w=IMG_W_NATIVE, img_h=IMG_H_NATIVE,
                            target_size=448):
    """Build camera matrices from DROID 6D extrinsics.

    DROID extrinsics: [x,y,z,rx,ry,rz] = camera pose in robot base frame.
    R = Rotation.from_euler("xyz", [rx,ry,rz]) maps camera frame → base frame.

    All matrices are in the RESIZED (target_size × target_size) image space,
    since training images are resized from 320×180 to 448×448.

    Returns:
        camera_pose: (4,4) camera→world (for unprojection / 3D recovery)
        world_to_cam_proj: (4,4) full projection matrix K@[R|t] at target_size
                           (compatible with robosuite project_points_from_world_to_camera)
        cam_K_norm: (3,3) normalized intrinsic matrix at target_size (divided by target_size)
    """
    pos = ext6d[:3].astype(np.float64)
    R_base_cam = ScipyR.from_euler("xyz", ext6d[3:6]).as_matrix()

    # camera_pose = T_base_cam (camera→world) — same convention as LIBERO cam_extrinsic
    camera_pose = np.eye(4, dtype=np.float64)
    camera_pose[:3, :3] = R_base_cam
    camera_pose[:3, 3] = pos

    # world_to_cam extrinsic = T_cam_base (world→camera)
    R_cam_base = R_base_cam.T
    t_cam = -R_cam_base @ pos
    T_cam_base = np.eye(4, dtype=np.float64)
    T_cam_base[:3, :3] = R_cam_base
    T_cam_base[:3, 3] = t_cam

    # Intrinsics scaled to target_size (non-uniform scaling 320×180 → 448×448)
    sx = target_size / img_w   # 448/320 = 1.4
    sy = target_size / img_h   # 448/180 ≈ 2.489
    fx_eff = fy * sx            # horizontal focal length at target_size
    fy_eff = fy * sy            # vertical focal length at target_size
    cx_eff = (img_w / 2.0) * sx  # = target_size / 2
    cy_eff = (img_h / 2.0) * sy  # = target_size / 2

    # 4×4 intrinsic matrix for robosuite-compatible projection
    K_4x4 = np.array([
        [fx_eff, 0,      cx_eff, 0],
        [0,      fy_eff, cy_eff, 0],
        [0,      0,      1,      0],
        [0,      0,      0,      1],
    ], dtype=np.float64)

    # Full projection: K @ extrinsic (world → pixel at target_size)
    world_to_cam_proj = K_4x4 @ T_cam_base

    # 3×3 intrinsic normalized by target_size (for unprojection via utils.py)
    K_3x3 = np.array([[fx_eff, 0, cx_eff], [0, fy_eff, cy_eff], [0, 0, 1]], dtype=np.float64)
    cam_K_norm = K_3x3.copy()
    cam_K_norm[0] /= float(target_size)
    cam_K_norm[1] /= float(target_size)

    return (camera_pose.astype(np.float32),
            world_to_cam_proj.astype(np.float32),
            cam_K_norm.astype(np.float32))


def _project_to_pixel_targetsize(p_world, world_to_cam_proj, target_size):
    """Project 3D world point to pixel (u, v) in target_size image space.

    Uses the full 4×4 projection matrix (K @ extrinsic) at target_size.
    Returns (u, v) clipped to [0, target_size-1].
    """
    p_h = np.append(p_world.astype(np.float64), 1.0)
    p_proj = world_to_cam_proj @ p_h
    if p_proj[2] <= 0:
        return np.array([target_size / 2, target_size / 2], dtype=np.float32)
    u = p_proj[0] / p_proj[2]
    v = p_proj[1] / p_proj[2]
    u = np.clip(u, 0, target_size - 1)
    v = np.clip(v, 0, target_size - 1)
    return np.array([u, v], dtype=np.float32)


class DroidLocalDataset(Dataset):
    """Reads DROID episodes from a local directory (downloaded via huggingface-cli).

    Init is fast: only scans file paths and reads parquet row counts from metadata.
    All per-episode data (parquet columns, projections) is loaded lazily in __getitem__.

    Expected layout under data_root:
        data/chunk-NNN/episode_NNNNNN.parquet
        videos/chunk-NNN/observation.images.exterior_{1,2}_left/episode_NNNNNN.mp4

    Args:
        data_root: path to downloaded dataset (e.g. /data/cameron/droid)
        camera: "ext1" or "ext2"
        max_episodes: limit number of episodes (0 = all)
        image_size, n_window, frame_stride, fy: same as DroidStreamingDataset
    """

    def __init__(
        self,
        data_root,
        camera="ext2",
        max_episodes=0,
        manifest_path="",
        image_size=448,
        n_window=N_WINDOW,
        frame_stride=2,
        fy=DEFAULT_FY,
    ):
        self.data_root = Path(data_root)
        self.image_size = image_size
        self.n_window = n_window
        self.frame_stride = frame_stride
        self.fy = fy
        self.camera = camera
        self.cam_key = (
            "observation.images.exterior_1_left" if camera == "ext1"
            else "observation.images.exterior_2_left"
        )
        self.ext_col = (
            "camera_extrinsics.exterior_1_left" if camera == "ext1"
            else "camera_extrinsics.exterior_2_left"
        )

        import time
        t0 = time.time()

        # If manifest provided, use it (pre-filtered quality episodes)
        if manifest_path and Path(manifest_path).exists():
            import json
            with open(manifest_path) as f:
                manifest = json.load(f)
            print(f"DroidLocalDataset: loading from manifest {manifest_path}")
            print(f"  {len(manifest['episodes'])} episodes, "
                  f"min_in_frame={manifest.get('min_in_frame', '?')}")

            pq_paths = []
            vid_paths = []
            frame_counts = []
            min_frames = self.n_window * self.frame_stride

            for entry in manifest["episodes"]:
                ep = entry["ep_idx"]
                T = entry["num_frames"]
                if T < min_frames:
                    continue
                chunk = f"chunk-{ep // 1000:03d}"
                ep_str = f"episode_{ep:06d}"
                pq_path = self.data_root / "data" / chunk / f"{ep_str}.parquet"
                vid_path = self.data_root / "videos" / chunk / self.cam_key / f"{ep_str}.mp4"
                if not pq_path.exists() or not vid_path.exists():
                    continue
                pq_paths.append(str(pq_path))
                vid_paths.append(str(vid_path))
                frame_counts.append(T)
                if max_episodes > 0 and len(pq_paths) >= max_episodes:
                    break
        else:
            # No manifest: scan all episodes (fast path scan + parquet row counts)
            import pyarrow.parquet as pq

            print(f"DroidLocalDataset: indexing {self.data_root}/data/ (no manifest)...")

            TOTAL_EPISODES = 95600
            pq_paths = []
            vid_paths = []
            frame_counts = []
            min_frames = self.n_window * self.frame_stride

            n_to_scan = min(max_episodes, TOTAL_EPISODES) if max_episodes > 0 else TOTAL_EPISODES
            for ep in range(n_to_scan):
                chunk = f"chunk-{ep // 1000:03d}"
                ep_str = f"episode_{ep:06d}"
                pq_path = self.data_root / "data" / chunk / f"{ep_str}.parquet"
                video_path = self.data_root / "videos" / chunk / self.cam_key / f"{ep_str}.mp4"

                if not pq_path.exists() or not video_path.exists():
                    continue
                try:
                    meta = pq.read_metadata(str(pq_path))
                    T = meta.num_rows
                except Exception:
                    continue
                if T < min_frames:
                    continue

                pq_paths.append(str(pq_path))
                vid_paths.append(str(video_path))
                frame_counts.append(T)

        self._pq_paths = pq_paths
        self._vid_paths = vid_paths
        self._frame_counts = np.array(frame_counts, dtype=np.int32)
        self._cumsum = np.cumsum(self._frame_counts)
        self._total_samples = int(self._cumsum[-1]) if len(self._cumsum) > 0 else 0

        elapsed = time.time() - t0
        print(f"DroidLocalDataset: {len(pq_paths)} episodes, {self._total_samples} samples "
              f"(init {elapsed:.1f}s)")

    def _load_episode_data(self, ep_idx):
        """Load and process a single episode's parquet. Called lazily from __getitem__."""
        import pandas as pd

        pq_path = self._pq_paths[ep_idx]
        video_path = self._vid_paths[ep_idx]
        T = int(self._frame_counts[ep_idx])
        df = pd.read_parquet(pq_path)

        cartesian_positions = np.stack(df["observation.state.cartesian_position"].values).astype(np.float32)
        gripper_positions = df["observation.state.gripper_position"].values.astype(np.float32)
        extrinsics = np.array(df[self.ext_col].iloc[0], dtype=np.float32)
        language = str(df["language_instruction"].iloc[0]) if "language_instruction" in df.columns else ""

        # Camera matrices
        camera_pose, world_to_cam_proj, cam_K_norm = _build_camera_matrices(
            extrinsics, fy=self.fy,
            img_w=IMG_W_NATIVE, img_h=IMG_H_NATIVE,
            target_size=self.image_size)

        # Project EEF to pixels
        eef_pos = cartesian_positions[:, :3]
        pix_uv = np.zeros((T, 2), dtype=np.float32)
        for t_i in range(T):
            pix_uv[t_i] = _project_to_pixel_targetsize(
                eef_pos[t_i].astype(np.float64), world_to_cam_proj, self.image_size)

        # Gripper: DROID [0=closed, 1=open] → PARA [-1=open, +1=closed]
        gripper_para = -(gripper_positions * 2 - 1)

        # EEF quaternions
        eef_euler = cartesian_positions[:, 3:6].astype(np.float64)
        eef_quat = np.stack([
            ScipyR.from_euler("xyz", e).as_quat() for e in eef_euler
        ], axis=0).astype(np.float32)

        return {
            "video_path": video_path,
            "eef_pos": eef_pos,
            "eef_quat": eef_quat,
            "gripper": gripper_para,
            "pix_uv": pix_uv,
            "camera_pose": camera_pose,
            "world_to_cam": world_to_cam_proj,
            "cam_K_norm": cam_K_norm,
            "num_frames": T,
            "language": language,
        }

    def _decode_frames(self, video_path, frame_indices):
        """Decode specific frames from mp4 using PyAV. Returns list of (H,W,3) uint8.

        Handles corrupt videos gracefully by returning black frames.
        """
        import av
        try:
            container = av.open(video_path)

            frames_needed = set(frame_indices)
            max_frame = max(frames_needed)
            result = {}
            for i, frame in enumerate(container.decode(video=0)):
                if i in frames_needed:
                    result[i] = frame.to_ndarray(format="rgb24")
                if i >= max_frame:
                    break
            container.close()

            if not result:
                raise ValueError("No frames decoded")
            return [result.get(i, result[max(result.keys())]) for i in frame_indices]
        except Exception:
            # Return black frames for corrupt videos
            return [np.zeros((IMG_H_NATIVE, IMG_W_NATIVE, 3), dtype=np.uint8)] * len(frame_indices)

    def __len__(self):
        return self._total_samples

    def __getitem__(self, idx):
        # Convert flat index to (episode_idx, start_t) via cumsum binary search
        demo_idx = int(np.searchsorted(self._cumsum, idx, side='right'))
        start_t = int(idx - (self._cumsum[demo_idx - 1] if demo_idx > 0 else 0))
        ep = self._load_episode_data(demo_idx)
        T = ep["num_frames"]

        # Compute frame indices for the window
        frame_indices = [min(start_t + k * self.frame_stride, T - 1) for k in range(self.n_window)]

        # Decode video frames on-the-fly
        raw_frames = self._decode_frames(ep["video_path"], frame_indices)

        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []
        trajectory_quat = []
        rgb_frames_raw = []
        rgb_ref = None

        for k, t in enumerate(frame_indices):
            rgb = raw_frames[k].astype(np.float32) / 255.0
            if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
                rgb = cv2.resize(rgb, (self.image_size, self.image_size),
                                 interpolation=cv2.INTER_LINEAR)
            if rgb_ref is None:
                rgb_ref = rgb
            rgb_frames_raw.append(rgb)

            trajectory_3d.append(ep["eef_pos"][t].astype(np.float64))
            trajectory_quat.append(ep["eef_quat"][t].astype(np.float64))
            trajectory_gripper.append(float(np.clip(ep["gripper"][t], MIN_GRIPPER, MAX_GRIPPER)))
            trajectory_2d.append(ep["pix_uv"][t].copy())

        trajectory_2d = np.stack(trajectory_2d, axis=0).astype(np.float32)
        trajectory_3d = np.stack(trajectory_3d, axis=0).astype(np.float32)
        trajectory_gripper = np.array(trajectory_gripper, dtype=np.float32)
        trajectory_quat = np.stack(trajectory_quat, axis=0).astype(np.float32)
        trajectory_euler = np.stack([
            ScipyR.from_quat(q).as_euler('xyz') for q in trajectory_quat
        ], axis=0).astype(np.float32)

        import model as _model_module
        ref_quat = np.array(_model_module.REF_ROTATION_QUAT, dtype=np.float64)
        ref_rot = ScipyR.from_quat(ref_quat)
        trajectory_delta_rotvec = np.stack([
            (ref_rot.inv() * ScipyR.from_quat(q)).as_rotvec() for q in trajectory_quat
        ], axis=0).astype(np.float32)

        rgb_frames_raw = np.stack(rgb_frames_raw, axis=0).astype(np.float32)

        heatmap_targets = []
        for t_k in range(self.n_window):
            x, y = trajectory_2d[t_k]
            x_i = int(np.clip(round(float(x)), 0, self.image_size - 1))
            y_i = int(np.clip(round(float(y)), 0, self.image_size - 1))
            hm = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            hm[y_i, x_i] = 1.0
            heatmap_targets.append(hm)
        heatmap_targets = np.stack(heatmap_targets, axis=0)

        rgb_t = torch.from_numpy(rgb_ref).permute(2, 0, 1).float()
        mean = torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1)
        std = torch.tensor(IMAGENET_STD, dtype=torch.float32).view(3, 1, 1)
        rgb_t = (rgb_t - mean) / std

        return {
            "rgb": rgb_t,
            "heatmap_target": torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d": torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d": torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat": torch.from_numpy(trajectory_quat).float(),
            "trajectory_euler": torch.from_numpy(trajectory_euler).float(),
            "trajectory_delta_rotvec": torch.from_numpy(trajectory_delta_rotvec).float(),
            "rgb_frames_raw": torch.from_numpy(rgb_frames_raw).float(),
            "world_to_camera": torch.from_numpy(ep["world_to_cam"]).float(),
            "base_z": torch.tensor(BASE_Z, dtype=torch.float32),
            "target_3d": torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose": torch.from_numpy(ep["camera_pose"]).float(),
            "cam_K_norm": torch.from_numpy(ep["cam_K_norm"]).float(),
            "demo_idx": torch.tensor(demo_idx, dtype=torch.long),
            "start_t": torch.tensor(start_t, dtype=torch.long),
            "clip_embedding": torch.zeros(512),
            "task_description": ep.get("language", ""),
            "has_wrist": torch.tensor(False, dtype=torch.bool),
            "wrist_rgb": torch.zeros(3, self.image_size, self.image_size),
            "wrist_trajectory_2d": torch.zeros(self.n_window, 2),
            "wrist_camera_pose": torch.eye(4, dtype=torch.float32),
            "wrist_cam_K_norm": torch.eye(3, dtype=torch.float32),
            "wrist_world_to_camera": torch.eye(4, dtype=torch.float32),
            "wrist_in_view": torch.zeros(self.n_window),
        }


class DroidStreamingDataset(Dataset):
    """Streams DROID episodes from HuggingFace, caches in memory.

    Downloads episode parquet + video on first access, then serves
    samples from RAM. Suitable for prototyping with small episode counts.

    Args:
        episode_indices: list of int episode indices to load
        camera: "ext1" or "ext2" (exterior cameras only, no wrist)
        image_size: target image size (default 448)
        n_window: number of future timesteps per sample
        frame_stride: stride between frames (default 2 for 15Hz → ~7.5Hz)
        fy: estimated focal length for camera intrinsics
    """

    def __init__(
        self,
        episode_indices=None,
        camera="ext2",
        image_size=448,
        n_window=N_WINDOW,
        frame_stride=2,
        fy=DEFAULT_FY,
    ):
        if episode_indices is None:
            episode_indices = list(range(10))  # default: first 10 episodes

        self.image_size = image_size
        self.n_window = n_window
        self.frame_stride = frame_stride
        self.fy = fy
        self.camera = camera

        # Load all episodes into memory
        self.episodes = []
        self.samples = []

        print(f"DroidStreamingDataset: loading {len(episode_indices)} episodes from HuggingFace...")
        for ep_idx in episode_indices:
            try:
                ep_data = self._load_episode(ep_idx)
                if ep_data is None:
                    continue
                demo_idx = len(self.episodes)
                self.episodes.append(ep_data)
                T = ep_data["num_frames"]
                for t in range(T):
                    self.samples.append((demo_idx, t))
            except Exception as e:
                print(f"  Warning: failed to load episode {ep_idx}: {e}")
                continue

        print(f"DroidStreamingDataset: {len(self.episodes)} episodes, {len(self.samples)} samples")

    def _load_episode(self, ep_idx):
        """Download and parse one DROID episode."""
        import pandas as pd
        from huggingface_hub import hf_hub_download
        import av

        REPO_ID = "cadene/droid_1.0.1"

        chunk = ep_idx // 1000
        ep_str = f"episode_{ep_idx:06d}"
        chunk_str = f"chunk-{chunk:03d}"

        # Download parquet
        parquet_path = hf_hub_download(
            REPO_ID, f"data/{chunk_str}/{ep_str}.parquet", repo_type="dataset")

        # Download video for selected camera
        cam_key = f"observation.images.exterior_1_left" if self.camera == "ext1" else "observation.images.exterior_2_left"
        video_path = hf_hub_download(
            REPO_ID, f"videos/{chunk_str}/{cam_key}/{ep_str}.mp4", repo_type="dataset")

        # Parse parquet
        df = pd.read_parquet(parquet_path)
        T = len(df)

        # Decode video
        container = av.open(video_path)
        frames = []
        for frame in container.decode(video=0):
            frames.append(frame.to_ndarray(format="rgb24"))
        container.close()
        images = np.stack(frames)

        if images.shape[0] != T:
            print(f"  Warning: ep {ep_idx} video frames {images.shape[0]} != parquet rows {T}")
            T = min(images.shape[0], T)
            images = images[:T]

        # Extract data
        joint_positions = np.stack(df["observation.state.joint_position"].values[:T])
        cartesian_positions = np.stack(df["observation.state.cartesian_position"].values[:T])
        gripper_positions = df["observation.state.gripper_position"].values[:T].astype(np.float32)

        ext_key = f"camera_extrinsics.exterior_1_left" if self.camera == "ext1" else "camera_extrinsics.exterior_2_left"
        extrinsics = np.array(df[ext_key].iloc[0], dtype=np.float32)

        language = df["language_instruction"].iloc[0] if "language_instruction" in df.columns else ""

        # Build camera matrices (all in target image_size space)
        camera_pose, world_to_cam_proj, cam_K_norm = _build_camera_matrices(
            extrinsics, fy=self.fy,
            img_w=IMG_W_NATIVE, img_h=IMG_H_NATIVE,
            target_size=self.image_size)

        # Project EEF positions to pixels in target image_size space
        pix_uv = np.zeros((T, 2), dtype=np.float32)
        for t_i in range(T):
            eef_pos = cartesian_positions[t_i, :3].astype(np.float64)
            pix_uv[t_i] = _project_to_pixel_targetsize(
                eef_pos, world_to_cam_proj, self.image_size)

        # Convert DROID gripper [0=closed, 1=open] → PARA [-1=open, +1=closed]
        gripper_para = -(gripper_positions * 2 - 1)  # 0→+1 (close), 1→-1 (open)

        # EEF quaternions from euler
        eef_euler = cartesian_positions[:T, 3:6].astype(np.float64)
        eef_quat = np.stack([
            ScipyR.from_euler("xyz", e).as_quat() for e in eef_euler
        ], axis=0).astype(np.float32)  # (T, 4) xyzw

        print(f"  Loaded ep {ep_idx}: {T} frames, task='{language[:50]}'")

        return {
            "images": images,
            "eef_pos": cartesian_positions[:T, :3].astype(np.float32),
            "eef_euler": eef_euler.astype(np.float32),
            "eef_quat": eef_quat,
            "gripper": gripper_para,
            "pix_uv": pix_uv,
            "camera_pose": camera_pose,
            "world_to_cam": world_to_cam_proj,
            "cam_K_norm": cam_K_norm,
            "num_frames": T,
            "language": language if isinstance(language, str) else "",
        }

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        demo_idx, start_t = self.samples[idx]
        ep = self.episodes[demo_idx]
        T = ep["num_frames"]

        trajectory_2d = []
        trajectory_3d = []
        trajectory_gripper = []
        trajectory_quat = []
        rgb_frames_raw = []
        rgb_ref = None

        for k in range(self.n_window):
            t = min(start_t + k * self.frame_stride, T - 1)

            # Load and resize image
            rgb = ep["images"][t].astype(np.float32) / 255.0
            if rgb.shape[0] != self.image_size or rgb.shape[1] != self.image_size:
                rgb = cv2.resize(rgb, (self.image_size, self.image_size),
                                 interpolation=cv2.INTER_LINEAR)
            if rgb_ref is None:
                rgb_ref = rgb
            rgb_frames_raw.append(rgb)

            trajectory_3d.append(ep["eef_pos"][t].astype(np.float64))
            trajectory_quat.append(ep["eef_quat"][t].astype(np.float64))
            trajectory_gripper.append(float(np.clip(ep["gripper"][t], MIN_GRIPPER, MAX_GRIPPER)))
            trajectory_2d.append(ep["pix_uv"][t].copy())

        trajectory_2d = np.stack(trajectory_2d, axis=0).astype(np.float32)
        trajectory_3d = np.stack(trajectory_3d, axis=0).astype(np.float32)
        trajectory_gripper = np.array(trajectory_gripper, dtype=np.float32)
        trajectory_quat = np.stack(trajectory_quat, axis=0).astype(np.float32)
        trajectory_euler = np.stack([
            ScipyR.from_quat(q).as_euler('xyz') for q in trajectory_quat
        ], axis=0).astype(np.float32)

        # Delta axis-angle from reference rotation
        import model as _model_module
        ref_quat = np.array(_model_module.REF_ROTATION_QUAT, dtype=np.float64)
        ref_rot = ScipyR.from_quat(ref_quat)
        trajectory_delta_rotvec = np.stack([
            (ref_rot.inv() * ScipyR.from_quat(q)).as_rotvec() for q in trajectory_quat
        ], axis=0).astype(np.float32)

        rgb_frames_raw = np.stack(rgb_frames_raw, axis=0).astype(np.float32)

        # Heatmap targets
        heatmap_targets = []
        for t_k in range(self.n_window):
            x, y = trajectory_2d[t_k]
            x_i = int(np.clip(round(float(x)), 0, self.image_size - 1))
            y_i = int(np.clip(round(float(y)), 0, self.image_size - 1))
            hm = np.zeros((self.image_size, self.image_size), dtype=np.float32)
            hm[y_i, x_i] = 1.0
            heatmap_targets.append(hm)
        heatmap_targets = np.stack(heatmap_targets, axis=0)

        # Normalize RGB
        rgb_t = torch.from_numpy(rgb_ref).permute(2, 0, 1).float()
        mean = torch.tensor(IMAGENET_MEAN, dtype=torch.float32).view(3, 1, 1)
        std = torch.tensor(IMAGENET_STD, dtype=torch.float32).view(3, 1, 1)
        rgb_t = (rgb_t - mean) / std

        return {
            "rgb": rgb_t,
            "heatmap_target": torch.from_numpy(heatmap_targets).float(),
            "trajectory_2d": torch.from_numpy(trajectory_2d).float(),
            "trajectory_3d": torch.from_numpy(trajectory_3d).float(),
            "trajectory_gripper": torch.from_numpy(trajectory_gripper).float(),
            "trajectory_quat": torch.from_numpy(trajectory_quat).float(),
            "trajectory_euler": torch.from_numpy(trajectory_euler).float(),
            "trajectory_delta_rotvec": torch.from_numpy(trajectory_delta_rotvec).float(),
            "rgb_frames_raw": torch.from_numpy(rgb_frames_raw).float(),
            "world_to_camera": torch.from_numpy(ep["world_to_cam"]).float(),
            "base_z": torch.tensor(BASE_Z, dtype=torch.float32),
            "target_3d": torch.from_numpy(trajectory_3d[-1]).float(),
            "camera_pose": torch.from_numpy(ep["camera_pose"]).float(),
            "cam_K_norm": torch.from_numpy(ep["cam_K_norm"]).float(),
            "demo_idx": torch.tensor(demo_idx, dtype=torch.long),
            "start_t": torch.tensor(start_t, dtype=torch.long),
            "clip_embedding": torch.zeros(512),
            "task_description": ep.get("language", ""),
            # No wrist camera
            "has_wrist": torch.tensor(False, dtype=torch.bool),
            "wrist_rgb": torch.zeros(3, self.image_size, self.image_size),
            "wrist_trajectory_2d": torch.zeros(self.n_window, 2),
            "wrist_camera_pose": torch.eye(4, dtype=torch.float32),
            "wrist_cam_K_norm": torch.eye(3, dtype=torch.float32),
            "wrist_world_to_camera": torch.eye(4, dtype=torch.float32),
            "wrist_in_view": torch.zeros(self.n_window),
        }
