"""DROID video clips for simple_uva: 4-frame, 256x256, [-1,1]. Cache or raw MP4s."""

import random
from pathlib import Path

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

try:
    import decord
    decord.bridge.set_bridge("torch")
except Exception:
    decord = None

NUM_FRAMES = 4


def _sample_frames(total_frames, video_fps, num_frames, sample_fps):
    interval = max(1, round(video_fps / sample_fps))
    need = (num_frames - 1) * interval + 1
    if total_frames < need:
        indices = np.linspace(0, total_frames - 1, num=num_frames).astype(int)
    else:
        start = random.randint(0, total_frames - need)
        indices = np.linspace(start, start + need - 1, num=num_frames).astype(int)
    return indices


def find_mp4s(root: str):
    root = Path(root)
    return list(root.rglob("*.mp4"))


def find_cached_clips(cache_dir: str):
    cache_dir = Path(cache_dir)
    return sorted(cache_dir.glob("*.pt"))


class CachedClipDataset(Dataset):
    """Pre-extracted clips from cache_dir. Uses first NUM_FRAMES frames per clip."""

    def __init__(self, cache_dir: str, num_frames: int = NUM_FRAMES, max_load_retries: int = 5):
        self.clips = find_cached_clips(cache_dir)
        self.num_frames = num_frames
        self.max_load_retries = max_load_retries
        if not self.clips:
            raise FileNotFoundError(f"No .pt clips in {cache_dir}. Run precache first.")

    def __len__(self):
        return len(self.clips)

    def _load_one(self, path):
        try:
            out = torch.load(path, map_location="cpu", weights_only=True)
        except (TypeError, EOFError):
            try:
                out = torch.load(path, map_location="cpu", weights_only=False)
            except Exception:
                raise
        if out.dim() == 5:
            out = out[0]
        out = out[:, : self.num_frames]
        return out.unsqueeze(0)

    def __getitem__(self, idx):
        last_err = None
        for attempt in range(self.max_load_retries):
            i = (idx + attempt) % len(self.clips) if attempt > 0 else idx
            try:
                return self._load_one(self.clips[i])
            except (EOFError, OSError, RuntimeError) as e:
                last_err = e
                continue
        raise RuntimeError(f"Failed to load any of {self.max_load_retries} clips (last idx={i}): {last_err}")


class DroidVideoDataset(Dataset):
    def __init__(
        self,
        root: str,
        num_frames: int = NUM_FRAMES,
        sample_fps: float = 4.0,
        size: int = 256,
    ):
        self.root = Path(root)
        self.num_frames = num_frames
        self.sample_fps = sample_fps
        self.size = size
        self.videos = find_mp4s(self.root)
        if not self.videos:
            raise FileNotFoundError(f"No .mp4 under {self.root}")

    def __len__(self):
        return len(self.videos)

    def __getitem__(self, idx):
        path = self.videos[idx]
        if decord is None:
            raise RuntimeError("decord not installed")
        vr = decord.VideoReader(str(path), num_threads=2)
        total = len(vr)
        fps = vr.get_avg_fps()
        indices = _sample_frames(total, fps, self.num_frames, self.sample_fps)
        frames = vr.get_batch(indices)
        frames = frames.float() / 255.0
        frames = frames.permute(0, 3, 1, 2)
        frames = torch.nn.functional.interpolate(
            frames, size=(self.size, self.size), mode="bilinear", align_corners=False
        )
        frames = frames.permute(1, 0, 2, 3)
        frames = frames * 2.0 - 1.0
        return frames.unsqueeze(0)


def collate_batch(batch):
    return torch.cat(batch, dim=0)


def _natural_sort_key(p: Path):
    """Sort by stem numerically when possible."""
    s = p.stem
    try:
        return (0, int(s))
    except ValueError:
        return (1, s)


class SelfCollectedDataset(Dataset):
    """Episodes under root: root/task/episode/00000.png, 00001.png, ...
    Each subdir of root is a task; each subdir of a task is an episode with ordered PNGs.
    Samples random start-frame consecutive 8-frame sequences; returns first 4 frames for the 4-frame model.
    """

    def __init__(self, root: str, num_frames: int = NUM_FRAMES, size: int = 256, sequence_length: int = 8):
        self.root = Path(root).resolve()
        self.num_frames = num_frames
        self.size = size
        self.sequence_length = sequence_length  # sample this many consecutive, then use first num_frames
        self.episodes = []
        for task_dir in sorted(self.root.iterdir()):
            if not task_dir.is_dir():
                continue
            for episode_dir in sorted(task_dir.iterdir()):
                if not episode_dir.is_dir():
                    continue
                pngs = sorted(
                    list(episode_dir.glob("*.png")),
                    key=_natural_sort_key,
                )
                if len(pngs) >= self.sequence_length:
                    self.episodes.append((episode_dir, pngs))
        if not self.episodes:
            raise FileNotFoundError(
                f"No episodes with >={self.sequence_length} PNGs under {self.root}. "
                "Expected structure: root/task_name/episode_name/00000.png, 00001.png, ..."
            )

    def __len__(self):
        return len(self.episodes)

    def __getitem__(self, idx):
        import random
        from torchvision.io import read_image
        episode_dir, png_paths = self.episodes[idx]
        n = len(png_paths)
        start = random.randint(0, n - self.sequence_length)
        paths = png_paths[start : start + self.sequence_length]
        frames = []
        for p in paths[: self.num_frames]:
            img = read_image(str(p))
            if img.shape[0] == 1:
                img = img.repeat(3, 1, 1)
            elif img.shape[0] == 4:
                img = img[:3]
            img = img.float() / 255.0
            img = torch.nn.functional.interpolate(
                img.unsqueeze(0),
                size=(self.size, self.size),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)
            frames.append(img)
        out = torch.stack(frames, dim=1)
        out = out * 2.0 - 1.0
        return out.unsqueeze(0)


class KeygripVideoDataset(Dataset):
    """Same as SelfCollectedDataset (real N-frame video from task/episode/*.png) plus optional
    trajectory per episode for keygrip/PARA supervision. Returns real video (no repeated frame)
    and trajectory_2d/trajectory_3d when episode has trajectory.json or trajectory_map is provided.
    """

    def __init__(
        self,
        root: str,
        num_frames: int = NUM_FRAMES,
        size: int = 256,
        sequence_length: int = 8,
        trajectory_root: str | None = None,
        trajectory_map: dict | None = None,
    ):
        self.root = Path(root).resolve()
        self.trajectory_root = Path(trajectory_root).resolve() if trajectory_root else self.root
        self.trajectory_map = trajectory_map  # optional: dict[Path, list of (t2d, t3d)] or dict[Path, (t2d, t3d)] for first num_frames
        self.num_frames = num_frames
        self.size = size
        self.sequence_length = sequence_length
        self.episodes = []
        for task_dir in sorted(self.root.iterdir()):
            if not task_dir.is_dir():
                continue
            for episode_dir in sorted(task_dir.iterdir()):
                if not episode_dir.is_dir():
                    continue
                pngs = sorted(
                    list(episode_dir.glob("*.png")),
                    key=_natural_sort_key,
                )
                if len(pngs) >= self.sequence_length:
                    self.episodes.append((episode_dir, pngs))
        if not self.episodes:
            raise FileNotFoundError(
                f"No episodes with >={self.sequence_length} PNGs under {self.root}. "
                "Expected structure: root/task_name/episode_name/00000.png, 00001.png, ..."
            )

    def __len__(self):
        return len(self.episodes)

    def _load_trajectory(self, episode_dir: Path, start: int):
        """Load trajectory_2d (N,2) and trajectory_3d (N,3) from trajectory_map, trajectory.json, or return zeros.
        Waypoints in pixel coords (same resolution as self.size). Uses waypoints [start:start+num_frames] to align
        with sampled frames. Returns (t2d, t3d, True) or (zeros, zeros, False).
        """
        # 1) Optional pre-built map: trajectory_map[episode_dir] = (t2d [N,2], t3d [N,3])
        if self.trajectory_map is not None:
            key = None
            for candidate in (episode_dir, str(episode_dir)):
                if candidate in self.trajectory_map:
                    key = candidate
                    break
            if key is None and hasattr(episode_dir, "resolve"):
                try:
                    res = episode_dir.resolve()
                    if res in self.trajectory_map:
                        key = res
                except Exception:
                    pass
            if key is not None:
                t2d, t3d = self.trajectory_map[key]
                if isinstance(t2d, torch.Tensor):
                    t2d, t3d = t2d.clone(), t3d.clone()
                else:
                    t2d = torch.tensor(t2d, dtype=torch.float32)
                    t3d = torch.tensor(t3d, dtype=torch.float32)
                end = start + self.num_frames
                if t2d.shape[0] >= end and t3d.shape[0] >= end:
                    t2d = t2d[start:end]
                    t3d = t3d[start:end]
                else:
                    n = min(self.num_frames, t2d.shape[0], t3d.shape[0])
                    t2d = t2d[:n]
                    t3d = t3d[:n]
                    if t2d.shape[0] < self.num_frames:
                        t2d = F.pad(t2d, (0, 0, 0, self.num_frames - t2d.shape[0]), value=0.0)
                        t3d = F.pad(t3d, (0, 0, 0, self.num_frames - t3d.shape[0]), value=0.0)
                return t2d[: self.num_frames], t3d[: self.num_frames], True
        # 2) trajectory.json under trajectory_root/task/episode/
        import json
        rel = episode_dir.relative_to(self.root)
        traj_path = self.trajectory_root / rel / "trajectory.json"
        if not traj_path.exists():
            return (
                torch.zeros(self.num_frames, 2, dtype=torch.float32),
                torch.zeros(self.num_frames, 3, dtype=torch.float32),
                False,
            )
        with open(traj_path) as f:
            data = json.load(f)
        t2d = torch.tensor(data["trajectory_2d"], dtype=torch.float32)
        t3d = torch.tensor(data["trajectory_3d"], dtype=torch.float32)
        end = start + self.num_frames
        if t2d.shape[0] >= end and t3d.shape[0] >= end:
            t2d = t2d[start:end]
            t3d = t3d[start:end]
        else:
            n = min(self.num_frames, t2d.shape[0], t3d.shape[0])
            t2d = t2d[:n]
            t3d = t3d[:n]
            if t2d.shape[0] < self.num_frames:
                t2d = F.pad(t2d, (0, 0, 0, self.num_frames - t2d.shape[0]), value=0.0)
                t3d = F.pad(t3d, (0, 0, 0, self.num_frames - t3d.shape[0]), value=0.0)
        return t2d[: self.num_frames], t3d[: self.num_frames], True

    def __getitem__(self, idx):
        import random
        from torchvision.io import read_image
        episode_dir, pngs = self.episodes[idx]
        n = len(pngs)
        start = random.randint(0, n - self.sequence_length)
        paths = pngs[start : start + self.sequence_length]
        frames = []
        for p in paths[: self.num_frames]:
            img = read_image(str(p))
            if img.shape[0] == 1:
                img = img.repeat(3, 1, 1)
            elif img.shape[0] == 4:
                img = img[:3]
            img = img.float() / 255.0
            img = torch.nn.functional.interpolate(
                img.unsqueeze(0),
                size=(self.size, self.size),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)
            frames.append(img)
        out = torch.stack(frames, dim=1)
        out = out * 2.0 - 1.0
        video = out.unsqueeze(0)
        trajectory_2d, trajectory_3d, has_trajectory = self._load_trajectory(episode_dir, start)
        return {
            "video": video,
            "trajectory_2d": trajectory_2d,
            "trajectory_3d": trajectory_3d,
            "has_trajectory": has_trajectory,
        }
