"""Build a combined Motion-LoRA dataset from every 'mac teleop'-tagged session.

For each source dir:
  - if rgb_overlay/episodes.json exists: one MP4 per episode (start..end inclusive)
  - else (3 unannotated dirs): one MP4 covering the entire rgb_*.jpg sequence

Skip episodes/clips with < N_FRAMES_MIN frames (Motion-LoRA's loader requires
total >= sample_frames). Output one MP4 per clip + one start-frame jpg per source
session into a sibling _val dir.
"""
from pathlib import Path
import json
import imageio.v2 as imageio
import numpy as np

SRC_ROOT = Path("/data/cameron/mac_robot_datasets")
SOURCES = [
    "umi_fold_towel", "dataset_20260501_180125", "dataset_20260505_114857",
    "dataset_20260506_124503", "dataset_20260506_151912", "robot_pick_cup",
    "umi_pick_cup", "dataset_20260509_105300", "dataset_20260509_170535",
    "dataset_20260510_162906", "dataset_20260510_173415",
    "dataset_20260510_173718", "dataset_20260510_181313",
    "dataset_20260510_204602", "dataset_20260510_225914",
    "dataset_20260510_235505", "dataset_20260511_002247",
    "dataset_20260511_133840", "dataset_20260511_153242",
    "dataset_20260511_185505",
]
ROOT = Path(__file__).resolve().parent
DST = ROOT / "mac_teleop_all_20260518"
VAL = ROOT / "mac_teleop_all_20260518_val"
DST.mkdir(parents=True, exist_ok=True)
VAL.mkdir(parents=True, exist_ok=True)

N_FRAMES_MIN = 7  # Motion-LoRA num_frames=7
FPS = 7

def write_clip(frames_np, out_mp4):
    imageio.mimwrite(
        str(out_mp4), frames_np, fps=FPS,
        codec="libx264", quality=8, macro_block_size=1,
        ffmpeg_params=["-movflags", "+faststart"],
    )

n_total = 0
n_skipped_short = 0
for ds in SOURCES:
    src = SRC_ROOT / ds
    if not src.is_dir():
        print(f"  [missing] {ds}")
        continue
    ep_json = src / "rgb_overlay" / "episodes.json"
    if ep_json.is_file():
        eps = json.loads(ep_json.read_text())["episodes"]
        first_frame_written = False
        for ep in eps:
            ep_id = ep["id"]
            start, end = int(ep["start"]), int(ep["end"])
            frames = []
            for i in range(start, end + 1):
                p = src / f"rgb_{i:06d}.jpg"
                if p.is_file():
                    frames.append(imageio.imread(str(p)))
            if len(frames) < N_FRAMES_MIN:
                print(f"  [skip-short] {ds}/{ep_id}: only {len(frames)} frames")
                n_skipped_short += 1
                continue
            arr = np.stack(frames)
            write_clip(arr, DST / f"{ds}__{ep_id}.mp4")
            if not first_frame_written:
                imageio.imwrite(str(VAL / f"{ds}.jpg"), frames[0])
                first_frame_written = True
            n_total += 1
        print(f"  [annotated] {ds}: {len(eps)} eps -> {sum(1 for ep in eps if (end := int(ep['end'])) - int(ep['start']) + 1 >= N_FRAMES_MIN)} kept (approx)")
    else:
        # No episodes.json: treat entire jpg sequence as one episode.
        jpgs = sorted(src.glob("rgb_*.jpg"))
        if len(jpgs) < N_FRAMES_MIN:
            print(f"  [skip-short] {ds} whole-seq: only {len(jpgs)} frames")
            n_skipped_short += 1
            continue
        frames = [imageio.imread(str(p)) for p in jpgs]
        arr = np.stack(frames)
        write_clip(arr, DST / f"{ds}__whole.mp4")
        imageio.imwrite(str(VAL / f"{ds}.jpg"), frames[0])
        n_total += 1
        print(f"  [unannotated] {ds}: whole {len(jpgs)} frames -> 1 mp4")

print(f"\ntotal: {n_total} MP4s, skipped {n_skipped_short} short clips")
print(f"validation start-frames: {VAL}")
