"""Pre-extract 8-frame clips from videos to a cache dir. Training then loads .pt files instead of decoding MP4s."""

import argparse
import hashlib
import multiprocessing as mp
import sys
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm

vidgen_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(vidgen_root))

from dino_vid_model.dataset import find_mp4s

try:
    import decord
    decord.bridge.set_bridge("torch")
except Exception:
    decord = None


def _extract_one(args_tuple):
    """Worker: (path_str, cache_dir_str, clips_per_video, num_frames, sample_fps, size) -> 0 on success, -1 if skipped."""
    import decord
    decord.bridge.set_bridge("torch")
    path_str, cache_dir_str, clips_per_video, num_frames, sample_fps, size = args_tuple
    path = Path(path_str)
    cache_dir = Path(cache_dir_str)
    try:
        vr = decord.VideoReader(str(path), num_threads=1)
    except Exception:
        return -1
    total = len(vr)
    fps = vr.get_avg_fps()
    need = (num_frames - 1) * max(1, round(fps / sample_fps)) + 1
    if total < need:
        return -1
    path_hash = hashlib.md5(str(path.resolve()).encode()).hexdigest()[:12]
    np.random.seed(hash(path_hash) % (2**32))
    for ci in range(clips_per_video):
        start = np.random.randint(0, total - need + 1)
        indices = np.linspace(start, start + need - 1, num=num_frames).astype(int)
        frames = vr.get_batch(indices)
        frames = frames.float() / 255.0
        frames = frames.permute(0, 3, 1, 2)
        frames = torch.nn.functional.interpolate(
            frames, size=(size, size), mode="bilinear", align_corners=False
        )
        frames = frames.permute(1, 0, 2, 3)
        frames = frames * 2.0 - 1.0
        out = frames.unsqueeze(0)
        out_path = cache_dir / f"{path_hash}_{ci:03d}.pt"
        torch.save(out, out_path)
    return 0


def main():
    p = argparse.ArgumentParser(description="Pre-extract clips to cache for fast training.")
    p.add_argument("--data-root", type=str, default="/data/weiduoyuan/droid_raw/1.0.1")
    p.add_argument("--cache-dir", type=str, required=True, help="Output dir for .pt clip files")
    p.add_argument("--clips-per-video", type=int, default=10, help="Random clips to extract per video")
    p.add_argument("--num-frames", type=int, default=8)
    p.add_argument("--sample-fps", type=float, default=4.0)
    p.add_argument("--size", type=int, default=256)
    p.add_argument("--max-videos", type=int, default=None, help="Cap number of videos (for testing)")
    p.add_argument("--workers", type=int, default=8, help="Parallel workers (multiprocessing)")
    args = p.parse_args()

    if decord is None:
        raise RuntimeError("decord required for precache")

    cache_dir = Path(args.cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)
    videos = find_mp4s(args.data_root)
    if not videos:
        raise FileNotFoundError(f"No .mp4 under {args.data_root}")
    if args.max_videos is not None:
        videos = videos[: args.max_videos]
    print(f"Extracting clips from {len(videos)} videos -> {cache_dir} (workers={args.workers})")

    if args.workers <= 1:
        for vi, path in enumerate(tqdm(videos, desc="videos")):
            path = Path(path)
            vr = decord.VideoReader(str(path), num_threads=2)
            total = len(vr)
            fps = vr.get_avg_fps()
            need = (args.num_frames - 1) * max(1, round(fps / args.sample_fps)) + 1
            if total < need:
                continue
            path_hash = hashlib.md5(str(path.resolve()).encode()).hexdigest()[:12]
            for ci in range(args.clips_per_video):
                start = np.random.randint(0, total - need + 1)
                indices = np.linspace(start, start + need - 1, num=args.num_frames).astype(int)
                frames = vr.get_batch(indices)
                frames = frames.float() / 255.0
                frames = frames.permute(0, 3, 1, 2)
                frames = torch.nn.functional.interpolate(
                    frames, size=(args.size, args.size), mode="bilinear", align_corners=False
                )
                frames = frames.permute(1, 0, 2, 3)
                frames = frames * 2.0 - 1.0
                out = frames.unsqueeze(0)
                out_path = cache_dir / f"{path_hash}_{ci:03d}.pt"
                torch.save(out, out_path)
    else:
        task_tuples = [
            (str(Path(p).resolve()), str(cache_dir), args.clips_per_video, args.num_frames, args.sample_fps, args.size)
            for p in videos
        ]
        with mp.Pool(args.workers) as pool:
            for _ in tqdm(
                pool.imap_unordered(_extract_one, task_tuples, chunksize=1),
                total=len(task_tuples),
                desc="videos",
            ):
                pass
    print(f"Done. Clips in {cache_dir}")


if __name__ == "__main__":
    main()
