"""Build keyframe-video MP4s from dataset_20260509_170535 for Motion-LoRA.

Same sliding 4-keyframe-window logic as smith300_keyframes_w4_20260509_build.py;
only the source dataset path differs. 18 episodes, 7-10 keyframes each, ~101 windows.
"""
from pathlib import Path
import json
import imageio.v2 as imageio
import numpy as np

SRC = Path("/data/cameron/mac_robot_datasets/dataset_20260509_170535")
ROOT = Path(__file__).resolve().parent
DST = ROOT / "smith300_keyframes_w4_d20260509_170535"
VAL = ROOT / "smith300_keyframes_w4_d20260509_170535_val"
DST.mkdir(parents=True, exist_ok=True)
VAL.mkdir(parents=True, exist_ok=True)

N_WINDOW = 4
FPS = 7

eps = json.loads((SRC / "rgb_overlay" / "episodes.json").read_text())["episodes"]

n_total_windows = 0
for ep in eps:
    ep_id = ep["id"]
    ep_start = int(ep["start"])
    ep_end = int(ep["end"])

    kf_frames = sorted(int(kf["frame"]) for kf in ep.get("keyframes", [])
                       if ep_start <= int(kf["frame"]) <= ep_end)
    if not kf_frames:
        print(f"  skip {ep_id}: no keyframes")
        continue
    if len(kf_frames) < N_WINDOW:
        pad = N_WINDOW - len(kf_frames)
        kf_frames = kf_frames + [kf_frames[-1]] * pad
        print(f"  {ep_id}: padded last keyframe x{pad}")

    n_windows = len(kf_frames) - N_WINDOW + 1
    print(f"  {ep_id}: {len(kf_frames)} kfs -> {n_windows} windows")
    for t in range(n_windows):
        win_frames_idx = kf_frames[t:t + N_WINDOW]
        frames = []
        for fi in win_frames_idx:
            p = SRC / f"rgb_{fi:06d}.jpg"
            if not p.is_file():
                raise FileNotFoundError(f"missing frame {p}")
            frames.append(imageio.imread(str(p)))
        frames = np.stack(frames)
        out_mp4 = DST / f"{ep_id}__win{t:02d}.mp4"
        imageio.mimwrite(
            str(out_mp4), frames, fps=FPS,
            codec="libx264", quality=8, macro_block_size=1,
            ffmpeg_params=["-movflags", "+faststart"],
        )
        if t == 0:
            imageio.imwrite(str(VAL / f"{ep_id}__win0.jpg"), frames[0])
        n_total_windows += 1

print(f"total {n_total_windows} 4-frame MP4s -> {DST}")
print(f"validation start frames -> {VAL}")
