"""Same as our_vid_model: MP4s under root, 8 frames @ ~4fps, 256x256, (B, 3, 8, 256, 256) in [-1,1]."""

import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

try:
    import decord
    decord.bridge.set_bridge("torch")
except Exception:
    decord = None


def _sample_frames(total_frames, video_fps, num_frames, sample_fps):
    interval = max(1, round(video_fps / sample_fps))
    need = (num_frames - 1) * interval + 1
    if total_frames < need:
        indices = np.linspace(0, total_frames - 1, num=num_frames).astype(int)
    else:
        start = random.randint(0, total_frames - need)
        indices = np.linspace(start, start + need - 1, num=num_frames).astype(int)
    return indices


def find_mp4s(root: str):
    root = Path(root)
    return list(root.rglob("*.mp4"))


def find_cached_clips(cache_dir: str):
    cache_dir = Path(cache_dir)
    return sorted(cache_dir.glob("*.pt"))


class CachedClipDataset(Dataset):
    """Load pre-extracted clips from cache_dir (from precache_clips.py). No video decode at train time."""

    def __init__(self, cache_dir: str):
        self.clips = find_cached_clips(cache_dir)
        if not self.clips:
            raise FileNotFoundError(f"No .pt clips in {cache_dir}. Run precache_clips.py first.")

    def __len__(self):
        return len(self.clips)

    def __getitem__(self, idx):
        try:
            return torch.load(self.clips[idx], map_location="cpu", weights_only=True)
        except TypeError:
            return torch.load(self.clips[idx], map_location="cpu")


class DroidVideoDataset(Dataset):
    def __init__(self, root: str, num_frames: int = 8, sample_fps: float = 4.0, size: int = 256):
        self.root = Path(root)
        self.num_frames = num_frames
        self.sample_fps = sample_fps
        self.size = size
        self.videos = find_mp4s(self.root)
        if not self.videos:
            raise FileNotFoundError(f"No .mp4 under {self.root}")

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        path = self.videos[idx]
        if decord is None:
            raise RuntimeError("decord not installed")
        vr = decord.VideoReader(str(path), num_threads=2)
        total = len(vr)
        fps = vr.get_avg_fps()
        indices = _sample_frames(total, fps, self.num_frames, self.sample_fps)
        frames = vr.get_batch(indices)
        frames = frames.float() / 255.0
        frames = frames.permute(0, 3, 1, 2)
        frames = torch.nn.functional.interpolate(
            frames, size=(self.size, self.size), mode="bilinear", align_corners=False
        )
        frames = frames.permute(1, 0, 2, 3)
        frames = frames * 2.0 - 1.0
        return frames.unsqueeze(0)


def collate_batch(batch):
    return torch.cat(batch, dim=0)
