"""Convert raw YAM data dumps into vla_foundry WebDataset (WDS) tar shards.

Input layout (per-sequence directory):
    <input>/
      <seq_id>/
        rgb/<camera>/NNNNNNNNNN.png
        lowdim/NNNNNNNNNN.pkl      # dict with keys: joints, action_joints, language_prompt, ...
        metadata.json
      metadata_shared.json

Output layout (vla_foundry WDS):
    <output>/
      manifest.jsonl                 # one line per shard: {"shard": "shard_XXXXXX", "num_sequences": N}
      shard_000000.tar
      shard_000001.tar
      ...

Each tar contains one entry per training sample. Per-sample files (all share the same
sample_id prefix):
    {sample_id}.{camera}_t0.jpg                # anchor frame, one per camera
    {sample_id}.lowdim.npz                     # arrays keyed by field name, shape [T, D]
    {sample_id}.metadata.json                  # {"anchor_relative_idx": <int>}
    {sample_id}.language_instructions.json     # {"original": "<prompt>"}

Depth images are skipped.

Usage:
    python scripts/convert_yam_raw_to_wds.py \
        --input /path/to/raw-yam-dump \
        --output /path/to/yam-wds \
        --samples-per-shard 10 \
        --stride 15 \
        --past 0 --future 14
"""

from __future__ import annotations

import argparse
import io
import json
import logging
import pickle
import tarfile
from dataclasses import dataclass
from pathlib import Path

import numpy as np
from PIL import Image


logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)


# Field name mapping: key in .pkl  →  key in output lowdim.npz
LOWDIM_FIELD_MAP: dict[str, str] = {
    "joints": "yam__joint_pos",
    "action_joints": "yam__action_joints",
}

# Only these camera subdirs are converted; depth is skipped.
CAMERA_SUBDIRS = ("scene_camera", "left_wrist_camera", "right_wrist_camera")


@dataclass
class Sample:
    """A single training sample: anchor frame + cropped lowdim window."""
    sample_id: str
    images: dict[str, bytes]           # camera → JPEG bytes
    lowdim: dict[str, np.ndarray]      # field → array of shape [past+1+future, D]
    prompt: str


def _load_sequence_lowdim(seq_dir: Path) -> tuple[dict[str, np.ndarray], str]:
    """Load all per-frame .pkl files in a sequence and stack their lowdim fields."""
    pkl_paths = sorted(seq_dir.glob("lowdim/*.pkl"))
    if not pkl_paths:
        raise ValueError(f"No lowdim pkls in {seq_dir}")

    stacked: dict[str, list[np.ndarray]] = {out_k: [] for out_k in LOWDIM_FIELD_MAP.values()}
    prompt: str | None = None

    for p in pkl_paths:
        with open(p, "rb") as f:
            d = pickle.load(f)
        for src_k, dst_k in LOWDIM_FIELD_MAP.items():
            if src_k not in d:
                raise KeyError(f"Expected key '{src_k}' in {p}; got {list(d.keys())}")
            arr = np.asarray(d[src_k], dtype=np.float32)
            stacked[dst_k].append(arr)
        if prompt is None:
            lp = d.get("language_prompt")
            if lp is not None:
                prompt = str(lp)

    out = {k: np.stack(v, axis=0) for k, v in stacked.items()}
    return out, (prompt or "")


def _encode_png_to_jpeg(png_path: Path, quality: int = 95) -> bytes:
    """Load a PNG and re-encode as JPEG bytes."""
    with Image.open(png_path) as img:
        img = img.convert("RGB")
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=quality)
        return buf.getvalue()


def _build_samples_from_sequence(
    seq_dir: Path,
    seq_id: str,
    *,
    past: int,
    future: int,
    stride: int,
    jpeg_quality: int,
) -> list[Sample]:
    """Build all training samples from one sequence directory."""
    lowdim, prompt = _load_sequence_lowdim(seq_dir)
    n_frames = next(iter(lowdim.values())).shape[0]

    # Check that the RGB directory contains the expected cameras
    missing = [c for c in CAMERA_SUBDIRS if not (seq_dir / "rgb" / c).is_dir()]
    if missing:
        raise FileNotFoundError(f"Sequence {seq_id}: missing camera dirs {missing}")

    # Valid anchor range: [past, n_frames - 1 - future]
    first = past
    last = n_frames - 1 - future
    if last < first:
        logger.warning(f"Sequence {seq_id}: only {n_frames} frames, cannot fit window past={past} future={future}. Skipping.")
        return []

    anchors = list(range(first, last + 1, stride))
    samples: list[Sample] = []

    for a in anchors:
        sample_id = f"{seq_id}_{a:010d}"

        # Crop lowdim window [a-past, a+future+1]
        lo = a - past
        hi = a + future + 1
        cropped = {k: v[lo:hi] for k, v in lowdim.items()}

        # Anchor-frame image for each camera (encoded once per sample)
        frame_fname = f"{a:010d}.png"
        images: dict[str, bytes] = {}
        for cam in CAMERA_SUBDIRS:
            png_path = seq_dir / "rgb" / cam / frame_fname
            if not png_path.is_file():
                raise FileNotFoundError(f"Missing image: {png_path}")
            images[cam] = _encode_png_to_jpeg(png_path, quality=jpeg_quality)

        samples.append(Sample(sample_id=sample_id, images=images, lowdim=cropped, prompt=prompt))

    return samples


def _write_shard(tar_path: Path, samples: list[Sample]) -> None:
    """Write a list of Samples into one tar file in vla_foundry WDS format."""
    with tarfile.open(tar_path, "w") as tar:
        for s in samples:
            # Images
            for cam, jpeg_bytes in s.images.items():
                name = f"{s.sample_id}.{cam}_t0.jpg"
                info = tarfile.TarInfo(name=name)
                info.size = len(jpeg_bytes)
                tar.addfile(info, io.BytesIO(jpeg_bytes))

            # Lowdim npz
            buf = io.BytesIO()
            np.savez(buf, **s.lowdim)
            npz_bytes = buf.getvalue()
            info = tarfile.TarInfo(name=f"{s.sample_id}.lowdim.npz")
            info.size = len(npz_bytes)
            tar.addfile(info, io.BytesIO(npz_bytes))

            # Metadata (anchor_relative_idx=past; since we already cropped,
            # the anchor is at position `past` inside the cropped window)
            meta = {"anchor_relative_idx": 0}  # past=0 ⇒ anchor at index 0
            meta_bytes = json.dumps(meta).encode("utf-8")
            info = tarfile.TarInfo(name=f"{s.sample_id}.metadata.json")
            info.size = len(meta_bytes)
            tar.addfile(info, io.BytesIO(meta_bytes))

            # Language instructions
            lang = {"original": s.prompt}
            lang_bytes = json.dumps(lang).encode("utf-8")
            info = tarfile.TarInfo(name=f"{s.sample_id}.language_instructions.json")
            info.size = len(lang_bytes)
            tar.addfile(info, io.BytesIO(lang_bytes))


def convert(
    input_dir: Path,
    output_dir: Path,
    *,
    samples_per_shard: int,
    stride: int,
    past: int,
    future: int,
    jpeg_quality: int,
) -> None:
    """Top-level conversion."""
    if not input_dir.is_dir():
        raise NotADirectoryError(f"Input not a directory: {input_dir}")
    output_dir.mkdir(parents=True, exist_ok=True)

    # Discover sequence directories (numeric names, e.g., "0000", "0001")
    seq_dirs = sorted([p for p in input_dir.iterdir() if p.is_dir() and p.name.isdigit()])
    if not seq_dirs:
        raise ValueError(f"No sequence directories found in {input_dir}")
    logger.info(f"Found {len(seq_dirs)} sequence directories")

    # Set metadata about anchor position — derived from `past`.
    if past != 0:
        # We always write anchor_relative_idx=0 in the tar because we store the
        # cropped window [anchor-past, anchor+future+1] starting at 0 and the
        # downstream loader re-crops with `lowdim_past_timesteps=past`. If past>0,
        # the loader would crop past elements before 0 (nothing there), which is
        # wrong. Keep past=0 unless the loader config is also changed.
        logger.warning(
            f"past={past}: downstream loader config must use lowdim_past_timesteps={past} "
            f"and anchor_relative_idx will need to equal {past} — this script writes 0."
        )

    # Gather all samples across all sequences
    all_samples: list[Sample] = []
    for seq_dir in seq_dirs:
        seq_id = seq_dir.name
        logger.info(f"  Processing sequence {seq_id}")
        seq_samples = _build_samples_from_sequence(
            seq_dir, seq_id, past=past, future=future, stride=stride, jpeg_quality=jpeg_quality,
        )
        logger.info(f"    → {len(seq_samples)} samples")
        all_samples.extend(seq_samples)

    logger.info(f"Total samples: {len(all_samples)}")
    if not all_samples:
        raise RuntimeError("No samples produced")

    # Pack into shards
    manifest_entries: list[dict] = []
    n_shards = (len(all_samples) + samples_per_shard - 1) // samples_per_shard
    for si in range(n_shards):
        chunk = all_samples[si * samples_per_shard : (si + 1) * samples_per_shard]
        shard_name = f"shard_{si:06d}"
        tar_path = output_dir / f"{shard_name}.tar"
        logger.info(f"  Writing {tar_path.name} ({len(chunk)} samples)")
        _write_shard(tar_path, chunk)
        manifest_entries.append({"shard": shard_name, "num_sequences": len(chunk)})

    # Write manifest
    manifest_path = output_dir / "manifest.jsonl"
    with open(manifest_path, "w") as f:
        for e in manifest_entries:
            f.write(json.dumps(e) + "\n")
    logger.info(f"Wrote manifest: {manifest_path}  ({n_shards} shards, {len(all_samples)} samples)")


def main():
    ap = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    ap.add_argument("--input", type=Path, required=True, help="Raw YAM dump directory")
    ap.add_argument("--output", type=Path, required=True, help="Output WDS directory")
    ap.add_argument("--samples-per-shard", type=int, default=10, help="Samples per tar shard")
    ap.add_argument("--stride", type=int, default=15, help="Frames between anchors (per sequence)")
    ap.add_argument("--past", type=int, default=0, help="Past timesteps in lowdim window")
    ap.add_argument("--future", type=int, default=14, help="Future timesteps in lowdim window")
    ap.add_argument("--jpeg-quality", type=int, default=95, help="JPEG quality (1-100)")
    args = ap.parse_args()

    convert(
        args.input,
        args.output,
        samples_per_shard=args.samples_per_shard,
        stride=args.stride,
        past=args.past,
        future=args.future,
        jpeg_quality=args.jpeg_quality,
    )


if __name__ == "__main__":
    main()
