"""Pre-extract 8-frame clips from self-collected keygrip episodes to a cache dir.

Episodes live under: root/task_name/episode_name/00000.png, 00001.png, ...
We sample random consecutive 8-frame clips per episode and save them as .pt
files shaped (1, 3, T, H, W) in [-1, 1], matching simple_uva expectations.
"""

import argparse
import hashlib
import sys
from pathlib import Path

import numpy as np
import torch
from tqdm import tqdm

vidgen_root = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(vidgen_root))


def _natural_sort_key(p: Path):
    """Sort by stem numerically when possible (00000.png, 00001.png, ...)."""
    s = p.stem
    try:
        return (0, int(s))
    except ValueError:
        return (1, s)


def find_episodes(root: Path, min_frames: int):
    """Find episode directories containing at least min_frames PNGs."""
    episodes = []
    root = root.resolve()
    for task_dir in sorted(root.iterdir()):
        if not task_dir.is_dir():
            continue
        for episode_dir in sorted(task_dir.iterdir()):
            if not episode_dir.is_dir():
                continue
            pngs = sorted(
                list(episode_dir.glob("*.png")),
                key=_natural_sort_key,
            )
            if len(pngs) >= min_frames:
                episodes.append((episode_dir, pngs))
    return episodes


def main():
    p = argparse.ArgumentParser(description="Pre-extract clips from keygrip self-collected episodes to cache.")
    p.add_argument(
        "--data-root",
        type=str,
        default="/data/cameron/keygrip/scratch",
        help="Root with task/episode/00000.png, 00001.png, ...",
    )
    p.add_argument(
        "--cache-dir",
        type=str,
        required=True,
        help="Output dir for .pt clip files",
    )
    p.add_argument("--clips-per-episode", type=int, default=5, help="Random clips to extract per episode")
    p.add_argument("--num-frames", type=int, default=8)
    p.add_argument("--size", type=int, default=256)
    p.add_argument("--max-episodes", type=int, default=None, help="Cap number of episodes (for testing)")
    args = p.parse_args()

    cache_dir = Path(args.cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)
    episodes = find_episodes(Path(args.data_root), min_frames=args.num_frames)
    if not episodes:
        raise FileNotFoundError(
            f"No episodes with >={args.num_frames} PNGs under {args.data_root}. "
            "Expected structure: root/task_name/episode_name/00000.png, 00001.png, ..."
        )
    if args.max_episodes is not None:
        episodes = episodes[: args.max_episodes]
    print(f"Extracting clips from {len(episodes)} episodes -> {cache_dir}")

    for episode_dir, pngs in tqdm(episodes, desc="episodes"):
        path_hash = hashlib.md5(str(episode_dir.resolve()).encode()).hexdigest()[:12]
        n = len(pngs)
        # deterministic randomness per episode
        np.random.seed(hash(path_hash) % (2**32))
        for ci in range(args.clips_per_episode):
            if n == args.num_frames:
                start = 0
            else:
                start = np.random.randint(0, n - args.num_frames + 1)
            paths = pngs[start : start + args.num_frames]
            frames = []
            from torchvision.io import read_image
            for p_png in paths:
                img = read_image(str(p_png))
                if img.shape[0] == 1:
                    img = img.repeat(3, 1, 1)
                elif img.shape[0] == 4:
                    img = img[:3]
                img = img.float() / 255.0
                img = torch.nn.functional.interpolate(
                    img.unsqueeze(0),
                    size=(args.size, args.size),
                    mode="bilinear",
                    align_corners=False,
                ).squeeze(0)
                frames.append(img)
            frames = torch.stack(frames, dim=1)  # (3, T, H, W)
            frames = frames * 2.0 - 1.0
            out = frames.unsqueeze(0)  # (1, 3, T, H, W)
            out_path = cache_dir / f"{path_hash}_{ci:03d}.pt"
            torch.save(out, out_path)

    print(f"Done. Clips in {cache_dir}")


if __name__ == "__main__":
    main()
