"""
Dataset for image-to-video fine-tuning: random N-frame windows from robot episodes.
Each sample is a contiguous clip; first frame is conditioning, full clip is target.
Output format matches DiffusionEngine: "jpg" = video in [-1, 1], (B, T, C, H, W).
"""
import random
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import Dataset

# SVD default resolution
VIDEO_H = 576
VIDEO_W = 1024
NUM_FRAMES = 14  # SVD uses 14; we pad 12-frame clips to 14


class RobotVideoDataset(Dataset):
    """
    Samples random contiguous N-frame windows from episode directories.
    Each episode dir contains 000000.png, 000001.png, ... (and optional .npy files).
    """

    def __init__(
        self,
        dataset_root: str,
        num_frames: int = 12,
        target_num_frames: int = NUM_FRAMES,
        height: int = VIDEO_H,
        width: int = VIDEO_W,
        max_episodes: int | None = None,
        seed: int | None = None,
    ):
        self.dataset_root = Path(dataset_root)
        self.num_frames = num_frames
        self.target_num_frames = target_num_frames
        self.height = height
        self.width = width
        if seed is not None:
            random.seed(seed)

        if not self.dataset_root.exists():
            raise ValueError(f"Dataset root not found: {self.dataset_root}")

        episode_dirs = sorted(
            [d for d in self.dataset_root.iterdir() if d.is_dir() and "episode" in d.name]
        )
        if max_episodes is not None:
            episode_dirs = episode_dirs[:max_episodes]
        if len(episode_dirs) == 0:
            raise ValueError(f"No episodes in {self.dataset_root}")

        # Build list of (episode_dir, start_idx) such that we have at least num_frames from start_idx
        self.samples = []
        for ep_dir in episode_dirs:
            frame_files = sorted([f for f in ep_dir.glob("*.png") if f.stem.isdigit()])
            frame_indices = sorted([int(f.stem) for f in frame_files])
            frame_set = set(frame_indices)
            for start_idx in frame_indices:
                need = [start_idx + i for i in range(num_frames)]
                if all(i in frame_set for i in need):
                    self.samples.append((ep_dir, start_idx))
        if len(self.samples) == 0:
            raise ValueError(
                f"No valid {num_frames}-frame windows in {len(episode_dirs)} episodes. "
                "Ensure episodes have consecutive frame indices."
            )
        print(f"RobotVideoDataset: {len(episode_dirs)} episodes, {len(self.samples)} samples")

    def __len__(self):
        return len(self.samples)

    def _load_frame(self, path: Path) -> np.ndarray:
        from PIL import Image
        img = np.array(Image.open(path).convert("RGB"))
        return img

    def __getitem__(self, idx: int) -> dict:
        ep_dir, start_idx = self.samples[idx]
        frames = []
        for i in range(self.num_frames):
            frame_idx = start_idx + i
            path = ep_dir / f"{frame_idx:06d}.png"
            frames.append(self._load_frame(path))
        # (T, H, W, 3) uint8
        video = np.stack(frames, axis=0).astype(np.float32) / 255.0
        # Resize to model resolution: (T, target_H, target_W, 3)
        import cv2
        video = np.stack(
            [
                cv2.resize(
                    video[t],
                    (self.width, self.height),
                    interpolation=cv2.INTER_LINEAR,
                )
                for t in range(video.shape[0])
            ],
            axis=0,
        )
        # Pad to target_num_frames if needed (repeat last frame)
        if video.shape[0] < self.target_num_frames:
            pad = np.tile(
                video[-1:],
                (self.target_num_frames - video.shape[0], 1, 1, 1),
            )
            video = np.concatenate([video, pad], axis=0)
        elif video.shape[0] > self.target_num_frames:
            video = video[: self.target_num_frames]
        # (T, H, W, 3) -> (T, 3, H, W), scale to [-1, 1]
        video = torch.from_numpy(video).permute(0, 3, 1, 2)
        video = video * 2.0 - 1.0
        return {"jpg": video}


def collate_robot_video(batch: list, cond_aug: float = 0.02, fps_id: int = 6, motion_bucket_id: int = 127):
    """
    Collate batch and add conditioning keys expected by DiffusionEngine / StandardDiffusionLoss.
    """
    videos = torch.stack([b["jpg"] for b in batch])
    B, T, C, H, W = videos.shape
    first_frame = videos[:, :1].clone()  # (B, 1, C, H, W)
    cond_frames_without_noise = first_frame
    cond_frames = first_frame + cond_aug * torch.randn_like(first_frame)
    return {
        "jpg": videos,
        "cond_frames_without_noise": cond_frames_without_noise,
        "cond_frames": cond_frames,
        "fps_id": torch.tensor([fps_id] * B, dtype=torch.long),
        "motion_bucket_id": torch.tensor([motion_bucket_id] * B, dtype=torch.long),
        "cond_aug": torch.tensor([cond_aug] * B, dtype=torch.float32),
        "num_video_frames": T,
        "image_only_indicator": torch.zeros(B, T, dtype=torch.float32),
    }
