"""Dataset: list MP4s under droid_raw, sample 8 frames at ~4fps, 256x256, return (B, 3, 8, 256, 256) in [-1,1]."""

import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

try:
    import decord
    decord.bridge.set_bridge("torch")
except Exception:
    decord = None


def _sample_frames(total_frames, video_fps, num_frames, sample_fps):
    interval = max(1, round(video_fps / sample_fps))
    need = (num_frames - 1) * interval + 1
    if total_frames < need:
        indices = np.linspace(0, total_frames - 1, num=num_frames).astype(int)
    else:
        start = random.randint(0, total_frames - need)
        indices = np.linspace(start, start + need - 1, num=num_frames).astype(int)
    return indices


def find_mp4s(root: str):
    root = Path(root)
    return list(root.rglob("*.mp4"))


class DroidVideoDataset(Dataset):
    """Videos from /data/weiduoyuan/droid_raw/1.0.1/*/*/recordings/MP4/*.mp4, 8 frames @ ~4fps, 256x256."""

    def __init__(self, root: str, num_frames: int = 8, sample_fps: float = 4.0, size: int = 256):
        self.root = Path(root)
        self.num_frames = num_frames
        self.sample_fps = sample_fps
        self.size = size
        self.videos = find_mp4s(self.root)
        if not self.videos:
            raise FileNotFoundError(f"No .mp4 under {self.root}")

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        path = self.videos[idx]
        if decord is None:
            raise RuntimeError("decord not installed")
        vr = decord.VideoReader(str(path), num_threads=0)
        total = len(vr)
        fps = vr.get_avg_fps()
        indices = _sample_frames(total, fps, self.num_frames, self.sample_fps)
        frames = vr.get_batch(indices)  # (T, H, W, 3)
        frames = frames.float() / 255.0
        frames = frames.permute(0, 3, 1, 2)  # (T, 3, H, W)
        frames = torch.nn.functional.interpolate(
            frames, size=(self.size, self.size), mode="bilinear", align_corners=False
        )
        frames = frames.permute(1, 0, 2, 3)  # (3, T, H, W)
        frames = frames * 2.0 - 1.0
        return frames.unsqueeze(0)


def collate_batch(batch):
    return torch.cat(batch, dim=0)
