"""Custom `rd convert` variant that can skip depth and/or exclude cameras.

The slow part of `rd convert` is per-frame depth computation: either the ZED
SDK's NEURAL_LIGHT pass, or a learned stereo network (FFS / TRI Stereo). This
script monkey-patches `raiden.converter._extract_svo2_synchronized` with a
no-depth version that only writes RGB pngs + `timestamps.npy`. The rest of
raiden's convert pipeline (timestamp alignment, lowdim pkl construction,
metadata, split files) runs unchanged.

Flags:
  --skip-depth          Skip depth entirely for ALL cameras (no `depth/` dir).
  --exclude-cameras     Comma-separated camera names to skip whole-SVO2 (no
                        rgb, no depth). Useful for skipping wrist cameras
                        when you don't need them. Excluded SVO2 files are
                        temporarily renamed to `.svo2.skip` during the run
                        and restored afterwards.
  --reconvert           Re-convert episodes that are already marked converted.
  --stereo_method       If --skip-depth is NOT set, passes through to raiden.

Usage (run on the YAM or wherever your data lives):
    python ~/cameron/yam_control/convert_no_depth.py \\
        --task ~/cameron/yam_control/data/raw/cube_in_carton \\
        --skip-depth --exclude-cameras left_wrist_camera,right_wrist_camera
"""
import argparse
import contextlib
import os
import time
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import cv2
import numpy as np

import raiden.converter as rcv


def _no_depth_extract(
    svo_paths: List[Path],
    names: List[str],
    rgb_dirs: List[Path],
    depth_dirs: List[Path],  # ignored
    flips: List[bool],
    max_frames: Optional[int] = None,
    sync_threshold_ns: int = 16_666_667,
    stereo_method: str = "zed",
    ffs_scale: float = 1.0,
    ffs_iters: int = 8,
    tri_stereo_variant: str = "c64",
) -> Dict[str, Tuple[np.ndarray, Optional[dict]]]:
    """Drop-in replacement for raiden.converter._extract_svo2_synchronized
    that writes only RGB pngs + timestamps.npy. No depth at all."""
    from raiden.cameras.zed import ZedCamera

    for d in rgb_dirs:
        d.mkdir(parents=True, exist_ok=True)

    cams: Dict[str, ZedCamera] = {
        name: ZedCamera.from_svo(name, svo_path, compute_sdk_depth=False)
        for name, svo_path in zip(names, svo_paths)
    }
    total_frames = {name: cam.get_total_frames() for name, cam in cams.items()}
    print(f"  [no-depth] Frames per camera: { {n: total_frames[n] for n in names} }")

    rgb_dir_map = dict(zip(names, rgb_dirs))
    flip_map = dict(zip(names, flips))

    # Per-camera state
    cam_ts: Dict[str, List[int]] = {n: [] for n in names}
    cam_buf: Dict[str, Optional[Tuple[int, np.ndarray]]] = {n: None for n in names}
    done: Dict[str, bool] = {n: False for n in names}

    def _advance(name: str) -> None:
        if done[name]:
            cam_buf[name] = None
            return
        cam = cams[name]
        if not cam.grab():
            done[name] = True
            cam_buf[name] = None
            return
        frame = cam.get_frame()
        cam_buf[name] = (int(frame.timestamp_ns), frame.color)

    for name in names:
        _advance(name)

    frame_idx = 0
    t_start = time.time()
    try:
        from tqdm import tqdm
        pbar = tqdm(total=min(total_frames.values()) if total_frames else 0,
                    desc="  extracting (rgb only)")
    except Exception:
        pbar = None

    while True:
        if max_frames is not None and frame_idx >= max_frames:
            break
        if any(cam_buf[n] is None for n in names):
            break

        # Sync: latest-timestamp wins; lagging cams advance until within threshold.
        latest_ts = max(cam_buf[n][0] for n in names)
        advanced = True
        while advanced:
            advanced = False
            for n in names:
                ts, _ = cam_buf[n]
                if latest_ts - ts > sync_threshold_ns:
                    _advance(n)
                    if cam_buf[n] is None:
                        break
                    advanced = True
                    latest_ts = max(latest_ts, cam_buf[n][0])
            if any(cam_buf[n] is None for n in names):
                break
        if any(cam_buf[n] is None for n in names):
            break

        # Write rgb for this synchronized slot
        for n in names:
            ts, rgb = cam_buf[n]
            if flip_map[n]:
                rgb = cv2.rotate(rgb, cv2.ROTATE_180)
            cv2.imwrite(str(rgb_dir_map[n] / f"{frame_idx:010d}.png"), rgb)
            cam_ts[n].append(ts)
            _advance(n)

        frame_idx += 1
        if pbar:
            pbar.update(1)

    if pbar:
        pbar.close()
    dt = time.time() - t_start
    print(f"  [no-depth] {frame_idx} synchronized frames in {dt:.1f}s "
          f"({frame_idx / max(dt, 1e-6):.1f} fps)")

    # Save timestamps.npy + return camera info dicts
    out: Dict[str, Tuple[np.ndarray, Optional[dict]]] = {}
    for n in names:
        ts_arr = np.asarray(cam_ts[n], dtype=np.int64)
        np.save(rgb_dir_map[n] / "timestamps.npy", ts_arr)
        try:
            info = cams[n].get_camera_info()
        except Exception:
            info = None
        out[n] = (ts_arr, info)
        cams[n].close()
    return out


@contextlib.contextmanager
def _temporarily_exclude_svos(task_dir: Path, exclude_names: List[str]):
    """Rename `<name>.svo2` → `<name>.svo2.skip` for each excluded camera so
    raiden's globbing doesn't see them. Restored on exit."""
    if not exclude_names:
        yield
        return

    renamed: List[Tuple[Path, Path]] = []
    try:
        for rec_dir in task_dir.iterdir():
            if not rec_dir.is_dir():
                continue
            cams_dir = rec_dir / "cameras"
            if not cams_dir.exists():
                continue
            for name in exclude_names:
                src = cams_dir / f"{name}.svo2"
                if src.exists():
                    dst = cams_dir / f"{name}.svo2.skip"
                    src.rename(dst)
                    renamed.append((src, dst))
        if renamed:
            print(f"  --exclude-cameras: temporarily renamed {len(renamed)} svo2 file(s)")
        yield
    finally:
        for src, dst in renamed:
            if dst.exists():
                dst.rename(src)
        if renamed:
            print(f"  --exclude-cameras: restored {len(renamed)} svo2 file(s)")


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--task", required=True, type=str,
                   help="Path to data/raw/<task_name>/ (the dir containing episode subdirs).")
    p.add_argument("--output_dir", default=None, type=str,
                   help="Where to write processed episodes. Default: <task_parent>/processed/<task_name>")
    p.add_argument("--skip-depth", action="store_true",
                   help="Skip depth processing for all cameras (RGB pngs only).")
    p.add_argument("--exclude-cameras", default="", type=str,
                   help="Comma-separated camera NAMES (not roles) to skip entirely "
                        "(e.g. 'left_wrist_camera,right_wrist_camera').")
    p.add_argument("--reconvert", action="store_true")
    p.add_argument("--stereo_method", default="zed",
                   choices=["zed", "ffs", "tri_stereo"],
                   help="Only relevant when --skip-depth is NOT set.")
    args = p.parse_args()

    task_dir = Path(args.task).resolve()
    exclude_names = [s.strip() for s in args.exclude_cameras.split(",") if s.strip()]

    # Patch in our no-depth extract if requested
    if args.skip_depth:
        print("  --skip-depth: patching raiden.converter._extract_svo2_synchronized")
        rcv._extract_svo2_synchronized = _no_depth_extract

    with _temporarily_exclude_svos(task_dir, exclude_names):
        rcv.convert_task(
            task_dir=str(task_dir),
            output_dir=args.output_dir,
            stereo_method=args.stereo_method,
            reconvert=args.reconvert,
        )


if __name__ == "__main__":
    main()
