"""Build keyframe-video MP4s from dataset_20260509_105300 for Motion-LoRA.

Mirror of `Smith300TrajectoryDataset(use_keyframes=True)` sampling logic from
/data/cameron/para/para_mac/data_smith300_para.py:

  for each episode:
      kf_frames = [int(kf['frame']) for kf in ep['keyframes']
                   if ep_start <= kf['frame'] <= ep_end]
      if len(kf_frames) < n_window:
          kf_frames = kf_frames + [kf_frames[-1]] * (n_window - len(kf_frames))
      for t in range(len(kf_frames) - n_window + 1):
          sample = kf_frames[t:t+n_window]   # 4 keyframes -> one 4-frame mp4

Outputs:
    dataset/smith300_keyframes_w4_20260509/<ep_id>__win<t>.mp4   (4 frames each)
    dataset/smith300_keyframes_w4_20260509_val/<ep_id>__win0.jpg (first kf of first window)
"""
from pathlib import Path
import json
import imageio.v2 as imageio
import numpy as np

SRC = Path("/data/cameron/mac_robot_datasets/dataset_20260509_105300")
ROOT = Path(__file__).resolve().parent
DST = ROOT / "smith300_keyframes_w4_20260509"
VAL = ROOT / "smith300_keyframes_w4_20260509_val"
DST.mkdir(parents=True, exist_ok=True)
VAL.mkdir(parents=True, exist_ok=True)

N_WINDOW = 4
FPS = 7  # cosmetic; SVD pipeline + decord both use frame indices

eps = json.loads((SRC / "rgb_overlay" / "episodes.json").read_text())["episodes"]

n_total_windows = 0
for ep in eps:
    ep_id = ep["id"]
    ep_start = int(ep["start"])
    ep_end = int(ep["end"])

    kf_frames = sorted(int(kf["frame"]) for kf in ep.get("keyframes", [])
                       if ep_start <= int(kf["frame"]) <= ep_end)
    if not kf_frames:
        print(f"  skip {ep_id}: no keyframes")
        continue
    if len(kf_frames) < N_WINDOW:
        pad = N_WINDOW - len(kf_frames)
        kf_frames = kf_frames + [kf_frames[-1]] * pad
        print(f"  {ep_id}: padded last keyframe x{pad}")

    n_windows = len(kf_frames) - N_WINDOW + 1
    print(f"  {ep_id}: {len(kf_frames)} kfs -> {n_windows} windows ({kf_frames})")
    for t in range(n_windows):
        win_frames_idx = kf_frames[t:t + N_WINDOW]
        frames = []
        for fi in win_frames_idx:
            p = SRC / f"rgb_{fi:06d}.jpg"
            if not p.is_file():
                raise FileNotFoundError(f"missing frame {p}")
            frames.append(imageio.imread(str(p)))
        frames = np.stack(frames)
        out_mp4 = DST / f"{ep_id}__win{t:02d}.mp4"
        imageio.mimwrite(
            str(out_mp4), frames, fps=FPS,
            codec="libx264", quality=8, macro_block_size=1,
            ffmpeg_params=["-movflags", "+faststart"],
        )
        if t == 0:
            # Save first-window's first keyframe for val
            imageio.imwrite(str(VAL / f"{ep_id}__win0.jpg"), frames[0])
        n_total_windows += 1

print(f"total {n_total_windows} 4-frame MP4s -> {DST}")
print(f"validation start frames -> {VAL}")
